From d01a22f33eac1fe8e0d0341db65f86f17c37e4f9 Mon Sep 17 00:00:00 2001 From: Alain Date: Sat, 16 May 2026 01:06:10 +0800 Subject: [PATCH 01/52] Implement feature X to enhance user experience and fix bug Y in module Z --- docs/design.md | 1121 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 1121 insertions(+) create mode 100644 docs/design.md diff --git a/docs/design.md b/docs/design.md new file mode 100644 index 0000000..9309afa --- /dev/null +++ b/docs/design.md @@ -0,0 +1,1121 @@ +# AIT 功能设计与架构文档 + +> 版本:v2.0 设计草案 +> 日期:2026-05-16 + +--- + +## 目录 + +1. [概述](#1-概述) +2. [现有架构分析](#2-现有架构分析) +3. [新架构设计](#3-新架构设计) +4. [交互式 TUI 设计](#4-交互式-tui-设计) +5. [Turbo 模式设计](#5-turbo-模式设计) +6. [新增数据结构与接口](#6-新增数据结构与接口) +7. [开发计划](#7-开发计划) + +--- + +## 1. 概述 + +本文档描述 AIT 工具 v2.0 的三个核心功能迭代: + +### 1.1 交互式 TUI + +将现有"执行完即退出"的单次命令行模式升级为**全屏交互式终端界面**,并以任务管理作为主入口: + +- 首页展示任务列表,可直接选择已有任务再次执行 +- 无需记忆所有参数,通过向导页创建或编辑任务 +- 任务详情页集中展示配置摘要、最近结果与运行记录 +- 测试运行时实时展示指标面板(成功率、TPS、TTFT、缓存命中率、并发状态) +- 请求日志滚动查看 +- 结果页支持键盘操作(生成报告、返回任务详情、再次运行等) +- 支持通过 CLI 参数生成临时任务草稿并进入 TUI 继续操作 + +### 1.2 任务管理 + +新增“任务”作为一等对象,用于保存和复用测试配置: + +- 每个任务只绑定一个模型,保存协议、完整接口地址、模型、Prompt、标准模式或 Turbo 参数 +- 协议值细化为 `openai-completions`、`openai-responses`、`anthropic-messages` +- 首页以任务列表形式呈现,支持新建、编辑、删除、复制、搜索和直接运行 +- 任务详情页展示最近一次运行摘要和最近运行记录 +- 多模型回归通过多个任务组织,而不是在单个任务内批量执行 +- 每次运行都会沉淀为任务记录,便于后续直接选择再次测试或对比回归结果 + +### 1.3 Turbo 模式 + +一种新的测试模式,用于**探测服务的最大稳定承载能力**: + +- 从初始并发数出发,按步进值逐级提升并发 +- 每个并发级别执行固定数量请求,采集该级别的完整指标 +- 当成功率或延迟超过阈值时判定服务出现降级,自动停止 +- 输出"并发爬坡曲线"报告,直观展示吞吐量、延迟与缓存命中率随并发变化的趋势 + +--- + +## 2. 现有架构分析 + +### 2.1 模块职责 + +``` +cmd/ait/ait.go ← 入口:flag 解析、参数校验、编排多模型执行 +internal/ + display/display.go ← 输出层:欢迎信息、进度条、结果表格 + runner/runner.go ← 执行层:并发请求、指标收集、进度回调 + client/ ← 协议层:OpenAI / Anthropic HTTP 客户端 + types/types.go ← 共享类型:Input、StatsData、ReportData + report/ ← 报告层:JSON / CSV 渲染 + prompt/ ← 输入层:字符串/文件/长度生成 + config/ ← 空(待实现) + network/ ← IP 工具 + upload/ ← 匿名数据上传 + logger/ ← 请求日志记录 +``` + +### 2.2 当前执行链 + +``` +main() + ├─ flag 解析 + 参数校验 + ├─ display.ShowWelcome() + ├─ display.ShowInput() + ├─ for each model: + │ runner.RunWithProgress(progressCallback) + │ └─ progressCallback → display.UpdateProgress() + ├─ display.FinishProgress() + ├─ display.ShowErrorsReport() + ├─ display.ShowSignalReport() 或 ShowMultiReport() + └─ report.GenerateReports() +``` + +### 2.3 现有问题 + +| 问题 | 影响 | +|------|------| +| `progressCallback` 仅能更新一个全局进度条,无法展示实时指标 | 测试中只能看到"完成了多少个",看不到当前 TPS/TTFT 等关键指标 | +| 多模型串行执行,前一个模型结束后进度条重置 | 体验割裂 | +| 没有中途中断的能力 | 一旦启动只能 Ctrl+C 强杀 | +| 并发数固定,无法动态探测极限 | 需要手动二分多次执行 | +| 协议抽象过粗,只区分 `openai` / `anthropic` | 无法准确对应 OpenAI Completions、OpenAI Responses、Anthropic Messages 的差异 | +| 测试配置无法保存为任务 | 每次测试都要重复输入参数,难以形成稳定回归基线 | +| 结果展示完即退出,无法二次查阅 | 对比分析需要回滚终端 | + +--- + +## 3. 新架构设计 + +### 3.1 整体模块图 + +``` +┌─────────────────────────────────────────────────────────────────┐ +│ cmd/ait/main.go ─ 入口 & 模式路由 │ +│ │ +│ ┌─ 无必填参数 ──→ TUI 任务中心(任务列表) │ +│ └─ 有完整参数 ──→ TUI 任务详情(生成临时任务草稿) │ +└───────────────┬─────────────────────────────────────────────────┘ + │ + ┌───────────▼───────────────────────────────────────────────┐ + │ internal/tui/ (NEW) │ + │ │ + │ model.go ─ BubbleTea 根模型 + 状态机 │ + │ messages.go ─ 所有 Msg 类型定义 │ + │ styles.go ─ lipgloss 样式常量 │ + │ │ + │ tasklist/ ─ 任务列表页 │ + │ taskdetail/ ─ 任务详情页 │ + │ wizard/ ─ 新建 / 编辑任务向导页 │ + │ dashboard/ ─ 运行中仪表盘页 │ + │ result/ ─ 结果展示页 │ + │ turbo/ ─ Turbo 仪表盘页 │ + └───────────────────────────────────────────────────────────┘ + │ program.Send(msg) + │ ↑ 从任意 goroutine 安全推送 + ┌───────────▼──────────────────────────────┐ + │ internal/runner/runner.go(已有,扩展) │ + │ RunWithCallback(cb RequestDoneCallback) │ + └──────────────────────────────────────────┘ + │ + ┌───────────▼──────────────────────────────┐ + │ internal/turbo/ (NEW) │ + │ Runner ─ 并发爬坡调度器 │ + │ Strategy ─ 步进 & 终止策略 │ + └──────────────────────────────────────────┘ + │ + ┌───────────▼──────────────────────────────┐ + │ internal/task/ (NEW) │ + │ Store ─ 任务 CRUD / 搜索 / 排序 │ + │ History ─ 任务运行记录与最近结果摘要 │ + └──────────────────────────────────────────┘ + │ + ┌───────────▼──────────────────────────────┐ + │ internal/client/ (已有,不变) │ + │ OpenAI Completions / Responses / │ + │ Anthropic Messages HTTP 客户端 │ + └──────────────────────────────────────────┘ + + ┌─────────────────────────────────────────────────────────┐ + │ internal/config/ (NEW) │ + │ 从 ~/.ait/config.json 加载 / 保存全局偏好 │ + └─────────────────────────────────────────────────────────┘ + + ┌─────────────────────────────────────────────────────────┐ + │ internal/report/ (已有,扩展) │ + │ 新增 turbo_renderer.go ─ Turbo 爬坡报告 │ + └─────────────────────────────────────────────────────────┘ +``` + +### 3.2 目录结构变化 + +```diff + cmd/ + ait/ +- ait.go ← 原来全部逻辑 ++ main.go ← 入口:模式检测 + 任务中心路由 ++ flags.go ← 所有 flag 定义 ++ run_tui.go ← TUI 模式启动 / 临时任务草稿注入 + + internal/ ++ tui/ ++ model.go ← 根 BubbleTea Model + 状态机 ++ messages.go ← 所有 Msg 类型 ++ styles.go ← lipgloss 样式常量 ++ tasklist/ ++ model.go ← 任务列表状态 ++ view.go ← 任务列表 UI 渲染 ++ taskdetail/ ++ model.go ← 任务详情状态 ++ view.go ← 任务详情 UI 渲染 ++ wizard/ ++ model.go ← 新建 / 编辑任务向导状态 ++ view.go ← 向导 UI 渲染 ++ dashboard/ ++ model.go ← 运行仪表盘状态 ++ view.go ← 仪表盘 UI 渲染 ++ result/ ++ model.go ← 结果页状态 ++ view.go ← 结果页 UI 渲染 ++ turbo/ ++ model.go ← Turbo 仪表盘状态 ++ view.go ← Turbo 仪表盘 UI 渲染 + ++ turbo/ ++ runner.go ← 并发爬坡调度器 ++ strategy.go ← 步进 & 终止策略 ++ result.go ← TurboResult、LevelResult 类型 + ++ config/ ++ config.go ← ~/.ait/config.json 全局配置读写 + ++ task/ ++ store.go ← ~/.ait/tasks.json 读写 + CRUD ++ history.go ← ~/.ait/history/.json 读写 + + runner/ + runner.go ← 扩展:增加 Stop() 方法 + 每请求完成回调 + + types/ + types.go ← 扩展:增加 TurboConfig、TurboResult + + report/ + report.go ← 已有 ++ turbo_renderer.go ← Turbo 爬坡报告渲染 +``` + +### 3.3 关键设计原则 + +**原则 1:Runner 不感知 TUI** + +Runner 通过回调函数推送进度,不直接依赖 BubbleTea。TUI 模式下由外层把 `program.Send(msg)` 包进回调: + +```go +// runner 接口不变,增加每请求完成的细粒度回调 +type RequestDoneCallback func(metrics *client.ResponseMetrics, index int, err error) + +// TUI 模式下的回调实现 +cb := func(m *client.ResponseMetrics, idx int, err error) { + program.Send(tui.RequestDoneMsg{Metrics: m, Index: idx, Err: err}) +} +runner.RunWithCallback(cb) +``` + +**原则 2:TUI 是纯状态机** + +`tui.Model` 通过消息驱动状态转换,不直接调用任何 I/O 函数,所有副作用都封装在 `tea.Cmd` 中,方便测试。 + +**原则 3:任务是一等领域对象** + +Runner 消费的是一次运行所需的 `Input`,但 UI 和持久化围绕 `TaskDefinition` 展开。列表、详情、编辑、重跑、历史记录都基于任务对象组织,而不是把一次性的 flag 输入直接暴露给用户。 + +**原则 4:一个任务只测一个模型** + +任务是最小回归单元。`TaskDefinition` 只保存一个 `Model`,这样任务详情、运行记录、Turbo 极限和结果对比都能稳定映射到单一模型;若要覆盖多个模型,应创建多个任务分别执行。 + +### 3.4 任务生命周期 + +1. 用户从任务列表进入“新建任务”向导,保存后写入 `~/.ait/tasks.json` +2. 用户在任务详情页查看配置摘要、最近结果和最近运行记录 +3. 用户从任务详情页启动标准模式或 Turbo 模式测试 +4. 测试完成后写入 `~/.ait/history/.json`,并回写任务的最近运行摘要 +5. 用户后续可直接在任务列表或任务详情页再次运行,无需重新输入参数 + +--- + +## 4. 交互式 TUI 设计 + +### 4.1 技术选型 + +| 库 | 用途 | +|----|------| +| `charm.land/bubbletea/v2` | 主框架:消息驱动状态机,从任意 goroutine 安全推送消息 | +| `github.com/charmbracelet/bubbles` | 预制组件:`textinput`、`list`、`spinner`、`viewport`、`progress`、`table` | +| `github.com/charmbracelet/lipgloss` | 样式 & 布局:边框、颜色、多栏弹性布局 | +| `github.com/NimbleMarkets/ntcharts` | Turbo 爬坡折线图 | + +### 4.2 TUI 状态机 + +``` +启动无参数 ───────────────────────────────→ TaskList +启动带完整参数 ─→ 生成临时任务草稿 ───────→ TaskDetail + + ┌─────────────┐ + │ TaskList │ + │ 任务列表页 │ + └──────┬──────┘ + [a 新建] │ │ [Enter] + │ │ + ▼ ▼ + ┌─────────────┐ + │ Wizard │ + │ 新建/编辑任务 │ + └──────┬──────┘ + [保存任务] │ + ▼ + ┌─────────────┐ + │ TaskDetail │ + │ 任务详情页 │ + └──────┬──────┘ + [Enter / r] │ │ [e 编辑] + │ └──────────────┐ + │ │ + [标准模式] ▼ │ + ┌─────────────┐ │ + │ Running │ │ + │ 标准运行中 │ │ + └──────┬──────┘ │ + [完成/s] │ │ + ▼ │ + ┌─────────────┐ │ + │ Completed │ │ + │ 标准结果页 │ │ + └──────┬──────┘ │ + [b 返回详情] │ │ + └────────────────┘ + + [Turbo 模式] ▼ + ┌─────────────┐ + │TurboRunning │ + │ Turbo 爬坡中 │ + └──────┬──────┘ + [完成/s] │ + ▼ + ┌─────────────┐ + │TurboCompleted│ + │ Turbo 结果页 │ + └──────┬──────┘ + [b 返回详情] │ + └──────────────→ TaskDetail +``` + +### 4.3 页面设计 + +--- + +#### 页面 1:任务列表首页 + +``` +╔══ AIT 任务中心 ─────────────────────────────────────────────══╗ +║ 已保存任务: 12 最近运行: 2026-05-16 09:42 [/] 搜索任务 ║ +╠══════════════════════╦═══════════════════════════════════════════╣ +║ 任务列表 ║ 快捷操作 / 最近摘要 ║ +║ ║ ║ +║ ▶ nightly-openai ║ [a] 新建任务 ║ +║ 标准模式 · gpt-4o ║ [Enter] 查看详情 ║ +║ 并发 10 / 请求 200 ║ [r] 直接运行选中任务 ║ +║ 上次: 98.5% · 12m 前 ║ [e] 编辑 [d] 删除 [y] 复制任务 ║ +║ ║ ║ +║ turbo-anthropic ║ 最近执行 ║ +║ Turbo · claude-3-7 ║ nightly-openai ✓ 98.5% 245ms ║ +║ 1→50 +2 / 每级 30 ║ turbo-anthropic ★ 稳定并发 8 ║ +║ 上次: 峰值 TPS 245.3 ║ smoke-regression ✗ timeout ×2 ║ +║ ║ ║ +║ smoke-regression ║ 提示:支持按任务名、协议、模型、模式过滤║ +║ 标准模式 · gpt-4o-mini ║ ║ +║ 并发 2 / 请求 20 ║ ║ +║ 从未运行 ║ ║ +╠══════════════════════╩═══════════════════════════════════════════╣ +║ [↑↓] 选择 [Enter] 详情 [a] 新建 [r] 运行 [q] 退出 ║ +╚══════════════════════════════════════════════════════════════════╝ +``` + +--- + +#### 页面 2:任务详情页 + +``` +╔══ AIT 任务详情 ─ nightly-openai ────────────────────────────══╗ +║ 任务 ID: task_01 更新: 2026-05-16 09:30 最近运行: 12m 前 ║ +╠══════════════════════╦═══════════════════════════════════════════╣ +║ 配置摘要 ║ 最近一次结果 ║ +║ ║ ║ +║ 协议 openai-responses ║ 状态 ✓ 完成 ║ +║ 接口地址 https://api.openai.com/v1/responses ║ +║ 模型 gpt-4o ║ avg TTFT 245ms ║ +║ 模式 标准模式 ║ avg TPS 124.3 tok/s ║ +║ 并发 10 ║ 缓存命中率 42.0% ║ +║ 请求数 200 ║ 总耗时 20.4s ║ +║ Prompt 你好,介绍一下你自己。║ 报告 ait-report-...json ║ +╠══════════════════════╩═══════════════════════════════════════════╣ +║ 最近运行记录 ║ +║ 2026-05-16 09:30 ✓ 98.5% TTFT 245ms Cache 42% 20.4s ║ +║ 2026-05-15 23:10 ✓ 99.0% TTFT 231ms Cache 38% 19.8s ║ +║ 2026-05-15 21:42 ✗ timeout ×2 Cache 12% 31.2s ║ +╠══════════════════════════════════════════════════════════════════╣ +║ [Enter] 运行 [e] 编辑 [h] 完整历史 [d] 删除 [b] 返回 ║ +╚══════════════════════════════════════════════════════════════════╝ +``` + +--- + +#### 页面 3:向导 - 新建任务(Step 1/3) + +``` +╔══════════════════════════════════════════════════════════════════╗ +║ ██████╗ ██╗████████╗ AI 模型性能测试工具 v2.0 ║ +║ ██╔══██╗██║╚══██╔══╝ https://github.com/yinxulai/ait ║ +║ ███████║██║ ██║ ║ +║ ██╔══██║██║ ██║ 向导 1/3 · 新建任务 ║ +║ ██║ ██║██║ ██║ ║ +║ ╚═╝ ╚═╝╚═╝ ╚═╝ ║ +╠══════════════════════════════════════════════════════════════════╣ +║ ║ +║ 任务名称 nightly-openai ║ +║ ────────────────────────────────────────── ║ +║ ║ +║ 协议类型 > openai-responses ║ +║ ○ openai-completions ║ +║ ● openai-responses ○ anthropic-messages ║ +║ ║ +║ 接口地址 https://api.openai.com/v1/responses ║ +║ ────────────────────────────────────────── ║ +║ 提示:填写完整接口地址,而不是 base URL ║ +║ ║ +║ API 密钥 sk-•••••••••••••••••••••••••••••• ║ +║ ────────────────────────────────────────── ║ +║ ║ +║ 测试模型 gpt-4o ║ +║ ────────────────────────────────────────── ║ +║ 提示:每个任务仅允许选择一个模型 ║ +║ ║ +╠══════════════════════════════════════════════════════════════════╣ +║ [Tab] 下一项 [↑↓] 切换协议 [Enter] 下一步 [Esc] 返回 ║ +╚══════════════════════════════════════════════════════════════════╝ +``` + +--- + +#### 页面 4:向导 - 测试参数(Step 2/3) + +``` +╔══════════════════════════════════════════════════════════════════╗ +║ AIT v2.0 向导 2/3 · 任务参数 ║ +╠══════════════════════════════════════════════════════════════════╣ +║ ║ +║ 测试模式 ○ 标准模式 ● Turbo 模式 ║ +║ [←→ 切换] ║ +║ ║ +║ ── 标准模式参数 ────────────────────────────────────────── ║ +║ 并发数 [ 5 ] 请求总数 [ 100 ] ║ +║ 超时时间 [ 300s ] 流式模式 [✓ 开启] ║ +║ ║ +║ ── Turbo 模式参数 ──────────────────────────────────────── ║ +║ 初始并发 [ 1 ] 最大并发 [ 50 ] ║ +║ 步进值 [ 2 ] 每级请求数 [ 30 ] ║ +║ 停止条件 成功率低于 [ 90% ] 或 延迟超过 [ 10s ] ║ +║ ║ +║ ── Prompt 配置 ─────────────────────────────────────────── ║ +║ 输入方式 ● 直接输入 ○ 文件 ○ 按长度生成 ║ +║ 内容 你好,介绍一下你自己。 ║ +║ ────────────────────────────────────────── ║ +║ ║ +║ 运行后记录 [✓ 保存运行记录到任务历史] ║ +║ ║ +╠══════════════════════════════════════════════════════════════════╣ +║ [Tab] 下一项 [←→] 切换模式 [Enter] 下一步 [Esc] 返回 ║ +╚══════════════════════════════════════════════════════════════════╝ +``` + +--- + +#### 页面 5:向导 - 确认(Step 3/3) + +``` +╔══════════════════════════════════════════════════════════════════╗ +║ AIT v2.0 向导 3/3 · 保存任务 ║ +╠══════════════════════════════════════════════════════════════════╣ +║ ║ +║ 🆔 任务 ID a3f2-8b1c-... ║ +║ 🏷️ 任务名称 nightly-openai ║ +║ 🔗 协议 openai-responses ║ +║ 🌐 接口地址 https://api.openai.com/v1/responses ║ +║ 🔑 API 密钥 sk-****...**** ║ +║ 🤖 测试模型 gpt-4o ║ +║ 🚀 测试模式 Turbo 模式 ║ +║ ⚡ 并发爬坡 1 → 50 步进 +2 每级 30 请求 ║ +║ 🛑 停止条件 成功率 < 90% 或 延迟 > 10s ║ +║ 🌊 流式模式 开启 ║ +║ 📝 Prompt 你好,介绍一下你自己。 (长度: 12) ║ +║ ║ +║ 💾 保存任务到 ~/.ait/tasks.json [✓] ║ +║ 📝 创建后自动写入运行历史索引 [✓] ║ +║ ║ +║ ┌────────────────────────────────────────────────────────┐ ║ +║ │ ▶ 保存任务 │ ║ +║ └────────────────────────────────────────────────────────┘ ║ +║ ║ +╠══════════════════════════════════════════════════════════════════╣ +║ [Enter] 保存任务 [r] 保存并运行 [Esc] 返回修改 ║ +╚══════════════════════════════════════════════════════════════════╝ +``` + +--- + +#### 页面 6:标准模式运行仪表盘 + +``` +╔══ AIT 正在测试 ─ gpt-4o ────────────────────────────────────════╗ +║ 任务: nightly-openai 协议: openai-responses 并发: 5 请求: 100║ +╠══════════════════════╦═══════════════════════════════════════════╣ +║ 进度 ║ 实时指标 ║ +║ ║ ║ +║ 完成 ████████░░ 47 ║ 成功率 ██████████████████░░░ 98.0% ║ +║ 失败 ░░░░░░░░░░ 2 ║ ║ +║ 总计 100 ║ avg TPS 124.3 tok/s ║ +║ ║ avg TTFT 245ms ║ +║ ──────────────────── ║ 缓存命中率 42.0% ║ +║ 已用时 12.4s ║ avg 总耗时 1.24s ║ +║ 预计剩余 ~8.2s ║ 并发槽 [●●●●●] 5/5 活跃 ║ +║ ║ ║ +╠══════════════════════╩═══════════════════════════════════════════╣ +║ 请求日志 [l 展开] ║ +║ ✓ #48 245ms TTFT:82ms cache:100% 128tok 12.3tok/s ║ +║ ✗ #47 timeout (30.0s) ║ +║ ✓ #46 312ms TTFT:95ms cache:25% 96tok 9.8tok/s ║ +║ ✓ #45 198ms TTFT:71ms cache:0% 145tok 14.2tok/s ║ +╠══════════════════════════════════════════════════════════════════╣ +║ [p] 暂停 [s] 停止 [l] 切换日志详情 [r] 提前报告 [q] 退出 ║ +╚══════════════════════════════════════════════════════════════════╝ +``` + +#### 页面 7:Turbo 模式运行仪表盘 + +``` +╔══ AIT Turbo 模式 ─ gpt-4o ──────────────────────────────────════╗ +║ 任务: turbo-anthropic 协议: anthropic-messages ║ +║ 爬坡: 1→50 步进: +2/级 每级: 30 请求 ║ +╠══════════════════════╦═══════════════════════════════════════════╣ +║ 爬坡曲线 (TPS) ║ 当前级别 [并发 = 8] ║ +║ ║ ║ +║ 250┤ ╭──● ║ 成功率 █████████████████░░ 96.0% ║ +║ 200┤ ╭────╯ ║ TPS 245.3 tok/s ║ +║ 150┤ ╭──╯ ║ TTFT 312ms ║ +║ 100┤─╯ ║ Cache 44.0% ║ +║ 50┤ ║ 总耗时 1.51s ║ +║ └──┬──┬──┬──┬──→ ║ 本级完成 28 / 30 ║ +║ 1 2 4 6 8 ║ 状态 🟢 稳定,继续探测... ║ +╠══════════════════════╩═══════════════════════════════════════════╣ +║ 并发 成功率 TPS TTFT Cache 总耗时 状态 ║ +║ ────────────────────────────────────────────────────────────── ║ +║ 1 100.0% 31.2 89ms 0.0% 0.82s ✓ 稳定 ║ +║ 2 100.0% 62.5 91ms 18.0% 0.84s ✓ 稳定 ║ +║ 4 99.0% 121.3 98ms 26.0% 0.91s ✓ 稳定 ║ +║ 6 98.0% 178.4 124ms 33.0% 1.08s ✓ 稳定 ║ +║ ▶ 8 96.0% 245.3 312ms 44.0% 1.51s 🔄 探测中 ║ +╠══════════════════════════════════════════════════════════════════╣ +║ [s] 停止 [m] 手动标记为极限 [r] 提前生成报告 [q] 退出 ║ +╚══════════════════════════════════════════════════════════════════╝ +``` + +--- + +#### 页面 8:标准模式结果页 + +``` +╔══ AIT 测试完成 ─ gpt-4o ─────────────────────────────────────════╗ +║ 任务: nightly-openai 协议: openai-responses ║ +║ 耗时: 20.4s 成功率: 98.0% 总请求: 100 ║ +╠══════════════════════════════════════════════════════════════════╣ +║ 指标 最小值 平均值 标准差 最大值 ║ +║ ────────────────────────────────────────────────────────────── ║ +║ 总耗时 0.82s 1.24s ±0.31s 3.12s ║ +║ TTFT 71ms 245ms ±89ms 812ms ║ +║ TPOT 12ms 18ms ±4ms 45ms ║ +║ 输出 TPS 89.2 124.3 ±21.4 198.5 ║ +║ 吞吐 TPS 102.1 148.7 ±25.2 231.4 ║ +║ 缓存命中率 0.0% 42.0% ±18.5% 100.0% ║ +║ 输入 Token 42 64 ±12 98 ║ +║ 输出 Token 78 128 ±32 195 ║ +║ DNS 时间 1.2ms 3.4ms 12.1ms ║ +║ TCP 连接时间 2.1ms 4.8ms 9.3ms ║ +║ TLS 握手时间 8.4ms 12.3ms 28.7ms ║ +╠══════════════════════════════════════════════════════════════════╣ +║ 错误摘要 (2 个错误,占 2.0%) ║ +║ context deadline exceeded (timeout) × 2 ║ +╠══════════════════════════════════════════════════════════════════╣ +║ 任务记录已更新:最近运行摘要 + 历史索引 ║ +╠══════════════════════════════════════════════════════════════════╣ +║ [r] 生成报告 [c] 复制摘要 [b] 返回任务详情 [q] 退出 ║ +╚══════════════════════════════════════════════════════════════════╝ +``` + +--- + +#### 页面 9:Turbo 模式结果页 + +``` +╔══ AIT Turbo 完成 ─ gpt-4o ──────────────────────────────────════╗ +║ 任务: turbo-anthropic 协议: anthropic-messages ║ +║ 🏆 最大稳定并发: 8 峰值 TPS: 245.3 tok/s 探测耗时: 52s ║ +╠══════════════════════════════════════════════════════════════════╣ +║ TPS 爬坡曲线 成功率曲线 ║ +║ ║ +║ 300┤ ╭─●最大稳定 245.3 100%┤████████████ ║ +║ 200┤ ╭────╯ ╲降级 ║ ████░░ ║ +║ 100┤ ╭────╯ ╲ 95%┤ ░░░ ← 阈值 ║ +║ 0┤───╯ ● 90%└──────────────→ ║ +║ └──┬──┬──┬──┬──┬──→ 1 2 4 6 8 10 ║ +║ 1 2 4 6 8 10 并发数 ║ +║ ║ +╠══════════════════════════════════════════════════════════════════╣ +║ 并发 成功率 TPS TTFT Cache 总耗时 结论 ║ +║ ────────────────────────────────────────────────────────────── ║ +║ 1 100.0% 31.2 89ms 0.0% 0.82s ✓ 稳定 ║ +║ 2 100.0% 62.5 91ms 18.0% 0.84s ✓ 稳定 ║ +║ 4 99.0% 121.3 98ms 26.0% 0.91s ✓ 稳定 ║ +║ 6 98.0% 178.4 124ms 33.0% 1.08s ✓ 稳定 ║ +║ ★ 8 96.0% 245.3 312ms 44.0% 1.51s ✓ 最大稳定 ║ +║ 10 84.0% 198.1 892ms 12.0% 4.23s ✗ 降级 ║ +╠══════════════════════════════════════════════════════════════════╣ +║ 任务记录已更新:最近运行摘要 + 历史索引 ║ +╠══════════════════════════════════════════════════════════════════╣ +║ [r] 生成报告 [d] 详细数据 [b] 返回任务详情 [q] 退出 ║ +╚══════════════════════════════════════════════════════════════════╝ +``` + +--- + +### 4.4 键盘交互规范 + +| 按键 | 适用页面 | 功能 | +|------|----------|------| +| `a` | 任务列表 | 新建任务 | +| `/` | 任务列表 | 搜索 / 过滤任务 | +| `Enter` | 任务列表 | 查看任务详情 | +| `r` | 任务列表 / 任务详情 | 直接运行当前任务 | +| `e` | 任务详情 | 编辑当前任务 | +| `d` | 任务详情 | 删除当前任务 | +| `y` | 任务列表 / 任务详情 | 复制当前任务 | +| `h` | 任务详情 | 打开完整运行历史 | +| `b` | 任务详情 / 结果页 | 返回上一级 | +| `Tab` / `Shift+Tab` | 向导 | 在输入项间切换焦点 | +| `↑` / `↓` | 任务列表、向导、结果表格 | 上下选择 | +| `←` / `→` | 向导模式选择 | 切换选项 | +| `Enter` | 向导 | 确认 / 下一步 / 保存 | +| `Esc` | 所有页 | 返回上一步 / 取消 | +| `p` | Running | 暂停/恢复 | +| `s` | Running / Turbo | 停止测试 | +| `l` | Dashboard | 切换日志详情展开/折叠 | +| `r` | Running / 结果页 | 生成报告文件 | +| `m` | Turbo Running | 手动标记当前并发为最大稳定并发并停止 | +| `c` | 结果页 | 复制摘要到剪贴板 | +| `q` / `Ctrl+C` | 所有页 | 退出程序 | + +--- + +### 4.5 布局响应式策略 + +- 终端宽度 `< 80` 列:任务列表与任务详情折叠为单列,摘要面板移动到下方 +- 终端宽度 `≥ 80` 列:任务列表、任务详情和运行页都采用双栏布局 +- 终端高度不足时,历史记录区或日志区自动收缩,至少保留 3 行内容 + +--- + +## 5. Turbo 模式设计 + +### 5.1 算法流程 + +``` +初始化 + concurrency = TurboConfig.InitConcurrency // 默认: 1 + step = TurboConfig.StepSize // 默认: 2 + levelReqs = TurboConfig.LevelRequests // 默认: 30(每级执行的请求数) + +循环 + ① 用当前 concurrency 执行 levelReqs 个请求 + → 并发调用已有 runner.Runner(复用所有指标收集逻辑) + + ② 采集该级别指标(LevelResult) + successRate = 成功请求数 / levelReqs + avgTPS = mean(输出 TPS) + avgTTFT = mean(TTFT) + cacheHitRate = cachedInputTokens / inputTokens + avgTotalTime = mean(总耗时) + + ③ 判断终止条件(任意一条满足即终止) + a. successRate < TurboConfig.MinSuccessRate → 服务降级 + b. avgTotalTime > TurboConfig.MaxLatency → 延迟过高 + c. concurrency >= TurboConfig.MaxConcurrency → 达到探测上限 + d. 用户按下 [s] 或 [m] → 手动终止 + + ④ 如果未终止 + 记录 LevelResult 到 TurboResult.Levels + concurrency += step + 继续循环 + +结束 + 最大稳定并发 = 最后一个通过终止检查的并发数 + TurboResult.MaxStableConcurrency = 最后一个 ✓ 稳定级别的并发数 + TurboResult.PeakTPS = 所有 ✓ 稳定级别中的最大 avgTPS + 生成 TurboResult +``` + +其中缓存命中率定义为“缓存命中的输入 token / 总输入 token”。 +- OpenAI Completions / Responses:基于 usage 中的 `cached_tokens` +- Anthropic Messages:基于 usage 中的 `cache_read_input_tokens` +- 若响应未提供缓存统计字段,则该指标记为 `N/A`,不参与阈值判断 + +### 5.2 停止条件详解 + +``` +type StopReason int + +const ( + StopReasonLowSuccessRate StopReason = iota // 成功率低于阈值 + StopReasonHighLatency // 延迟超过阈值 + StopReasonMaxConcurrency // 达到最大并发上限 + StopReasonManual // 用户手动停止 + StopReasonDegraded // 综合降级判断 +) +``` + +**降级判断示例:** + +``` +并发=8: 成功率=96%, avgTPS=245, avgTTFT=312ms, cache=44% → ✓ 通过(成功率>90%,延迟<10s) +并发=10: 成功率=84% → ✗ 停止(成功率 84% < 阈值 90%) +最大稳定并发 = 8 +``` + +### 5.3 CLI 参数 + +Turbo 模式通过 `--turbo` 标志启用,新增以下参数: + +```bash +# 启用 Turbo 模式 +ait --protocol=openai-responses --endpoint=https://api.openai.com/v1/responses --model=gpt-4o --turbo + +# 完整 Turbo 参数 +ait --protocol=openai-responses --endpoint=https://api.openai.com/v1/responses --model=gpt-4o --turbo \ + --turbo-init-concurrency=1 \ # 初始并发数(默认: 1) + --turbo-max-concurrency=50 \ # 最大探测并发数(默认: 50) + --turbo-step=2 \ # 每级步进值(默认: 2) + --turbo-level-requests=30 \ # 每级执行的请求数(默认: 30) + --turbo-min-success-rate=0.9 \ # 成功率低于此值停止(默认: 0.9) + --turbo-max-latency=10s # 延迟超过此值停止(默认: 10s) +``` + +**与现有参数的关系:** + +- `--concurrency` 在 Turbo 模式下**被忽略**(Turbo 自己控制并发) +- `--count` 在 Turbo 模式下表示**每级**的请求数(等同于 `--turbo-level-requests`,优先级低于后者) +- `--protocol` 允许值为 `openai-completions`、`openai-responses`、`anthropic-messages` +- `--endpoint` 必须填写完整接口地址,例如 `https://api.openai.com/v1/responses` +- 其他参数(protocol、endpoint、apiKey、model、stream、timeout 等)正常生效 +- 单个任务只接受一个 `--model`;如果要测试多个模型,应拆分成多个任务 + +**与任务管理的关系:** + +- `--task=` 直接加载已保存任务并进入详情页或直接运行 +- 通过完整 CLI 参数启动时,AIT 会先生成一个未保存的临时任务草稿,用户可选择保存后复用 +- Turbo 运行完成后,其结果会自动追加到对应任务的历史记录中 + +### 5.4 Turbo 报告格式 + +**JSON 报告(新增字段):** + +```json +{ + "turbo": { + "max_stable_concurrency": 8, + "peak_tps": 245.3, + "stop_reason": "low_success_rate", + "probe_duration": "52.3s", + "protocol": "openai-responses", + "endpoint_url": "https://api.openai.com/v1/responses", + "config": { + "init_concurrency": 1, + "max_concurrency": 50, + "step": 2, + "level_requests": 30, + "min_success_rate": 0.9, + "max_latency": "10s" + }, + "levels": [ + { + "concurrency": 1, + "total_requests": 30, + "success_count": 30, + "success_rate": 1.0, + "avg_tps": 31.2, + "avg_ttft": "89ms", + "cache_hit_rate": 0.0, + "avg_total_time": "0.82s", + "stable": true + } + ] + } +} +``` + +**CSV 报告(新增 turbo 爬坡汇总表):** + +``` +protocol,concurrency,success_rate,avg_tps,peak_tps,avg_ttft,cache_hit_rate,avg_total_time,stable +openai-responses,1,1.00,31.2,38.5,89ms,0.00,0.82s,true +openai-responses,2,1.00,62.5,74.1,91ms,0.18,0.84s,true +openai-responses,4,0.99,121.3,145.2,98ms,0.26,0.91s,true +openai-responses,6,0.98,178.4,201.3,124ms,0.33,1.08s,true +openai-responses,8,0.96,245.3,280.1,312ms,0.44,1.51s,true +openai-responses,10,0.84,198.1,234.5,892ms,0.12,4.23s,false +``` + +--- + +## 6. 新增数据结构与接口 + +### 6.1 TaskDefinition + +```go +// internal/types/types.go 新增 + +// TaskDefinition 可重复执行的测试任务定义 +type TaskDefinition struct { + ID string `json:"id"` + Name string `json:"name"` + Input Input `json:"input"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` + LastRunAt *time.Time `json:"last_run_at,omitempty"` + LastRunSummary *TaskRunSummary `json:"last_run_summary,omitempty"` +} +``` + +其中 `Input.Model` 为单个模型标识,不允许逗号分隔列表。 +其中 `Input.Protocol` 允许值为 `openai-completions`、`openai-responses`、`anthropic-messages`。 +其中 `Input.EndpointURL` 必须是完整接口地址,例如: +- `openai-completions` → `https://api.openai.com/v1/chat/completions` +- `openai-responses` → `https://api.openai.com/v1/responses` +- `anthropic-messages` → `https://api.anthropic.com/v1/messages` + +### 6.2 TaskRunSummary + +```go +// TaskRunSummary 单次任务运行后的摘要信息 +type TaskRunSummary struct { + RunID string `json:"run_id"` + TaskID string `json:"task_id"` + Mode string `json:"mode"` + Status string `json:"status"` + Protocol string `json:"protocol"` + Model string `json:"model"` + StartedAt time.Time `json:"started_at"` + FinishedAt time.Time `json:"finished_at"` + SuccessRate float64 `json:"success_rate"` + AvgTTFT time.Duration `json:"avg_ttft"` + AvgTPS float64 `json:"avg_tps"` + CacheHitRate float64 `json:"cache_hit_rate"` + MaxStableConcurrency int `json:"max_stable_concurrency,omitempty"` + ReportJSONPath string `json:"report_json_path,omitempty"` + ReportCSVPath string `json:"report_csv_path,omitempty"` + ErrorSummary string `json:"error_summary,omitempty"` +} +``` + +### 6.3 TurboConfig + +```go +// internal/types/types.go 新增 + +// TurboConfig Turbo 模式的配置参数 +type TurboConfig struct { + InitConcurrency int // 初始并发数,默认 1 + MaxConcurrency int // 最大探测并发数,默认 50 + StepSize int // 每级步进值,默认 2 + LevelRequests int // 每级执行请求数,默认 30 + MinSuccessRate float64 // 停止阈值:成功率,默认 0.9 + MaxLatency time.Duration // 停止阈值:平均延迟,默认 10s +} + +// TurboLevelResult 单个并发级别的测试结果 +type TurboLevelResult struct { + Concurrency int `json:"concurrency"` + TotalRequests int `json:"total_requests"` + SuccessCount int `json:"success_count"` + SuccessRate float64 `json:"success_rate"` + AvgTPS float64 `json:"avg_tps"` + PeakTPS float64 `json:"peak_tps"` + AvgTTFT time.Duration `json:"avg_ttft"` + CacheHitRate float64 `json:"cache_hit_rate"` + AvgTotalTime time.Duration `json:"avg_total_time"` + StdDevTPS float64 `json:"stddev_tps"` + Stable bool `json:"stable"` + StopReason string `json:"stop_reason,omitempty"` +} + +// TurboResult Turbo 模式的最终结果 +type TurboResult struct { + Config TurboConfig `json:"config"` + Levels []TurboLevelResult `json:"levels"` + MaxStableConcurrency int `json:"max_stable_concurrency"` + PeakTPS float64 `json:"peak_tps"` + StopReason string `json:"stop_reason"` + ProbeDuration time.Duration `json:"probe_duration"` + Model string `json:"model"` + Protocol string `json:"protocol"` + EndpointURL string `json:"endpoint_url"` + Timestamp string `json:"timestamp"` +} +``` + +现有指标结构也需要补充缓存命中率字段,用于 dashboard、结果页和报告渲染: + +```go +type ResponseMetrics struct { + CachedInputTokens int // 当前请求命中的输入缓存 token 数 + CacheHitRate float64 // CachedInputTokens / max(PromptTokens, 1) +} + +type StatsData struct { + CacheHitRates []float64 // 所有请求的缓存命中率 +} + +type ReportData struct { + AvgCacheHitRate float64 `json:"avg_cache_hit_rate"` + MinCacheHitRate float64 `json:"min_cache_hit_rate"` + MaxCacheHitRate float64 `json:"max_cache_hit_rate"` + StdDevCacheHitRate float64 `json:"stddev_cache_hit_rate"` +} +``` + + ### 6.4 Input 扩展 + +```go +// internal/types/types.go 扩展 Input + +type Input struct { + // ... 现有字段 ... + + Protocol string // openai-completions | openai-responses | anthropic-messages + EndpointURL string // 完整接口地址,例如 https://api.openai.com/v1/responses + + // Turbo 模式 + Turbo bool // 是否启用 Turbo 模式 + TurboConfig TurboConfig // Turbo 配置(Turbo=true 时生效) +} +``` + +### 6.5 TUI 消息类型 + +```go +// internal/tui/messages.go + +// TasksLoadedMsg 任务列表加载完成 +type TasksLoadedMsg struct { + Tasks []types.TaskDefinition +} + +// TaskSavedMsg 任务保存完成 +type TaskSavedMsg struct { + Task types.TaskDefinition +} + +// TaskHistoryLoadedMsg 任务运行记录加载完成 +type TaskHistoryLoadedMsg struct { + TaskID string + History []types.TaskRunSummary +} + +// RequestDoneMsg 单个请求完成 +type RequestDoneMsg struct { + Metrics *client.ResponseMetrics + Index int + Err error +} + +// AllRequestsDoneMsg 所有请求完成 +type AllRequestsDoneMsg struct { + Result *types.ReportData + Errors []string +} + +// TurboLevelStartMsg Turbo 新一级开始 +type TurboLevelStartMsg struct { + Concurrency int + LevelIndex int +} + +// TurboLevelDoneMsg Turbo 一级完成 +type TurboLevelDoneMsg struct { + Level types.TurboLevelResult + LevelIndex int +} + +// TurboDoneMsg Turbo 全部完成 +type TurboDoneMsg struct { + Result *types.TurboResult +} + +// ProgressTickMsg 定时刷新实时指标 +type ProgressTickMsg struct { + Stats types.StatsData +} + +// ErrorMsg 运行时错误 +type ErrorMsg struct { + Err error +} +``` + +### 6.6 Runner 接口扩展 + +```go +// internal/runner/runner.go 新增 + +// RequestDoneCallback 每个请求完成后的回调(细粒度,供 TUI 使用) +type RequestDoneCallback func(metrics *client.ResponseMetrics, index int, err error) + +// RunWithCallback 运行测试,每个请求完成后调用 cb(线程安全) +// 同时保留原有的 RunWithProgress,供 Legacy 模式使用 +func (r *Runner) RunWithCallback(cb RequestDoneCallback) (*types.ReportData, error) + +// Stop 异步停止正在进行的测试 +func (r *Runner) Stop() +``` + +### 6.7 任务与全局配置持久化 + +```go +// internal/config/config.go + +type Config struct { + SaveAPIKey bool `json:"save_api_key"` + LastSelectedTaskID string `json:"last_selected_task_id,omitempty"` + DefaultProtocol string `json:"default_protocol,omitempty"` // openai-completions | openai-responses | anthropic-messages +} + +func Load() (*Config, error) // 从 ~/.ait/config.json 加载 +func (c *Config) Save() error // 保存到 ~/.ait/config.json + +// internal/task/store.go + +type TaskStore struct { + Tasks []types.TaskDefinition `json:"tasks"` +} + +func LoadTasks() (*TaskStore, error) // 从 ~/.ait/tasks.json 加载 +func (s *TaskStore) Save() error // 保存到 ~/.ait/tasks.json +func (s *TaskStore) Upsert(task types.TaskDefinition) // 新建或更新任务 +func (s *TaskStore) Delete(taskID string) error + +// internal/task/history.go + +func AppendRun(taskID string, run types.TaskRunSummary) error +func LoadHistory(taskID string, limit int) ([]types.TaskRunSummary, error) +``` + +--- + +## 7. 开发计划 + +### Phase 1 — 任务中心与 TUI 基础框架(优先) + +**目标:** 先建立任务管理主流程,再用 BubbleTea 替换现有的进度条 + 静态表格输出 + +**任务清单:** + +- [ ] 引入依赖:`charm.land/bubbletea/v2`、`bubbles`、`lipgloss` +- [ ] 实现 `internal/tui/` 基础骨架(model、messages、styles) +- [ ] 实现任务列表页(tasklist):选择 / 搜索 / 删除 / 复制 +- [ ] 实现任务详情页(taskdetail):配置摘要 + 最近记录 + 直接运行 +- [ ] 实现向导页(wizard):三步创建 / 编辑任务 +- [ ] 实现仪表盘页(dashboard):进度 + 实时指标双栏 +- [ ] 实现结果页(result):完整指标表格 + 键盘操作 +- [ ] 协议枚举细化:`openai-completions`、`openai-responses`、`anthropic-messages` +- [ ] 扩展指标采集与渲染:缓存命中率(dashboard / result / report) +- [ ] `cmd/ait/main.go` 模式检测路由(无参数 → 任务列表,有参数 → 临时任务草稿) +- [ ] 实现任务持久化(`tasks.json` + `history/*.json`) +- [ ] Runner 增加 `Stop()` 方法和 `RunWithCallback` 接口 +- [ ] 全局配置持久化(默认协议、最后选择任务、密钥保存策略) +- [ ] 结果页回写任务最近运行摘要 +- [ ] 响应式布局(终端宽度自适应) +- [ ] `internal/display/` 模块退役,由 TUI 全面接管输出 + +--- + +### Phase 2 — Turbo 模式 + +**目标:** 将并发爬坡能力完整融入任务体系 + +**任务清单:** + +- [ ] 实现 `internal/turbo/runner.go`:封装爬坡调度逻辑 +- [ ] 实现 `internal/turbo/strategy.go`:步进 & 终止策略 +- [ ] `types.TurboConfig`、`TurboLevelResult`、`TurboResult` 数据结构 +- [ ] TUI Turbo 仪表盘页(折线图 + 爬坡表格) +- [ ] `internal/report/turbo_renderer.go`:Turbo CSV/JSON 报告 +- [ ] Turbo 结果写回任务最近摘要和运行记录 +- [ ] 新增 CLI 参数:`--turbo`、`--turbo-*` 系列 + +--- + +### Phase 3 — 增强 + +**目标:** 细节打磨与扩展 + +**任务清单:** + +- [ ] 多任务 Turbo 对比(并排爬坡曲线) +- [ ] 任务收藏和快速筛选视图 +- [ ] 任务复制、模板化创建和批量导入 +- [ ] 运行记录对比视图(同一任务不同 run 对比) +- [ ] 结果页 `c` 键复制摘要到剪贴板 +- [ ] `ntcharts` 折线图替换 ASCII 折线图 +- [ ] 终端尺寸变化自适应重绘 +- [ ] 完善单元测试(TUI model 测试、turbo strategy 测试) + +--- + +### 依赖变更汇总 + +```diff + # go.mod 新增 ++ charm.land/bubbletea/v2 # TUI 主框架 ++ github.com/charmbracelet/bubbles # 预制 UI 组件 ++ github.com/charmbracelet/lipgloss # 样式与布局 ++ github.com/NimbleMarkets/ntcharts # Phase 3 图表(可选) + + # go.mod 移除 +- github.com/schollz/progressbar/v3 # 由 bubbles/progress 替代 +- github.com/olekukonko/tablewriter # 由 bubbles/table + lipgloss 替代 +``` From e6696b54b108e80758a9255cc46775c93801048f Mon Sep 17 00:00:00 2001 From: Alain Date: Sat, 16 May 2026 08:00:55 +0800 Subject: [PATCH 02/52] Add TUI styles, Turbo engine, and task definition handling - Implemented TUI styles for better UI presentation in `internal/tui/styles.go`. - Added Turbo engine functionality in `internal/turbo/engine.go` to manage concurrent task execution with success rate and latency checks. - Created test cases for Turbo engine in `internal/turbo/engine_test.go` to ensure correct behavior under various conditions. - Introduced task definition structure in `internal/types/types.go` to encapsulate task-related data. - Added input normalization and endpoint resolution functions to improve protocol handling in `internal/types/types.go`. - Developed unit tests for task definition building in `internal/tui/model_test.go` to validate input processing. --- cmd/ait/ait.go | 10 + go.mod | 26 +- go.sum | 50 +- internal/client/anthropic.go | 26 +- internal/client/anthropic_test.go | 30 +- internal/client/client.go | 7 +- internal/client/client_test.go | 120 ++- internal/client/openai.go | 343 ++++++-- internal/client/openai_test.go | 210 +++-- internal/config/config.go | 101 +++ internal/config/config_test.go | 51 ++ internal/runner/runner.go | 244 ++++-- internal/runner/runner_callback_test.go | 172 ++++ internal/runner/runner_test.go | 1 + internal/task/history.go | 89 ++ internal/task/history_test.go | 41 + internal/task/input.go | 48 ++ internal/task/input_test.go | 33 + internal/task/store.go | 102 +++ internal/task/store_test.go | 76 ++ internal/tui/messages.go | 22 + internal/tui/model.go | 1054 +++++++++++++++++++++++ internal/tui/model_test.go | 73 ++ internal/tui/styles.go | 32 + internal/turbo/engine.go | 195 +++++ internal/turbo/engine_test.go | 147 ++++ internal/types/types.go | 191 +++- 27 files changed, 3183 insertions(+), 311 deletions(-) create mode 100644 internal/config/config.go create mode 100644 internal/config/config_test.go create mode 100644 internal/runner/runner_callback_test.go create mode 100644 internal/task/history.go create mode 100644 internal/task/history_test.go create mode 100644 internal/task/input.go create mode 100644 internal/task/input_test.go create mode 100644 internal/task/store.go create mode 100644 internal/task/store_test.go create mode 100644 internal/tui/messages.go create mode 100644 internal/tui/model.go create mode 100644 internal/tui/model_test.go create mode 100644 internal/tui/styles.go create mode 100644 internal/turbo/engine.go create mode 100644 internal/turbo/engine_test.go diff --git a/cmd/ait/ait.go b/cmd/ait/ait.go index 7a80c78..c0acabc 100644 --- a/cmd/ait/ait.go +++ b/cmd/ait/ait.go @@ -13,6 +13,7 @@ import ( "github.com/yinxulai/ait/internal/prompt" "github.com/yinxulai/ait/internal/report" "github.com/yinxulai/ait/internal/runner" + "github.com/yinxulai/ait/internal/tui" "github.com/yinxulai/ait/internal/types" ) @@ -332,6 +333,7 @@ func executeModelsTestSuite(taskID string, modelList []string, finalProtocol, fi func main() { taskID := generateTaskID() versionFlag := flag.Bool("version", false, "显示版本信息") + interactiveFlag := flag.Bool("interactive", false, "启动交互式 TUI") baseUrl := flag.String("baseUrl", "", "服务地址") apiKey := flag.String("apiKey", "", "API 密钥") count := flag.Int("count", 10, "请求总数") @@ -357,6 +359,14 @@ func main() { os.Exit(0) } + if *interactiveFlag { + if err := tui.Run(); err != nil { + fmt.Printf("启动交互式 TUI 失败: %v\n", err) + os.Exit(1) + } + return + } + // 合并 --model 和 --models 参数 finalModels := *models if *model != "" { diff --git a/go.mod b/go.mod index 91c2fbb..9e6973a 100644 --- a/go.mod +++ b/go.mod @@ -2,19 +2,37 @@ module github.com/yinxulai/ait go 1.22 -toolchain go1.22.4 +require ( + github.com/charmbracelet/bubbles v0.20.0 + github.com/charmbracelet/bubbletea v1.2.1 + github.com/charmbracelet/lipgloss v1.0.0 + github.com/olekukonko/tablewriter v1.0.9 + github.com/schollz/progressbar/v3 v3.18.0 +) require ( + github.com/atotto/clipboard v0.1.4 // indirect + github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect + github.com/charmbracelet/x/ansi v0.8.0 // indirect + github.com/charmbracelet/x/term v0.2.0 // indirect + github.com/clipperhouse/stringish v0.1.1 // indirect + github.com/clipperhouse/uax29/v2 v2.5.0 // indirect + github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect github.com/fatih/color v1.18.0 // indirect + github.com/lucasb-eyer/go-colorful v1.3.0 // indirect github.com/mattn/go-colorable v0.1.13 // indirect github.com/mattn/go-isatty v0.0.20 // indirect - github.com/mattn/go-runewidth v0.0.16 // indirect + github.com/mattn/go-localereader v0.0.1 // indirect + github.com/mattn/go-runewidth v0.0.19 // indirect github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db // indirect + github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 // indirect + github.com/muesli/cancelreader v0.2.2 // indirect + github.com/muesli/termenv v0.15.2 // indirect github.com/olekukonko/errors v1.1.0 // indirect github.com/olekukonko/ll v0.0.9 // indirect - github.com/olekukonko/tablewriter v1.0.9 github.com/rivo/uniseg v0.4.7 // indirect - github.com/schollz/progressbar/v3 v3.18.0 + golang.org/x/sync v0.9.0 // indirect golang.org/x/sys v0.29.0 // indirect golang.org/x/term v0.28.0 // indirect + golang.org/x/text v0.3.8 // indirect ) diff --git a/go.sum b/go.sum index 2c4f6db..35c3f09 100644 --- a/go.sum +++ b/go.sum @@ -1,28 +1,72 @@ +github.com/atotto/clipboard v0.1.4 h1:EH0zSVneZPSuFR11BlR9YppQTVDbh5+16AmcJi4g1z4= +github.com/atotto/clipboard v0.1.4/go.mod h1:ZY9tmq7sm5xIbd9bOK4onWV4S6X0u6GY7Vn0Yu86PYI= +github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k= +github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8= +github.com/charmbracelet/bubbles v0.20.0 h1:jSZu6qD8cRQ6k9OMfR1WlM+ruM8fkPWkHvQWD9LIutE= +github.com/charmbracelet/bubbles v0.20.0/go.mod h1:39slydyswPy+uVOHZ5x/GjwVAFkCsV8IIVy+4MhzwwU= +github.com/charmbracelet/bubbletea v1.2.1 h1:J041h57zculJKEKf/O2pS4edXGIz+V0YvojvfGXePIk= +github.com/charmbracelet/bubbletea v1.2.1/go.mod h1:viLoDL7hG4njLJSKU2gw7kB3LSEmWsrM80rO1dBJWBI= +github.com/charmbracelet/lipgloss v1.0.0 h1:O7VkGDvqEdGi93X+DeqsQ7PKHDgtQfF8j8/O2qFMQNg= +github.com/charmbracelet/lipgloss v1.0.0/go.mod h1:U5fy9Z+C38obMs+T+tJqst9VGzlOYGj4ri9reL3qUlo= +github.com/charmbracelet/x/ansi v0.8.0 h1:9GTq3xq9caJW8ZrBTe0LIe2fvfLR/bYXKTx2llXn7xE= +github.com/charmbracelet/x/ansi v0.8.0/go.mod h1:wdYl/ONOLHLIVmQaxbIYEC/cRKOQyjTkowiI4blgS9Q= +github.com/charmbracelet/x/term v0.2.0 h1:cNB9Ot9q8I711MyZ7myUR5HFWL/lc3OpU8jZ4hwm0x0= +github.com/charmbracelet/x/term v0.2.0/go.mod h1:GVxgxAbjUrmpvIINHIQnJJKpMlHiZ4cktEQCN6GWyF0= +github.com/chengxilo/virtualterm v1.0.4 h1:Z6IpERbRVlfB8WkOmtbHiDbBANU7cimRIof7mk9/PwM= +github.com/chengxilo/virtualterm v1.0.4/go.mod h1:DyxxBZz/x1iqJjFxTFcr6/x+jSpqN0iwWCOK1q10rlY= +github.com/clipperhouse/stringish v0.1.1 h1:+NSqMOr3GR6k1FdRhhnXrLfztGzuG+VuFDfatpWHKCs= +github.com/clipperhouse/stringish v0.1.1/go.mod h1:v/WhFtE1q0ovMta2+m+UbpZ+2/HEXNWYXQgCt4hdOzA= +github.com/clipperhouse/uax29/v2 v2.5.0 h1:x7T0T4eTHDONxFJsL94uKNKPHrclyFI0lm7+w94cO8U= +github.com/clipperhouse/uax29/v2 v2.5.0/go.mod h1:Wn1g7MK6OoeDT0vL+Q0SQLDz/KpfsVRgg6W7ihQeh4g= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f h1:Y/CXytFA4m6baUTXGLOoWe4PQhGxaX0KpnayAqC48p4= +github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f/go.mod h1:vw97MGsxSvLiUE2X8qFplwetxpGLQrlU1Q9AUEIzCaM= github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM= github.com/fatih/color v1.18.0/go.mod h1:4FelSpRwEGDpQ12mAdzqdOukCy4u8WUtOY6lkT/6HfU= +github.com/lucasb-eyer/go-colorful v1.3.0 h1:2/yBRLdWBZKrf7gB40FoiKfAWYQ0lqNcbuQwVHXptag= +github.com/lucasb-eyer/go-colorful v1.3.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0= github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= -github.com/mattn/go-runewidth v0.0.16 h1:E5ScNMtiwvlvB5paMFdw9p4kSQzbXFikJ5SQO6TULQc= -github.com/mattn/go-runewidth v0.0.16/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= +github.com/mattn/go-localereader v0.0.1 h1:ygSAOl7ZXTx4RdPYinUpg6W99U8jWvWi9Ye2JC/oIi4= +github.com/mattn/go-localereader v0.0.1/go.mod h1:8fBrzywKY7BI3czFoHkuzRoWE9C+EiG4R1k4Cjx5p88= +github.com/mattn/go-runewidth v0.0.19 h1:v++JhqYnZuu5jSKrk9RbgF5v4CGUjqRfBm05byFGLdw= +github.com/mattn/go-runewidth v0.0.19/go.mod h1:XBkDxAl56ILZc9knddidhrOlY5R/pDhgLpndooCuJAs= github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db h1:62I3jR2EmQ4l5rM/4FEfDWcRD+abF5XlKShorW5LRoQ= github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db/go.mod h1:l0dey0ia/Uv7NcFFVbCLtqEBQbrT4OCwCSKTEv6enCw= +github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 h1:ZK8zHtRHOkbHy6Mmr5D264iyp3TiX5OmNcI5cIARiQI= +github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6/go.mod h1:CJlz5H+gyd6CUWT45Oy4q24RdLyn7Md9Vj2/ldJBSIo= +github.com/muesli/cancelreader v0.2.2 h1:3I4Kt4BQjOR54NavqnDogx/MIoWBFa0StPA8ELUXHmA= +github.com/muesli/cancelreader v0.2.2/go.mod h1:3XuTXfFS2VjM+HTLZY9Ak0l6eUKfijIfMUZ4EgX0QYo= +github.com/muesli/termenv v0.15.2 h1:GohcuySI0QmI3wN8Ok9PtKGkgkFIk7y6Vpb5PvrY+Wo= +github.com/muesli/termenv v0.15.2/go.mod h1:Epx+iuz8sNs7mNKhxzH4fWXGNpZwUaJKRS1noLXviQ8= github.com/olekukonko/errors v1.1.0 h1:RNuGIh15QdDenh+hNvKrJkmxxjV4hcS50Db478Ou5sM= github.com/olekukonko/errors v1.1.0/go.mod h1:ppzxA5jBKcO1vIpCXQ9ZqgDh8iwODz6OXIGKU8r5m4Y= github.com/olekukonko/ll v0.0.9 h1:Y+1YqDfVkqMWuEQMclsF9HUR5+a82+dxJuL1HHSRpxI= github.com/olekukonko/ll v0.0.9/go.mod h1:En+sEW0JNETl26+K8eZ6/W4UQ7CYSrrgg/EdIYT2H8g= github.com/olekukonko/tablewriter v1.0.9 h1:XGwRsYLC2bY7bNd93Dk51bcPZksWZmLYuaTHR0FqfL8= github.com/olekukonko/tablewriter v1.0.9/go.mod h1:5c+EBPeSqvXnLLgkm9isDdzR3wjfBkHR9Nhfp3NWrzo= -github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ= github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= github.com/schollz/progressbar/v3 v3.18.0 h1:uXdoHABRFmNIjUfte/Ex7WtuyVslrw2wVPQmCN62HpA= github.com/schollz/progressbar/v3 v3.18.0/go.mod h1:IsO3lpbaGuzh8zIMzgY3+J8l4C8GjO0Y9S69eFvNsec= +github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +golang.org/x/sync v0.9.0 h1:fEo0HyrW1GIgZdpbhCRO0PkJajUS5H9IFUztCgEo2jQ= +golang.org/x/sync v0.9.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.29.0 h1:TPYlXGxvx1MGTn2GiZDhnjPA9wZzZeGKHHmKhHYvgaU= golang.org/x/sys v0.29.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.28.0 h1:/Ts8HFuMR2E6IP/jlo7QVLZHggjKQbhu/7H0LJFr3Gg= golang.org/x/term v0.28.0/go.mod h1:Sw/lC2IAUZ92udQNf3WodGtn4k/XoLyZoh8v/8uiwek= +golang.org/x/text v0.3.8 h1:nAL+RVCQ9uMn3vJZbV+MRnydTJFPf8qqY42YiA6MrqY= +golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/internal/client/anthropic.go b/internal/client/anthropic.go index 773dbfa..235f23a 100644 --- a/internal/client/anthropic.go +++ b/internal/client/anthropic.go @@ -28,8 +28,9 @@ type AnthropicResponse struct { } `json:"content"` Model string `json:"model"` Usage struct { - InputTokens int `json:"input_tokens"` - OutputTokens int `json:"output_tokens"` + InputTokens int `json:"input_tokens"` + CacheReadInputTokens int `json:"cache_read_input_tokens"` + OutputTokens int `json:"output_tokens"` } `json:"usage"` } @@ -53,14 +54,15 @@ type AnthropicStreamChunk struct { PartialJSON *string `json:"partial_json,omitempty"` } `json:"delta,omitempty"` Usage *struct { - InputTokens int `json:"input_tokens"` - OutputTokens int `json:"output_tokens"` + InputTokens int `json:"input_tokens"` + CacheReadInputTokens int `json:"cache_read_input_tokens"` + OutputTokens int `json:"output_tokens"` } `json:"usage,omitempty"` } // AnthropicClient Anthropic 协议客户端 type AnthropicClient struct { - BaseUrl string + EndpointURL string ApiKey string Model string Provider string @@ -85,10 +87,10 @@ func NewAnthropicClient(config types.Input) *AnthropicClient { } return &AnthropicClient{ - BaseUrl: config.BaseUrl, + EndpointURL: config.ResolvedEndpointURL(), ApiKey: config.ApiKey, Model: config.Model, - Provider: "anthropic", + Provider: config.NormalizedProtocol(), Thinking: config.Thinking, httpClient: &http.Client{ Transport: transport, @@ -110,7 +112,7 @@ func (c *AnthropicClient) Request(prompt string, stream bool) (*ResponseMetrics, c.logger.LogTestStart(c.Model, prompt, map[string]interface{}{ "stream": stream, "protocol": c.Provider, - "base_url": c.BaseUrl, + "endpoint_url": c.EndpointURL, }) } @@ -152,7 +154,7 @@ func (c *AnthropicClient) Request(prompt string, stream bool) (*ResponseMetrics, }, err } - req, err := http.NewRequest("POST", c.BaseUrl+"/v1/messages", bytes.NewBuffer(reqBodyBytes)) + req, err := http.NewRequest("POST", c.EndpointURL, bytes.NewBuffer(reqBodyBytes)) if err != nil { // 记录错误日志 if c.logger != nil && c.logger.IsEnabled() { @@ -300,6 +302,7 @@ func (c *AnthropicClient) Request(prompt string, stream bool) (*ResponseMetrics, var fullContent strings.Builder var outputTokens int var inputTokens int + var cachedInputTokens int var streamChunks []string // 用于记录所有流式数据块 // 记录流式响应开始日志 @@ -357,6 +360,7 @@ func (c *AnthropicClient) Request(prompt string, stream bool) (*ResponseMetrics, // 获取 token 统计信息 if chunk.Usage != nil { inputTokens = chunk.Usage.InputTokens + cachedInputTokens = chunk.Usage.CacheReadInputTokens outputTokens = chunk.Usage.OutputTokens } } @@ -383,6 +387,7 @@ func (c *AnthropicClient) Request(prompt string, stream bool) (*ResponseMetrics, "total_time": totalTime.String(), "time_to_first_token": firstTokenTime.String(), "input_tokens": inputTokens, + "cached_input_tokens": cachedInputTokens, "output_tokens": outputTokens, "full_content": fullContent.String(), }) @@ -396,6 +401,7 @@ func (c *AnthropicClient) Request(prompt string, stream bool) (*ResponseMetrics, TLSHandshakeTime: tlsTime, TargetIP: targetIP, PromptTokens: inputTokens, + CachedInputTokens: cachedInputTokens, CompletionTokens: outputTokens, ErrorMessage: "", }, nil @@ -482,6 +488,7 @@ func (c *AnthropicClient) Request(prompt string, stream bool) (*ResponseMetrics, "total_time": totalTime.String(), "output_tokens": anthropicResp.Usage.OutputTokens, "input_tokens": anthropicResp.Usage.InputTokens, + "cached_input_tokens": anthropicResp.Usage.CacheReadInputTokens, "response_id": anthropicResp.ID, "content_length": len(contentText), }) @@ -495,6 +502,7 @@ func (c *AnthropicClient) Request(prompt string, stream bool) (*ResponseMetrics, TLSHandshakeTime: tlsTime, TargetIP: targetIP, PromptTokens: anthropicResp.Usage.InputTokens, + CachedInputTokens: anthropicResp.Usage.CacheReadInputTokens, CompletionTokens: anthropicResp.Usage.OutputTokens, ErrorMessage: "", }, nil diff --git a/internal/client/anthropic_test.go b/internal/client/anthropic_test.go index 6f5c323..715747d 100644 --- a/internal/client/anthropic_test.go +++ b/internal/client/anthropic_test.go @@ -15,7 +15,7 @@ import ( // createTestConfig 创建用于测试的标准配置 func createTestConfig(baseUrl, apiKey, model string, timeout time.Duration, thinking bool) types.Input { return types.Input{ - Protocol: "anthropic", + Protocol: types.ProtocolAnthropicMessages, BaseUrl: baseUrl, ApiKey: apiKey, Model: model, @@ -115,7 +115,7 @@ func TestNewAnthropicClient(t *testing.T) { { name: "valid anthropic client", config: types.Input{ - Protocol: "anthropic", + Protocol: types.ProtocolAnthropicMessages, BaseUrl: "https://api.anthropic.com", ApiKey: "test-key", Model: "claude-3-sonnet-20240229", @@ -123,10 +123,10 @@ func TestNewAnthropicClient(t *testing.T) { Thinking: false, }, want: &AnthropicClient{ - BaseUrl: "https://api.anthropic.com", + EndpointURL: "https://api.anthropic.com/v1/messages", ApiKey: "test-key", Model: "claude-3-sonnet-20240229", - Provider: "anthropic", + Provider: types.ProtocolAnthropicMessages, Thinking: false, }, }, @@ -136,8 +136,8 @@ func TestNewAnthropicClient(t *testing.T) { t.Run(tt.name, func(t *testing.T) { got := NewAnthropicClient(tt.config) - if got.BaseUrl != tt.want.BaseUrl { - t.Errorf("NewAnthropicClient().BaseUrl = %v, want %v", got.BaseUrl, tt.want.BaseUrl) + if got.EndpointURL != tt.want.EndpointURL { + t.Errorf("NewAnthropicClient().EndpointURL = %v, want %v", got.EndpointURL, tt.want.EndpointURL) } if got.ApiKey != tt.want.ApiKey { @@ -162,8 +162,8 @@ func TestNewAnthropicClient(t *testing.T) { func TestAnthropicClient_GetProtocol(t *testing.T) { client := NewAnthropicClient(createTestConfig("https://api.anthropic.com", "test-key", "claude-3-sonnet-20240229", 30*time.Second, false)) - if got := client.GetProtocol(); got != "anthropic" { - t.Errorf("AnthropicClient.GetProtocol() = %v, want %v", got, "anthropic") + if got := client.GetProtocol(); got != types.ProtocolAnthropicMessages { + t.Errorf("AnthropicClient.GetProtocol() = %v, want %v", got, types.ProtocolAnthropicMessages) } } @@ -395,7 +395,7 @@ func TestNewAnthropicClientTimeout(t *testing.T) { { name: "with custom timeout", config: types.Input{ - Protocol: "anthropic", + Protocol: types.ProtocolAnthropicMessages, BaseUrl: "https://api.anthropic.com", ApiKey: "test-key", Model: "claude-3-sonnet", @@ -407,7 +407,7 @@ func TestNewAnthropicClientTimeout(t *testing.T) { { name: "with zero timeout", config: types.Input{ - Protocol: "anthropic", + Protocol: types.ProtocolAnthropicMessages, BaseUrl: "https://api.anthropic.com", ApiKey: "test-key", Model: "claude-3-opus", @@ -419,7 +419,7 @@ func TestNewAnthropicClientTimeout(t *testing.T) { { name: "with long timeout", config: types.Input{ - Protocol: "anthropic", + Protocol: types.ProtocolAnthropicMessages, BaseUrl: "https://custom.api.com", ApiKey: "test-key", Model: "claude-3-haiku", @@ -1308,16 +1308,16 @@ func TestAnthropicClientWithConfig(t *testing.T) { } // 验证其他基本字段 - if client.BaseUrl != "https://api.anthropic.com" { - t.Errorf("Expected BaseUrl = https://api.anthropic.com, got %s", client.BaseUrl) + if client.EndpointURL != "https://api.anthropic.com/v1/messages" { + t.Errorf("Expected EndpointURL = https://api.anthropic.com/v1/messages, got %s", client.EndpointURL) } if client.Model != "claude-3-sonnet" { t.Errorf("Expected Model = claude-3-sonnet, got %s", client.Model) } - if client.Provider != "anthropic" { - t.Errorf("Expected Provider = anthropic, got %s", client.Provider) + if client.Provider != types.ProtocolAnthropicMessages { + t.Errorf("Expected Provider = %s, got %s", types.ProtocolAnthropicMessages, client.Provider) } }) } diff --git a/internal/client/client.go b/internal/client/client.go index 3418cd1..bbf9b9e 100644 --- a/internal/client/client.go +++ b/internal/client/client.go @@ -22,6 +22,7 @@ type ResponseMetrics struct { // 内容指标 PromptTokens int // 输入 token 数量 + CachedInputTokens int // 缓存命中的输入 token 数量 ThinkingTokens int // 思考/推理 token 数量 CompletionTokens int // 输出 token 数量 (用于TPS计算) @@ -39,12 +40,12 @@ type ModelClient interface { // NewClient 根据配置创建客户端 func NewClient(config types.Input, logger *logger.Logger) (ModelClient, error) { - switch config.Protocol { - case "openai": + switch config.NormalizedProtocol() { + case types.ProtocolOpenAICompletions, types.ProtocolOpenAIResponses: client := NewOpenAIClient(config) client.SetLogger(logger) return client, nil - case "anthropic": + case types.ProtocolAnthropicMessages: client := NewAnthropicClient(config) client.SetLogger(logger) return client, nil diff --git a/internal/client/client_test.go b/internal/client/client_test.go index c178076..1f6b899 100644 --- a/internal/client/client_test.go +++ b/internal/client/client_test.go @@ -9,37 +9,69 @@ import ( func TestNewClient(t *testing.T) { tests := []struct { - name string - config types.Input - wantError bool + name string + config types.Input + wantError bool + expectedProtocol string + expectedEndpoint string }{ { - name: "valid openai client", + name: "valid openai completions client", config: types.Input{ - Protocol: "openai", - BaseUrl: "https://api.openai.com", - ApiKey: "test-key", - Model: "gpt-3.5-turbo", - Timeout: 30 * time.Second, + Protocol: types.ProtocolOpenAICompletions, + EndpointURL: "https://api.openai.com/v1/chat/completions", + ApiKey: "test-key", + Model: "gpt-4.1-mini", + Timeout: 30 * time.Second, + }, + wantError: false, + expectedProtocol: types.ProtocolOpenAICompletions, + expectedEndpoint: "https://api.openai.com/v1/chat/completions", + }, + { + name: "valid openai responses client", + config: types.Input{ + Protocol: types.ProtocolOpenAIResponses, + EndpointURL: "https://api.openai.com/v1/responses", + ApiKey: "test-key", + Model: "gpt-4.1-mini", + Timeout: 30 * time.Second, + }, + wantError: false, + expectedProtocol: types.ProtocolOpenAIResponses, + expectedEndpoint: "https://api.openai.com/v1/responses", + }, + { + name: "valid anthropic messages client", + config: types.Input{ + Protocol: types.ProtocolAnthropicMessages, + EndpointURL: "https://api.anthropic.com/v1/messages", + ApiKey: "test-key", + Model: "claude-3-7-sonnet-latest", + Timeout: 30 * time.Second, }, - wantError: false, + wantError: false, + expectedProtocol: types.ProtocolAnthropicMessages, + expectedEndpoint: "https://api.anthropic.com/v1/messages", }, { - name: "valid anthropic client", + name: "legacy provider maps to explicit protocol and endpoint", config: types.Input{ - Protocol: "anthropic", - BaseUrl: "https://api.anthropic.com", + Protocol: "openai", + BaseUrl: "https://api.openai.com", ApiKey: "test-key", - Model: "claude-3-sonnet-20240229", + Model: "gpt-4.1-mini", Timeout: 30 * time.Second, }, - wantError: false, + wantError: false, + expectedProtocol: types.ProtocolOpenAICompletions, + expectedEndpoint: "https://api.openai.com/v1/chat/completions", }, { name: "invalid provider", config: types.Input{ Protocol: "invalid", - BaseUrl: "https://api.test.com", + EndpointURL: "https://api.test.com/v1/anything", ApiKey: "test-key", Model: "test-model", Timeout: 30 * time.Second, @@ -69,50 +101,64 @@ func TestNewClient(t *testing.T) { return } - if client.GetProtocol() != tt.config.Protocol { - t.Errorf("NewClient().GetProtocol() = %v, want %v", client.GetProtocol(), tt.config.Protocol) + if client.GetProtocol() != tt.expectedProtocol { + t.Errorf("NewClient().GetProtocol() = %v, want %v", client.GetProtocol(), tt.expectedProtocol) } if client.GetModel() != tt.config.Model { t.Errorf("NewClient().GetModel() = %v, want %v", client.GetModel(), tt.config.Model) } + + switch typed := client.(type) { + case *OpenAIClient: + if typed.endpointURL != tt.expectedEndpoint { + t.Errorf("NewClient() endpointURL = %v, want %v", typed.endpointURL, tt.expectedEndpoint) + } + case *AnthropicClient: + if typed.EndpointURL != tt.expectedEndpoint { + t.Errorf("NewClient() endpointURL = %v, want %v", typed.EndpointURL, tt.expectedEndpoint) + } + } }) } } func TestNewClientWithTimeout(t *testing.T) { tests := []struct { - name string - config types.Input - wantError bool + name string + config types.Input + wantError bool + expectedProtocol string }{ { - name: "valid openai client with timeout", + name: "valid openai completions client with timeout", config: types.Input{ - Protocol: "openai", - BaseUrl: "https://api.openai.com", - ApiKey: "test-key", - Model: "gpt-3.5-turbo", - Timeout: 10 * time.Second, + Protocol: types.ProtocolOpenAICompletions, + EndpointURL: "https://api.openai.com/v1/chat/completions", + ApiKey: "test-key", + Model: "gpt-4.1-mini", + Timeout: 10 * time.Second, }, - wantError: false, + wantError: false, + expectedProtocol: types.ProtocolOpenAICompletions, }, { name: "valid anthropic client with timeout", config: types.Input{ - Protocol: "anthropic", - BaseUrl: "https://api.anthropic.com", - ApiKey: "test-key", - Model: "claude-3-sonnet", - Timeout: 30 * time.Second, + Protocol: types.ProtocolAnthropicMessages, + EndpointURL: "https://api.anthropic.com/v1/messages", + ApiKey: "test-key", + Model: "claude-3-sonnet", + Timeout: 30 * time.Second, }, - wantError: false, + wantError: false, + expectedProtocol: types.ProtocolAnthropicMessages, }, { name: "invalid provider with timeout", config: types.Input{ Protocol: "invalid", - BaseUrl: "https://api.test.com", + EndpointURL: "https://api.test.com/v1/anything", ApiKey: "test-key", Model: "test-model", Timeout: 5 * time.Second, @@ -142,8 +188,8 @@ func TestNewClientWithTimeout(t *testing.T) { return } - if client.GetProtocol() != tt.config.Protocol { - t.Errorf("NewClient().GetProtocol() = %v, want %v", client.GetProtocol(), tt.config.Protocol) + if client.GetProtocol() != tt.expectedProtocol { + t.Errorf("NewClient().GetProtocol() = %v, want %v", client.GetProtocol(), tt.expectedProtocol) } if client.GetModel() != tt.config.Model { diff --git a/internal/client/openai.go b/internal/client/openai.go index 39e26cc..9f9bddb 100644 --- a/internal/client/openai.go +++ b/internal/client/openai.go @@ -34,12 +34,20 @@ type ThinkingOptions struct { Type string `json:"type"` } +type ResponsesReasoningOptions struct { + Effort string `json:"effort,omitempty"` +} + // CompletionTokensDetails represents detailed completion token usage breakdown type CompletionTokensDetails struct { ReasoningTokens int `json:"reasoning_tokens"` ThinkingTokens int `json:"thinking_tokens"` } +type PromptTokensDetails struct { + CachedTokens int `json:"cached_tokens"` +} + // ChatCompletionRequest represents the request payload for chat completion type ChatCompletionRequest struct { Model string `json:"model"` @@ -49,6 +57,13 @@ type ChatCompletionRequest struct { Thinking *ThinkingOptions `json:"thinking,omitempty"` } +type ResponsesAPIRequest struct { + Model string `json:"model"` + Input string `json:"input"` + Stream bool `json:"stream,omitempty"` + Reasoning *ResponsesReasoningOptions `json:"reasoning,omitempty"` +} + // ChatCompletionResponse represents the response from chat completion type ChatCompletionResponse struct { ID string `json:"id"` @@ -64,13 +79,36 @@ type ChatCompletionResponse struct { FinishReason string `json:"finish_reason"` } `json:"choices"` Usage struct { - PromptTokens int `json:"prompt_tokens"` - CompletionTokens int `json:"completion_tokens"` - TotalTokens int `json:"total_tokens"` + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` + PromptTokensDetails *PromptTokensDetails `json:"prompt_tokens_details,omitempty"` CompletionTokensDetails *CompletionTokensDetails `json:"completion_tokens_details,omitempty"` } `json:"usage"` } +type ResponsesAPIResponse struct { + ID string `json:"id"` + Object string `json:"object"` + CreatedAt int64 `json:"created_at"` + Model string `json:"model"` + Output []struct { + Type string `json:"type"` + Role string `json:"role"` + Content []struct { + Type string `json:"type"` + Text string `json:"text,omitempty"` + } `json:"content"` + } `json:"output"` + Usage struct { + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` + TotalTokens int `json:"total_tokens"` + InputTokensDetails *PromptTokensDetails `json:"input_tokens_details,omitempty"` + OutputTokensDetails *CompletionTokensDetails `json:"output_tokens_details,omitempty"` + } `json:"usage"` +} + // OpenAIErrorResponse represents OpenAI API error response type OpenAIErrorResponse struct { Error struct { @@ -96,13 +134,27 @@ type StreamResponseChunk struct { FinishReason *string `json:"finish_reason"` } `json:"choices"` Usage *struct { - PromptTokens int `json:"prompt_tokens"` - CompletionTokens int `json:"completion_tokens"` - TotalTokens int `json:"total_tokens"` + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` + PromptTokensDetails *PromptTokensDetails `json:"prompt_tokens_details,omitempty"` CompletionTokensDetails *CompletionTokensDetails `json:"completion_tokens_details,omitempty"` } `json:"usage,omitempty"` } +type ResponsesAPIStreamEvent struct { + Type string `json:"type"` + Delta string `json:"delta,omitempty"` + Response *ResponsesAPIResponse `json:"response,omitempty"` + Usage *struct { + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` + TotalTokens int `json:"total_tokens"` + InputTokensDetails *PromptTokensDetails `json:"input_tokens_details,omitempty"` + OutputTokensDetails *CompletionTokensDetails `json:"output_tokens_details,omitempty"` + } `json:"usage,omitempty"` +} + func extractThinkingTokens(details *CompletionTokensDetails) int { if details == nil { return 0 @@ -113,15 +165,172 @@ func extractThinkingTokens(details *CompletionTokensDetails) int { return details.ReasoningTokens } +func extractCachedInputTokens(details *PromptTokensDetails) int { + if details == nil { + return 0 + } + return details.CachedTokens +} + +func (c *OpenAIClient) buildRequestBody(prompt string, stream bool) ([]byte, error) { + if c.Provider == types.ProtocolOpenAIResponses { + reqBody := ResponsesAPIRequest{ + Model: c.Model, + Input: prompt, + Stream: stream, + } + if c.Thinking { + reqBody.Reasoning = &ResponsesReasoningOptions{Effort: "medium"} + } + return json.Marshal(reqBody) + } + + reqBody := ChatCompletionRequest{ + Model: c.Model, + Messages: []ChatCompletionMessage{ + { + Role: "user", + Content: prompt, + }, + }, + Stream: stream, + } + + if stream { + reqBody.StreamOptions = &StreamOptions{ + IncludeUsage: true, + } + } + + if c.Thinking { + reqBody.Thinking = &ThinkingOptions{ + Type: "enabled", + } + } + + return json.Marshal(reqBody) +} + +func (c *OpenAIClient) parseResponsesStream(resp *http.Response, t0 time.Time, dnsTime, connectTime, tlsTime time.Duration, targetIP string) (*ResponseMetrics, error) { + scanner := bufio.NewScanner(resp.Body) + firstTokenTime := time.Duration(0) + gotFirst := false + var completionTokens int + var promptTokens int + var cachedInputTokens int + var thinkingTokens int + var streamChunks []string + + for scanner.Scan() { + line := scanner.Text() + if !strings.HasPrefix(line, "data: ") { + continue + } + data := strings.TrimPrefix(line, "data: ") + if data == "[DONE]" { + break + } + if c.logger != nil && c.logger.IsEnabled() { + streamChunks = append(streamChunks, data) + } + + var event ResponsesAPIStreamEvent + if err := json.Unmarshal([]byte(data), &event); err != nil { + continue + } + + if !gotFirst && event.Delta != "" { + firstTokenTime = time.Since(t0) + gotFirst = true + } + + if event.Usage != nil { + promptTokens = event.Usage.InputTokens + completionTokens = event.Usage.OutputTokens + cachedInputTokens = extractCachedInputTokens(event.Usage.InputTokensDetails) + thinkingTokens = extractThinkingTokens(event.Usage.OutputTokensDetails) + } + + if event.Response != nil { + promptTokens = event.Response.Usage.InputTokens + completionTokens = event.Response.Usage.OutputTokens + cachedInputTokens = extractCachedInputTokens(event.Response.Usage.InputTokensDetails) + thinkingTokens = extractThinkingTokens(event.Response.Usage.OutputTokensDetails) + } + } + + if err := scanner.Err(); err != nil { + if c.logger != nil && c.logger.IsEnabled() { + c.logger.Error(c.Model, "Responses stream scanning failed", err) + } + return nil, err + } + + totalTime := time.Since(t0) + if c.logger != nil && c.logger.IsEnabled() { + c.logger.LogResponse(c.Model, logger.ResponseData{ + StatusCode: resp.StatusCode, + StreamChunks: streamChunks, + }) + } + + return &ResponseMetrics{ + TimeToFirstToken: firstTokenTime, + TotalTime: totalTime, + DNSTime: dnsTime, + ConnectTime: connectTime, + TLSHandshakeTime: tlsTime, + TargetIP: targetIP, + PromptTokens: promptTokens, + CachedInputTokens: cachedInputTokens, + CompletionTokens: completionTokens, + ThinkingTokens: thinkingTokens, + ErrorMessage: "", + }, nil +} + +func (c *OpenAIClient) parseResponsesNonStream(responseData []byte, totalTime, dnsTime, connectTime, tlsTime time.Duration, targetIP string) (*ResponseMetrics, error) { + var apiResp ResponsesAPIResponse + if err := json.Unmarshal(responseData, &apiResp); err != nil { + if c.logger != nil && c.logger.IsEnabled() { + c.logger.Error(c.Model, "Failed to parse responses API JSON", err) + } + return &ResponseMetrics{ + TimeToFirstToken: 0, + TotalTime: totalTime, + DNSTime: dnsTime, + ConnectTime: connectTime, + TLSHandshakeTime: tlsTime, + TargetIP: targetIP, + CompletionTokens: 0, + ErrorMessage: fmt.Sprintf("JSON parsing error: %s", err.Error()), + }, err + } + + return &ResponseMetrics{ + TimeToFirstToken: totalTime, + TotalTime: totalTime, + DNSTime: dnsTime, + ConnectTime: connectTime, + TLSHandshakeTime: tlsTime, + TargetIP: targetIP, + PromptTokens: apiResp.Usage.InputTokens, + CachedInputTokens: extractCachedInputTokens(apiResp.Usage.InputTokensDetails), + CompletionTokens: apiResp.Usage.OutputTokens, + ThinkingTokens: extractThinkingTokens(apiResp.Usage.OutputTokensDetails), + ErrorMessage: "", + }, nil +} + // OpenAIClient OpenAI 协议客户端 type OpenAIClient struct { - httpClient *http.Client - baseURL string - apiKey string - Model string - Provider string - Thinking bool // 是否开启 thinking 模式 - logger *logger.Logger + httpClient *http.Client + endpointURL string + apiKey string + Model string + Provider string + Thinking bool // 是否开启 thinking 模式 + logger *logger.Logger } // NewOpenAIClient 根据配置创建 OpenAI 客户端 @@ -133,10 +342,7 @@ type OpenAIClient struct { // 网络栈性能,包括 DNS 解析、TCP 连接建立、TLS 握手等。 // - DisableCompression=false: 启用压缩以节省带宽 func NewOpenAIClient(config types.Input) *OpenAIClient { - baseUrl := config.BaseUrl - if baseUrl == "" { - baseUrl = "https://api.openai.com" - } + endpointURL := config.ResolvedEndpointURL() // 禁用连接复用以确保每个请求都是独立的 transport := &http.Transport{ @@ -149,12 +355,12 @@ func NewOpenAIClient(config types.Input) *OpenAIClient { Transport: transport, Timeout: config.Timeout, }, - baseURL: baseUrl, - apiKey: config.ApiKey, - Model: config.Model, - Provider: "openai", - Thinking: config.Thinking, - logger: nil, + endpointURL: endpointURL, + apiKey: config.ApiKey, + Model: config.Model, + Provider: config.NormalizedProtocol(), + Thinking: config.Thinking, + logger: nil, } } @@ -168,38 +374,13 @@ func (c *OpenAIClient) Request(prompt string, stream bool) (*ResponseMetrics, er // 记录请求开始日志 if c.logger != nil && c.logger.IsEnabled() { c.logger.LogTestStart(c.Model, prompt, map[string]interface{}{ - "stream": stream, - "protocol": c.Provider, - "base_url": c.baseURL, + "stream": stream, + "protocol": c.Provider, + "endpoint_url": c.endpointURL, }) } - reqBody := ChatCompletionRequest{ - Model: c.Model, - Messages: []ChatCompletionMessage{ - { - Role: "user", - Content: prompt, - }, - }, - Stream: stream, - } - - // 当启用流模式时,添加 stream_options 参数 - if stream { - reqBody.StreamOptions = &StreamOptions{ - IncludeUsage: true, - } - } - - // 当启用 thinking 模式时,添加 reasoning 参数 - if c.Thinking { - reqBody.Thinking = &ThinkingOptions{ - Type: "enabled", - } - } - - jsonData, err := json.Marshal(reqBody) + jsonData, err := c.buildRequestBody(prompt, stream) if err != nil { // 记录错误日志 if c.logger != nil && c.logger.IsEnabled() { @@ -208,8 +389,7 @@ func (c *OpenAIClient) Request(prompt string, stream bool) (*ResponseMetrics, er return nil, err } - url := fmt.Sprintf("%s/chat/completions", c.baseURL) - req, err := http.NewRequestWithContext(context.Background(), "POST", url, bytes.NewBuffer(jsonData)) + req, err := http.NewRequestWithContext(context.Background(), "POST", c.endpointURL, bytes.NewBuffer(jsonData)) if err != nil { // 记录错误日志 if c.logger != nil && c.logger.IsEnabled() { @@ -350,12 +530,17 @@ func (c *OpenAIClient) Request(prompt string, stream bool) (*ResponseMetrics, er }, fmt.Errorf(errorMessage) } + if c.Provider == types.ProtocolOpenAIResponses { + return c.parseResponsesStream(resp, t0, dnsTime, connectTime, tlsTime, targetIP) + } + scanner := bufio.NewScanner(resp.Body) firstTokenTime := time.Duration(0) gotFirst := false var fullContent strings.Builder var completionTokens int var promptTokens int + var cachedInputTokens int var thinkingTokens int var streamChunks []string // 用于记录所有流式数据块 @@ -408,6 +593,7 @@ func (c *OpenAIClient) Request(prompt string, stream bool) (*ResponseMetrics, er if chunk.Usage != nil { promptTokens = chunk.Usage.PromptTokens completionTokens = chunk.Usage.CompletionTokens + cachedInputTokens = extractCachedInputTokens(chunk.Usage.PromptTokensDetails) thinkingTokens = extractThinkingTokens(chunk.Usage.CompletionTokensDetails) } } @@ -434,6 +620,7 @@ func (c *OpenAIClient) Request(prompt string, stream bool) (*ResponseMetrics, er "total_time": totalTime.String(), "time_to_first_token": firstTokenTime.String(), "prompt_tokens": promptTokens, + "cached_input_tokens": cachedInputTokens, "completion_tokens": completionTokens, "thinking_tokens": thinkingTokens, "full_content": fullContent.String(), @@ -441,16 +628,17 @@ func (c *OpenAIClient) Request(prompt string, stream bool) (*ResponseMetrics, er } return &ResponseMetrics{ - TimeToFirstToken: firstTokenTime, - TotalTime: totalTime, - DNSTime: dnsTime, - ConnectTime: connectTime, - TLSHandshakeTime: tlsTime, - TargetIP: targetIP, - PromptTokens: promptTokens, - CompletionTokens: completionTokens, - ThinkingTokens: thinkingTokens, - ErrorMessage: "", + TimeToFirstToken: firstTokenTime, + TotalTime: totalTime, + DNSTime: dnsTime, + ConnectTime: connectTime, + TLSHandshakeTime: tlsTime, + TargetIP: targetIP, + PromptTokens: promptTokens, + CachedInputTokens: cachedInputTokens, + CompletionTokens: completionTokens, + ThinkingTokens: thinkingTokens, + ErrorMessage: "", }, nil } else { // 非流式请求 @@ -532,6 +720,10 @@ func (c *OpenAIClient) Request(prompt string, stream bool) (*ResponseMetrics, er }, fmt.Errorf("empty response body") } + if c.Provider == types.ProtocolOpenAIResponses { + return c.parseResponsesNonStream(responseData, totalTime, dnsTime, connectTime, tlsTime, targetIP) + } + var chatResp ChatCompletionResponse if err := json.Unmarshal(responseData, &chatResp); err != nil { // 记录JSON解析错误日志 @@ -553,17 +745,18 @@ func (c *OpenAIClient) Request(prompt string, stream bool) (*ResponseMetrics, er thinkingTokens := extractThinkingTokens(chatResp.Usage.CompletionTokensDetails) return &ResponseMetrics{ - TimeToFirstToken: totalTime, // 非流式模式下,所有token一次性返回,TTFT等于总时间 - TotalTime: totalTime, - DNSTime: dnsTime, - ConnectTime: connectTime, - TLSHandshakeTime: tlsTime, - TargetIP: targetIP, - PromptTokens: chatResp.Usage.PromptTokens, - CompletionTokens: chatResp.Usage.CompletionTokens, - ThinkingTokens: thinkingTokens, - ErrorMessage: "", - }, nil + TimeToFirstToken: totalTime, // 非流式模式下,所有token一次性返回,TTFT等于总时间 + TotalTime: totalTime, + DNSTime: dnsTime, + ConnectTime: connectTime, + TLSHandshakeTime: tlsTime, + TargetIP: targetIP, + PromptTokens: chatResp.Usage.PromptTokens, + CachedInputTokens: extractCachedInputTokens(chatResp.Usage.PromptTokensDetails), + CompletionTokens: chatResp.Usage.CompletionTokens, + ThinkingTokens: thinkingTokens, + ErrorMessage: "", + }, nil } } diff --git a/internal/client/openai_test.go b/internal/client/openai_test.go index 18e46b7..23505e1 100644 --- a/internal/client/openai_test.go +++ b/internal/client/openai_test.go @@ -2,6 +2,7 @@ package client import ( "fmt" + "io" "net/http" "net/http/httptest" "reflect" @@ -16,7 +17,7 @@ import ( // createOpenAITestConfig 创建用于测试的标准 OpenAI 配置 func createOpenAITestConfig(baseUrl, apiKey, model string, timeout time.Duration, thinking bool) types.Input { return types.Input{ - Protocol: "openai", + Protocol: types.ProtocolOpenAICompletions, BaseUrl: baseUrl, ApiKey: apiKey, Model: model, @@ -25,10 +26,21 @@ func createOpenAITestConfig(baseUrl, apiKey, model string, timeout time.Duration } } +func createOpenAIResponsesTestConfig(endpointURL, apiKey, model string, timeout time.Duration, thinking bool) types.Input { + return types.Input{ + Protocol: types.ProtocolOpenAIResponses, + EndpointURL: endpointURL, + ApiKey: apiKey, + Model: model, + Timeout: timeout, + Thinking: thinking, + } +} + // createOpenAITestConfigWithDefaultTimeout 创建带默认超时的测试配置 func createOpenAITestConfigWithDefaultTimeout(baseUrl, apiKey, model string) types.Input { return types.Input{ - Protocol: "openai", + Protocol: types.ProtocolOpenAICompletions, BaseUrl: baseUrl, ApiKey: apiKey, Model: model, @@ -51,10 +63,10 @@ func TestNewOpenAIClient(t *testing.T) { apiKey: "test-key", model: "gpt-3.5-turbo", want: &OpenAIClient{ - baseURL: "https://custom.api.com", - apiKey: "test-key", - Model: "gpt-3.5-turbo", - Provider: "openai", + endpointURL: "https://custom.api.com/v1/chat/completions", + apiKey: "test-key", + Model: "gpt-3.5-turbo", + Provider: types.ProtocolOpenAICompletions, }, }, { @@ -63,10 +75,10 @@ func TestNewOpenAIClient(t *testing.T) { apiKey: "test-key", model: "gpt-4", want: &OpenAIClient{ - baseURL: "https://api.openai.com", - apiKey: "test-key", - Model: "gpt-4", - Provider: "openai", + endpointURL: "https://api.openai.com/v1/chat/completions", + apiKey: "test-key", + Model: "gpt-4", + Provider: types.ProtocolOpenAICompletions, }, }, } @@ -75,8 +87,8 @@ func TestNewOpenAIClient(t *testing.T) { t.Run(tt.name, func(t *testing.T) { got := NewOpenAIClient(createOpenAITestConfigWithDefaultTimeout(tt.baseUrl, tt.apiKey, tt.model)) - if got.baseURL != tt.want.baseURL { - t.Errorf("NewOpenAIClient().baseURL = %v, want %v", got.baseURL, tt.want.baseURL) + if got.endpointURL != tt.want.endpointURL { + t.Errorf("NewOpenAIClient().endpointURL = %v, want %v", got.endpointURL, tt.want.endpointURL) } if got.apiKey != tt.want.apiKey { @@ -172,17 +184,17 @@ func TestOpenAIClient_ConnectionReuse(t *testing.T) { // 创建一个测试服务器,记录连接数 connectionCount := 0 var connMu sync.Mutex - + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // 每个请求到达时记录 connMu.Lock() connectionCount++ currentCount := connectionCount connMu.Unlock() - + w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) - + // 返回简单的非流式响应 response := fmt.Sprintf(`{"id":"chatcmpl-%d","choices":[{"message":{"content":"Response %d"}}],"usage":{"completion_tokens":1}}`, currentCount, currentCount) w.Write([]byte(response)) @@ -196,7 +208,7 @@ func TestOpenAIClient_ConnectionReuse(t *testing.T) { if !ok { t.Fatal("Expected client to use http.Transport") } - + if !transport.DisableKeepAlives { t.Error("Expected DisableKeepAlives to be true to prevent connection reuse") } @@ -209,12 +221,12 @@ func TestOpenAIClient_ConnectionReuse(t *testing.T) { t.Errorf("Request %d failed: %v", i, err) continue } - + if metrics == nil { t.Errorf("Request %d returned nil metrics", i) continue } - + // 验证每个请求都有合理的时间指标 if metrics.TotalTime <= 0 { t.Errorf("Request %d has invalid TotalTime: %v", i, metrics.TotalTime) @@ -225,7 +237,7 @@ func TestOpenAIClient_ConnectionReuse(t *testing.T) { connMu.Lock() finalCount := connectionCount connMu.Unlock() - + if finalCount != requestCount { t.Errorf("Expected %d requests to reach server, got %d", requestCount, finalCount) } @@ -235,17 +247,17 @@ func TestOpenAIClient_ConnectionReuse(t *testing.T) { func TestOpenAIClient_NoConnectionReuse(t *testing.T) { // 验证客户端的 Transport 配置确实禁用了连接复用 client := NewOpenAIClient(createOpenAITestConfig("https://api.openai.com", "test-key", "test-model", 0, false)) - + transport, ok := client.httpClient.Transport.(*http.Transport) if !ok { t.Fatal("Expected client to use http.Transport") } - + // 关键验证:DisableKeepAlives 应该为 true if !transport.DisableKeepAlives { t.Error("DisableKeepAlives should be true to prevent connection reuse, which could affect timing measurements") } - + // DisableCompression 应该为 false(我们想要压缩以节省带宽) if transport.DisableCompression { t.Error("DisableCompression should be false to enable compression") @@ -254,12 +266,12 @@ func TestOpenAIClient_NoConnectionReuse(t *testing.T) { func TestOpenAIClient_ConnectionReuseImpactOnTiming(t *testing.T) { // 这个测试演示为什么禁用连接复用对于准确的性能测量很重要 - + // 创建一个有一定延迟的测试服务器 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // 模拟网络延迟 time.Sleep(50 * time.Millisecond) - + w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) w.Write([]byte(`{"choices":[{"message":{"content":"test"}}],"usage":{"completion_tokens":1}}`)) @@ -270,14 +282,14 @@ func TestOpenAIClient_ConnectionReuseImpactOnTiming(t *testing.T) { clientWithoutReuse := &OpenAIClient{ httpClient: &http.Client{ Transport: &http.Transport{ - DisableKeepAlives: true, // 禁用连接复用 + DisableKeepAlives: true, // 禁用连接复用 }, Timeout: 30 * time.Second, }, - baseURL: server.URL, - apiKey: "test-key", - Model: "test-model", - Provider: "openai", + endpointURL: server.URL, + apiKey: "test-key", + Model: "test-model", + Provider: types.ProtocolOpenAICompletions, } clientWithReuse := &OpenAIClient{ @@ -287,10 +299,10 @@ func TestOpenAIClient_ConnectionReuseImpactOnTiming(t *testing.T) { }, Timeout: 30 * time.Second, }, - baseURL: server.URL, - apiKey: "test-key", - Model: "test-model", - Provider: "openai", + endpointURL: server.URL, + apiKey: "test-key", + Model: "test-model", + Provider: types.ProtocolOpenAICompletions, } // 测试两个客户端的性能差异 @@ -304,14 +316,14 @@ func TestOpenAIClient_ConnectionReuseImpactOnTiming(t *testing.T) { } totalTimes = append(totalTimes, metrics.TotalTime) } - + // 由于每次都要重新建立连接,时间应该相对稳定且包含连接开销 for i, duration := range totalTimes { if duration < 40*time.Millisecond { t.Errorf("Request %d duration %v is too short, expected at least 40ms (including connection overhead)", i, duration) } } - + t.Logf("Without reuse - timing results: %v", totalTimes) }) @@ -330,7 +342,7 @@ func TestOpenAIClient_ConnectionReuseImpactOnTiming(t *testing.T) { } t.Logf("With reuse - First request: %v, Second request: %v", metrics1.TotalTime, metrics2.TotalTime) - + // 这个测试主要是为了说明问题,不是为了断言特定的性能差异 // 因为在测试环境中,本地连接可能不会显示显著差异 }) @@ -339,8 +351,8 @@ func TestOpenAIClient_ConnectionReuseImpactOnTiming(t *testing.T) { func TestOpenAIClient_GetProtocol(t *testing.T) { client := NewOpenAIClient(createOpenAITestConfig("https://api.openai.com", "test-key", "gpt-3.5-turbo", 0, false)) - if got := client.GetProtocol(); got != "openai" { - t.Errorf("OpenAIClient.GetProtocol() = %v, want %v", got, "openai") + if got := client.GetProtocol(); got != types.ProtocolOpenAICompletions { + t.Errorf("OpenAIClient.GetProtocol() = %v, want %v", got, types.ProtocolOpenAICompletions) } } @@ -394,9 +406,9 @@ func TestOpenAIClient_Request_MalformedJSON(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) - - if strings.Contains(r.Header.Get("Accept"), "text/event-stream") || - r.Header.Get("Stream") == "true" { + + if strings.Contains(r.Header.Get("Accept"), "text/event-stream") || + r.Header.Get("Stream") == "true" { // 流式响应:发送畸形的 JSON w.Write([]byte("data: {invalid json}\n\n")) w.Write([]byte("data: {\"choices\":[{\"delta\":{\"content\":\"valid content\"}}]}\n\n")) @@ -430,6 +442,72 @@ func TestOpenAIClient_Request_MalformedJSON(t *testing.T) { }) } +func TestOpenAIClient_Request_OpenAIResponses_NonStream(t *testing.T) { + var requestBody string + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, err := io.ReadAll(r.Body) + if err != nil { + t.Fatalf("failed to read request body: %v", err) + } + requestBody = string(body) + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"id":"resp_123","object":"response","created_at":123,"model":"gpt-4.1-mini","output":[{"type":"message","role":"assistant","content":[{"type":"output_text","text":"hello"}]}],"usage":{"input_tokens":12,"input_tokens_details":{"cached_tokens":3},"output_tokens":7,"output_tokens_details":{"reasoning_tokens":2},"total_tokens":19}}`)) + })) + defer server.Close() + + client := NewOpenAIClient(createOpenAIResponsesTestConfig(server.URL, "test-key", "gpt-4.1-mini", 30*time.Second, false)) + metrics, err := client.Request("hello from responses", false) + if err != nil { + t.Fatalf("Request() unexpected error: %v", err) + } + if strings.Contains(requestBody, "messages") { + t.Fatalf("responses request should not use chat-completions payload: %s", requestBody) + } + if !strings.Contains(requestBody, `"input":"hello from responses"`) { + t.Fatalf("responses request body missing input field: %s", requestBody) + } + if metrics.PromptTokens != 12 || metrics.CachedInputTokens != 3 || metrics.CompletionTokens != 7 || metrics.ThinkingTokens != 2 { + t.Fatalf("unexpected metrics: %+v", metrics) + } +} + +func TestOpenAIClient_Request_OpenAIResponses_Stream(t *testing.T) { + var requestBody string + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, err := io.ReadAll(r.Body) + if err != nil { + t.Fatalf("failed to read request body: %v", err) + } + requestBody = string(body) + + w.Header().Set("Content-Type", "text/event-stream") + w.WriteHeader(http.StatusOK) + fmt.Fprint(w, "data: {\"type\":\"response.created\"}\n\n") + fmt.Fprint(w, "data: {\"type\":\"response.output_text.delta\",\"delta\":\"Hello\"}\n\n") + fmt.Fprint(w, "data: {\"type\":\"response.output_text.delta\",\"delta\":\" world\"}\n\n") + fmt.Fprint(w, "data: {\"type\":\"response.completed\",\"response\":{\"usage\":{\"input_tokens\":10,\"input_tokens_details\":{\"cached_tokens\":4},\"output_tokens\":6,\"output_tokens_details\":{\"reasoning_tokens\":1},\"total_tokens\":16}}}\n\n") + fmt.Fprint(w, "data: [DONE]\n\n") + })) + defer server.Close() + + client := NewOpenAIClient(createOpenAIResponsesTestConfig(server.URL, "test-key", "gpt-4.1-mini", 30*time.Second, false)) + metrics, err := client.Request("stream me", true) + if err != nil { + t.Fatalf("Request() unexpected error: %v", err) + } + if !strings.Contains(requestBody, `"stream":true`) { + t.Fatalf("responses stream request body missing stream flag: %s", requestBody) + } + if metrics.TimeToFirstToken <= 0 { + t.Fatalf("expected positive TTFT, got %v", metrics.TimeToFirstToken) + } + if metrics.PromptTokens != 10 || metrics.CachedInputTokens != 4 || metrics.CompletionTokens != 6 || metrics.ThinkingTokens != 1 { + t.Fatalf("unexpected stream metrics: %+v", metrics) + } +} + func TestOpenAIClient_Request_BodyReadError(t *testing.T) { // 创建一个在读取响应体时出错的服务器 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -461,7 +539,7 @@ func TestOpenAIClient_Request_ScannerError(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/event-stream") w.WriteHeader(http.StatusOK) - + // 发送一个非常长的行,可能导致 scanner 错误 longLine := strings.Repeat("x", 1024*1024) // 1MB 的数据 fmt.Fprintf(w, "data: %s\n\n", longLine) @@ -543,14 +621,14 @@ func TestOpenAIClient_ConcurrentRequests(t *testing.T) { time.Sleep(50 * time.Millisecond) // 模拟慢响应 w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) - + response := `{"id":"test","choices":[{"message":{"content":"concurrent response"}}],"usage":{"completion_tokens":2}}` w.Write([]byte(response)) })) defer server.Close() client := NewOpenAIClient(createOpenAITestConfig(server.URL, "test-key", "test-model", 0, false)) - + // 并发执行多个请求 numRequests := 10 var wg sync.WaitGroup @@ -562,9 +640,9 @@ func TestOpenAIClient_ConcurrentRequests(t *testing.T) { wg.Add(1) go func(id int) { defer wg.Done() - + metrics, err := client.Request(fmt.Sprintf("concurrent test %d", id), false) - + mu.Lock() if err != nil { errors = append(errors, err) @@ -586,7 +664,7 @@ func TestOpenAIClient_ConcurrentRequests(t *testing.T) { t.Errorf("Concurrent request error: %v", err) } } - + if successCount != numRequests { t.Errorf("Expected %d successful requests, got %d", numRequests, successCount) } @@ -603,12 +681,12 @@ func TestOpenAIClient_Request_TimeoutHandling(t *testing.T) { // 创建一个超时时间很短的客户端 client := NewOpenAIClient(createOpenAITestConfig(server.URL, "test-key", "test-model", 100*time.Millisecond, false)) - + _, err := client.Request("timeout test", false) if err == nil { t.Error("Expected timeout error but got none") } - + // 确保错误信息包含超时相关内容 if !strings.Contains(err.Error(), "timeout") && !strings.Contains(err.Error(), "context deadline exceeded") { t.Errorf("Expected timeout-related error, got: %v", err) @@ -619,7 +697,7 @@ func TestOpenAIClient_Request_EmptyChoicesArray(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) - + if strings.Contains(r.Header.Get("Accept"), "text/event-stream") { // 流式响应:发送空的 choices w.Write([]byte("data: {\"choices\":[]}\n\n")) @@ -655,10 +733,10 @@ func TestOpenAIClient_Request_EmptyChoicesArray(t *testing.T) { // TestOpenAIClient_Request_ThinkingContent 测试 ThinkingContent 字段对 TTFT 统计的影响 func TestOpenAIClient_Request_ThinkingContent(t *testing.T) { tests := []struct { - name string - streamResponses []string - expectedTTFTValid bool - description string + name string + streamResponses []string + expectedTTFTValid bool + description string }{ { name: "reasoning content first, then regular content", @@ -732,7 +810,7 @@ func TestOpenAIClient_Request_ThinkingContent(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/event-stream") w.WriteHeader(http.StatusOK) - + // 添加小延迟以确保 TTFT 有意义的值 for i, response := range tt.streamResponses { if i > 0 { @@ -751,7 +829,7 @@ func TestOpenAIClient_Request_ThinkingContent(t *testing.T) { defer server.Close() client := NewOpenAIClient(createOpenAITestConfig(server.URL, "test-key", "test-model", 0, false)) - + metrics, err := client.Request("test prompt", true) if err != nil { t.Errorf("Request failed: %v", err) @@ -768,7 +846,7 @@ func TestOpenAIClient_Request_ThinkingContent(t *testing.T) { t.Errorf("Expected valid TTFT, got %v. %s", metrics.TimeToFirstToken, tt.description) } if metrics.TimeToFirstToken > metrics.TotalTime { - t.Errorf("TTFT (%v) should not exceed total time (%v). %s", + t.Errorf("TTFT (%v) should not exceed total time (%v). %s", metrics.TimeToFirstToken, metrics.TotalTime, tt.description) } } @@ -783,38 +861,38 @@ func TestOpenAIClient_Request_TTFTAccuracy(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/event-stream") w.WriteHeader(http.StatusOK) - + // 第一个 chunk: 只有空的 delta fmt.Fprint(w, "data: {\"choices\":[{\"delta\":{}}]}\n\n") if flusher, ok := w.(http.Flusher); ok { flusher.Flush() } - + // 等待 50ms 后发送第一个有内容的 chunk time.Sleep(50 * time.Millisecond) fmt.Fprint(w, "data: {\"choices\":[{\"delta\":{\"reasoning_content\":\"Thinking...\"}}]}\n\n") if flusher, ok := w.(http.Flusher); ok { flusher.Flush() } - + // 等待 30ms 后发送常规内容 time.Sleep(30 * time.Millisecond) fmt.Fprint(w, "data: {\"choices\":[{\"delta\":{\"content\":\"Hello\"}}]}\n\n") if flusher, ok := w.(http.Flusher); ok { flusher.Flush() } - + // 结束 fmt.Fprint(w, "data: [DONE]\n\n") })) defer server.Close() client := NewOpenAIClient(createOpenAITestConfig(server.URL, "test-key", "test-model", 0, false)) - + start := time.Now() metrics, err := client.Request("test prompt", true) totalDuration := time.Since(start) - + if err != nil { t.Errorf("Request failed: %v", err) return @@ -837,11 +915,11 @@ func TestOpenAIClient_Request_TTFTAccuracy(t *testing.T) { // TTFT 应该小于总时间 if metrics.TimeToFirstToken >= metrics.TotalTime { - t.Errorf("TTFT (%v) should be less than total time (%v)", + t.Errorf("TTFT (%v) should be less than total time (%v)", metrics.TimeToFirstToken, metrics.TotalTime) } - t.Logf("Actual timing - TTFT: %v, Total: %v, External total: %v", + t.Logf("Actual timing - TTFT: %v, Total: %v, External total: %v", metrics.TimeToFirstToken, metrics.TotalTime, totalDuration) } @@ -944,7 +1022,7 @@ func TestOpenAIClient_Request_ErrorHandlingFixes(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/event-stream") w.WriteHeader(http.StatusOK) - + // 发送一些无效的 JSON 数据块,然后发送有效的 w.Write([]byte("data: {invalid json}\n\n")) w.Write([]byte("data: {\"choices\":[{\"delta\":{\"content\":\"valid\"}}]}\n\n")) diff --git a/internal/config/config.go b/internal/config/config.go new file mode 100644 index 0000000..c989c08 --- /dev/null +++ b/internal/config/config.go @@ -0,0 +1,101 @@ +package config + +import ( + "encoding/json" + "errors" + "os" + "path/filepath" +) + +const ( + appDirName = ".ait" + configJSON = "config.json" + tasksJSON = "tasks.json" + historyDirName = "history" +) + +type Config struct { + SaveAPIKey bool `json:"save_api_key"` + LastSelectedTaskID string `json:"last_selected_task_id,omitempty"` + DefaultProtocol string `json:"default_protocol,omitempty"` +} + +func Load() (*Config, error) { + path, err := ConfigPath() + if err != nil { + return nil, err + } + + data, err := os.ReadFile(path) + if errors.Is(err, os.ErrNotExist) { + return &Config{}, nil + } + if err != nil { + return nil, err + } + + var cfg Config + if err := json.Unmarshal(data, &cfg); err != nil { + return nil, err + } + return &cfg, nil +} + +func (c *Config) Save() error { + path, err := ConfigPath() + if err != nil { + return err + } + if _, err := EnsureAppDir(); err != nil { + return err + } + + data, err := json.MarshalIndent(c, "", " ") + if err != nil { + return err + } + return os.WriteFile(path, data, 0o644) +} + +func AppDir() (string, error) { + homeDir, err := os.UserHomeDir() + if err != nil { + return "", err + } + return filepath.Join(homeDir, appDirName), nil +} + +func EnsureAppDir() (string, error) { + dir, err := AppDir() + if err != nil { + return "", err + } + if err := os.MkdirAll(dir, 0o755); err != nil { + return "", err + } + return dir, nil +} + +func ConfigPath() (string, error) { + dir, err := AppDir() + if err != nil { + return "", err + } + return filepath.Join(dir, configJSON), nil +} + +func TasksPath() (string, error) { + dir, err := AppDir() + if err != nil { + return "", err + } + return filepath.Join(dir, tasksJSON), nil +} + +func HistoryDir() (string, error) { + dir, err := AppDir() + if err != nil { + return "", err + } + return filepath.Join(dir, historyDirName), nil +} diff --git a/internal/config/config_test.go b/internal/config/config_test.go new file mode 100644 index 0000000..622c961 --- /dev/null +++ b/internal/config/config_test.go @@ -0,0 +1,51 @@ +package config + +import ( + "path/filepath" + "testing" +) + +func TestLoadReturnsDefaultWhenFileMissing(t *testing.T) { + t.Setenv("HOME", t.TempDir()) + + cfg, err := Load() + if err != nil { + t.Fatalf("Load() returned unexpected error: %v", err) + } + if cfg == nil { + t.Fatal("Load() returned nil config") + } + if cfg.SaveAPIKey { + t.Fatal("expected SaveAPIKey to default to false") + } +} + +func TestConfigSaveAndLoadRoundTrip(t *testing.T) { + homeDir := t.TempDir() + t.Setenv("HOME", homeDir) + + cfg := &Config{ + SaveAPIKey: true, + LastSelectedTaskID: "task-1", + DefaultProtocol: "openai-responses", + } + if err := cfg.Save(); err != nil { + t.Fatalf("Save() returned unexpected error: %v", err) + } + + path, err := ConfigPath() + if err != nil { + t.Fatalf("ConfigPath() returned unexpected error: %v", err) + } + if want := filepath.Join(homeDir, ".ait", "config.json"); path != want { + t.Fatalf("expected config path %s, got %s", want, path) + } + + loaded, err := Load() + if err != nil { + t.Fatalf("Load() returned unexpected error: %v", err) + } + if !loaded.SaveAPIKey || loaded.LastSelectedTaskID != "task-1" || loaded.DefaultProtocol != "openai-responses" { + t.Fatalf("unexpected loaded config: %+v", loaded) + } +} diff --git a/internal/runner/runner.go b/internal/runner/runner.go index 58d819e..47f20c9 100644 --- a/internal/runner/runner.go +++ b/internal/runner/runner.go @@ -18,8 +18,12 @@ type Runner struct { input types.Input upload *upload.Uploader client client.ModelClient + stopCh chan struct{} + stopOnce sync.Once } +type RequestDoneCallback func(metrics *client.ResponseMetrics, index int, err error) + // NewRunner 创建新的性能测试执行器 func NewRunner(taskID string, config types.Input) (*Runner, error) { // 创建日志记录器(如果启用) @@ -38,22 +42,46 @@ func NewRunner(taskID string, config types.Input) (*Runner, error) { client: client, input: config, upload: upload.New(), + stopCh: make(chan struct{}), }, nil } +func (r *Runner) Stop() { + r.stopOnce.Do(func() { + close(r.stopCh) + }) +} + +func (r *Runner) acquireSlot(ch chan int) bool { + select { + case <-r.stopCh: + return false + case ch <- 1: + return true + } +} + +func calculateCacheHitRate(metrics *client.ResponseMetrics) float64 { + if metrics == nil || metrics.PromptTokens <= 0 { + return 0 + } + return float64(metrics.CachedInputTokens) / float64(metrics.PromptTokens) +} + // Run 执行性能测试,返回结果数据 func (r *Runner) Run() (*types.ReportData, error) { var wg sync.WaitGroup results := make([]*client.ResponseMetrics, r.input.Count) start := time.Now() ch := make(chan int, r.input.Concurrency) - - completed := int64(0) - failed := int64(0) + launchedCount := 0 for i := 0; i < r.input.Count; i++ { + if !r.acquireSlot(ch) { + break + } + launchedCount++ wg.Add(1) - ch <- 1 go func(idx int) { defer wg.Done() defer func() { <-ch }() @@ -63,7 +91,6 @@ func (r *Runner) Run() (*types.ReportData, error) { metrics, err := r.client.Request(currentPrompt, r.input.Stream) if err != nil { - atomic.AddInt64(&failed, 1) // 即使有错误,也尝试保存 metrics(如果有的话) if metrics != nil { results[idx] = metrics @@ -76,15 +103,51 @@ func (r *Runner) Run() (*types.ReportData, error) { if metrics.ErrorMessage == "" && r.upload != nil { r.upload.UploadReport(r.taskID, metrics, r.input) } - - atomic.AddInt64(&completed, 1) }(i) } wg.Wait() elapsed := time.Since(start) // 计算并返回结果 - return r.calculateResult(results, elapsed), nil + return r.calculateResult(results, elapsed, launchedCount), nil +} + +func (r *Runner) RunWithCallback(cb RequestDoneCallback) (*types.ReportData, error) { + var wg sync.WaitGroup + results := make([]*client.ResponseMetrics, r.input.Count) + start := time.Now() + ch := make(chan int, r.input.Concurrency) + launchedCount := 0 + + for i := 0; i < r.input.Count; i++ { + if !r.acquireSlot(ch) { + break + } + launchedCount++ + wg.Add(1) + go func(idx int) { + defer wg.Done() + defer func() { <-ch }() + + currentPrompt := r.input.PromptSource.GetRandomContent() + metrics, err := r.client.Request(currentPrompt, r.input.Stream) + if metrics != nil { + results[idx] = metrics + } + + if err == nil && metrics != nil && metrics.ErrorMessage == "" && r.upload != nil { + r.upload.UploadReport(r.taskID, metrics, r.input) + } + + if cb != nil { + cb(metrics, idx, err) + } + }(i) + } + + wg.Wait() + elapsed := time.Since(start) + return r.calculateResult(results, elapsed, launchedCount), nil } // RunWithProgress 运行性能测试并实时显示进度 @@ -103,9 +166,12 @@ func (r *Runner) RunWithProgress(progressCallback func(types.StatsData)) (*types var tlsHandshakeTimes []time.Duration var outputTokenCounts []int var inputTokenCounts []int + var cachedInputTokenCounts []int var thinkingTokenCounts []int + var cacheHitRates []float64 var errorMessages []string var ttftsMutex sync.Mutex + launchedCount := 0 // 启动进度更新 goroutine stopProgress := make(chan bool) @@ -126,8 +192,10 @@ func (r *Runner) RunWithProgress(progressCallback func(types.StatsData)) (*types ConnectTimes: make([]time.Duration, len(connectTimes)), TLSHandshakeTimes: make([]time.Duration, len(tlsHandshakeTimes)), InputTokenCounts: make([]int, len(inputTokenCounts)), + CachedInputTokenCounts: make([]int, len(cachedInputTokenCounts)), OutputTokenCounts: make([]int, len(outputTokenCounts)), ThinkingTokenCounts: make([]int, len(thinkingTokenCounts)), + CacheHitRates: make([]float64, len(cacheHitRates)), ErrorMessages: make([]string, len(errorMessages)), StartTime: start, ElapsedTime: time.Since(start), @@ -138,8 +206,10 @@ func (r *Runner) RunWithProgress(progressCallback func(types.StatsData)) (*types copy(stats.ConnectTimes, connectTimes) copy(stats.TLSHandshakeTimes, tlsHandshakeTimes) copy(stats.InputTokenCounts, inputTokenCounts) + copy(stats.CachedInputTokenCounts, cachedInputTokenCounts) copy(stats.OutputTokenCounts, outputTokenCounts) copy(stats.ThinkingTokenCounts, thinkingTokenCounts) + copy(stats.CacheHitRates, cacheHitRates) copy(stats.ErrorMessages, errorMessages) ttftsMutex.Unlock() @@ -151,8 +221,11 @@ func (r *Runner) RunWithProgress(progressCallback func(types.StatsData)) (*types }() for i := 0; i < r.input.Count; i++ { + if !r.acquireSlot(ch) { + break + } + launchedCount++ wg.Add(1) - ch <- 1 go func(idx int) { defer wg.Done() defer func() { <-ch }() @@ -178,7 +251,9 @@ func (r *Runner) RunWithProgress(progressCallback func(types.StatsData)) (*types tlsHandshakeTimes = append(tlsHandshakeTimes, metrics.TLSHandshakeTime) outputTokenCounts = append(outputTokenCounts, metrics.CompletionTokens) inputTokenCounts = append(inputTokenCounts, metrics.PromptTokens) + cachedInputTokenCounts = append(cachedInputTokenCounts, metrics.CachedInputTokens) thinkingTokenCounts = append(thinkingTokenCounts, metrics.ThinkingTokens) + cacheHitRates = append(cacheHitRates, calculateCacheHitRate(metrics)) ttftsMutex.Unlock() } return @@ -194,7 +269,9 @@ func (r *Runner) RunWithProgress(progressCallback func(types.StatsData)) (*types tlsHandshakeTimes = append(tlsHandshakeTimes, metrics.TLSHandshakeTime) outputTokenCounts = append(outputTokenCounts, metrics.CompletionTokens) inputTokenCounts = append(inputTokenCounts, metrics.PromptTokens) + cachedInputTokenCounts = append(cachedInputTokenCounts, metrics.CachedInputTokens) thinkingTokenCounts = append(thinkingTokenCounts, metrics.ThinkingTokens) + cacheHitRates = append(cacheHitRates, calculateCacheHitRate(metrics)) ttftsMutex.Unlock() if metrics.ErrorMessage == "" && r.upload != nil { @@ -219,8 +296,10 @@ func (r *Runner) RunWithProgress(progressCallback func(types.StatsData)) (*types ConnectTimes: make([]time.Duration, len(connectTimes)), TLSHandshakeTimes: make([]time.Duration, len(tlsHandshakeTimes)), InputTokenCounts: make([]int, len(inputTokenCounts)), + CachedInputTokenCounts: make([]int, len(cachedInputTokenCounts)), OutputTokenCounts: make([]int, len(outputTokenCounts)), ThinkingTokenCounts: make([]int, len(thinkingTokenCounts)), + CacheHitRates: make([]float64, len(cacheHitRates)), ErrorMessages: make([]string, len(errorMessages)), StartTime: start, ElapsedTime: elapsed, @@ -231,58 +310,55 @@ func (r *Runner) RunWithProgress(progressCallback func(types.StatsData)) (*types copy(finalStats.ConnectTimes, connectTimes) copy(finalStats.TLSHandshakeTimes, tlsHandshakeTimes) copy(finalStats.InputTokenCounts, inputTokenCounts) + copy(finalStats.CachedInputTokenCounts, cachedInputTokenCounts) copy(finalStats.OutputTokenCounts, outputTokenCounts) copy(finalStats.ThinkingTokenCounts, thinkingTokenCounts) + copy(finalStats.CacheHitRates, cacheHitRates) copy(finalStats.ErrorMessages, errorMessages) ttftsMutex.Unlock() progressCallback(finalStats) // 计算并返回结果 - return r.calculateResult(results, elapsed), nil + return r.calculateResult(results, elapsed, launchedCount), nil } // calculateResult 计算性能统计结果 -func (r *Runner) calculateResult(results []*client.ResponseMetrics, totalTime time.Duration) *types.ReportData { - if len(results) == 0 { +func (r *Runner) calculateResult(results []*client.ResponseMetrics, totalTime time.Duration, totalRequests ...int) *types.ReportData { + requestCount := r.input.Count + if len(totalRequests) > 0 { + requestCount = totalRequests[0] + } + if requestCount <= 0 || len(results) == 0 { return &types.ReportData{} } - // 分别收集所有结果和成功结果 allResults := make([]*client.ResponseMetrics, 0) successResults := make([]*client.ResponseMetrics, 0) - for _, result := range results { - if result != nil { - allResults = append(allResults, result) - // 只有没有错误且有token输出的才算成功 - if result.ErrorMessage == "" && result.CompletionTokens > 0 { - successResults = append(successResults, result) - } + if result == nil { + continue + } + allResults = append(allResults, result) + if result.ErrorMessage == "" && result.CompletionTokens > 0 { + successResults = append(successResults, result) } } - - // 如果完全没有数据,返回空结果 if len(allResults) == 0 { return &types.ReportData{} } - // 使用成功结果计算业务指标,使用所有结果计算网络指标 validResults := successResults if len(validResults) == 0 { - // 如果没有成功的结果,至少尝试使用有部分数据的结果 for _, result := range allResults { if result.TotalTime > 0 { validResults = append(validResults, result) } } - - // 如果仍然没有可用数据 if len(validResults) == 0 { return &types.ReportData{} } } - // 初始化最小值和最大值 firstResult := validResults[0] minTTFT := firstResult.TimeToFirstToken maxTTFT := firstResult.TimeToFirstToken @@ -292,8 +368,12 @@ func (r *Runner) calculateResult(results []*client.ResponseMetrics, totalTime ti maxOutputTokens := firstResult.CompletionTokens minInputTokens := firstResult.PromptTokens maxInputTokens := firstResult.PromptTokens + minCachedInputTokens := firstResult.CachedInputTokens + maxCachedInputTokens := firstResult.CachedInputTokens minThinkingTokens := firstResult.ThinkingTokens maxThinkingTokens := firstResult.ThinkingTokens + minCacheHitRate := calculateCacheHitRate(firstResult) + maxCacheHitRate := minCacheHitRate minDNSTime := firstResult.DNSTime maxDNSTime := firstResult.DNSTime @@ -302,7 +382,6 @@ func (r *Runner) calculateResult(results []*client.ResponseMetrics, totalTime ti minTLSTime := firstResult.TLSHandshakeTime maxTLSTime := firstResult.TLSHandshakeTime - // 计算第一个结果的 TPS 和 TPOT var firstTPS float64 if firstResult.TotalTime.Seconds() > 0 { firstTPS = float64(firstResult.CompletionTokens) / firstResult.TotalTime.Seconds() @@ -310,7 +389,6 @@ func (r *Runner) calculateResult(results []*client.ResponseMetrics, totalTime ti minTPS := firstTPS maxTPS := firstTPS - // 计算第一个结果的总吞吐量 TPS (输入 + 输出 tokens / 时间) var firstTotalThroughputTPS float64 if firstResult.TotalTime.Seconds() > 0 { totalTokens := firstResult.PromptTokens + firstResult.CompletionTokens @@ -319,17 +397,14 @@ func (r *Runner) calculateResult(results []*client.ResponseMetrics, totalTime ti minTotalThroughputTPS := firstTotalThroughputTPS maxTotalThroughputTPS := firstTotalThroughputTPS - // 计算第一个结果的 TPOT (Time Per Output Token) var firstTPOT time.Duration if firstResult.CompletionTokens > 1 { - // TPOT = (总耗时 - 首token耗时) / (总token数 - 1) remainingTime := firstResult.TotalTime - firstResult.TimeToFirstToken firstTPOT = remainingTime / time.Duration(firstResult.CompletionTokens-1) } minTPOT := firstTPOT maxTPOT := firstTPOT - // 获取目标IP(使用第一个有效结果的IP) var targetIP string for _, result := range validResults { if result.TargetIP != "" { @@ -338,16 +413,14 @@ func (r *Runner) calculateResult(results []*client.ResponseMetrics, totalTime ti } } - // 累积统计 var sumTTFT, sumTotalTime time.Duration var sumDNSTime, sumConnectTime, sumTLSTime time.Duration - var sumOutputTokens, sumInputTokens int + var sumOutputTokens, sumInputTokens, sumCachedInputTokens int var sumThinkingTokens int var sumTPOT time.Duration - var sumTotalThroughputTPS float64 + var sumCacheHitRate, sumTotalThroughputTPS float64 for _, result := range validResults { - // TTFT 统计 sumTTFT += result.TimeToFirstToken if result.TimeToFirstToken < minTTFT { minTTFT = result.TimeToFirstToken @@ -356,7 +429,6 @@ func (r *Runner) calculateResult(results []*client.ResponseMetrics, totalTime ti maxTTFT = result.TimeToFirstToken } - // 总时间统计 sumTotalTime += result.TotalTime if result.TotalTime < minTotalTime { minTotalTime = result.TotalTime @@ -365,14 +437,11 @@ func (r *Runner) calculateResult(results []*client.ResponseMetrics, totalTime ti maxTotalTime = result.TotalTime } - // TPOT 统计 var tpot time.Duration if result.CompletionTokens > 1 { - // TPOT = (总耗时 - 首token耗时) / (总token数 - 1) remainingTime := result.TotalTime - result.TimeToFirstToken tpot = remainingTime / time.Duration(result.CompletionTokens-1) sumTPOT += tpot - if tpot < minTPOT || minTPOT == 0 { minTPOT = tpot } @@ -381,7 +450,6 @@ func (r *Runner) calculateResult(results []*client.ResponseMetrics, totalTime ti } } - // 网络指标统计 sumDNSTime += result.DNSTime if result.DNSTime < minDNSTime { minDNSTime = result.DNSTime @@ -406,7 +474,6 @@ func (r *Runner) calculateResult(results []*client.ResponseMetrics, totalTime ti maxTLSTime = result.TLSHandshakeTime } - // Output Token 统计 sumOutputTokens += result.CompletionTokens if result.CompletionTokens < minOutputTokens { minOutputTokens = result.CompletionTokens @@ -415,7 +482,6 @@ func (r *Runner) calculateResult(results []*client.ResponseMetrics, totalTime ti maxOutputTokens = result.CompletionTokens } - // Input Token 统计 sumInputTokens += result.PromptTokens if result.PromptTokens < minInputTokens { minInputTokens = result.PromptTokens @@ -424,7 +490,14 @@ func (r *Runner) calculateResult(results []*client.ResponseMetrics, totalTime ti maxInputTokens = result.PromptTokens } - // Thinking Token 统计 + sumCachedInputTokens += result.CachedInputTokens + if result.CachedInputTokens < minCachedInputTokens { + minCachedInputTokens = result.CachedInputTokens + } + if result.CachedInputTokens > maxCachedInputTokens { + maxCachedInputTokens = result.CachedInputTokens + } + sumThinkingTokens += result.ThinkingTokens if result.ThinkingTokens < minThinkingTokens { minThinkingTokens = result.ThinkingTokens @@ -433,7 +506,15 @@ func (r *Runner) calculateResult(results []*client.ResponseMetrics, totalTime ti maxThinkingTokens = result.ThinkingTokens } - // TPS 统计 + cacheHitRate := calculateCacheHitRate(result) + sumCacheHitRate += cacheHitRate + if cacheHitRate < minCacheHitRate { + minCacheHitRate = cacheHitRate + } + if cacheHitRate > maxCacheHitRate { + maxCacheHitRate = cacheHitRate + } + var tps float64 if result.TotalTime.Seconds() > 0 { tps = float64(result.CompletionTokens) / result.TotalTime.Seconds() @@ -445,7 +526,6 @@ func (r *Runner) calculateResult(results []*client.ResponseMetrics, totalTime ti maxTPS = tps } - // 总吞吐量 TPS 统计 (输入 + 输出 tokens / 时间) var totalThroughputTPS float64 if result.TotalTime.Seconds() > 0 { totalTokens := result.PromptTokens + result.CompletionTokens @@ -461,33 +541,31 @@ func (r *Runner) calculateResult(results []*client.ResponseMetrics, totalTime ti } validCount := len(validResults) + errorRate := float64(requestCount-validCount) / float64(requestCount) * 100 + successRate := float64(validCount) / float64(requestCount) * 100 + resolvedEndpoint := r.input.ResolvedEndpointURL() - // 计算错误率和成功率 - errorRate := float64(r.input.Count-validCount) / float64(r.input.Count) * 100 - successRate := float64(validCount) / float64(r.input.Count) * 100 - - // 如果没有有效结果,返回基础结果 if validCount == 0 { return &types.ReportData{ - TotalRequests: r.input.Count, + TotalRequests: requestCount, Concurrency: r.input.Concurrency, TotalTime: totalTime, IsStream: r.input.Stream, IsThinking: r.input.Thinking, + Protocol: r.input.NormalizedProtocol(), + EndpointURL: resolvedEndpoint, + BaseUrl: resolvedEndpoint, ErrorRate: errorRate, SuccessRate: successRate, } } - // 计算各项指标的平均值 - // 注意:时间指标可以直接用总和除以数量来计算平均值,因为时间是可加性的 avgTTFT := sumTTFT / time.Duration(validCount) avgTotalTime := sumTotalTime / time.Duration(validCount) avgDNSTime := sumDNSTime / time.Duration(validCount) avgConnectTime := sumConnectTime / time.Duration(validCount) avgTLSTime := sumTLSTime / time.Duration(validCount) - // 计算TPOT平均值 - 只对有效的TPOT计算结果求平均 var avgTPOT time.Duration validTPOTCount := 0 for _, result := range validResults { @@ -499,53 +577,47 @@ func (r *Runner) calculateResult(results []*client.ResponseMetrics, totalTime ti avgTPOT = sumTPOT / time.Duration(validTPOTCount) } - // Token数量计算 avgOutputTokens := sumOutputTokens / validCount avgInputTokens := sumInputTokens / validCount + avgCachedInputTokens := sumCachedInputTokens / validCount avgThinkingTokens := sumThinkingTokens / validCount + avgCacheHitRate := sumCacheHitRate / float64(validCount) - // TPS是比率指标,需要特殊处理: - // 错误方式:float64(sumTokens) / sumTotalTime.Seconds() - 这相当于计算总体批处理的TPS - // 正确方式:先计算每个请求的TPS,然后求算术平均值 - 这反映单个请求的平均性能 var sumTPS float64 for _, result := range validResults { if result.TotalTime.Seconds() > 0 { - tps := float64(result.CompletionTokens) / result.TotalTime.Seconds() - sumTPS += tps + sumTPS += float64(result.CompletionTokens) / result.TotalTime.Seconds() } } avgTPS := sumTPS / float64(validCount) - - // 计算总吞吐量 TPS 的平均值 avgTotalThroughputTPS := sumTotalThroughputTPS / float64(validCount) - // 计算方差 - 第一遍遍历计算平均值后的方差 var varianceSumTotalTime, varianceSumTTFT, varianceSumTPOT float64 - var varianceSumInputTokens, varianceSumOutputTokens, varianceSumThinkingTokens float64 - var varianceSumTPS, varianceSumTotalThroughputTPS float64 + var varianceSumInputTokens, varianceSumCachedInputTokens, varianceSumOutputTokens, varianceSumThinkingTokens float64 + var varianceSumCacheHitRate, varianceSumTPS, varianceSumTotalThroughputTPS float64 for _, result := range validResults { - // 总时间方差 diffTotalTime := float64(result.TotalTime - avgTotalTime) varianceSumTotalTime += diffTotalTime * diffTotalTime - // TTFT 方差 diffTTFT := float64(result.TimeToFirstToken - avgTTFT) varianceSumTTFT += diffTTFT * diffTTFT - // Input Token 方差 diffInputTokens := float64(result.PromptTokens - avgInputTokens) varianceSumInputTokens += diffInputTokens * diffInputTokens - // Output Token 方差 + diffCachedInputTokens := float64(result.CachedInputTokens - avgCachedInputTokens) + varianceSumCachedInputTokens += diffCachedInputTokens * diffCachedInputTokens + diffOutputTokens := float64(result.CompletionTokens - avgOutputTokens) varianceSumOutputTokens += diffOutputTokens * diffOutputTokens - // Thinking Token 方差 diffThinkingTokens := float64(result.ThinkingTokens - avgThinkingTokens) varianceSumThinkingTokens += diffThinkingTokens * diffThinkingTokens - // TPS 方差 + diffCacheHitRate := calculateCacheHitRate(result) - avgCacheHitRate + varianceSumCacheHitRate += diffCacheHitRate * diffCacheHitRate + var tps float64 if result.TotalTime.Seconds() > 0 { tps = float64(result.CompletionTokens) / result.TotalTime.Seconds() @@ -553,7 +625,6 @@ func (r *Runner) calculateResult(results []*client.ResponseMetrics, totalTime ti diffTPS := tps - avgTPS varianceSumTPS += diffTPS * diffTPS - // 总吞吐量 TPS 方差 var totalThroughputTPS float64 if result.TotalTime.Seconds() > 0 { totalTokens := result.PromptTokens + result.CompletionTokens @@ -563,7 +634,6 @@ func (r *Runner) calculateResult(results []*client.ResponseMetrics, totalTime ti varianceSumTotalThroughputTPS += diffTotalThroughputTPS * diffTotalThroughputTPS } - // TPOT 方差计算 - 只对有效的 TPOT 计算 for _, result := range validResults { if result.CompletionTokens > 1 { remainingTime := result.TotalTime - result.TimeToFirstToken @@ -573,7 +643,6 @@ func (r *Runner) calculateResult(results []*client.ResponseMetrics, totalTime ti } } - // 计算最终标准差值(方差的平方根) stdDevTotalTime := time.Duration(math.Sqrt(varianceSumTotalTime / float64(validCount))) stdDevTTFT := time.Duration(math.Sqrt(varianceSumTTFT / float64(validCount))) stdDevTPOT := time.Duration(0) @@ -581,22 +650,25 @@ func (r *Runner) calculateResult(results []*client.ResponseMetrics, totalTime ti stdDevTPOT = time.Duration(math.Sqrt(varianceSumTPOT / float64(validTPOTCount))) } stdDevInputTokenCount := math.Sqrt(varianceSumInputTokens / float64(validCount)) + stdDevCachedInputTokenCount := math.Sqrt(varianceSumCachedInputTokens / float64(validCount)) stdDevOutputTokenCount := math.Sqrt(varianceSumOutputTokens / float64(validCount)) stdDevThinkingTokenCount := math.Sqrt(varianceSumThinkingTokens / float64(validCount)) + stdDevCacheHitRate := math.Sqrt(varianceSumCacheHitRate / float64(validCount)) stdDevTPS := math.Sqrt(varianceSumTPS / float64(validCount)) stdDevTotalThroughputTPS := math.Sqrt(varianceSumTotalThroughputTPS / float64(validCount)) - result := &types.ReportData{ - TotalRequests: r.input.Count, + return &types.ReportData{ + TotalRequests: requestCount, Concurrency: r.input.Concurrency, TotalTime: totalTime, IsStream: r.input.Stream, IsThinking: r.input.Thinking, - // 时间指标 + Protocol: r.input.NormalizedProtocol(), + EndpointURL: resolvedEndpoint, + BaseUrl: resolvedEndpoint, AvgTotalTime: avgTotalTime, MinTotalTime: minTotalTime, MaxTotalTime: maxTotalTime, - // 网络指标 AvgDNSTime: avgDNSTime, MinDNSTime: minDNSTime, MaxDNSTime: maxDNSTime, @@ -607,7 +679,6 @@ func (r *Runner) calculateResult(results []*client.ResponseMetrics, totalTime ti MinTLSHandshakeTime: minTLSTime, MaxTLSHandshakeTime: maxTLSTime, TargetIP: targetIP, - // 服务性能指标 AvgTTFT: avgTTFT, MinTTFT: minTTFT, MaxTTFT: maxTTFT, @@ -617,32 +688,35 @@ func (r *Runner) calculateResult(results []*client.ResponseMetrics, totalTime ti AvgInputTokenCount: avgInputTokens, MinInputTokenCount: minInputTokens, MaxInputTokenCount: maxInputTokens, + AvgCachedInputTokenCount: avgCachedInputTokens, + MinCachedInputTokenCount: minCachedInputTokens, + MaxCachedInputTokenCount: maxCachedInputTokens, AvgOutputTokenCount: avgOutputTokens, MinOutputTokenCount: minOutputTokens, MaxOutputTokenCount: maxOutputTokens, AvgThinkingTokenCount: avgThinkingTokens, MinThinkingTokenCount: minThinkingTokens, MaxThinkingTokenCount: maxThinkingTokens, + AvgCacheHitRate: avgCacheHitRate, + MinCacheHitRate: minCacheHitRate, + MaxCacheHitRate: maxCacheHitRate, AvgTPS: avgTPS, MinTPS: minTPS, MaxTPS: maxTPS, - // 总吞吐量指标 AvgTotalThroughputTPS: avgTotalThroughputTPS, MinTotalThroughputTPS: minTotalThroughputTPS, MaxTotalThroughputTPS: maxTotalThroughputTPS, - // 标准差指标 StdDevTotalTime: stdDevTotalTime, StdDevTTFT: stdDevTTFT, StdDevTPOT: stdDevTPOT, StdDevInputTokenCount: stdDevInputTokenCount, + StdDevCachedInputTokenCount: stdDevCachedInputTokenCount, StdDevOutputTokenCount: stdDevOutputTokenCount, StdDevThinkingTokenCount: stdDevThinkingTokenCount, + StdDevCacheHitRate: stdDevCacheHitRate, StdDevTPS: stdDevTPS, StdDevTotalThroughputTPS: stdDevTotalThroughputTPS, - // 可靠性指标 ErrorRate: errorRate, SuccessRate: successRate, } - - return result } diff --git a/internal/runner/runner_callback_test.go b/internal/runner/runner_callback_test.go new file mode 100644 index 0000000..9e09901 --- /dev/null +++ b/internal/runner/runner_callback_test.go @@ -0,0 +1,172 @@ +package runner + +import ( + "math" + "sync/atomic" + "testing" + "time" + + "github.com/yinxulai/ait/internal/client" + "github.com/yinxulai/ait/internal/types" +) + +func TestRunner_RunWithCallback_InvokesCallbackForEachRequest(t *testing.T) { + input := types.Input{ + Protocol: types.ProtocolOpenAICompletions, + BaseUrl: "https://api.openai.com", + ApiKey: "test-key", + Model: "gpt-4.1-mini", + Concurrency: 2, + Count: 4, + PromptSource: createTestPromptSource("test prompt"), + Stream: true, + } + + mockClient := &MockClient{ + responseMetrics: &client.ResponseMetrics{ + TotalTime: 200 * time.Millisecond, + TimeToFirstToken: 50 * time.Millisecond, + PromptTokens: 100, + CachedInputTokens: 25, + CompletionTokens: 100, + ThinkingTokens: 20, + DNSTime: 10 * time.Millisecond, + ConnectTime: 20 * time.Millisecond, + TLSHandshakeTime: 30 * time.Millisecond, + TargetIP: "8.8.8.8", + }, + } + + runner := NewRunnerWithClient(input, mockClient) + var callbackCount atomic.Int64 + + result, err := runner.RunWithCallback(func(metrics *client.ResponseMetrics, index int, err error) { + if err != nil { + t.Errorf("callback received unexpected error: %v", err) + } + if metrics == nil { + t.Errorf("callback metrics should not be nil for index %d", index) + return + } + callbackCount.Add(1) + }) + if err != nil { + t.Fatalf("RunWithCallback() returned unexpected error: %v", err) + } + if callbackCount.Load() != int64(input.Count) { + t.Fatalf("expected %d callbacks, got %d", input.Count, callbackCount.Load()) + } + if result.AvgCachedInputTokenCount != 25 { + t.Fatalf("expected AvgCachedInputTokenCount 25, got %d", result.AvgCachedInputTokenCount) + } + if result.AvgCacheHitRate != 0.25 { + t.Fatalf("expected AvgCacheHitRate 0.25, got %f", result.AvgCacheHitRate) + } +} + +func TestRunner_Stop_StopsLaunchingNewRequests(t *testing.T) { + input := types.Input{ + Protocol: types.ProtocolOpenAICompletions, + BaseUrl: "https://api.openai.com", + ApiKey: "test-key", + Model: "gpt-4.1-mini", + Concurrency: 1, + Count: 20, + PromptSource: createTestPromptSource("test prompt"), + Stream: false, + } + + mockClient := &MockClient{ + requestDelay: 40 * time.Millisecond, + responseMetrics: &client.ResponseMetrics{ + TotalTime: 40 * time.Millisecond, + TimeToFirstToken: 40 * time.Millisecond, + PromptTokens: 80, + CachedInputTokens: 10, + CompletionTokens: 30, + }, + } + + runner := NewRunnerWithClient(input, mockClient) + resultCh := make(chan *types.ReportData, 1) + errCh := make(chan error, 1) + var callbackCount atomic.Int64 + + go func() { + result, err := runner.RunWithCallback(func(metrics *client.ResponseMetrics, index int, err error) { + if callbackCount.Add(1) == 1 { + runner.Stop() + } + }) + resultCh <- result + errCh <- err + }() + + select { + case err := <-errCh: + if err != nil { + t.Fatalf("RunWithCallback() returned unexpected error: %v", err) + } + case <-time.After(3 * time.Second): + t.Fatal("RunWithCallback() did not finish after Stop()") + } + + result := <-resultCh + callCount := mockClient.GetCallCount() + if callCount >= int64(input.Count) { + t.Fatalf("expected Stop() to stop before launching all requests, got %d calls", callCount) + } + if int64(result.TotalRequests) != callCount { + t.Fatalf("expected TotalRequests %d to match launched calls, got %d", callCount, result.TotalRequests) + } + if callbackCount.Load() != callCount { + t.Fatalf("expected callback count %d, got %d", callCount, callbackCount.Load()) + } +} + +func TestRunner_CalculateResult_CacheHitRate(t *testing.T) { + input := types.Input{ + Protocol: types.ProtocolOpenAICompletions, + BaseUrl: "https://api.openai.com", + ApiKey: "test-key", + Model: "gpt-4.1-mini", + Concurrency: 1, + Count: 2, + PromptSource: createTestPromptSource("test prompt"), + Stream: true, + } + + runner := NewRunnerWithClient(input, &MockClient{}) + results := []*client.ResponseMetrics{ + { + TotalTime: 2 * time.Second, + TimeToFirstToken: 500 * time.Millisecond, + PromptTokens: 100, + CachedInputTokens: 50, + CompletionTokens: 100, + TargetIP: "1.1.1.1", + }, + { + TotalTime: 1 * time.Second, + TimeToFirstToken: 250 * time.Millisecond, + PromptTokens: 80, + CachedInputTokens: 20, + CompletionTokens: 60, + TargetIP: "1.1.1.1", + }, + } + + result := runner.calculateResult(results, 3*time.Second, 2) + if result.AvgCachedInputTokenCount != 35 { + t.Fatalf("expected AvgCachedInputTokenCount 35, got %d", result.AvgCachedInputTokenCount) + } + if result.MinCachedInputTokenCount != 20 || result.MaxCachedInputTokenCount != 50 { + t.Fatalf("unexpected cached input token min/max: %d/%d", result.MinCachedInputTokenCount, result.MaxCachedInputTokenCount) + } + if math.Abs(result.AvgCacheHitRate-0.375) > 0.00001 { + t.Fatalf("expected AvgCacheHitRate 0.375, got %f", result.AvgCacheHitRate) + } + if math.Abs(result.MinCacheHitRate-0.25) > 0.00001 || math.Abs(result.MaxCacheHitRate-0.5) > 0.00001 { + t.Fatalf("unexpected cache hit rate min/max: %f/%f", result.MinCacheHitRate, result.MaxCacheHitRate) + } +} diff --git a/internal/runner/runner_test.go b/internal/runner/runner_test.go index 4301087..e5f227d 100644 --- a/internal/runner/runner_test.go +++ b/internal/runner/runner_test.go @@ -103,6 +103,7 @@ func NewRunnerWithClient(input types.Input, client client.ModelClient) *Runner { input: input, client: client, upload: upload.New(), + stopCh: make(chan struct{}), } } diff --git a/internal/task/history.go b/internal/task/history.go new file mode 100644 index 0000000..ae0063b --- /dev/null +++ b/internal/task/history.go @@ -0,0 +1,89 @@ +package task + +import ( + "encoding/json" + "errors" + "os" + "path/filepath" + + "github.com/yinxulai/ait/internal/config" + "github.com/yinxulai/ait/internal/types" +) + +func AppendRun(taskID string, run types.TaskRunSummary) error { + runs, err := loadHistoryFile(taskID) + if err != nil && !errors.Is(err, os.ErrNotExist) { + return err + } + if run.TaskID == "" { + run.TaskID = taskID + } + runs = append(runs, run) + return saveHistoryFile(taskID, runs) +} + +func LoadHistory(taskID string, limit int) ([]types.TaskRunSummary, error) { + runs, err := loadHistoryFile(taskID) + if errors.Is(err, os.ErrNotExist) { + return []types.TaskRunSummary{}, nil + } + if err != nil { + return nil, err + } + + reversed := make([]types.TaskRunSummary, 0, len(runs)) + for i := len(runs) - 1; i >= 0; i-- { + reversed = append(reversed, runs[i]) + } + if limit > 0 && len(reversed) > limit { + reversed = reversed[:limit] + } + return reversed, nil +} + +func loadHistoryFile(taskID string) ([]types.TaskRunSummary, error) { + path, err := historyPath(taskID) + if err != nil { + return nil, err + } + + data, err := os.ReadFile(path) + if errors.Is(err, os.ErrNotExist) { + return []types.TaskRunSummary{}, os.ErrNotExist + } + if err != nil { + return nil, err + } + + var runs []types.TaskRunSummary + if err := json.Unmarshal(data, &runs); err != nil { + return nil, err + } + return runs, nil +} + +func saveHistoryFile(taskID string, runs []types.TaskRunSummary) error { + dir, err := config.HistoryDir() + if err != nil { + return err + } + if err := os.MkdirAll(dir, 0o755); err != nil { + return err + } + + data, err := json.MarshalIndent(runs, "", " ") + if err != nil { + return err + } + + path := filepath.Join(dir, taskID+".json") + return os.WriteFile(path, data, 0o644) +} + +func historyPath(taskID string) (string, error) { + dir, err := config.HistoryDir() + if err != nil { + return "", err + } + return filepath.Join(dir, taskID+".json"), nil +} diff --git a/internal/task/history_test.go b/internal/task/history_test.go new file mode 100644 index 0000000..6426fac --- /dev/null +++ b/internal/task/history_test.go @@ -0,0 +1,41 @@ +package task + +import ( + "testing" + "time" + + "github.com/yinxulai/ait/internal/types" +) + +func TestAppendRunAndLoadHistoryNewestFirst(t *testing.T) { + t.Setenv("HOME", t.TempDir()) + + first := types.TaskRunSummary{RunID: "run-1", StartedAt: time.Unix(100, 0), FinishedAt: time.Unix(110, 0)} + second := types.TaskRunSummary{RunID: "run-2", StartedAt: time.Unix(200, 0), FinishedAt: time.Unix(210, 0)} + + if err := AppendRun("task-1", first); err != nil { + t.Fatalf("AppendRun(first) returned unexpected error: %v", err) + } + if err := AppendRun("task-1", second); err != nil { + t.Fatalf("AppendRun(second) returned unexpected error: %v", err) + } + + history, err := LoadHistory("task-1", 0) + if err != nil { + t.Fatalf("LoadHistory() returned unexpected error: %v", err) + } + if len(history) != 2 { + t.Fatalf("expected 2 history items, got %d", len(history)) + } + if history[0].RunID != "run-2" || history[1].RunID != "run-1" { + t.Fatalf("expected newest-first order, got %+v", history) + } + + limited, err := LoadHistory("task-1", 1) + if err != nil { + t.Fatalf("LoadHistory(limit) returned unexpected error: %v", err) + } + if len(limited) != 1 || limited[0].RunID != "run-2" { + t.Fatalf("unexpected limited history: %+v", limited) + } +} diff --git a/internal/task/input.go b/internal/task/input.go new file mode 100644 index 0000000..bd8f3f5 --- /dev/null +++ b/internal/task/input.go @@ -0,0 +1,48 @@ +package task + +import ( + "fmt" + + "github.com/yinxulai/ait/internal/prompt" + "github.com/yinxulai/ait/internal/types" +) + +func HydrateInput(input types.Input) (types.Input, error) { + if input.PromptSource != nil { + return input, nil + } + + switch input.PromptMode { + case "", "text": + if input.PromptText == "" { + return input, fmt.Errorf("prompt_text is required for prompt_mode=text") + } + source, err := prompt.LoadPrompts(input.PromptText) + if err != nil { + return input, err + } + input.PromptSource = source + case "file": + if input.PromptFile == "" { + return input, fmt.Errorf("prompt_file is required for prompt_mode=file") + } + source, err := prompt.LoadPromptsFromFile(input.PromptFile) + if err != nil { + return input, err + } + input.PromptSource = source + case "generated": + if input.PromptLength <= 0 { + return input, fmt.Errorf("prompt_length must be greater than zero for prompt_mode=generated") + } + source, err := prompt.LoadPromptByLength(input.PromptLength) + if err != nil { + return input, err + } + input.PromptSource = source + default: + return input, fmt.Errorf("unsupported prompt_mode: %s", input.PromptMode) + } + + return input, nil +} diff --git a/internal/task/input_test.go b/internal/task/input_test.go new file mode 100644 index 0000000..8a58f16 --- /dev/null +++ b/internal/task/input_test.go @@ -0,0 +1,33 @@ +package task + +import ( + "testing" + + "github.com/yinxulai/ait/internal/types" +) + +func TestHydrateInputTextMode(t *testing.T) { + input, err := HydrateInput(types.Input{PromptMode: "text", PromptText: "hello"}) + if err != nil { + t.Fatalf("HydrateInput(text) returned unexpected error: %v", err) + } + if input.PromptSource == nil || input.PromptSource.GetRandomContent() != "hello" { + t.Fatal("expected PromptSource to be hydrated from PromptText") + } +} + +func TestHydrateInputGeneratedMode(t *testing.T) { + input, err := HydrateInput(types.Input{PromptMode: "generated", PromptLength: 32}) + if err != nil { + t.Fatalf("HydrateInput(generated) returned unexpected error: %v", err) + } + if input.PromptSource == nil || input.PromptSource.Count() != 1 { + t.Fatal("expected generated PromptSource to be created") + } +} + +func TestHydrateInputRejectsInvalidMode(t *testing.T) { + if _, err := HydrateInput(types.Input{PromptMode: "unknown"}); err == nil { + t.Fatal("expected HydrateInput to reject unsupported prompt_mode") + } +} diff --git a/internal/task/store.go b/internal/task/store.go new file mode 100644 index 0000000..4c37620 --- /dev/null +++ b/internal/task/store.go @@ -0,0 +1,102 @@ +package task + +import ( + "encoding/json" + "errors" + "fmt" + "os" + "time" + + "github.com/yinxulai/ait/internal/config" + "github.com/yinxulai/ait/internal/types" +) + +type TaskStore struct { + Tasks []types.TaskDefinition `json:"tasks"` +} + +func LoadTasks() (*TaskStore, error) { + path, err := config.TasksPath() + if err != nil { + return nil, err + } + + data, err := os.ReadFile(path) + if errors.Is(err, os.ErrNotExist) { + return &TaskStore{Tasks: []types.TaskDefinition{}}, nil + } + if err != nil { + return nil, err + } + + var store TaskStore + if err := json.Unmarshal(data, &store); err != nil { + return nil, err + } + if store.Tasks == nil { + store.Tasks = []types.TaskDefinition{} + } + return &store, nil +} + +func (s *TaskStore) Save() error { + if _, err := config.EnsureAppDir(); err != nil { + return err + } + path, err := config.TasksPath() + if err != nil { + return err + } + data, err := json.MarshalIndent(s, "", " ") + if err != nil { + return err + } + return os.WriteFile(path, data, 0o644) +} + +func (s *TaskStore) Upsert(task types.TaskDefinition) { + now := time.Now() + if task.ID == "" { + task.ID = fmt.Sprintf("task_%d", now.UnixNano()) + } + + for i, existing := range s.Tasks { + if existing.ID != task.ID { + continue + } + if task.CreatedAt.IsZero() { + task.CreatedAt = existing.CreatedAt + } + task.UpdatedAt = now + updated := append([]types.TaskDefinition{task}, append(s.Tasks[:i], s.Tasks[i+1:]...)...) + s.Tasks = updated + return + } + + if task.CreatedAt.IsZero() { + task.CreatedAt = now + } + task.UpdatedAt = now + + s.Tasks = append([]types.TaskDefinition{task}, s.Tasks...) +} + +func (s *TaskStore) Delete(taskID string) error { + for i, task := range s.Tasks { + if task.ID != taskID { + continue + } + s.Tasks = append(s.Tasks[:i], s.Tasks[i+1:]...) + return nil + } + return os.ErrNotExist +} + +func (s *TaskStore) Get(taskID string) (types.TaskDefinition, bool) { + for _, task := range s.Tasks { + if task.ID == taskID { + return task, true + } + } + return types.TaskDefinition{}, false +} diff --git a/internal/task/store_test.go b/internal/task/store_test.go new file mode 100644 index 0000000..9959da0 --- /dev/null +++ b/internal/task/store_test.go @@ -0,0 +1,76 @@ +package task + +import ( + "errors" + "os" + "testing" + "time" + + "github.com/yinxulai/ait/internal/types" +) + +func TestLoadTasksReturnsEmptyStoreWhenMissing(t *testing.T) { + t.Setenv("HOME", t.TempDir()) + + store, err := LoadTasks() + if err != nil { + t.Fatalf("LoadTasks() returned unexpected error: %v", err) + } + if len(store.Tasks) != 0 { + t.Fatalf("expected no tasks, got %d", len(store.Tasks)) + } +} + +func TestTaskStoreUpsertSaveAndReload(t *testing.T) { + t.Setenv("HOME", t.TempDir()) + + store := &TaskStore{} + task := types.TaskDefinition{ + ID: "task-1", + Name: "nightly-openai", + Input: types.Input{ + Protocol: types.ProtocolOpenAIResponses, + EndpointURL: "https://api.openai.com/v1/responses", + Model: "gpt-4.1", + }, + } + store.Upsert(task) + if err := store.Save(); err != nil { + t.Fatalf("Save() returned unexpected error: %v", err) + } + + loaded, err := LoadTasks() + if err != nil { + t.Fatalf("LoadTasks() returned unexpected error: %v", err) + } + if len(loaded.Tasks) != 1 || loaded.Tasks[0].ID != "task-1" { + t.Fatalf("unexpected loaded tasks: %+v", loaded.Tasks) + } + + firstUpdatedAt := loaded.Tasks[0].UpdatedAt + time.Sleep(10 * time.Millisecond) + task.Name = "nightly-openai-updated" + loaded.Upsert(task) + if len(loaded.Tasks) != 1 { + t.Fatalf("expected one task after update, got %d", len(loaded.Tasks)) + } + if loaded.Tasks[0].Name != "nightly-openai-updated" { + t.Fatalf("expected updated task name, got %s", loaded.Tasks[0].Name) + } + if !loaded.Tasks[0].UpdatedAt.After(firstUpdatedAt) { + t.Fatalf("expected UpdatedAt to advance after Upsert") + } +} + +func TestTaskStoreDelete(t *testing.T) { + store := &TaskStore{Tasks: []types.TaskDefinition{{ID: "task-1"}, {ID: "task-2"}}} + if err := store.Delete("task-1"); err != nil { + t.Fatalf("Delete() returned unexpected error: %v", err) + } + if len(store.Tasks) != 1 || store.Tasks[0].ID != "task-2" { + t.Fatalf("unexpected tasks after delete: %+v", store.Tasks) + } + if err := store.Delete("missing"); !errors.Is(err, os.ErrNotExist) { + t.Fatalf("expected os.ErrNotExist, got %v", err) + } +} diff --git a/internal/tui/messages.go b/internal/tui/messages.go new file mode 100644 index 0000000..ad776f9 --- /dev/null +++ b/internal/tui/messages.go @@ -0,0 +1,22 @@ +package tui + +import "github.com/yinxulai/ait/internal/types" + +type progressMsg struct { + stats types.StatsData +} + +type runCompleteMsg struct { + taskID string + result *types.ReportData + reportPaths []string +} + +type turboCompleteMsg struct { + taskID string + result *types.TurboResult +} + +type asyncErrorMsg struct { + err error +} diff --git a/internal/tui/model.go b/internal/tui/model.go new file mode 100644 index 0000000..d9cb4f4 --- /dev/null +++ b/internal/tui/model.go @@ -0,0 +1,1054 @@ +package tui + +import ( + "fmt" + "path/filepath" + "strconv" + "strings" + "time" + + "github.com/charmbracelet/bubbles/textinput" + tea "github.com/charmbracelet/bubbletea" + "github.com/charmbracelet/lipgloss" + "github.com/yinxulai/ait/internal/config" + "github.com/yinxulai/ait/internal/report" + "github.com/yinxulai/ait/internal/runner" + "github.com/yinxulai/ait/internal/task" + "github.com/yinxulai/ait/internal/turbo" + "github.com/yinxulai/ait/internal/types" +) + +type viewState string + +const ( + viewTaskList viewState = "task-list" + viewTaskDetail viewState = "task-detail" + viewWizard viewState = "wizard" + viewDashboard viewState = "dashboard" + viewResult viewState = "result" + viewTurboResult viewState = "turbo-result" +) + +const ( + modeStandard = "standard" + modeTurbo = "turbo" + + promptModeText = "text" + promptModeFile = "file" + promptModeGenerated = "generated" +) + +var protocolOptions = []string{ + types.ProtocolOpenAICompletions, + types.ProtocolOpenAIResponses, + types.ProtocolAnthropicMessages, +} + +var promptModeOptions = []string{promptModeText, promptModeFile, promptModeGenerated} + +type fieldKind int + +const ( + fieldText fieldKind = iota + fieldSelect + fieldToggle +) + +type wizardField struct { + key string + label string + kind fieldKind +} + +type wizardState struct { + editingTaskID string + createdAt time.Time + lastRunAt *time.Time + lastRunSummary *types.TaskRunSummary + fromView viewState + current int + input textinput.Model + values map[string]string + protocolIndex int + mode string + promptModeIndex int + stream bool + thinking bool + report bool +} + +type Model struct { + styles styles + store *task.TaskStore + config *config.Config + tasks []types.TaskDefinition + history []types.TaskRunSummary + selected int + view viewState + wizard *wizardState + width int + height int + status string + err error + program *tea.Program + runningTask *types.TaskDefinition + runStartedAt time.Time + progress types.StatsData + runResult *types.ReportData + turboResult *types.TurboResult + activeRunner *runner.Runner + activeTurbo *turbo.Engine +} + +func NewModel(store *task.TaskStore, cfg *config.Config) *Model { + return &Model{ + styles: newStyles(), + store: store, + config: cfg, + tasks: store.Tasks, + view: viewTaskList, + } +} + +func Run() error { + store, err := task.LoadTasks() + if err != nil { + return err + } + cfg, err := config.Load() + if err != nil { + return err + } + model := NewModel(store, cfg) + program := tea.NewProgram(model, tea.WithAltScreen()) + model.program = program + _, err = program.Run() + return err +} + +func (m *Model) Init() tea.Cmd { + return nil +} + +func (m *Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { + switch msg := msg.(type) { + case tea.WindowSizeMsg: + m.width = msg.Width + m.height = msg.Height + return m, nil + case progressMsg: + m.progress = msg.stats + return m, nil + case runCompleteMsg: + m.activeRunner = nil + m.runResult = msg.result + m.view = viewResult + m.status = fmt.Sprintf("标准模式完成,共 %d 请求", msg.result.TotalRequests) + m.persistStandardRun(msg.taskID, msg.result, msg.reportPaths) + return m, nil + case turboCompleteMsg: + m.activeTurbo = nil + m.turboResult = msg.result + m.view = viewTurboResult + m.status = fmt.Sprintf("Turbo 完成,最大稳定并发 %d", msg.result.MaxStableConcurrency) + m.persistTurboRun(msg.taskID, msg.result) + return m, nil + case asyncErrorMsg: + m.err = msg.err + m.status = msg.err.Error() + return m, nil + case tea.KeyMsg: + return m.handleKey(msg) + } + + return m, nil +} + +func (m *Model) handleKey(msg tea.KeyMsg) (tea.Model, tea.Cmd) { + if msg.String() == "ctrl+c" { + return m, tea.Quit + } + + switch m.view { + case viewTaskList: + return m.handleTaskListKey(msg) + case viewTaskDetail: + return m.handleTaskDetailKey(msg) + case viewWizard: + return m.handleWizardKey(msg) + case viewDashboard: + return m.handleDashboardKey(msg) + case viewResult, viewTurboResult: + if msg.String() == "b" || msg.String() == "esc" || msg.String() == "enter" { + m.reloadHistoryForSelectedTask() + m.view = viewTaskDetail + return m, nil + } + } + + return m, nil +} + +func (m *Model) handleTaskListKey(msg tea.KeyMsg) (tea.Model, tea.Cmd) { + switch msg.String() { + case "up", "k": + if m.selected > 0 { + m.selected-- + } + case "down", "j": + if m.selected < len(m.tasks)-1 { + m.selected++ + } + case "a": + m.openWizard(nil) + case "e": + if taskDef, ok := m.currentTask(); ok { + copyTask := taskDef + m.openWizard(©Task) + } + case "y": + if taskDef, ok := m.currentTask(); ok { + copyTask := taskDef + copyTask.ID = "" + copyTask.Name = taskDef.Name + "-copy" + m.openWizard(©Task) + } + case "d": + if taskDef, ok := m.currentTask(); ok { + if err := m.store.Delete(taskDef.ID); err != nil { + m.err = err + break + } + if err := m.store.Save(); err != nil { + m.err = err + break + } + m.tasks = m.store.Tasks + if m.selected >= len(m.tasks) && m.selected > 0 { + m.selected-- + } + m.status = "任务已删除" + } + case "enter": + if _, ok := m.currentTask(); ok { + m.reloadHistoryForSelectedTask() + m.view = viewTaskDetail + } + case "r": + if taskDef, ok := m.currentTask(); ok { + m.startTaskRun(taskDef) + } + case "q": + return m, tea.Quit + } + return m, nil +} + +func (m *Model) handleTaskDetailKey(msg tea.KeyMsg) (tea.Model, tea.Cmd) { + taskDef, ok := m.currentTask() + if !ok { + m.view = viewTaskList + return m, nil + } + + switch msg.String() { + case "b", "esc": + m.view = viewTaskList + case "e": + copyTask := taskDef + m.openWizard(©Task) + case "d": + if err := m.store.Delete(taskDef.ID); err != nil { + m.err = err + break + } + if err := m.store.Save(); err != nil { + m.err = err + break + } + m.tasks = m.store.Tasks + if m.selected >= len(m.tasks) && m.selected > 0 { + m.selected-- + } + m.view = viewTaskList + case "enter", "r": + m.startTaskRun(taskDef) + } + + return m, nil +} + +func (m *Model) handleWizardKey(msg tea.KeyMsg) (tea.Model, tea.Cmd) { + field := m.currentWizardField() + switch msg.String() { + case "esc": + m.view = m.wizard.fromView + m.wizard = nil + return m, nil + case "tab", "enter": + if field.kind == fieldText { + m.wizard.values[field.key] = m.wizard.input.Value() + } + if m.wizard.current == len(m.wizardFields())-1 { + if err := m.saveWizard(); err != nil { + m.err = err + m.status = err.Error() + } + return m, nil + } + m.wizard.current++ + m.refreshWizardInput() + return m, nil + case "shift+tab", "up": + if m.wizard.current > 0 { + m.wizard.values[field.key] = m.wizard.input.Value() + m.wizard.current-- + m.refreshWizardInput() + } + return m, nil + case "left", "h": + m.cycleWizardField(-1) + return m, nil + case "right", "l", "space": + m.cycleWizardField(1) + return m, nil + } + if field.kind == fieldText { + var cmd tea.Cmd + m.wizard.input, cmd = m.wizard.input.Update(msg) + m.wizard.values[field.key] = m.wizard.input.Value() + return m, cmd + } + return m, nil +} + +func (m *Model) handleDashboardKey(msg tea.KeyMsg) (tea.Model, tea.Cmd) { + switch msg.String() { + case "s", "q", "esc": + if m.activeRunner != nil { + m.activeRunner.Stop() + } + if m.activeTurbo != nil { + m.activeTurbo.Stop() + } + } + return m, nil +} + +func (m *Model) View() string { + switch m.view { + case viewTaskDetail: + return m.renderTaskDetail() + case viewWizard: + return m.renderWizard() + case viewDashboard: + return m.renderDashboard() + case viewResult: + return m.renderResult() + case viewTurboResult: + return m.renderTurboResult() + default: + return m.renderTaskList() + } +} + +func (m *Model) renderTaskList() string { + var rows []string + for i, taskDef := range m.tasks { + mode := modeStandard + if taskDef.Input.Turbo { + mode = modeTurbo + } + summary := "从未运行" + if taskDef.LastRunSummary != nil { + summary = fmt.Sprintf("上次 %.1f%% · %.1f tok/s", taskDef.LastRunSummary.SuccessRate, taskDef.LastRunSummary.AvgTPS) + } + line := fmt.Sprintf("%s %s %s %s", taskDef.Name, taskDef.Input.Model, mode, summary) + if i == m.selected { + line = m.styles.selected.Render("▶ " + line) + } else { + line = " " + line + } + rows = append(rows, line) + } + if len(rows) == 0 { + rows = append(rows, m.styles.muted.Render("暂无任务,按 a 新建")) + } + + content := []string{ + m.styles.title.Render("AIT 任务中心"), + m.styles.subtitle.Render(fmt.Sprintf("已保存任务: %d", len(m.tasks))), + m.styles.panel.Render(strings.Join(rows, "\n")), + m.footer("[↑↓] 选择", "[Enter] 详情", "[a] 新建", "[r] 运行", "[e] 编辑", "[d] 删除", "[q] 退出"), + } + if m.status != "" { + content = append(content, m.styles.muted.Render(m.status)) + } + return strings.Join(content, "\n") +} + +func (m *Model) renderTaskDetail() string { + taskDef, ok := m.currentTask() + if !ok { + return m.styles.error.Render("任务不存在") + } + lastRun := "从未运行" + if taskDef.LastRunAt != nil { + lastRun = taskDef.LastRunAt.Format(time.RFC3339) + } + mode := modeStandard + if taskDef.Input.Turbo { + mode = modeTurbo + } + left := []string{ + fmt.Sprintf("名称: %s", taskDef.Name), + fmt.Sprintf("协议: %s", taskDef.Input.NormalizedProtocol()), + fmt.Sprintf("接口: %s", taskDef.Input.ResolvedEndpointURL()), + fmt.Sprintf("模型: %s", taskDef.Input.Model), + fmt.Sprintf("模式: %s", mode), + fmt.Sprintf("Prompt: %s", promptSummary(taskDef.Input)), + fmt.Sprintf("最近运行: %s", lastRun), + } + + historyLines := []string{m.styles.label.Render("最近运行记录")} + if len(m.history) == 0 { + historyLines = append(historyLines, m.styles.muted.Render("暂无历史")) + } else { + for _, item := range m.history { + historyLines = append(historyLines, fmt.Sprintf("%s %s %.1f%% %.1f tok/s cache %.1f%%", item.FinishedAt.Format("2006-01-02 15:04:05"), item.Mode, item.SuccessRate, item.AvgTPS, item.CacheHitRate)) + } + } + + return strings.Join([]string{ + m.styles.title.Render("AIT 任务详情"), + m.styles.panel.Render(strings.Join(left, "\n")), + m.styles.panel.Render(strings.Join(historyLines, "\n")), + m.footer("[Enter] 运行", "[e] 编辑", "[d] 删除", "[b] 返回"), + }, "\n") +} + +func (m *Model) renderWizard() string { + fields := m.wizardFields() + field := fields[m.wizard.current] + var lines []string + for i, f := range fields { + marker := " " + if i == m.wizard.current { + marker = "▶ " + } + lines = append(lines, marker+fmt.Sprintf("%s: %s", f.label, m.displayWizardValue(f))) + } + + editor := "" + if field.kind == fieldText { + editor = m.styles.panel.Render(m.wizard.input.View()) + } else { + editor = m.styles.panel.Render(m.displayWizardValue(field)) + } + + return strings.Join([]string{ + m.styles.title.Render("AIT 任务向导"), + m.styles.subtitle.Render(fmt.Sprintf("步骤 %d/%d", m.wizard.current+1, len(fields))), + m.styles.panel.Render(strings.Join(lines, "\n")), + editor, + m.footer("[Enter/Tab] 下一项或保存", "[←→/Space] 切换选项", "[Esc] 取消"), + }, "\n") +} + +func (m *Model) renderDashboard() string { + title := "AIT 正在运行" + if m.runningTask != nil && m.runningTask.Input.Turbo { + title = "AIT Turbo 正在探测" + } + stats := []string{ + fmt.Sprintf("完成: %d", m.progress.CompletedCount), + fmt.Sprintf("失败: %d", m.progress.FailedCount), + fmt.Sprintf("运行时长: %s", m.progress.ElapsedTime.Truncate(100*time.Millisecond)), + } + if len(m.progress.CacheHitRates) > 0 { + stats = append(stats, fmt.Sprintf("最近缓存命中率: %.1f%%", m.progress.CacheHitRates[len(m.progress.CacheHitRates)-1]*100)) + } + return strings.Join([]string{ + m.styles.title.Render(title), + m.styles.panel.Render(strings.Join(stats, "\n")), + m.footer("[s] 停止"), + }, "\n") +} + +func (m *Model) renderResult() string { + if m.runResult == nil { + return m.styles.error.Render("结果为空") + } + result := m.runResult + lines := []string{ + fmt.Sprintf("协议: %s", result.Protocol), + fmt.Sprintf("接口: %s", result.EndpointURL), + fmt.Sprintf("成功率: %.1f%%", result.SuccessRate), + fmt.Sprintf("平均 TTFT: %s", result.AvgTTFT), + fmt.Sprintf("平均 TPS: %.2f", result.AvgTPS), + fmt.Sprintf("缓存命中率: %.1f%%", result.AvgCacheHitRate*100), + fmt.Sprintf("平均总耗时: %s", result.AvgTotalTime), + } + return strings.Join([]string{ + m.styles.title.Render("AIT 标准模式结果"), + m.styles.panel.Render(strings.Join(lines, "\n")), + m.footer("[b] 返回详情"), + }, "\n") +} + +func (m *Model) renderTurboResult() string { + if m.turboResult == nil { + return m.styles.error.Render("Turbo 结果为空") + } + lines := []string{ + fmt.Sprintf("协议: %s", m.turboResult.Protocol), + fmt.Sprintf("接口: %s", m.turboResult.EndpointURL), + fmt.Sprintf("最大稳定并发: %d", m.turboResult.MaxStableConcurrency), + fmt.Sprintf("峰值平均 TPS: %.2f", m.turboResult.PeakTPS), + fmt.Sprintf("停止原因: %s", m.turboResult.StopReason), + } + for _, level := range m.turboResult.Levels { + status := "✓" + if !level.Stable { + status = "✗" + } + lines = append(lines, fmt.Sprintf("%s 并发 %d 成功率 %.1f%% avgTPS %.2f cache %.1f%%", status, level.Concurrency, level.SuccessRate*100, level.AvgTPS, level.CacheHitRate*100)) + } + return strings.Join([]string{ + m.styles.title.Render("AIT Turbo 结果"), + m.styles.panel.Render(strings.Join(lines, "\n")), + m.footer("[b] 返回详情"), + }, "\n") +} + +func (m *Model) footer(parts ...string) string { + styled := make([]string, 0, len(parts)) + for _, part := range parts { + styled = append(styled, m.styles.key.Render(part+" ")) + } + return lipgloss.JoinHorizontal(lipgloss.Left, styled...) +} + +func (m *Model) currentTask() (types.TaskDefinition, bool) { + if len(m.tasks) == 0 || m.selected < 0 || m.selected >= len(m.tasks) { + return types.TaskDefinition{}, false + } + return m.tasks[m.selected], true +} + +func (m *Model) openWizard(existing *types.TaskDefinition) { + state := newWizardState(existing, m.view, m.config) + m.wizard = state + m.view = viewWizard + m.refreshWizardInput() +} + +func newWizardState(existing *types.TaskDefinition, from viewState, cfg *config.Config) *wizardState { + input := textinput.New() + input.Width = 72 + input.Prompt = "" + values := map[string]string{ + "name": "", + "endpoint": "", + "apiKey": "", + "model": "", + "concurrency": "5", + "count": "100", + "timeout": "30s", + "turbo_init": "1", + "turbo_max": "50", + "turbo_step": "2", + "turbo_level_requests": "30", + "turbo_min_success": "0.9", + "turbo_max_latency": "10s", + "prompt_value": "你好,介绍一下你自己。", + } + state := &wizardState{ + fromView: from, + input: input, + values: values, + protocolIndex: protocolIndex(cfg.DefaultProtocol), + mode: modeStandard, + promptModeIndex: 0, + stream: true, + thinking: false, + report: true, + } + if existing != nil { + state.editingTaskID = existing.ID + state.createdAt = existing.CreatedAt + state.lastRunAt = existing.LastRunAt + state.lastRunSummary = existing.LastRunSummary + state.values["name"] = existing.Name + state.values["endpoint"] = existing.Input.ResolvedEndpointURL() + state.values["apiKey"] = existing.Input.ApiKey + state.values["model"] = existing.Input.Model + state.protocolIndex = protocolIndex(existing.Input.NormalizedProtocol()) + state.stream = existing.Input.Stream + state.thinking = existing.Input.Thinking + state.report = existing.Input.Report + if existing.Input.Turbo { + state.mode = modeTurbo + state.values["turbo_init"] = strconv.Itoa(existing.Input.TurboConfig.InitConcurrency) + state.values["turbo_max"] = strconv.Itoa(existing.Input.TurboConfig.MaxConcurrency) + state.values["turbo_step"] = strconv.Itoa(existing.Input.TurboConfig.StepSize) + state.values["turbo_level_requests"] = strconv.Itoa(existing.Input.TurboConfig.LevelRequests) + state.values["turbo_min_success"] = strconv.FormatFloat(existing.Input.TurboConfig.MinSuccessRate, 'f', -1, 64) + state.values["turbo_max_latency"] = existing.Input.TurboConfig.MaxLatency.String() + } else { + state.values["concurrency"] = strconv.Itoa(existing.Input.Concurrency) + state.values["count"] = strconv.Itoa(existing.Input.Count) + if existing.Input.Timeout > 0 { + state.values["timeout"] = existing.Input.Timeout.String() + } + } + switch existing.Input.PromptMode { + case promptModeFile: + state.promptModeIndex = 1 + state.values["prompt_value"] = existing.Input.PromptFile + case promptModeGenerated: + state.promptModeIndex = 2 + state.values["prompt_value"] = strconv.Itoa(existing.Input.PromptLength) + default: + state.promptModeIndex = 0 + state.values["prompt_value"] = existing.Input.PromptText + } + } + return state +} + +func protocolIndex(protocol string) int { + for i, item := range protocolOptions { + if item == types.NormalizeProtocol(protocol) { + return i + } + } + return 0 +} + +func (m *Model) wizardFields() []wizardField { + fields := []wizardField{ + {key: "name", label: "任务名称", kind: fieldText}, + {key: "protocol", label: "协议类型", kind: fieldSelect}, + {key: "endpoint", label: "完整接口地址", kind: fieldText}, + {key: "apiKey", label: "API 密钥", kind: fieldText}, + {key: "model", label: "测试模型", kind: fieldText}, + {key: "mode", label: "运行模式", kind: fieldSelect}, + } + if m.wizard.mode == modeTurbo { + fields = append(fields, + wizardField{key: "turbo_init", label: "初始并发", kind: fieldText}, + wizardField{key: "turbo_max", label: "最大并发", kind: fieldText}, + wizardField{key: "turbo_step", label: "步进值", kind: fieldText}, + wizardField{key: "turbo_level_requests", label: "每级请求数", kind: fieldText}, + wizardField{key: "turbo_min_success", label: "最小成功率", kind: fieldText}, + wizardField{key: "turbo_max_latency", label: "最大平均延迟", kind: fieldText}, + ) + } else { + fields = append(fields, + wizardField{key: "concurrency", label: "并发数", kind: fieldText}, + wizardField{key: "count", label: "请求总数", kind: fieldText}, + wizardField{key: "timeout", label: "超时时间", kind: fieldText}, + ) + } + fields = append(fields, + wizardField{key: "stream", label: "流式模式", kind: fieldToggle}, + wizardField{key: "thinking", label: "Thinking 模式", kind: fieldToggle}, + wizardField{key: "report", label: "生成报告", kind: fieldToggle}, + wizardField{key: "prompt_mode", label: "Prompt 输入方式", kind: fieldSelect}, + wizardField{key: "prompt_value", label: promptValueLabel(m.wizard.promptModeIndex), kind: fieldText}, + ) + return fields +} + +func promptValueLabel(promptModeIndex int) string { + switch promptModeOptions[promptModeIndex] { + case promptModeFile: + return "Prompt 文件路径" + case promptModeGenerated: + return "Prompt 生成长度" + default: + return "Prompt 文本" + } +} + +func (m *Model) currentWizardField() wizardField { + return m.wizardFields()[m.wizard.current] +} + +func (m *Model) refreshWizardInput() { + field := m.currentWizardField() + m.wizard.input.Blur() + m.wizard.input.Focus() + m.wizard.input.EchoMode = textinput.EchoNormal + if field.key == "apiKey" { + m.wizard.input.EchoMode = textinput.EchoPassword + } + m.wizard.input.SetValue(m.wizard.values[field.key]) + if field.key == "prompt_value" { + m.wizard.input.Placeholder = promptValueLabel(m.wizard.promptModeIndex) + } else { + m.wizard.input.Placeholder = field.label + } + if field.kind != fieldText { + m.wizard.input.SetValue("") + } +} + +func (m *Model) cycleWizardField(delta int) { + field := m.currentWizardField() + switch field.key { + case "protocol": + m.wizard.protocolIndex = wrapIndex(m.wizard.protocolIndex+delta, len(protocolOptions)) + case "mode": + if m.wizard.mode == modeStandard { + m.wizard.mode = modeTurbo + } else { + m.wizard.mode = modeStandard + } + case "prompt_mode": + m.wizard.promptModeIndex = wrapIndex(m.wizard.promptModeIndex+delta, len(promptModeOptions)) + m.wizard.values["prompt_value"] = "" + case "stream": + m.wizard.stream = !m.wizard.stream + case "thinking": + m.wizard.thinking = !m.wizard.thinking + case "report": + m.wizard.report = !m.wizard.report + default: + return + } + m.wizard.current = min(m.wizard.current, len(m.wizardFields())-1) + m.refreshWizardInput() +} + +func (m *Model) displayWizardValue(field wizardField) string { + switch field.key { + case "protocol": + return protocolOptions[m.wizard.protocolIndex] + case "mode": + return m.wizard.mode + case "stream": + return boolLabel(m.wizard.stream) + case "thinking": + return boolLabel(m.wizard.thinking) + case "report": + return boolLabel(m.wizard.report) + case "prompt_mode": + return promptModeOptions[m.wizard.promptModeIndex] + default: + return m.wizard.values[field.key] + } +} + +func boolLabel(v bool) string { + if v { + return "开启" + } + return "关闭" +} + +func wrapIndex(index, length int) int { + if length == 0 { + return 0 + } + for index < 0 { + index += length + } + return index % length +} + +func min(a, b int) int { + if a < b { + return a + } + return b +} + +func buildTaskDefinition(state *wizardState) (types.TaskDefinition, error) { + protocol := protocolOptions[state.protocolIndex] + input := types.Input{ + Protocol: protocol, + EndpointURL: strings.TrimSpace(state.values["endpoint"]), + ApiKey: strings.TrimSpace(state.values["apiKey"]), + Model: strings.TrimSpace(state.values["model"]), + Stream: state.stream, + Thinking: state.thinking, + Report: state.report, + PromptMode: promptModeOptions[state.promptModeIndex], + } + + switch input.PromptMode { + case promptModeFile: + input.PromptFile = strings.TrimSpace(state.values["prompt_value"]) + case promptModeGenerated: + length, err := strconv.Atoi(strings.TrimSpace(state.values["prompt_value"])) + if err != nil { + return types.TaskDefinition{}, fmt.Errorf("invalid prompt length: %w", err) + } + input.PromptLength = length + default: + input.PromptText = state.values["prompt_value"] + } + + if state.mode == modeTurbo { + initConcurrency, err := strconv.Atoi(strings.TrimSpace(state.values["turbo_init"])) + if err != nil { + return types.TaskDefinition{}, fmt.Errorf("invalid turbo init concurrency: %w", err) + } + maxConcurrency, err := strconv.Atoi(strings.TrimSpace(state.values["turbo_max"])) + if err != nil { + return types.TaskDefinition{}, fmt.Errorf("invalid turbo max concurrency: %w", err) + } + stepSize, err := strconv.Atoi(strings.TrimSpace(state.values["turbo_step"])) + if err != nil { + return types.TaskDefinition{}, fmt.Errorf("invalid turbo step size: %w", err) + } + levelRequests, err := strconv.Atoi(strings.TrimSpace(state.values["turbo_level_requests"])) + if err != nil { + return types.TaskDefinition{}, fmt.Errorf("invalid turbo level requests: %w", err) + } + minSuccessRate, err := strconv.ParseFloat(strings.TrimSpace(state.values["turbo_min_success"]), 64) + if err != nil { + return types.TaskDefinition{}, fmt.Errorf("invalid turbo min success rate: %w", err) + } + maxLatency, err := time.ParseDuration(strings.TrimSpace(state.values["turbo_max_latency"])) + if err != nil { + return types.TaskDefinition{}, fmt.Errorf("invalid turbo max latency: %w", err) + } + input.Turbo = true + input.Count = levelRequests + input.Concurrency = initConcurrency + input.TurboConfig = types.TurboConfig{ + InitConcurrency: initConcurrency, + MaxConcurrency: maxConcurrency, + StepSize: stepSize, + LevelRequests: levelRequests, + MinSuccessRate: minSuccessRate, + MaxLatency: maxLatency, + } + } else { + concurrency, err := strconv.Atoi(strings.TrimSpace(state.values["concurrency"])) + if err != nil { + return types.TaskDefinition{}, fmt.Errorf("invalid concurrency: %w", err) + } + count, err := strconv.Atoi(strings.TrimSpace(state.values["count"])) + if err != nil { + return types.TaskDefinition{}, fmt.Errorf("invalid count: %w", err) + } + timeout, err := time.ParseDuration(strings.TrimSpace(state.values["timeout"])) + if err != nil { + return types.TaskDefinition{}, fmt.Errorf("invalid timeout: %w", err) + } + input.Concurrency = concurrency + input.Count = count + input.Timeout = timeout + } + + validatedInput, err := task.HydrateInput(input) + if err != nil { + return types.TaskDefinition{}, err + } + validatedInput.PromptSource = nil + + now := time.Now() + createdAt := state.createdAt + if createdAt.IsZero() { + createdAt = now + } + + return types.TaskDefinition{ + ID: state.editingTaskID, + Name: strings.TrimSpace(state.values["name"]), + Input: validatedInput, + CreatedAt: createdAt, + UpdatedAt: now, + LastRunAt: state.lastRunAt, + LastRunSummary: state.lastRunSummary, + }, nil +} + +func (m *Model) saveWizard() error { + taskDef, err := buildTaskDefinition(m.wizard) + if err != nil { + return err + } + if taskDef.ID == "" { + taskDef.ID = fmt.Sprintf("task_%d", time.Now().UnixNano()) + } + m.store.Upsert(taskDef) + if err := m.store.Save(); err != nil { + return err + } + m.tasks = m.store.Tasks + for i, item := range m.tasks { + if item.ID == taskDef.ID { + m.selected = i + break + } + } + m.status = "任务已保存" + m.wizard = nil + m.view = viewTaskList + return nil +} + +func (m *Model) reloadHistoryForSelectedTask() { + taskDef, ok := m.currentTask() + if !ok { + m.history = nil + return + } + history, err := task.LoadHistory(taskDef.ID, 5) + if err != nil { + m.err = err + m.history = nil + return + } + m.history = history +} + +func (m *Model) startTaskRun(taskDef types.TaskDefinition) { + input, err := task.HydrateInput(taskDef.Input) + if err != nil { + m.err = err + return + } + m.runningTask = &taskDef + m.runStartedAt = time.Now() + m.progress = types.StatsData{} + m.runResult = nil + m.turboResult = nil + m.view = viewDashboard + + if input.Turbo { + engine := turbo.New(turbo.DefaultRunnerFactory(taskDef.ID)) + m.activeTurbo = engine + go func() { + result, err := engine.Run(input) + if err != nil { + m.program.Send(asyncErrorMsg{err: err}) + return + } + m.program.Send(turboCompleteMsg{taskID: taskDef.ID, result: result}) + }() + return + } + + runnerInstance, err := runner.NewRunner(taskDef.ID, input) + if err != nil { + m.err = err + return + } + m.activeRunner = runnerInstance + go func() { + result, err := runnerInstance.RunWithProgress(func(stats types.StatsData) { + m.program.Send(progressMsg{stats: stats}) + }) + if err != nil { + m.program.Send(asyncErrorMsg{err: err}) + return + } + paths, err := generateReports(result, input.Report) + if err != nil { + m.program.Send(asyncErrorMsg{err: err}) + return + } + m.program.Send(runCompleteMsg{taskID: taskDef.ID, result: result, reportPaths: paths}) + }() +} + +func generateReports(result *types.ReportData, enabled bool) ([]string, error) { + if !enabled || result == nil { + return nil, nil + } + manager := report.NewReportManager() + return manager.GenerateReports([]types.ReportData{*result}, []string{"json", "csv"}) +} + +func (m *Model) persistStandardRun(taskID string, result *types.ReportData, reportPaths []string) { + taskDef, ok := m.store.Get(taskID) + if !ok { + return + } + finishedAt := time.Now() + summary := &types.TaskRunSummary{ + RunID: fmt.Sprintf("run_%d", finishedAt.UnixNano()), + TaskID: taskID, + Mode: modeStandard, + Status: "completed", + Protocol: result.Protocol, + Model: result.Model, + StartedAt: m.runStartedAt, + FinishedAt: finishedAt, + SuccessRate: result.SuccessRate, + AvgTTFT: result.AvgTTFT, + AvgTPS: result.AvgTPS, + CacheHitRate: result.AvgCacheHitRate * 100, + } + for _, path := range reportPaths { + switch filepath.Ext(path) { + case ".json": + summary.ReportJSONPath = path + case ".csv": + summary.ReportCSVPath = path + } + } + taskDef.LastRunAt = &finishedAt + taskDef.LastRunSummary = summary + m.store.Upsert(taskDef) + _ = m.store.Save() + _ = task.AppendRun(taskID, *summary) + m.tasks = m.store.Tasks + m.reloadHistoryForSelectedTask() +} + +func (m *Model) persistTurboRun(taskID string, result *types.TurboResult) { + taskDef, ok := m.store.Get(taskID) + if !ok { + return + } + finishedAt := time.Now() + latestSuccessRate := 0.0 + latestCacheHitRate := 0.0 + if len(result.Levels) > 0 { + lastLevel := result.Levels[len(result.Levels)-1] + latestSuccessRate = lastLevel.SuccessRate * 100 + latestCacheHitRate = lastLevel.CacheHitRate * 100 + } + summary := &types.TaskRunSummary{ + RunID: fmt.Sprintf("run_%d", finishedAt.UnixNano()), + TaskID: taskID, + Mode: modeTurbo, + Status: result.StopReason, + Protocol: result.Protocol, + Model: result.Model, + StartedAt: m.runStartedAt, + FinishedAt: finishedAt, + SuccessRate: latestSuccessRate, + AvgTPS: result.PeakTPS, + CacheHitRate: latestCacheHitRate, + MaxStableConcurrency: result.MaxStableConcurrency, + } + taskDef.LastRunAt = &finishedAt + taskDef.LastRunSummary = summary + m.store.Upsert(taskDef) + _ = m.store.Save() + _ = task.AppendRun(taskID, *summary) + m.tasks = m.store.Tasks + m.reloadHistoryForSelectedTask() +} + +func promptSummary(input types.Input) string { + switch input.PromptMode { + case promptModeFile: + return input.PromptFile + case promptModeGenerated: + return fmt.Sprintf("长度 %d", input.PromptLength) + default: + if len(input.PromptText) > 48 { + return input.PromptText[:48] + "..." + } + return input.PromptText + } +} diff --git a/internal/tui/model_test.go b/internal/tui/model_test.go new file mode 100644 index 0000000..df03c10 --- /dev/null +++ b/internal/tui/model_test.go @@ -0,0 +1,73 @@ +package tui + +import ( + "testing" + "time" + + "github.com/yinxulai/ait/internal/config" + "github.com/yinxulai/ait/internal/types" +) + +func TestBuildTaskDefinitionStandardMode(t *testing.T) { + state := newWizardState(nil, viewTaskList, &config.Config{DefaultProtocol: types.ProtocolOpenAIResponses}) + state.values["name"] = "nightly-openai" + state.values["endpoint"] = "https://api.openai.com/v1/responses" + state.values["apiKey"] = "sk-test" + state.values["model"] = "gpt-4.1" + state.values["concurrency"] = "8" + state.values["count"] = "120" + state.values["timeout"] = "45s" + state.values["prompt_value"] = "hello" + + taskDef, err := buildTaskDefinition(state) + if err != nil { + t.Fatalf("buildTaskDefinition() returned unexpected error: %v", err) + } + if taskDef.Input.Protocol != types.ProtocolOpenAIResponses { + t.Fatalf("expected protocol %s, got %s", types.ProtocolOpenAIResponses, taskDef.Input.Protocol) + } + if taskDef.Input.EndpointURL != "https://api.openai.com/v1/responses" { + t.Fatalf("unexpected endpoint: %s", taskDef.Input.EndpointURL) + } + if taskDef.Input.Concurrency != 8 || taskDef.Input.Count != 120 || taskDef.Input.Timeout != 45*time.Second { + t.Fatalf("unexpected standard input fields: %+v", taskDef.Input) + } + if taskDef.Input.PromptMode != promptModeText || taskDef.Input.PromptText != "hello" { + t.Fatalf("unexpected prompt fields: %+v", taskDef.Input) + } +} + +func TestBuildTaskDefinitionTurboMode(t *testing.T) { + state := newWizardState(nil, viewTaskList, &config.Config{}) + state.mode = modeTurbo + state.protocolIndex = 2 + state.promptModeIndex = 2 + state.values["name"] = "turbo-anthropic" + state.values["endpoint"] = "https://api.anthropic.com/v1/messages" + state.values["apiKey"] = "sk-ant" + state.values["model"] = "claude-3-7-sonnet" + state.values["turbo_init"] = "1" + state.values["turbo_max"] = "12" + state.values["turbo_step"] = "2" + state.values["turbo_level_requests"] = "20" + state.values["turbo_min_success"] = "0.92" + state.values["turbo_max_latency"] = "6s" + state.values["prompt_value"] = "256" + + taskDef, err := buildTaskDefinition(state) + if err != nil { + t.Fatalf("buildTaskDefinition() returned unexpected error: %v", err) + } + if !taskDef.Input.Turbo { + t.Fatal("expected Turbo to be enabled") + } + if taskDef.Input.TurboConfig.MaxConcurrency != 12 || taskDef.Input.TurboConfig.MaxLatency != 6*time.Second { + t.Fatalf("unexpected turbo config: %+v", taskDef.Input.TurboConfig) + } + if taskDef.Input.PromptMode != promptModeGenerated || taskDef.Input.PromptLength != 256 { + t.Fatalf("unexpected generated prompt config: %+v", taskDef.Input) + } + if taskDef.Input.Protocol != types.ProtocolAnthropicMessages { + t.Fatalf("expected anthropic protocol, got %s", taskDef.Input.Protocol) + } +} diff --git a/internal/tui/styles.go b/internal/tui/styles.go new file mode 100644 index 0000000..9a69807 --- /dev/null +++ b/internal/tui/styles.go @@ -0,0 +1,32 @@ +package tui + +import "github.com/charmbracelet/lipgloss" + +type styles struct { + title lipgloss.Style + subtitle lipgloss.Style + panel lipgloss.Style + selected lipgloss.Style + muted lipgloss.Style + error lipgloss.Style + ok lipgloss.Style + key lipgloss.Style + label lipgloss.Style + value lipgloss.Style +} + +func newStyles() styles { + border := lipgloss.RoundedBorder() + return styles{ + title: lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("212")).Padding(0, 1), + subtitle: lipgloss.NewStyle().Foreground(lipgloss.Color("110")).Padding(0, 1), + panel: lipgloss.NewStyle().Border(border).BorderForeground(lipgloss.Color("62")).Padding(1, 2), + selected: lipgloss.NewStyle().Foreground(lipgloss.Color("230")).Background(lipgloss.Color("62")).Bold(true), + muted: lipgloss.NewStyle().Foreground(lipgloss.Color("245")), + error: lipgloss.NewStyle().Foreground(lipgloss.Color("203")).Bold(true), + ok: lipgloss.NewStyle().Foreground(lipgloss.Color("85")).Bold(true), + key: lipgloss.NewStyle().Foreground(lipgloss.Color("214")).Bold(true), + label: lipgloss.NewStyle().Foreground(lipgloss.Color("151")).Bold(true), + value: lipgloss.NewStyle().Foreground(lipgloss.Color("255")), + } +} diff --git a/internal/turbo/engine.go b/internal/turbo/engine.go new file mode 100644 index 0000000..e13bc15 --- /dev/null +++ b/internal/turbo/engine.go @@ -0,0 +1,195 @@ +package turbo + +import ( + "fmt" + "sync" + "time" + + runnerpkg "github.com/yinxulai/ait/internal/runner" + "github.com/yinxulai/ait/internal/types" +) + +const ( + StopReasonLowSuccessRate = "low_success_rate" + StopReasonHighLatency = "high_latency" + StopReasonMaxConcurrency = "max_concurrency" + StopReasonManual = "manual" +) + +type LevelRunner interface { + Run() (*types.ReportData, error) + Stop() +} + +type RunnerFactory func(input types.Input) (LevelRunner, error) + +type Engine struct { + runnerFactory RunnerFactory + now func() time.Time + mu sync.Mutex + currentRunner LevelRunner + stopCh chan struct{} + stopOnce sync.Once +} + +func New(factory RunnerFactory) *Engine { + return &Engine{ + runnerFactory: factory, + now: time.Now, + stopCh: make(chan struct{}), + } +} + +func DefaultRunnerFactory(taskID string) RunnerFactory { + return func(input types.Input) (LevelRunner, error) { + return runnerpkg.NewRunner(taskID, input) + } +} + +func (e *Engine) Stop() { + e.stopOnce.Do(func() { + close(e.stopCh) + }) + e.mu.Lock() + runner := e.currentRunner + e.mu.Unlock() + if runner != nil { + runner.Stop() + } +} + +func (e *Engine) Run(input types.Input) (*types.TurboResult, error) { + if e.runnerFactory == nil { + return nil, fmt.Errorf("turbo runnerFactory is required") + } + + cfg := normalizeConfig(input.TurboConfig, input.Count) + startedAt := e.now() + result := &types.TurboResult{ + Config: cfg, + Levels: []types.TurboLevelResult{}, + Model: input.Model, + Protocol: input.NormalizedProtocol(), + EndpointURL: input.ResolvedEndpointURL(), + Timestamp: startedAt.Format(time.RFC3339), + } + + for concurrency := cfg.InitConcurrency; concurrency <= cfg.MaxConcurrency; concurrency += cfg.StepSize { + select { + case <-e.stopCh: + result.StopReason = StopReasonManual + result.ProbeDuration = time.Since(startedAt) + return result, nil + default: + } + + levelInput := input + levelInput.Turbo = false + levelInput.Concurrency = concurrency + levelInput.Count = cfg.LevelRequests + + levelRunner, err := e.runnerFactory(levelInput) + if err != nil { + return nil, err + } + + e.mu.Lock() + e.currentRunner = levelRunner + e.mu.Unlock() + + report, err := levelRunner.Run() + + e.mu.Lock() + e.currentRunner = nil + e.mu.Unlock() + + if err != nil { + return nil, err + } + + level := buildLevelResult(report, concurrency) + result.Levels = append(result.Levels, level) + + select { + case <-e.stopCh: + result.StopReason = StopReasonManual + result.ProbeDuration = time.Since(startedAt) + return result, nil + default: + } + + if level.SuccessRate < cfg.MinSuccessRate { + result.Levels[len(result.Levels)-1].Stable = false + result.Levels[len(result.Levels)-1].StopReason = StopReasonLowSuccessRate + result.StopReason = StopReasonLowSuccessRate + break + } + if level.AvgTotalTime > cfg.MaxLatency { + result.Levels[len(result.Levels)-1].Stable = false + result.Levels[len(result.Levels)-1].StopReason = StopReasonHighLatency + result.StopReason = StopReasonHighLatency + break + } + + result.MaxStableConcurrency = concurrency + if level.AvgTPS > result.PeakTPS { + result.PeakTPS = level.AvgTPS + } + + if concurrency+cfg.StepSize > cfg.MaxConcurrency { + result.StopReason = StopReasonMaxConcurrency + } + } + + if result.StopReason == "" { + result.StopReason = StopReasonMaxConcurrency + } + result.ProbeDuration = time.Since(startedAt) + return result, nil +} + +func buildLevelResult(report *types.ReportData, concurrency int) types.TurboLevelResult { + successCount := int(float64(report.TotalRequests) * report.SuccessRate / 100) + return types.TurboLevelResult{ + Concurrency: concurrency, + TotalRequests: report.TotalRequests, + SuccessCount: successCount, + SuccessRate: report.SuccessRate / 100, + AvgTPS: report.AvgTPS, + PeakTPS: report.MaxTPS, + AvgTTFT: report.AvgTTFT, + CacheHitRate: report.AvgCacheHitRate, + AvgTotalTime: report.AvgTotalTime, + StdDevTPS: report.StdDevTPS, + Stable: true, + } +} + +func normalizeConfig(cfg types.TurboConfig, fallbackLevelRequests int) types.TurboConfig { + if cfg.InitConcurrency <= 0 { + cfg.InitConcurrency = 1 + } + if cfg.MaxConcurrency <= 0 { + cfg.MaxConcurrency = 50 + } + if cfg.MaxConcurrency < cfg.InitConcurrency { + cfg.MaxConcurrency = cfg.InitConcurrency + } + if cfg.StepSize <= 0 { + cfg.StepSize = 2 + } + if cfg.LevelRequests <= 0 { + if fallbackLevelRequests > 0 { + cfg.LevelRequests = fallbackLevelRequests + } else { + cfg.LevelRequests = 30 + } + } + if cfg.MinSuccessRate <= 0 { + cfg.MinSuccessRate = 0.9 + } + if cfg.MaxLatency <= 0 { + cfg.MaxLatency = 10 * time.Second + } + return cfg +} diff --git a/internal/turbo/engine_test.go b/internal/turbo/engine_test.go new file mode 100644 index 0000000..52c9ae8 --- /dev/null +++ b/internal/turbo/engine_test.go @@ -0,0 +1,147 @@ +package turbo + +import ( + "sync/atomic" + "testing" + "time" + + "github.com/yinxulai/ait/internal/types" +) + +type fakeRunner struct { + report *types.ReportData + err error + stopCalled *atomic.Bool + blockUntil <-chan struct{} +} + +func (f *fakeRunner) Run() (*types.ReportData, error) { + if f.blockUntil != nil { + <-f.blockUntil + } + return f.report, f.err +} + +func (f *fakeRunner) Stop() { + if f.stopCalled != nil { + f.stopCalled.Store(true) + } +} + +func TestEngineRunStopsOnLowSuccessRate(t *testing.T) { + levels := map[int]*types.ReportData{ + 1: {TotalRequests: 10, SuccessRate: 100, AvgTPS: 10, MaxTPS: 12, AvgTTFT: 50 * time.Millisecond, AvgTotalTime: 200 * time.Millisecond}, + 2: {TotalRequests: 10, SuccessRate: 95, AvgTPS: 20, MaxTPS: 22, AvgTTFT: 60 * time.Millisecond, AvgTotalTime: 300 * time.Millisecond}, + 3: {TotalRequests: 10, SuccessRate: 80, AvgTPS: 18, MaxTPS: 20, AvgTTFT: 80 * time.Millisecond, AvgTotalTime: 400 * time.Millisecond}, + } + engine := New(func(input types.Input) (LevelRunner, error) { + return &fakeRunner{report: levels[input.Concurrency]}, nil + }) + + result, err := engine.Run(types.Input{ + Protocol: types.ProtocolOpenAIResponses, + EndpointURL: "https://api.openai.com/v1/responses", + Model: "gpt-4.1", + Count: 30, + TurboConfig: types.TurboConfig{InitConcurrency: 1, MaxConcurrency: 5, StepSize: 1, LevelRequests: 10, MinSuccessRate: 0.9, MaxLatency: 10 * time.Second}, + }) + if err != nil { + t.Fatalf("Run() returned unexpected error: %v", err) + } + if len(result.Levels) != 3 { + t.Fatalf("expected 3 levels, got %d", len(result.Levels)) + } + if result.MaxStableConcurrency != 2 { + t.Fatalf("expected MaxStableConcurrency 2, got %d", result.MaxStableConcurrency) + } + if result.StopReason != StopReasonLowSuccessRate { + t.Fatalf("expected stop reason %s, got %s", StopReasonLowSuccessRate, result.StopReason) + } + if result.PeakTPS != 20 { + t.Fatalf("expected PeakTPS 20, got %f", result.PeakTPS) + } + if result.Levels[2].Stable { + t.Fatal("expected last level to be marked unstable") + } + if result.Levels[2].StopReason != StopReasonLowSuccessRate { + t.Fatalf("expected level stop reason %s, got %s", StopReasonLowSuccessRate, result.Levels[2].StopReason) + } +} + +func TestEngineRunStopsOnHighLatency(t *testing.T) { + engine := New(func(input types.Input) (LevelRunner, error) { + report := &types.ReportData{TotalRequests: 10, SuccessRate: 100, AvgTPS: 10, MaxTPS: 15, AvgTTFT: 80 * time.Millisecond, AvgTotalTime: time.Duration(input.Concurrency) * time.Second} + return &fakeRunner{report: report}, nil + }) + + result, err := engine.Run(types.Input{ + Protocol: types.ProtocolAnthropicMessages, + EndpointURL: "https://api.anthropic.com/v1/messages", + Model: "claude-3-7-sonnet", + Count: 20, + TurboConfig: types.TurboConfig{InitConcurrency: 1, MaxConcurrency: 5, StepSize: 1, LevelRequests: 10, MinSuccessRate: 0.9, MaxLatency: 2 * time.Second}, + }) + if err != nil { + t.Fatalf("Run() returned unexpected error: %v", err) + } + if result.StopReason != StopReasonHighLatency { + t.Fatalf("expected stop reason %s, got %s", StopReasonHighLatency, result.StopReason) + } + if result.MaxStableConcurrency != 2 { + t.Fatalf("expected MaxStableConcurrency 2, got %d", result.MaxStableConcurrency) + } +} + +func TestEngineStopPropagatesToActiveRunner(t *testing.T) { + stopCalled := &atomic.Bool{} + blocker := make(chan struct{}) + engine := New(func(input types.Input) (LevelRunner, error) { + return &fakeRunner{ + report: &types.ReportData{TotalRequests: 1, SuccessRate: 100, AvgTPS: 1, MaxTPS: 1, AvgTotalTime: 10 * time.Millisecond}, + stopCalled: stopCalled, + blockUntil: blocker, + }, nil + }) + + resultCh := make(chan *types.TurboResult, 1) + errCh := make(chan error, 1) + go func() { + result, err := engine.Run(types.Input{ + Protocol: types.ProtocolOpenAICompletions, + EndpointURL: "https://api.openai.com/v1/chat/completions", + Model: "gpt-4.1-mini", + Count: 10, + TurboConfig: types.TurboConfig{InitConcurrency: 1, MaxConcurrency: 3, StepSize: 1, LevelRequests: 1, MinSuccessRate: 0.9, MaxLatency: time.Second}, + }) + resultCh <- result + errCh <- err + }() + + time.Sleep(30 * time.Millisecond) + engine.Stop() + close(blocker) + + result := <-resultCh + if err := <-errCh; err != nil { + t.Fatalf("Run() returned unexpected error: %v", err) + } + if !stopCalled.Load() { + t.Fatal("expected active runner Stop() to be called") + } + if result.StopReason != StopReasonManual { + t.Fatalf("expected stop reason %s, got %s", StopReasonManual, result.StopReason) + } +} + +func TestNormalizeConfigUsesDefaults(t *testing.T) { + cfg := normalizeConfig(types.TurboConfig{}, 12) + if cfg.InitConcurrency != 1 || cfg.MaxConcurrency != 50 || cfg.StepSize != 2 { + t.Fatalf("unexpected concurrency defaults: %+v", cfg) + } + if cfg.LevelRequests != 12 { + t.Fatalf("expected fallback LevelRequests 12, got %d", cfg.LevelRequests) + } + if cfg.MinSuccessRate != 0.9 || cfg.MaxLatency != 10*time.Second { + t.Fatalf("unexpected threshold defaults: %+v", cfg) + } +} diff --git a/internal/types/types.go b/internal/types/types.go index e4d56da..a4d7836 100644 --- a/internal/types/types.go +++ b/internal/types/types.go @@ -2,9 +2,80 @@ package types import ( "encoding/json" + "strings" "time" ) +const ( + ProtocolOpenAICompletions = "openai-completions" + ProtocolOpenAIResponses = "openai-responses" + ProtocolAnthropicMessages = "anthropic-messages" +) + +func NormalizeProtocol(protocol string) string { + switch strings.ToLower(strings.TrimSpace(protocol)) { + case "", "openai", ProtocolOpenAICompletions: + return ProtocolOpenAICompletions + case ProtocolOpenAIResponses: + return ProtocolOpenAIResponses + case "anthropic", ProtocolAnthropicMessages: + return ProtocolAnthropicMessages + default: + return strings.TrimSpace(protocol) + } +} + +func DefaultEndpointURL(protocol string) string { + switch NormalizeProtocol(protocol) { + case ProtocolOpenAICompletions: + return "https://api.openai.com/v1/chat/completions" + case ProtocolOpenAIResponses: + return "https://api.openai.com/v1/responses" + case ProtocolAnthropicMessages: + return "https://api.anthropic.com/v1/messages" + default: + return "" + } +} + +func ResolveEndpointURL(protocol, endpointURL, baseURL string) string { + resolved := strings.TrimSpace(endpointURL) + if resolved != "" { + return resolved + } + + resolved = strings.TrimRight(strings.TrimSpace(baseURL), "/") + if resolved == "" { + return DefaultEndpointURL(protocol) + } + + switch NormalizeProtocol(protocol) { + case ProtocolOpenAICompletions: + if strings.HasSuffix(resolved, "/chat/completions") { + return resolved + } + if strings.HasSuffix(resolved, "/v1") { + return resolved + "/chat/completions" + } + return resolved + "/v1/chat/completions" + case ProtocolOpenAIResponses: + if strings.HasSuffix(resolved, "/responses") { + return resolved + } + if strings.HasSuffix(resolved, "/v1") { + return resolved + "/responses" + } + return resolved + "/v1/responses" + case ProtocolAnthropicMessages: + if strings.HasSuffix(resolved, "/v1/messages") { + return resolved + } + return resolved + "/v1/messages" + default: + return resolved + } +} + // PromptSource 需要前向声明,实际定义在 prompt 包中 type PromptSource interface { GetRandomContent() string @@ -14,18 +85,33 @@ type PromptSource interface { // Input 测试配置信息 - 统一的配置结构 type Input struct { - Protocol string - BaseUrl string - ApiKey string - Model string // 多个模型列表 - Concurrency int - Count int - Stream bool - Thinking bool // 是否开启 thinking 模式(仅支持 OpenAI 协议) - PromptSource PromptSource // 改为使用PromptSource而不是简单字符串 - Report bool // 是否生成报告文件 - Timeout time.Duration // 请求超时时间 - Log bool // 是否开启详细日志记录 + Protocol string `json:"protocol"` + EndpointURL string `json:"endpoint_url,omitempty"` + BaseUrl string `json:"base_url,omitempty"` + ApiKey string `json:"api_key,omitempty"` + Model string `json:"model"` + Concurrency int `json:"concurrency,omitempty"` + Count int `json:"count,omitempty"` + Stream bool `json:"stream,omitempty"` + Thinking bool `json:"thinking,omitempty"` // 是否开启 thinking 模式(仅支持 OpenAI 协议) + Turbo bool `json:"turbo,omitempty"` // 是否启用 Turbo 模式 + TurboConfig TurboConfig `json:"turbo_config,omitempty"` // Turbo 模式配置 + PromptMode string `json:"prompt_mode,omitempty"` + PromptText string `json:"prompt_text,omitempty"` + PromptFile string `json:"prompt_file,omitempty"` + PromptLength int `json:"prompt_length,omitempty"` + PromptSource PromptSource `json:"-"` // 运行态字段,不直接持久化 + Report bool `json:"report,omitempty"` // 是否生成报告文件 + Timeout time.Duration `json:"timeout,omitempty"` // 请求超时时间 + Log bool `json:"log,omitempty"` // 是否开启详细日志记录 +} + +func (i Input) NormalizedProtocol() string { + return NormalizeProtocol(i.Protocol) +} + +func (i Input) ResolvedEndpointURL() string { + return ResolveEndpointURL(i.Protocol, i.EndpointURL, i.BaseUrl) } // StatsData 实时测试统计数据 - runner 内部使用的统计结构 @@ -45,9 +131,11 @@ type StatsData struct { TLSHandshakeTimes []time.Duration // 所有TLS握手时间 // 服务性能指标 - 原始数据收集(与 ReportData 命名对齐) - InputTokenCounts []int // 所有 prompt/input token 数量 - OutputTokenCounts []int // 所有 completion/output token 数量 (用于TPS计算) + InputTokenCounts []int // 所有 prompt/input token 数量 + CachedInputTokenCounts []int // 所有缓存命中的输入 token 数量 + OutputTokenCounts []int // 所有 completion/output token 数量 (用于TPS计算) ThinkingTokenCounts []int // 所有思考/推理 token 数量 + CacheHitRates []float64 // 所有请求的缓存命中率 // 错误信息 ErrorMessages []string // 所有错误信息 @@ -72,6 +160,7 @@ type ReportData struct { Timestamp string `json:"timestamp"` // 测试时间戳 Protocol string `json:"protocol"` // 协议类型 Model string `json:"model"` // 模型名称 + EndpointURL string `json:"endpoint_url,omitempty"` // 完整接口地址 BaseUrl string `json:"base_url"` // 基础URL // 时间性能指标 - 统计结果 @@ -101,12 +190,18 @@ type ReportData struct { AvgInputTokenCount int `json:"avg_input_token_count"` // 平均输入token数量 MinInputTokenCount int `json:"min_input_token_count"` // 最小输入token数量 MaxInputTokenCount int `json:"max_input_token_count"` // 最大输入token数量 + AvgCachedInputTokenCount int `json:"avg_cached_input_token_count"` // 平均缓存命中的输入 token 数量 + MinCachedInputTokenCount int `json:"min_cached_input_token_count"` // 最小缓存命中的输入 token 数量 + MaxCachedInputTokenCount int `json:"max_cached_input_token_count"` // 最大缓存命中的输入 token 数量 AvgOutputTokenCount int `json:"avg_output_token_count"` // 平均输出token数量 MinOutputTokenCount int `json:"min_output_token_count"` // 最小输出token数量 MaxOutputTokenCount int `json:"max_output_token_count"` // 最大输出token数量 AvgThinkingTokenCount int `json:"avg_thinking_token_count"` // 平均思考token数量 MinThinkingTokenCount int `json:"min_thinking_token_count"` // 最小思考token数量 MaxThinkingTokenCount int `json:"max_thinking_token_count"` // 最大思考token数量 + AvgCacheHitRate float64 `json:"avg_cache_hit_rate"` // 平均缓存命中率 + MinCacheHitRate float64 `json:"min_cache_hit_rate"` // 最小缓存命中率 + MaxCacheHitRate float64 `json:"max_cache_hit_rate"` // 最大缓存命中率 AvgTPS float64 `json:"avg_tps"` // 平均输出 TPS (仅输出 tokens per second) MinTPS float64 `json:"min_tps"` // 最小输出 TPS MaxTPS float64 `json:"max_tps"` // 最大输出 TPS @@ -121,8 +216,10 @@ type ReportData struct { StdDevTTFT time.Duration `json:"stddev_ttft"` // TTFT 标准差 StdDevTPOT time.Duration `json:"stddev_tpot"` // TPOT 标准差 StdDevInputTokenCount float64 `json:"stddev_input_token_count"` // 输入 Token 数标准差 + StdDevCachedInputTokenCount float64 `json:"stddev_cached_input_token_count"` // 缓存命中输入 Token 数标准差 StdDevOutputTokenCount float64 `json:"stddev_output_token_count"` // 输出 Token 数标准差 StdDevThinkingTokenCount float64 `json:"stddev_thinking_token_count"` // 思考 Token 数标准差 + StdDevCacheHitRate float64 `json:"stddev_cache_hit_rate"` // 缓存命中率标准差 StdDevTPS float64 `json:"stddev_tps"` // 输出 TPS 标准差 StdDevTotalThroughputTPS float64 `json:"stddev_total_throughput_tps"` // 吞吐 TPS 标准差 @@ -131,6 +228,72 @@ type ReportData struct { SuccessRate float64 `json:"success_rate"` // 成功率 (%) } +type TaskDefinition struct { + ID string `json:"id"` + Name string `json:"name"` + Input Input `json:"input"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` + LastRunAt *time.Time `json:"last_run_at,omitempty"` + LastRunSummary *TaskRunSummary `json:"last_run_summary,omitempty"` +} + +type TaskRunSummary struct { + RunID string `json:"run_id"` + TaskID string `json:"task_id"` + Mode string `json:"mode"` + Status string `json:"status"` + Protocol string `json:"protocol"` + Model string `json:"model"` + StartedAt time.Time `json:"started_at"` + FinishedAt time.Time `json:"finished_at"` + SuccessRate float64 `json:"success_rate"` + AvgTTFT time.Duration `json:"avg_ttft"` + AvgTPS float64 `json:"avg_tps"` + CacheHitRate float64 `json:"cache_hit_rate"` + MaxStableConcurrency int `json:"max_stable_concurrency,omitempty"` + ReportJSONPath string `json:"report_json_path,omitempty"` + ReportCSVPath string `json:"report_csv_path,omitempty"` + ErrorSummary string `json:"error_summary,omitempty"` +} + +type TurboConfig struct { + InitConcurrency int `json:"init_concurrency"` + MaxConcurrency int `json:"max_concurrency"` + StepSize int `json:"step_size"` + LevelRequests int `json:"level_requests"` + MinSuccessRate float64 `json:"min_success_rate"` + MaxLatency time.Duration `json:"max_latency"` +} + +type TurboLevelResult struct { + Concurrency int `json:"concurrency"` + TotalRequests int `json:"total_requests"` + SuccessCount int `json:"success_count"` + SuccessRate float64 `json:"success_rate"` + AvgTPS float64 `json:"avg_tps"` + PeakTPS float64 `json:"peak_tps"` + AvgTTFT time.Duration `json:"avg_ttft"` + CacheHitRate float64 `json:"cache_hit_rate"` + AvgTotalTime time.Duration `json:"avg_total_time"` + StdDevTPS float64 `json:"stddev_tps"` + Stable bool `json:"stable"` + StopReason string `json:"stop_reason,omitempty"` +} + +type TurboResult struct { + Config TurboConfig `json:"config"` + Levels []TurboLevelResult `json:"levels"` + MaxStableConcurrency int `json:"max_stable_concurrency"` + PeakTPS float64 `json:"peak_tps"` + StopReason string `json:"stop_reason"` + ProbeDuration time.Duration `json:"probe_duration"` + Model string `json:"model"` + Protocol string `json:"protocol"` + EndpointURL string `json:"endpoint_url"` + Timestamp string `json:"timestamp"` +} + // MarshalJSON 自定义 JSON 序列化,将 time.Duration 转换为字符串 func (r *ReportData) MarshalJSON() ([]byte, error) { // 自定义序列化,所有 time.Duration 字段转为字符串 From b5e0b0730140006258eba80801970437c3e2ad1d Mon Sep 17 00:00:00 2001 From: Alain Date: Sun, 17 May 2026 00:16:43 +0800 Subject: [PATCH 03/52] Enhance TUI styles with a comprehensive color palette and new style definitions - Introduced a new color palette inspired by the Lip Gloss demo aesthetic. - Expanded the styles struct to include additional styles for headers, footers, table rows, badges, and more. - Updated the newStyles function to initialize the new styles with appropriate colors and formatting. - Improved visual hierarchy and consistency across the TUI components. --- docs/design.md | 1267 ++++++++++++++++++++++--------------- internal/store/history.go | 49 ++ internal/store/store.go | 65 ++ internal/store/task.go | 103 +++ internal/tui/messages.go | 4 + internal/tui/model.go | 1223 +++++++++++++++++++++++++++++------ internal/tui/styles.go | 194 +++++- 7 files changed, 2193 insertions(+), 712 deletions(-) create mode 100644 internal/store/history.go create mode 100644 internal/store/store.go create mode 100644 internal/store/task.go diff --git a/docs/design.md b/docs/design.md index 9309afa..bb3a9eb 100644 --- a/docs/design.md +++ b/docs/design.md @@ -25,13 +25,15 @@ 将现有"执行完即退出"的单次命令行模式升级为**全屏交互式终端界面**,并以任务管理作为主入口: -- 首页展示任务列表,可直接选择已有任务再次执行 -- 无需记忆所有参数,通过向导页创建或编辑任务 -- 任务详情页集中展示配置摘要、最近结果与运行记录 -- 测试运行时实时展示指标面板(成功率、TPS、TTFT、缓存命中率、并发状态) -- 请求日志滚动查看 -- 结果页支持键盘操作(生成报告、返回任务详情、再次运行等) -- 支持通过 CLI 参数生成临时任务草稿并进入 TUI 继续操作 +- `ait` 不带参数直接进入 TUI 任务中心;带完整参数时自动创建任务并立即启动,直接进入运行仪表盘(v2.0 不再提供独立的 `tpg` 子命令) +- 首页展示任务列表,可直接选择已有任务再次执行,运行中任务用 `◉` 实时标注进度 +- 无需记忆所有参数,通过弹窗向导创建或编辑任务 +- 任务详情页集中展示配置摘要和最近运行记录 +- 运行仪表盘:任务配置概览、实时指标汇总、进度、请求列表(每行一条请求含完整指标) +- 请求列表中可选中单条,按 `[Enter]` 进入请求详情页,展示原始输入/输出及全量性能数据 +- 支持多个任务同时在后台运行;任务间(尤其是网络指标)可能存在干扰,启动时给出提示 +- 仪表盘按 `[b]` 后台运行并返回任务列表;运行结束后自动跳转回任务详情页 +- 任务详情页兼做结果页:最近一次运行指标高亮展开,支持生成报告和复制摘要 ### 1.2 任务管理 @@ -39,9 +41,10 @@ - 每个任务只绑定一个模型,保存协议、完整接口地址、模型、Prompt、标准模式或 Turbo 参数 - 协议值细化为 `openai-completions`、`openai-responses`、`anthropic-messages` -- 首页以任务列表形式呈现,支持新建、编辑、删除、复制、搜索和直接运行 -- 任务详情页展示最近一次运行摘要和最近运行记录 +- 首页以任务列表形式呈现,支持新建、编辑、删除、复制和直接运行 +- 任务详情页展示配置摘要和最近运行记录 - 多模型回归通过多个任务组织,而不是在单个任务内批量执行 +- 支持同时运行多个任务;系统会提示任务间存在潜在干扰,尤其是 DNS/TCP/TLS 等网络指标 - 每次运行都会沉淀为任务记录,便于后续直接选择再次测试或对比回归结果 ### 1.3 Turbo 模式 @@ -106,159 +109,263 @@ main() ## 3. 新架构设计 -### 3.1 整体模块图 +### 3.1 SC 分层概览 + +整体采用 **Server-Client** 架构:所有业务能力集中在 Server 层,TUI 和未来的 Web UI 作为 Client 通过统一接口调用。当前是同进程函数调用(无网络开销),但清晰的接口边界使未来接入 Web UI 时无需改动任何业务逻辑。 ``` -┌─────────────────────────────────────────────────────────────────┐ -│ cmd/ait/main.go ─ 入口 & 模式路由 │ -│ │ -│ ┌─ 无必填参数 ──→ TUI 任务中心(任务列表) │ -│ └─ 有完整参数 ──→ TUI 任务详情(生成临时任务草稿) │ -└───────────────┬─────────────────────────────────────────────────┘ - │ - ┌───────────▼───────────────────────────────────────────────┐ - │ internal/tui/ (NEW) │ - │ │ - │ model.go ─ BubbleTea 根模型 + 状态机 │ - │ messages.go ─ 所有 Msg 类型定义 │ - │ styles.go ─ lipgloss 样式常量 │ - │ │ - │ tasklist/ ─ 任务列表页 │ - │ taskdetail/ ─ 任务详情页 │ - │ wizard/ ─ 新建 / 编辑任务向导页 │ - │ dashboard/ ─ 运行中仪表盘页 │ - │ result/ ─ 结果展示页 │ - │ turbo/ ─ Turbo 仪表盘页 │ - └───────────────────────────────────────────────────────────┘ - │ program.Send(msg) - │ ↑ 从任意 goroutine 安全推送 - ┌───────────▼──────────────────────────────┐ - │ internal/runner/runner.go(已有,扩展) │ - │ RunWithCallback(cb RequestDoneCallback) │ - └──────────────────────────────────────────┘ - │ - ┌───────────▼──────────────────────────────┐ - │ internal/turbo/ (NEW) │ - │ Runner ─ 并发爬坡调度器 │ - │ Strategy ─ 步进 & 终止策略 │ - └──────────────────────────────────────────┘ - │ - ┌───────────▼──────────────────────────────┐ - │ internal/task/ (NEW) │ - │ Store ─ 任务 CRUD / 搜索 / 排序 │ - │ History ─ 任务运行记录与最近结果摘要 │ - └──────────────────────────────────────────┘ - │ - ┌───────────▼──────────────────────────────┐ - │ internal/client/ (已有,不变) │ - │ OpenAI Completions / Responses / │ - │ Anthropic Messages HTTP 客户端 │ - └──────────────────────────────────────────┘ - - ┌─────────────────────────────────────────────────────────┐ - │ internal/config/ (NEW) │ - │ 从 ~/.ait/config.json 加载 / 保存全局偏好 │ - └─────────────────────────────────────────────────────────┘ - - ┌─────────────────────────────────────────────────────────┐ - │ internal/report/ (已有,扩展) │ - │ 新增 turbo_renderer.go ─ Turbo 爬坡报告 │ - └─────────────────────────────────────────────────────────┘ +┌──────────────────────────────────────────────────────────────┐ +│ cmd/ait/main.go ─ 入口 │ +│ 解析 flag → 创建 Server → 启动 TUI Client │ +└──────────────────────┬───────────────────────────────────────┘ + │ + ╔═══════════════════▼════════════════════════════════════╗ + ║ internal/server/ ─ SERVICE LAYER ║ + ║ ║ + ║ 对外暴露 Server 接口(任务管理 / 运行管理 / 事件订阅) ║ + ║ 编排下层执行/持久化/报告模块,不感知任何 UI 细节 ║ + ╚═══════╤══════════════════╤══════════════════════════════╝ + │ uses │ uses + ┌────────▼────────┐ ┌──────▼───────────┐ + │ 执行层 │ │ 持久化 & 工具层 │ + │ runner/ │ │ task/ │ + │ turbo/ │ │ config/ │ + │ client/ │ │ report/ │ + │ prompt/ │ │ logger/ ... │ + └─────────────────┘ └──────────────────┘ + + CLIENT LAYER(调用 Server 接口,不直接依赖下层): + ┌─────────────────────────────────────────────────────┐ + │ internal/tui/ ─ TUI Client(当前) │ + │ BubbleTea 状态机 + 页面渲染 │ + └─────────────────────────────────────────────────────┘ + ┌─────────────────────────────────────────────────────┐ + │ internal/webui/ ─ Web UI Bridge(未来) │ + │ HTTP / WebSocket 桥接,前端通过浏览器访问 │ + └─────────────────────────────────────────────────────┘ ``` -### 3.2 目录结构变化 +### 3.2 Server 接口定义 -```diff - cmd/ - ait/ -- ait.go ← 原来全部逻辑 -+ main.go ← 入口:模式检测 + 任务中心路由 -+ flags.go ← 所有 flag 定义 -+ run_tui.go ← TUI 模式启动 / 临时任务草稿注入 - - internal/ -+ tui/ -+ model.go ← 根 BubbleTea Model + 状态机 -+ messages.go ← 所有 Msg 类型 -+ styles.go ← lipgloss 样式常量 -+ tasklist/ -+ model.go ← 任务列表状态 -+ view.go ← 任务列表 UI 渲染 -+ taskdetail/ -+ model.go ← 任务详情状态 -+ view.go ← 任务详情 UI 渲染 -+ wizard/ -+ model.go ← 新建 / 编辑任务向导状态 -+ view.go ← 向导 UI 渲染 -+ dashboard/ -+ model.go ← 运行仪表盘状态 -+ view.go ← 仪表盘 UI 渲染 -+ result/ -+ model.go ← 结果页状态 -+ view.go ← 结果页 UI 渲染 -+ turbo/ -+ model.go ← Turbo 仪表盘状态 -+ view.go ← Turbo 仪表盘 UI 渲染 - -+ turbo/ -+ runner.go ← 并发爬坡调度器 -+ strategy.go ← 步进 & 终止策略 -+ result.go ← TurboResult、LevelResult 类型 - -+ config/ -+ config.go ← ~/.ait/config.json 全局配置读写 - -+ task/ -+ store.go ← ~/.ait/tasks.json 读写 + CRUD -+ history.go ← ~/.ait/history/.json 读写 - - runner/ - runner.go ← 扩展:增加 Stop() 方法 + 每请求完成回调 - - types/ - types.go ← 扩展:增加 TurboConfig、TurboResult - - report/ - report.go ← 已有 -+ turbo_renderer.go ← Turbo 爬坡报告渲染 +Server 层对外暴露一个统一的 Go 接口,TUI 和 Web UI 只依赖该接口,不直接依赖任何下层包: + +```go +// internal/server/server.go + +// Server 是 AIT 的核心服务接口。 +// TUI、Web UI 等所有前端仅依赖此接口,不直接依赖 runner/task 等底层包。 +type Server interface { + // 任务管理 + ListTasks() ([]types.Task, error) + GetTask(id string) (*types.Task, error) + CreateTask(cfg types.TaskConfig) (*types.Task, error) + UpdateTask(id string, cfg types.TaskConfig) error + DeleteTask(id string) error + CopyTask(id string) (*types.Task, error) + + // 运行管理 + StartRun(taskID string) (RunID, error) + StopRun(runID RunID) error + GetRunState(runID RunID) (*RunState, error) + + // 事件订阅(解耦 UI 刷新,替代直接回调) + // 返回事件 channel、取消订阅函数、错误 + Subscribe(runID RunID) (<-chan Event, CancelFunc, error) + + // 历史 & 报告 + GetHistory(taskID string) ([]RunSummary, error) + GenerateReport(runID RunID, format ReportFormat) (path string, err error) + + // 生命周期 + Shutdown() error +} + +// Event 是 Server 推送给订阅者的运行事件 +type Event struct { + RunID RunID + Kind EventKind // RequestDone | ProgressTick | LevelDone | RunComplete | RunFailed + Payload any // *RequestMetrics | *ProgressSnapshot | *LevelResult | *RunResult +} + +type CancelFunc func() ``` -### 3.3 关键设计原则 +### 3.3 各层职责边界 -**原则 1:Runner 不感知 TUI** +| 层 | 包 | 职责 | +|----|-----|------| +| **入口层** | `cmd/ait` | flag 解析;创建 Server;启动 TUI Client | +| **Server 层** | `internal/server` | 暴露业务 API;编排下层;管理运行状态;分发 Event | +| **执行层** | `internal/runner` `internal/turbo` | 并发请求执行;回调推送指标;**不感知 UI** | +| **协议层** | `internal/client` | OpenAI / Anthropic HTTP 客户端;不感知上层 | +| **持久化层** | `internal/task` `internal/config` | 任务 / 历史 / 配置的 JSON 文件读写 | +| **渲染层** | `internal/report` | JSON / CSV / Turbo 报告渲染;纯函数,无副作用 | +| **工具层** | `internal/prompt` `internal/network` `internal/logger` `internal/upload` | 公共工具,无业务依赖 | +| **TUI Client** | `internal/tui` | BubbleTea 状态机;**只依赖 server.Server 接口**;渲染终端 UI | +| **Web Client** _(Future)_ | `internal/webui` | HTTP/WS 桥接;**只依赖 server.Server 接口**;提供 Web API | -Runner 通过回调函数推送进度,不直接依赖 BubbleTea。TUI 模式下由外层把 `program.Send(msg)` 包进回调: +### 3.4 目录结构 -```go -// runner 接口不变,增加每请求完成的细粒度回调 -type RequestDoneCallback func(metrics *client.ResponseMetrics, index int, err error) +``` +cmd/ + ait/ + main.go ← 入口:创建 server.New() + 启动 TUI + +internal/ + server/ ← SERVICE LAYER + server.go ← Server 接口 + 实现 (New / Shutdown) + task.go ← 任务 CRUD 方法实现 + run.go ← 运行启动 / 停止 / 状态管理 + event.go ← Event / EventKind / RunState 类型 + types.go ← RunID / RunSummary / ReportFormat 等 Server 层类型 + + tui/ ← TUI CLIENT + client.go ← 持有 server.Server;提供 tea.Cmd 包装(异步调用 server) + model.go ← 根 BubbleTea Model + 全局状态机 + messages.go ← 所有 tea.Msg 类型 + styles.go ← lipgloss 样式常量 + pages/ + tasklist.go ← 任务列表页渲染 + 按键处理 + taskdetail.go ← 任务详情页 + wizard.go ← 新建 / 编辑弹窗向导(overlay,覆盖任务列表) + dashboard.go ← 标准模式仪表盘 + turbodash.go ← Turbo 仪表盘 + reqdetail.go ← 请求详情页 + contextbar.go ← Context Bar 组件(条件渲染) + + runner/ + runner.go ← 并发请求执行(RunWithCallback / Stop) + + turbo/ + engine.go ← 并发爬坡调度(Run / Stop) + strategy.go ← 步进 & 终止策略 + types.go ← TurboResult / LevelResult + + client/ + client.go ← AI 客户端接口定义 + openai.go ← OpenAI Completions / Responses + anthropic.go ← Anthropic Messages + + store/ ← 统一持久化层(独立包,可被多个模块引用) + store.go ← 泛型 JSON 文件读写基类(Load / Save / 文件锁) + task.go ← TaskStore:~/.ait/tasks.json CRUD + history.go ← HistoryStore:~/.ait/history/.json 读写 + + task/ + task.go ← 任务纯业务逻辑(不持有文件 I/O) + + config/ + config.go ← ~/.ait/config.json 全局配置(使用 store 基类) + + report/ + report.go + csv_renderer.go + json_renderer.go + turbo_renderer.go ← (新增) Turbo 爬坡报告 + + types/ + types.go ← 跨层共享领域类型(Task / TaskConfig / Input / ResponseMetrics ...) + + prompt/ ← 字符串 / 文件 / 长度 Prompt 生成 + network/ ← IP 工具 + logger/ ← 请求日志 + upload/ ← 匿名数据上传 +``` -// TUI 模式下的回调实现 -cb := func(m *client.ResponseMetrics, idx int, err error) { - program.Send(tui.RequestDoneMsg{Metrics: m, Index: idx, Err: err}) +### 3.5 TUI Client 与 Server 交互示意 + +``` +TUI Model (tea.Update) server.Server 底层执行模块 +─────────────────── ───────────── ──────────── +[用户按 r 运行任务] + │ + ├─ client.StartRunCmd(taskID) + │ └─ server.StartRun(taskID) + │ └─ 创建 RunState + │ └─ go runner.RunWithCallback(cb) ─→ runner/ + │ └─ cb 内部: eventBus.Publish(Event{RequestDone}) + │ └─ 返回 runID + │ + ├─ client.SubscribeCmd(runID) → tea.Cmd + │ └─ server.Subscribe(runID) → eventCh + │ └─ tea.Cmd: 持续从 eventCh 读事件 → tea.Msg + │ +[Event: RequestDone] + → RequestDoneMsg → Update() → 追加请求行 → View() + +[Event: RunComplete] + → RunCompleteMsg → Update() → 切换到任务详情页(最近运行展开) + +[用户按 b 后台运行] + │ cancelFunc() ← 停止接收事件,但运行仍在 Server 继续 + │ 返回任务列表 + │ 任务列表中◉ 标记:定时 server.GetRunState(runID) 轮询刷新进度 + +[用户重新进入仪表盘] + ├─ server.GetRunState(runID) ← 恢复当前快照(已完成请求列表) + └─ server.Subscribe(runID) ← 重新订阅,接收后续事件 +``` + +### 3.6 Web UI 接入路径(未来) + +新增 `internal/webui/` 包,直接复用同一个 `server.Server` 实例,无需修改 Server 层或任何下层模块: + +```go +// internal/webui/handler.go (示意) +func (h *Handler) startRun(w http.ResponseWriter, r *http.Request) { + runID, _ := h.server.StartRun(r.PathValue("taskID")) + eventCh, cancel, _ := h.server.Subscribe(runID) + defer cancel() + // 通过 SSE 或 WebSocket 将 Event 推给浏览器 + for event := range eventCh { + writeSSE(w, event) + } } -runner.RunWithCallback(cb) ``` -**原则 2:TUI 是纯状态机** +### 3.7 关键设计原则 + +**原则 1:TUI / Web UI 只依赖 server.Server 接口,不直接 import runner / task / report 等包。** + +**原则 2:Server 层只依赖执行层和持久化层,不 import 任何 UI 包(tui / webui)。** -`tui.Model` 通过消息驱动状态转换,不直接调用任何 I/O 函数,所有副作用都封装在 `tea.Cmd` 中,方便测试。 +**原则 3:执行层(runner / turbo)通过回调/channel 推送进度,不感知 UI,不 import tea 或 http。** -**原则 3:任务是一等领域对象** +**原则 4:TUI Model 是纯状态机。** 所有副作用(调用 server、读文件)封装在 `tea.Cmd` 中,`Update()` 只做状态转换,方便单元测试。 -Runner 消费的是一次运行所需的 `Input`,但 UI 和持久化围绕 `TaskDefinition` 展开。列表、详情、编辑、重跑、历史记录都基于任务对象组织,而不是把一次性的 flag 输入直接暴露给用户。 +**原则 5:一个任务只测一个模型。** 任务是最小回归单元,多模型对比通过创建多个任务实现。 + +### 3.8 迁移自现有架构的变化摘要 + +```diff + cmd/ait/ +- ait.go ← 含全部逻辑 ++ main.go ← 仅做:flag解析 + server.New() + tui启动 -**原则 4:一个任务只测一个模型** ++ internal/store/ ← 新增:独立持久化层(store.go / task.go / history.go) ++ internal/server/ ← 新增:SERVICE LAYER(核心新增) ++ server.go / task.go / run.go / event.go / types.go -任务是最小回归单元。`TaskDefinition` 只保存一个 `Model`,这样任务详情、运行记录、Turbo 极限和结果对比都能稳定映射到单一模型;若要覆盖多个模型,应创建多个任务分别执行。 + internal/tui/ ← 已有,重构: ++ client.go ← 新增:持有 server.Server,包装 tea.Cmd + model.go ← 改造:不再直接调用 runner/task,改调 server ++ pages/ ← 新增:各页面拆分为独立文件 -### 3.4 任务生命周期 + internal/runner/ + runner.go ← 扩展:增加 Stop() + RunWithCallback 稳定化 -1. 用户从任务列表进入“新建任务”向导,保存后写入 `~/.ait/tasks.json` -2. 用户在任务详情页查看配置摘要、最近结果和最近运行记录 -3. 用户从任务详情页启动标准模式或 Turbo 模式测试 -4. 测试完成后写入 `~/.ait/history/.json`,并回写任务的最近运行摘要 -5. 用户后续可直接在任务列表或任务详情页再次运行,无需重新输入参数 + internal/turbo/ ++ engine.go / strategy.go / types.go ← 新增 + + internal/task/ ++ task.go ← 任务纯业务逻辑(不持有文件 I/O) + + internal/report/ ++ turbo_renderer.go ← 新增 + +- internal/display/ ← 废弃(功能被 tui/ 和 server/ 替代) +- cmd/tpg/ ← 废弃(功能合并进 tui 向导) +``` --- @@ -277,57 +384,70 @@ Runner 消费的是一次运行所需的 `Input`,但 UI 和持久化围绕 `Ta ``` 启动无参数 ───────────────────────────────→ TaskList -启动带完整参数 ─→ 生成临时任务草稿 ───────→ TaskDetail +启动带完整参数 ─→ 创建任务 + 自动 StartRun ─→ Running / TurboRunning ┌─────────────┐ │ TaskList │ │ 任务列表页 │ └──────┬──────┘ - [a 新建] │ │ [Enter] - │ │ - ▼ ▼ - ┌─────────────┐ - │ Wizard │ - │ 新建/编辑任务 │ - └──────┬──────┘ - [保存任务] │ - ▼ + [a 新建/e 编辑] │ │ [Enter] + ╔══════▼══╗ │ │ ← Wizard 弹窗 overlay(不切换页面) + ║ Wizard ║ │ │ + ║ 弹窗向导 ║ │ │ + ╚══════╤══╝ │ │ + [保存] │ │ │ + └─────┘ │ + ▼ ┌─────────────┐ │ TaskDetail │ │ 任务详情页 │ └──────┬──────┘ - [Enter / r] │ │ [e 编辑] + [Enter / r] │ │ [e 编辑 → 弹窗] │ └──────────────┐ │ │ [标准模式] ▼ │ ┌─────────────┐ │ │ Running │ │ │ 标准运行中 │ │ - └──────┬──────┘ │ - [完成/s] │ │ - ▼ │ - ┌─────────────┐ │ - │ Completed │ │ - │ 标准结果页 │ │ - └──────┬──────┘ │ - [b 返回详情] │ │ + └──────┬──┬───┘ │ + [完成/s 停止] │ │ [b/Esc 后台] │ + │ └──→ TaskList │ + ▼ (◉ 标记) │ └────────────────┘ + (完成/停止后直接返回 TaskDetail,最近运行展开) [Turbo 模式] ▼ ┌─────────────┐ │TurboRunning │ │ Turbo 爬坡中 │ - └──────┬──────┘ - [完成/s] │ - ▼ + └──────┬──┬───┘ + [完成/s 停止] │ │ [b/Esc 后台] + │ └──→ TaskList + ▼ (◉ 标记) + └──────────────→ TaskDetail + (完成/停止后直接返回 TaskDetail,最近运行展开) + + [请求详情] 在 Running/TurboRunning 请求列表中选中后 ┌─────────────┐ - │TurboCompleted│ - │ Turbo 结果页 │ + │RequestDetail│ + │ 请求详情页 │ └──────┬──────┘ - [b 返回详情] │ - └──────────────→ TaskDetail + [b/Esc 返回] │ + └──────→ Running / TurboRunning ``` +**多任务并发规则:** + +- 支持多个任务同时在后台运行,不限数量 +- 启动第二个任务时弹出提示:"当前已有 N 个任务正在运行,多任务并行可能影响网络指标(DNS/TCP/TLS),`[y]` 继续 `[n]` 取消" +- 任务列表中所有运行中任务都带 `◉` 标记和实时进度 +- 任务完成后自动更新对应任务的状态和历史记录,无论当前处于哪个页面 + +**后台运行规则:** + +- 在仪表盘(Running / TurboRunning)按 `[b]` 或 `[Esc]` 可返回任务列表,测试继续在后台执行 +- 任务列表中正在运行的任务行首显示 `◉` 标记,对其按 `[Enter]` 可随时重新进入仪表盘 + ### 4.3 页面设计 --- @@ -335,284 +455,274 @@ Runner 消费的是一次运行所需的 `Input`,但 UI 和持久化围绕 `Ta #### 页面 1:任务列表首页 ``` -╔══ AIT 任务中心 ─────────────────────────────────────────────══╗ -║ 已保存任务: 12 最近运行: 2026-05-16 09:42 [/] 搜索任务 ║ -╠══════════════════════╦═══════════════════════════════════════════╣ -║ 任务列表 ║ 快捷操作 / 最近摘要 ║ -║ ║ ║ -║ ▶ nightly-openai ║ [a] 新建任务 ║ -║ 标准模式 · gpt-4o ║ [Enter] 查看详情 ║ -║ 并发 10 / 请求 200 ║ [r] 直接运行选中任务 ║ -║ 上次: 98.5% · 12m 前 ║ [e] 编辑 [d] 删除 [y] 复制任务 ║ -║ ║ ║ -║ turbo-anthropic ║ 最近执行 ║ -║ Turbo · claude-3-7 ║ nightly-openai ✓ 98.5% 245ms ║ -║ 1→50 +2 / 每级 30 ║ turbo-anthropic ★ 稳定并发 8 ║ -║ 上次: 峰值 TPS 245.3 ║ smoke-regression ✗ timeout ×2 ║ -║ ║ ║ -║ smoke-regression ║ 提示:支持按任务名、协议、模型、模式过滤║ -║ 标准模式 · gpt-4o-mini ║ ║ -║ 并发 2 / 请求 20 ║ ║ -║ 从未运行 ║ ║ -╠══════════════════════╩═══════════════════════════════════════════╣ -║ [↑↓] 选择 [Enter] 详情 [a] 新建 [r] 运行 [q] 退出 ║ -╚══════════════════════════════════════════════════════════════════╝ +╔══ AIT 任务中心 ──────────────────────────────────────══════╗ +║ ◆ AIT 已保存任务: 3 最近运行: 2026-05-16 09:42 ║ +╠══════════════════════════════════════════════════════════════╣ +║ 任务名称 模式 协议 上次结果║ +║ ──────────────────────────────────────────────────────────║ +║ ▶ ◉ nightly-openai 标准 responses ✓ 98.5% ║ +║ gpt-4o 并发10 请求200 ◉ 47/100 成功率 98.0% ║ +║ ║ +║ turbo-anthropic Turbo messages ★ 并发8 ║ +║ claude-3-7 1→50 步进+2 上次: 峰值 TPS 245.3 ║ +║ ║ +║ smoke-regression 标准 completions 从未运行 ║ +║ gpt-4o-mini 并发2 请求20 ║ +║ ║ +╠══════════════════════════════════════════════════════════════╣ +║ [Enter] 详情/仪表盘 [a] 新建 [e] 编辑 [d] 删除 [r] 运行 ║ ← context bar +╠══════════════════════════════════════════════════════════════╣ +║ [↑↓] 选择 [y] 复制 [q] 退出 ◆ AIT v0.1 ║ +╚══════════════════════════════════════════════════════════════╝ ``` +> **说明:** 任务名前的 `◉` 表示该任务正在后台运行,子行显示实时进度。Context bar(倒数第三行)根据当前选中任务动态调整可用操作;若选中的是运行中任务,`[Enter]` 进入仪表盘而非详情页。 + --- #### 页面 2:任务详情页 +*(兼做结果页——运行结束后自动跳转至此,最近一次运行展开展示)* + ``` -╔══ AIT 任务详情 ─ nightly-openai ────────────────────────────══╗ -║ 任务 ID: task_01 更新: 2026-05-16 09:30 最近运行: 12m 前 ║ -╠══════════════════════╦═══════════════════════════════════════════╣ -║ 配置摘要 ║ 最近一次结果 ║ -║ ║ ║ -║ 协议 openai-responses ║ 状态 ✓ 完成 ║ -║ 接口地址 https://api.openai.com/v1/responses ║ -║ 模型 gpt-4o ║ avg TTFT 245ms ║ -║ 模式 标准模式 ║ avg TPS 124.3 tok/s ║ -║ 并发 10 ║ 缓存命中率 42.0% ║ -║ 请求数 200 ║ 总耗时 20.4s ║ -║ Prompt 你好,介绍一下你自己。║ 报告 ait-report-...json ║ -╠══════════════════════╩═══════════════════════════════════════════╣ -║ 最近运行记录 ║ -║ 2026-05-16 09:30 ✓ 98.5% TTFT 245ms Cache 42% 20.4s ║ -║ 2026-05-15 23:10 ✓ 99.0% TTFT 231ms Cache 38% 19.8s ║ -║ 2026-05-15 21:42 ✗ timeout ×2 Cache 12% 31.2s ║ -╠══════════════════════════════════════════════════════════════════╣ -║ [Enter] 运行 [e] 编辑 [h] 完整历史 [d] 删除 [b] 返回 ║ -╚══════════════════════════════════════════════════════════════════╝ +╔══ AIT 任务详情 ─ nightly-openai ──────────────────────════╗ +║ ◆ AIT 任务 ID: task_01 更新: 2026-05-16 09:30 刚刚 ║ +╠══════════════════════════════════════════════════════════════╣ +║ 配置摘要 ║ +║ 协议 openai-responses 接口 https://api.openai.com/... ║ +║ 模型 gpt-4o 模式 标准模式 并发 10 请求 200 ║ +║ 超时 30s 流式 开启 Prompt 你好,介绍一下你自己。 ║ +╠══════════════════════════════════════════════════════════════╣ +║ 最近运行 ▼ 2026-05-16 09:30 ✓ 完成 100 请求 耗时 20.4s ║ ← 展开行 +║ ────────────────────────────────────────────────────────── ║ +║ 指标 最小值 平均值 标准差 最大值 ║ +║ 总耗时 0.82s 1.24s ±0.31s 3.12s ║ +║ TTFT 71ms 245ms ±89ms 812ms ║ +║ 输出 TPS 89.2 124.3 ±21.4 198.5 ║ +║ 缓存命中率 0.0% 42.0% ±18.5% 100.0% ║ +║ 输入 Token 42 64 ±12 98 ║ +║ 输出 Token 78 128 ±32 195 ║ +║ 错误 context deadline exceeded (timeout) × 2 ║ +╠══════════════════════════════════════════════════════════════╣ +║ 历史运行记录 ║ +║ 时间 模式 成功率 TTFT TPS Cache ║ +║ ────────────────────────────────────────────────────────── ║ +║ 2026-05-15 23:10 标准 99.0% 231ms 128.1 38.0% ║ +║ 2026-05-15 21:42 标准 87.0% timeout ×2 12.0% ║ +╠══════════════════════════════════════════════════════════════╣ +║ [Enter/r] 运行 [r] 生成报告 [c] 复制摘要 [e] 编辑 ║ ← context bar +╠══════════════════════════════════════════════════════════════╣ +║ [b/Esc] 返回列表 ◆ AIT v0.1 ║ +╚══════════════════════════════════════════════════════════════╝ ``` +> **说明:** 运行结束(或 `[s]` 停止)后自动跳转至此,最近一次运行结果默认展开;历史记录折叠为摘要行。`[r] 生成报告` 和 `[c] 复制摘要` 仅在有运行记录时出现在 Context bar。 + --- -#### 页面 3:向导 - 新建任务(Step 1/3) +#### 页面 3:弹窗向导 — 新建任务(Step 1/3 · 基本信息) + +*(overlay 覆盖任务列表,背景列表只读;编辑任务时同样覆盖任务详情页)* ``` -╔══════════════════════════════════════════════════════════════════╗ -║ ██████╗ ██╗████████╗ AI 模型性能测试工具 v2.0 ║ -║ ██╔══██╗██║╚══██╔══╝ https://github.com/yinxulai/ait ║ -║ ███████║██║ ██║ ║ -║ ██╔══██║██║ ██║ 向导 1/3 · 新建任务 ║ -║ ██║ ██║██║ ██║ ║ -║ ╚═╝ ╚═╝╚═╝ ╚═╝ ║ -╠══════════════════════════════════════════════════════════════════╣ -║ ║ -║ 任务名称 nightly-openai ║ -║ ────────────────────────────────────────── ║ -║ ║ -║ 协议类型 > openai-responses ║ -║ ○ openai-completions ║ -║ ● openai-responses ○ anthropic-messages ║ -║ ║ -║ 接口地址 https://api.openai.com/v1/responses ║ -║ ────────────────────────────────────────── ║ -║ 提示:填写完整接口地址,而不是 base URL ║ -║ ║ -║ API 密钥 sk-•••••••••••••••••••••••••••••• ║ -║ ────────────────────────────────────────── ║ -║ ║ -║ 测试模型 gpt-4o ║ -║ ────────────────────────────────────────── ║ -║ 提示:每个任务仅允许选择一个模型 ║ -║ ║ -╠══════════════════════════════════════════════════════════════════╣ -║ [Tab] 下一项 [↑↓] 切换协议 [Enter] 下一步 [Esc] 返回 ║ -╚══════════════════════════════════════════════════════════════════╝ +╔══ AIT 任务中心 ─────────────────────────────════╗ +║ ◆ AIT (列表背景,只读暗化) ║ +║ ┌────────────── 新建任务 1/3 · 基本信息 ──────────────┐ ║ +║ │ │ ║ +║ │ 任务名称 ________________________ │ ║ +║ │ nightly-openai │ ║ +║ │ │ ║ +║ │ 协议类型 ● openai-responses │ ║ +║ │ ○ openai-completions │ ║ +║ │ ○ anthropic-messages │ ║ +║ │ │ ║ +║ │ 接口地址 ________________________ │ ║ +║ │ https://api.openai.com/... │ ║ +║ │ 提示:填写完整接口地址,而非 base URL │ ║ +║ │ │ ║ +║ │ API 密钥 ________________________ │ ║ +║ │ sk-•••••••••••••••• │ ║ +║ │ │ ║ +║ │ 测试模型 ________________________ │ ║ +║ │ gpt-4o │ ║ +║ │ 提示:每个任务仅允许选择一个模型 │ ║ +║ │ │ ║ +║ ├───────────────────────────────────────────────────────┤ ║ +║ │ [Tab] 下一项 [↑↓] 切换协议 [Enter] 下一步 [Esc] 取消 │ ║ +║ └───────────────────────────────────────────────────────┘ ║ +╚══════════════════════════════════════════════════════════════╝ ``` --- -#### 页面 4:向导 - 测试参数(Step 2/3) +#### 页面 4:弹窗向导 — 新建任务(Step 2/3 · 测试参数) ``` -╔══════════════════════════════════════════════════════════════════╗ -║ AIT v2.0 向导 2/3 · 任务参数 ║ -╠══════════════════════════════════════════════════════════════════╣ -║ ║ -║ 测试模式 ○ 标准模式 ● Turbo 模式 ║ -║ [←→ 切换] ║ -║ ║ -║ ── 标准模式参数 ────────────────────────────────────────── ║ -║ 并发数 [ 5 ] 请求总数 [ 100 ] ║ -║ 超时时间 [ 300s ] 流式模式 [✓ 开启] ║ -║ ║ -║ ── Turbo 模式参数 ──────────────────────────────────────── ║ -║ 初始并发 [ 1 ] 最大并发 [ 50 ] ║ -║ 步进值 [ 2 ] 每级请求数 [ 30 ] ║ -║ 停止条件 成功率低于 [ 90% ] 或 延迟超过 [ 10s ] ║ -║ ║ -║ ── Prompt 配置 ─────────────────────────────────────────── ║ -║ 输入方式 ● 直接输入 ○ 文件 ○ 按长度生成 ║ -║ 内容 你好,介绍一下你自己。 ║ -║ ────────────────────────────────────────── ║ -║ ║ -║ 运行后记录 [✓ 保存运行记录到任务历史] ║ -║ ║ -╠══════════════════════════════════════════════════════════════════╣ -║ [Tab] 下一项 [←→] 切换模式 [Enter] 下一步 [Esc] 返回 ║ -╚══════════════════════════════════════════════════════════════════╝ +╔══ AIT 任务中心 ─────────────────────────────════╗ +║ ◆ AIT (列表背景,只读暗化) ║ +║ ┌────────────── 新建任务 2/3 · 测试参数 ──────────────┐ ║ +║ │ │ ║ +║ │ 测试模式 ○ 标准模式 ● Turbo 模式 │ ║ +║ │ [←→ 切换] │ ║ +║ │ ── 标准模式参数 ─────────────────────────────── │ ║ +║ │ 并发数 [ 5 ] 请求总数 [ 100 ] │ ║ +║ │ 超时时间 [ 30s ] 流式模式 [✓ 开启] │ ║ +║ │ ── Turbo 模式参数 ───────────────────────────── │ ║ +║ │ 初始并发 [ 1 ] 最大并发 [ 50 ] │ ║ +║ │ 步进值 [ 2 ] 每级请求数 [ 30 ] │ ║ +║ │ 停止条件 成功率低于 [ 90% ] 或 延迟 > [ 10s ] │ ║ +║ │ ── Prompt 配置 ──────────────────────────────── │ ║ +║ │ 输入方式 ● 直接输入 ○ 文件 ○ 按长度生成 │ ║ +║ │ 内容 你好,介绍一下你自己。 │ ║ +║ │ ───────────────────────────────── │ ║ +║ │ │ ║ +║ ├───────────────────────────────────────────────────────┤ ║ +║ │ [Tab] 下一项 [←→] 切换模式 [Enter] 下一步 [Esc] 返回 │ ║ +║ └───────────────────────────────────────────────────────┘ ║ +╚══════════════════════════════════════════════════════════════╝ ``` --- -#### 页面 5:向导 - 确认(Step 3/3) +#### 页面 5:弹窗向导 — 新建任务(Step 3/3 · 确认保存) ``` -╔══════════════════════════════════════════════════════════════════╗ -║ AIT v2.0 向导 3/3 · 保存任务 ║ -╠══════════════════════════════════════════════════════════════════╣ -║ ║ -║ 🆔 任务 ID a3f2-8b1c-... ║ -║ 🏷️ 任务名称 nightly-openai ║ -║ 🔗 协议 openai-responses ║ -║ 🌐 接口地址 https://api.openai.com/v1/responses ║ -║ 🔑 API 密钥 sk-****...**** ║ -║ 🤖 测试模型 gpt-4o ║ -║ 🚀 测试模式 Turbo 模式 ║ -║ ⚡ 并发爬坡 1 → 50 步进 +2 每级 30 请求 ║ -║ 🛑 停止条件 成功率 < 90% 或 延迟 > 10s ║ -║ 🌊 流式模式 开启 ║ -║ 📝 Prompt 你好,介绍一下你自己。 (长度: 12) ║ -║ ║ -║ 💾 保存任务到 ~/.ait/tasks.json [✓] ║ -║ 📝 创建后自动写入运行历史索引 [✓] ║ -║ ║ -║ ┌────────────────────────────────────────────────────────┐ ║ -║ │ ▶ 保存任务 │ ║ -║ └────────────────────────────────────────────────────────┘ ║ -║ ║ -╠══════════════════════════════════════════════════════════════════╣ -║ [Enter] 保存任务 [r] 保存并运行 [Esc] 返回修改 ║ -╚══════════════════════════════════════════════════════════════════╝ +╔══ AIT 任务中心 ─────────────────────────────════╗ +║ ◆ AIT (列表背景,只读暗化) ║ +║ ┌────────────── 新建任务 3/3 · 确认保存 ──────────────┐ ║ +║ │ │ ║ +║ │ 任务名称 nightly-openai │ ║ +║ │ 协议 openai-responses │ ║ +║ │ 接口地址 https://api.openai.com/... │ ║ +║ │ API 密钥 sk-****...**** │ ║ +║ │ 测试模型 gpt-4o │ ║ +║ │ 测试模式 Turbo 模式 │ ║ +║ │ 并发爬坡 1 → 50 步进 +2 每级 30 请求 │ ║ +║ │ 停止条件 成功率 < 90% 或 延迟 > 10s │ ║ +║ │ 流式模式 开启 │ ║ +║ │ Prompt 你好,介绍一下你自己。 (长度: 12) │ ║ +║ │ │ ║ +║ │ 保存任务到 ~/.ait/tasks.json [✓] │ ║ +║ │ │ ║ +║ │ ▶ 保存任务 │ ║ +║ │ │ ║ +║ ├───────────────────────────────────────────────────────┤ ║ +║ │ [Enter] 保存任务 [r] 保存并运行 [Esc] 返回修改 │ ║ +║ └───────────────────────────────────────────────────────┘ ║ +╚══════════════════════════════════════════════════════════════╝ ``` +> **说明:** 弹窗向导不切换页面,以 overlay 方式覆盖当前页面(任务列表或任务详情)。新建 / 编辑同一套 overlay,编辑时内容预填。 +> +> **自动运行规则:** `[Enter]` 保存任务时,若当前**没有任何运行中的任务**,自动调用 `StartRun` 并进入仪表盘;若已有任务运行,则仅保存并返回任务列表(不弹干扰提示)。`[r]` 保存并运行时**无论是否有其他任务运行,始终启动**(与多任务并发规则一致,启动前弹干扰风险提示)。 + --- #### 页面 6:标准模式运行仪表盘 ``` -╔══ AIT 正在测试 ─ gpt-4o ────────────────────────────────────════╗ -║ 任务: nightly-openai 协议: openai-responses 并发: 5 请求: 100║ -╠══════════════════════╦═══════════════════════════════════════════╣ -║ 进度 ║ 实时指标 ║ -║ ║ ║ -║ 完成 ████████░░ 47 ║ 成功率 ██████████████████░░░ 98.0% ║ -║ 失败 ░░░░░░░░░░ 2 ║ ║ -║ 总计 100 ║ avg TPS 124.3 tok/s ║ -║ ║ avg TTFT 245ms ║ -║ ──────────────────── ║ 缓存命中率 42.0% ║ -║ 已用时 12.4s ║ avg 总耗时 1.24s ║ -║ 预计剩余 ~8.2s ║ 并发槽 [●●●●●] 5/5 活跃 ║ -║ ║ ║ -╠══════════════════════╩═══════════════════════════════════════════╣ -║ 请求日志 [l 展开] ║ -║ ✓ #48 245ms TTFT:82ms cache:100% 128tok 12.3tok/s ║ -║ ✗ #47 timeout (30.0s) ║ -║ ✓ #46 312ms TTFT:95ms cache:25% 96tok 9.8tok/s ║ -║ ✓ #45 198ms TTFT:71ms cache:0% 145tok 14.2tok/s ║ -╠══════════════════════════════════════════════════════════════════╣ -║ [p] 暂停 [s] 停止 [l] 切换日志详情 [r] 提前报告 [q] 退出 ║ -╚══════════════════════════════════════════════════════════════════╝ +╔══ AIT 正在测试 ─ nightly-openai ──────────────────────════╗ +║ ◆ AIT gpt-4o · openai-responses · 并发: 5 · 请求: 100 ║ +╠══════════════════════════╦═════════════════════════════════╣ +║ 任务参数 ║ 实时指标 ║ +║ ║ ║ +║ 协议 responses ║ 成功率 98.0% ║ +║ 模型 gpt-4o ║ avg TPS 124.3 tok/s ║ +║ 并发 5 请求 100 ║ avg TTFT 245ms ║ +║ 超时 30s 流式 开启 ║ 缓存命中 42.0% ║ +║ ║ avg 总耗时 1.24s ║ +║ ║ 成功: 45 失败: 2 ║ +╠══════════════════════════╩═════════════════════════════════╣ +║ 进度 ████████░░ 47 / 100 已用: 12.4s 剩余: ~8.2s ║ +╠═════════════════════════════════════════════════════════════╣ +║ 请求列表 ║ +║ # 状态 总耗时 TTFT Cache 输出Token TPS ║ +║ ──────────────────────────────────────────────────────── ║ +║ ▶ #48 ✓ 245ms 82ms 100% 128tok 12.3/s ║ +║ #47 ✗ timeout(30.0s) ║ +║ #46 ✓ 312ms 95ms 25% 96tok 9.8/s ║ +║ #45 ✓ 198ms 71ms 0% 145tok 14.2/s ║ +║ #44 ✓ 271ms 103ms 50% 112tok 11.1/s ║ +╠═════════════════════════════════════════════════════════════╣ +║ [Enter] 查看请求详情 [↑↓] 选择请求 [s] 停止 ║ ← context bar +╠═════════════════════════════════════════════════════════════╣ +║ [s] 停止 [b] 后台运行 [r] 提前报告 [q] 退出 ║ +╚═════════════════════════════════════════════════════════════╝ ``` -#### 页面 7:Turbo 模式运行仪表盘 - -``` -╔══ AIT Turbo 模式 ─ gpt-4o ──────────────────────────────────════╗ -║ 任务: turbo-anthropic 协议: anthropic-messages ║ -║ 爬坡: 1→50 步进: +2/级 每级: 30 请求 ║ -╠══════════════════════╦═══════════════════════════════════════════╣ -║ 爬坡曲线 (TPS) ║ 当前级别 [并发 = 8] ║ -║ ║ ║ -║ 250┤ ╭──● ║ 成功率 █████████████████░░ 96.0% ║ -║ 200┤ ╭────╯ ║ TPS 245.3 tok/s ║ -║ 150┤ ╭──╯ ║ TTFT 312ms ║ -║ 100┤─╯ ║ Cache 44.0% ║ -║ 50┤ ║ 总耗时 1.51s ║ -║ └──┬──┬──┬──┬──→ ║ 本级完成 28 / 30 ║ -║ 1 2 4 6 8 ║ 状态 🟢 稳定,继续探测... ║ -╠══════════════════════╩═══════════════════════════════════════════╣ -║ 并发 成功率 TPS TTFT Cache 总耗时 状态 ║ -║ ────────────────────────────────────────────────────────────── ║ -║ 1 100.0% 31.2 89ms 0.0% 0.82s ✓ 稳定 ║ -║ 2 100.0% 62.5 91ms 18.0% 0.84s ✓ 稳定 ║ -║ 4 99.0% 121.3 98ms 26.0% 0.91s ✓ 稳定 ║ -║ 6 98.0% 178.4 124ms 33.0% 1.08s ✓ 稳定 ║ -║ ▶ 8 96.0% 245.3 312ms 44.0% 1.51s 🔄 探测中 ║ -╠══════════════════════════════════════════════════════════════════╣ -║ [s] 停止 [m] 手动标记为极限 [r] 提前生成报告 [q] 退出 ║ -╚══════════════════════════════════════════════════════════════════╝ -``` +> **说明:** 上方左右分栏:左侧展示任务参数,右侧展示实时指标(任务完成后变为结果指标);进度条独立一行;请求列表按完成时间倒序滚动(最新在上方),可用 `[↑↓]` 选中一行,`[Enter]` 进入请求详情页。Context bar 根据是否有选中请求动态显示可用操作。 --- -#### 页面 8:标准模式结果页 +#### 页面 7:Turbo 模式运行仪表盘 ``` -╔══ AIT 测试完成 ─ gpt-4o ─────────────────────────────────────════╗ -║ 任务: nightly-openai 协议: openai-responses ║ -║ 耗时: 20.4s 成功率: 98.0% 总请求: 100 ║ -╠══════════════════════════════════════════════════════════════════╣ -║ 指标 最小值 平均值 标准差 最大值 ║ -║ ────────────────────────────────────────────────────────────── ║ -║ 总耗时 0.82s 1.24s ±0.31s 3.12s ║ -║ TTFT 71ms 245ms ±89ms 812ms ║ -║ TPOT 12ms 18ms ±4ms 45ms ║ -║ 输出 TPS 89.2 124.3 ±21.4 198.5 ║ -║ 吞吐 TPS 102.1 148.7 ±25.2 231.4 ║ -║ 缓存命中率 0.0% 42.0% ±18.5% 100.0% ║ -║ 输入 Token 42 64 ±12 98 ║ -║ 输出 Token 78 128 ±32 195 ║ -║ DNS 时间 1.2ms 3.4ms 12.1ms ║ -║ TCP 连接时间 2.1ms 4.8ms 9.3ms ║ -║ TLS 握手时间 8.4ms 12.3ms 28.7ms ║ -╠══════════════════════════════════════════════════════════════════╣ -║ 错误摘要 (2 个错误,占 2.0%) ║ -║ context deadline exceeded (timeout) × 2 ║ -╠══════════════════════════════════════════════════════════════════╣ -║ 任务记录已更新:最近运行摘要 + 历史索引 ║ -╠══════════════════════════════════════════════════════════════════╣ -║ [r] 生成报告 [c] 复制摘要 [b] 返回任务详情 [q] 退出 ║ -╚══════════════════════════════════════════════════════════════════╝ +╔══ AIT Turbo 探测 ─ turbo-anthropic ───────────────────════╗ +║ ◆ AIT claude-3-7 · anthropic-messages · 1→50 步进+2 ║ +╠══════════════════════════╦═════════════════════════════════╣ +║ 任务参数 ║ 当前级别实时指标 [并发 = 8] ║ +║ ║ ║ +║ 协议 messages ║ 成功率 96.0% ║ +║ 模型 claude-3-7 ║ TPS 245.3 tok/s ║ +║ 爬坡 1→50 步进+2 ║ TTFT 312ms ║ +║ 每级 30 请求 ║ Cache 44.0% ║ +║ 停止 成功率 < 90% ║ 总耗时 1.51s ║ +║ 延迟 > 10s ║ 已用 38.2s ║ +╠══════════════════════════╩═════════════════════════════════╣ +║ 进度 ████████░░ 28/30 当前并发 8 总进度: 已完成 4/~25 级║ +╠═════════════════════════════════════════════════════════════╣ +║ 级别列表 ║ +║ 并发 成功率 TPS TTFT Cache 总耗时 结论 ║ +║ ──────────────────────────────────────────────────────── ║ +║ 1 100.0% 31.2 89ms 0.0% 0.82s ✓ 稳定 ║ +║ 2 100.0% 62.5 91ms 18.0% 0.84s ✓ 稳定 ║ +║ 4 99.0% 121.3 98ms 26.0% 0.91s ✓ 稳定 ║ +║ 6 98.0% 178.4 124ms 33.0% 1.08s ✓ 稳定 ║ +║ ▶ 8 96.0% 245.3 312ms 44.0% 1.51s 🔄 进行中 ║ +╠═════════════════════════════════════════════════════════════╣ +║ [Enter] 查看该级别请求列表 [↑↓] 选择 ║ ← context bar +╠═════════════════════════════════════════════════════════════╣ +║ [s] 停止 [b] 后台运行 [m] 标记极限 [r] 提前报告 [q] 退出║ +╚═════════════════════════════════════════════════════════════╝ ``` +> **说明:** 上方左右分栏:左侧展示任务参数,右侧展示当前级别实时指标(并发标号显示在标题);进度条同时显示当前级别进度和总体级别进度;级别列表可用 `[↑↓]` 选中已完成的级别,`[Enter]` 查看该级别的请求列表(复用标准仪表盘的请求列表布局,状态为只读)。 + --- -#### 页面 9:Turbo 模式结果页 +#### 页面 8:请求详情页 ``` -╔══ AIT Turbo 完成 ─ gpt-4o ──────────────────────────────────════╗ -║ 任务: turbo-anthropic 协议: anthropic-messages ║ -║ 🏆 最大稳定并发: 8 峰值 TPS: 245.3 tok/s 探测耗时: 52s ║ -╠══════════════════════════════════════════════════════════════════╣ -║ TPS 爬坡曲线 成功率曲线 ║ -║ ║ -║ 300┤ ╭─●最大稳定 245.3 100%┤████████████ ║ -║ 200┤ ╭────╯ ╲降级 ║ ████░░ ║ -║ 100┤ ╭────╯ ╲ 95%┤ ░░░ ← 阈值 ║ -║ 0┤───╯ ● 90%└──────────────→ ║ -║ └──┬──┬──┬──┬──┬──→ 1 2 4 6 8 10 ║ -║ 1 2 4 6 8 10 并发数 ║ -║ ║ -╠══════════════════════════════════════════════════════════════════╣ -║ 并发 成功率 TPS TTFT Cache 总耗时 结论 ║ -║ ────────────────────────────────────────────────────────────── ║ -║ 1 100.0% 31.2 89ms 0.0% 0.82s ✓ 稳定 ║ -║ 2 100.0% 62.5 91ms 18.0% 0.84s ✓ 稳定 ║ -║ 4 99.0% 121.3 98ms 26.0% 0.91s ✓ 稳定 ║ -║ 6 98.0% 178.4 124ms 33.0% 1.08s ✓ 稳定 ║ -║ ★ 8 96.0% 245.3 312ms 44.0% 1.51s ✓ 最大稳定 ║ -║ 10 84.0% 198.1 892ms 12.0% 4.23s ✗ 降级 ║ -╠══════════════════════════════════════════════════════════════════╣ -║ 任务记录已更新:最近运行摘要 + 历史索引 ║ -╠══════════════════════════════════════════════════════════════════╣ -║ [r] 生成报告 [d] 详细数据 [b] 返回任务详情 [q] 退出 ║ -╚══════════════════════════════════════════════════════════════════╝ +╔══ AIT 请求详情 - nightly-openai #48 ─────────────────════╗ +║ ◆ AIT 任务: nightly-openai 请求 #48 / 100 ✓ 成功 ║ +╠══════════════════════════╦═════════════════════════════════╣ +║ 性能指标 ║ 网络指标 ║ +║ ║ ║ +║ 状态 ✓ 成功 ║ DNS 1.2ms ║ +║ 总耗时 245ms ║ TCP 连接 2.1ms ║ +║ TTFT 82ms ║ TLS 握手 8.4ms ║ +║ 输出TPS 12.3 tok/s ║ ║ +║ 输入Token 64 ║ ║ +║ 输出Token 128 ║ ║ +║ 缓存命中 100% ║ ║ +╠══════════════════════════╩═════════════════════════════════╣ +║ 输入 (Prompt) ║ +║ ──────────────────────────────────────────────────────── ║ +║ 你好,介绍一下你自己。 ║ +╠═════════════════════════════════════════════════════════════╣ +║ 输出 (Response) ║ +║ ──────────────────────────────────────────────────────── ║ +║ 你好!我是 Claude,一个由 Anthropic 开发的 AI 助手。我可以 ║ +║ 帮助你解答问题、分析文本、编写代码等各种任务。请告诉我你 ║ +║ 需要什么帮助! ║ +║ (↑↓ 滚动查看完整内容) ║ +╠═════════════════════════════════════════════════════════════╣ +║ [b/Esc] 返回仪表盘 [↑↓] 滚动 [←→] 上/下一条请求 ║ +╚═════════════════════════════════════════════════════════════╝ ``` +> **说明:** 支持 `[←→]` 切换前后请求,无需每次返回仪表盘。输入/输出区域均可独立滚动,内容较长时显示滚动提示。 + --- ### 4.4 键盘交互规范 @@ -620,34 +730,70 @@ Runner 消费的是一次运行所需的 `Input`,但 UI 和持久化围绕 `Ta | 按键 | 适用页面 | 功能 | |------|----------|------| | `a` | 任务列表 | 新建任务 | -| `/` | 任务列表 | 搜索 / 过滤任务 | -| `Enter` | 任务列表 | 查看任务详情 | -| `r` | 任务列表 / 任务详情 | 直接运行当前任务 | -| `e` | 任务详情 | 编辑当前任务 | -| `d` | 任务详情 | 删除当前任务 | +| `Enter` | 任务列表(普通任务) | 查看任务详情 | +| `Enter` | 任务列表(运行中任务) | 重新进入仪表盘 | +| `r` | 任务列表 / 任务详情 | 运行当前任务(有其他任务运行时提示干扰风险) | +| `e` | 任务列表 / 任务详情 | 编辑当前任务 | +| `d` | 任务列表 / 任务详情 | 删除当前任务 | | `y` | 任务列表 / 任务详情 | 复制当前任务 | -| `h` | 任务详情 | 打开完整运行历史 | -| `b` | 任务详情 / 结果页 | 返回上一级 | +| `b` / `Esc` | 任务详情 | 返回任务列表 | +| `b` / `Esc` | 仪表盘(标准/Turbo) | **后台运行**,返回任务列表(测试继续进行) | +| `b` / `Esc` | 请求详情页 | 返回仪表盘 | +| `Enter` | 仪表盘请求列表 | 进入请求详情页 | +| `↑` / `↓` | 仪表盘请求列表 / 请求详情 | 选择请求 / 滚动内容 | +| `←` / `→` | 请求详情页 | 切换上/下一条请求 | | `Tab` / `Shift+Tab` | 向导 | 在输入项间切换焦点 | -| `↑` / `↓` | 任务列表、向导、结果表格 | 上下选择 | +| `↑` / `↓` | 任务列表、向导 | 上下选择 | | `←` / `→` | 向导模式选择 | 切换选项 | | `Enter` | 向导 | 确认 / 下一步 / 保存 | | `Esc` | 所有页 | 返回上一步 / 取消 | -| `p` | Running | 暂停/恢复 | -| `s` | Running / Turbo | 停止测试 | -| `l` | Dashboard | 切换日志详情展开/折叠 | -| `r` | Running / 结果页 | 生成报告文件 | -| `m` | Turbo Running | 手动标记当前并发为最大稳定并发并停止 | -| `c` | 结果页 | 复制摘要到剪贴板 | +| `s` | 仪表盘(标准/Turbo) | 停止测试 | +| `r` | 仪表盘 / 任务详情 | 生成报告文件 | +| `m` | Turbo 仪表盘 | 手动标记当前并发为最大稳定并发并停止 | +| `c` | 任务详情(有运行记录时) | 复制最近运行摘要到剪贴板 | | `q` / `Ctrl+C` | 所有页 | 退出程序 | --- -### 4.5 布局响应式策略 +### 4.5 Context Bar 规范 + +**Context Bar** 是紧贴 Footer 上方的一行动态提示区域,**仅当当前页面/选中项有可用操作时才显示**;若无选中项或无可执行操作,该行完全不渲染(不占空间)。 + +**格式:** + +``` +[key] 操作描述 [key] 操作描述 [key] 操作描述 +``` + +**各页面 Context Bar 内容:** + +| 页面 / 场景 | Context Bar 示例 | +|------------|------------------| +| 任务列表(选中普通任务) | `[Enter] 查看详情 [r] 运行 [e] 编辑 [d] 删除 [y] 复制` | +| 任务列表(选中运行中任务) | `[Enter] 进入仪表盘 [s] 停止 [y] 复制` | +| 任务列表(无任务) | 不显示 | +| 任务详情(无运行记录) | `[Enter/r] 运行 [e] 编辑 [y] 复制 [d] 删除` | +| 任务详情(有运行记录) | `[r] 生成报告 [c] 复制摘要 [Enter/r] 再次运行 [e] 编辑` | +| 仪表盘(未选中请求) | `[s] 停止 [b] 后台运行 [r] 提前报告` | +| 仪表盘(选中请求) | `[Enter] 查看请求详情 [↑↓] 选择请求 [s] 停止` | +| Turbo 仪表盘(未选中级别) | `[s] 停止 [b] 后台运行 [m] 标记极限` | +| Turbo 仪表盘(选中已完成级别) | `[Enter] 查看该级别请求列表 [↑↓] 选择 [s] 停止` | +| 请求详情页 | `[b/Esc] 返回仪表盘 [←→] 上/下一条请求` | + +**规则:** +- Context Bar 使用与 Footer 相同的暗色调,但前景色略亮(用于区分层级) +- 仅展示**当前状态下可执行**的操作(例如:仅在请求列表选中行时才显示 `[Enter] 查看详情`) +- Context Bar 不替代 Footer——Footer 始终展示全局快捷键(`[q]` 退出等) + +--- + +### 4.6 布局响应式策略 -- 终端宽度 `< 80` 列:任务列表与任务详情折叠为单列,摘要面板移动到下方 -- 终端宽度 `≥ 80` 列:任务列表、任务详情和运行页都采用双栏布局 -- 终端高度不足时,历史记录区或日志区自动收缩,至少保留 3 行内容 +- 终端宽度 `< 80` 列:所有页面折叠为单列;仪表盘将指标区与进度条叠放为两行 +- 终端宽度 `≥ 80` 列:仪表盘实时指标全宽展示,进度条独立一行;任务列表与任务详情均为全宽单列 +- 终端高度不足时,请求列表区域自动收缩(最少保留 3 行);输入/输出内容区优先滚动而非截断 +- Context Bar 若无内容则不渲染,不影响其他区域高度分配 +- 仪表盘页面进度条独立占一行,位于实时指标区域与请求列表之间 --- @@ -747,7 +893,7 @@ ait --protocol=openai-responses --endpoint=https://api.openai.com/v1/responses - **与任务管理的关系:** - `--task=` 直接加载已保存任务并进入详情页或直接运行 -- 通过完整 CLI 参数启动时,AIT 会先生成一个未保存的临时任务草稿,用户可选择保存后复用 +- 通过完整 CLI 参数启动时,AIT 自动创建任务并立即调用 `StartRun`,直接进入 Running / TurboRunning 仪表盘 - Turbo 运行完成后,其结果会自动追加到对应任务的历史记录中 ### 5.4 Turbo 报告格式 @@ -898,7 +1044,7 @@ type TurboResult struct { } ``` -现有指标结构也需要补充缓存命中率字段,用于 dashboard、结果页和报告渲染: +现有指标结构也需要补充缓存命中率字段,用于 dashboard、任务详情页和报告渲染: ```go type ResponseMetrics struct { @@ -935,85 +1081,173 @@ type Input struct { } ``` -### 6.5 TUI 消息类型 +### 6.5 Server 层类型 + +```go +// internal/server/types.go + +// RunID 是一次运行的唯一标识 +type RunID string + +// ReportFormat 报告格式 +type ReportFormat string + +const ( + ReportFormatJSON ReportFormat = "json" + ReportFormatCSV ReportFormat = "csv" +) + +// RunStatus 运行状态枚举 +type RunStatus string + +const ( + RunStatusRunning RunStatus = "running" + RunStatusCompleted RunStatus = "completed" + RunStatusFailed RunStatus = "failed" + RunStatusStopped RunStatus = "stopped" +) + +// RunState 一次运行的当前状态快照(用于 GetRunState / 后台轮询) +type RunState struct { + RunID RunID + TaskID string + Status RunStatus + Mode string // "standard" | "turbo" + StartedAt time.Time + FinishedAt *time.Time + // 标准模式 + TotalReqs int + DoneReqs int + SuccessReqs int + FailedReqs int + Requests []*RequestMetrics // 已完成请求列表(供重新进入仪表盘恢复) + // Turbo 模式 + Levels []types.TurboLevelResult + CurrentLevel int + // 聚合实时指标(标准模式 / Turbo 当前级别) + AvgTPS float64 + AvgTTFT time.Duration + SuccessRate float64 + CacheHitRate float64 + // 最终结果(完成后填充) + StandardResult *types.ReportData + TurboResult *types.TurboResult + ErrorMsg string +} + +// RequestMetrics 单条请求的指标快照(供请求详情页展示) +type RequestMetrics struct { + Index int + Success bool + TotalTime time.Duration + TTFT time.Duration + TPS float64 + PromptTokens int + CompletionTokens int + CachedTokens int + CacheHitRate float64 + DNSTime time.Duration + ConnectTime time.Duration + TLSTime time.Duration + TargetIP string + ErrorMessage string + PromptText string // 原始输入(供请求详情页展示) + ResponseText string // 原始输出 +} + +// RunSummary 历史记录列表项 +type RunSummary struct { + RunID RunID + StartedAt time.Time + Status RunStatus + Mode string + SuccessRate float64 + AvgTTFT time.Duration + AvgTPS float64 + CacheHitRate float64 + MaxStableConcurrency int + ReportPath string +} + +// EventKind 事件类型枚举 +type EventKind string + +const ( + EventRequestDone EventKind = "request_done" // 单条请求完成 + EventProgressTick EventKind = "progress_tick" // 500ms 聚合快照 + EventLevelDone EventKind = "level_done" // Turbo 一级完成 + EventRunComplete EventKind = "run_complete" // 全部完成 + EventRunFailed EventKind = "run_failed" // 运行出错 +) +``` + +### 6.6 TUI 消息类型 + +TUI 层的 `tea.Msg` 类型由 `tui/client.go` 包装 `server.Server` 调用后产生,不直接暴露 server 内部类型: ```go // internal/tui/messages.go // TasksLoadedMsg 任务列表加载完成 type TasksLoadedMsg struct { - Tasks []types.TaskDefinition + Tasks []types.Task } // TaskSavedMsg 任务保存完成 type TaskSavedMsg struct { - Task types.TaskDefinition -} - -// TaskHistoryLoadedMsg 任务运行记录加载完成 -type TaskHistoryLoadedMsg struct { - TaskID string - History []types.TaskRunSummary + Task types.Task } -// RequestDoneMsg 单个请求完成 -type RequestDoneMsg struct { - Metrics *client.ResponseMetrics - Index int - Err error +// HistoryLoadedMsg 运行历史加载完成 +type HistoryLoadedMsg struct { + TaskID string + History []server.RunSummary } -// AllRequestsDoneMsg 所有请求完成 -type AllRequestsDoneMsg struct { - Result *types.ReportData - Errors []string +// RunStartedMsg 运行启动成功,获得 RunID +type RunStartedMsg struct { + RunID server.RunID + TaskID string } -// TurboLevelStartMsg Turbo 新一级开始 -type TurboLevelStartMsg struct { - Concurrency int - LevelIndex int +// ServerEventMsg 从 server.Subscribe 接收到的事件(统一包装) +type ServerEventMsg struct { + Event server.Event } -// TurboLevelDoneMsg Turbo 一级完成 -type TurboLevelDoneMsg struct { - Level types.TurboLevelResult - LevelIndex int +// RunStateMsg server.GetRunState 的轮询结果(后台模式重新进入仪表盘时使用) +type RunStateMsg struct { + State *server.RunState } -// TurboDoneMsg Turbo 全部完成 -type TurboDoneMsg struct { - Result *types.TurboResult +// ReportGeneratedMsg 报告生成完成 +type ReportGeneratedMsg struct { + Path string } -// ProgressTickMsg 定时刷新实时指标 -type ProgressTickMsg struct { - Stats types.StatsData -} - -// ErrorMsg 运行时错误 +// ErrorMsg 操作出错 type ErrorMsg struct { Err error } ``` -### 6.6 Runner 接口扩展 +### 6.7 Runner 接口扩展 ```go -// internal/runner/runner.go 新增 +// internal/runner/runner.go — Server 层内部使用,TUI 不直接调用 -// RequestDoneCallback 每个请求完成后的回调(细粒度,供 TUI 使用) -type RequestDoneCallback func(metrics *client.ResponseMetrics, index int, err error) +// RequestDoneCallback 每个请求完成后的回调(由 server/run.go 包装为 Event) +type RequestDoneCallback func(metrics *ResponseMetrics, index int, err error) // RunWithCallback 运行测试,每个请求完成后调用 cb(线程安全) -// 同时保留原有的 RunWithProgress,供 Legacy 模式使用 func (r *Runner) RunWithCallback(cb RequestDoneCallback) (*types.ReportData, error) // Stop 异步停止正在进行的测试 func (r *Runner) Stop() ``` +``` -### 6.7 任务与全局配置持久化 +### 6.8 任务与全局配置持久化 ```go // internal/config/config.go @@ -1027,18 +1261,29 @@ type Config struct { func Load() (*Config, error) // 从 ~/.ait/config.json 加载 func (c *Config) Save() error // 保存到 ~/.ait/config.json -// internal/task/store.go +// internal/store/store.go +// 泛型基类,提供 JSON 文件安全读写 + +type JSONStore[T any] struct { + path string +} + +func NewJSONStore[T any](path string) *JSONStore[T] +func (s *JSONStore[T]) Load() (T, error) +func (s *JSONStore[T]) Save(v T) error // 写入前加文件锁 + +// internal/store/task.go type TaskStore struct { Tasks []types.TaskDefinition `json:"tasks"` } -func LoadTasks() (*TaskStore, error) // 从 ~/.ait/tasks.json 加载 -func (s *TaskStore) Save() error // 保存到 ~/.ait/tasks.json -func (s *TaskStore) Upsert(task types.TaskDefinition) // 新建或更新任务 +func LoadTasks() (*TaskStore, error) // 从 ~/.ait/tasks.json 加载 +func (s *TaskStore) Save() error // 保存到 ~/.ait/tasks.json +func (s *TaskStore) Upsert(task types.TaskDefinition) // 新建或更新任务 func (s *TaskStore) Delete(taskID string) error -// internal/task/history.go +// internal/store/history.go func AppendRun(taskID string, run types.TaskRunSummary) error func LoadHistory(taskID string, limit int) ([]types.TaskRunSummary, error) @@ -1048,61 +1293,71 @@ func LoadHistory(taskID string, limit int) ([]types.TaskRunSummary, error) ## 7. 开发计划 -### Phase 1 — 任务中心与 TUI 基础框架(优先) - -**目标:** 先建立任务管理主流程,再用 BubbleTea 替换现有的进度条 + 静态表格输出 - -**任务清单:** - -- [ ] 引入依赖:`charm.land/bubbletea/v2`、`bubbles`、`lipgloss` -- [ ] 实现 `internal/tui/` 基础骨架(model、messages、styles) -- [ ] 实现任务列表页(tasklist):选择 / 搜索 / 删除 / 复制 -- [ ] 实现任务详情页(taskdetail):配置摘要 + 最近记录 + 直接运行 -- [ ] 实现向导页(wizard):三步创建 / 编辑任务 -- [ ] 实现仪表盘页(dashboard):进度 + 实时指标双栏 -- [ ] 实现结果页(result):完整指标表格 + 键盘操作 -- [ ] 协议枚举细化:`openai-completions`、`openai-responses`、`anthropic-messages` -- [ ] 扩展指标采集与渲染:缓存命中率(dashboard / result / report) -- [ ] `cmd/ait/main.go` 模式检测路由(无参数 → 任务列表,有参数 → 临时任务草稿) -- [ ] 实现任务持久化(`tasks.json` + `history/*.json`) -- [ ] Runner 增加 `Stop()` 方法和 `RunWithCallback` 接口 -- [ ] 全局配置持久化(默认协议、最后选择任务、密钥保存策略) -- [ ] 结果页回写任务最近运行摘要 +> **约定:** 严格遵守 SC 分层原则——先建好 Server 接口,再实现 TUI Client;所有 UI 层代码仅 import `internal/server`,不直接 import `runner` / `task` / `report` 等下层包。 + +### Phase 1 — Server 层 + TUI 基础框架 + +**目标:** 建立 SC 架构骨架,跑通"创建任务 → 运行 → 看到进度 → 查看结果"主流程 + +**Step 1:Server 层(先行)** + +- [ ] `internal/types/types.go`:补充 `Task`、`TaskConfig`、`TurboConfig`、`ReportData` 等领域类型 +- [ ] `internal/store/store.go`:泛型 `JSONStore[T]` 基类(文件锁 + Load/Save) +- [ ] `internal/store/task.go`:`TaskStore`,`~/.ait/tasks.json` CRUD +- [ ] `internal/store/history.go`:`HistoryStore`,`~/.ait/history/.json` 读写 +- [ ] `internal/config/config.go`:`~/.ait/config.json` 全局配置(复用 `store.JSONStore`) +- [ ] `internal/runner/runner.go`:增加 `Stop()` + 稳定 `RunWithCallback` +- [ ] `internal/server/server.go`:定义 `Server` 接口 + `New()` 构造函数 +- [ ] `internal/server/task.go`:实现任务 CRUD 方法(调用 task.Store) +- [ ] `internal/server/run.go`:实现 `StartRun / StopRun / GetRunState`(调用 runner) +- [ ] `internal/server/event.go`:`Event / EventKind / RunState / RunID` 类型 + 内部 eventBus +- [ ] Server 单元测试:任务 CRUD、运行状态机、事件分发 + +**Step 2:TUI Client(依赖 Server 接口完成后)** + +- [ ] `internal/tui/client.go`:持有 `server.Server`,封装 `tea.Cmd` 异步调用 +- [ ] `internal/tui/model.go`:根 BubbleTea Model + 全局状态机(只依赖 `client.go`) +- [ ] `internal/tui/messages.go`:所有 `tea.Msg` 类型 +- [ ] `internal/tui/styles.go`:lipgloss 样式常量 +- [ ] `internal/tui/pages/contextbar.go`:Context Bar 组件(条件渲染) +- [ ] `internal/tui/pages/tasklist.go`:任务列表页(含 ◉ 运行状态展示) +- [ ] `internal/tui/pages/taskdetail.go`:任务详情页 +- [ ] `internal/tui/pages/wizard.go`:三步弹窗向导(overlay) +- [ ] `internal/tui/pages/dashboard.go`:标准模式仪表盘(请求列表 + 实时指标) +- [ ] `internal/tui/pages/reqdetail.go`:请求详情页(含原始输入/输出) +- [ ] `cmd/ait/main.go`:`server.New()` → 启动 TUI +- [ ] 协议枚举:`openai-completions`、`openai-responses`、`anthropic-messages` - [ ] 响应式布局(终端宽度自适应) -- [ ] `internal/display/` 模块退役,由 TUI 全面接管输出 +- [ ] `internal/display/` 退役,由 TUI 全面接管输出 --- ### Phase 2 — Turbo 模式 -**目标:** 将并发爬坡能力完整融入任务体系 - -**任务清单:** +**目标:** 将并发爬坡能力完整融入 SC 架构 -- [ ] 实现 `internal/turbo/runner.go`:封装爬坡调度逻辑 -- [ ] 实现 `internal/turbo/strategy.go`:步进 & 终止策略 -- [ ] `types.TurboConfig`、`TurboLevelResult`、`TurboResult` 数据结构 -- [ ] TUI Turbo 仪表盘页(折线图 + 爬坡表格) -- [ ] `internal/report/turbo_renderer.go`:Turbo CSV/JSON 报告 -- [ ] Turbo 结果写回任务最近摘要和运行记录 -- [ ] 新增 CLI 参数:`--turbo`、`--turbo-*` 系列 +- [ ] `internal/turbo/engine.go`:爬坡调度(`Run / Stop`) +- [ ] `internal/turbo/strategy.go`:步进 & 终止策略 +- [ ] `internal/turbo/types.go`:`TurboResult / LevelResult` +- [ ] `internal/server/run.go`:扩展 `StartRun` 支持 Turbo 模式(发布 `EventLevelDone`) +- [ ] `internal/tui/pages/turbodash.go`:Turbo 仪表盘页(级别列表 + 当前级别指标) +- [ ] `internal/tui/pages/taskdetail.go`:扩展支持 Turbo 运行结果展开(爬坡表格 + ASCII 曲线) +- [ ] `internal/report/turbo_renderer.go`:Turbo CSV/JSON 报告渲染 +- [ ] Turbo 运行历史写回任务摘要 --- -### Phase 3 — 增强 - -**目标:** 细节打磨与扩展 +### Phase 3 — 增强 & Web UI 接入准备 -**任务清单:** +**目标:** 细节打磨,为 Web UI 预留接入点 -- [ ] 多任务 Turbo 对比(并排爬坡曲线) -- [ ] 任务收藏和快速筛选视图 -- [ ] 任务复制、模板化创建和批量导入 -- [ ] 运行记录对比视图(同一任务不同 run 对比) -- [ ] 结果页 `c` 键复制摘要到剪贴板 -- [ ] `ntcharts` 折线图替换 ASCII 折线图 +- [ ] `internal/webui/`(骨架):HTTP handler 接收请求 → 调用 `server.Server` → SSE/WS 推送 Event +- [ ] 多任务并发干扰提示完善 +- [ ] 运行记录对比视图(同一任务不同 run) +- [ ] 任务详情页 `[c]` 复制最近运行摘要到剪贴板 +- [ ] `ntcharts` 折线图替换 ASCII 折线图(Turbo 曲线) - [ ] 终端尺寸变化自适应重绘 -- [ ] 完善单元测试(TUI model 测试、turbo strategy 测试) +- [ ] 完善单元测试(TUI model 测试、server 集成测试、turbo strategy 测试) --- diff --git a/internal/store/history.go b/internal/store/history.go new file mode 100644 index 0000000..8b30cd6 --- /dev/null +++ b/internal/store/history.go @@ -0,0 +1,49 @@ +package store + +import "github.com/yinxulai/ait/internal/types" + +// HistoryStore 管理单个任务的运行历史文件(~/.ait/history/.json)。 +// 每个任务对应独立的 HistoryStore 实例和独立的文件。 +type HistoryStore struct { + store *JSONStore[[]types.TaskRunSummary] +} + +// NewHistoryStore 创建持久化到 path 的 HistoryStore。 +func NewHistoryStore(path string) *HistoryStore { + return &HistoryStore{store: NewJSONStore[[]types.TaskRunSummary](path)} +} + +// Append 追加一条运行摘要到历史文件。 +func (s *HistoryStore) Append(run types.TaskRunSummary) error { + runs, err := s.store.Load() + if err != nil { + return err + } + if runs == nil { + runs = []types.TaskRunSummary{} + } + runs = append(runs, run) + return s.store.Save(runs) +} + +// Load 返回运行历史,最新的排在前面。limit <= 0 表示不限制条数。 +func (s *HistoryStore) Load(limit int) ([]types.TaskRunSummary, error) { + runs, err := s.store.Load() + if err != nil { + return nil, err + } + if runs == nil { + return []types.TaskRunSummary{}, nil + } + + // 反转(最新在前) + reversed := make([]types.TaskRunSummary, len(runs)) + for i, r := range runs { + reversed[len(runs)-1-i] = r + } + + if limit > 0 && len(reversed) > limit { + reversed = reversed[:limit] + } + return reversed, nil +} diff --git a/internal/store/store.go b/internal/store/store.go new file mode 100644 index 0000000..6bc44fc --- /dev/null +++ b/internal/store/store.go @@ -0,0 +1,65 @@ +package store + +import ( + "encoding/json" + "errors" + "os" + "path/filepath" + "sync" +) + +// JSONStore 是泛型 JSON 文件持久化基类。 +// 内置进程级互斥锁,防止同一进程内并发读写;文件操作通过原子写入保证安全。 +type JSONStore[T any] struct { + path string + mu sync.Mutex +} + +// NewJSONStore 创建指向指定路径的 JSONStore。 +func NewJSONStore[T any](path string) *JSONStore[T] { + return &JSONStore[T]{path: path} +} + +// Load 从文件读取并反序列化为 T。文件不存在时返回零值(无错误)。 +func (s *JSONStore[T]) Load() (T, error) { + s.mu.Lock() + defer s.mu.Unlock() + + var zero T + data, err := os.ReadFile(s.path) + if errors.Is(err, os.ErrNotExist) { + return zero, nil + } + if err != nil { + return zero, err + } + + var v T + if err := json.Unmarshal(data, &v); err != nil { + return zero, err + } + return v, nil +} + +// Save 将 v 序列化为 JSON 并写入文件,写入前自动创建父目录。 +// 使用"写临时文件后重命名"的原子写入方式,避免文件写到一半导致损坏。 +func (s *JSONStore[T]) Save(v T) error { + s.mu.Lock() + defer s.mu.Unlock() + + if err := os.MkdirAll(filepath.Dir(s.path), 0o755); err != nil { + return err + } + + data, err := json.MarshalIndent(v, "", " ") + if err != nil { + return err + } + + // 原子写入:先写临时文件,再重命名 + tmp := s.path + ".tmp" + if err := os.WriteFile(tmp, data, 0o644); err != nil { + return err + } + return os.Rename(tmp, s.path) +} diff --git a/internal/store/task.go b/internal/store/task.go new file mode 100644 index 0000000..0e4e5ec --- /dev/null +++ b/internal/store/task.go @@ -0,0 +1,103 @@ +package store + +import ( + "fmt" + "time" + + "github.com/yinxulai/ait/internal/types" +) + +type taskStoreData struct { + Tasks []types.TaskDefinition `json:"tasks"` +} + +// TaskStore 管理 ~/.ait/tasks.json 的任务列表持久化。 +type TaskStore struct { + store *JSONStore[taskStoreData] + data taskStoreData +} + +// NewTaskStore 创建持久化到 path 的 TaskStore(调用方需先调用 Load)。 +func NewTaskStore(path string) *TaskStore { + return &TaskStore{store: NewJSONStore[taskStoreData](path)} +} + +// Load 从磁盘加载任务列表,文件不存在时初始化为空列表。 +func (s *TaskStore) Load() error { + data, err := s.store.Load() + if err != nil { + return err + } + if data.Tasks == nil { + data.Tasks = []types.TaskDefinition{} + } + s.data = data + return nil +} + +// Save 将当前内存中的任务列表持久化到磁盘。 +func (s *TaskStore) Save() error { + return s.store.Save(s.data) +} + +// All 返回所有任务的副本,最近更新的排在前面。 +func (s *TaskStore) All() []types.TaskDefinition { + result := make([]types.TaskDefinition, len(s.data.Tasks)) + copy(result, s.data.Tasks) + return result +} + +// Get 按 ID 查找任务,返回副本。 +func (s *TaskStore) Get(id string) (types.TaskDefinition, bool) { + for _, t := range s.data.Tasks { + if t.ID == id { + return t, true + } + } + return types.TaskDefinition{}, false +} + +// Upsert 新建或更新任务。 +// - 若 task.ID 为空,自动生成唯一 ID。 +// - 更新时将任务移至列表头部(最近活跃排序)。 +func (s *TaskStore) Upsert(task types.TaskDefinition) { + now := time.Now() + if task.ID == "" { + task.ID = fmt.Sprintf("task_%d", now.UnixNano()) + } + + for i, existing := range s.data.Tasks { + if existing.ID != task.ID { + continue + } + if task.CreatedAt.IsZero() { + task.CreatedAt = existing.CreatedAt + } + task.UpdatedAt = now + // 移至列表头部 + tasks := make([]types.TaskDefinition, 0, len(s.data.Tasks)) + tasks = append(tasks, task) + tasks = append(tasks, s.data.Tasks[:i]...) + tasks = append(tasks, s.data.Tasks[i+1:]...) + s.data.Tasks = tasks + return + } + + // 新增 + if task.CreatedAt.IsZero() { + task.CreatedAt = now + } + task.UpdatedAt = now + s.data.Tasks = append([]types.TaskDefinition{task}, s.data.Tasks...) +} + +// Delete 按 ID 删除任务,任务不存在时返回错误。 +func (s *TaskStore) Delete(id string) error { + for i, t := range s.data.Tasks { + if t.ID == id { + s.data.Tasks = append(s.data.Tasks[:i], s.data.Tasks[i+1:]...) + return nil + } + } + return fmt.Errorf("task %q not found", id) +} diff --git a/internal/tui/messages.go b/internal/tui/messages.go index ad776f9..41106b9 100644 --- a/internal/tui/messages.go +++ b/internal/tui/messages.go @@ -20,3 +20,7 @@ type turboCompleteMsg struct { type asyncErrorMsg struct { err error } + +type requestLogMsg struct { + entry string +} diff --git a/internal/tui/model.go b/internal/tui/model.go index d9cb4f4..6194176 100644 --- a/internal/tui/model.go +++ b/internal/tui/model.go @@ -66,7 +66,8 @@ type wizardState struct { lastRunAt *time.Time lastRunSummary *types.TaskRunSummary fromView viewState - current int + step int // 0=基本信息 1=测试参数 2=确认保存 + fieldIndex int // active field within current step input textinput.Model values map[string]string protocolIndex int @@ -90,14 +91,16 @@ type Model struct { height int status string err error - program *tea.Program - runningTask *types.TaskDefinition - runStartedAt time.Time + program *tea.Program + runningTask *types.TaskDefinition + runningTaskID string + runStartedAt time.Time progress types.StatsData runResult *types.ReportData turboResult *types.TurboResult activeRunner *runner.Runner activeTurbo *turbo.Engine + requestLog []string } func NewModel(store *task.TaskStore, cfg *config.Config) *Model { @@ -139,21 +142,34 @@ func (m *Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { case progressMsg: m.progress = msg.stats return m, nil + case requestLogMsg: + m.requestLog = append(m.requestLog, msg.entry) + if len(m.requestLog) > 60 { + m.requestLog = m.requestLog[len(m.requestLog)-60:] + } + return m, nil case runCompleteMsg: m.activeRunner = nil + m.runningTaskID = "" m.runResult = msg.result - m.view = viewResult + if m.view == viewDashboard { + m.view = viewResult + } m.status = fmt.Sprintf("标准模式完成,共 %d 请求", msg.result.TotalRequests) m.persistStandardRun(msg.taskID, msg.result, msg.reportPaths) return m, nil case turboCompleteMsg: m.activeTurbo = nil + m.runningTaskID = "" m.turboResult = msg.result - m.view = viewTurboResult + if m.view == viewDashboard { + m.view = viewTurboResult + } m.status = fmt.Sprintf("Turbo 完成,最大稳定并发 %d", msg.result.MaxStableConcurrency) m.persistTurboRun(msg.taskID, msg.result) return m, nil case asyncErrorMsg: + m.runningTaskID = "" m.err = msg.err m.status = msg.err.Error() return m, nil @@ -230,13 +246,21 @@ func (m *Model) handleTaskListKey(msg tea.KeyMsg) (tea.Model, tea.Cmd) { m.status = "任务已删除" } case "enter": - if _, ok := m.currentTask(); ok { - m.reloadHistoryForSelectedTask() - m.view = viewTaskDetail + if taskDef, ok := m.currentTask(); ok { + if taskDef.ID == m.runningTaskID { + m.view = viewDashboard + } else { + m.reloadHistoryForSelectedTask() + m.view = viewTaskDetail + } } case "r": if taskDef, ok := m.currentTask(); ok { - m.startTaskRun(taskDef) + if m.runningTaskID != "" { + m.status = "已有任务正在运行中,请等待完成或进入仪表盘停止" + } else { + m.startTaskRun(taskDef) + } } case "q": return m, tea.Quit @@ -272,39 +296,92 @@ func (m *Model) handleTaskDetailKey(msg tea.KeyMsg) (tea.Model, tea.Cmd) { } m.view = viewTaskList case "enter", "r": - m.startTaskRun(taskDef) + if m.runningTaskID != "" && m.runningTaskID != taskDef.ID { + m.status = "已有任务正在运行中" + } else { + m.startTaskRun(taskDef) + if m.runningTaskID == taskDef.ID { + m.view = viewDashboard + } + } } return m, nil } func (m *Model) handleWizardKey(msg tea.KeyMsg) (tea.Model, tea.Cmd) { - field := m.currentWizardField() + if m.wizard == nil { + return m, nil + } + + // Step 2 (confirm): only action keys, no text input + if m.wizard.step == 2 { + switch msg.String() { + case "esc": + m.wizard.step = 1 + m.wizard.fieldIndex = len(m.wizardStepFields(1)) - 1 + m.refreshWizardInput() + case "enter": + if err := m.saveWizard(); err != nil { + m.err = err + m.status = err.Error() + } + case "r": + if err := m.saveWizard(); err != nil { + m.err = err + m.status = err.Error() + return m, nil + } + if taskDef, ok := m.currentTask(); ok { + m.startTaskRun(taskDef) + if m.runningTaskID == taskDef.ID { + m.view = viewDashboard + } + } + } + return m, nil + } + + fields := m.wizardStepFields(m.wizard.step) + field := fields[m.wizard.fieldIndex] + switch msg.String() { case "esc": - m.view = m.wizard.fromView - m.wizard = nil + if m.wizard.step > 0 { + m.wizard.step-- + m.wizard.fieldIndex = len(m.wizardStepFields(m.wizard.step)) - 1 + m.refreshWizardInput() + } else { + m.view = m.wizard.fromView + m.wizard = nil + } return m, nil - case "tab", "enter": + case "tab", "down", "j": if field.kind == fieldText { m.wizard.values[field.key] = m.wizard.input.Value() } - if m.wizard.current == len(m.wizardFields())-1 { - if err := m.saveWizard(); err != nil { - m.err = err - m.status = err.Error() + m.advanceWizardField(1) + return m, nil + case "enter": + if field.kind == fieldText { + m.wizard.values[field.key] = m.wizard.input.Value() + } + if m.wizard.fieldIndex == len(fields)-1 { + m.wizard.step++ + m.wizard.fieldIndex = 0 + if m.wizard.step < 2 { + m.refreshWizardInput() } - return m, nil + } else { + m.wizard.fieldIndex++ + m.refreshWizardInput() } - m.wizard.current++ - m.refreshWizardInput() return m, nil - case "shift+tab", "up": - if m.wizard.current > 0 { + case "shift+tab", "up", "k": + if field.kind == fieldText { m.wizard.values[field.key] = m.wizard.input.Value() - m.wizard.current-- - m.refreshWizardInput() } + m.advanceWizardField(-1) return m, nil case "left", "h": m.cycleWizardField(-1) @@ -324,13 +401,18 @@ func (m *Model) handleWizardKey(msg tea.KeyMsg) (tea.Model, tea.Cmd) { func (m *Model) handleDashboardKey(msg tea.KeyMsg) (tea.Model, tea.Cmd) { switch msg.String() { - case "s", "q", "esc": + case "s": if m.activeRunner != nil { m.activeRunner.Stop() } if m.activeTurbo != nil { m.activeTurbo.Stop() } + case "b", "esc": + // 返回列表,任务继续在后台运行 + m.view = viewTaskList + case "q": + return m, tea.Quit } return m, nil } @@ -353,180 +435,908 @@ func (m *Model) View() string { } func (m *Model) renderTaskList() string { - var rows []string - for i, taskDef := range m.tasks { - mode := modeStandard - if taskDef.Input.Turbo { - mode = modeTurbo + if m.width == 0 { + return "加载中..." + } + lastRunStr := "" + for _, t := range m.tasks { + if t.LastRunAt != nil { + lastRunStr = "最近: " + timeAgo(*t.LastRunAt) + break } - summary := "从未运行" - if taskDef.LastRunSummary != nil { - summary = fmt.Sprintf("上次 %.1f%% · %.1f tok/s", taskDef.LastRunSummary.SuccessRate, taskDef.LastRunSummary.AvgTPS) + } + header := m.renderHeader( + "AIT 任务中心", + fmt.Sprintf("已保存任务: %d %s", len(m.tasks), lastRunStr), + ) + footer := m.renderFooter( + "[↑↓] 选择", "[Enter] 详情", "[a] 新建", "[r] 运行", + "[e] 编辑", "[d] 删除", "[y] 复制", "[q] 退出", + ) + contentH := m.height - 2 + if contentH < 4 { + contentH = 4 + } + panelH := contentH - 2 + leftW := (m.width - 4) * 57 / 100 + rightW := m.width - 4 - leftW + leftContent := m.buildTaskListLeft(panelH, leftW) + rightContent := m.buildTaskListRight(panelH) + mid := m.dualColumnLayout(leftContent, rightContent, leftW, rightW, panelH) + return lipgloss.JoinVertical(lipgloss.Left, header, mid, footer) +} + +func (m *Model) buildTaskListLeft(maxH, width int) string { + var lines []string + lines = append(lines, m.styles.tableHead.Render( + fmt.Sprintf(" %-28s %-9s %-14s %s", "任务名称", "模式", "协议", "上次结果"), + )) + lines = append(lines, m.styles.muted.Render(strings.Repeat("─", width))) + if len(m.tasks) == 0 { + lines = append(lines, "") + lines = append(lines, m.styles.muted.Render(" 暂无任务 按 [a] 新建")) + return strings.Join(lines, "\n") + } + for i, t := range m.tasks { + if len(lines) >= maxH-1 { + break } - line := fmt.Sprintf("%s %s %s %s", taskDef.Name, taskDef.Input.Model, mode, summary) + // Mode: color-coded tag text with manual padding to 9 visual columns + var modeRendered string + if t.Input.Turbo { + modeRendered = m.styles.tagTurbo.Render("Turbo") + } else { + modeRendered = m.styles.tagStd.Render("标准") + } + modePad := 9 - lipgloss.Width(modeRendered) + if modePad < 0 { + modePad = 0 + } + modeCol := modeRendered + strings.Repeat(" ", modePad) + proto := shortProtocol(t.Input.NormalizedProtocol()) + lastResult := m.styles.muted.Render("从未运行") + if t.LastRunSummary != nil { + pct := t.LastRunSummary.SuccessRate + if pct >= 99 { + lastResult = m.styles.ok.Render(fmt.Sprintf("%.1f%%", pct)) + } else if pct >= 90 { + lastResult = m.styles.metricVal.Render(fmt.Sprintf("%.1f%%", pct)) + } else { + lastResult = m.styles.errStyle.Render(fmt.Sprintf("%.1f%%", pct)) + } + } + nameStr := truncate(t.Name, 28) + // Build row from parts so ANSI in modeCol doesn't break alignment + nameCol := fmt.Sprintf("%-28s ", nameStr) + protoCol := fmt.Sprintf("%-14s ", proto) + mainRow := " " + nameCol + modeCol + " " + protoCol + lastResult if i == m.selected { - line = m.styles.selected.Render("▶ " + line) + plainRow := " " + nameCol + fmt.Sprintf("%-9s ", func() string { + if t.Input.Turbo { + return "Turbo" + } + return "标准" + }()) + protoCol + lastResult + lines = append(lines, m.styles.tableRowSel.Width(width).Render("▶"+plainRow[1:])) } else { - line = " " + line + lines = append(lines, mainRow) } - rows = append(rows, line) - } - if len(rows) == 0 { - rows = append(rows, m.styles.muted.Render("暂无任务,按 a 新建")) + var sub string + if t.Input.Turbo { + tc := t.Input.TurboConfig + sub = fmt.Sprintf(" %s %d→%d +%d 每级%d", + truncate(t.Input.Model, 18), + tc.InitConcurrency, tc.MaxConcurrency, tc.StepSize, tc.LevelRequests) + } else { + sub = fmt.Sprintf(" %s 并发%d/请求%d", + truncate(t.Input.Model, 20), t.Input.Concurrency, t.Input.Count) + } + if i == m.selected { + lines = append(lines, m.styles.tableRowSel.Width(width).Render(sub)) + } else { + lines = append(lines, m.styles.muted.Render(sub)) + } + lines = append(lines, "") } + return strings.Join(lines, "\n") +} - content := []string{ - m.styles.title.Render("AIT 任务中心"), - m.styles.subtitle.Render(fmt.Sprintf("已保存任务: %d", len(m.tasks))), - m.styles.panel.Render(strings.Join(rows, "\n")), - m.footer("[↑↓] 选择", "[Enter] 详情", "[a] 新建", "[r] 运行", "[e] 编辑", "[d] 删除", "[q] 退出"), +func (m *Model) buildTaskListRight(maxH int) string { + var lines []string + lines = append(lines, m.styles.sectionHead.Render("快捷操作")) + lines = append(lines, "") + lines = append(lines, " "+m.styles.key.Render("[a]")+" 新建任务") + lines = append(lines, " "+m.styles.key.Render("[Enter]")+" 查看详情") + lines = append(lines, " "+m.styles.key.Render("[r]")+" 直接运行选中任务") + lines = append(lines, " "+m.styles.key.Render("[e]")+" 编辑 "+m.styles.key.Render("[d]")+" 删除 "+m.styles.key.Render("[y]")+" 复制") + lines = append(lines, "") + lines = append(lines, m.styles.muted.Render(strings.Repeat("─", 28))) + lines = append(lines, "") + lines = append(lines, m.styles.sectionHead.Render("最近执行")) + lines = append(lines, "") + count := 0 + for _, t := range m.tasks { + if t.LastRunSummary == nil { + continue + } + s := t.LastRunSummary + statusIcon := m.styles.ok.Render("✓") + if s.SuccessRate < 90 { + statusIcon = m.styles.errStyle.Render("✗") + } + lines = append(lines, fmt.Sprintf(" %s %-16s %.1f%% %.0f tok/s", + statusIcon, truncate(t.Name, 16), s.SuccessRate, s.AvgTPS)) + count++ + if count >= 5 || len(lines) >= maxH-2 { + break + } + } + if count == 0 { + lines = append(lines, m.styles.muted.Render(" 暂无记录")) } if m.status != "" { - content = append(content, m.styles.muted.Render(m.status)) + lines = append(lines, "") + lines = append(lines, m.styles.muted.Render(m.status)) } - return strings.Join(content, "\n") + if m.err != nil { + lines = append(lines, m.styles.errStyle.Render("错误: "+m.err.Error())) + } + return strings.Join(lines, "\n") } func (m *Model) renderTaskDetail() string { + if m.width == 0 { + return "加载中..." + } taskDef, ok := m.currentTask() if !ok { - return m.styles.error.Render("任务不存在") + return m.styles.errStyle.Render("任务不存在") + } + updatedStr := "" + if !taskDef.UpdatedAt.IsZero() { + updatedStr = "更新: " + taskDef.UpdatedAt.Format("01-02 15:04") } - lastRun := "从未运行" + lastRunStr := "从未运行" if taskDef.LastRunAt != nil { - lastRun = taskDef.LastRunAt.Format(time.RFC3339) + lastRunStr = "上次: " + timeAgo(*taskDef.LastRunAt) } - mode := modeStandard - if taskDef.Input.Turbo { - mode = modeTurbo + header := m.renderHeader( + "AIT 任务详情 — "+truncate(taskDef.Name, 24), + updatedStr+" "+lastRunStr, + ) + footer := m.renderFooter("[Enter/r] 运行", "[e] 编辑", "[d] 删除", "[b] 返回") + contentH := m.height - 2 + histH := 9 + topH := contentH - histH + if topH < 6 { + topH = 6 + } + panelH := topH - 2 + leftW := (m.width - 4) * 57 / 100 + rightW := m.width - 4 - leftW + leftContent := m.buildDetailLeft(taskDef, panelH, leftW) + rightContent := m.buildDetailRight(taskDef) + top := m.dualColumnLayout(leftContent, rightContent, leftW, rightW, panelH) + histPanelH := histH - 2 + histContent := m.buildHistoryContent(histPanelH, m.width-4) + histPanel := lipgloss.NewStyle(). + Width(m.width - 2).Height(histPanelH). + Border(lipgloss.RoundedBorder()). + BorderForeground(colorPurple). + Render(histContent) + return lipgloss.JoinVertical(lipgloss.Left, header, top, histPanel, footer) +} + +func (m *Model) buildDetailLeft(t types.TaskDefinition, h, w int) string { + var lines []string + lines = append(lines, m.styles.sectionHead.Render("配置摘要")) + lines = append(lines, "") + maxURLLen := w - 14 + if maxURLLen < 20 { + maxURLLen = 20 + } + rows := [][2]string{ + {"协议", t.Input.NormalizedProtocol()}, + {"接口地址", truncate(t.Input.ResolvedEndpointURL(), maxURLLen)}, + {"模型", t.Input.Model}, + } + if t.Input.Turbo { + tc := t.Input.TurboConfig + rows = append(rows, + [2]string{"模式", "Turbo 模式"}, + [2]string{"爬坡", fmt.Sprintf("%d → %d 步进+%d 每级%d", + tc.InitConcurrency, tc.MaxConcurrency, tc.StepSize, tc.LevelRequests)}, + [2]string{"停止条件", fmt.Sprintf("成功率<%.0f%% 或延迟>%s", + tc.MinSuccessRate*100, tc.MaxLatency)}, + ) + } else { + rows = append(rows, + [2]string{"模式", "标准模式"}, + [2]string{"并发", fmt.Sprintf("%d", t.Input.Concurrency)}, + [2]string{"请求数", fmt.Sprintf("%d", t.Input.Count)}, + [2]string{"超时", t.Input.Timeout.String()}, + ) } - left := []string{ - fmt.Sprintf("名称: %s", taskDef.Name), - fmt.Sprintf("协议: %s", taskDef.Input.NormalizedProtocol()), - fmt.Sprintf("接口: %s", taskDef.Input.ResolvedEndpointURL()), - fmt.Sprintf("模型: %s", taskDef.Input.Model), - fmt.Sprintf("模式: %s", mode), - fmt.Sprintf("Prompt: %s", promptSummary(taskDef.Input)), - fmt.Sprintf("最近运行: %s", lastRun), + rows = append(rows, + [2]string{"流式", boolLabel(t.Input.Stream)}, + [2]string{"Prompt", promptSummary(t.Input)}, + ) + for _, row := range rows { + lines = append(lines, fmt.Sprintf(" %s %s", + m.styles.label.Render(fmt.Sprintf("%-8s", row[0])), + m.styles.value.Render(row[1]))) } + return strings.Join(lines, "\n") +} - historyLines := []string{m.styles.label.Render("最近运行记录")} +func (m *Model) buildDetailRight(t types.TaskDefinition) string { + var lines []string + lines = append(lines, m.styles.sectionHead.Render("最近一次结果")) + lines = append(lines, "") + if t.LastRunSummary == nil { + lines = append(lines, m.styles.muted.Render(" 从未运行")) + return strings.Join(lines, "\n") + } + s := t.LastRunSummary + statusStr := m.styles.ok.Render("✓ 完成") + if s.SuccessRate < 90 { + statusStr = m.styles.errStyle.Render("✗ 异常") + } + rows := [][2]string{ + {"状态", statusStr}, + {"成功率", fmt.Sprintf("%.1f%%", s.SuccessRate)}, + {"avg TTFT", s.AvgTTFT.Truncate(time.Millisecond).String()}, + {"avg TPS", fmt.Sprintf("%.1f tok/s", s.AvgTPS)}, + {"缓存命中", fmt.Sprintf("%.1f%%", s.CacheHitRate)}, + } + if s.MaxStableConcurrency > 0 { + rows = append(rows, [2]string{"最大稳定并发", fmt.Sprintf("%d", s.MaxStableConcurrency)}) + } + for _, row := range rows { + lines = append(lines, fmt.Sprintf(" %s %s", + m.styles.label.Render(fmt.Sprintf("%-10s", row[0])), + row[1])) + } + return strings.Join(lines, "\n") +} + +func (m *Model) buildHistoryContent(maxH, width int) string { + var lines []string + lines = append(lines, m.styles.sectionHead.Render("最近运行记录")+" "+ + m.styles.tableHead.Render(fmt.Sprintf("%-19s %-6s %-8s %-12s %-10s %-8s", + "时间", "模式", "成功率", "TTFT", "TPS", "Cache"))) + lines = append(lines, m.styles.muted.Render(strings.Repeat("─", width-2))) if len(m.history) == 0 { - historyLines = append(historyLines, m.styles.muted.Render("暂无历史")) + lines = append(lines, m.styles.muted.Render(" 暂无历史记录")) + return strings.Join(lines, "\n") + } + for _, run := range m.history { + if len(lines) >= maxH { + break + } + status := m.styles.ok.Render("✓") + if run.SuccessRate < 90 { + status = m.styles.errStyle.Render("✗") + } + modeShort := run.Mode + if len(modeShort) > 5 { + modeShort = modeShort[:5] + } + lines = append(lines, fmt.Sprintf(" %s %-19s %-6s %-8.1f%% %-12s %-10.1f %-8.1f%%", + status, + run.FinishedAt.Format("2006-01-02 15:04"), + modeShort, + run.SuccessRate, + run.AvgTTFT.Truncate(time.Millisecond), + run.AvgTPS, + run.CacheHitRate)) + } + return strings.Join(lines, "\n") +} + +func (m *Model) renderWizard() string { + if m.width == 0 || m.wizard == nil { + return "加载中..." + } + stepTitles := []string{"1/3 · 基本信息", "2/3 · 测试参数", "3/3 · 确认保存"} + header := m.renderHeader("AIT 任务向导", "步骤 "+stepTitles[m.wizard.step]) + var footer string + if m.wizard.step < 2 { + footer = m.renderFooter("[Tab/↓] 下一项", "[↑] 上一项", "[←→] 切换选项", "[Enter] 下一步", "[Esc] 返回") } else { - for _, item := range m.history { - historyLines = append(historyLines, fmt.Sprintf("%s %s %.1f%% %.1f tok/s cache %.1f%%", item.FinishedAt.Format("2006-01-02 15:04:05"), item.Mode, item.SuccessRate, item.AvgTPS, item.CacheHitRate)) + footer = m.renderFooter("[Enter] 保存任务", "[r] 保存并运行", "[Esc] 返回修改") + } + contentH := m.height - 2 + dialogW := m.width - 6 + if dialogW > 78 { + dialogW = 78 + } + if dialogW < 40 { + dialogW = 40 + } + dialogContentW := dialogW - 6 // -2 border -4 padding + var content string + switch m.wizard.step { + case 0: + content = m.renderWizardStep0(dialogContentW) + case 1: + content = m.renderWizardStep1(dialogContentW) + case 2: + content = m.renderWizardStep2(dialogContentW) + } + dialog := m.styles.dialog.Width(dialogContentW).Render(content) + dialogH := lipgloss.Height(dialog) + padTop := (contentH - dialogH) / 2 + if padTop < 0 { + padTop = 0 + } + centeredDialog := lipgloss.Place(m.width, contentH, + lipgloss.Center, lipgloss.Top, + strings.Repeat("\n", padTop)+dialog) + return lipgloss.JoinVertical(lipgloss.Left, header, centeredDialog, footer) +} + +func (m *Model) renderWizardStep0(w int) string { + fields := m.wizardStepFields(0) + var lines []string + // Step indicator: ● ○ ○ + lines = append(lines, m.styles.stepActive.Render("●")+" "+ + m.styles.stepTodo.Render("○")+" "+ + m.styles.stepTodo.Render("○")+" "+ + m.styles.sectionHead.Render("基本信息")) + lines = append(lines, "") + for i, field := range fields { + active := i == m.wizard.fieldIndex + lines = append(lines, m.renderWizardField(field, active)) + if field.key == "protocol" { + for pi, p := range protocolOptions { + bullet := " ○ " + if pi == m.wizard.protocolIndex { + bullet = " " + m.styles.ok.Render("●") + " " + } + lines = append(lines, " "+bullet+p) + } } + lines = append(lines, "") } - - return strings.Join([]string{ - m.styles.title.Render("AIT 任务详情"), - m.styles.panel.Render(strings.Join(left, "\n")), - m.styles.panel.Render(strings.Join(historyLines, "\n")), - m.footer("[Enter] 运行", "[e] 编辑", "[d] 删除", "[b] 返回"), - }, "\n") + return strings.Join(lines, "\n") } -func (m *Model) renderWizard() string { - fields := m.wizardFields() - field := fields[m.wizard.current] +func (m *Model) renderWizardStep1(w int) string { + fields := m.wizardStepFields(1) var lines []string - for i, f := range fields { - marker := " " - if i == m.wizard.current { - marker = "▶ " + // Step indicator: ✓ ● ○ + lines = append(lines, m.styles.stepDone.Render("✓")+" "+ + m.styles.stepActive.Render("●")+" "+ + m.styles.stepTodo.Render("○")+" "+ + m.styles.sectionHead.Render("测试参数")) + lines = append(lines, "") + for i, field := range fields { + active := i == m.wizard.fieldIndex + lines = append(lines, m.renderWizardField(field, active)) + if field.key == "mode" { + opts := []string{modeStandard, modeTurbo} + labels := []string{"标准模式", "Turbo 模式"} + for oi, opt := range opts { + bullet := " ○ " + if opt == m.wizard.mode { + bullet = " " + m.styles.ok.Render("●") + " " + } + lines = append(lines, " "+bullet+labels[oi]) + } + } + if field.key == "prompt_mode" { + pmLabels := []string{"直接输入", "文件路径", "按长度生成"} + for pi, pl := range pmLabels { + bullet := " ○ " + if pi == m.wizard.promptModeIndex { + bullet = " " + m.styles.ok.Render("●") + " " + } + lines = append(lines, " "+bullet+pl) + } } - lines = append(lines, marker+fmt.Sprintf("%s: %s", f.label, m.displayWizardValue(f))) + lines = append(lines, "") } + return strings.Join(lines, "\n") +} - editor := "" - if field.kind == fieldText { - editor = m.styles.panel.Render(m.wizard.input.View()) +func (m *Model) renderWizardStep2(w int) string { + var lines []string + // Step indicator: ✓ ✓ ● + lines = append(lines, m.styles.stepDone.Render("✓")+" "+ + m.styles.stepDone.Render("✓")+" "+ + m.styles.stepActive.Render("●")+" "+ + m.styles.sectionHead.Render("确认保存")) + lines = append(lines, "") + d, err := buildTaskDefinition(m.wizard) + if err != nil { + lines = append(lines, m.styles.errStyle.Render("配置有误: "+err.Error())) + lines = append(lines, "") + lines = append(lines, m.styles.muted.Render("按 [Esc] 返回修改")) + return strings.Join(lines, "\n") + } + rows := [][2]string{ + {"任务名称", d.Name}, + {"协议", d.Input.NormalizedProtocol()}, + {"接口地址", truncate(d.Input.ResolvedEndpointURL(), w-16)}, + {"API 密钥", maskAPIKey(d.Input.ApiKey)}, + {"测试模型", d.Input.Model}, + } + if d.Input.Turbo { + tc := d.Input.TurboConfig + rows = append(rows, + [2]string{"测试模式", "Turbo 模式"}, + [2]string{"并发爬坡", fmt.Sprintf("%d → %d 步进+%d 每级%d", + tc.InitConcurrency, tc.MaxConcurrency, tc.StepSize, tc.LevelRequests)}, + [2]string{"停止条件", fmt.Sprintf("成功率<%.0f%% 或延迟>%s", + tc.MinSuccessRate*100, tc.MaxLatency)}, + ) } else { - editor = m.styles.panel.Render(m.displayWizardValue(field)) + rows = append(rows, + [2]string{"测试模式", "标准模式"}, + [2]string{"并发/请求", fmt.Sprintf("%d / %d", d.Input.Concurrency, d.Input.Count)}, + [2]string{"超时", d.Input.Timeout.String()}, + ) } + rows = append(rows, + [2]string{"流式", boolLabel(d.Input.Stream)}, + [2]string{"Prompt", promptSummary(d.Input)}, + ) + for _, row := range rows { + lines = append(lines, fmt.Sprintf(" %s %s", + m.styles.label.Render(fmt.Sprintf("%-10s", row[0])), + row[1])) + } + lines = append(lines, "") + lines = append(lines, m.styles.ok.Render(" ▶ 按 [Enter] 保存,[r] 保存并立即运行")) + return strings.Join(lines, "\n") +} - return strings.Join([]string{ - m.styles.title.Render("AIT 任务向导"), - m.styles.subtitle.Render(fmt.Sprintf("步骤 %d/%d", m.wizard.current+1, len(fields))), - m.styles.panel.Render(strings.Join(lines, "\n")), - editor, - m.footer("[Enter/Tab] 下一项或保存", "[←→/Space] 切换选项", "[Esc] 取消"), - }, "\n") +func (m *Model) renderWizardField(field wizardField, active bool) string { + var val string + if field.kind == fieldText && active { + val = m.wizard.input.View() + } else { + val = m.displayWizardValue(field) + } + labelStr := fmt.Sprintf("%-12s", field.label) + if active { + return m.styles.cursor.Render("▶") + " " + + m.styles.fieldActive.Render(labelStr) + " " + val + } + return " " + m.styles.fieldIdle.Render(labelStr) + " " + m.styles.muted.Render(val) } func (m *Model) renderDashboard() string { - title := "AIT 正在运行" - if m.runningTask != nil && m.runningTask.Input.Turbo { - title = "AIT Turbo 正在探测" + if m.width == 0 { + return "加载中..." + } + taskName, protocol, modelName := "", "", "" + isTurbo := false + totalReqs, concurrency := 0, 0 + if m.runningTask != nil { + taskName = m.runningTask.Name + protocol = shortProtocol(m.runningTask.Input.NormalizedProtocol()) + modelName = m.runningTask.Input.Model + isTurbo = m.runningTask.Input.Turbo + totalReqs = m.runningTask.Input.Count + concurrency = m.runningTask.Input.Concurrency + if isTurbo { + totalReqs = m.runningTask.Input.TurboConfig.LevelRequests + concurrency = m.runningTask.Input.TurboConfig.InitConcurrency + } } - stats := []string{ - fmt.Sprintf("完成: %d", m.progress.CompletedCount), - fmt.Sprintf("失败: %d", m.progress.FailedCount), - fmt.Sprintf("运行时长: %s", m.progress.ElapsedTime.Truncate(100*time.Millisecond)), + title := "AIT 正在测试 — " + modelName + if isTurbo { + title = "AIT Turbo 探测 — " + modelName + } + header := m.renderHeader(title, + fmt.Sprintf("任务: %s 协议: %s", truncate(taskName, 20), protocol)) + footer := m.renderFooter("[s] 停止", "[q] 退出") + contentH := m.height - 2 + logH := 7 + topH := contentH - logH + if topH < 6 { + topH = 6 + } + panelH := topH - 2 + leftW := (m.width - 4) * 50 / 100 + rightW := m.width - 4 - leftW + var leftContent, rightContent string + if isTurbo { + leftContent = m.buildTurboDashLeft(panelH) + rightContent = m.buildTurboDashRight(panelH) + } else { + leftContent = m.buildStdDashLeft(panelH, totalReqs, concurrency) + rightContent = m.buildStdDashRight(panelH) + } + top := m.dualColumnLayout(leftContent, rightContent, leftW, rightW, panelH) + logPanelH := logH - 2 + logContent := m.buildLogPanel(logPanelH, m.width-4) + logPanel := lipgloss.NewStyle(). + Width(m.width - 2).Height(logPanelH). + Border(lipgloss.RoundedBorder()). + BorderForeground(colorPurple). + Render(logContent) + return lipgloss.JoinVertical(lipgloss.Left, header, top, logPanel, footer) +} + +func (m *Model) buildStdDashLeft(h, total, concurrency int) string { + p := m.progress + completed := p.CompletedCount + failed := p.FailedCount + elapsed := time.Duration(0) + if !p.StartTime.IsZero() { + elapsed = time.Since(p.StartTime) + } + var estRemaining string + if completed > 0 && total > completed && elapsed > 0 { + rate := float64(completed) / elapsed.Seconds() + remaining := float64(total-completed) / rate + estRemaining = "~" + time.Duration(remaining*float64(time.Second)).Truncate(time.Second).String() + } + barW := 20 + var lines []string + lines = append(lines, m.styles.sectionHead.Render("进度")) + lines = append(lines, "") + lines = append(lines, fmt.Sprintf(" %s %s %d", + m.styles.label.Render("完成"), progressBar(completed, total, barW), completed)) + lines = append(lines, fmt.Sprintf(" %s %s %d", + m.styles.errStyle.Render("失败"), progressBarRed(failed, total, barW), failed)) + lines = append(lines, fmt.Sprintf(" %s %s %d", + m.styles.muted.Render("总计"), progressBar(total, total, barW), total)) + lines = append(lines, "") + lines = append(lines, fmt.Sprintf(" %-10s %s", + m.styles.label.Render("已用时"), + elapsed.Truncate(100*time.Millisecond))) + if estRemaining != "" { + lines = append(lines, fmt.Sprintf(" %-10s %s", + m.styles.label.Render("预计剩余"), + estRemaining)) + } + lines = append(lines, fmt.Sprintf(" %-10s %d 活跃", + m.styles.label.Render("并发槽"), + concurrency)) + return strings.Join(lines, "\n") +} + +func (m *Model) buildStdDashRight(h int) string { + p := m.progress + var lines []string + lines = append(lines, m.styles.sectionHead.Render("实时指标")) + lines = append(lines, "") + successRate := 0.0 + if p.CompletedCount > 0 { + successRate = float64(p.CompletedCount-p.FailedCount) / float64(p.CompletedCount) * 100 + } + srBar := progressBar(int(successRate), 100, 16) + lines = append(lines, fmt.Sprintf(" 成功率 %s %.1f%%", srBar, successRate)) + lines = append(lines, "") + avgTPS := 0.0 + if len(p.OutputTokenCounts) > 0 && len(p.TotalTimes) > 0 { + totalTokens := 0 + for _, tok := range p.OutputTokenCounts { + totalTokens += tok + } + totalTimeS := 0.0 + for _, d := range p.TotalTimes { + totalTimeS += d.Seconds() + } + if totalTimeS > 0 { + avgTPS = float64(totalTokens) / totalTimeS + } } - if len(m.progress.CacheHitRates) > 0 { - stats = append(stats, fmt.Sprintf("最近缓存命中率: %.1f%%", m.progress.CacheHitRates[len(m.progress.CacheHitRates)-1]*100)) + avgTTFT := time.Duration(0) + if len(p.TTFTs) > 0 { + sum := time.Duration(0) + for _, d := range p.TTFTs { + sum += d + } + avgTTFT = sum / time.Duration(len(p.TTFTs)) } - return strings.Join([]string{ - m.styles.title.Render(title), - m.styles.panel.Render(strings.Join(stats, "\n")), - m.footer("[s] 停止"), - }, "\n") + avgTotal := time.Duration(0) + if len(p.TotalTimes) > 0 { + sum := time.Duration(0) + for _, d := range p.TotalTimes { + sum += d + } + avgTotal = sum / time.Duration(len(p.TotalTimes)) + } + avgCache := 0.0 + if len(p.CacheHitRates) > 0 { + sum := 0.0 + for _, r := range p.CacheHitRates { + sum += r + } + avgCache = sum / float64(len(p.CacheHitRates)) * 100 + } + rows := [][2]string{ + {"avg TPS", fmt.Sprintf("%.1f tok/s", avgTPS)}, + {"avg TTFT", avgTTFT.Truncate(time.Millisecond).String()}, + {"缓存命中率", fmt.Sprintf("%.1f%%", avgCache)}, + {"avg 总耗时", avgTotal.Truncate(time.Millisecond).String()}, + } + for _, row := range rows { + lines = append(lines, fmt.Sprintf(" %-12s %s", + m.styles.label.Render(row[0]), + m.styles.metricVal.Render(row[1]))) + } + return strings.Join(lines, "\n") +} + +func (m *Model) buildTurboDashLeft(h int) string { + var lines []string + lines = append(lines, m.styles.sectionHead.Render("Turbo 探测中")) + lines = append(lines, "") + elapsed := time.Since(m.runStartedAt) + lines = append(lines, fmt.Sprintf(" %s %s", + m.styles.label.Render("已用时"), + elapsed.Truncate(time.Second))) + lines = append(lines, "") + lines = append(lines, m.styles.muted.Render(" 正在逐级探测最大稳定并发...")) + lines = append(lines, m.styles.muted.Render(" 完成后将自动显示结果")) + return strings.Join(lines, "\n") +} + +func (m *Model) buildTurboDashRight(h int) string { + var lines []string + lines = append(lines, m.styles.sectionHead.Render("探测状态")) + lines = append(lines, "") + lines = append(lines, " "+m.styles.ok.Render("●")+" 测试运行中") + lines = append(lines, m.styles.muted.Render(" 等待完成...")) + return strings.Join(lines, "\n") +} + +func (m *Model) buildLogPanel(maxH, width int) string { + var lines []string + lines = append(lines, m.styles.sectionHead.Render("请求日志")) + if len(m.requestLog) == 0 { + lines = append(lines, m.styles.muted.Render(" 等待请求完成...")) + return strings.Join(lines, "\n") + } + start := 0 + if len(m.requestLog) > maxH-1 { + start = len(m.requestLog) - (maxH - 1) + } + for _, entry := range m.requestLog[start:] { + // Color log entries based on their leading status marker + if strings.HasPrefix(entry, "✓") || strings.HasPrefix(entry, "✔") { + lines = append(lines, " "+m.styles.logOk.Render(entry)) + } else if strings.HasPrefix(entry, "✗") || strings.HasPrefix(entry, "✘") || strings.HasPrefix(entry, "ERR") { + lines = append(lines, " "+m.styles.logErr.Render(entry)) + } else { + lines = append(lines, " "+m.styles.muted.Render(entry)) + } + } + return strings.Join(lines, "\n") } func (m *Model) renderResult() string { - if m.runResult == nil { - return m.styles.error.Render("结果为空") + if m.width == 0 { + return "加载中..." } - result := m.runResult - lines := []string{ - fmt.Sprintf("协议: %s", result.Protocol), - fmt.Sprintf("接口: %s", result.EndpointURL), - fmt.Sprintf("成功率: %.1f%%", result.SuccessRate), - fmt.Sprintf("平均 TTFT: %s", result.AvgTTFT), - fmt.Sprintf("平均 TPS: %.2f", result.AvgTPS), - fmt.Sprintf("缓存命中率: %.1f%%", result.AvgCacheHitRate*100), - fmt.Sprintf("平均总耗时: %s", result.AvgTotalTime), + header := m.renderHeader("AIT 测试完成", "标准模式结果") + footer := m.renderFooter("[b/Esc] 返回详情") + if m.runResult == nil { + return lipgloss.JoinVertical(lipgloss.Left, header, + m.styles.errStyle.Render("结果为空"), footer) } - return strings.Join([]string{ - m.styles.title.Render("AIT 标准模式结果"), - m.styles.panel.Render(strings.Join(lines, "\n")), - m.footer("[b] 返回详情"), - }, "\n") + r := m.runResult + panelW := m.width - 4 + panelH := m.height - 4 + var lines []string + lines = append(lines, m.styles.sectionHead.Render(fmt.Sprintf("任务完成 — %s", r.Model))) + lines = append(lines, "") + rows := [][2]string{ + {"协议", r.Protocol}, + {"接口地址", truncate(r.EndpointURL, panelW-16)}, + {"模型", r.Model}, + {"成功率", fmt.Sprintf("%.1f%%", r.SuccessRate)}, + {"总请求数", fmt.Sprintf("%d", r.TotalRequests)}, + {"avg TTFT", r.AvgTTFT.Truncate(time.Millisecond).String()}, + {"avg TPS", fmt.Sprintf("%.2f tok/s", r.AvgTPS)}, + {"缓存命中率", fmt.Sprintf("%.1f%%", r.AvgCacheHitRate*100)}, + {"avg 总耗时", r.AvgTotalTime.Truncate(time.Millisecond).String()}, + {"总测试时长", r.TotalTime.Truncate(time.Second).String()}, + } + for _, row := range rows { + lines = append(lines, fmt.Sprintf(" %s %s", + m.styles.label.Render(fmt.Sprintf("%-12s", row[0])), + m.styles.value.Render(row[1]))) + } + content := strings.Join(lines, "\n") + panel := lipgloss.NewStyle(). + Width(panelW).Height(panelH). + Border(lipgloss.RoundedBorder()). + BorderForeground(colorPurple). + Render(content) + return lipgloss.JoinVertical(lipgloss.Left, header, panel, footer) } func (m *Model) renderTurboResult() string { - if m.turboResult == nil { - return m.styles.error.Render("Turbo 结果为空") + if m.width == 0 { + return "加载中..." } - lines := []string{ - fmt.Sprintf("协议: %s", m.turboResult.Protocol), - fmt.Sprintf("接口: %s", m.turboResult.EndpointURL), - fmt.Sprintf("最大稳定并发: %d", m.turboResult.MaxStableConcurrency), - fmt.Sprintf("峰值平均 TPS: %.2f", m.turboResult.PeakTPS), - fmt.Sprintf("停止原因: %s", m.turboResult.StopReason), + header := m.renderHeader("AIT Turbo 完成", "Turbo 模式结果") + footer := m.renderFooter("[b/Esc] 返回详情") + if m.turboResult == nil { + return lipgloss.JoinVertical(lipgloss.Left, header, + m.styles.errStyle.Render("Turbo 结果为空"), footer) } - for _, level := range m.turboResult.Levels { - status := "✓" + r := m.turboResult + panelW := m.width - 4 + panelH := m.height - 4 + var lines []string + lines = append(lines, m.styles.sectionHead.Render(fmt.Sprintf( + "Turbo 完成 — %s 最大稳定并发: %d 峰值 TPS: %.1f", + r.Model, r.MaxStableConcurrency, r.PeakTPS))) + lines = append(lines, "") + lines = append(lines, m.styles.tableHead.Render(fmt.Sprintf( + " %-6s %-8s %-10s %-10s %-8s %-8s %s", + "并发", "成功率", "TPS", "TTFT", "Cache", "总耗时", "状态"))) + lines = append(lines, m.styles.muted.Render(strings.Repeat("─", panelW-4))) + for _, level := range r.Levels { + status := m.styles.ok.Render("✓ 稳定") if !level.Stable { - status = "✗" + status = m.styles.errStyle.Render("✗ 不稳定") + } + marker := " " + if level.Concurrency == r.MaxStableConcurrency { + marker = m.styles.cursor.Render("▶ ") } - lines = append(lines, fmt.Sprintf("%s 并发 %d 成功率 %.1f%% avgTPS %.2f cache %.1f%%", status, level.Concurrency, level.SuccessRate*100, level.AvgTPS, level.CacheHitRate*100)) + lines = append(lines, fmt.Sprintf("%s%-6d %-8.1f%% %-10.1f %-10s %-8.1f%% %-8s %s", + marker, + level.Concurrency, + level.SuccessRate*100, + level.AvgTPS, + level.AvgTTFT.Truncate(time.Millisecond), + level.CacheHitRate*100, + level.AvgTotalTime.Truncate(time.Millisecond), + status)) + } + lines = append(lines, "") + lines = append(lines, m.styles.muted.Render(" 停止原因: "+r.StopReason)) + content := strings.Join(lines, "\n") + panel := lipgloss.NewStyle(). + Width(panelW).Height(panelH). + Border(lipgloss.RoundedBorder()). + BorderForeground(colorPurple). + Render(content) + return lipgloss.JoinVertical(lipgloss.Left, header, panel, footer) +} + +func (m *Model) renderHeader(left, right string) string { + if m.width == 0 { + return "" + } + // Each part gets the same header background so the bar spans the full width. + leftStyled := lipgloss.NewStyle(). + Background(colorHeaderBg).Bold(true).Foreground(colorPink). + Render(" ◆ " + left) + rightStyled := lipgloss.NewStyle(). + Background(colorHeaderBg).Foreground(colorHeaderFg). + Render(right + " ") + lw := lipgloss.Width(leftStyled) + rw := lipgloss.Width(rightStyled) + gap := m.width - lw - rw + if gap < 0 { + gap = 0 + } + spacer := lipgloss.NewStyle().Background(colorHeaderBg).Render(strings.Repeat(" ", gap)) + return leftStyled + spacer + rightStyled +} + +func (m *Model) renderFooter(hints ...string) string { + if m.width == 0 { + return "" + } + // Left: colored AIT brand badge + leftBadge := lipgloss.NewStyle(). + Background(colorPurple).Foreground(colorWhite).Bold(true). + Render(" ◆ AIT ") + // Right: dim version badge + rightBadge := lipgloss.NewStyle(). + Background(colorHeaderBg).Foreground(colorHeaderFg). + Render(" v0.1 ") + // Middle: key hints in pink on footer bg + var parts []string + for _, h := range hints { + parts = append(parts, lipgloss.NewStyle().Foreground(colorPink).Render(h)) + } + hintsStr := " " + strings.Join(parts, " ") + lw := lipgloss.Width(leftBadge) + rw := lipgloss.Width(rightBadge) + hw := lipgloss.Width(hintsStr) + gap := m.width - lw - rw - hw + if gap < 0 { + gap = 0 + } + middle := lipgloss.NewStyle(). + Background(colorFooterBg).Foreground(colorMuted). + Render(hintsStr + strings.Repeat(" ", gap)) + return leftBadge + middle + rightBadge +} + +func (m *Model) dualColumnLayout(leftContent, rightContent string, leftW, rightW, h int) string { + bc := colorPurple + leftPane := lipgloss.NewStyle(). + Width(leftW).Height(h). + Border(lipgloss.RoundedBorder()). + BorderForeground(bc). + Render(leftContent) + rightPane := lipgloss.NewStyle(). + Width(rightW).Height(h). + Border(lipgloss.RoundedBorder()). + BorderForeground(bc). + Render(rightContent) + return lipgloss.JoinHorizontal(lipgloss.Top, leftPane, rightPane) +} + +func progressBar(current, total, width int) string { + if total <= 0 || width <= 0 { + return lipgloss.NewStyle().Foreground(lipgloss.Color("238")).Render(strings.Repeat("░", width)) + } + filled := current * width / total + if filled > width { + filled = width + } + bar := lipgloss.NewStyle().Foreground(colorGreen).Render(strings.Repeat("█", filled)) + empty := lipgloss.NewStyle().Foreground(lipgloss.Color("238")).Render(strings.Repeat("░", width-filled)) + return bar + empty +} + +// progressBarRed renders a red-tinted progress bar for failure/error metrics. +func progressBarRed(current, total, width int) string { + if total <= 0 || width <= 0 { + return lipgloss.NewStyle().Foreground(lipgloss.Color("238")).Render(strings.Repeat("░", width)) + } + filled := current * width / total + if filled > width { + filled = width + } + bar := lipgloss.NewStyle().Foreground(colorRed).Render(strings.Repeat("█", filled)) + empty := lipgloss.NewStyle().Foreground(lipgloss.Color("238")).Render(strings.Repeat("░", width-filled)) + return bar + empty +} + +func truncate(s string, n int) string { + if n <= 0 || len(s) <= n { + return s } - return strings.Join([]string{ - m.styles.title.Render("AIT Turbo 结果"), - m.styles.panel.Render(strings.Join(lines, "\n")), - m.footer("[b] 返回详情"), - }, "\n") + if n <= 3 { + return s[:n] + } + return s[:n-3] + "..." +} + +func timeAgo(t time.Time) string { + d := time.Since(t) + if d < time.Minute { + return fmt.Sprintf("%ds 前", int(d.Seconds())) + } + if d < time.Hour { + return fmt.Sprintf("%dm 前", int(d.Minutes())) + } + if d < 24*time.Hour { + return fmt.Sprintf("%dh 前", int(d.Hours())) + } + return t.Format("01-02 15:04") +} + +func shortProtocol(p string) string { + p = strings.ReplaceAll(p, "openai-", "") + p = strings.ReplaceAll(p, "anthropic-", "") + return p } -func (m *Model) footer(parts ...string) string { - styled := make([]string, 0, len(parts)) - for _, part := range parts { - styled = append(styled, m.styles.key.Render(part+" ")) +func maskAPIKey(key string) string { + if len(key) == 0 { + return "(空)" + } + if len(key) <= 8 { + return strings.Repeat("•", len(key)) } - return lipgloss.JoinHorizontal(lipgloss.Left, styled...) + return key[:4] + strings.Repeat("•", len(key)-8) + key[len(key)-4:] } func (m *Model) currentTask() (types.TaskDefinition, bool) { @@ -626,39 +1436,74 @@ func protocolIndex(protocol string) int { return 0 } -func (m *Model) wizardFields() []wizardField { - fields := []wizardField{ - {key: "name", label: "任务名称", kind: fieldText}, - {key: "protocol", label: "协议类型", kind: fieldSelect}, - {key: "endpoint", label: "完整接口地址", kind: fieldText}, - {key: "apiKey", label: "API 密钥", kind: fieldText}, - {key: "model", label: "测试模型", kind: fieldText}, - {key: "mode", label: "运行模式", kind: fieldSelect}, - } - if m.wizard.mode == modeTurbo { - fields = append(fields, - wizardField{key: "turbo_init", label: "初始并发", kind: fieldText}, - wizardField{key: "turbo_max", label: "最大并发", kind: fieldText}, - wizardField{key: "turbo_step", label: "步进值", kind: fieldText}, - wizardField{key: "turbo_level_requests", label: "每级请求数", kind: fieldText}, - wizardField{key: "turbo_min_success", label: "最小成功率", kind: fieldText}, - wizardField{key: "turbo_max_latency", label: "最大平均延迟", kind: fieldText}, - ) - } else { +func (m *Model) wizardStepFields(step int) []wizardField { + switch step { + case 0: + return []wizardField{ + {key: "name", label: "任务名称", kind: fieldText}, + {key: "protocol", label: "协议类型", kind: fieldSelect}, + {key: "endpoint", label: "完整接口地址", kind: fieldText}, + {key: "apiKey", label: "API 密钥", kind: fieldText}, + {key: "model", label: "测试模型", kind: fieldText}, + } + case 1: + fields := []wizardField{ + {key: "mode", label: "运行模式", kind: fieldSelect}, + } + if m.wizard.mode == modeTurbo { + fields = append(fields, + wizardField{key: "turbo_init", label: "初始并发", kind: fieldText}, + wizardField{key: "turbo_max", label: "最大并发", kind: fieldText}, + wizardField{key: "turbo_step", label: "步进值", kind: fieldText}, + wizardField{key: "turbo_level_requests", label: "每级请求数", kind: fieldText}, + wizardField{key: "turbo_min_success", label: "最小成功率", kind: fieldText}, + wizardField{key: "turbo_max_latency", label: "最大平均延迟", kind: fieldText}, + ) + } else { + fields = append(fields, + wizardField{key: "concurrency", label: "并发数", kind: fieldText}, + wizardField{key: "count", label: "请求总数", kind: fieldText}, + wizardField{key: "timeout", label: "超时时间", kind: fieldText}, + ) + } fields = append(fields, - wizardField{key: "concurrency", label: "并发数", kind: fieldText}, - wizardField{key: "count", label: "请求总数", kind: fieldText}, - wizardField{key: "timeout", label: "超时时间", kind: fieldText}, + wizardField{key: "stream", label: "流式模式", kind: fieldToggle}, + wizardField{key: "thinking", label: "Thinking 模式", kind: fieldToggle}, + wizardField{key: "report", label: "生成报告", kind: fieldToggle}, + wizardField{key: "prompt_mode", label: "Prompt 方式", kind: fieldSelect}, + wizardField{key: "prompt_value", label: promptValueLabel(m.wizard.promptModeIndex), kind: fieldText}, ) + return fields + default: + return nil } - fields = append(fields, - wizardField{key: "stream", label: "流式模式", kind: fieldToggle}, - wizardField{key: "thinking", label: "Thinking 模式", kind: fieldToggle}, - wizardField{key: "report", label: "生成报告", kind: fieldToggle}, - wizardField{key: "prompt_mode", label: "Prompt 输入方式", kind: fieldSelect}, - wizardField{key: "prompt_value", label: promptValueLabel(m.wizard.promptModeIndex), kind: fieldText}, - ) - return fields +} + +func (m *Model) advanceWizardField(delta int) { + if m.wizard == nil { + return + } + fields := m.wizardStepFields(m.wizard.step) + next := m.wizard.fieldIndex + delta + if next < 0 { + if m.wizard.step > 0 { + m.wizard.step-- + prevFields := m.wizardStepFields(m.wizard.step) + m.wizard.fieldIndex = len(prevFields) - 1 + m.refreshWizardInput() + } + return + } + if next >= len(fields) { + m.wizard.step++ + m.wizard.fieldIndex = 0 + if m.wizard.step < 2 { + m.refreshWizardInput() + } + return + } + m.wizard.fieldIndex = next + m.refreshWizardInput() } func promptValueLabel(promptModeIndex int) string { @@ -673,7 +1518,14 @@ func promptValueLabel(promptModeIndex int) string { } func (m *Model) currentWizardField() wizardField { - return m.wizardFields()[m.wizard.current] + if m.wizard == nil { + return wizardField{} + } + fields := m.wizardStepFields(m.wizard.step) + if len(fields) == 0 || m.wizard.fieldIndex >= len(fields) { + return wizardField{} + } + return fields[m.wizard.fieldIndex] } func (m *Model) refreshWizardInput() { @@ -696,6 +1548,9 @@ func (m *Model) refreshWizardInput() { } func (m *Model) cycleWizardField(delta int) { + if m.wizard == nil { + return + } field := m.currentWizardField() switch field.key { case "protocol": @@ -718,7 +1573,11 @@ func (m *Model) cycleWizardField(delta int) { default: return } - m.wizard.current = min(m.wizard.current, len(m.wizardFields())-1) + // Clamp fieldIndex in case field count changed (e.g. mode switch) + fields := m.wizardStepFields(m.wizard.step) + if m.wizard.fieldIndex >= len(fields) && len(fields) > 0 { + m.wizard.fieldIndex = len(fields) - 1 + } m.refreshWizardInput() } @@ -758,13 +1617,6 @@ func wrapIndex(index, length int) int { return index % length } -func min(a, b int) int { - if a < b { - return a - } - return b -} - func buildTaskDefinition(state *wizardState) (types.TaskDefinition, error) { protocol := protocolOptions[state.protocolIndex] input := types.Input{ @@ -887,9 +1739,10 @@ func (m *Model) saveWizard() error { break } } + m.reloadHistoryForSelectedTask() m.status = "任务已保存" m.wizard = nil - m.view = viewTaskList + m.view = viewTaskDetail return nil } diff --git a/internal/tui/styles.go b/internal/tui/styles.go index 9a69807..f0f61e0 100644 --- a/internal/tui/styles.go +++ b/internal/tui/styles.go @@ -2,31 +2,183 @@ package tui import "github.com/charmbracelet/lipgloss" +// Color palette — inspired by the Lip Gloss demo aesthetic: +// electric-purple header, vivid hot-pink brand, deep-plum panels, aqua accents. +const ( + colorHeaderBg = lipgloss.Color("57") // electric indigo — header background + colorFooterBg = lipgloss.Color("235") // near-black footer background + colorPink = lipgloss.Color("205") // vivid hot pink/magenta — brand primary + colorCyan = lipgloss.Color("86") // bright aquamarine — table headers + colorPurple = lipgloss.Color("99") // medium violet — badge, border + colorPurpleDim = lipgloss.Color("60") // slate purple — selected row bg + colorPanelBg = lipgloss.Color("55") // deep plum — content card / panel bg + colorGold = lipgloss.Color("214") // amber — status badge + colorGreen = lipgloss.Color("78") // vivid spring green — ok/success + colorRed = lipgloss.Color("204") // vivid rose-red — error/fail + colorYellow = lipgloss.Color("221") // warm yellow — metric values + colorTeal = lipgloss.Color("111") // periwinkle-teal — labels + colorWhite = lipgloss.Color("255") // bright white + colorMuted = lipgloss.Color("245") // muted gray + colorDimBorder = lipgloss.Color("238") // dim border gray + colorFieldBg = lipgloss.Color("55") // deep plum — active field bg + colorDark = lipgloss.Color("235") // near-black text on light badge + colorHeaderFg = lipgloss.Color("212") // light pink — header right text +) + type styles struct { - title lipgloss.Style - subtitle lipgloss.Style - panel lipgloss.Style - selected lipgloss.Style - muted lipgloss.Style - error lipgloss.Style - ok lipgloss.Style - key lipgloss.Style - label lipgloss.Style - value lipgloss.Style + header lipgloss.Style + footer lipgloss.Style + sectionHead lipgloss.Style + tableHead lipgloss.Style + tableRow lipgloss.Style + tableRowSel lipgloss.Style + label lipgloss.Style + value lipgloss.Style + muted lipgloss.Style + ok lipgloss.Style + errStyle lipgloss.Style + key lipgloss.Style + metricVal lipgloss.Style + dialog lipgloss.Style + fieldActive lipgloss.Style + fieldIdle lipgloss.Style + cursor lipgloss.Style + // Badge styles + badge lipgloss.Style // AIT brand badge (purple) + badgeAlt lipgloss.Style // alternate badge (gold) + tagTurbo lipgloss.Style // "TURBO" mode inline tag + tagStd lipgloss.Style // "标准" mode inline tag + // Log entry markers + logOk lipgloss.Style + logErr lipgloss.Style + // Wizard step indicators + stepDone lipgloss.Style + stepActive lipgloss.Style + stepTodo lipgloss.Style + // Primary action button + btnPrimary lipgloss.Style + // Divider / decorative line + divider lipgloss.Style + // Content panel (deep-plum background card, like the demo's purple paragraphs) + panel lipgloss.Style } func newStyles() styles { - border := lipgloss.RoundedBorder() return styles{ - title: lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("212")).Padding(0, 1), - subtitle: lipgloss.NewStyle().Foreground(lipgloss.Color("110")).Padding(0, 1), - panel: lipgloss.NewStyle().Border(border).BorderForeground(lipgloss.Color("62")).Padding(1, 2), - selected: lipgloss.NewStyle().Foreground(lipgloss.Color("230")).Background(lipgloss.Color("62")).Bold(true), - muted: lipgloss.NewStyle().Foreground(lipgloss.Color("245")), - error: lipgloss.NewStyle().Foreground(lipgloss.Color("203")).Bold(true), - ok: lipgloss.NewStyle().Foreground(lipgloss.Color("85")).Bold(true), - key: lipgloss.NewStyle().Foreground(lipgloss.Color("214")).Bold(true), - label: lipgloss.NewStyle().Foreground(lipgloss.Color("151")).Bold(true), - value: lipgloss.NewStyle().Foreground(lipgloss.Color("255")), + // Header: deep indigo-purple background, white foreground + header: lipgloss.NewStyle(). + Background(colorHeaderBg). + Foreground(colorWhite), + // Footer: dark near-black background + footer: lipgloss.NewStyle(). + Background(colorFooterBg). + Foreground(colorMuted), + // Section headings: hot pink, bold + sectionHead: lipgloss.NewStyle(). + Foreground(colorPink). + Bold(true), + // Table column headers: vivid cyan, bold + tableHead: lipgloss.NewStyle(). + Foreground(colorCyan). + Bold(true), + // Normal table rows: bright white + tableRow: lipgloss.NewStyle(). + Foreground(colorWhite), + // Selected table row: dim-purple bg, white text, bold + tableRowSel: lipgloss.NewStyle(). + Background(colorPurpleDim). + Foreground(colorWhite). + Bold(true), + // Property labels: soft teal, bold + label: lipgloss.NewStyle(). + Foreground(colorTeal). + Bold(true), + // Property values: bright white + value: lipgloss.NewStyle(). + Foreground(colorWhite), + // Secondary/muted text: gray + muted: lipgloss.NewStyle(). + Foreground(colorMuted), + // Success indicator: bright green, bold + ok: lipgloss.NewStyle(). + Foreground(colorGreen). + Bold(true), + // Error indicator: red, bold + errStyle: lipgloss.NewStyle(). + Foreground(colorRed). + Bold(true), + // Keyboard shortcut keys: hot pink, bold + key: lipgloss.NewStyle(). + Foreground(colorPink). + Bold(true), + // Metric numeric values: yellow, bold + metricVal: lipgloss.NewStyle(). + Foreground(colorYellow). + Bold(true), + // Dialog/modal box: rounded border in hot pink, padded + dialog: lipgloss.NewStyle(). + Border(lipgloss.RoundedBorder()). + BorderForeground(colorPink). + Padding(1, 2), + // Active wizard field: dark purple-blue bg, white text + fieldActive: lipgloss.NewStyle(). + Background(colorFieldBg). + Foreground(colorWhite), + // Idle wizard field: muted gray text + fieldIdle: lipgloss.NewStyle(). + Foreground(colorMuted), + // Cursor/selection arrow: hot pink, bold + cursor: lipgloss.NewStyle(). + Foreground(colorPink). + Bold(true), + // AIT brand badge: purple bg, white text, padded + badge: lipgloss.NewStyle(). + Background(colorPurple). + Foreground(colorWhite). + Bold(true). + Padding(0, 1), + // Alternate badge: gold bg, dark text + badgeAlt: lipgloss.NewStyle(). + Background(colorGold). + Foreground(colorDark). + Bold(true). + Padding(0, 1), + // TURBO mode tag: pink bg, dark text + tagTurbo: lipgloss.NewStyle(). + Foreground(colorPink). + Bold(true), + // Standard mode tag: cyan text + tagStd: lipgloss.NewStyle(). + Foreground(colorCyan), + // Log ok: green + logOk: lipgloss.NewStyle(). + Foreground(colorGreen), + // Log error: red + logErr: lipgloss.NewStyle(). + Foreground(colorRed), + // Wizard step done: green checkmark + stepDone: lipgloss.NewStyle(). + Foreground(colorGreen), + // Wizard step active: pink, bold + stepActive: lipgloss.NewStyle(). + Foreground(colorPink). + Bold(true), + // Wizard step todo: dim + stepTodo: lipgloss.NewStyle(). + Foreground(colorMuted), + // Primary action button: hot-pink bg, dark text, padded + btnPrimary: lipgloss.NewStyle(). + Background(colorPink). + Foreground(colorDark). + Bold(true). + Padding(0, 2), + // Divider line: dim gray + divider: lipgloss.NewStyle(). + Foreground(colorDimBorder), + // Content panel: deep-plum background, white text, padded (like demo's purple paragraphs) + panel: lipgloss.NewStyle(). + Background(colorPanelBg). + Foreground(colorWhite). + Padding(1, 2), } } From 3ed37aedfe5dbde4bad6bdd5d0dbe654df53eba0 Mon Sep 17 00:00:00 2001 From: Alain Date: Sun, 17 May 2026 00:23:17 +0800 Subject: [PATCH 04/52] =?UTF-8?q?feat:=20=E5=AE=9E=E7=8E=B0=E4=BA=8B?= =?UTF-8?q?=E4=BB=B6=E6=80=BB=E7=BA=BF=E5=92=8C=E4=BB=BB=E5=8A=A1=E7=AE=A1?= =?UTF-8?q?=E7=90=86=E5=8A=9F=E8=83=BD=EF=BC=8C=E6=94=AF=E6=8C=81=E4=BB=BB?= =?UTF-8?q?=E5=8A=A1=E7=9A=84=E5=88=9B=E5=BB=BA=E3=80=81=E6=9B=B4=E6=96=B0?= =?UTF-8?q?=E3=80=81=E5=88=A0=E9=99=A4=E5=8F=8A=E8=BF=90=E8=A1=8C=E7=8A=B6?= =?UTF-8?q?=E6=80=81=E7=AE=A1=E7=90=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/server/event.go | 76 +++++++ internal/server/run.go | 424 ++++++++++++++++++++++++++++++++++++++ internal/server/server.go | 82 ++++++++ internal/server/task.go | 103 +++++++++ internal/server/types.go | 123 +++++++++++ 5 files changed, 808 insertions(+) create mode 100644 internal/server/event.go create mode 100644 internal/server/run.go create mode 100644 internal/server/server.go create mode 100644 internal/server/task.go create mode 100644 internal/server/types.go diff --git a/internal/server/event.go b/internal/server/event.go new file mode 100644 index 0000000..8376dce --- /dev/null +++ b/internal/server/event.go @@ -0,0 +1,76 @@ +package server + +import "sync" + +// subscriber 持有单个订阅者的带缓冲事件通道。 +type subscriber struct { + ch chan Event +} + +// eventBus 按 RunID 分组管理订阅者,负责事件的发布与通道生命周期管理。 +type eventBus struct { + mu sync.Mutex + subscribers map[RunID][]*subscriber +} + +func newEventBus() *eventBus { + return &eventBus{ + subscribers: make(map[RunID][]*subscriber), + } +} + +// Subscribe 注册对指定 RunID 的订阅,返回只读事件通道和取消函数。 +// 取消函数调用后通道被关闭,range 循环自然退出。 +// 通道容量为 64;若消费者处理过慢,后续事件将被丢弃(非阻塞发布)。 +func (b *eventBus) Subscribe(runID RunID) (<-chan Event, CancelFunc) { + b.mu.Lock() + defer b.mu.Unlock() + + sub := &subscriber{ch: make(chan Event, 64)} + b.subscribers[runID] = append(b.subscribers[runID], sub) + + cancel := func() { + b.mu.Lock() + defer b.mu.Unlock() + b.removeLocked(runID, sub) + } + return sub.ch, cancel +} + +// Publish 向该 RunID 的所有订阅者非阻塞地投递事件。 +func (b *eventBus) Publish(event Event) { + b.mu.Lock() + defer b.mu.Unlock() + + for _, sub := range b.subscribers[event.RunID] { + select { + case sub.ch <- event: + default: + // 消费者过慢时丢弃,避免阻塞发布方 + } + } +} + +// CloseRun 关闭该 RunID 下所有订阅通道并清理条目。 +// 必须在该 RunID 的最后一个 Publish 调用之后执行,以确保不丢失末尾事件。 +func (b *eventBus) CloseRun(runID RunID) { + b.mu.Lock() + defer b.mu.Unlock() + + for _, sub := range b.subscribers[runID] { + close(sub.ch) + } + delete(b.subscribers, runID) +} + +// removeLocked 从订阅列表中移除 sub 并关闭其通道(已持锁时调用)。 +func (b *eventBus) removeLocked(runID RunID, sub *subscriber) { + subs := b.subscribers[runID] + for i, s := range subs { + if s == sub { + b.subscribers[runID] = append(subs[:i], subs[i+1:]...) + close(sub.ch) + return + } + } +} diff --git a/internal/server/run.go b/internal/server/run.go new file mode 100644 index 0000000..e28c809 --- /dev/null +++ b/internal/server/run.go @@ -0,0 +1,424 @@ +package server + +import ( + "fmt" + "path/filepath" + "sync" + "time" + + "github.com/yinxulai/ait/internal/client" + "github.com/yinxulai/ait/internal/report" + "github.com/yinxulai/ait/internal/runner" + "github.com/yinxulai/ait/internal/store" + "github.com/yinxulai/ait/internal/task" + "github.com/yinxulai/ait/internal/turbo" + "github.com/yinxulai/ait/internal/types" +) + +// activeRun 持有一次正在执行的运行的全部运行时状态。 +type activeRun struct { + mu sync.RWMutex + state *RunState + rnr *runner.Runner // standard 模式使用 + turboEngine *turbo.Engine // turbo 模式使用 + // 用于计算实时均值 + tpsSum float64 + ttftSum time.Duration + cacheSum float64 + doneCount int // 与 state.DoneReqs 保持同步,方便不加锁时计算 +} + +// snapshotState 返回 state 的深度拷贝(调用方须已持有 activeRun.mu 读锁)。 +func (ar *activeRun) snapshotState() *RunState { + s := ar.state + snap := *s + // 深拷贝切片 + if len(s.Requests) > 0 { + snap.Requests = make([]*RequestMetrics, len(s.Requests)) + copy(snap.Requests, s.Requests) + } + if len(s.Levels) > 0 { + snap.Levels = make([]types.TurboLevelResult, len(s.Levels)) + copy(snap.Levels, s.Levels) + } + return &snap +} + +// mapRequestMetrics 将 client.ResponseMetrics 映射到 server.RequestMetrics。 +func mapRequestMetrics(m *client.ResponseMetrics, idx int, err error) *RequestMetrics { + rm := &RequestMetrics{Index: idx} + if m == nil { + rm.Success = false + if err != nil { + rm.ErrorMessage = err.Error() + } + return rm + } + + rm.Success = m.ErrorMessage == "" && err == nil + rm.TotalTime = m.TotalTime + rm.TTFT = m.TimeToFirstToken + rm.PromptTokens = m.PromptTokens + rm.CompletionTokens = m.CompletionTokens + rm.CachedTokens = m.CachedInputTokens + rm.DNSTime = m.DNSTime + rm.ConnectTime = m.ConnectTime + rm.TLSTime = m.TLSHandshakeTime + rm.TargetIP = m.TargetIP + rm.ErrorMessage = m.ErrorMessage + if err != nil && rm.ErrorMessage == "" { + rm.ErrorMessage = err.Error() + } + + if m.TotalTime > 0 && m.CompletionTokens > 0 { + rm.TPS = float64(m.CompletionTokens) / m.TotalTime.Seconds() + } + if m.PromptTokens > 0 { + rm.CacheHitRate = float64(m.CachedInputTokens) / float64(m.PromptTokens) + } + return rm +} + +// historyPath 返回指定任务的历史文件路径。 +func historyPath(historyDir, taskID string) string { + return filepath.Join(historyDir, taskID+".json") +} + +// StartRun 启动一次新的运行,立即返回 RunID。 +func (s *serverImpl) StartRun(taskID string) (RunID, error) { + s.mu.RLock() + taskDef, ok := s.taskStore.Get(taskID) + historyDir := s.historyDir + s.mu.RUnlock() + + if !ok { + return "", fmt.Errorf("task %q not found", taskID) + } + + // 解析 PromptSource(将 PromptText/PromptFile 转换为可调用的 PromptSource) + hydratedInput, err := task.HydrateInput(taskDef.Input) + if err != nil { + return "", fmt.Errorf("hydrate input: %w", err) + } + + runID := RunID(fmt.Sprintf("run_%d", time.Now().UnixNano())) + now := time.Now() + + mode := "standard" + if hydratedInput.Turbo { + mode = "turbo" + } + + state := &RunState{ + RunID: runID, + TaskID: taskID, + Status: RunStatusRunning, + Mode: mode, + StartedAt: now, + TotalReqs: hydratedInput.Count, + Requests: make([]*RequestMetrics, hydratedInput.Count), + } + + ar := &activeRun{state: state} + + s.mu.Lock() + s.activeRuns[runID] = ar + s.mu.Unlock() + + if hydratedInput.Turbo { + go s.runTurbo(ar, runID, taskDef, hydratedInput, historyDir) + } else { + go s.runStandard(ar, runID, taskDef, hydratedInput, historyDir) + } + + return runID, nil +} + +// runStandard 在 goroutine 中执行标准运行。 +func (s *serverImpl) runStandard(ar *activeRun, runID RunID, taskDef types.TaskDefinition, input types.Input, historyDir string) { + rnr, err := runner.NewRunner(taskDef.ID, input) + if err != nil { + s.failRun(ar, runID, taskDef, historyDir, err) + return + } + + ar.mu.Lock() + ar.rnr = rnr + ar.mu.Unlock() + + reportData, err := rnr.RunWithCallback(func(metrics *client.ResponseMetrics, idx int, cbErr error) { + rm := mapRequestMetrics(metrics, idx, cbErr) + + ar.mu.Lock() + if idx < len(ar.state.Requests) { + ar.state.Requests[idx] = rm + } + ar.state.DoneReqs++ + if rm.Success { + ar.state.SuccessReqs++ + ar.tpsSum += rm.TPS + ar.ttftSum += rm.TTFT + ar.cacheSum += rm.CacheHitRate + } else { + ar.state.FailedReqs++ + } + successCount := ar.state.SuccessReqs + done := ar.state.DoneReqs + // 更新实时均值 + if successCount > 0 { + ar.state.AvgTPS = ar.tpsSum / float64(successCount) + ar.state.AvgTTFT = ar.ttftSum / time.Duration(successCount) + ar.state.CacheHitRate = ar.cacheSum / float64(successCount) + } + if done > 0 { + ar.state.SuccessRate = float64(successCount) / float64(done) * 100 + } + ar.mu.Unlock() + + s.bus.Publish(Event{RunID: runID, Kind: EventRequestDone, Payload: rm}) + }) + + if err != nil { + s.failRun(ar, runID, taskDef, historyDir, err) + return + } + + s.completeStandardRun(ar, runID, taskDef, historyDir, reportData) +} + +// runTurbo 在 goroutine 中执行 Turbo 运行。 +func (s *serverImpl) runTurbo(ar *activeRun, runID RunID, taskDef types.TaskDefinition, input types.Input, historyDir string) { + engine := turbo.New(turbo.DefaultRunnerFactory(taskDef.ID)) + + ar.mu.Lock() + ar.turboEngine = engine + ar.mu.Unlock() + + turboResult, err := engine.Run(input) + if err != nil { + s.failRun(ar, runID, taskDef, historyDir, err) + return + } + + s.completeTurboRun(ar, runID, taskDef, historyDir, turboResult) +} + +// completeStandardRun 处理标准运行成功完成的后续工作。 +func (s *serverImpl) completeStandardRun(ar *activeRun, runID RunID, taskDef types.TaskDefinition, historyDir string, data *types.ReportData) { + finishedAt := time.Now() + + ar.mu.Lock() + ar.state.Status = RunStatusCompleted + ar.state.FinishedAt = &finishedAt + ar.state.StandardResult = data + if data != nil { + ar.state.AvgTPS = data.AvgTPS + ar.state.AvgTTFT = data.AvgTTFT + ar.state.SuccessRate = data.SuccessRate + ar.state.CacheHitRate = data.AvgCacheHitRate + } + snap := ar.snapshotState() + ar.mu.Unlock() + + s.bus.Publish(Event{RunID: runID, Kind: EventRunComplete, Payload: snap}) + s.bus.CloseRun(runID) + + summary := types.TaskRunSummary{ + RunID: string(runID), + TaskID: taskDef.ID, + Mode: "standard", + Status: string(RunStatusCompleted), + Protocol: taskDef.Input.NormalizedProtocol(), + Model: taskDef.Input.Model, + StartedAt: snap.StartedAt, + FinishedAt: finishedAt, + SuccessRate: snap.SuccessRate, + AvgTTFT: snap.AvgTTFT, + AvgTPS: snap.AvgTPS, + CacheHitRate: snap.CacheHitRate, + } + + s.persistRunResult(taskDef.ID, historyDir, summary) +} + +// completeTurboRun 处理 Turbo 运行成功完成的后续工作。 +func (s *serverImpl) completeTurboRun(ar *activeRun, runID RunID, taskDef types.TaskDefinition, historyDir string, result *types.TurboResult) { + finishedAt := time.Now() + + ar.mu.Lock() + ar.state.Status = RunStatusCompleted + ar.state.FinishedAt = &finishedAt + ar.state.TurboResult = result + if result != nil { + ar.state.Levels = result.Levels + ar.state.CurrentLevel = result.MaxStableConcurrency + } + snap := ar.snapshotState() + ar.mu.Unlock() + + s.bus.Publish(Event{RunID: runID, Kind: EventRunComplete, Payload: snap}) + s.bus.CloseRun(runID) + + var maxStable int + var peakTPS float64 + if result != nil { + maxStable = result.MaxStableConcurrency + peakTPS = result.PeakTPS + } + + summary := types.TaskRunSummary{ + RunID: string(runID), + TaskID: taskDef.ID, + Mode: "turbo", + Status: string(RunStatusCompleted), + Protocol: taskDef.Input.NormalizedProtocol(), + Model: taskDef.Input.Model, + StartedAt: snap.StartedAt, + FinishedAt: finishedAt, + MaxStableConcurrency: maxStable, + AvgTPS: peakTPS, + } + + s.persistRunResult(taskDef.ID, historyDir, summary) +} + +// failRun 处理运行失败的后续工作。 +func (s *serverImpl) failRun(ar *activeRun, runID RunID, taskDef types.TaskDefinition, historyDir string, runErr error) { + finishedAt := time.Now() + + ar.mu.Lock() + ar.state.Status = RunStatusFailed + ar.state.FinishedAt = &finishedAt + ar.state.ErrorMsg = runErr.Error() + snap := ar.snapshotState() + ar.mu.Unlock() + + s.bus.Publish(Event{RunID: runID, Kind: EventRunFailed, Payload: runErr}) + s.bus.CloseRun(runID) + + summary := types.TaskRunSummary{ + RunID: string(runID), + TaskID: taskDef.ID, + Mode: ar.state.Mode, + Status: string(RunStatusFailed), + Protocol: taskDef.Input.NormalizedProtocol(), + Model: taskDef.Input.Model, + StartedAt: snap.StartedAt, + FinishedAt: finishedAt, + ErrorSummary: runErr.Error(), + } + + s.persistRunResult(taskDef.ID, historyDir, summary) +} + +// persistRunResult 将运行摘要写入历史文件,并更新任务的 LastRunAt/LastRunSummary。 +func (s *serverImpl) persistRunResult(taskID, historyDir string, summary types.TaskRunSummary) { + // 写历史文件 + hs := store.NewHistoryStore(historyPath(historyDir, taskID)) + _ = hs.Append(summary) // 历史记录失败不影响主流程 + + // 更新任务的最后运行时间和摘要 + s.mu.Lock() + defer s.mu.Unlock() + existing, ok := s.taskStore.Get(taskID) + if ok { + existing.LastRunAt = &summary.FinishedAt + existing.LastRunSummary = &summary + s.taskStore.Upsert(existing) + _ = s.taskStore.Save() + } +} + +// StopRun 请求停止指定运行。 +func (s *serverImpl) StopRun(runID RunID) error { + s.mu.RLock() + ar, ok := s.activeRuns[runID] + s.mu.RUnlock() + + if !ok { + return fmt.Errorf("run %q not found or already finished", runID) + } + + ar.mu.RLock() + rnr := ar.rnr + eng := ar.turboEngine + ar.mu.RUnlock() + + if rnr != nil { + rnr.Stop() + } + if eng != nil { + eng.Stop() + } + return nil +} + +// GetRunState 返回指定运行的当前状态快照。 +func (s *serverImpl) GetRunState(runID RunID) (*RunState, bool) { + s.mu.RLock() + ar, ok := s.activeRuns[runID] + s.mu.RUnlock() + + if !ok { + return nil, false + } + + ar.mu.RLock() + snap := ar.snapshotState() + ar.mu.RUnlock() + return snap, true +} + +// Subscribe 订阅指定运行的事件流。 +func (s *serverImpl) Subscribe(runID RunID) (<-chan Event, CancelFunc) { + return s.bus.Subscribe(runID) +} + +// GetHistory 返回任务的历史运行摘要,最新在前。 +func (s *serverImpl) GetHistory(taskID string, limit int) ([]types.TaskRunSummary, error) { + s.mu.RLock() + historyDir := s.historyDir + s.mu.RUnlock() + + hs := store.NewHistoryStore(historyPath(historyDir, taskID)) + return hs.Load(limit) +} + +// GenerateReport 为已完成的标准运行生成报告文件。 +func (s *serverImpl) GenerateReport(runID RunID, format ReportFormat) (string, error) { + s.mu.RLock() + ar, ok := s.activeRuns[runID] + s.mu.RUnlock() + + if !ok { + return "", fmt.Errorf("run %q not found", runID) + } + + ar.mu.RLock() + status := ar.state.Status + mode := ar.state.Mode + standardResult := ar.state.StandardResult + ar.mu.RUnlock() + + if status == RunStatusRunning { + return "", fmt.Errorf("run %q is still in progress", runID) + } + + if mode == "turbo" { + return "", fmt.Errorf("report generation for turbo runs is not yet supported") + } + + if standardResult == nil { + return "", fmt.Errorf("no result data available for run %q", runID) + } + + rm := report.NewReportManager() + paths, err := rm.GenerateReports([]types.ReportData{*standardResult}, []string{string(format)}) + if err != nil { + return "", fmt.Errorf("generate report: %w", err) + } + if len(paths) == 0 { + return "", fmt.Errorf("no report file generated") + } + return paths[0], nil +} diff --git a/internal/server/server.go b/internal/server/server.go new file mode 100644 index 0000000..ddd0d83 --- /dev/null +++ b/internal/server/server.go @@ -0,0 +1,82 @@ +package server + +import ( + "sync" + + "github.com/yinxulai/ait/internal/config" + "github.com/yinxulai/ait/internal/store" + "github.com/yinxulai/ait/internal/types" +) + +// Server 是业务逻辑层的统一入口,TUI 层通过此接口与业务交互。 +// 所有方法均为线程安全。 +type Server interface { + // --- 任务 CRUD --- + ListTasks() []types.TaskDefinition + GetTask(id string) (types.TaskDefinition, bool) + CreateTask(cfg TaskConfig) (types.TaskDefinition, error) + UpdateTask(id string, cfg TaskConfig) (types.TaskDefinition, error) + DeleteTask(id string) error + CopyTask(id string) (types.TaskDefinition, error) + + // --- 运行管理 --- + + // StartRun 根据任务配置启动一次运行,立即返回 RunID。 + // 运行在后台 goroutine 中执行,进度通过 Subscribe 获取。 + StartRun(taskID string) (RunID, error) + + // StopRun 请求停止指定运行(软停止,等待当前批次完成)。 + StopRun(runID RunID) error + + // GetRunState 返回指定运行的当前状态快照(线程安全的深度拷贝)。 + GetRunState(runID RunID) (*RunState, bool) + + // Subscribe 订阅指定运行的事件流。返回只读通道和取消函数。 + // 通道在运行结束后自动关闭,调用方可 range 消费。 + Subscribe(runID RunID) (<-chan Event, CancelFunc) + + // GetHistory 返回任务的运行历史,最新在前。limit<=0 表示不限条数。 + GetHistory(taskID string, limit int) ([]types.TaskRunSummary, error) + + // GenerateReport 为已完成的运行生成报告文件,返回文件路径。 + GenerateReport(runID RunID, format ReportFormat) (string, error) +} + +// serverImpl 是 Server 的具体实现。 +type serverImpl struct { + mu sync.RWMutex + taskStore *store.TaskStore + bus *eventBus + activeRuns map[RunID]*activeRun + historyDir string +} + +// New 创建并初始化 Server 实例。 +// 会自动加载 ~/.ait/tasks.json;historyDir 用于存放每个任务的运行历史文件。 +func New() (Server, error) { + if _, err := config.EnsureAppDir(); err != nil { + return nil, err + } + + tasksPath, err := config.TasksPath() + if err != nil { + return nil, err + } + + historyDir, err := config.HistoryDir() + if err != nil { + return nil, err + } + + ts := store.NewTaskStore(tasksPath) + if err := ts.Load(); err != nil { + return nil, err + } + + return &serverImpl{ + taskStore: ts, + bus: newEventBus(), + activeRuns: make(map[RunID]*activeRun), + historyDir: historyDir, + }, nil +} diff --git a/internal/server/task.go b/internal/server/task.go new file mode 100644 index 0000000..3069bba --- /dev/null +++ b/internal/server/task.go @@ -0,0 +1,103 @@ +package server + +import ( + "fmt" + + "github.com/yinxulai/ait/internal/types" +) + +// ListTasks 返回所有任务(最近更新排在前面)。 +func (s *serverImpl) ListTasks() []types.TaskDefinition { + s.mu.RLock() + defer s.mu.RUnlock() + return s.taskStore.All() +} + +// GetTask 按 ID 查找任务。 +func (s *serverImpl) GetTask(id string) (types.TaskDefinition, bool) { + s.mu.RLock() + defer s.mu.RUnlock() + return s.taskStore.Get(id) +} + +// CreateTask 新建任务并持久化。 +func (s *serverImpl) CreateTask(cfg TaskConfig) (types.TaskDefinition, error) { + s.mu.Lock() + defer s.mu.Unlock() + + task := types.TaskDefinition{ + Name: cfg.Name, + Input: cfg.Input, + } + s.taskStore.Upsert(task) + if err := s.taskStore.Save(); err != nil { + return types.TaskDefinition{}, fmt.Errorf("save tasks: %w", err) + } + + // 返回已生成 ID 和时间戳的最新状态 + all := s.taskStore.All() + if len(all) > 0 { + return all[0], nil + } + return task, nil +} + +// UpdateTask 更新指定任务,任务不存在时返回错误。 +func (s *serverImpl) UpdateTask(id string, cfg TaskConfig) (types.TaskDefinition, error) { + s.mu.Lock() + defer s.mu.Unlock() + + existing, ok := s.taskStore.Get(id) + if !ok { + return types.TaskDefinition{}, fmt.Errorf("task %q not found", id) + } + + existing.Name = cfg.Name + existing.Input = cfg.Input + s.taskStore.Upsert(existing) + + if err := s.taskStore.Save(); err != nil { + return types.TaskDefinition{}, fmt.Errorf("save tasks: %w", err) + } + + updated, _ := s.taskStore.Get(id) + return updated, nil +} + +// DeleteTask 删除指定任务,任务不存在时返回错误。 +func (s *serverImpl) DeleteTask(id string) error { + s.mu.Lock() + defer s.mu.Unlock() + + if err := s.taskStore.Delete(id); err != nil { + return err + } + return s.taskStore.Save() +} + +// CopyTask 复制指定任务(ID 和时间戳重置,名称加 " (copy)" 后缀)。 +func (s *serverImpl) CopyTask(id string) (types.TaskDefinition, error) { + s.mu.Lock() + defer s.mu.Unlock() + + src, ok := s.taskStore.Get(id) + if !ok { + return types.TaskDefinition{}, fmt.Errorf("task %q not found", id) + } + + copied := types.TaskDefinition{ + Name: src.Name + " (copy)", + Input: src.Input, + } + s.taskStore.Upsert(copied) + + if err := s.taskStore.Save(); err != nil { + return types.TaskDefinition{}, fmt.Errorf("save tasks: %w", err) + } + + all := s.taskStore.All() + if len(all) > 0 { + return all[0], nil + } + return copied, nil +} diff --git a/internal/server/types.go b/internal/server/types.go new file mode 100644 index 0000000..09564e8 --- /dev/null +++ b/internal/server/types.go @@ -0,0 +1,123 @@ +package server + +import ( + "time" + + "github.com/yinxulai/ait/internal/types" +) + +// RunID 唯一标识一次运行(全局唯一,格式 run_)。 +type RunID string + +// CancelFunc 取消订阅用的函数,调用后关闭对应的事件通道。 +type CancelFunc func() + +// ReportFormat 报告文件格式。 +type ReportFormat string + +const ( + ReportFormatJSON ReportFormat = "json" + ReportFormatCSV ReportFormat = "csv" +) + +// TaskConfig 新建/更新任务时提交的可变配置。 +// ID、时间戳等元数据由 Server 自动管理。 +type TaskConfig struct { + Name string + Input types.Input +} + +// RunStatus 运行的生命周期状态。 +type RunStatus string + +const ( + RunStatusRunning RunStatus = "running" + RunStatusCompleted RunStatus = "completed" + RunStatusFailed RunStatus = "failed" + RunStatusStopped RunStatus = "stopped" +) + +// RunState 一次运行的完整快照,由 GetRunState 返回。 +// 字段为只读快照,不持有锁,TUI 层可安全读取。 +type RunState struct { + RunID RunID + TaskID string + Status RunStatus + Mode string // "standard" | "turbo" + StartedAt time.Time + FinishedAt *time.Time + + // 进度计数 + TotalReqs int + DoneReqs int + SuccessReqs int + FailedReqs int + + // 聚合指标(实时更新) + AvgTPS float64 + AvgTTFT time.Duration + SuccessRate float64 + CacheHitRate float64 + + // 详细请求列表(按 index 排序) + Requests []*RequestMetrics + + // Turbo 专用 + Levels []types.TurboLevelResult + CurrentLevel int + + // 最终结果(运行结束后填充) + StandardResult *types.ReportData + TurboResult *types.TurboResult + + ErrorMsg string +} + +// RequestMetrics 单次请求的详细指标,供请求列表页展示。 +type RequestMetrics struct { + Index int + Success bool + TotalTime time.Duration + TTFT time.Duration + TPS float64 + PromptTokens int + CompletionTokens int + CachedTokens int + CacheHitRate float64 + DNSTime time.Duration + ConnectTime time.Duration + TLSTime time.Duration + TargetIP string + ErrorMessage string + // 以下字段当前为空,待 client.ResponseMetrics 支持后填充 + PromptText string + ResponseText string +} + +// EventKind 事件类型枚举。 +type EventKind string + +const ( + // EventRequestDone 单个请求完成(含成功/失败)。 + EventRequestDone EventKind = "request_done" + // EventProgressTick 定时聚合快照(约 500ms 发一次)。 + EventProgressTick EventKind = "progress_tick" + // EventLevelDone Turbo 模式下一个并发级别探测完成。 + EventLevelDone EventKind = "level_done" + // EventRunComplete 运行正常结束。 + EventRunComplete EventKind = "run_complete" + // EventRunFailed 运行异常中止。 + EventRunFailed EventKind = "run_failed" +) + +// Event 是推送给 TUI 层的通知。Payload 类型随 Kind 不同: +// - EventRequestDone → *RequestMetrics +// - EventProgressTick → *RunState(快照) +// - EventLevelDone → types.TurboLevelResult +// - EventRunComplete → *RunState(最终快照) +// - EventRunFailed → error +type Event struct { + RunID RunID + Kind EventKind + Payload any +} From cf1bece32179844a86caa35a5770c9283332409d Mon Sep 17 00:00:00 2001 From: Alain Date: Sun, 17 May 2026 00:44:02 +0800 Subject: [PATCH 05/52] feat: add task detail and list pages with wizard for task creation and editing - Implemented task detail page to display task configurations and run history. - Created task list page to manage tasks with options for creating, editing, copying, and deleting tasks. - Developed a wizard interface for creating and editing tasks, including input validation and dynamic field handling. - Enhanced user experience with contextual navigation and feedback on task operations. --- cmd/ait/ait.go | 509 ++------ cmd/ait/ait_test.go | 1802 ++------------------------- internal/tui/client.go | 147 +++ internal/tui/contextbar.go | 95 ++ internal/tui/messages.go | 57 +- internal/tui/model.go | 2004 ++++--------------------------- internal/tui/model_test.go | 179 ++- internal/tui/page_dashboard.go | 317 +++++ internal/tui/page_reqdetail.go | 182 +++ internal/tui/page_taskdetail.go | 196 +++ internal/tui/page_tasklist.go | 412 +++++++ internal/tui/page_wizard.go | 569 +++++++++ 12 files changed, 2485 insertions(+), 3984 deletions(-) create mode 100644 internal/tui/client.go create mode 100644 internal/tui/contextbar.go create mode 100644 internal/tui/page_dashboard.go create mode 100644 internal/tui/page_reqdetail.go create mode 100644 internal/tui/page_taskdetail.go create mode 100644 internal/tui/page_tasklist.go create mode 100644 internal/tui/page_wizard.go diff --git a/cmd/ait/ait.go b/cmd/ait/ait.go index c0acabc..2d57758 100644 --- a/cmd/ait/ait.go +++ b/cmd/ait/ait.go @@ -1,357 +1,43 @@ package main import ( - "crypto/rand" "flag" "fmt" - "io" "os" "strings" "time" - "github.com/yinxulai/ait/internal/display" - "github.com/yinxulai/ait/internal/prompt" - "github.com/yinxulai/ait/internal/report" - "github.com/yinxulai/ait/internal/runner" + "github.com/yinxulai/ait/internal/server" "github.com/yinxulai/ait/internal/tui" "github.com/yinxulai/ait/internal/types" ) -// 版本信息,通过 ldflags 在构建时注入 +// 版本信息,通过 ldflags 在构建时注入。 var ( Version = "dev" GitCommit = "unknown" BuildTime = "unknown" ) -func generateTaskID() string { - bytes := make([]byte, 16) - rand.Read(bytes) - - // 设置版本 (4) 和变体位 - bytes[6] = (bytes[6] & 0x0f) | 0x40 // Version 4 - bytes[8] = (bytes[8] & 0x3f) | 0x80 // Variant 10 - - return fmt.Sprintf("%08x-%04x-%04x-%04x-%012x", - bytes[0:4], bytes[4:6], bytes[6:8], bytes[8:10], bytes[10:16]) -} - -// readPromptFromStdin 从标准输入读取 prompt 内容 -func readPromptFromStdin() (string, error) { - // 检查是否有标准输入数据 - stat, err := os.Stdin.Stat() - if err != nil { - return "", err - } - - // 如果没有管道输入,返回空字符串 - if stat.Mode()&os.ModeCharDevice != 0 { - return "", nil - } - - // 读取标准输入的所有内容 - content, err := io.ReadAll(os.Stdin) - if err != nil { - return "", err - } - - return strings.TrimSpace(string(content)), nil -} - -// resolvePrompt 解析最终的 prompt 内容 -// 优先级:1. prompt-length 参数 > 2. prompt-file 参数 > 3. prompt 参数 > 4. 管道输入 > 5. 默认值 -func resolvePrompt(promptLengthSpecified bool, promptLength int, promptSpecified bool, flagPrompt string, promptFileSpecified bool, flagPromptFile string) (*prompt.PromptSource, error) { - // 1. 如果用户指定了 --prompt-length 参数,优先使用长度生成 - if promptLengthSpecified && promptLength > 0 { - return prompt.LoadPromptByLength(promptLength) - } - - // 2. 如果用户指定了 --prompt-file 参数,使用文件 - if promptFileSpecified { - return prompt.LoadPromptsFromFile(flagPromptFile) - } - - // 3. 如果用户明确指定了 --prompt 参数,则使用它 - if promptSpecified { - return prompt.LoadPrompts(flagPrompt) - } - - // 4. 检查是否有管道输入 - stdinPrompt, err := readPromptFromStdin() - if err == nil && stdinPrompt != "" { - return prompt.LoadPrompts(stdinPrompt) - } - - // 5. 使用默认值 - return prompt.LoadPrompts(flagPrompt) -} - -// detectProviderFromEnv 根据环境变量自动检测 provider -func detectProviderFromEnv() string { - // 优先检查 OpenAI 环境变量 - if os.Getenv("OPENAI_API_KEY") != "" || os.Getenv("OPENAI_BASE_URL") != "" { - return "openai" - } - // 其次检查 Anthropic 环境变量 - if os.Getenv("ANTHROPIC_API_KEY") != "" || os.Getenv("ANTHROPIC_BASE_URL") != "" { - return "anthropic" - } - // 默认返回 openai - return "openai" -} - -// loadEnvForProvider 根据 provider 加载对应的环境变量 -func loadEnvForProvider(provider string) (baseUrl, apiKey string) { - switch provider { - case "openai": - return os.Getenv("OPENAI_BASE_URL"), os.Getenv("OPENAI_API_KEY") - case "anthropic": - return os.Getenv("ANTHROPIC_BASE_URL"), os.Getenv("ANTHROPIC_API_KEY") - default: - return "", "" - } -} - -// validateRequiredParams 验证必需的参数 -func validateRequiredParams(models, baseUrl, apiKey, protocol string) error { - if models == "" { - return fmt.Errorf("models 参数必填,请通过 -models 参数指定") - } - - if baseUrl == "" || apiKey == "" { - return fmt.Errorf("baseUrl 和 apikey 参数必填,对于 %s 协议,你也可以设置相应的环境变量", protocol) - } - - return nil -} - -// parseModelList 解析模型列表字符串 -func parseModelList(models string) []string { - if models == "" { - return nil - } - - modelList := strings.Split(models, ",") - for i, m := range modelList { - modelList[i] = strings.TrimSpace(m) - } - return modelList -} - -// resolveConfigValues 解析并合并配置值 -func resolveConfigValues(protocol, baseUrl, apiKey string) (string, string, string) { - finalProtocol := protocol - finalBaseUrl := baseUrl - finalApiKey := apiKey - - // 如果未指定 protocol,根据环境变量自动推断 - if finalProtocol == "" { - finalProtocol = detectProviderFromEnv() - } - - // 根据 protocol 加载对应的环境变量 - if finalBaseUrl == "" || finalApiKey == "" { - envBaseUrl, envApiKey := loadEnvForProvider(finalProtocol) - if finalBaseUrl == "" { - finalBaseUrl = envBaseUrl - } - if finalApiKey == "" { - finalApiKey = envApiKey - } - } - - return finalProtocol, finalBaseUrl, finalApiKey -} - -// printErrorMessages 打印错误消息并提供环境变量设置建议 -func printErrorMessages(protocol string) { - fmt.Println("baseUrl 和 apikey 参数必填") - fmt.Printf("对于 %s 协议,你也可以设置以下环境变量:\n", protocol) - - switch protocol { - case "openai": - fmt.Println(" OPENAI_BASE_URL - OpenAI API 基础 URL") - fmt.Println(" OPENAI_API_KEY - OpenAI API 密钥") - case "anthropic": - fmt.Println(" ANTHROPIC_BASE_URL - Anthropic API 基础 URL") - fmt.Println(" ANTHROPIC_API_KEY - Anthropic API 密钥") - } -} - -// createRunnerConfig 创建runner配置 -func createRunnerConfig(protocol, baseUrl, apiKey, model string, promptSource *prompt.PromptSource, concurrency, count, timeout int, stream, report, log, thinking bool) types.Input { - return types.Input{ - Protocol: protocol, - BaseUrl: baseUrl, - ApiKey: apiKey, - Model: model, - Concurrency: concurrency, - Count: count, - PromptSource: promptSource, - Stream: stream, - Report: report, - Timeout: time.Duration(timeout) * time.Second, - Log: log, - Thinking: thinking, - } -} - -// processModelExecution 处理单个模型的执行逻辑 -func processModelExecution(taskID string, modelName string, config types.Input, displayer *display.Displayer, completedRequests, totalRequests int) (*types.ReportData, []string, error) { - runnerInstance, err := runner.NewRunner(taskID, config) - if err != nil { - return nil, nil, fmt.Errorf("创建测试执行器失败: %v", err) - } - - // 用于收集当前模型的错误信息 - var currentModelErrors []string - - // 执行测试,使用回调函数来更新显示 - result, err := runnerInstance.RunWithProgress(func(sd types.StatsData) { - // 计算当前总完成数:之前模型的完成数 + 当前模型的完成数 - currentCompleted := completedRequests + sd.CompletedCount + sd.FailedCount - - // 计算百分比 - percent := float64(currentCompleted) / float64(totalRequests) * 100.0 - - // 类型断言来调用UpdateProgress方法 - displayer.UpdateProgress(percent) - - // 保存最新的错误信息(覆盖之前的,确保获取最完整的错误列表) - currentModelErrors = make([]string, len(sd.ErrorMessages)) - copy(currentModelErrors, sd.ErrorMessages) - }) - if err != nil { - return nil, nil, err - } - - return result, currentModelErrors, nil -} - -// collectErrorsWithContext 收集带有模型上下文的错误信息 -func collectErrorsWithContext(modelName string, modelErrors []string) []string { - var errors []string - for _, errorMsg := range modelErrors { - if errorMsg != "" { - // 为错误信息添加模型上下文 - errorWithContext := fmt.Sprintf("[%s] %s", modelName, errorMsg) - errors = append(errors, errorWithContext) - } - } - return errors -} - -// fillResultMetadata 填充结果元数据 -func fillResultMetadata(results []*types.ReportData, modelList []string, baseUrl, protocol string) { - for i, result := range results { - result.Model = modelList[i] - result.BaseUrl = baseUrl - result.Protocol = protocol - result.Timestamp = time.Now().Format(time.RFC3339) - } -} - -func convertErrorsToPointers(errors []string) []*string { - errorPtrs := make([]*string, len(errors)) - for i := range errors { - errorPtrs[i] = &errors[i] - } - return errorPtrs -} - -// generateReportsIfEnabled 如果启用了报告功能,则生成报告 -func generateReportsIfEnabled(reportFlag bool, results []*types.ReportData) error { - if !reportFlag || len(results) == 0 { - return nil - } - - // 转换为 ReportData 切片 - reportDataList := make([]types.ReportData, len(results)) - for i, result := range results { - reportDataList[i] = *result - } - - // 使用 ReportManager 生成汇总报告 - manager := report.NewReportManager() - filePaths, err := manager.GenerateReports(reportDataList, []string{"json", "csv"}) - if err != nil { - return fmt.Errorf("生成汇总报告失败: %v", err) - } - - fmt.Printf("\n汇总报告已生成:\n") - for _, filePath := range filePaths { - fmt.Printf(" - %s\n", filePath) - } - return nil -} - -// executeModelsTestSuite 执行多个模型的测试套件 -func executeModelsTestSuite(taskID string, modelList []string, finalProtocol, finalBaseUrl, finalApiKey string, promptSource *prompt.PromptSource, concurrency, count, timeout int, stream, reportFlag, log, thinking bool, displayer *display.Displayer) ([]*types.ReportData, []string, error) { - // 用于收集所有错误信息 - var allErrors []string - - // 用于汇总所有模型的测试结果 - var allResults []*types.ReportData - - // 循环处理每个模型 - totalRequests := count * len(modelList) - - // 初始化总进度条 - displayer.InitProgress(totalRequests, fmt.Sprintf("🚀 测试进度 (%d 个模型)", len(modelList))) - - completedRequests := 0 - - for _, modelName := range modelList { - config := createRunnerConfig(finalProtocol, finalBaseUrl, finalApiKey, modelName, promptSource, concurrency, count, timeout, stream, reportFlag, log, thinking) - - result, currentModelErrors, err := processModelExecution(taskID, modelName, config, displayer, completedRequests, totalRequests) - if err != nil { - fmt.Printf("模型 %s 执行失败: %v\n", modelName, err) - continue - } - - // 处理当前模型的错误信息 - modelErrors := collectErrorsWithContext(modelName, currentModelErrors) - allErrors = append(allErrors, modelErrors...) - - // 更新已完成的请求数(当前模型的所有请求都已完成) - completedRequests += config.Count - - // 保存结果用于汇总 - allResults = append(allResults, result) - } - - // 完成进度条 - displayer.FinishProgress() - - // 为所有结果填充模型名称元数据 - fillResultMetadata(allResults, modelList, finalBaseUrl, finalProtocol) - - return allResults, allErrors, nil -} - func main() { - taskID := generateTaskID() + // ── flags ──────────────────────────────────────────────────────────────── versionFlag := flag.Bool("version", false, "显示版本信息") - interactiveFlag := flag.Bool("interactive", false, "启动交互式 TUI") - baseUrl := flag.String("baseUrl", "", "服务地址") - apiKey := flag.String("apiKey", "", "API 密钥") - count := flag.Int("count", 10, "请求总数") - model := flag.String("model", "", "模型名称(单个模型)") - models := flag.String("models", "", "模型名称,支持多个模型用,(逗号)分割") - protocol := flag.String("protocol", "", "协议类型: openai 或 anthropic") - prompt := flag.String("prompt", "你好,介绍一下你自己。", "测试用 prompt 内容。未指定时支持管道输入") - promptFile := flag.String("prompt-file", "", "从文件读取 prompt。支持单文件路径或通配符 (如: prompts/*.txt)") - promptLength := flag.Int("prompt-length", 0, "生成指定长度的测试 prompt(字符数)。优先级高于其他 prompt 参数") - stream := flag.Bool("stream", true, "是否开启流模式") - concurrency := flag.Int("concurrency", 3, "并发数") - reportFlag := flag.Bool("report", false, "是否生成报告文件") - timeout := flag.Int("timeout", 300, "请求超时时间(秒)") - logFlag := flag.Bool("log", false, "是否开启详细日志记录") - thinking := flag.Bool("thinking", false, "是否开启 thinking 模式") + baseURL := flag.String("baseUrl", "", "服务基础地址(可选,留空使用协议默认地址)") + apiKey := flag.String("apiKey", "", "API 密钥") + model := flag.String("model", "", "模型名称") + protocol := flag.String("protocol", "", "协议类型: openai / anthropic") + promptText := flag.String("prompt", "", "Prompt 文本(可选)") + promptFile := flag.String("prompt-file", "", "从文件读取 Prompt") + promptLen := flag.Int("prompt-length", 0, "生成指定长度的测试 Prompt(字符数)") + stream := flag.Bool("stream", true, "是否开启流式输出") + thinking := flag.Bool("thinking", false, "是否开启 Thinking 模式") + concurrency := flag.Int("concurrency", 10, "并发数") + count := flag.Int("count", 100, "请求总数") + timeout := flag.Int("timeout", 300, "请求超时时间(秒)") + turboFlag := flag.Bool("turbo", false, "是否启用 Turbo 并发探测模式") flag.Parse() - // 如果指定了 --version,显示版本信息后退出 + // ── 版本输出 ────────────────────────────────────────────────────────────── if *versionFlag { fmt.Printf("ait version %s\n", Version) fmt.Printf("Git Commit: %s\n", GitCommit) @@ -359,114 +45,91 @@ func main() { os.Exit(0) } - if *interactiveFlag { - if err := tui.Run(); err != nil { - fmt.Printf("启动交互式 TUI 失败: %v\n", err) - os.Exit(1) - } - return + // ── 创建 Server ─────────────────────────────────────────────────────────── + srv, err := server.New() + if err != nil { + fmt.Fprintf(os.Stderr, "初始化 Server 失败: %v\n", err) + os.Exit(1) } - // 合并 --model 和 --models 参数 - finalModels := *models + // ── 若提供了足够参数则预建任务并自动启动 ──────────────────────────────────── if *model != "" { - if finalModels != "" { - fmt.Println("错误:不能同时使用 --model 和 --models 参数") + finalProtocol, finalBaseURL, finalAPIKey := resolveConfig(*protocol, *baseURL, *apiKey) + if finalAPIKey == "" { + fmt.Fprintln(os.Stderr, "错误: 缺少 API Key(-apiKey 或环境变量)") os.Exit(1) } - finalModels = *model - } - // 解析和验证配置 - finalProtocol, finalBaseUrl, finalApiKey := resolveConfigValues(*protocol, *baseUrl, *apiKey) - - // 验证必需参数 - if err := validateRequiredParams(finalModels, finalBaseUrl, finalApiKey, finalProtocol); err != nil { - if finalModels == "" { - fmt.Println("model/models 参数必填,请通过 --model 或 --models 参数指定") - fmt.Println("--model: 指定单个模型,例如:--model gpt-3.5-turbo") - fmt.Println("--models: 支持多个模型,用逗号分割,例如:--models gpt-3.5-turbo,gpt-4") - } else { - printErrorMessages(finalProtocol) + inp := types.Input{ + Protocol: finalProtocol, + BaseUrl: finalBaseURL, + ApiKey: finalAPIKey, + Model: *model, + Stream: *stream, + Thinking: *thinking, + Concurrency: *concurrency, + Count: *count, + Turbo: *turboFlag, + Timeout: time.Duration(*timeout) * time.Second, } - os.Exit(1) - } - // 解析模型列表 - modelList := parseModelList(finalModels) - - // 检查用户是否明确指定了 --prompt、--prompt-file 和 --prompt-length 参数 - promptSpecified := false - promptFileSpecified := false - promptLengthSpecified := false - flag.Visit(func(f *flag.Flag) { - if f.Name == "prompt" { - promptSpecified = true - } - if f.Name == "prompt-file" { - promptFileSpecified = true + // Prompt 配置 + switch { + case *promptLen > 0: + inp.PromptMode = "generated" + inp.PromptLength = *promptLen + case *promptFile != "": + inp.PromptMode = "file" + inp.PromptFile = *promptFile + case *promptText != "": + inp.PromptMode = "text" + inp.PromptText = *promptText + default: + inp.PromptMode = "text" + inp.PromptText = "你好,介绍一下你自己。" } - if f.Name == "prompt-length" { - promptLengthSpecified = true - } - }) - // 解析最终的 prompt,优先级:prompt-length > prompt-file > prompt > 管道输入 > 默认值 - promptSource, err := resolvePrompt(promptLengthSpecified, *promptLength, promptSpecified, *prompt, promptFileSpecified, *promptFile) - if err != nil { - fmt.Printf("解析 prompt 失败: %v\n", err) - os.Exit(1) + taskName := fmt.Sprintf("%s@%s", *model, strings.TrimRight(finalBaseURL, "/")) + _, err := srv.CreateTask(server.TaskConfig{Name: taskName, Input: inp}) + if err != nil { + fmt.Fprintf(os.Stderr, "创建任务失败: %v\n", err) + os.Exit(1) + } } - displayer := display.New() - - // 显示欢迎信息 - displayer.ShowWelcome(Version) - - displayer.ShowInput(&display.Input{ - TaskId: taskID, - Protocol: finalProtocol, - BaseUrl: finalBaseUrl, - ApiKey: finalApiKey, - Models: modelList, - Concurrency: *concurrency, - Count: *count, - Stream: *stream, - Thinking: *thinking, - PromptText: promptSource.DisplayText, - PromptShouldTruncate: promptSource.ShouldTruncate, - IsFile: promptSource.IsFile, - Report: *reportFlag, - Timeout: *timeout, - }) - - // 执行多个模型的测试套件 - allResults, allErrors, err := executeModelsTestSuite( - taskID, modelList, finalProtocol, finalBaseUrl, finalApiKey, promptSource, - *concurrency, *count, *timeout, *stream, *reportFlag, *logFlag, *thinking, displayer, - ) - if err != nil { - fmt.Printf("执行测试套件失败: %v\n", err) + // ── 启动 TUI ────────────────────────────────────────────────────────────── + if err := tui.Run(srv); err != nil { + fmt.Fprintf(os.Stderr, "TUI 启动失败: %v\n", err) os.Exit(1) } +} - // 显示错误报告(如果有错误的话) - if len(allErrors) > 0 { - errorPtrs := convertErrorsToPointers(allErrors) - displayer.ShowErrorsReport(errorPtrs) - } - - // 根据模型数量显示相应的报告 - if len(modelList) == 1 { - displayer.ShowSignalReport(allResults[0]) +// resolveConfig 合并命令行参数与环境变量。 +func resolveConfig(protocol, baseURL, apiKey string) (string, string, string) { + if protocol == "" { + if os.Getenv("OPENAI_API_KEY") != "" || os.Getenv("OPENAI_BASE_URL") != "" { + protocol = "openai" + } else if os.Getenv("ANTHROPIC_API_KEY") != "" || os.Getenv("ANTHROPIC_BASE_URL") != "" { + protocol = "anthropic" + } else { + protocol = "openai" + } } - - if len(modelList) > 1 { - displayer.ShowMultiReport(allResults) + if baseURL == "" { + switch protocol { + case "anthropic": + baseURL = os.Getenv("ANTHROPIC_BASE_URL") + default: + baseURL = os.Getenv("OPENAI_BASE_URL") + } } - - // 生成报告文件(如果启用) - if err := generateReportsIfEnabled(*reportFlag, allResults); err != nil { - fmt.Printf("报告生成失败: %v\n", err) + if apiKey == "" { + switch protocol { + case "anthropic": + apiKey = os.Getenv("ANTHROPIC_API_KEY") + default: + apiKey = os.Getenv("OPENAI_API_KEY") + } } + return protocol, baseURL, apiKey } diff --git a/cmd/ait/ait_test.go b/cmd/ait/ait_test.go index 6cfaf41..22c92b0 100644 --- a/cmd/ait/ait_test.go +++ b/cmd/ait/ait_test.go @@ -1,370 +1,149 @@ package main import ( - "flag" - "fmt" "os" "strings" "testing" - "time" - - "github.com/yinxulai/ait/internal/display" - "github.com/yinxulai/ait/internal/prompt" - "github.com/yinxulai/ait/internal/types" ) -// createTestPromptSource 创建测试用的 PromptSource -func createTestPromptSource(promptText string) *prompt.PromptSource { - source, _ := prompt.LoadPrompts(promptText) - return source -} - -// MockRunner 模拟 runner 以便测试 main 函数逻辑 -type MockRunner struct { - result *types.ReportData - input types.Input - err error -} - -func NewMockRunner(config types.Input) (*MockRunner, error) { - if config.Model == "invalid-model" { - return nil, fmt.Errorf("invalid model: %s", config.Model) - } - - return &MockRunner{ - input: config, - result: &types.ReportData{ - TotalRequests: config.Count, - Concurrency: config.Concurrency, - IsStream: config.Stream, - IsThinking: config.Thinking, - TotalTime: 1500 * time.Millisecond, - Timestamp: time.Now().Format(time.RFC3339), - Protocol: config.Protocol, - Model: config.Model, - BaseUrl: config.BaseUrl, - AvgTotalTime: 150 * time.Millisecond, - MinTotalTime: 100 * time.Millisecond, - MaxTotalTime: 200 * time.Millisecond, - AvgTTFT: 50 * time.Millisecond, - MinTTFT: 30 * time.Millisecond, - MaxTTFT: 70 * time.Millisecond, - AvgTPOT: 25 * time.Millisecond, - MinTPOT: 20 * time.Millisecond, - MaxTPOT: 30 * time.Millisecond, - AvgInputTokenCount: 50, - MinInputTokenCount: 40, - MaxInputTokenCount: 60, - AvgOutputTokenCount: 100, - MinOutputTokenCount: 80, - MaxOutputTokenCount: 120, - AvgThinkingTokenCount: 40, - MinThinkingTokenCount: 30, - MaxThinkingTokenCount: 50, - AvgTPS: 200.0, - MinTPS: 150.0, - MaxTPS: 250.0, - ErrorRate: 0.0, - SuccessRate: 100.0, - }, - }, nil -} - -func (m *MockRunner) RunWithProgress(callback func(types.StatsData)) (*types.ReportData, error) { - if m.err != nil { - return nil, m.err - } - - // 模拟进度回调 - for i := 0; i <= m.input.Count; i++ { - callback(types.StatsData{ - CompletedCount: i, - FailedCount: 0, - ErrorMessages: []string{}, - }) +// ─── resolveConfig ──────────────────────────────────────────────────────────── + +// clearEnv 清除所有 provider 环境变量,返回 restore 函数。 +func clearEnv(t *testing.T) func() { + t.Helper() + saved := map[string]string{} + keys := []string{"OPENAI_API_KEY", "OPENAI_BASE_URL", "ANTHROPIC_API_KEY", "ANTHROPIC_BASE_URL"} + for _, k := range keys { + saved[k] = os.Getenv(k) + os.Unsetenv(k) + } + return func() { + for k, v := range saved { + if v == "" { + os.Unsetenv(k) + } else { + os.Setenv(k, v) + } + } } - - return m.result, nil -} - -func (m *MockRunner) Run() (*types.ReportData, error) { - return m.result, m.err -} - -// MockDisplay 模拟 display 组件以便测试 -type MockDisplay struct{} - -func (md *MockDisplay) Init(total int) error { - return nil -} - -func (md *MockDisplay) Update(current int) error { - return nil } -func (md *MockDisplay) Finish() error { - return nil -} - -func (md *MockDisplay) ShowResults(data interface{}) error { - return nil -} - -func TestDetectProtocolFromEnv(t *testing.T) { - // 保存原始环境变量 - originalOpenAIKey := os.Getenv("OPENAI_API_KEY") - originalOpenAIURL := os.Getenv("OPENAI_BASE_URL") - originalAnthropicKey := os.Getenv("ANTHROPIC_API_KEY") - originalAnthropicURL := os.Getenv("ANTHROPIC_BASE_URL") - - // 确保测试后恢复原始环境变量 - defer func() { - os.Setenv("OPENAI_API_KEY", originalOpenAIKey) - os.Setenv("OPENAI_BASE_URL", originalOpenAIURL) - os.Setenv("ANTHROPIC_API_KEY", originalAnthropicKey) - os.Setenv("ANTHROPIC_BASE_URL", originalAnthropicURL) - }() - +func TestResolveConfig_ProtocolInference(t *testing.T) { tests := []struct { - name string - openaiKey string - openaiURL string - anthropicKey string - anthropicURL string - expectedProtocol string + name string + envVars map[string]string + wantProt string }{ { - name: "OpenAI API key set", - openaiKey: "test-openai-key", - openaiURL: "", - anthropicKey: "", - anthropicURL: "", - expectedProtocol: "openai", + name: "OpenAI API key → openai", + envVars: map[string]string{"OPENAI_API_KEY": "sk-test"}, + wantProt: "openai", }, { - name: "OpenAI base URL set", - openaiKey: "", - openaiURL: "https://api.openai.com", - anthropicKey: "", - anthropicURL: "", - expectedProtocol: "openai", + name: "OpenAI base URL → openai", + envVars: map[string]string{"OPENAI_BASE_URL": "https://api.openai.com"}, + wantProt: "openai", }, { - name: "Anthropic API key set", - openaiKey: "", - openaiURL: "", - anthropicKey: "test-anthropic-key", - anthropicURL: "", - expectedProtocol: "anthropic", + name: "Anthropic key → anthropic", + envVars: map[string]string{"ANTHROPIC_API_KEY": "sk-ant"}, + wantProt: "anthropic", }, { - name: "Anthropic base URL set", - openaiKey: "", - openaiURL: "", - anthropicKey: "", - anthropicURL: "https://api.anthropic.com", - expectedProtocol: "anthropic", + name: "Anthropic URL → anthropic", + envVars: map[string]string{"ANTHROPIC_BASE_URL": "https://api.anthropic.com"}, + wantProt: "anthropic", }, { - name: "Both providers set - OpenAI takes priority", - openaiKey: "test-openai-key", - openaiURL: "", - anthropicKey: "test-anthropic-key", - anthropicURL: "", - expectedProtocol: "openai", + name: "Both set → openai wins", + envVars: map[string]string{"OPENAI_API_KEY": "sk-test", "ANTHROPIC_API_KEY": "sk-ant"}, + wantProt: "openai", }, { - name: "No environment variables set - defaults to openai", - openaiKey: "", - openaiURL: "", - anthropicKey: "", - anthropicURL: "", - expectedProtocol: "openai", + name: "No env vars → default openai", + envVars: map[string]string{}, + wantProt: "openai", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - // 清理所有相关环境变量 - os.Unsetenv("OPENAI_API_KEY") - os.Unsetenv("OPENAI_BASE_URL") - os.Unsetenv("ANTHROPIC_API_KEY") - os.Unsetenv("ANTHROPIC_BASE_URL") - - // 设置测试环境变量 - if tt.openaiKey != "" { - os.Setenv("OPENAI_API_KEY", tt.openaiKey) + restore := clearEnv(t) + defer restore() + for k, v := range tt.envVars { + os.Setenv(k, v) } - if tt.openaiURL != "" { - os.Setenv("OPENAI_BASE_URL", tt.openaiURL) - } - if tt.anthropicKey != "" { - os.Setenv("ANTHROPIC_API_KEY", tt.anthropicKey) - } - if tt.anthropicURL != "" { - os.Setenv("ANTHROPIC_BASE_URL", tt.anthropicURL) - } - - got := detectProviderFromEnv() - if got != tt.expectedProtocol { - t.Errorf("detectProviderFromEnv() = %v, want %v", got, tt.expectedProtocol) + prot, _, _ := resolveConfig("", "", "") + if prot != tt.wantProt { + t.Errorf("protocol = %q, want %q", prot, tt.wantProt) } }) } } -func TestLoadEnvForProtocol(t *testing.T) { - // 保存原始环境变量 - originalOpenAIKey := os.Getenv("OPENAI_API_KEY") - originalOpenAIURL := os.Getenv("OPENAI_BASE_URL") - originalAnthropicKey := os.Getenv("ANTHROPIC_API_KEY") - originalAnthropicURL := os.Getenv("ANTHROPIC_BASE_URL") - - // 确保测试后恢复原始环境变量 - defer func() { - os.Setenv("OPENAI_API_KEY", originalOpenAIKey) - os.Setenv("OPENAI_BASE_URL", originalOpenAIURL) - os.Setenv("ANTHROPIC_API_KEY", originalAnthropicKey) - os.Setenv("ANTHROPIC_BASE_URL", originalAnthropicURL) - }() - +func TestResolveConfig_KeyAndURL(t *testing.T) { tests := []struct { - name string - protocol string - envVars map[string]string - expectedURL string - expectedKey string + name string + protocol string + envVars map[string]string + wantURL string + wantKey string }{ { - name: "OpenAI protocol with environment variables", + name: "openai env vars resolved", protocol: "openai", - envVars: map[string]string{ - "OPENAI_BASE_URL": "https://api.openai.com", - "OPENAI_API_KEY": "test-openai-key", - }, - expectedURL: "https://api.openai.com", - expectedKey: "test-openai-key", + envVars: map[string]string{"OPENAI_BASE_URL": "https://api.openai.com", "OPENAI_API_KEY": "sk-openai"}, + wantURL: "https://api.openai.com", + wantKey: "sk-openai", }, { - name: "Anthropic protocol with environment variables", + name: "anthropic env vars resolved", protocol: "anthropic", - envVars: map[string]string{ - "ANTHROPIC_BASE_URL": "https://api.anthropic.com", - "ANTHROPIC_API_KEY": "test-anthropic-key", - }, - expectedURL: "https://api.anthropic.com", - expectedKey: "test-anthropic-key", - }, - { - name: "OpenAI protocol without environment variables", - protocol: "openai", - envVars: map[string]string{}, - expectedURL: "", - expectedKey: "", + envVars: map[string]string{"ANTHROPIC_BASE_URL": "https://api.anthropic.com", "ANTHROPIC_API_KEY": "sk-ant"}, + wantURL: "https://api.anthropic.com", + wantKey: "sk-ant", }, { - name: "Anthropic protocol without environment variables", - protocol: "anthropic", - envVars: map[string]string{}, - expectedURL: "", - expectedKey: "", - }, - { - name: "Unknown protocol", - protocol: "unknown", - envVars: map[string]string{}, - expectedURL: "", - expectedKey: "", - }, - { - name: "Only OpenAI URL set", + name: "explicit args override env", protocol: "openai", - envVars: map[string]string{ - "OPENAI_BASE_URL": "https://custom.openai.com", - }, - expectedURL: "https://custom.openai.com", - expectedKey: "", + envVars: map[string]string{"OPENAI_BASE_URL": "https://env.url", "OPENAI_API_KEY": "env-key"}, + wantURL: "https://explicit.url", + wantKey: "explicit-key", }, { - name: "Only Anthropic key set", - protocol: "anthropic", - envVars: map[string]string{ - "ANTHROPIC_API_KEY": "test-key-only", - }, - expectedURL: "", - expectedKey: "test-key-only", + name: "unknown protocol - no env", + protocol: "other", + envVars: map[string]string{}, + wantURL: "", + wantKey: "", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - // 清除所有相关环境变量 - os.Unsetenv("OPENAI_API_KEY") - os.Unsetenv("OPENAI_BASE_URL") - os.Unsetenv("ANTHROPIC_API_KEY") - os.Unsetenv("ANTHROPIC_BASE_URL") - - // 设置测试环境变量 - for key, value := range tt.envVars { - os.Setenv(key, value) + restore := clearEnv(t) + defer restore() + for k, v := range tt.envVars { + os.Setenv(k, v) } - - baseUrl, apiKey := loadEnvForProvider(tt.protocol) - if baseUrl != tt.expectedURL { - t.Errorf("loadEnvForProtocol(%v) baseUrl = %v, want %v", tt.protocol, baseUrl, tt.expectedURL) + var prot, url, key string + if tt.name == "explicit args override env" { + prot, url, key = resolveConfig(tt.protocol, "https://explicit.url", "explicit-key") + } else { + prot, url, key = resolveConfig(tt.protocol, "", "") } - if apiKey != tt.expectedKey { - t.Errorf("loadEnvForProtocol(%v) apiKey = %v, want %v", tt.protocol, apiKey, tt.expectedKey) + _ = prot + if url != tt.wantURL { + t.Errorf("url = %q, want %q", url, tt.wantURL) + } + if key != tt.wantKey { + t.Errorf("key = %q, want %q", key, tt.wantKey) } }) } } -func TestFlagDefinitions(t *testing.T) { - // 重置 flag 状态,避免冲突 - flag.CommandLine = flag.NewFlagSet("test", flag.ContinueOnError) - - // 模拟定义 flags(这部分通常在 main 中) - baseUrl := flag.String("baseUrl", "", "服务地址") - apikey := flag.String("apikey", "", "API 密钥") - model := flag.String("model", "", "模型名称") - provider := flag.String("protocol", "", "协议类型: openai 或 anthropic") - concurrency := flag.Int("concurrency", 3, "并发数") - count := flag.Int("count", 10, "请求总数") - prompt := flag.String("prompt", "你好,介绍一下你自己。", "测试用 prompt") - stream := flag.Bool("stream", true, "是否开启流模式") - reportFlag := flag.Bool("report", false, "是否生成报告文件") - - // 测试默认值 - if *provider != "" { - t.Errorf("Expected default protocol '', got '%s'", *provider) - } - - if *concurrency != 3 { - t.Errorf("Expected default concurrency 3, got %d", *concurrency) - } - - if *count != 10 { - t.Errorf("Expected default count 10, got %d", *count) - } - - if *stream != true { - t.Errorf("Expected default stream true, got %t", *stream) - } - - if *reportFlag != false { - t.Errorf("Expected default report false, got %t", *reportFlag) - } - - if *prompt != "你好,介绍一下你自己。" { - t.Errorf("Expected default prompt '你好,介绍一下你自己。', got '%s'", *prompt) - } - - // 测试 flag 是否正确定义 - if baseUrl == nil || apikey == nil || model == nil || prompt == nil { - t.Error("Required flags should be defined") - } -} +// ─── ParseModels (inline logic, no helper needed) ───────────────────────────── func TestParseModels(t *testing.T) { tests := []struct { @@ -372,1410 +151,27 @@ func TestParseModels(t *testing.T) { input string expected []string }{ - { - name: "Single model", - input: "gpt-3.5-turbo", - expected: []string{"gpt-3.5-turbo"}, - }, - { - name: "Multiple models", - input: "gpt-3.5-turbo,gpt-4,claude-3", - expected: []string{"gpt-3.5-turbo", "gpt-4", "claude-3"}, - }, - { - name: "Models with spaces", - input: "gpt-3.5-turbo, gpt-4 , claude-3", - expected: []string{"gpt-3.5-turbo", "gpt-4", "claude-3"}, - }, - { - name: "Empty model in list", - input: "gpt-3.5-turbo,,gpt-4", - expected: []string{"gpt-3.5-turbo", "", "gpt-4"}, - }, - { - name: "Single model with spaces", - input: " gpt-3.5-turbo ", - expected: []string{"gpt-3.5-turbo"}, - }, + {"single model", "gpt-3.5-turbo", []string{"gpt-3.5-turbo"}}, + {"multiple models", "gpt-3.5-turbo,gpt-4,claude-3", []string{"gpt-3.5-turbo", "gpt-4", "claude-3"}}, + {"models with spaces", "gpt-3.5-turbo, gpt-4 , claude-3", []string{"gpt-3.5-turbo", "gpt-4", "claude-3"}}, + {"single model with spaces", " gpt-3.5-turbo ", []string{"gpt-3.5-turbo"}}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - // 模拟 main 函数中的模型解析逻辑 - modelList := strings.Split(tt.input, ",") - for i, m := range modelList { - modelList[i] = strings.TrimSpace(m) + parts := strings.Split(tt.input, ",") + for i, p := range parts { + parts[i] = strings.TrimSpace(p) } - - if len(modelList) != len(tt.expected) { - t.Errorf("Expected %d models, got %d", len(tt.expected), len(modelList)) + if len(parts) != len(tt.expected) { + t.Errorf("got %d models, want %d", len(parts), len(tt.expected)) return } - - for i, expected := range tt.expected { - if modelList[i] != expected { - t.Errorf("Model[%d]: expected '%s', got '%s'", i, expected, modelList[i]) - } - } - }) - } -} - -func TestInputConfig(t *testing.T) { - tests := []struct { - name string - protocol string - baseUrl string - apiKey string - model string - input types.Input - }{ - { - name: "OpenAI configuration", - protocol: "openai", - baseUrl: "https://api.openai.com", - apiKey: "test-openai-key", - model: "gpt-3.5-turbo", - input: types.Input{ - Protocol: "openai", - BaseUrl: "https://api.openai.com", - ApiKey: "test-openai-key", - Model: "gpt-3.5-turbo", - Concurrency: 3, - Count: 10, - PromptSource: createTestPromptSource("你好,介绍一下你自己。"), - Stream: true, - Report: false, - Timeout: 30 * time.Second, - }, - }, - { - name: "Anthropic configuration", - protocol: "anthropic", - baseUrl: "https://api.anthropic.com", - apiKey: "test-anthropic-key", - model: "claude-3", - input: types.Input{ - Protocol: "anthropic", - BaseUrl: "https://api.anthropic.com", - ApiKey: "test-anthropic-key", - Model: "claude-3", - Concurrency: 5, - Count: 20, - PromptSource: createTestPromptSource("Test prompt"), - Stream: false, - Report: true, - Timeout: 60 * time.Second, - }, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - // 模拟创建 Input 配置的过程 - input := types.Input{ - Protocol: tt.protocol, - BaseUrl: tt.baseUrl, - ApiKey: tt.apiKey, - Model: tt.model, - Concurrency: tt.input.Concurrency, - Count: tt.input.Count, - PromptSource: tt.input.PromptSource, - Stream: tt.input.Stream, - Report: tt.input.Report, - Timeout: tt.input.Timeout, - } - - // 验证配置字段 - if input.Protocol != tt.input.Protocol { - t.Errorf("Protocol: expected %s, got %s", tt.input.Protocol, input.Protocol) - } - if input.BaseUrl != tt.input.BaseUrl { - t.Errorf("BaseUrl: expected %s, got %s", tt.input.BaseUrl, input.BaseUrl) - } - if input.ApiKey != tt.input.ApiKey { - t.Errorf("ApiKey: expected %s, got %s", tt.input.ApiKey, input.ApiKey) - } - if input.Model != tt.input.Model { - t.Errorf("Model: expected %s, got %s", tt.input.Model, input.Model) - } - if input.Concurrency != tt.input.Concurrency { - t.Errorf("Concurrency: expected %d, got %d", tt.input.Concurrency, input.Concurrency) - } - if input.Count != tt.input.Count { - t.Errorf("Count: expected %d, got %d", tt.input.Count, input.Count) - } - if input.Stream != tt.input.Stream { - t.Errorf("Stream: expected %t, got %t", tt.input.Stream, input.Stream) - } - if input.Report != tt.input.Report { - t.Errorf("Report: expected %t, got %t", tt.input.Report, input.Report) - } - if input.Timeout != tt.input.Timeout { - t.Errorf("Timeout: expected %v, got %v", tt.input.Timeout, input.Timeout) - } - }) - } -} - -func TestParameterValidation(t *testing.T) { - tests := []struct { - name string - models string - baseUrl string - apiKey string - shouldError bool - errorDesc string - }{ - { - name: "Valid parameters", - models: "gpt-3.5-turbo", - baseUrl: "https://api.openai.com", - apiKey: "test-key", - shouldError: false, - }, - { - name: "Empty models", - models: "", - baseUrl: "https://api.openai.com", - apiKey: "test-key", - shouldError: true, - errorDesc: "models parameter is required", - }, - { - name: "Empty baseUrl", - models: "gpt-3.5-turbo", - baseUrl: "", - apiKey: "test-key", - shouldError: true, - errorDesc: "baseUrl is required", - }, - { - name: "Empty apiKey", - models: "gpt-3.5-turbo", - baseUrl: "https://api.openai.com", - apiKey: "", - shouldError: true, - errorDesc: "apiKey is required", - }, - { - name: "Both baseUrl and apiKey empty", - models: "gpt-3.5-turbo", - baseUrl: "", - apiKey: "", - shouldError: true, - errorDesc: "both baseUrl and apiKey are required", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - // 模拟主函数中的参数验证逻辑 - hasError := false - - // 检查 models 参数 - if tt.models == "" { - hasError = true - } - - // 检查 baseUrl 和 apiKey 参数 - if tt.baseUrl == "" || tt.apiKey == "" { - hasError = true - } - - if hasError != tt.shouldError { - t.Errorf("Expected error: %t, got: %t. Description: %s", tt.shouldError, hasError, tt.errorDesc) - } - }) - } -} - -func TestEnvironmentVariableOverride(t *testing.T) { - // 保存原始环境变量 - originalOpenAIKey := os.Getenv("OPENAI_API_KEY") - originalOpenAIURL := os.Getenv("OPENAI_BASE_URL") - originalAnthropicKey := os.Getenv("ANTHROPIC_API_KEY") - originalAnthropicURL := os.Getenv("ANTHROPIC_BASE_URL") - - // 确保测试后恢复原始环境变量 - defer func() { - os.Setenv("OPENAI_API_KEY", originalOpenAIKey) - os.Setenv("OPENAI_BASE_URL", originalOpenAIURL) - os.Setenv("ANTHROPIC_API_KEY", originalAnthropicKey) - os.Setenv("ANTHROPIC_BASE_URL", originalAnthropicURL) - }() - - tests := []struct { - name string - protocol string - cmdBaseUrl string - cmdApiKey string - envVars map[string]string - expectedUrl string - expectedKey string - }{ - { - name: "Command line takes priority over env vars", - protocol: "openai", - cmdBaseUrl: "https://cmd.api.com", - cmdApiKey: "cmd-key", - envVars: map[string]string{ - "OPENAI_BASE_URL": "https://env.api.com", - "OPENAI_API_KEY": "env-key", - }, - expectedUrl: "https://cmd.api.com", - expectedKey: "cmd-key", - }, - { - name: "Env vars used when command line empty", - protocol: "openai", - cmdBaseUrl: "", - cmdApiKey: "", - envVars: map[string]string{ - "OPENAI_BASE_URL": "https://env.api.com", - "OPENAI_API_KEY": "env-key", - }, - expectedUrl: "https://env.api.com", - expectedKey: "env-key", - }, - { - name: "Mixed: command line URL, env key", - protocol: "anthropic", - cmdBaseUrl: "https://cmd.anthropic.com", - cmdApiKey: "", - envVars: map[string]string{ - "ANTHROPIC_BASE_URL": "https://env.anthropic.com", - "ANTHROPIC_API_KEY": "env-anthropic-key", - }, - expectedUrl: "https://cmd.anthropic.com", - expectedKey: "env-anthropic-key", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - // 清除所有环境变量 - os.Unsetenv("OPENAI_API_KEY") - os.Unsetenv("OPENAI_BASE_URL") - os.Unsetenv("ANTHROPIC_API_KEY") - os.Unsetenv("ANTHROPIC_BASE_URL") - - // 设置测试环境变量 - for key, value := range tt.envVars { - os.Setenv(key, value) - } - - // 模拟 main 函数中的逻辑 - finalProtocol := tt.protocol - finalBaseUrl := tt.cmdBaseUrl - finalApiKey := tt.cmdApiKey - - // 根据 protocol 加载对应的环境变量 - if finalBaseUrl == "" || finalApiKey == "" { - envBaseUrl, envApiKey := loadEnvForProvider(finalProtocol) - if finalBaseUrl == "" { - finalBaseUrl = envBaseUrl - } - if finalApiKey == "" { - finalApiKey = envApiKey - } - } - - if finalBaseUrl != tt.expectedUrl { - t.Errorf("Expected baseUrl %s, got %s", tt.expectedUrl, finalBaseUrl) - } - if finalApiKey != tt.expectedKey { - t.Errorf("Expected apiKey %s, got %s", tt.expectedKey, finalApiKey) - } - }) - } -} - -func TestModelListProcessing(t *testing.T) { - tests := []struct { - name string - modelsParam string - expectedCount int - expectedModels []string - shouldError bool - }{ - { - name: "Valid single model", - modelsParam: "gpt-3.5-turbo", - expectedCount: 1, - expectedModels: []string{"gpt-3.5-turbo"}, - shouldError: false, - }, - { - name: "Valid multiple models", - modelsParam: "gpt-3.5-turbo,gpt-4,claude-3", - expectedCount: 3, - expectedModels: []string{"gpt-3.5-turbo", "gpt-4", "claude-3"}, - shouldError: false, - }, - { - name: "Models with various spacing", - modelsParam: " gpt-3.5-turbo , gpt-4, claude-3 ", - expectedCount: 3, - expectedModels: []string{"gpt-3.5-turbo", "gpt-4", "claude-3"}, - shouldError: false, - }, - { - name: "Empty models parameter", - modelsParam: "", - shouldError: true, - }, - { - name: "Models with empty entries", - modelsParam: "gpt-3.5-turbo,,gpt-4", - expectedCount: 3, - expectedModels: []string{"gpt-3.5-turbo", "", "gpt-4"}, - shouldError: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - // 模拟 main 函数中的模型处理逻辑 - if tt.modelsParam == "" && tt.shouldError { - // 验证空参数应该导致错误 - return // 这在实际代码中会导致 os.Exit(1) - } - - // 解析多个模型 - modelList := strings.Split(tt.modelsParam, ",") - for i, m := range modelList { - modelList[i] = strings.TrimSpace(m) - } - - if !tt.shouldError { - if len(modelList) != tt.expectedCount { - t.Errorf("Expected %d models, got %d", tt.expectedCount, len(modelList)) - } - - for i, expectedModel := range tt.expectedModels { - if i < len(modelList) && modelList[i] != expectedModel { - t.Errorf("Model[%d]: expected '%s', got '%s'", i, expectedModel, modelList[i]) - } + for i, want := range tt.expected { + if parts[i] != want { + t.Errorf("[%d] got %q, want %q", i, parts[i], want) } } }) } } - -func TestConfigCreation(t *testing.T) { - tests := []struct { - name string - input struct { - protocol string - baseUrl string - apiKey string - model string - concurrency int - count int - prompt string - stream bool - report bool - timeout int - } - expected types.Input - }{ - { - name: "Complete OpenAI configuration", - input: struct { - protocol string - baseUrl string - apiKey string - model string - concurrency int - count int - prompt string - stream bool - report bool - timeout int - }{ - protocol: "openai", - baseUrl: "https://api.openai.com/v1", - apiKey: "sk-test123", - model: "gpt-3.5-turbo", - concurrency: 5, - count: 20, - prompt: "Hello, world!", - stream: true, - report: false, - timeout: 30, - }, - expected: types.Input{ - Protocol: "openai", - BaseUrl: "https://api.openai.com/v1", - ApiKey: "sk-test123", - Model: "gpt-3.5-turbo", - Concurrency: 5, - Count: 20, - PromptSource: createTestPromptSource("Hello, world!"), - Stream: true, - Report: false, - Timeout: 30 * time.Second, - }, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - // 模拟 main 函数中创建配置的过程 - input := types.Input{ - Protocol: tt.input.protocol, - BaseUrl: tt.input.baseUrl, - ApiKey: tt.input.apiKey, - Model: tt.input.model, - Concurrency: tt.input.concurrency, - Count: tt.input.count, - PromptSource: createTestPromptSource(tt.input.prompt), - Stream: tt.input.stream, - Report: tt.input.report, - Timeout: time.Duration(tt.input.timeout) * time.Second, - } - - // 验证所有字段 - if input.Protocol != tt.expected.Protocol { - t.Errorf("Protocol: expected %s, got %s", tt.expected.Protocol, input.Protocol) - } - if input.BaseUrl != tt.expected.BaseUrl { - t.Errorf("BaseUrl: expected %s, got %s", tt.expected.BaseUrl, input.BaseUrl) - } - if input.Model != tt.expected.Model { - t.Errorf("Model: expected %s, got %s", tt.expected.Model, input.Model) - } - if input.Timeout != tt.expected.Timeout { - t.Errorf("Timeout: expected %v, got %v", tt.expected.Timeout, input.Timeout) - } - }) - } -} - -func TestEdgeCases(t *testing.T) { - tests := []struct { - name string - protocol string - expected string - }{ - { - name: "Empty protocol falls back to env detection", - protocol: "", - expected: "openai", // default when no env vars set - }, - { - name: "Openai protocol", - protocol: "openai", - expected: "openai", - }, - { - name: "Anthropic protocol", - protocol: "anthropic", - expected: "anthropic", - }, - { - name: "Unknown protocol", - protocol: "unknown", - expected: "unknown", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - // 清除环境变量以确保一致的测试环境 - os.Unsetenv("OPENAI_API_KEY") - os.Unsetenv("OPENAI_BASE_URL") - os.Unsetenv("ANTHROPIC_API_KEY") - os.Unsetenv("ANTHROPIC_BASE_URL") - - // 模拟main函数中的协议处理逻辑 - finalProtocol := tt.protocol - if finalProtocol == "" { - finalProtocol = detectProviderFromEnv() - } - - if finalProtocol != tt.expected { - t.Errorf("Expected protocol %s, got %s", tt.expected, finalProtocol) - } - }) - } -} - -func TestValidateParameters(t *testing.T) { - tests := []struct { - name string - models string - baseUrl string - apiKey string - expectExit bool - description string - }{ - { - name: "All parameters valid", - models: "gpt-3.5-turbo", - baseUrl: "https://api.openai.com", - apiKey: "sk-test123", - expectExit: false, - description: "Valid configuration should not exit", - }, - { - name: "Empty models triggers exit", - models: "", - baseUrl: "https://api.openai.com", - apiKey: "sk-test123", - expectExit: true, - description: "Empty models should trigger exit", - }, - { - name: "Empty baseUrl triggers exit", - models: "gpt-3.5-turbo", - baseUrl: "", - apiKey: "sk-test123", - expectExit: true, - description: "Empty baseUrl should trigger exit", - }, - { - name: "Empty apiKey triggers exit", - models: "gpt-3.5-turbo", - baseUrl: "https://api.openai.com", - apiKey: "", - expectExit: true, - description: "Empty apiKey should trigger exit", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - // 模拟参数验证逻辑 - shouldExit := false - - // models 参数检查 - if tt.models == "" { - shouldExit = true - } - - // baseUrl 和 apikey 检查 - if tt.baseUrl == "" || tt.apiKey == "" { - shouldExit = true - } - - if shouldExit != tt.expectExit { - t.Errorf("%s: expected exit=%t, got exit=%t", tt.description, tt.expectExit, shouldExit) - } - }) - } -} - -func TestTimeoutHandling(t *testing.T) { - tests := []struct { - name string - timeoutSeconds int - expectedTimeout time.Duration - }{ - { - name: "Default timeout", - timeoutSeconds: 30, - expectedTimeout: 30 * time.Second, - }, - { - name: "Custom timeout", - timeoutSeconds: 60, - expectedTimeout: 60 * time.Second, - }, - { - name: "Short timeout", - timeoutSeconds: 5, - expectedTimeout: 5 * time.Second, - }, - { - name: "Zero timeout", - timeoutSeconds: 0, - expectedTimeout: 0 * time.Second, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - // 模拟timeout转换逻辑 - timeout := time.Duration(tt.timeoutSeconds) * time.Second - - if timeout != tt.expectedTimeout { - t.Errorf("Expected timeout %v, got %v", tt.expectedTimeout, timeout) - } - }) - } -} - -func TestValidateRequiredParams(t *testing.T) { - tests := []struct { - name string - models string - baseUrl string - apiKey string - protocol string - expectError bool - errorMsg string - }{ - { - name: "All params valid", - models: "gpt-3.5-turbo", - baseUrl: "https://api.openai.com", - apiKey: "sk-test123", - protocol: "openai", - expectError: false, - }, - { - name: "Empty models", - models: "", - baseUrl: "https://api.openai.com", - apiKey: "sk-test123", - protocol: "openai", - expectError: true, - errorMsg: "models 参数必填", - }, - { - name: "Empty baseUrl", - models: "gpt-3.5-turbo", - baseUrl: "", - apiKey: "sk-test123", - protocol: "openai", - expectError: true, - errorMsg: "baseUrl 和 apikey 参数必填", - }, - { - name: "Empty apiKey", - models: "gpt-3.5-turbo", - baseUrl: "https://api.openai.com", - apiKey: "", - protocol: "openai", - expectError: true, - errorMsg: "baseUrl 和 apikey 参数必填", - }, - { - name: "Both baseUrl and apiKey empty", - models: "gpt-3.5-turbo", - baseUrl: "", - apiKey: "", - protocol: "anthropic", - expectError: true, - errorMsg: "baseUrl 和 apikey 参数必填", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - err := validateRequiredParams(tt.models, tt.baseUrl, tt.apiKey, tt.protocol) - - if tt.expectError { - if err == nil { - t.Errorf("Expected error containing '%s', but got no error", tt.errorMsg) - } else if !strings.Contains(err.Error(), tt.errorMsg) { - t.Errorf("Expected error containing '%s', got '%s'", tt.errorMsg, err.Error()) - } - } else { - if err != nil { - t.Errorf("Expected no error, got: %v", err) - } - } - }) - } -} - -func TestParseModelList(t *testing.T) { - tests := []struct { - name string - input string - expected []string - }{ - { - name: "Empty string", - input: "", - expected: nil, - }, - { - name: "Single model", - input: "gpt-3.5-turbo", - expected: []string{"gpt-3.5-turbo"}, - }, - { - name: "Multiple models", - input: "gpt-3.5-turbo,gpt-4,claude-3", - expected: []string{"gpt-3.5-turbo", "gpt-4", "claude-3"}, - }, - { - name: "Models with spaces", - input: " gpt-3.5-turbo , gpt-4 , claude-3 ", - expected: []string{"gpt-3.5-turbo", "gpt-4", "claude-3"}, - }, - { - name: "Models with empty entries", - input: "gpt-3.5-turbo,,gpt-4", - expected: []string{"gpt-3.5-turbo", "", "gpt-4"}, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := parseModelList(tt.input) - - if len(result) != len(tt.expected) { - t.Errorf("Expected %d models, got %d", len(tt.expected), len(result)) - return - } - - for i, expected := range tt.expected { - if result[i] != expected { - t.Errorf("Model[%d]: expected '%s', got '%s'", i, expected, result[i]) - } - } - }) - } -} - -func TestResolveConfigValues(t *testing.T) { - // 保存原始环境变量 - originalOpenAIKey := os.Getenv("OPENAI_API_KEY") - originalOpenAIURL := os.Getenv("OPENAI_BASE_URL") - originalAnthropicKey := os.Getenv("ANTHROPIC_API_KEY") - originalAnthropicURL := os.Getenv("ANTHROPIC_BASE_URL") - - // 确保测试后恢复原始环境变量 - defer func() { - os.Setenv("OPENAI_API_KEY", originalOpenAIKey) - os.Setenv("OPENAI_BASE_URL", originalOpenAIURL) - os.Setenv("ANTHROPIC_API_KEY", originalAnthropicKey) - os.Setenv("ANTHROPIC_BASE_URL", originalAnthropicURL) - }() - - tests := []struct { - name string - inputProtocol string - inputBaseUrl string - inputApiKey string - envVars map[string]string - expectedProtocol string - expectedBaseUrl string - expectedApiKey string - }{ - { - name: "All command line params provided", - inputProtocol: "openai", - inputBaseUrl: "https://cmd.api.com", - inputApiKey: "cmd-key", - envVars: map[string]string{}, - expectedProtocol: "openai", - expectedBaseUrl: "https://cmd.api.com", - expectedApiKey: "cmd-key", - }, - { - name: "Empty protocol, auto-detect from env", - inputProtocol: "", - inputBaseUrl: "https://cmd.api.com", - inputApiKey: "cmd-key", - envVars: map[string]string{ - "OPENAI_API_KEY": "env-key", - }, - expectedProtocol: "openai", - expectedBaseUrl: "https://cmd.api.com", - expectedApiKey: "cmd-key", - }, - { - name: "Missing baseUrl, get from env", - inputProtocol: "openai", - inputBaseUrl: "", - inputApiKey: "cmd-key", - envVars: map[string]string{ - "OPENAI_BASE_URL": "https://env.api.com", - }, - expectedProtocol: "openai", - expectedBaseUrl: "https://env.api.com", - expectedApiKey: "cmd-key", - }, - { - name: "Missing apiKey, get from env", - inputProtocol: "anthropic", - inputBaseUrl: "https://cmd.api.com", - inputApiKey: "", - envVars: map[string]string{ - "ANTHROPIC_API_KEY": "env-key", - }, - expectedProtocol: "anthropic", - expectedBaseUrl: "https://cmd.api.com", - expectedApiKey: "env-key", - }, - { - name: "All from env vars", - inputProtocol: "", - inputBaseUrl: "", - inputApiKey: "", - envVars: map[string]string{ - "ANTHROPIC_BASE_URL": "https://env.anthropic.com", - "ANTHROPIC_API_KEY": "env-ant-key", - }, - expectedProtocol: "anthropic", - expectedBaseUrl: "https://env.anthropic.com", - expectedApiKey: "env-ant-key", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - // 清除所有环境变量 - os.Unsetenv("OPENAI_API_KEY") - os.Unsetenv("OPENAI_BASE_URL") - os.Unsetenv("ANTHROPIC_API_KEY") - os.Unsetenv("ANTHROPIC_BASE_URL") - - // 设置测试环境变量 - for key, value := range tt.envVars { - os.Setenv(key, value) - } - - protocol, baseUrl, apiKey := resolveConfigValues(tt.inputProtocol, tt.inputBaseUrl, tt.inputApiKey) - - if protocol != tt.expectedProtocol { - t.Errorf("Protocol: expected %s, got %s", tt.expectedProtocol, protocol) - } - if baseUrl != tt.expectedBaseUrl { - t.Errorf("BaseUrl: expected %s, got %s", tt.expectedBaseUrl, baseUrl) - } - if apiKey != tt.expectedApiKey { - t.Errorf("ApiKey: expected %s, got %s", tt.expectedApiKey, apiKey) - } - }) - } -} - -func TestPrintErrorMessages(t *testing.T) { - tests := []struct { - name string - protocol string - expected []string - }{ - { - name: "OpenAI protocol", - protocol: "openai", - expected: []string{ - " OPENAI_BASE_URL - OpenAI API 基础 URL", - " OPENAI_API_KEY - OpenAI API 密钥", - }, - }, - { - name: "Anthropic protocol", - protocol: "anthropic", - expected: []string{ - " ANTHROPIC_BASE_URL - Anthropic API 基础 URL", - " ANTHROPIC_API_KEY - Anthropic API 密钥", - }, - }, - { - name: "Unknown protocol", - protocol: "unknown", - expected: []string{}, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - // 由于printErrorMessages函数直接输出到stdout,我们只能测试它不会panic - // 在实际应用中,可以考虑将其重构为返回字符串的函数 - defer func() { - if r := recover(); r != nil { - t.Errorf("printErrorMessages panicked: %v", r) - } - }() - - printErrorMessages(tt.protocol) - }) - } -} - -func TestCreateRunnerConfig(t *testing.T) { - tests := []struct { - name string - protocol string - baseUrl string - apiKey string - model string - prompt string - concurrency int - count int - timeout int - stream bool - report bool - expected types.Input - }{ - { - name: "Complete OpenAI config", - protocol: "openai", - baseUrl: "https://api.openai.com", - apiKey: "sk-test123", - model: "gpt-3.5-turbo", - prompt: "Hello world", - concurrency: 5, - count: 20, - timeout: 30, - stream: true, - report: false, - expected: types.Input{ - Protocol: "openai", - BaseUrl: "https://api.openai.com", - ApiKey: "sk-test123", - Model: "gpt-3.5-turbo", - PromptSource: createTestPromptSource("Hello world"), - Concurrency: 5, - Count: 20, - Timeout: 30 * time.Second, - Stream: true, - Report: false, - }, - }, - { - name: "Anthropic config with defaults", - protocol: "anthropic", - baseUrl: "https://api.anthropic.com", - apiKey: "ant-test456", - model: "claude-3", - prompt: "Test prompt", - concurrency: 3, - count: 10, - timeout: 60, - stream: false, - report: true, - expected: types.Input{ - Protocol: "anthropic", - BaseUrl: "https://api.anthropic.com", - ApiKey: "ant-test456", - Model: "claude-3", - PromptSource: createTestPromptSource("Test prompt"), - Concurrency: 3, - Count: 10, - Timeout: 60 * time.Second, - Stream: false, - Report: true, - }, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := createRunnerConfig( - tt.protocol, tt.baseUrl, tt.apiKey, tt.model, createTestPromptSource(tt.prompt), - tt.concurrency, tt.count, tt.timeout, tt.stream, tt.report, - false, false, - ) - - if result.Protocol != tt.expected.Protocol { - t.Errorf("Protocol: expected %s, got %s", tt.expected.Protocol, result.Protocol) - } - if result.BaseUrl != tt.expected.BaseUrl { - t.Errorf("BaseUrl: expected %s, got %s", tt.expected.BaseUrl, result.BaseUrl) - } - if result.ApiKey != tt.expected.ApiKey { - t.Errorf("ApiKey: expected %s, got %s", tt.expected.ApiKey, result.ApiKey) - } - if result.Model != tt.expected.Model { - t.Errorf("Model: expected %s, got %s", tt.expected.Model, result.Model) - } - if result.PromptSource.GetRandomContent() != tt.expected.PromptSource.GetRandomContent() { - t.Errorf("PromptSource content: expected %s, got %s", tt.expected.PromptSource.GetRandomContent(), result.PromptSource.GetRandomContent()) - } - if result.Concurrency != tt.expected.Concurrency { - t.Errorf("Concurrency: expected %d, got %d", tt.expected.Concurrency, result.Concurrency) - } - if result.Count != tt.expected.Count { - t.Errorf("Count: expected %d, got %d", tt.expected.Count, result.Count) - } - if result.Timeout != tt.expected.Timeout { - t.Errorf("Timeout: expected %v, got %v", tt.expected.Timeout, result.Timeout) - } - if result.Stream != tt.expected.Stream { - t.Errorf("Stream: expected %t, got %t", tt.expected.Stream, result.Stream) - } - if result.Report != tt.expected.Report { - t.Errorf("Report: expected %t, got %t", tt.expected.Report, result.Report) - } - }) - } -} - -func TestCollectErrorsWithContext(t *testing.T) { - tests := []struct { - name string - modelName string - modelErrors []string - expected []string - }{ - { - name: "No errors", - modelName: "gpt-3.5-turbo", - modelErrors: []string{}, - expected: []string{}, - }, - { - name: "Single error", - modelName: "gpt-4", - modelErrors: []string{"Connection timeout"}, - expected: []string{"[gpt-4] Connection timeout"}, - }, - { - name: "Multiple errors", - modelName: "claude-3", - modelErrors: []string{"Rate limit exceeded", "API key invalid"}, - expected: []string{"[claude-3] Rate limit exceeded", "[claude-3] API key invalid"}, - }, - { - name: "Mixed errors with empty strings", - modelName: "gpt-3.5-turbo", - modelErrors: []string{"Error 1", "", "Error 2"}, - expected: []string{"[gpt-3.5-turbo] Error 1", "[gpt-3.5-turbo] Error 2"}, - }, - { - name: "Only empty errors", - modelName: "test-model", - modelErrors: []string{"", "", ""}, - expected: []string{}, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := collectErrorsWithContext(tt.modelName, tt.modelErrors) - - if len(result) != len(tt.expected) { - t.Errorf("Expected %d errors, got %d", len(tt.expected), len(result)) - return - } - - for i, expected := range tt.expected { - if result[i] != expected { - t.Errorf("Error[%d]: expected '%s', got '%s'", i, expected, result[i]) - } - } - }) - } -} - -func TestFillResultMetadata(t *testing.T) { - // 创建测试数据 - modelList := []string{"gpt-3.5-turbo", "gpt-4"} - baseUrl := "https://api.openai.com" - protocol := "openai" - - results := []*types.ReportData{ - { - TotalRequests: 10, - Concurrency: 3, - }, - { - TotalRequests: 20, - Concurrency: 5, - }, - } - - // 调用被测试函数 - fillResultMetadata(results, modelList, baseUrl, protocol) - - // 验证结果 - for i, result := range results { - if result.Model != modelList[i] { - t.Errorf("Result[%d] Model: expected %s, got %s", i, modelList[i], result.Model) - } - if result.BaseUrl != baseUrl { - t.Errorf("Result[%d] BaseUrl: expected %s, got %s", i, baseUrl, result.BaseUrl) - } - if result.Protocol != protocol { - t.Errorf("Result[%d] Protocol: expected %s, got %s", i, protocol, result.Protocol) - } - if result.Timestamp == "" { - t.Errorf("Result[%d] Timestamp should not be empty", i) - } - // 验证时间戳格式是否为RFC3339 - if _, err := time.Parse(time.RFC3339, result.Timestamp); err != nil { - t.Errorf("Result[%d] Timestamp format invalid: %v", i, err) - } - } -} - -func TestConvertErrorsToPointers(t *testing.T) { - tests := []struct { - name string - errors []string - expected int - }{ - { - name: "Empty slice", - errors: []string{}, - expected: 0, - }, - { - name: "Single error", - errors: []string{"Error 1"}, - expected: 1, - }, - { - name: "Multiple errors", - errors: []string{"Error 1", "Error 2", "Error 3"}, - expected: 3, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := convertErrorsToPointers(tt.errors) - - if len(result) != tt.expected { - t.Errorf("Expected %d pointers, got %d", tt.expected, len(result)) - return - } - - // 验证指针指向正确的值 - for i, errorPtr := range result { - if errorPtr == nil { - t.Errorf("Pointer[%d] should not be nil", i) - continue - } - if *errorPtr != tt.errors[i] { - t.Errorf("Pointer[%d]: expected '%s', got '%s'", i, tt.errors[i], *errorPtr) - } - } - }) - } -} - -func TestGenerateReportsIfEnabled(t *testing.T) { - tests := []struct { - name string - reportFlag bool - results []*types.ReportData - expectCall bool - }{ - { - name: "Report disabled", - reportFlag: false, - results: []*types.ReportData{{}}, - expectCall: false, - }, - { - name: "No results", - reportFlag: true, - results: []*types.ReportData{}, - expectCall: false, - }, - { - name: "Report enabled with results", - reportFlag: true, - results: []*types.ReportData{ - { - TotalRequests: 10, - Concurrency: 3, - }, - }, - expectCall: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - // 由于这个函数调用外部依赖(report.NewReportManager), - // 在实际测试中可能需要依赖注入或模拟 - // 这里我们只测试函数不会panic - defer func() { - if r := recover(); r != nil { - t.Errorf("generateReportsIfEnabled panicked: %v", r) - } - }() - - err := generateReportsIfEnabled(tt.reportFlag, tt.results) - - // 如果不应该调用,检查是否返回nil error(因为是空操作) - if !tt.expectCall && err != nil { - t.Errorf("Expected no error when report disabled or no results, got: %v", err) - } - }) - } -} - -func TestProcessModelExecution(t *testing.T) { - tests := []struct { - name string - modelName string - input types.Input - displayer *display.Displayer - completedRequests int - totalRequests int - expectedResult bool - }{ - { - name: "Successful execution", - modelName: "gpt-3.5-turbo", - input: types.Input{ - PromptSource: createTestPromptSource("test prompt"), - Protocol: "openai", - BaseUrl: "https://api.openai.com", - ApiKey: "test-key", - Timeout: 30, - Count: 1, - Concurrency: 1, - }, - displayer: display.New(), - completedRequests: 0, - totalRequests: 1, - expectedResult: true, - }, - { - name: "Another model execution", - modelName: "gpt-4", - input: types.Input{ - PromptSource: createTestPromptSource("test prompt"), - Protocol: "openai", - BaseUrl: "https://api.openai.com", - ApiKey: "test-key", - Timeout: 30, - Count: 1, - Concurrency: 1, - }, - displayer: display.New(), - completedRequests: 0, - totalRequests: 1, - expectedResult: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - // 设置模型配置 - tt.input.Model = tt.modelName - - result, errorMessages, err := processModelExecution( - "test-task-id", - tt.modelName, - tt.input, - tt.displayer, - tt.completedRequests, - tt.totalRequests, - ) - - // 检查基本返回值 - if tt.expectedResult { - if result == nil && err != nil { - t.Errorf("Expected successful result, but got error: %v", err) - } - } - - // 验证 errorMessages 不是 nil - if errorMessages == nil { - t.Errorf("errorMessages should not be nil") - } - - // 确保没有 panic,这是最基本的要求 - }) - } -} - -func TestExecuteModelsTestSuite(t *testing.T) { - tests := []struct { - name string - modelList []string - protocol string - baseUrl string - apiKey string - prompt string - concurrency int - count int - timeout int - stream bool - reportFlag bool - displayer *display.Displayer - expectedLen int - expectError bool - }{ - { - name: "Single model execution", - modelList: []string{"gpt-3.5-turbo"}, - protocol: "openai", - baseUrl: "https://api.openai.com", - apiKey: "test-key", - prompt: "test prompt", - concurrency: 1, - count: 1, - timeout: 30, - stream: false, - reportFlag: false, - displayer: display.New(), - expectedLen: 1, - expectError: false, - }, - { - name: "Multiple models execution", - modelList: []string{"gpt-3.5-turbo", "gpt-4"}, - protocol: "openai", - baseUrl: "https://api.openai.com", - apiKey: "test-key", - prompt: "test prompt", - concurrency: 1, - count: 1, - timeout: 30, - stream: false, - reportFlag: false, - displayer: display.New(), - expectedLen: 2, - expectError: false, - }, - { - name: "Empty models list", - modelList: []string{}, - protocol: "openai", - baseUrl: "https://api.openai.com", - apiKey: "test-key", - prompt: "test prompt", - concurrency: 1, - count: 1, - timeout: 30, - stream: false, - reportFlag: false, - displayer: display.New(), - expectedLen: 0, - expectError: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - results, errorMessages, err := executeModelsTestSuite( - "test-task-id", - tt.modelList, - tt.protocol, - tt.baseUrl, - tt.apiKey, - createTestPromptSource(tt.prompt), - tt.concurrency, - tt.count, - tt.timeout, - tt.stream, - tt.reportFlag, - false, - false, - tt.displayer, - ) - - if tt.expectError { - if err == nil { - t.Errorf("Expected error, but got none") - } - } else { - if err != nil { - t.Errorf("Unexpected error: %v", err) - } - } - - if len(results) != tt.expectedLen { - t.Errorf("Expected %d results, got %d", tt.expectedLen, len(results)) - } - - // All tests should complete without panicking - _ = errorMessages // Use the variable to avoid unused variable warning - }) - } -} diff --git a/internal/tui/client.go b/internal/tui/client.go new file mode 100644 index 0000000..9ddf935 --- /dev/null +++ b/internal/tui/client.go @@ -0,0 +1,147 @@ +package tui + +import ( + "fmt" + + tea "github.com/charmbracelet/bubbletea" + "github.com/yinxulai/ait/internal/server" +) + +// Client 持有 server.Server,为 TUI 层提供 tea.Cmd 包装的异步调用。 +// TUI Model 通过 Client 与 Server 交互,不直接 import runner/task/turbo 等下层包。 +type Client struct { + srv server.Server +} + +// NewClient 创建 Client 实例。 +func NewClient(srv server.Server) *Client { + return &Client{srv: srv} +} + +// ─── 任务 CRUD ──────────────────────────────────────────────────────────────── + +// LoadTasksCmd 异步加载任务列表。 +func (c *Client) LoadTasksCmd() tea.Cmd { + return func() tea.Msg { + return TasksLoadedMsg{Tasks: c.srv.ListTasks()} + } +} + +// CreateTaskCmd 异步新建任务,autoStart 表示成功后是否自动触发运行。 +func (c *Client) CreateTaskCmd(cfg server.TaskConfig, autoStart bool) tea.Cmd { + return func() tea.Msg { + task, err := c.srv.CreateTask(cfg) + if err != nil { + return ErrorMsg{Err: fmt.Errorf("创建任务失败: %w", err)} + } + return TaskSavedMsg{Task: task, AutoStart: autoStart} + } +} + +// UpdateTaskCmd 异步更新任务。 +func (c *Client) UpdateTaskCmd(id string, cfg server.TaskConfig) tea.Cmd { + return func() tea.Msg { + task, err := c.srv.UpdateTask(id, cfg) + if err != nil { + return ErrorMsg{Err: fmt.Errorf("更新任务失败: %w", err)} + } + return TaskSavedMsg{Task: task, AutoStart: false} + } +} + +// DeleteTaskCmd 异步删除任务。 +func (c *Client) DeleteTaskCmd(id string) tea.Cmd { + return func() tea.Msg { + if err := c.srv.DeleteTask(id); err != nil { + return ErrorMsg{Err: fmt.Errorf("删除任务失败: %w", err)} + } + return TaskDeletedMsg{TaskID: id} + } +} + +// CopyTaskCmd 异步复制任务。 +func (c *Client) CopyTaskCmd(id string) tea.Cmd { + return func() tea.Msg { + task, err := c.srv.CopyTask(id) + if err != nil { + return ErrorMsg{Err: fmt.Errorf("复制任务失败: %w", err)} + } + return TaskSavedMsg{Task: task, AutoStart: false} + } +} + +// ─── 运行管理 ───────────────────────────────────────────────────────────────── + +// StartRunCmd 异步启动运行。 +func (c *Client) StartRunCmd(taskID string) tea.Cmd { + return func() tea.Msg { + runID, err := c.srv.StartRun(taskID) + if err != nil { + return ErrorMsg{Err: fmt.Errorf("启动运行失败: %w", err)} + } + return RunStartedMsg{RunID: runID, TaskID: taskID} + } +} + +// StopRunCmd 异步停止运行(fire-and-forget,忽略错误)。 +func (c *Client) StopRunCmd(runID server.RunID) tea.Cmd { + return func() tea.Msg { + _ = c.srv.StopRun(runID) + return nil + } +} + +// SubscribeCmd 订阅 runID 的事件流,返回用于首次等待的 Cmd 和 CancelFunc。 +// 调用方应将 ch 存储在 dashboardState 中,每次收到 ServerEventMsg 后 +// 再次调用 WaitEventCmd(ch) 继续监听。 +func (c *Client) SubscribeCmd(runID server.RunID) (<-chan server.Event, server.CancelFunc, tea.Cmd) { + ch, cancel := c.srv.Subscribe(runID) + return ch, cancel, WaitEventCmd(ch) +} + +// WaitEventCmd 等待事件通道的下一条事件。 +// 通道关闭时返回 nil(Update 中检测 nil 即可停止循环)。 +func WaitEventCmd(ch <-chan server.Event) tea.Cmd { + return func() tea.Msg { + event, ok := <-ch + if !ok { + return nil + } + return ServerEventMsg{Event: event} + } +} + +// ─── 历史 & 报告 ────────────────────────────────────────────────────────────── + +// LoadHistoryCmd 异步加载指定任务的运行历史,limit<=0 表示不限条数。 +func (c *Client) LoadHistoryCmd(taskID string, limit int) tea.Cmd { + return func() tea.Msg { + history, err := c.srv.GetHistory(taskID, limit) + if err != nil { + return ErrorMsg{Err: fmt.Errorf("加载历史失败: %w", err)} + } + return HistoryLoadedMsg{TaskID: taskID, History: history} + } +} + +// GetRunStateCmd 异步获取运行状态快照(后台模式重入仪表盘时使用)。 +func (c *Client) GetRunStateCmd(runID server.RunID) tea.Cmd { + return func() tea.Msg { + state, ok := c.srv.GetRunState(runID) + if !ok { + return nil + } + return RunStateMsg{State: state} + } +} + +// GenerateReportCmd 异步生成报告文件。 +func (c *Client) GenerateReportCmd(runID server.RunID, format server.ReportFormat) tea.Cmd { + return func() tea.Msg { + path, err := c.srv.GenerateReport(runID, format) + if err != nil { + return ErrorMsg{Err: fmt.Errorf("生成报告失败: %w", err)} + } + return ReportGeneratedMsg{RunID: runID, Path: path} + } +} diff --git a/internal/tui/contextbar.go b/internal/tui/contextbar.go new file mode 100644 index 0000000..10d84de --- /dev/null +++ b/internal/tui/contextbar.go @@ -0,0 +1,95 @@ +package tui + +import ( + "fmt" + "strings" +) + +// contextBarItem 描述 Context Bar 中的一个可用操作。 +type contextBarItem struct { + key string + desc string +} + +// renderContextBar 渲染 Context Bar:紧贴 Footer 上方的动态操作提示行。 +// 若 items 为空则返回空字符串(不占空间)。 +func (m *Model) renderContextBar(items []contextBarItem) string { + if len(items) == 0 { + return "" + } + var parts []string + for _, item := range items { + parts = append(parts, fmt.Sprintf("%s %s", + m.styles.key.Render("["+item.key+"]"), + m.styles.muted.Render(item.desc), + )) + } + bar := " " + strings.Join(parts, " ") + barW := m.width + if barW < 1 { + barW = 80 + } + return m.styles.footer.Width(barW).Render(bar) +} + +// contextBarItems_taskList 返回任务列表页的 Context Bar 内容。 +func contextBarItems_taskList(isRunning bool) []contextBarItem { + if isRunning { + return []contextBarItem{ + {"Enter", "进入仪表盘"}, + {"s", "停止"}, + {"y", "复制"}, + } + } + return []contextBarItem{ + {"Enter", "查看详情"}, + {"r", "运行"}, + {"e", "编辑"}, + {"d", "删除"}, + {"y", "复制"}, + } +} + +// contextBarItems_taskDetail 返回任务详情页的 Context Bar 内容。 +func contextBarItems_taskDetail(hasHistory bool) []contextBarItem { + if hasHistory { + return []contextBarItem{ + {"r", "生成报告"}, + {"c", "复制摘要"}, + {"Enter", "再次运行"}, + {"e", "编辑"}, + } + } + return []contextBarItem{ + {"Enter", "运行"}, + {"e", "编辑"}, + {"y", "复制"}, + {"d", "删除"}, + } +} + +// contextBarItems_dashboard_nosel 仪表盘无选中请求。 +func contextBarItems_dashboard_nosel() []contextBarItem { + return []contextBarItem{ + {"s", "停止"}, + {"b", "后台运行"}, + {"r", "提前报告"}, + } +} + +// contextBarItems_dashboard_sel 仪表盘有选中请求。 +func contextBarItems_dashboard_sel() []contextBarItem { + return []contextBarItem{ + {"Enter", "查看请求详情"}, + {"↑↓", "选择请求"}, + {"s", "停止"}, + } +} + +// contextBarItems_reqdetail 请求详情页。 +func contextBarItems_reqdetail() []contextBarItem { + return []contextBarItem{ + {"b/Esc", "返回仪表盘"}, + {"←→", "上/下一条请求"}, + } +} diff --git a/internal/tui/messages.go b/internal/tui/messages.go index 41106b9..7321d63 100644 --- a/internal/tui/messages.go +++ b/internal/tui/messages.go @@ -1,26 +1,55 @@ package tui -import "github.com/yinxulai/ait/internal/types" +import ( + "github.com/yinxulai/ait/internal/server" + "github.com/yinxulai/ait/internal/types" +) -type progressMsg struct { - stats types.StatsData +// TasksLoadedMsg 任务列表加载完成(初始化或刷新后)。 +type TasksLoadedMsg struct { + Tasks []types.TaskDefinition } -type runCompleteMsg struct { - taskID string - result *types.ReportData - reportPaths []string +// TaskSavedMsg 新建或更新任务完成。 +type TaskSavedMsg struct { + Task types.TaskDefinition + AutoStart bool // 是否自动启动运行(由调用方决定) } -type turboCompleteMsg struct { - taskID string - result *types.TurboResult +// TaskDeletedMsg 任务删除完成。 +type TaskDeletedMsg struct { + TaskID string } -type asyncErrorMsg struct { - err error +// HistoryLoadedMsg 任务历史记录加载完成。 +type HistoryLoadedMsg struct { + TaskID string + History []types.TaskRunSummary } -type requestLogMsg struct { - entry string +// RunStartedMsg 运行成功启动,携带 RunID 供后续订阅事件使用。 +type RunStartedMsg struct { + RunID server.RunID + TaskID string +} + +// ServerEventMsg 封装从 server.Subscribe 获取的事件,由 waitEventCmd 产生。 +type ServerEventMsg struct { + Event server.Event +} + +// RunStateMsg server.GetRunState 的轮询结果(用于后台运行恢复仪表盘)。 +type RunStateMsg struct { + State *server.RunState +} + +// ReportGeneratedMsg 报告文件生成完成。 +type ReportGeneratedMsg struct { + RunID server.RunID + Path string +} + +// ErrorMsg 通用异步错误,显示在状态栏。 +type ErrorMsg struct { + Err error } diff --git a/internal/tui/model.go b/internal/tui/model.go index 6194176..7367f10 100644 --- a/internal/tui/model.go +++ b/internal/tui/model.go @@ -1,1907 +1,323 @@ +// Package tui implements the interactive terminal UI for AIT. +// The TUI is built with BubbleTea; all server interactions go through Client. package tui import ( "fmt" - "path/filepath" - "strconv" "strings" - "time" - "github.com/charmbracelet/bubbles/textinput" tea "github.com/charmbracelet/bubbletea" "github.com/charmbracelet/lipgloss" - "github.com/yinxulai/ait/internal/config" - "github.com/yinxulai/ait/internal/report" - "github.com/yinxulai/ait/internal/runner" - "github.com/yinxulai/ait/internal/task" - "github.com/yinxulai/ait/internal/turbo" - "github.com/yinxulai/ait/internal/types" + "github.com/yinxulai/ait/internal/server" ) -type viewState string - -const ( - viewTaskList viewState = "task-list" - viewTaskDetail viewState = "task-detail" - viewWizard viewState = "wizard" - viewDashboard viewState = "dashboard" - viewResult viewState = "result" - viewTurboResult viewState = "turbo-result" -) - -const ( - modeStandard = "standard" - modeTurbo = "turbo" - - promptModeText = "text" - promptModeFile = "file" - promptModeGenerated = "generated" -) - -var protocolOptions = []string{ - types.ProtocolOpenAICompletions, - types.ProtocolOpenAIResponses, - types.ProtocolAnthropicMessages, -} - -var promptModeOptions = []string{promptModeText, promptModeFile, promptModeGenerated} +// ─── 视图状态 ───────────────────────────────────────────────────────────────── -type fieldKind int +type viewState string const ( - fieldText fieldKind = iota - fieldSelect - fieldToggle + viewTaskList viewState = "task-list" + viewTaskDetail viewState = "task-detail" + viewWizard viewState = "wizard" + viewDashboard viewState = "dashboard" + viewReqDetail viewState = "req-detail" ) -type wizardField struct { - key string - label string - kind fieldKind -} - -type wizardState struct { - editingTaskID string - createdAt time.Time - lastRunAt *time.Time - lastRunSummary *types.TaskRunSummary - fromView viewState - step int // 0=基本信息 1=测试参数 2=确认保存 - fieldIndex int // active field within current step - input textinput.Model - values map[string]string - protocolIndex int - mode string - promptModeIndex int - stream bool - thinking bool - report bool -} +// ─── 根 Model ───────────────────────────────────────────────────────────────── +// Model 是 BubbleTea 的根状态机。 +// 所有 Server 交互均通过 Client 发出 tea.Cmd;Model 不直接 import runner/task/turbo。 type Model struct { - styles styles - store *task.TaskStore - config *config.Config - tasks []types.TaskDefinition - history []types.TaskRunSummary - selected int - view viewState - wizard *wizardState - width int - height int - status string - err error - program *tea.Program - runningTask *types.TaskDefinition - runningTaskID string - runStartedAt time.Time - progress types.StatsData - runResult *types.ReportData - turboResult *types.TurboResult - activeRunner *runner.Runner - activeTurbo *turbo.Engine - requestLog []string -} - -func NewModel(store *task.TaskStore, cfg *config.Config) *Model { + client *Client + styles styles + width int + height int + view viewState + status string + err error + + // 页面局部状态 + taskList taskListState + hist *historyState // 任务详情页的历史 + wizard *wizardState // nil = 向导未打开 + dash *dashboardState // nil = 无活跃运行 + reqDetail *reqDetailState // nil = 不在请求详情页 +} + +// NewModel 创建 Model。srv 不能为 nil。 +func NewModel(srv server.Server) *Model { return &Model{ + client: NewClient(srv), styles: newStyles(), - store: store, - config: cfg, - tasks: store.Tasks, view: viewTaskList, + taskList: taskListState{selected: 0}, } } -func Run() error { - store, err := task.LoadTasks() - if err != nil { - return err - } - cfg, err := config.Load() - if err != nil { - return err - } - model := NewModel(store, cfg) - program := tea.NewProgram(model, tea.WithAltScreen()) - model.program = program - _, err = program.Run() +// Run 启动 BubbleTea 全屏程序。是此包的主要外部入口。 +func Run(srv server.Server) error { + m := NewModel(srv) + p := tea.NewProgram(m, tea.WithAltScreen()) + _, err := p.Run() return err } +// ─── BubbleTea 接口 ─────────────────────────────────────────────────────────── + func (m *Model) Init() tea.Cmd { - return nil + return m.client.LoadTasksCmd() } func (m *Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { switch msg := msg.(type) { + + // ── 窗口尺寸 ── case tea.WindowSizeMsg: m.width = msg.Width m.height = msg.Height return m, nil - case progressMsg: - m.progress = msg.stats - return m, nil - case requestLogMsg: - m.requestLog = append(m.requestLog, msg.entry) - if len(m.requestLog) > 60 { - m.requestLog = m.requestLog[len(m.requestLog)-60:] - } - return m, nil - case runCompleteMsg: - m.activeRunner = nil - m.runningTaskID = "" - m.runResult = msg.result - if m.view == viewDashboard { - m.view = viewResult - } - m.status = fmt.Sprintf("标准模式完成,共 %d 请求", msg.result.TotalRequests) - m.persistStandardRun(msg.taskID, msg.result, msg.reportPaths) - return m, nil - case turboCompleteMsg: - m.activeTurbo = nil - m.runningTaskID = "" - m.turboResult = msg.result - if m.view == viewDashboard { - m.view = viewTurboResult - } - m.status = fmt.Sprintf("Turbo 完成,最大稳定并发 %d", msg.result.MaxStableConcurrency) - m.persistTurboRun(msg.taskID, msg.result) - return m, nil - case asyncErrorMsg: - m.runningTaskID = "" - m.err = msg.err - m.status = msg.err.Error() - return m, nil + + // ── 键盘 ── case tea.KeyMsg: return m.handleKey(msg) - } - - return m, nil -} - -func (m *Model) handleKey(msg tea.KeyMsg) (tea.Model, tea.Cmd) { - if msg.String() == "ctrl+c" { - return m, tea.Quit - } - - switch m.view { - case viewTaskList: - return m.handleTaskListKey(msg) - case viewTaskDetail: - return m.handleTaskDetailKey(msg) - case viewWizard: - return m.handleWizardKey(msg) - case viewDashboard: - return m.handleDashboardKey(msg) - case viewResult, viewTurboResult: - if msg.String() == "b" || msg.String() == "esc" || msg.String() == "enter" { - m.reloadHistoryForSelectedTask() - m.view = viewTaskDetail - return m, nil - } - } - return m, nil -} - -func (m *Model) handleTaskListKey(msg tea.KeyMsg) (tea.Model, tea.Cmd) { - switch msg.String() { - case "up", "k": - if m.selected > 0 { - m.selected-- - } - case "down", "j": - if m.selected < len(m.tasks)-1 { - m.selected++ - } - case "a": - m.openWizard(nil) - case "e": - if taskDef, ok := m.currentTask(); ok { - copyTask := taskDef - m.openWizard(©Task) - } - case "y": - if taskDef, ok := m.currentTask(); ok { - copyTask := taskDef - copyTask.ID = "" - copyTask.Name = taskDef.Name + "-copy" - m.openWizard(©Task) - } - case "d": - if taskDef, ok := m.currentTask(); ok { - if err := m.store.Delete(taskDef.ID); err != nil { - m.err = err - break - } - if err := m.store.Save(); err != nil { - m.err = err - break - } - m.tasks = m.store.Tasks - if m.selected >= len(m.tasks) && m.selected > 0 { - m.selected-- - } - m.status = "任务已删除" + // ── 任务列表加载完成 ── + case TasksLoadedMsg: + m.taskList.tasks = msg.Tasks + // 调整选中项不越界 + if m.taskList.selected >= len(msg.Tasks) { + m.taskList.selected = max(len(msg.Tasks)-1, 0) } - case "enter": - if taskDef, ok := m.currentTask(); ok { - if taskDef.ID == m.runningTaskID { - m.view = viewDashboard - } else { - m.reloadHistoryForSelectedTask() - m.view = viewTaskDetail - } - } - case "r": - if taskDef, ok := m.currentTask(); ok { - if m.runningTaskID != "" { - m.status = "已有任务正在运行中,请等待完成或进入仪表盘停止" - } else { - m.startTaskRun(taskDef) - } - } - case "q": - return m, tea.Quit - } - return m, nil -} - -func (m *Model) handleTaskDetailKey(msg tea.KeyMsg) (tea.Model, tea.Cmd) { - taskDef, ok := m.currentTask() - if !ok { - m.view = viewTaskList + m.status = "" + m.err = nil return m, nil - } - switch msg.String() { - case "b", "esc": - m.view = viewTaskList - case "e": - copyTask := taskDef - m.openWizard(©Task) - case "d": - if err := m.store.Delete(taskDef.ID); err != nil { - m.err = err - break - } - if err := m.store.Save(); err != nil { - m.err = err - break - } - m.tasks = m.store.Tasks - if m.selected >= len(m.tasks) && m.selected > 0 { - m.selected-- - } - m.view = viewTaskList - case "enter", "r": - if m.runningTaskID != "" && m.runningTaskID != taskDef.ID { - m.status = "已有任务正在运行中" - } else { - m.startTaskRun(taskDef) - if m.runningTaskID == taskDef.ID { - m.view = viewDashboard - } + // ── 任务保存完成(新建或更新) ── + case TaskSavedMsg: + m.status = fmt.Sprintf("任务 %q 已保存", msg.Task.Name) + // 若 AutoStart 且无活跃运行,立刻发起运行 + if msg.AutoStart && (m.dash == nil || !m.dash.isRunning()) { + return m, tea.Batch( + m.client.LoadTasksCmd(), + m.client.StartRunCmd(msg.Task.ID), + ) } - } + return m, m.client.LoadTasksCmd() - return m, nil -} + // ── 任务删除完成 ── + case TaskDeletedMsg: + m.status = "任务已删除" + m.view = viewTaskList + return m, m.client.LoadTasksCmd() -func (m *Model) handleWizardKey(msg tea.KeyMsg) (tea.Model, tea.Cmd) { - if m.wizard == nil { + // ── 历史加载完成 ── + case HistoryLoadedMsg: + m.hist = &historyState{taskID: msg.TaskID, history: msg.History} return m, nil - } - // Step 2 (confirm): only action keys, no text input - if m.wizard.step == 2 { - switch msg.String() { - case "esc": - m.wizard.step = 1 - m.wizard.fieldIndex = len(m.wizardStepFields(1)) - 1 - m.refreshWizardInput() - case "enter": - if err := m.saveWizard(); err != nil { - m.err = err - m.status = err.Error() - } - case "r": - if err := m.saveWizard(); err != nil { - m.err = err - m.status = err.Error() - return m, nil - } - if taskDef, ok := m.currentTask(); ok { - m.startTaskRun(taskDef) - if m.runningTaskID == taskDef.ID { - m.view = viewDashboard - } - } + // ── 运行启动 ── + case RunStartedMsg: + ch, cancel, firstCmd := m.client.SubscribeCmd(msg.RunID) + m.dash = &dashboardState{ + runID: msg.RunID, + taskID: msg.TaskID, + eventCh: ch, + cancelFn: cancel, + reqSel: -1, + } + m.view = viewDashboard + m.status = "" + return m, firstCmd + + // ── Server 事件(来自运行中订阅) ── + case ServerEventMsg: + return m.handleServerEvent(msg) + + // ── 运行状态快照(重入仪表盘时) ── + case RunStateMsg: + if m.dash != nil && msg.State != nil && m.dash.runID == msg.State.RunID { + m.dash.runState = msg.State } return m, nil - } - fields := m.wizardStepFields(m.wizard.step) - field := fields[m.wizard.fieldIndex] - - switch msg.String() { - case "esc": - if m.wizard.step > 0 { - m.wizard.step-- - m.wizard.fieldIndex = len(m.wizardStepFields(m.wizard.step)) - 1 - m.refreshWizardInput() - } else { - m.view = m.wizard.fromView - m.wizard = nil - } + // ── 报告生成完成 ── + case ReportGeneratedMsg: + m.status = fmt.Sprintf("报告已生成: %s", msg.Path) return m, nil - case "tab", "down", "j": - if field.kind == fieldText { - m.wizard.values[field.key] = m.wizard.input.Value() - } - m.advanceWizardField(1) - return m, nil - case "enter": - if field.kind == fieldText { - m.wizard.values[field.key] = m.wizard.input.Value() - } - if m.wizard.fieldIndex == len(fields)-1 { - m.wizard.step++ - m.wizard.fieldIndex = 0 - if m.wizard.step < 2 { - m.refreshWizardInput() - } - } else { - m.wizard.fieldIndex++ - m.refreshWizardInput() - } - return m, nil - case "shift+tab", "up", "k": - if field.kind == fieldText { - m.wizard.values[field.key] = m.wizard.input.Value() - } - m.advanceWizardField(-1) - return m, nil - case "left", "h": - m.cycleWizardField(-1) - return m, nil - case "right", "l", "space": - m.cycleWizardField(1) + + // ── 错误 ── + case ErrorMsg: + m.err = msg.Err return m, nil } - if field.kind == fieldText { - var cmd tea.Cmd - m.wizard.input, cmd = m.wizard.input.Update(msg) - m.wizard.values[field.key] = m.wizard.input.Value() - return m, cmd - } - return m, nil -} -func (m *Model) handleDashboardKey(msg tea.KeyMsg) (tea.Model, tea.Cmd) { - switch msg.String() { - case "s": - if m.activeRunner != nil { - m.activeRunner.Stop() - } - if m.activeTurbo != nil { - m.activeTurbo.Stop() - } - case "b", "esc": - // 返回列表,任务继续在后台运行 - m.view = viewTaskList - case "q": - return m, tea.Quit - } return m, nil } func (m *Model) View() string { switch m.view { + case viewTaskList: + return m.renderTaskList() case viewTaskDetail: return m.renderTaskDetail() case viewWizard: return m.renderWizard() case viewDashboard: return m.renderDashboard() - case viewResult: - return m.renderResult() - case viewTurboResult: - return m.renderTurboResult() - default: - return m.renderTaskList() - } -} - -func (m *Model) renderTaskList() string { - if m.width == 0 { - return "加载中..." - } - lastRunStr := "" - for _, t := range m.tasks { - if t.LastRunAt != nil { - lastRunStr = "最近: " + timeAgo(*t.LastRunAt) - break - } - } - header := m.renderHeader( - "AIT 任务中心", - fmt.Sprintf("已保存任务: %d %s", len(m.tasks), lastRunStr), - ) - footer := m.renderFooter( - "[↑↓] 选择", "[Enter] 详情", "[a] 新建", "[r] 运行", - "[e] 编辑", "[d] 删除", "[y] 复制", "[q] 退出", - ) - contentH := m.height - 2 - if contentH < 4 { - contentH = 4 - } - panelH := contentH - 2 - leftW := (m.width - 4) * 57 / 100 - rightW := m.width - 4 - leftW - leftContent := m.buildTaskListLeft(panelH, leftW) - rightContent := m.buildTaskListRight(panelH) - mid := m.dualColumnLayout(leftContent, rightContent, leftW, rightW, panelH) - return lipgloss.JoinVertical(lipgloss.Left, header, mid, footer) -} - -func (m *Model) buildTaskListLeft(maxH, width int) string { - var lines []string - lines = append(lines, m.styles.tableHead.Render( - fmt.Sprintf(" %-28s %-9s %-14s %s", "任务名称", "模式", "协议", "上次结果"), - )) - lines = append(lines, m.styles.muted.Render(strings.Repeat("─", width))) - if len(m.tasks) == 0 { - lines = append(lines, "") - lines = append(lines, m.styles.muted.Render(" 暂无任务 按 [a] 新建")) - return strings.Join(lines, "\n") - } - for i, t := range m.tasks { - if len(lines) >= maxH-1 { - break - } - // Mode: color-coded tag text with manual padding to 9 visual columns - var modeRendered string - if t.Input.Turbo { - modeRendered = m.styles.tagTurbo.Render("Turbo") - } else { - modeRendered = m.styles.tagStd.Render("标准") - } - modePad := 9 - lipgloss.Width(modeRendered) - if modePad < 0 { - modePad = 0 - } - modeCol := modeRendered + strings.Repeat(" ", modePad) - proto := shortProtocol(t.Input.NormalizedProtocol()) - lastResult := m.styles.muted.Render("从未运行") - if t.LastRunSummary != nil { - pct := t.LastRunSummary.SuccessRate - if pct >= 99 { - lastResult = m.styles.ok.Render(fmt.Sprintf("%.1f%%", pct)) - } else if pct >= 90 { - lastResult = m.styles.metricVal.Render(fmt.Sprintf("%.1f%%", pct)) - } else { - lastResult = m.styles.errStyle.Render(fmt.Sprintf("%.1f%%", pct)) - } - } - nameStr := truncate(t.Name, 28) - // Build row from parts so ANSI in modeCol doesn't break alignment - nameCol := fmt.Sprintf("%-28s ", nameStr) - protoCol := fmt.Sprintf("%-14s ", proto) - mainRow := " " + nameCol + modeCol + " " + protoCol + lastResult - if i == m.selected { - plainRow := " " + nameCol + fmt.Sprintf("%-9s ", func() string { - if t.Input.Turbo { - return "Turbo" - } - return "标准" - }()) + protoCol + lastResult - lines = append(lines, m.styles.tableRowSel.Width(width).Render("▶"+plainRow[1:])) - } else { - lines = append(lines, mainRow) - } - var sub string - if t.Input.Turbo { - tc := t.Input.TurboConfig - sub = fmt.Sprintf(" %s %d→%d +%d 每级%d", - truncate(t.Input.Model, 18), - tc.InitConcurrency, tc.MaxConcurrency, tc.StepSize, tc.LevelRequests) - } else { - sub = fmt.Sprintf(" %s 并发%d/请求%d", - truncate(t.Input.Model, 20), t.Input.Concurrency, t.Input.Count) - } - if i == m.selected { - lines = append(lines, m.styles.tableRowSel.Width(width).Render(sub)) - } else { - lines = append(lines, m.styles.muted.Render(sub)) - } - lines = append(lines, "") - } - return strings.Join(lines, "\n") -} - -func (m *Model) buildTaskListRight(maxH int) string { - var lines []string - lines = append(lines, m.styles.sectionHead.Render("快捷操作")) - lines = append(lines, "") - lines = append(lines, " "+m.styles.key.Render("[a]")+" 新建任务") - lines = append(lines, " "+m.styles.key.Render("[Enter]")+" 查看详情") - lines = append(lines, " "+m.styles.key.Render("[r]")+" 直接运行选中任务") - lines = append(lines, " "+m.styles.key.Render("[e]")+" 编辑 "+m.styles.key.Render("[d]")+" 删除 "+m.styles.key.Render("[y]")+" 复制") - lines = append(lines, "") - lines = append(lines, m.styles.muted.Render(strings.Repeat("─", 28))) - lines = append(lines, "") - lines = append(lines, m.styles.sectionHead.Render("最近执行")) - lines = append(lines, "") - count := 0 - for _, t := range m.tasks { - if t.LastRunSummary == nil { - continue - } - s := t.LastRunSummary - statusIcon := m.styles.ok.Render("✓") - if s.SuccessRate < 90 { - statusIcon = m.styles.errStyle.Render("✗") - } - lines = append(lines, fmt.Sprintf(" %s %-16s %.1f%% %.0f tok/s", - statusIcon, truncate(t.Name, 16), s.SuccessRate, s.AvgTPS)) - count++ - if count >= 5 || len(lines) >= maxH-2 { - break - } - } - if count == 0 { - lines = append(lines, m.styles.muted.Render(" 暂无记录")) - } - if m.status != "" { - lines = append(lines, "") - lines = append(lines, m.styles.muted.Render(m.status)) - } - if m.err != nil { - lines = append(lines, m.styles.errStyle.Render("错误: "+m.err.Error())) - } - return strings.Join(lines, "\n") -} - -func (m *Model) renderTaskDetail() string { - if m.width == 0 { - return "加载中..." - } - taskDef, ok := m.currentTask() - if !ok { - return m.styles.errStyle.Render("任务不存在") - } - updatedStr := "" - if !taskDef.UpdatedAt.IsZero() { - updatedStr = "更新: " + taskDef.UpdatedAt.Format("01-02 15:04") - } - lastRunStr := "从未运行" - if taskDef.LastRunAt != nil { - lastRunStr = "上次: " + timeAgo(*taskDef.LastRunAt) - } - header := m.renderHeader( - "AIT 任务详情 — "+truncate(taskDef.Name, 24), - updatedStr+" "+lastRunStr, - ) - footer := m.renderFooter("[Enter/r] 运行", "[e] 编辑", "[d] 删除", "[b] 返回") - contentH := m.height - 2 - histH := 9 - topH := contentH - histH - if topH < 6 { - topH = 6 - } - panelH := topH - 2 - leftW := (m.width - 4) * 57 / 100 - rightW := m.width - 4 - leftW - leftContent := m.buildDetailLeft(taskDef, panelH, leftW) - rightContent := m.buildDetailRight(taskDef) - top := m.dualColumnLayout(leftContent, rightContent, leftW, rightW, panelH) - histPanelH := histH - 2 - histContent := m.buildHistoryContent(histPanelH, m.width-4) - histPanel := lipgloss.NewStyle(). - Width(m.width - 2).Height(histPanelH). - Border(lipgloss.RoundedBorder()). - BorderForeground(colorPurple). - Render(histContent) - return lipgloss.JoinVertical(lipgloss.Left, header, top, histPanel, footer) -} - -func (m *Model) buildDetailLeft(t types.TaskDefinition, h, w int) string { - var lines []string - lines = append(lines, m.styles.sectionHead.Render("配置摘要")) - lines = append(lines, "") - maxURLLen := w - 14 - if maxURLLen < 20 { - maxURLLen = 20 - } - rows := [][2]string{ - {"协议", t.Input.NormalizedProtocol()}, - {"接口地址", truncate(t.Input.ResolvedEndpointURL(), maxURLLen)}, - {"模型", t.Input.Model}, - } - if t.Input.Turbo { - tc := t.Input.TurboConfig - rows = append(rows, - [2]string{"模式", "Turbo 模式"}, - [2]string{"爬坡", fmt.Sprintf("%d → %d 步进+%d 每级%d", - tc.InitConcurrency, tc.MaxConcurrency, tc.StepSize, tc.LevelRequests)}, - [2]string{"停止条件", fmt.Sprintf("成功率<%.0f%% 或延迟>%s", - tc.MinSuccessRate*100, tc.MaxLatency)}, - ) - } else { - rows = append(rows, - [2]string{"模式", "标准模式"}, - [2]string{"并发", fmt.Sprintf("%d", t.Input.Concurrency)}, - [2]string{"请求数", fmt.Sprintf("%d", t.Input.Count)}, - [2]string{"超时", t.Input.Timeout.String()}, - ) - } - rows = append(rows, - [2]string{"流式", boolLabel(t.Input.Stream)}, - [2]string{"Prompt", promptSummary(t.Input)}, - ) - for _, row := range rows { - lines = append(lines, fmt.Sprintf(" %s %s", - m.styles.label.Render(fmt.Sprintf("%-8s", row[0])), - m.styles.value.Render(row[1]))) - } - return strings.Join(lines, "\n") -} - -func (m *Model) buildDetailRight(t types.TaskDefinition) string { - var lines []string - lines = append(lines, m.styles.sectionHead.Render("最近一次结果")) - lines = append(lines, "") - if t.LastRunSummary == nil { - lines = append(lines, m.styles.muted.Render(" 从未运行")) - return strings.Join(lines, "\n") - } - s := t.LastRunSummary - statusStr := m.styles.ok.Render("✓ 完成") - if s.SuccessRate < 90 { - statusStr = m.styles.errStyle.Render("✗ 异常") - } - rows := [][2]string{ - {"状态", statusStr}, - {"成功率", fmt.Sprintf("%.1f%%", s.SuccessRate)}, - {"avg TTFT", s.AvgTTFT.Truncate(time.Millisecond).String()}, - {"avg TPS", fmt.Sprintf("%.1f tok/s", s.AvgTPS)}, - {"缓存命中", fmt.Sprintf("%.1f%%", s.CacheHitRate)}, - } - if s.MaxStableConcurrency > 0 { - rows = append(rows, [2]string{"最大稳定并发", fmt.Sprintf("%d", s.MaxStableConcurrency)}) - } - for _, row := range rows { - lines = append(lines, fmt.Sprintf(" %s %s", - m.styles.label.Render(fmt.Sprintf("%-10s", row[0])), - row[1])) - } - return strings.Join(lines, "\n") -} - -func (m *Model) buildHistoryContent(maxH, width int) string { - var lines []string - lines = append(lines, m.styles.sectionHead.Render("最近运行记录")+" "+ - m.styles.tableHead.Render(fmt.Sprintf("%-19s %-6s %-8s %-12s %-10s %-8s", - "时间", "模式", "成功率", "TTFT", "TPS", "Cache"))) - lines = append(lines, m.styles.muted.Render(strings.Repeat("─", width-2))) - if len(m.history) == 0 { - lines = append(lines, m.styles.muted.Render(" 暂无历史记录")) - return strings.Join(lines, "\n") - } - for _, run := range m.history { - if len(lines) >= maxH { - break - } - status := m.styles.ok.Render("✓") - if run.SuccessRate < 90 { - status = m.styles.errStyle.Render("✗") - } - modeShort := run.Mode - if len(modeShort) > 5 { - modeShort = modeShort[:5] - } - lines = append(lines, fmt.Sprintf(" %s %-19s %-6s %-8.1f%% %-12s %-10.1f %-8.1f%%", - status, - run.FinishedAt.Format("2006-01-02 15:04"), - modeShort, - run.SuccessRate, - run.AvgTTFT.Truncate(time.Millisecond), - run.AvgTPS, - run.CacheHitRate)) - } - return strings.Join(lines, "\n") -} - -func (m *Model) renderWizard() string { - if m.width == 0 || m.wizard == nil { - return "加载中..." - } - stepTitles := []string{"1/3 · 基本信息", "2/3 · 测试参数", "3/3 · 确认保存"} - header := m.renderHeader("AIT 任务向导", "步骤 "+stepTitles[m.wizard.step]) - var footer string - if m.wizard.step < 2 { - footer = m.renderFooter("[Tab/↓] 下一项", "[↑] 上一项", "[←→] 切换选项", "[Enter] 下一步", "[Esc] 返回") - } else { - footer = m.renderFooter("[Enter] 保存任务", "[r] 保存并运行", "[Esc] 返回修改") - } - contentH := m.height - 2 - dialogW := m.width - 6 - if dialogW > 78 { - dialogW = 78 - } - if dialogW < 40 { - dialogW = 40 - } - dialogContentW := dialogW - 6 // -2 border -4 padding - var content string - switch m.wizard.step { - case 0: - content = m.renderWizardStep0(dialogContentW) - case 1: - content = m.renderWizardStep1(dialogContentW) - case 2: - content = m.renderWizardStep2(dialogContentW) - } - dialog := m.styles.dialog.Width(dialogContentW).Render(content) - dialogH := lipgloss.Height(dialog) - padTop := (contentH - dialogH) / 2 - if padTop < 0 { - padTop = 0 - } - centeredDialog := lipgloss.Place(m.width, contentH, - lipgloss.Center, lipgloss.Top, - strings.Repeat("\n", padTop)+dialog) - return lipgloss.JoinVertical(lipgloss.Left, header, centeredDialog, footer) -} - -func (m *Model) renderWizardStep0(w int) string { - fields := m.wizardStepFields(0) - var lines []string - // Step indicator: ● ○ ○ - lines = append(lines, m.styles.stepActive.Render("●")+" "+ - m.styles.stepTodo.Render("○")+" "+ - m.styles.stepTodo.Render("○")+" "+ - m.styles.sectionHead.Render("基本信息")) - lines = append(lines, "") - for i, field := range fields { - active := i == m.wizard.fieldIndex - lines = append(lines, m.renderWizardField(field, active)) - if field.key == "protocol" { - for pi, p := range protocolOptions { - bullet := " ○ " - if pi == m.wizard.protocolIndex { - bullet = " " + m.styles.ok.Render("●") + " " - } - lines = append(lines, " "+bullet+p) - } - } - lines = append(lines, "") - } - return strings.Join(lines, "\n") -} - -func (m *Model) renderWizardStep1(w int) string { - fields := m.wizardStepFields(1) - var lines []string - // Step indicator: ✓ ● ○ - lines = append(lines, m.styles.stepDone.Render("✓")+" "+ - m.styles.stepActive.Render("●")+" "+ - m.styles.stepTodo.Render("○")+" "+ - m.styles.sectionHead.Render("测试参数")) - lines = append(lines, "") - for i, field := range fields { - active := i == m.wizard.fieldIndex - lines = append(lines, m.renderWizardField(field, active)) - if field.key == "mode" { - opts := []string{modeStandard, modeTurbo} - labels := []string{"标准模式", "Turbo 模式"} - for oi, opt := range opts { - bullet := " ○ " - if opt == m.wizard.mode { - bullet = " " + m.styles.ok.Render("●") + " " - } - lines = append(lines, " "+bullet+labels[oi]) - } - } - if field.key == "prompt_mode" { - pmLabels := []string{"直接输入", "文件路径", "按长度生成"} - for pi, pl := range pmLabels { - bullet := " ○ " - if pi == m.wizard.promptModeIndex { - bullet = " " + m.styles.ok.Render("●") + " " - } - lines = append(lines, " "+bullet+pl) - } - } - lines = append(lines, "") - } - return strings.Join(lines, "\n") -} - -func (m *Model) renderWizardStep2(w int) string { - var lines []string - // Step indicator: ✓ ✓ ● - lines = append(lines, m.styles.stepDone.Render("✓")+" "+ - m.styles.stepDone.Render("✓")+" "+ - m.styles.stepActive.Render("●")+" "+ - m.styles.sectionHead.Render("确认保存")) - lines = append(lines, "") - d, err := buildTaskDefinition(m.wizard) - if err != nil { - lines = append(lines, m.styles.errStyle.Render("配置有误: "+err.Error())) - lines = append(lines, "") - lines = append(lines, m.styles.muted.Render("按 [Esc] 返回修改")) - return strings.Join(lines, "\n") - } - rows := [][2]string{ - {"任务名称", d.Name}, - {"协议", d.Input.NormalizedProtocol()}, - {"接口地址", truncate(d.Input.ResolvedEndpointURL(), w-16)}, - {"API 密钥", maskAPIKey(d.Input.ApiKey)}, - {"测试模型", d.Input.Model}, - } - if d.Input.Turbo { - tc := d.Input.TurboConfig - rows = append(rows, - [2]string{"测试模式", "Turbo 模式"}, - [2]string{"并发爬坡", fmt.Sprintf("%d → %d 步进+%d 每级%d", - tc.InitConcurrency, tc.MaxConcurrency, tc.StepSize, tc.LevelRequests)}, - [2]string{"停止条件", fmt.Sprintf("成功率<%.0f%% 或延迟>%s", - tc.MinSuccessRate*100, tc.MaxLatency)}, - ) - } else { - rows = append(rows, - [2]string{"测试模式", "标准模式"}, - [2]string{"并发/请求", fmt.Sprintf("%d / %d", d.Input.Concurrency, d.Input.Count)}, - [2]string{"超时", d.Input.Timeout.String()}, - ) - } - rows = append(rows, - [2]string{"流式", boolLabel(d.Input.Stream)}, - [2]string{"Prompt", promptSummary(d.Input)}, - ) - for _, row := range rows { - lines = append(lines, fmt.Sprintf(" %s %s", - m.styles.label.Render(fmt.Sprintf("%-10s", row[0])), - row[1])) + case viewReqDetail: + return m.renderReqDetail() } - lines = append(lines, "") - lines = append(lines, m.styles.ok.Render(" ▶ 按 [Enter] 保存,[r] 保存并立即运行")) - return strings.Join(lines, "\n") + return "未知视图" } -func (m *Model) renderWizardField(field wizardField, active bool) string { - var val string - if field.kind == fieldText && active { - val = m.wizard.input.View() - } else { - val = m.displayWizardValue(field) - } - labelStr := fmt.Sprintf("%-12s", field.label) - if active { - return m.styles.cursor.Render("▶") + " " + - m.styles.fieldActive.Render(labelStr) + " " + val - } - return " " + m.styles.fieldIdle.Render(labelStr) + " " + m.styles.muted.Render(val) -} - -func (m *Model) renderDashboard() string { - if m.width == 0 { - return "加载中..." - } - taskName, protocol, modelName := "", "", "" - isTurbo := false - totalReqs, concurrency := 0, 0 - if m.runningTask != nil { - taskName = m.runningTask.Name - protocol = shortProtocol(m.runningTask.Input.NormalizedProtocol()) - modelName = m.runningTask.Input.Model - isTurbo = m.runningTask.Input.Turbo - totalReqs = m.runningTask.Input.Count - concurrency = m.runningTask.Input.Concurrency - if isTurbo { - totalReqs = m.runningTask.Input.TurboConfig.LevelRequests - concurrency = m.runningTask.Input.TurboConfig.InitConcurrency - } - } - title := "AIT 正在测试 — " + modelName - if isTurbo { - title = "AIT Turbo 探测 — " + modelName - } - header := m.renderHeader(title, - fmt.Sprintf("任务: %s 协议: %s", truncate(taskName, 20), protocol)) - footer := m.renderFooter("[s] 停止", "[q] 退出") - contentH := m.height - 2 - logH := 7 - topH := contentH - logH - if topH < 6 { - topH = 6 - } - panelH := topH - 2 - leftW := (m.width - 4) * 50 / 100 - rightW := m.width - 4 - leftW - var leftContent, rightContent string - if isTurbo { - leftContent = m.buildTurboDashLeft(panelH) - rightContent = m.buildTurboDashRight(panelH) - } else { - leftContent = m.buildStdDashLeft(panelH, totalReqs, concurrency) - rightContent = m.buildStdDashRight(panelH) - } - top := m.dualColumnLayout(leftContent, rightContent, leftW, rightW, panelH) - logPanelH := logH - 2 - logContent := m.buildLogPanel(logPanelH, m.width-4) - logPanel := lipgloss.NewStyle(). - Width(m.width - 2).Height(logPanelH). - Border(lipgloss.RoundedBorder()). - BorderForeground(colorPurple). - Render(logContent) - return lipgloss.JoinVertical(lipgloss.Left, header, top, logPanel, footer) -} +// ─── 键盘分发 ───────────────────────────────────────────────────────────────── -func (m *Model) buildStdDashLeft(h, total, concurrency int) string { - p := m.progress - completed := p.CompletedCount - failed := p.FailedCount - elapsed := time.Duration(0) - if !p.StartTime.IsZero() { - elapsed = time.Since(p.StartTime) - } - var estRemaining string - if completed > 0 && total > completed && elapsed > 0 { - rate := float64(completed) / elapsed.Seconds() - remaining := float64(total-completed) / rate - estRemaining = "~" + time.Duration(remaining*float64(time.Second)).Truncate(time.Second).String() - } - barW := 20 - var lines []string - lines = append(lines, m.styles.sectionHead.Render("进度")) - lines = append(lines, "") - lines = append(lines, fmt.Sprintf(" %s %s %d", - m.styles.label.Render("完成"), progressBar(completed, total, barW), completed)) - lines = append(lines, fmt.Sprintf(" %s %s %d", - m.styles.errStyle.Render("失败"), progressBarRed(failed, total, barW), failed)) - lines = append(lines, fmt.Sprintf(" %s %s %d", - m.styles.muted.Render("总计"), progressBar(total, total, barW), total)) - lines = append(lines, "") - lines = append(lines, fmt.Sprintf(" %-10s %s", - m.styles.label.Render("已用时"), - elapsed.Truncate(100*time.Millisecond))) - if estRemaining != "" { - lines = append(lines, fmt.Sprintf(" %-10s %s", - m.styles.label.Render("预计剩余"), - estRemaining)) - } - lines = append(lines, fmt.Sprintf(" %-10s %d 活跃", - m.styles.label.Render("并发槽"), - concurrency)) - return strings.Join(lines, "\n") -} - -func (m *Model) buildStdDashRight(h int) string { - p := m.progress - var lines []string - lines = append(lines, m.styles.sectionHead.Render("实时指标")) - lines = append(lines, "") - successRate := 0.0 - if p.CompletedCount > 0 { - successRate = float64(p.CompletedCount-p.FailedCount) / float64(p.CompletedCount) * 100 - } - srBar := progressBar(int(successRate), 100, 16) - lines = append(lines, fmt.Sprintf(" 成功率 %s %.1f%%", srBar, successRate)) - lines = append(lines, "") - avgTPS := 0.0 - if len(p.OutputTokenCounts) > 0 && len(p.TotalTimes) > 0 { - totalTokens := 0 - for _, tok := range p.OutputTokenCounts { - totalTokens += tok - } - totalTimeS := 0.0 - for _, d := range p.TotalTimes { - totalTimeS += d.Seconds() - } - if totalTimeS > 0 { - avgTPS = float64(totalTokens) / totalTimeS - } - } - avgTTFT := time.Duration(0) - if len(p.TTFTs) > 0 { - sum := time.Duration(0) - for _, d := range p.TTFTs { - sum += d - } - avgTTFT = sum / time.Duration(len(p.TTFTs)) - } - avgTotal := time.Duration(0) - if len(p.TotalTimes) > 0 { - sum := time.Duration(0) - for _, d := range p.TotalTimes { - sum += d - } - avgTotal = sum / time.Duration(len(p.TotalTimes)) - } - avgCache := 0.0 - if len(p.CacheHitRates) > 0 { - sum := 0.0 - for _, r := range p.CacheHitRates { - sum += r - } - avgCache = sum / float64(len(p.CacheHitRates)) * 100 - } - rows := [][2]string{ - {"avg TPS", fmt.Sprintf("%.1f tok/s", avgTPS)}, - {"avg TTFT", avgTTFT.Truncate(time.Millisecond).String()}, - {"缓存命中率", fmt.Sprintf("%.1f%%", avgCache)}, - {"avg 总耗时", avgTotal.Truncate(time.Millisecond).String()}, - } - for _, row := range rows { - lines = append(lines, fmt.Sprintf(" %-12s %s", - m.styles.label.Render(row[0]), - m.styles.metricVal.Render(row[1]))) +func (m *Model) handleKey(msg tea.KeyMsg) (tea.Model, tea.Cmd) { + switch m.view { + case viewTaskList: + return m.handleTaskListKey(msg) + case viewTaskDetail: + return m, m.handleTaskDetailKey(msg) + case viewWizard: + return m.handleWizardKey(msg) + case viewDashboard: + return m.handleDashboardKey(msg) + case viewReqDetail: + return m.handleReqDetailKey(msg) } - return strings.Join(lines, "\n") -} - -func (m *Model) buildTurboDashLeft(h int) string { - var lines []string - lines = append(lines, m.styles.sectionHead.Render("Turbo 探测中")) - lines = append(lines, "") - elapsed := time.Since(m.runStartedAt) - lines = append(lines, fmt.Sprintf(" %s %s", - m.styles.label.Render("已用时"), - elapsed.Truncate(time.Second))) - lines = append(lines, "") - lines = append(lines, m.styles.muted.Render(" 正在逐级探测最大稳定并发...")) - lines = append(lines, m.styles.muted.Render(" 完成后将自动显示结果")) - return strings.Join(lines, "\n") -} - -func (m *Model) buildTurboDashRight(h int) string { - var lines []string - lines = append(lines, m.styles.sectionHead.Render("探测状态")) - lines = append(lines, "") - lines = append(lines, " "+m.styles.ok.Render("●")+" 测试运行中") - lines = append(lines, m.styles.muted.Render(" 等待完成...")) - return strings.Join(lines, "\n") + return m, nil } -func (m *Model) buildLogPanel(maxH, width int) string { - var lines []string - lines = append(lines, m.styles.sectionHead.Render("请求日志")) - if len(m.requestLog) == 0 { - lines = append(lines, m.styles.muted.Render(" 等待请求完成...")) - return strings.Join(lines, "\n") - } - start := 0 - if len(m.requestLog) > maxH-1 { - start = len(m.requestLog) - (maxH - 1) - } - for _, entry := range m.requestLog[start:] { - // Color log entries based on their leading status marker - if strings.HasPrefix(entry, "✓") || strings.HasPrefix(entry, "✔") { - lines = append(lines, " "+m.styles.logOk.Render(entry)) - } else if strings.HasPrefix(entry, "✗") || strings.HasPrefix(entry, "✘") || strings.HasPrefix(entry, "ERR") { - lines = append(lines, " "+m.styles.logErr.Render(entry)) - } else { - lines = append(lines, " "+m.styles.muted.Render(entry)) - } - } - return strings.Join(lines, "\n") -} +// ─── Server 事件处理 ────────────────────────────────────────────────────────── -func (m *Model) renderResult() string { - if m.width == 0 { - return "加载中..." - } - header := m.renderHeader("AIT 测试完成", "标准模式结果") - footer := m.renderFooter("[b/Esc] 返回详情") - if m.runResult == nil { - return lipgloss.JoinVertical(lipgloss.Left, header, - m.styles.errStyle.Render("结果为空"), footer) - } - r := m.runResult - panelW := m.width - 4 - panelH := m.height - 4 - var lines []string - lines = append(lines, m.styles.sectionHead.Render(fmt.Sprintf("任务完成 — %s", r.Model))) - lines = append(lines, "") - rows := [][2]string{ - {"协议", r.Protocol}, - {"接口地址", truncate(r.EndpointURL, panelW-16)}, - {"模型", r.Model}, - {"成功率", fmt.Sprintf("%.1f%%", r.SuccessRate)}, - {"总请求数", fmt.Sprintf("%d", r.TotalRequests)}, - {"avg TTFT", r.AvgTTFT.Truncate(time.Millisecond).String()}, - {"avg TPS", fmt.Sprintf("%.2f tok/s", r.AvgTPS)}, - {"缓存命中率", fmt.Sprintf("%.1f%%", r.AvgCacheHitRate*100)}, - {"avg 总耗时", r.AvgTotalTime.Truncate(time.Millisecond).String()}, - {"总测试时长", r.TotalTime.Truncate(time.Second).String()}, - } - for _, row := range rows { - lines = append(lines, fmt.Sprintf(" %s %s", - m.styles.label.Render(fmt.Sprintf("%-12s", row[0])), - m.styles.value.Render(row[1]))) +func (m *Model) handleServerEvent(msg ServerEventMsg) (tea.Model, tea.Cmd) { + if m.dash == nil { + return m, nil } - content := strings.Join(lines, "\n") - panel := lipgloss.NewStyle(). - Width(panelW).Height(panelH). - Border(lipgloss.RoundedBorder()). - BorderForeground(colorPurple). - Render(content) - return lipgloss.JoinVertical(lipgloss.Left, header, panel, footer) -} + e := msg.Event -func (m *Model) renderTurboResult() string { - if m.width == 0 { - return "加载中..." - } - header := m.renderHeader("AIT Turbo 完成", "Turbo 模式结果") - footer := m.renderFooter("[b/Esc] 返回详情") - if m.turboResult == nil { - return lipgloss.JoinVertical(lipgloss.Left, header, - m.styles.errStyle.Render("Turbo 结果为空"), footer) - } - r := m.turboResult - panelW := m.width - 4 - panelH := m.height - 4 - var lines []string - lines = append(lines, m.styles.sectionHead.Render(fmt.Sprintf( - "Turbo 完成 — %s 最大稳定并发: %d 峰值 TPS: %.1f", - r.Model, r.MaxStableConcurrency, r.PeakTPS))) - lines = append(lines, "") - lines = append(lines, m.styles.tableHead.Render(fmt.Sprintf( - " %-6s %-8s %-10s %-10s %-8s %-8s %s", - "并发", "成功率", "TPS", "TTFT", "Cache", "总耗时", "状态"))) - lines = append(lines, m.styles.muted.Render(strings.Repeat("─", panelW-4))) - for _, level := range r.Levels { - status := m.styles.ok.Render("✓ 稳定") - if !level.Stable { - status = m.styles.errStyle.Render("✗ 不稳定") - } - marker := " " - if level.Concurrency == r.MaxStableConcurrency { - marker = m.styles.cursor.Render("▶ ") + switch e.Kind { + case server.EventProgressTick: + if rs, ok := e.Payload.(*server.RunState); ok { + m.dash.runState = rs } - lines = append(lines, fmt.Sprintf("%s%-6d %-8.1f%% %-10.1f %-10s %-8.1f%% %-8s %s", - marker, - level.Concurrency, - level.SuccessRate*100, - level.AvgTPS, - level.AvgTTFT.Truncate(time.Millisecond), - level.CacheHitRate*100, - level.AvgTotalTime.Truncate(time.Millisecond), - status)) - } - lines = append(lines, "") - lines = append(lines, m.styles.muted.Render(" 停止原因: "+r.StopReason)) - content := strings.Join(lines, "\n") - panel := lipgloss.NewStyle(). - Width(panelW).Height(panelH). - Border(lipgloss.RoundedBorder()). - BorderForeground(colorPurple). - Render(content) - return lipgloss.JoinVertical(lipgloss.Left, header, panel, footer) -} - -func (m *Model) renderHeader(left, right string) string { - if m.width == 0 { - return "" - } - // Each part gets the same header background so the bar spans the full width. - leftStyled := lipgloss.NewStyle(). - Background(colorHeaderBg).Bold(true).Foreground(colorPink). - Render(" ◆ " + left) - rightStyled := lipgloss.NewStyle(). - Background(colorHeaderBg).Foreground(colorHeaderFg). - Render(right + " ") - lw := lipgloss.Width(leftStyled) - rw := lipgloss.Width(rightStyled) - gap := m.width - lw - rw - if gap < 0 { - gap = 0 - } - spacer := lipgloss.NewStyle().Background(colorHeaderBg).Render(strings.Repeat(" ", gap)) - return leftStyled + spacer + rightStyled -} - -func (m *Model) renderFooter(hints ...string) string { - if m.width == 0 { - return "" - } - // Left: colored AIT brand badge - leftBadge := lipgloss.NewStyle(). - Background(colorPurple).Foreground(colorWhite).Bold(true). - Render(" ◆ AIT ") - // Right: dim version badge - rightBadge := lipgloss.NewStyle(). - Background(colorHeaderBg).Foreground(colorHeaderFg). - Render(" v0.1 ") - // Middle: key hints in pink on footer bg - var parts []string - for _, h := range hints { - parts = append(parts, lipgloss.NewStyle().Foreground(colorPink).Render(h)) - } - hintsStr := " " + strings.Join(parts, " ") - lw := lipgloss.Width(leftBadge) - rw := lipgloss.Width(rightBadge) - hw := lipgloss.Width(hintsStr) - gap := m.width - lw - rw - hw - if gap < 0 { - gap = 0 - } - middle := lipgloss.NewStyle(). - Background(colorFooterBg).Foreground(colorMuted). - Render(hintsStr + strings.Repeat(" ", gap)) - return leftBadge + middle + rightBadge -} - -func (m *Model) dualColumnLayout(leftContent, rightContent string, leftW, rightW, h int) string { - bc := colorPurple - leftPane := lipgloss.NewStyle(). - Width(leftW).Height(h). - Border(lipgloss.RoundedBorder()). - BorderForeground(bc). - Render(leftContent) - rightPane := lipgloss.NewStyle(). - Width(rightW).Height(h). - Border(lipgloss.RoundedBorder()). - BorderForeground(bc). - Render(rightContent) - return lipgloss.JoinHorizontal(lipgloss.Top, leftPane, rightPane) -} - -func progressBar(current, total, width int) string { - if total <= 0 || width <= 0 { - return lipgloss.NewStyle().Foreground(lipgloss.Color("238")).Render(strings.Repeat("░", width)) - } - filled := current * width / total - if filled > width { - filled = width - } - bar := lipgloss.NewStyle().Foreground(colorGreen).Render(strings.Repeat("█", filled)) - empty := lipgloss.NewStyle().Foreground(lipgloss.Color("238")).Render(strings.Repeat("░", width-filled)) - return bar + empty -} - -// progressBarRed renders a red-tinted progress bar for failure/error metrics. -func progressBarRed(current, total, width int) string { - if total <= 0 || width <= 0 { - return lipgloss.NewStyle().Foreground(lipgloss.Color("238")).Render(strings.Repeat("░", width)) - } - filled := current * width / total - if filled > width { - filled = width - } - bar := lipgloss.NewStyle().Foreground(colorRed).Render(strings.Repeat("█", filled)) - empty := lipgloss.NewStyle().Foreground(lipgloss.Color("238")).Render(strings.Repeat("░", width-filled)) - return bar + empty -} - -func truncate(s string, n int) string { - if n <= 0 || len(s) <= n { - return s - } - if n <= 3 { - return s[:n] - } - return s[:n-3] + "..." -} - -func timeAgo(t time.Time) string { - d := time.Since(t) - if d < time.Minute { - return fmt.Sprintf("%ds 前", int(d.Seconds())) - } - if d < time.Hour { - return fmt.Sprintf("%dm 前", int(d.Minutes())) - } - if d < 24*time.Hour { - return fmt.Sprintf("%dh 前", int(d.Hours())) - } - return t.Format("01-02 15:04") -} - -func shortProtocol(p string) string { - p = strings.ReplaceAll(p, "openai-", "") - p = strings.ReplaceAll(p, "anthropic-", "") - return p -} - -func maskAPIKey(key string) string { - if len(key) == 0 { - return "(空)" - } - if len(key) <= 8 { - return strings.Repeat("•", len(key)) - } - return key[:4] + strings.Repeat("•", len(key)-8) + key[len(key)-4:] -} - -func (m *Model) currentTask() (types.TaskDefinition, bool) { - if len(m.tasks) == 0 || m.selected < 0 || m.selected >= len(m.tasks) { - return types.TaskDefinition{}, false - } - return m.tasks[m.selected], true -} -func (m *Model) openWizard(existing *types.TaskDefinition) { - state := newWizardState(existing, m.view, m.config) - m.wizard = state - m.view = viewWizard - m.refreshWizardInput() -} - -func newWizardState(existing *types.TaskDefinition, from viewState, cfg *config.Config) *wizardState { - input := textinput.New() - input.Width = 72 - input.Prompt = "" - values := map[string]string{ - "name": "", - "endpoint": "", - "apiKey": "", - "model": "", - "concurrency": "5", - "count": "100", - "timeout": "30s", - "turbo_init": "1", - "turbo_max": "50", - "turbo_step": "2", - "turbo_level_requests": "30", - "turbo_min_success": "0.9", - "turbo_max_latency": "10s", - "prompt_value": "你好,介绍一下你自己。", - } - state := &wizardState{ - fromView: from, - input: input, - values: values, - protocolIndex: protocolIndex(cfg.DefaultProtocol), - mode: modeStandard, - promptModeIndex: 0, - stream: true, - thinking: false, - report: true, - } - if existing != nil { - state.editingTaskID = existing.ID - state.createdAt = existing.CreatedAt - state.lastRunAt = existing.LastRunAt - state.lastRunSummary = existing.LastRunSummary - state.values["name"] = existing.Name - state.values["endpoint"] = existing.Input.ResolvedEndpointURL() - state.values["apiKey"] = existing.Input.ApiKey - state.values["model"] = existing.Input.Model - state.protocolIndex = protocolIndex(existing.Input.NormalizedProtocol()) - state.stream = existing.Input.Stream - state.thinking = existing.Input.Thinking - state.report = existing.Input.Report - if existing.Input.Turbo { - state.mode = modeTurbo - state.values["turbo_init"] = strconv.Itoa(existing.Input.TurboConfig.InitConcurrency) - state.values["turbo_max"] = strconv.Itoa(existing.Input.TurboConfig.MaxConcurrency) - state.values["turbo_step"] = strconv.Itoa(existing.Input.TurboConfig.StepSize) - state.values["turbo_level_requests"] = strconv.Itoa(existing.Input.TurboConfig.LevelRequests) - state.values["turbo_min_success"] = strconv.FormatFloat(existing.Input.TurboConfig.MinSuccessRate, 'f', -1, 64) - state.values["turbo_max_latency"] = existing.Input.TurboConfig.MaxLatency.String() - } else { - state.values["concurrency"] = strconv.Itoa(existing.Input.Concurrency) - state.values["count"] = strconv.Itoa(existing.Input.Count) - if existing.Input.Timeout > 0 { - state.values["timeout"] = existing.Input.Timeout.String() - } - } - switch existing.Input.PromptMode { - case promptModeFile: - state.promptModeIndex = 1 - state.values["prompt_value"] = existing.Input.PromptFile - case promptModeGenerated: - state.promptModeIndex = 2 - state.values["prompt_value"] = strconv.Itoa(existing.Input.PromptLength) - default: - state.promptModeIndex = 0 - state.values["prompt_value"] = existing.Input.PromptText + case server.EventRequestDone: + if rs, ok := e.Payload.(*server.RunState); ok { + m.dash.runState = rs } - } - return state -} -func protocolIndex(protocol string) int { - for i, item := range protocolOptions { - if item == types.NormalizeProtocol(protocol) { - return i + case server.EventLevelDone: + if rs, ok := e.Payload.(*server.RunState); ok { + m.dash.runState = rs } - } - return 0 -} -func (m *Model) wizardStepFields(step int) []wizardField { - switch step { - case 0: - return []wizardField{ - {key: "name", label: "任务名称", kind: fieldText}, - {key: "protocol", label: "协议类型", kind: fieldSelect}, - {key: "endpoint", label: "完整接口地址", kind: fieldText}, - {key: "apiKey", label: "API 密钥", kind: fieldText}, - {key: "model", label: "测试模型", kind: fieldText}, - } - case 1: - fields := []wizardField{ - {key: "mode", label: "运行模式", kind: fieldSelect}, + case server.EventRunComplete: + if rs, ok := e.Payload.(*server.RunState); ok { + m.dash.runState = rs } - if m.wizard.mode == modeTurbo { - fields = append(fields, - wizardField{key: "turbo_init", label: "初始并发", kind: fieldText}, - wizardField{key: "turbo_max", label: "最大并发", kind: fieldText}, - wizardField{key: "turbo_step", label: "步进值", kind: fieldText}, - wizardField{key: "turbo_level_requests", label: "每级请求数", kind: fieldText}, - wizardField{key: "turbo_min_success", label: "最小成功率", kind: fieldText}, - wizardField{key: "turbo_max_latency", label: "最大平均延迟", kind: fieldText}, - ) - } else { - fields = append(fields, - wizardField{key: "concurrency", label: "并发数", kind: fieldText}, - wizardField{key: "count", label: "请求总数", kind: fieldText}, - wizardField{key: "timeout", label: "超时时间", kind: fieldText}, - ) - } - fields = append(fields, - wizardField{key: "stream", label: "流式模式", kind: fieldToggle}, - wizardField{key: "thinking", label: "Thinking 模式", kind: fieldToggle}, - wizardField{key: "report", label: "生成报告", kind: fieldToggle}, - wizardField{key: "prompt_mode", label: "Prompt 方式", kind: fieldSelect}, - wizardField{key: "prompt_value", label: promptValueLabel(m.wizard.promptModeIndex), kind: fieldText}, + // 运行结束后保留 dash 供用户查阅,切换到详情页 + m.view = viewTaskDetail + return m, tea.Batch( + m.client.LoadTasksCmd(), + m.client.LoadHistoryCmd(m.dash.taskID, 10), ) - return fields - default: - return nil - } -} - -func (m *Model) advanceWizardField(delta int) { - if m.wizard == nil { - return - } - fields := m.wizardStepFields(m.wizard.step) - next := m.wizard.fieldIndex + delta - if next < 0 { - if m.wizard.step > 0 { - m.wizard.step-- - prevFields := m.wizardStepFields(m.wizard.step) - m.wizard.fieldIndex = len(prevFields) - 1 - m.refreshWizardInput() - } - return - } - if next >= len(fields) { - m.wizard.step++ - m.wizard.fieldIndex = 0 - if m.wizard.step < 2 { - m.refreshWizardInput() - } - return - } - m.wizard.fieldIndex = next - m.refreshWizardInput() -} - -func promptValueLabel(promptModeIndex int) string { - switch promptModeOptions[promptModeIndex] { - case promptModeFile: - return "Prompt 文件路径" - case promptModeGenerated: - return "Prompt 生成长度" - default: - return "Prompt 文本" - } -} - -func (m *Model) currentWizardField() wizardField { - if m.wizard == nil { - return wizardField{} - } - fields := m.wizardStepFields(m.wizard.step) - if len(fields) == 0 || m.wizard.fieldIndex >= len(fields) { - return wizardField{} - } - return fields[m.wizard.fieldIndex] -} -func (m *Model) refreshWizardInput() { - field := m.currentWizardField() - m.wizard.input.Blur() - m.wizard.input.Focus() - m.wizard.input.EchoMode = textinput.EchoNormal - if field.key == "apiKey" { - m.wizard.input.EchoMode = textinput.EchoPassword - } - m.wizard.input.SetValue(m.wizard.values[field.key]) - if field.key == "prompt_value" { - m.wizard.input.Placeholder = promptValueLabel(m.wizard.promptModeIndex) - } else { - m.wizard.input.Placeholder = field.label - } - if field.kind != fieldText { - m.wizard.input.SetValue("") - } -} - -func (m *Model) cycleWizardField(delta int) { - if m.wizard == nil { - return - } - field := m.currentWizardField() - switch field.key { - case "protocol": - m.wizard.protocolIndex = wrapIndex(m.wizard.protocolIndex+delta, len(protocolOptions)) - case "mode": - if m.wizard.mode == modeStandard { - m.wizard.mode = modeTurbo - } else { - m.wizard.mode = modeStandard + case server.EventRunFailed: + if rs, ok := e.Payload.(*server.RunState); ok { + m.dash.runState = rs } - case "prompt_mode": - m.wizard.promptModeIndex = wrapIndex(m.wizard.promptModeIndex+delta, len(promptModeOptions)) - m.wizard.values["prompt_value"] = "" - case "stream": - m.wizard.stream = !m.wizard.stream - case "thinking": - m.wizard.thinking = !m.wizard.thinking - case "report": - m.wizard.report = !m.wizard.report - default: - return - } - // Clamp fieldIndex in case field count changed (e.g. mode switch) - fields := m.wizardStepFields(m.wizard.step) - if m.wizard.fieldIndex >= len(fields) && len(fields) > 0 { - m.wizard.fieldIndex = len(fields) - 1 - } - m.refreshWizardInput() -} - -func (m *Model) displayWizardValue(field wizardField) string { - switch field.key { - case "protocol": - return protocolOptions[m.wizard.protocolIndex] - case "mode": - return m.wizard.mode - case "stream": - return boolLabel(m.wizard.stream) - case "thinking": - return boolLabel(m.wizard.thinking) - case "report": - return boolLabel(m.wizard.report) - case "prompt_mode": - return promptModeOptions[m.wizard.promptModeIndex] - default: - return m.wizard.values[field.key] - } -} - -func boolLabel(v bool) string { - if v { - return "开启" + m.err = fmt.Errorf("运行失败: %s", m.dash.runState.ErrorMsg) + m.view = viewTaskDetail + return m, tea.Batch( + m.client.LoadTasksCmd(), + m.client.LoadHistoryCmd(m.dash.taskID, 10), + ) } - return "关闭" -} -func wrapIndex(index, length int) int { - if length == 0 { - return 0 + // 若 eventCh 还在,继续等待下一条事件 + if m.dash.eventCh != nil { + return m, WaitEventCmd(m.dash.eventCh) } - for index < 0 { - index += length - } - return index % length + return m, nil } -func buildTaskDefinition(state *wizardState) (types.TaskDefinition, error) { - protocol := protocolOptions[state.protocolIndex] - input := types.Input{ - Protocol: protocol, - EndpointURL: strings.TrimSpace(state.values["endpoint"]), - ApiKey: strings.TrimSpace(state.values["apiKey"]), - Model: strings.TrimSpace(state.values["model"]), - Stream: state.stream, - Thinking: state.thinking, - Report: state.report, - PromptMode: promptModeOptions[state.promptModeIndex], - } - - switch input.PromptMode { - case promptModeFile: - input.PromptFile = strings.TrimSpace(state.values["prompt_value"]) - case promptModeGenerated: - length, err := strconv.Atoi(strings.TrimSpace(state.values["prompt_value"])) - if err != nil { - return types.TaskDefinition{}, fmt.Errorf("invalid prompt length: %w", err) - } - input.PromptLength = length - default: - input.PromptText = state.values["prompt_value"] - } +// ─── 共享渲染工具 ───────────────────────────────────────────────────────────── - if state.mode == modeTurbo { - initConcurrency, err := strconv.Atoi(strings.TrimSpace(state.values["turbo_init"])) - if err != nil { - return types.TaskDefinition{}, fmt.Errorf("invalid turbo init concurrency: %w", err) - } - maxConcurrency, err := strconv.Atoi(strings.TrimSpace(state.values["turbo_max"])) - if err != nil { - return types.TaskDefinition{}, fmt.Errorf("invalid turbo max concurrency: %w", err) - } - stepSize, err := strconv.Atoi(strings.TrimSpace(state.values["turbo_step"])) - if err != nil { - return types.TaskDefinition{}, fmt.Errorf("invalid turbo step size: %w", err) - } - levelRequests, err := strconv.Atoi(strings.TrimSpace(state.values["turbo_level_requests"])) - if err != nil { - return types.TaskDefinition{}, fmt.Errorf("invalid turbo level requests: %w", err) - } - minSuccessRate, err := strconv.ParseFloat(strings.TrimSpace(state.values["turbo_min_success"]), 64) - if err != nil { - return types.TaskDefinition{}, fmt.Errorf("invalid turbo min success rate: %w", err) - } - maxLatency, err := time.ParseDuration(strings.TrimSpace(state.values["turbo_max_latency"])) - if err != nil { - return types.TaskDefinition{}, fmt.Errorf("invalid turbo max latency: %w", err) - } - input.Turbo = true - input.Count = levelRequests - input.Concurrency = initConcurrency - input.TurboConfig = types.TurboConfig{ - InitConcurrency: initConcurrency, - MaxConcurrency: maxConcurrency, - StepSize: stepSize, - LevelRequests: levelRequests, - MinSuccessRate: minSuccessRate, - MaxLatency: maxLatency, - } - } else { - concurrency, err := strconv.Atoi(strings.TrimSpace(state.values["concurrency"])) - if err != nil { - return types.TaskDefinition{}, fmt.Errorf("invalid concurrency: %w", err) - } - count, err := strconv.Atoi(strings.TrimSpace(state.values["count"])) - if err != nil { - return types.TaskDefinition{}, fmt.Errorf("invalid count: %w", err) - } - timeout, err := time.ParseDuration(strings.TrimSpace(state.values["timeout"])) - if err != nil { - return types.TaskDefinition{}, fmt.Errorf("invalid timeout: %w", err) - } - input.Concurrency = concurrency - input.Count = count - input.Timeout = timeout +// renderHeader 渲染顶部状态栏(全宽,左侧标题 + 右侧信息)。 +func (m *Model) renderHeader(title, right string) string { + w := m.width + if w < 1 { + w = 80 } - - validatedInput, err := task.HydrateInput(input) - if err != nil { - return types.TaskDefinition{}, err + titleW := lipgloss.Width(title) + rightW := lipgloss.Width(right) + pad := w - titleW - rightW - 2 + if pad < 1 { + pad = 1 } - validatedInput.PromptSource = nil - - now := time.Now() - createdAt := state.createdAt - if createdAt.IsZero() { - createdAt = now + line := " " + title + strings.Repeat(" ", pad) + right + " " + // 截断 + if lipgloss.Width(line) > w { + line = line[:w] } - - return types.TaskDefinition{ - ID: state.editingTaskID, - Name: strings.TrimSpace(state.values["name"]), - Input: validatedInput, - CreatedAt: createdAt, - UpdatedAt: now, - LastRunAt: state.lastRunAt, - LastRunSummary: state.lastRunSummary, - }, nil + return m.styles.header.Width(w).Render(line) } -func (m *Model) saveWizard() error { - taskDef, err := buildTaskDefinition(m.wizard) - if err != nil { - return err +// renderFooter 渲染底部状态栏(全宽)。 +func (m *Model) renderFooter(parts ...string) string { + w := m.width + if w < 1 { + w = 80 } - if taskDef.ID == "" { - taskDef.ID = fmt.Sprintf("task_%d", time.Now().UnixNano()) - } - m.store.Upsert(taskDef) - if err := m.store.Save(); err != nil { - return err - } - m.tasks = m.store.Tasks - for i, item := range m.tasks { - if item.ID == taskDef.ID { - m.selected = i - break + var visible []string + for _, p := range parts { + if p != "" { + visible = append(visible, p) } } - m.reloadHistoryForSelectedTask() - m.status = "任务已保存" - m.wizard = nil - m.view = viewTaskDetail - return nil + line := " " + strings.Join(visible, " │ ") + return m.styles.footer.Width(w).Render(line) } -func (m *Model) reloadHistoryForSelectedTask() { - taskDef, ok := m.currentTask() - if !ok { - m.history = nil - return - } - history, err := task.LoadHistory(taskDef.ID, 5) - if err != nil { - m.err = err - m.history = nil - return - } - m.history = history -} +// dualColumnLayout 将左右内容放入双列布局,高度限制为 maxH。 +func (m *Model) dualColumnLayout(left, right string, leftW, rightW, maxH int) string { + leftLines := strings.Split(left, "\n") + rightLines := strings.Split(right, "\n") -func (m *Model) startTaskRun(taskDef types.TaskDefinition) { - input, err := task.HydrateInput(taskDef.Input) - if err != nil { - m.err = err - return + // 裁剪至 maxH + if len(leftLines) > maxH { + leftLines = leftLines[:maxH] } - m.runningTask = &taskDef - m.runStartedAt = time.Now() - m.progress = types.StatsData{} - m.runResult = nil - m.turboResult = nil - m.view = viewDashboard - - if input.Turbo { - engine := turbo.New(turbo.DefaultRunnerFactory(taskDef.ID)) - m.activeTurbo = engine - go func() { - result, err := engine.Run(input) - if err != nil { - m.program.Send(asyncErrorMsg{err: err}) - return - } - m.program.Send(turboCompleteMsg{taskID: taskDef.ID, result: result}) - }() - return + if len(rightLines) > maxH { + rightLines = rightLines[:maxH] } - - runnerInstance, err := runner.NewRunner(taskDef.ID, input) - if err != nil { - m.err = err - return + // 补齐行数 + for len(leftLines) < maxH { + leftLines = append(leftLines, "") } - m.activeRunner = runnerInstance - go func() { - result, err := runnerInstance.RunWithProgress(func(stats types.StatsData) { - m.program.Send(progressMsg{stats: stats}) - }) - if err != nil { - m.program.Send(asyncErrorMsg{err: err}) - return - } - paths, err := generateReports(result, input.Report) - if err != nil { - m.program.Send(asyncErrorMsg{err: err}) - return - } - m.program.Send(runCompleteMsg{taskID: taskDef.ID, result: result, reportPaths: paths}) - }() -} - -func generateReports(result *types.ReportData, enabled bool) ([]string, error) { - if !enabled || result == nil { - return nil, nil + for len(rightLines) < maxH { + rightLines = append(rightLines, "") } - manager := report.NewReportManager() - return manager.GenerateReports([]types.ReportData{*result}, []string{"json", "csv"}) -} -func (m *Model) persistStandardRun(taskID string, result *types.ReportData, reportPaths []string) { - taskDef, ok := m.store.Get(taskID) - if !ok { - return - } - finishedAt := time.Now() - summary := &types.TaskRunSummary{ - RunID: fmt.Sprintf("run_%d", finishedAt.UnixNano()), - TaskID: taskID, - Mode: modeStandard, - Status: "completed", - Protocol: result.Protocol, - Model: result.Model, - StartedAt: m.runStartedAt, - FinishedAt: finishedAt, - SuccessRate: result.SuccessRate, - AvgTTFT: result.AvgTTFT, - AvgTPS: result.AvgTPS, - CacheHitRate: result.AvgCacheHitRate * 100, - } - for _, path := range reportPaths { - switch filepath.Ext(path) { - case ".json": - summary.ReportJSONPath = path - case ".csv": - summary.ReportCSVPath = path + var rows []string + for i := 0; i < maxH; i++ { + lLine := leftLines[i] + rLine := rightLines[i] + lW := lipgloss.Width(lLine) + if lW < leftW { + lLine += strings.Repeat(" ", leftW-lW) } + rows = append(rows, lLine+" "+rLine) } - taskDef.LastRunAt = &finishedAt - taskDef.LastRunSummary = summary - m.store.Upsert(taskDef) - _ = m.store.Save() - _ = task.AppendRun(taskID, *summary) - m.tasks = m.store.Tasks - m.reloadHistoryForSelectedTask() + return strings.Join(rows, "\n") } -func (m *Model) persistTurboRun(taskID string, result *types.TurboResult) { - taskDef, ok := m.store.Get(taskID) - if !ok { - return - } - finishedAt := time.Now() - latestSuccessRate := 0.0 - latestCacheHitRate := 0.0 - if len(result.Levels) > 0 { - lastLevel := result.Levels[len(result.Levels)-1] - latestSuccessRate = lastLevel.SuccessRate * 100 - latestCacheHitRate = lastLevel.CacheHitRate * 100 - } - summary := &types.TaskRunSummary{ - RunID: fmt.Sprintf("run_%d", finishedAt.UnixNano()), - TaskID: taskID, - Mode: modeTurbo, - Status: result.StopReason, - Protocol: result.Protocol, - Model: result.Model, - StartedAt: m.runStartedAt, - FinishedAt: finishedAt, - SuccessRate: latestSuccessRate, - AvgTPS: result.PeakTPS, - CacheHitRate: latestCacheHitRate, - MaxStableConcurrency: result.MaxStableConcurrency, - } - taskDef.LastRunAt = &finishedAt - taskDef.LastRunSummary = summary - m.store.Upsert(taskDef) - _ = m.store.Save() - _ = task.AppendRun(taskID, *summary) - m.tasks = m.store.Tasks - m.reloadHistoryForSelectedTask() -} +// ─── 工具 ───────────────────────────────────────────────────────────────────── -func promptSummary(input types.Input) string { - switch input.PromptMode { - case promptModeFile: - return input.PromptFile - case promptModeGenerated: - return fmt.Sprintf("长度 %d", input.PromptLength) - default: - if len(input.PromptText) > 48 { - return input.PromptText[:48] + "..." - } - return input.PromptText +func max(a, b int) int { + if a > b { + return a } + return b } diff --git a/internal/tui/model_test.go b/internal/tui/model_test.go index df03c10..a03ffea 100644 --- a/internal/tui/model_test.go +++ b/internal/tui/model_test.go @@ -2,72 +2,151 @@ package tui import ( "testing" - "time" - "github.com/yinxulai/ait/internal/config" + "github.com/yinxulai/ait/internal/server" "github.com/yinxulai/ait/internal/types" ) -func TestBuildTaskDefinitionStandardMode(t *testing.T) { - state := newWizardState(nil, viewTaskList, &config.Config{DefaultProtocol: types.ProtocolOpenAIResponses}) - state.values["name"] = "nightly-openai" - state.values["endpoint"] = "https://api.openai.com/v1/responses" - state.values["apiKey"] = "sk-test" - state.values["model"] = "gpt-4.1" - state.values["concurrency"] = "8" - state.values["count"] = "120" - state.values["timeout"] = "45s" - state.values["prompt_value"] = "hello" +// stubServer 是 server.Server 的测试桩,所有方法都返回零值。 +type stubServer struct{} - taskDef, err := buildTaskDefinition(state) - if err != nil { - t.Fatalf("buildTaskDefinition() returned unexpected error: %v", err) +func (s *stubServer) ListTasks() []types.TaskDefinition { return nil } +func (s *stubServer) GetTask(id string) (types.TaskDefinition, bool) { return types.TaskDefinition{}, false } +func (s *stubServer) CreateTask(cfg server.TaskConfig) (types.TaskDefinition, error) { return types.TaskDefinition{}, nil } +func (s *stubServer) UpdateTask(id string, cfg server.TaskConfig) (types.TaskDefinition, error) { + return types.TaskDefinition{}, nil +} +func (s *stubServer) DeleteTask(id string) error { return nil } +func (s *stubServer) CopyTask(id string) (types.TaskDefinition, error) { return types.TaskDefinition{}, nil } +func (s *stubServer) StartRun(taskID string) (server.RunID, error) { return "", nil } +func (s *stubServer) StopRun(runID server.RunID) error { return nil } +func (s *stubServer) GetRunState(runID server.RunID) (*server.RunState, bool) { return nil, false } +func (s *stubServer) Subscribe(runID server.RunID) (<-chan server.Event, server.CancelFunc) { + ch := make(chan server.Event) + close(ch) + return ch, func() {} +} +func (s *stubServer) GetHistory(taskID string, limit int) ([]types.TaskRunSummary, error) { return nil, nil } +func (s *stubServer) GenerateReport(runID server.RunID, fmt server.ReportFormat) (string, error) { + return "", nil +} + +// ─── NewModel ───────────────────────────────────────────────────────────────── + +func TestNewModel_InitialState(t *testing.T) { + m := NewModel(&stubServer{}) + if m == nil { + t.Fatal("NewModel returned nil") + } + if m.view != viewTaskList { + t.Errorf("initial view = %q, want %q", m.view, viewTaskList) } - if taskDef.Input.Protocol != types.ProtocolOpenAIResponses { - t.Fatalf("expected protocol %s, got %s", types.ProtocolOpenAIResponses, taskDef.Input.Protocol) +} + +// ─── Wizard: openWizard + buildTaskInput ────────────────────────────────────── + +func TestOpenWizard_NewTask_Defaults(t *testing.T) { + m := NewModel(&stubServer{}) + m.openWizard(nil) + if m.wizard == nil { + t.Fatal("wizard should not be nil after openWizard") } - if taskDef.Input.EndpointURL != "https://api.openai.com/v1/responses" { - t.Fatalf("unexpected endpoint: %s", taskDef.Input.EndpointURL) + if m.wizard.editingID != "" { + t.Errorf("new task wizard should have empty editingID, got %q", m.wizard.editingID) } - if taskDef.Input.Concurrency != 8 || taskDef.Input.Count != 120 || taskDef.Input.Timeout != 45*time.Second { - t.Fatalf("unexpected standard input fields: %+v", taskDef.Input) + if m.wizard.concurrency <= 0 { + t.Errorf("default concurrency should be positive, got %d", m.wizard.concurrency) } - if taskDef.Input.PromptMode != promptModeText || taskDef.Input.PromptText != "hello" { - t.Fatalf("unexpected prompt fields: %+v", taskDef.Input) + if m.wizard.promptMode != promptModeText { + t.Errorf("default promptMode = %q, want %q", m.wizard.promptMode, promptModeText) } } -func TestBuildTaskDefinitionTurboMode(t *testing.T) { - state := newWizardState(nil, viewTaskList, &config.Config{}) - state.mode = modeTurbo - state.protocolIndex = 2 - state.promptModeIndex = 2 - state.values["name"] = "turbo-anthropic" - state.values["endpoint"] = "https://api.anthropic.com/v1/messages" - state.values["apiKey"] = "sk-ant" - state.values["model"] = "claude-3-7-sonnet" - state.values["turbo_init"] = "1" - state.values["turbo_max"] = "12" - state.values["turbo_step"] = "2" - state.values["turbo_level_requests"] = "20" - state.values["turbo_min_success"] = "0.92" - state.values["turbo_max_latency"] = "6s" - state.values["prompt_value"] = "256" +func TestOpenWizard_EditTask_Populate(t *testing.T) { + m := NewModel(&stubServer{}) + task := types.TaskDefinition{ + ID: "task-123", + Name: "my-task", + Input: types.Input{ + Model: "gpt-4", + Protocol: types.ProtocolOpenAICompletions, + ApiKey: "sk-test", + Concurrency: 5, + Count: 50, + PromptMode: promptModeText, + PromptText: "hello", + }, + } + m.openWizard(&task) + if m.wizard == nil { + t.Fatal("wizard should not be nil") + } + if m.wizard.editingID != "task-123" { + t.Errorf("editingID = %q, want %q", m.wizard.editingID, "task-123") + } + if m.wizard.model != "gpt-4" { + t.Errorf("model = %q, want %q", m.wizard.model, "gpt-4") + } + if m.wizard.concurrency != 5 { + t.Errorf("concurrency = %d, want 5", m.wizard.concurrency) + } +} - taskDef, err := buildTaskDefinition(state) - if err != nil { - t.Fatalf("buildTaskDefinition() returned unexpected error: %v", err) +func TestBuildTaskInput_Standard(t *testing.T) { + m := NewModel(&stubServer{}) + m.openWizard(nil) + wz := m.wizard + wz.model = "gpt-4.1" + wz.apiKey = "sk-test" + wz.concurrency = 8 + wz.count = 120 + wz.promptMode = promptModeText + wz.promptText = "hello" + + inp := m.buildTaskInput() + if inp.Model != "gpt-4.1" { + t.Errorf("model = %q, want gpt-4.1", inp.Model) + } + if inp.Concurrency != 8 { + t.Errorf("concurrency = %d, want 8", inp.Concurrency) + } + if inp.Count != 120 { + t.Errorf("count = %d, want 120", inp.Count) } - if !taskDef.Input.Turbo { - t.Fatal("expected Turbo to be enabled") + if inp.PromptMode != promptModeText || inp.PromptText != "hello" { + t.Errorf("unexpected prompt config: mode=%q text=%q", inp.PromptMode, inp.PromptText) + } + if inp.Turbo { + t.Error("turbo should be false in standard mode") + } +} + +func TestBuildTaskInput_Turbo(t *testing.T) { + m := NewModel(&stubServer{}) + m.openWizard(nil) + wz := m.wizard + wz.model = "claude-3-7-sonnet" + wz.apiKey = "sk-ant" + wz.protocol = types.ProtocolAnthropicMessages + wz.turbo = true + wz.initConcurrency = 1 + wz.maxConcurrency = 12 + wz.stepSize = 2 + wz.levelRequests = 20 + wz.promptMode = promptModeGenerated + wz.promptLength = 256 + + inp := m.buildTaskInput() + if !inp.Turbo { + t.Error("expected Turbo=true") } - if taskDef.Input.TurboConfig.MaxConcurrency != 12 || taskDef.Input.TurboConfig.MaxLatency != 6*time.Second { - t.Fatalf("unexpected turbo config: %+v", taskDef.Input.TurboConfig) + if inp.TurboConfig.MaxConcurrency != 12 { + t.Errorf("MaxConcurrency = %d, want 12", inp.TurboConfig.MaxConcurrency) } - if taskDef.Input.PromptMode != promptModeGenerated || taskDef.Input.PromptLength != 256 { - t.Fatalf("unexpected generated prompt config: %+v", taskDef.Input) + if inp.PromptMode != promptModeGenerated || inp.PromptLength != 256 { + t.Errorf("unexpected prompt config: mode=%q len=%d", inp.PromptMode, inp.PromptLength) } - if taskDef.Input.Protocol != types.ProtocolAnthropicMessages { - t.Fatalf("expected anthropic protocol, got %s", taskDef.Input.Protocol) + if inp.Protocol != types.ProtocolAnthropicMessages { + t.Errorf("protocol = %q, want anthropic-messages", inp.Protocol) } } diff --git a/internal/tui/page_dashboard.go b/internal/tui/page_dashboard.go new file mode 100644 index 0000000..4f982c1 --- /dev/null +++ b/internal/tui/page_dashboard.go @@ -0,0 +1,317 @@ +package tui + +import ( + "fmt" + "strings" + + tea "github.com/charmbracelet/bubbletea" + "github.com/yinxulai/ait/internal/server" +) + +// dashboardState 仪表盘页的局部状态。 +type dashboardState struct { + runID server.RunID + taskID string + eventCh <-chan server.Event // nil 表示已后台/已结束 + cancelFn server.CancelFunc + runState *server.RunState + reqSel int // 选中请求的 index(-1 = 无选中) + reqOff int // 请求列表滚动偏移 +} + +// isRunning 判断仪表盘内的运行是否仍在进行中。 +func (d *dashboardState) isRunning() bool { + if d == nil || d.runState == nil { + return false + } + return d.runState.Status == server.RunStatusRunning +} + +// ─── 按键处理 ───────────────────────────────────────────────────────────────── + +func (m *Model) handleDashboardKey(msg tea.KeyMsg) (tea.Model, tea.Cmd) { + if m.dash == nil { + m.view = viewTaskList + return m, nil + } + d := m.dash + var reqs []*server.RequestMetrics + if d.runState != nil { + reqs = d.runState.Requests + } + + switch msg.String() { + case "up", "k": + if d.reqSel > 0 { + d.reqSel-- + } else if len(reqs) > 0 { + d.reqSel = len(reqs) - 1 + } + m.adjustReqOffset() + + case "down", "j": + if d.reqSel < len(reqs)-1 { + d.reqSel++ + } else { + d.reqSel = 0 + } + m.adjustReqOffset() + + case "enter": + if d.reqSel >= 0 && d.reqSel < len(reqs) { + m.reqDetail = &reqDetailState{ + runID: d.runID, + requests: reqs, + index: d.reqSel, + } + m.view = viewReqDetail + } + + case "s": + if d.isRunning() { + return m, m.client.StopRunCmd(d.runID) + } + + case "b": + // 后台运行:取消订阅,返回任务列表,保留 dash 状态 + if d.cancelFn != nil { + d.cancelFn() + } + d.eventCh = nil + d.cancelFn = nil + m.view = viewTaskList + m.status = fmt.Sprintf("运行 %s 已转入后台", d.runID) + + case "r": + // 生成报告 + if d.runState != nil && !d.isRunning() { + return m, m.client.GenerateReportCmd(d.runID, server.ReportFormatJSON) + } + + case "left", "esc": + if !d.isRunning() { + // 运行已结束,直接返回任务详情 + if d.cancelFn != nil { + d.cancelFn() + } + m.dash = nil + m.view = viewTaskDetail + } + + case "q": + return m, tea.Quit + } + + return m, nil +} + +// adjustReqOffset 根据 reqSel 调整列表的可见窗口。 +func (m *Model) adjustReqOffset() { + if m.dash == nil { + return + } + visH := m.height - 10 + if visH < 5 { + visH = 5 + } + sel := m.dash.reqSel + off := m.dash.reqOff + if sel < off { + off = sel + } else if sel >= off+visH { + off = sel - visH + 1 + } + m.dash.reqOff = off +} + +// ─── 渲染 ───────────────────────────────────────────────────────────────────── + +func (m *Model) renderDashboard() string { + if m.dash == nil || m.width == 0 { + return "加载中..." + } + d := m.dash + rs := d.runState + + statusStr := "等待中" + if rs != nil { + switch rs.Status { + case server.RunStatusRunning: + statusStr = m.styles.ok.Render("运行中") + case server.RunStatusCompleted: + statusStr = m.styles.ok.Render("已完成") + case server.RunStatusFailed: + statusStr = m.styles.errStyle.Render("失败") + case server.RunStatusStopped: + statusStr = m.styles.muted.Render("已停止") + } + } + header := m.renderHeader("AIT 仪表盘", statusStr) + + var cbItems []contextBarItem + if d.reqSel >= 0 { + cbItems = contextBarItems_dashboard_sel() + } else { + cbItems = contextBarItems_dashboard_nosel() + } + contextBar := m.renderContextBar(cbItems) + + var footerRight string + if rs != nil && rs.TotalReqs > 0 { + footerRight = fmt.Sprintf("%d/%d 请求", rs.DoneReqs, rs.TotalReqs) + } + footer := m.renderFooter("[s] 停止", "[b] 后台", "[r] 报告", footerRight) + + cbH := 0 + if contextBar != "" { + cbH = 1 + } + contentH := m.height - 1 - cbH - 1 + if contentH < 4 { + contentH = 4 + } + + leftW := (m.width - 4) * 55 / 100 + rightW := m.width - 4 - leftW + + leftContent := m.buildDashLeft(contentH, leftW) + rightContent := m.buildDashRight(contentH) + mid := m.dualColumnLayout(leftContent, rightContent, leftW, rightW, contentH) + + parts := []string{header, mid} + if contextBar != "" { + parts = append(parts, contextBar) + } + parts = append(parts, footer) + return strings.Join(parts, "\n") +} + +func (m *Model) buildDashLeft(maxH, width int) string { + d := m.dash + rs := d.runState + var lines []string + + // ── 汇总指标 ── + lines = append(lines, m.styles.sectionHead.Render("实时指标")) + lines = append(lines, "") + + if rs == nil { + lines = append(lines, m.styles.muted.Render(" 等待数据...")) + return strings.Join(lines, "\n") + } + + // 进度条 + pbW := width - 20 + if pbW < 10 { + pbW = 10 + } + pct := 0.0 + if rs.TotalReqs > 0 { + pct = float64(rs.DoneReqs) * 100 / float64(rs.TotalReqs) + } + pb := progressBar(rs.DoneReqs, rs.TotalReqs, pbW) + lines = append(lines, fmt.Sprintf(" 进度 %s %5.1f%%", pb, pct)) + lines = append(lines, "") + + lines = append(lines, row(m, "总请求数 ", fmt.Sprintf("%d", rs.TotalReqs))) + lines = append(lines, row(m, "已完成 ", fmt.Sprintf("%d", rs.DoneReqs))) + lines = append(lines, row(m, "成功 ", m.styles.ok.Render(fmt.Sprintf("%d", rs.SuccessReqs)))) + lines = append(lines, row(m, "失败 ", m.styles.errStyle.Render(fmt.Sprintf("%d", rs.FailedReqs)))) + lines = append(lines, row(m, "成功率 ", fmt.Sprintf("%.2f%%", rs.SuccessRate))) + lines = append(lines, "") + lines = append(lines, m.styles.sectionHead.Render("性能指标")) + lines = append(lines, "") + lines = append(lines, row(m, "平均 TPS ", m.styles.metricVal.Render(fmt.Sprintf("%.2f tok/s", rs.AvgTPS)))) + lines = append(lines, row(m, "平均 TTFT ", m.styles.metricVal.Render(fmt.Sprintf("%.0f ms", float64(rs.AvgTTFT.Milliseconds()))))) + lines = append(lines, row(m, "缓存命中率 ", fmt.Sprintf("%.2f%%", rs.CacheHitRate))) + + if rs.Mode == "turbo" && len(rs.Levels) > 0 { + lines = append(lines, "") + lines = append(lines, m.styles.sectionHead.Render("Turbo 并发探测")) + lines = append(lines, "") + for i, lv := range rs.Levels { + sel := "" + if i == rs.CurrentLevel { + sel = m.styles.ok.Render("▶") + } else { + sel = " " + } + stableStr := m.styles.muted.Render("探测中") + if lv.Stable { + stableStr = m.styles.ok.Render("稳定") + } else if lv.StopReason != "" { + stableStr = m.styles.errStyle.Render("停止") + } + lines = append(lines, fmt.Sprintf("%s 并发%3d TPS %5.1f 成功率 %5.1f%% %s", + sel, lv.Concurrency, lv.AvgTPS, lv.SuccessRate, stableStr)) + } + } + + if rs.ErrorMsg != "" { + lines = append(lines, "") + lines = append(lines, m.styles.errStyle.Render("错误: "+truncate(rs.ErrorMsg, width-10))) + } + + return strings.Join(lines, "\n") +} + +func (m *Model) buildDashRight(maxH int) string { + d := m.dash + rs := d.runState + var lines []string + + lines = append(lines, m.styles.sectionHead.Render("请求列表")) + lines = append(lines, "") + lines = append(lines, m.styles.tableHead.Render( + fmt.Sprintf(" %-4s %-5s %8s %8s %8s %-6s", "#", "状态", "耗时", "TTFT", "TPS", "Token"), + )) + lines = append(lines, m.styles.muted.Render(strings.Repeat("─", 56))) + + if rs == nil || len(rs.Requests) == 0 { + lines = append(lines, m.styles.muted.Render(" 暂无请求...")) + return strings.Join(lines, "\n") + } + + visH := maxH - len(lines) - 1 + if visH < 1 { + visH = 1 + } + + start := d.reqOff + if start < 0 { + start = 0 + } + end := start + visH + if end > len(rs.Requests) { + end = len(rs.Requests) + } + + for i := start; i < end; i++ { + r := rs.Requests[i] + statusIcon := m.styles.ok.Render("✓") + if !r.Success { + statusIcon = m.styles.errStyle.Render("✗") + } + line := fmt.Sprintf(" %3d %s %7dms %7dms %7.1f %-6d", + r.Index+1, + statusIcon, + r.TotalTime.Milliseconds(), + r.TTFT.Milliseconds(), + r.TPS, + r.CompletionTokens, + ) + if i == d.reqSel { + lines = append(lines, m.styles.tableRowSel.Render(line)) + } else { + lines = append(lines, line) + } + } + + // 滚动提示 + if len(rs.Requests) > visH { + lines = append(lines, m.styles.muted.Render( + fmt.Sprintf(" %d/%d 请求 [↑↓] 滚动", len(rs.Requests), len(rs.Requests)))) + } + + return strings.Join(lines, "\n") +} diff --git a/internal/tui/page_reqdetail.go b/internal/tui/page_reqdetail.go new file mode 100644 index 0000000..199852b --- /dev/null +++ b/internal/tui/page_reqdetail.go @@ -0,0 +1,182 @@ +package tui + +import ( + "fmt" + "strings" + + tea "github.com/charmbracelet/bubbletea" + "github.com/yinxulai/ait/internal/server" +) + +// reqDetailState 请求详情页的状态。 +type reqDetailState struct { + runID server.RunID + requests []*server.RequestMetrics + index int +} + +// ─── 按键处理 ───────────────────────────────────────────────────────────────── + +func (m *Model) handleReqDetailKey(msg tea.KeyMsg) (tea.Model, tea.Cmd) { + rd := m.reqDetail + if rd == nil { + m.view = viewDashboard + return m, nil + } + + switch msg.String() { + case "left", "h": + if rd.index > 0 { + rd.index-- + } else { + rd.index = len(rd.requests) - 1 + } + case "right", "l": + if rd.index < len(rd.requests)-1 { + rd.index++ + } else { + rd.index = 0 + } + case "b", "esc", "backspace": + m.view = viewDashboard + case "q": + return m, tea.Quit + } + return m, nil +} + +// ─── 渲染 ───────────────────────────────────────────────────────────────────── + +func (m *Model) renderReqDetail() string { + rd := m.reqDetail + if rd == nil || m.width == 0 { + return "加载中..." + } + if len(rd.requests) == 0 { + return "无请求数据" + } + + idx := rd.index + if idx < 0 { + idx = 0 + } + if idx >= len(rd.requests) { + idx = len(rd.requests) - 1 + } + r := rd.requests[idx] + + header := m.renderHeader( + fmt.Sprintf("AIT 请求详情 #%d / %d", idx+1, len(rd.requests)), + statusStr(m, r), + ) + contextBar := m.renderContextBar(contextBarItems_reqdetail()) + footer := m.renderFooter("[←→] 切换", "[Esc] 返回仪表盘", "", "◆ AIT") + + cbH := 0 + if contextBar != "" { + cbH = 1 + } + contentH := m.height - 1 - cbH - 1 + if contentH < 4 { + contentH = 4 + } + + leftW := (m.width - 4) * 50 / 100 + rightW := m.width - 4 - leftW + + leftContent := m.buildReqLeft(r, contentH, leftW) + rightContent := m.buildReqRight(r, contentH) + mid := m.dualColumnLayout(leftContent, rightContent, leftW, rightW, contentH) + + parts := []string{header, mid} + if contextBar != "" { + parts = append(parts, contextBar) + } + parts = append(parts, footer) + return strings.Join(parts, "\n") +} + +func statusStr(m *Model, r *server.RequestMetrics) string { + if r.Success { + return m.styles.ok.Render("成功") + } + return m.styles.errStyle.Render("失败") +} + +func (m *Model) buildReqLeft(r *server.RequestMetrics, maxH, width int) string { + var lines []string + lines = append(lines, m.styles.sectionHead.Render("时间指标")) + lines = append(lines, "") + lines = append(lines, row(m, "总耗时 ", fmt.Sprintf("%d ms", r.TotalTime.Milliseconds()))) + lines = append(lines, row(m, "TTFT ", fmt.Sprintf("%d ms", r.TTFT.Milliseconds()))) + lines = append(lines, row(m, "TPS ", m.styles.metricVal.Render(fmt.Sprintf("%.2f tok/s", r.TPS)))) + lines = append(lines, "") + lines = append(lines, m.styles.sectionHead.Render("Token 统计")) + lines = append(lines, "") + lines = append(lines, row(m, "Prompt Tok ", fmt.Sprintf("%d", r.PromptTokens))) + lines = append(lines, row(m, "Output Tok ", fmt.Sprintf("%d", r.CompletionTokens))) + lines = append(lines, row(m, "缓存命中 ", fmt.Sprintf("%d tok (%.1f%%)", r.CachedTokens, r.CacheHitRate*100))) + lines = append(lines, "") + + if r.ErrorMessage != "" { + lines = append(lines, m.styles.sectionHead.Render("错误信息")) + lines = append(lines, "") + for _, part := range wrapText(r.ErrorMessage, width-4) { + lines = append(lines, m.styles.errStyle.Render(" "+part)) + } + } + + if r.PromptText != "" { + lines = append(lines, "") + lines = append(lines, m.styles.sectionHead.Render("Prompt")) + lines = append(lines, "") + for _, part := range wrapText(r.PromptText, width-4) { + if len(lines) >= maxH-1 { + break + } + lines = append(lines, " "+part) + } + } + return strings.Join(lines, "\n") +} + +func (m *Model) buildReqRight(r *server.RequestMetrics, maxH int) string { + var lines []string + lines = append(lines, m.styles.sectionHead.Render("网络指标")) + lines = append(lines, "") + lines = append(lines, row(m, "目标 IP ", r.TargetIP)) + lines = append(lines, row(m, "DNS 解析 ", fmt.Sprintf("%d ms", r.DNSTime.Milliseconds()))) + lines = append(lines, row(m, "TCP 连接 ", fmt.Sprintf("%d ms", r.ConnectTime.Milliseconds()))) + lines = append(lines, row(m, "TLS 握手 ", fmt.Sprintf("%d ms", r.TLSTime.Milliseconds()))) + lines = append(lines, "") + + if r.ResponseText != "" { + lines = append(lines, m.styles.sectionHead.Render("Response")) + lines = append(lines, "") + for _, part := range wrapText(r.ResponseText, 40) { + if len(lines) >= maxH-1 { + break + } + lines = append(lines, " "+part) + } + } + return strings.Join(lines, "\n") +} + +// wrapText 按宽度折行(简单按字节宽度,不处理 CJK)。 +func wrapText(s string, width int) []string { + if width <= 0 { + return []string{s} + } + var result []string + runes := []rune(s) + for len(runes) > 0 { + end := width + if end > len(runes) { + end = len(runes) + } + result = append(result, string(runes[:end])) + runes = runes[end:] + } + return result +} diff --git a/internal/tui/page_taskdetail.go b/internal/tui/page_taskdetail.go new file mode 100644 index 0000000..1720277 --- /dev/null +++ b/internal/tui/page_taskdetail.go @@ -0,0 +1,196 @@ +package tui + +import ( + "fmt" + "strings" + "time" + + tea "github.com/charmbracelet/bubbletea" + "github.com/yinxulai/ait/internal/types" +) + +// historyState 任务详情页的历史数据。 +type historyState struct { + taskID string + history []types.TaskRunSummary +} + +// ─── 渲染 ───────────────────────────────────────────────────────────────────── + +func (m *Model) renderTaskDetail() string { + if m.width == 0 { + return "加载中..." + } + + task, ok := m.taskList.currentTask() + if !ok { + return "请先选择任务" + } + + header := m.renderHeader( + "AIT 任务详情", + task.Name, + ) + + var cbItems []contextBarItem + if m.hist != nil { + cbItems = contextBarItems_taskDetail(len(m.hist.history) > 0) + } else { + cbItems = contextBarItems_taskDetail(false) + } + contextBar := m.renderContextBar(cbItems) + footer := m.renderFooter("[←/Esc] 返回", "[r] 运行", "[e] 编辑", "◆ AIT") + + cbH := 0 + if contextBar != "" { + cbH = 1 + } + contentH := m.height - 1 - cbH - 1 + if contentH < 4 { + contentH = 4 + } + + leftW := (m.width - 4) * 55 / 100 + rightW := m.width - 4 - leftW + + leftContent := m.buildDetailLeft(task, contentH, leftW) + rightContent := m.buildDetailRight(contentH) + mid := m.dualColumnLayout(leftContent, rightContent, leftW, rightW, contentH) + + parts := []string{header, mid} + if contextBar != "" { + parts = append(parts, contextBar) + } + parts = append(parts, footer) + return strings.Join(parts, "\n") +} + +func (m *Model) buildDetailLeft(task types.TaskDefinition, maxH, width int) string { + inp := task.Input + var lines []string + + lines = append(lines, m.styles.sectionHead.Render("基本配置")) + lines = append(lines, "") + lines = append(lines, row(m, "名称 ", task.Name)) + lines = append(lines, row(m, "创建时间 ", task.CreatedAt.Format("2006-01-02 15:04:05"))) + lines = append(lines, row(m, "更新时间 ", task.UpdatedAt.Format("2006-01-02 15:04:05"))) + if task.LastRunAt != nil { + lines = append(lines, row(m, "上次运行 ", timeAgo(*task.LastRunAt))) + } + lines = append(lines, "") + lines = append(lines, m.styles.sectionHead.Render("测试参数")) + lines = append(lines, "") + lines = append(lines, row(m, "协议 ", shortProtocol(inp.NormalizedProtocol()))) + lines = append(lines, row(m, "接口地址 ", truncate(inp.ResolvedEndpointURL(), width-20))) + lines = append(lines, row(m, "API Key ", maskAPIKey(inp.ApiKey))) + lines = append(lines, row(m, "模型 ", inp.Model)) + + modeStr := "标准" + if inp.Turbo { + modeStr = "Turbo (并发探测)" + } + lines = append(lines, row(m, "测试模式 ", modeStr)) + + if inp.Turbo { + tc := inp.TurboConfig + lines = append(lines, row(m, "初始并发 ", fmt.Sprintf("%d", tc.InitConcurrency))) + lines = append(lines, row(m, "最大并发 ", fmt.Sprintf("%d", tc.MaxConcurrency))) + lines = append(lines, row(m, "步进大小 ", fmt.Sprintf("+%d", tc.StepSize))) + lines = append(lines, row(m, "每级请求 ", fmt.Sprintf("%d", tc.LevelRequests))) + lines = append(lines, row(m, "最低成功率", fmt.Sprintf("%.0f%%", tc.MinSuccessRate))) + } else { + lines = append(lines, row(m, "并发数 ", fmt.Sprintf("%d", inp.Concurrency))) + lines = append(lines, row(m, "请求总数 ", fmt.Sprintf("%d", inp.Count))) + } + + lines = append(lines, row(m, "流式输出 ", boolLabel(inp.Stream))) + lines = append(lines, row(m, "Thinking ", boolLabel(inp.Thinking))) + lines = append(lines, "") + lines = append(lines, m.styles.sectionHead.Render("Prompt 配置")) + lines = append(lines, "") + lines = append(lines, row(m, "模式 ", inp.PromptMode)) + lines = append(lines, row(m, "内容 ", truncate(promptSummary(inp), width-20))) + + if m.status != "" { + lines = append(lines, "") + lines = append(lines, m.styles.muted.Render(m.status)) + } + if m.err != nil { + lines = append(lines, m.styles.errStyle.Render("错误: "+m.err.Error())) + } + + return strings.Join(lines, "\n") +} + +func (m *Model) buildDetailRight(maxH int) string { + var lines []string + lines = append(lines, m.styles.sectionHead.Render("运行历史")) + lines = append(lines, "") + + if m.hist == nil || len(m.hist.history) == 0 { + lines = append(lines, m.styles.muted.Render(" 暂无历史记录")) + lines = append(lines, "") + lines = append(lines, m.styles.muted.Render(" 按 [Enter] 或 [r] 开始第一次运行")) + return strings.Join(lines, "\n") + } + + for i, run := range m.hist.history { + if len(lines) >= maxH-2 { + break + } + statusIcon := m.styles.ok.Render("✓") + if run.Status != "completed" { + statusIcon = m.styles.errStyle.Render("✗") + } + elapsed := run.FinishedAt.Sub(run.StartedAt) + lines = append(lines, fmt.Sprintf("%s #%d %s", + statusIcon, i+1, timeAgo(run.StartedAt))) + lines = append(lines, fmt.Sprintf(" 成功率 %.1f%% TTFT %.0fms TPS %.1f", + run.SuccessRate, float64(run.AvgTTFT.Milliseconds()), run.AvgTPS)) + lines = append(lines, fmt.Sprintf(" 耗时 %s 模式 %s", + fmtDuration(elapsed), run.Mode)) + if run.ErrorSummary != "" { + lines = append(lines, m.styles.errStyle.Render(" "+truncate(run.ErrorSummary, 36))) + } + if run.ReportJSONPath != "" { + lines = append(lines, m.styles.muted.Render(" 报告: "+truncate(run.ReportJSONPath, 32))) + } + lines = append(lines, "") + } + + return strings.Join(lines, "\n") +} + +// ─── 按键处理 ───────────────────────────────────────────────────────────────── + +func (m *Model) handleTaskDetailKey(msg interface{ String() string }) tea.Cmd { + switch msg.String() { + case "left", "esc", "b": + m.view = viewTaskList + return nil + + case "enter", "r": + if t, ok := m.taskList.currentTask(); ok { + return m.startRunIfAllowed(t.ID, false) + } + } + return nil +} + +// ─── helpers ────────────────────────────────────────────────────────────────── + +func row(m *Model, label, value string) string { + return m.styles.label.Render(label) + " " + m.styles.value.Render(value) +} + +func fmtDuration(d time.Duration) string { + ms := d.Milliseconds() + if ms < 1000 { + return fmt.Sprintf("%dms", ms) + } + s := float64(ms) / 1000 + if s < 60 { + return fmt.Sprintf("%.1fs", s) + } + return fmt.Sprintf("%.0fm%.0fs", s/60, float64(int64(s)%60)) +} diff --git a/internal/tui/page_tasklist.go b/internal/tui/page_tasklist.go new file mode 100644 index 0000000..6ee8fc2 --- /dev/null +++ b/internal/tui/page_tasklist.go @@ -0,0 +1,412 @@ +package tui + +import ( + "fmt" + "strings" + "time" + + tea "github.com/charmbracelet/bubbletea" + "github.com/charmbracelet/lipgloss" + "github.com/yinxulai/ait/internal/types" +) + +// taskListState 任务列表页的局部状态。 +type taskListState struct { + tasks []types.TaskDefinition + selected int +} + +// currentTask 返回当前选中的任务。 +func (s *taskListState) currentTask() (types.TaskDefinition, bool) { + if len(s.tasks) == 0 || s.selected < 0 || s.selected >= len(s.tasks) { + return types.TaskDefinition{}, false + } + return s.tasks[s.selected], true +} + +// ─── 按键处理 ───────────────────────────────────────────────────────────────── + +func (m *Model) handleTaskListKey(msg tea.KeyMsg) (tea.Model, tea.Cmd) { + s := &m.taskList + + switch msg.String() { + case "up", "k": + if s.selected > 0 { + s.selected-- + } + case "down", "j": + if s.selected < len(s.tasks)-1 { + s.selected++ + } + + case "a": + // 新建任务 — 打开向导 + m.openWizard(nil) + + case "e": + if t, ok := s.currentTask(); ok { + m.openWizard(&t) + } + + case "y": + // 复制任务 + if t, ok := s.currentTask(); ok { + return m, m.client.CopyTaskCmd(t.ID) + } + + case "d": + // 删除任务 + if t, ok := s.currentTask(); ok { + return m, m.client.DeleteTaskCmd(t.ID) + } + + case "enter": + if t, ok := s.currentTask(); ok { + // 如果是运行中任务,进入仪表盘 + if m.dash != nil && m.dash.runID != "" && m.dash.taskID == t.ID { + m.view = viewDashboard + return m, nil + } + // 否则进入任务详情并加载历史 + m.view = viewTaskDetail + return m, m.client.LoadHistoryCmd(t.ID, 10) + } + + case "r": + if t, ok := s.currentTask(); ok { + return m, m.startRunIfAllowed(t.ID, false) + } + + case "s": + // 停止当前运行中的任务(若选中的是运行中任务) + if t, ok := s.currentTask(); ok { + if m.dash != nil && m.dash.taskID == t.ID { + return m, m.client.StopRunCmd(m.dash.runID) + } + } + + case "q": + return m, tea.Quit + } + + return m, nil +} + +// ─── 渲染 ───────────────────────────────────────────────────────────────────── + +func (m *Model) renderTaskList() string { + if m.width == 0 { + return "加载中..." + } + s := &m.taskList + + lastRunStr := "" + for _, t := range s.tasks { + if t.LastRunAt != nil { + lastRunStr = "最近: " + timeAgo(*t.LastRunAt) + break + } + } + header := m.renderHeader( + "AIT 任务中心", + fmt.Sprintf("已保存任务: %d %s", len(s.tasks), lastRunStr), + ) + + // 决定 context bar 内容 + var cbItems []contextBarItem + if t, ok := s.currentTask(); ok { + isRunning := m.dash != nil && m.dash.taskID == t.ID && m.dash.isRunning() + cbItems = contextBarItems_taskList(isRunning) + } + contextBar := m.renderContextBar(cbItems) + footer := m.renderFooter("[↑↓] 选择", "[a] 新建", "[q] 退出", "◆ AIT v0.1") + + // 内容区高度 = 总高 - header(1) - contextbar - footer(1) + cbH := 0 + if contextBar != "" { + cbH = 1 + } + contentH := m.height - 1 - cbH - 1 + if contentH < 4 { + contentH = 4 + } + + leftW := (m.width - 4) * 65 / 100 + rightW := m.width - 4 - leftW + + leftContent := m.buildTaskListTable(contentH, leftW) + rightContent := m.buildTaskListSidebar(contentH) + mid := m.dualColumnLayout(leftContent, rightContent, leftW, rightW, contentH) + + parts := []string{header, mid} + if contextBar != "" { + parts = append(parts, contextBar) + } + parts = append(parts, footer) + return strings.Join(parts, "\n") +} + +func (m *Model) buildTaskListTable(maxH, width int) string { + s := &m.taskList + var lines []string + + lines = append(lines, m.styles.tableHead.Render( + fmt.Sprintf(" %-28s %-9s %-14s %s", "任务名称", "模式", "协议", "上次结果"), + )) + lines = append(lines, m.styles.muted.Render(strings.Repeat("─", width))) + + if len(s.tasks) == 0 { + lines = append(lines, "") + lines = append(lines, m.styles.muted.Render(" 暂无任务 按 [a] 新建")) + return strings.Join(lines, "\n") + } + + for i, t := range s.tasks { + if len(lines) >= maxH-1 { + break + } + + // 运行中标记 + runIndicator := " " + if m.dash != nil && m.dash.taskID == t.ID && m.dash.isRunning() { + runIndicator = m.styles.ok.Render("◉") + } + + // 模式列(手动对齐 9 列宽) + var modeRendered string + if t.Input.Turbo { + modeRendered = m.styles.tagTurbo.Render("Turbo") + } else { + modeRendered = m.styles.tagStd.Render("标准") + } + modePad := 9 - lipgloss.Width(modeRendered) + if modePad < 0 { + modePad = 0 + } + modeCol := modeRendered + strings.Repeat(" ", modePad) + + proto := shortProtocol(t.Input.NormalizedProtocol()) + lastResult := m.styles.muted.Render("从未运行") + if t.LastRunSummary != nil { + pct := t.LastRunSummary.SuccessRate + if pct >= 99 { + lastResult = m.styles.ok.Render(fmt.Sprintf("✓ %.1f%%", pct)) + } else if pct >= 90 { + lastResult = m.styles.metricVal.Render(fmt.Sprintf("%.1f%%", pct)) + } else { + lastResult = m.styles.errStyle.Render(fmt.Sprintf("✗ %.1f%%", pct)) + } + } + + nameStr := truncate(t.Name, 27) + nameCol := fmt.Sprintf("%-27s ", nameStr) + protoCol := fmt.Sprintf("%-14s ", proto) + + mainRow := runIndicator + " " + nameCol + modeCol + " " + protoCol + lastResult + if i == s.selected { + // 选中行:纯文本 + tableRowSel 背景 + plainMode := "标准" + if t.Input.Turbo { + plainMode = "Turbo" + } + plainRow := " ▶ " + nameCol + fmt.Sprintf("%-9s ", plainMode) + protoCol + lastResult + lines = append(lines, m.styles.tableRowSel.Width(width).Render(plainRow)) + } else { + lines = append(lines, mainRow) + } + + // 二级子行:配置摘要 + var sub string + if m.dash != nil && m.dash.taskID == t.ID && m.dash.isRunning() { + rs := m.dash.runState + if rs != nil { + sub = fmt.Sprintf(" %s ◉ %d/%d 成功率 %.1f%%", + truncate(t.Input.Model, 18), rs.DoneReqs, rs.TotalReqs, rs.SuccessRate) + } + } + if sub == "" { + if t.Input.Turbo { + tc := t.Input.TurboConfig + sub = fmt.Sprintf(" %s %d→%d 步进+%d 每级%d", + truncate(t.Input.Model, 18), + tc.InitConcurrency, tc.MaxConcurrency, tc.StepSize, tc.LevelRequests) + } else { + sub = fmt.Sprintf(" %s 并发%d/请求%d", + truncate(t.Input.Model, 20), t.Input.Concurrency, t.Input.Count) + } + } + if i == s.selected { + lines = append(lines, m.styles.tableRowSel.Width(width).Render(sub)) + } else { + lines = append(lines, m.styles.muted.Render(sub)) + } + lines = append(lines, "") + } + return strings.Join(lines, "\n") +} + +func (m *Model) buildTaskListSidebar(maxH int) string { + var lines []string + lines = append(lines, m.styles.sectionHead.Render("快捷操作")) + lines = append(lines, "") + lines = append(lines, " "+m.styles.key.Render("[a]")+" 新建任务") + lines = append(lines, " "+m.styles.key.Render("[Enter]")+" 查看详情 / 进仪表盘") + lines = append(lines, " "+m.styles.key.Render("[r]")+" 运行选中任务") + lines = append(lines, " "+m.styles.key.Render("[e]")+" 编辑 "+ + m.styles.key.Render("[d]")+" 删除 "+ + m.styles.key.Render("[y]")+" 复制") + lines = append(lines, "") + lines = append(lines, m.styles.muted.Render(strings.Repeat("─", 28))) + lines = append(lines, "") + lines = append(lines, m.styles.sectionHead.Render("最近执行")) + lines = append(lines, "") + + count := 0 + for _, t := range m.taskList.tasks { + if t.LastRunSummary == nil { + continue + } + s := t.LastRunSummary + icon := m.styles.ok.Render("✓") + if s.SuccessRate < 90 { + icon = m.styles.errStyle.Render("✗") + } + lines = append(lines, fmt.Sprintf(" %s %-16s %.1f%% %.0f tok/s", + icon, truncate(t.Name, 16), s.SuccessRate, s.AvgTPS)) + count++ + if count >= 5 || len(lines) >= maxH-2 { + break + } + } + if count == 0 { + lines = append(lines, m.styles.muted.Render(" 暂无记录")) + } + + if m.status != "" { + lines = append(lines, "") + lines = append(lines, m.styles.muted.Render(m.status)) + } + if m.err != nil { + lines = append(lines, m.styles.errStyle.Render("错误: "+m.err.Error())) + } + return strings.Join(lines, "\n") +} + +// startRunIfAllowed 根据是否已有运行中任务决定是否启动新运行。 +// forceStart=true 表示无论是否有其他任务都启动(用于向导 [r] 保存并运行)。 +func (m *Model) startRunIfAllowed(taskID string, forceStart bool) tea.Cmd { + if !forceStart && m.dash != nil && m.dash.isRunning() { + m.status = fmt.Sprintf("已有任务 %q 在运行中,多任务并行可能影响网络指标", + m.dash.taskID) + return nil + } + return m.client.StartRunCmd(taskID) +} + +// ─── 共享渲染工具 ───────────────────────────────────────────────────────────── + +// 这些函数被多个 page_*.go 使用,统一放在此文件。 + +func progressBar(current, total, width int) string { + if total <= 0 || width <= 0 { + return lipgloss.NewStyle().Foreground(lipgloss.Color("238")).Render(strings.Repeat("░", width)) + } + filled := current * width / total + if filled > width { + filled = width + } + bar := lipgloss.NewStyle().Foreground(colorGreen).Render(strings.Repeat("█", filled)) + empty := lipgloss.NewStyle().Foreground(lipgloss.Color("238")).Render(strings.Repeat("░", width-filled)) + return bar + empty +} + +func progressBarRed(current, total, width int) string { + if total <= 0 || width <= 0 { + return lipgloss.NewStyle().Foreground(lipgloss.Color("238")).Render(strings.Repeat("░", width)) + } + filled := current * width / total + if filled > width { + filled = width + } + bar := lipgloss.NewStyle().Foreground(colorRed).Render(strings.Repeat("█", filled)) + empty := lipgloss.NewStyle().Foreground(lipgloss.Color("238")).Render(strings.Repeat("░", width-filled)) + return bar + empty +} + +// isRunningTask 判断任务是否当前正在运行。 +func (m *Model) isRunningTask(taskID string) bool { + return m.dash != nil && m.dash.taskID == taskID && m.dash.isRunning() +} + +// ─── 工具函数 ───────────────────────────────────────────────────────────────── + +func truncate(s string, n int) string { + if n <= 0 || len([]rune(s)) <= n { + return s + } + r := []rune(s) + if n <= 3 { + return string(r[:n]) + } + return string(r[:n-3]) + "..." +} + +func timeAgo(t time.Time) string { + d := time.Since(t) + switch { + case d < time.Minute: + return fmt.Sprintf("%ds 前", int(d.Seconds())) + case d < time.Hour: + return fmt.Sprintf("%dm 前", int(d.Minutes())) + case d < 24*time.Hour: + return fmt.Sprintf("%dh 前", int(d.Hours())) + } + return t.Format("01-02 15:04") +} + +func shortProtocol(p string) string { + p = strings.ReplaceAll(p, "openai-", "") + p = strings.ReplaceAll(p, "anthropic-", "") + return p +} + +func boolLabel(v bool) string { + if v { + return "开启" + } + return "关闭" +} + +func promptSummary(input types.Input) string { + switch input.PromptMode { + case promptModeFile: + return input.PromptFile + case promptModeGenerated: + return fmt.Sprintf("长度 %d", input.PromptLength) + default: + if len([]rune(input.PromptText)) > 48 { + return string([]rune(input.PromptText)[:45]) + "..." + } + return input.PromptText + } +} + +func maskAPIKey(key string) string { + if len(key) == 0 { + return "(空)" + } + if len(key) <= 8 { + return strings.Repeat("•", len(key)) + } + return key[:4] + strings.Repeat("•", len(key)-8) + key[len(key)-4:] +} + +func wrapIndex(index, length int) int { + if length == 0 { + return 0 + } + for index < 0 { + index += length + } + return index % length +} diff --git a/internal/tui/page_wizard.go b/internal/tui/page_wizard.go new file mode 100644 index 0000000..dd8c0d4 --- /dev/null +++ b/internal/tui/page_wizard.go @@ -0,0 +1,569 @@ +package tui + +import ( + "fmt" + "strconv" + "strings" + + tea "github.com/charmbracelet/bubbletea" + "github.com/charmbracelet/lipgloss" + "github.com/yinxulai/ait/internal/server" + "github.com/yinxulai/ait/internal/types" +) + +// ─── Prompt 模式常量 ────────────────────────────────────────────────────────── + +const ( + promptModeText = "text" + promptModeFile = "file" + promptModeGenerated = "generated" +) + +// ─── 向导状态 ───────────────────────────────────────────────────────────────── + +type wizardStep int + +const ( + wizardStepBasic wizardStep = 0 // 基础配置(名称、模式、协议) + wizardStepEndpoint wizardStep = 1 // 接口配置(URL、APIKey、模型) + wizardStepPrompt wizardStep = 2 // Prompt 配置(模式、内容、并发参数) +) + +// wizardState 向导的完整状态。 +type wizardState struct { + step wizardStep + editingID string // 非空表示编辑模式,存放被编辑任务的 ID + + // Step 0 + name string + turbo bool + protocol string // types.Protocol* 常量 + + // Step 1 + endpointURL string + apiKey string + model string + stream bool + thinking bool + + // Step 2 — Standard + concurrency int + count int + + // Step 2 — Turbo + initConcurrency int + maxConcurrency int + stepSize int + levelRequests int + minSuccessRate float64 + + // Prompt + promptMode string + promptText string + promptFile string + promptLength int + + // 当前活跃字段索引 + fieldIndex int +} + +// openWizard 打开向导。task==nil 表示新建,非 nil 表示编辑。 +func (m *Model) openWizard(task *types.TaskDefinition) { + if task == nil { + m.wizard = &wizardState{ + step: wizardStepBasic, + protocol: types.ProtocolOpenAICompletions, + concurrency: 10, + count: 100, + initConcurrency: 1, + maxConcurrency: 50, + stepSize: 5, + levelRequests: 20, + minSuccessRate: 95, + promptMode: promptModeText, + } + } else { + inp := task.Input + tc := inp.TurboConfig + m.wizard = &wizardState{ + step: wizardStepBasic, + editingID: task.ID, + name: task.Name, + turbo: inp.Turbo, + protocol: types.NormalizeProtocol(inp.Protocol), + endpointURL: inp.EndpointURL, + apiKey: inp.ApiKey, + model: inp.Model, + stream: inp.Stream, + thinking: inp.Thinking, + concurrency: inp.Concurrency, + count: inp.Count, + initConcurrency: tc.InitConcurrency, + maxConcurrency: tc.MaxConcurrency, + stepSize: tc.StepSize, + levelRequests: tc.LevelRequests, + minSuccessRate: tc.MinSuccessRate, + promptMode: inp.PromptMode, + promptText: inp.PromptText, + promptFile: inp.PromptFile, + promptLength: inp.PromptLength, + } + if m.wizard.promptMode == "" { + m.wizard.promptMode = promptModeText + } + if m.wizard.concurrency == 0 { + m.wizard.concurrency = 10 + } + if m.wizard.count == 0 { + m.wizard.count = 100 + } + } + m.view = viewWizard +} + +// ─── 按键处理 ───────────────────────────────────────────────────────────────── + +func (m *Model) handleWizardKey(msg tea.KeyMsg) (tea.Model, tea.Cmd) { + wz := m.wizard + if wz == nil { + m.view = viewTaskList + return m, nil + } + + fields := m.wizardFields() + maxField := len(fields) - 1 + + switch msg.String() { + case "esc": + m.wizard = nil + m.view = viewTaskList + return m, nil + + case "tab", "down", "j": + if wz.fieldIndex < maxField { + wz.fieldIndex++ + } + + case "shift+tab", "up", "k": + if wz.fieldIndex > 0 { + wz.fieldIndex-- + } + + case "left", "right": + // 布尔/枚举切换 + m.wizardToggleField(fields, wz.fieldIndex, msg.String() == "right") + + case "enter": + // 如果在最后一个字段,或者按下 Enter 且是最后步骤,保存并运行 + if int(wz.step) == 2 && wz.fieldIndex == maxField { + return m, m.saveWizard(true) + } + // 否则 Next / 保存 + if wz.fieldIndex == maxField { + wz.step++ + wz.fieldIndex = 0 + } else { + wz.fieldIndex++ + } + + case "ctrl+s": + return m, m.saveWizard(false) + + case "ctrl+enter": + if int(wz.step) == 2 { + return m, m.saveWizard(true) + } + + case "backspace": + m.wizardBackspace(fields, wz.fieldIndex) + + default: + // 字符输入 + if len(msg.Runes) > 0 { + m.wizardInput(fields, wz.fieldIndex, string(msg.Runes)) + } + } + + return m, nil +} + +// ─── 字段定义 ───────────────────────────────────────────────────────────────── + +type fieldKind int + +const ( + fieldText fieldKind = iota // 自由文本输入 + fieldNumber // 数字 + fieldBool // 布尔开关 + fieldEnum // 枚举循环 + fieldAction // 动作按钮(保存/运行) +) + +type wizardField struct { + kind fieldKind + label string + getValue func(wz *wizardState) string + setValue func(wz *wizardState, s string) + options []string // 仅 fieldEnum 使用 +} + +// wizardFields 根据当前步骤和 turbo 模式动态返回字段列表。 +func (m *Model) wizardFields() []wizardField { + wz := m.wizard + if wz == nil { + return nil + } + switch wz.step { + case wizardStepBasic: + return []wizardField{ + {kind: fieldText, label: "名称", + getValue: func(wz *wizardState) string { return wz.name }, + setValue: func(wz *wizardState, s string) { wz.name = s }}, + {kind: fieldBool, label: "Turbo 模式", + getValue: func(wz *wizardState) string { return boolLabel(wz.turbo) }, + setValue: func(wz *wizardState, s string) { wz.turbo = (s == "true") }}, + {kind: fieldEnum, label: "协议", + options: []string{ + types.ProtocolOpenAICompletions, + types.ProtocolOpenAIResponses, + types.ProtocolAnthropicMessages, + }, + getValue: func(wz *wizardState) string { return wz.protocol }, + setValue: func(wz *wizardState, s string) { wz.protocol = s }}, + } + + case wizardStepEndpoint: + return []wizardField{ + {kind: fieldText, label: "接口地址 (可选)", + getValue: func(wz *wizardState) string { return wz.endpointURL }, + setValue: func(wz *wizardState, s string) { wz.endpointURL = s }}, + {kind: fieldText, label: "API Key", + getValue: func(wz *wizardState) string { return wz.apiKey }, + setValue: func(wz *wizardState, s string) { wz.apiKey = s }}, + {kind: fieldText, label: "模型", + getValue: func(wz *wizardState) string { return wz.model }, + setValue: func(wz *wizardState, s string) { wz.model = s }}, + {kind: fieldBool, label: "流式输出", + getValue: func(wz *wizardState) string { return boolLabel(wz.stream) }, + setValue: func(wz *wizardState, s string) { wz.stream = (s == "true") }}, + {kind: fieldBool, label: "Thinking 模式", + getValue: func(wz *wizardState) string { return boolLabel(wz.thinking) }, + setValue: func(wz *wizardState, s string) { wz.thinking = (s == "true") }}, + } + + case wizardStepPrompt: + base := []wizardField{ + {kind: fieldEnum, label: "Prompt 模式", + options: []string{promptModeText, promptModeFile, promptModeGenerated}, + getValue: func(wz *wizardState) string { return wz.promptMode }, + setValue: func(wz *wizardState, s string) { wz.promptMode = s }}, + } + switch wz.promptMode { + case promptModeFile: + base = append(base, wizardField{kind: fieldText, label: "文件路径", + getValue: func(wz *wizardState) string { return wz.promptFile }, + setValue: func(wz *wizardState, s string) { wz.promptFile = s }}) + case promptModeGenerated: + base = append(base, wizardField{kind: fieldNumber, label: "生成长度", + getValue: func(wz *wizardState) string { return fmt.Sprintf("%d", wz.promptLength) }, + setValue: func(wz *wizardState, s string) { + if n, err := strconv.Atoi(s); err == nil { + wz.promptLength = n + } + }}) + default: // text + base = append(base, wizardField{kind: fieldText, label: "Prompt 文本", + getValue: func(wz *wizardState) string { return wz.promptText }, + setValue: func(wz *wizardState, s string) { wz.promptText = s }}) + } + + if wz.turbo { + base = append(base, + wizardField{kind: fieldNumber, label: "初始并发", + getValue: func(wz *wizardState) string { return fmt.Sprintf("%d", wz.initConcurrency) }, + setValue: func(wz *wizardState, s string) { + if n, err := strconv.Atoi(s); err == nil && n > 0 { + wz.initConcurrency = n + } + }}, + wizardField{kind: fieldNumber, label: "最大并发", + getValue: func(wz *wizardState) string { return fmt.Sprintf("%d", wz.maxConcurrency) }, + setValue: func(wz *wizardState, s string) { + if n, err := strconv.Atoi(s); err == nil && n > 0 { + wz.maxConcurrency = n + } + }}, + wizardField{kind: fieldNumber, label: "步进大小", + getValue: func(wz *wizardState) string { return fmt.Sprintf("%d", wz.stepSize) }, + setValue: func(wz *wizardState, s string) { + if n, err := strconv.Atoi(s); err == nil && n > 0 { + wz.stepSize = n + } + }}, + wizardField{kind: fieldNumber, label: "每级请求数", + getValue: func(wz *wizardState) string { return fmt.Sprintf("%d", wz.levelRequests) }, + setValue: func(wz *wizardState, s string) { + if n, err := strconv.Atoi(s); err == nil && n > 0 { + wz.levelRequests = n + } + }}, + ) + } else { + base = append(base, + wizardField{kind: fieldNumber, label: "并发数", + getValue: func(wz *wizardState) string { return fmt.Sprintf("%d", wz.concurrency) }, + setValue: func(wz *wizardState, s string) { + if n, err := strconv.Atoi(s); err == nil && n > 0 { + wz.concurrency = n + } + }}, + wizardField{kind: fieldNumber, label: "请求总数", + getValue: func(wz *wizardState) string { return fmt.Sprintf("%d", wz.count) }, + setValue: func(wz *wizardState, s string) { + if n, err := strconv.Atoi(s); err == nil && n > 0 { + wz.count = n + } + }}, + ) + } + return base + } + return nil +} + +// wizardToggleField 处理 ←→ 按键的枚举/布尔切换。 +func (m *Model) wizardToggleField(fields []wizardField, idx int, forward bool) { + wz := m.wizard + if idx >= len(fields) { + return + } + f := fields[idx] + switch f.kind { + case fieldBool: + cur := f.getValue(wz) == "开启" + if forward { + f.setValue(wz, boolVal(!cur)) + } else { + f.setValue(wz, boolVal(cur)) + } + case fieldEnum: + cur := f.getValue(wz) + i := 0 + for j, o := range f.options { + if o == cur { + i = j + break + } + } + if forward { + i = (i + 1) % len(f.options) + } else { + i = (i - 1 + len(f.options)) % len(f.options) + } + f.setValue(wz, f.options[i]) + } +} + +func boolVal(v bool) string { + if v { + return "true" + } + return "false" +} + +// wizardInput 追加字符到文本/数字字段。 +func (m *Model) wizardInput(fields []wizardField, idx int, s string) { + wz := m.wizard + if idx >= len(fields) { + return + } + f := fields[idx] + if f.kind == fieldText { + cur := f.getValue(wz) + f.setValue(wz, cur+s) + } else if f.kind == fieldNumber { + cur := f.getValue(wz) + if len(s) == 1 && s[0] >= '0' && s[0] <= '9' { + f.setValue(wz, cur+s) + } + } +} + +// wizardBackspace 删除文本/数字字段末尾字符。 +func (m *Model) wizardBackspace(fields []wizardField, idx int) { + wz := m.wizard + if idx >= len(fields) { + return + } + f := fields[idx] + if f.kind == fieldText || f.kind == fieldNumber { + cur := []rune(f.getValue(wz)) + if len(cur) > 0 { + f.setValue(wz, string(cur[:len(cur)-1])) + } + } +} + +// ─── 保存 ───────────────────────────────────────────────────────────────────── + +// buildTaskInput 从 wizardState 构建 types.Input。 +func (m *Model) buildTaskInput() types.Input { + wz := m.wizard + inp := types.Input{ + Protocol: wz.protocol, + EndpointURL: wz.endpointURL, + ApiKey: wz.apiKey, + Model: wz.model, + Stream: wz.stream, + Thinking: wz.thinking, + Turbo: wz.turbo, + PromptMode: wz.promptMode, + PromptText: wz.promptText, + PromptFile: wz.promptFile, + PromptLength: wz.promptLength, + } + if wz.turbo { + inp.TurboConfig = types.TurboConfig{ + InitConcurrency: wz.initConcurrency, + MaxConcurrency: wz.maxConcurrency, + StepSize: wz.stepSize, + LevelRequests: wz.levelRequests, + MinSuccessRate: wz.minSuccessRate, + } + } else { + inp.Concurrency = wz.concurrency + inp.Count = wz.count + } + return inp +} + +// saveWizard 保存或创建任务。autoStart=true 表示成功后立刻运行。 +func (m *Model) saveWizard(autoStart bool) tea.Cmd { + wz := m.wizard + if wz == nil { + return nil + } + cfg := server.TaskConfig{ + Name: wz.name, + Input: m.buildTaskInput(), + } + m.wizard = nil + m.view = viewTaskList + + if wz.editingID != "" { + return m.client.UpdateTaskCmd(wz.editingID, cfg) + } + return m.client.CreateTaskCmd(cfg, autoStart) +} + +// ─── 渲染 ───────────────────────────────────────────────────────────────────── + +func (m *Model) renderWizard() string { + wz := m.wizard + if wz == nil { + return "" + } + + title := "新建任务" + if wz.editingID != "" { + title = "编辑任务" + } + + steps := []string{"基础配置", "接口配置", "参数配置"} + var stepParts []string + for i, s := range steps { + switch { + case i < int(wz.step): + stepParts = append(stepParts, m.styles.stepDone.Render("✓ "+s)) + case i == int(wz.step): + stepParts = append(stepParts, m.styles.stepActive.Render("▶ "+s)) + default: + stepParts = append(stepParts, m.styles.stepTodo.Render("○ "+s)) + } + } + stepLine := strings.Join(stepParts, " ") + + fields := m.wizardFields() + var fieldLines []string + for i, f := range fields { + label := fmt.Sprintf("%-16s", f.label) + val := f.getValue(wz) + if f.kind == fieldBool { + if wz.turbo && f.label == "Turbo 模式" { + val = m.styles.ok.Render("开启") + } else if !wz.turbo && f.label == "Turbo 模式" { + val = m.styles.muted.Render("关闭") + } + } + // mask API key display + if f.label == "API Key" && val != "" { + val = maskAPIKey(val) + } + + var line string + if i == wz.fieldIndex { + cursor := m.styles.cursor.Render("▶") + labelS := m.styles.sectionHead.Render(label) + valS := m.styles.fieldActive.Render(" " + val + " " + m.styles.cursor.Render("_")) + line = cursor + " " + labelS + " " + valS + } else { + labelS := m.styles.label.Render(label) + valS := m.styles.value.Render(val) + line = " " + labelS + " " + valS + } + fieldLines = append(fieldLines, line) + } + + // 底部操作提示 + var hints []string + hints = append(hints, m.styles.key.Render("[↑↓/Tab]")+" 切换字段") + hints = append(hints, m.styles.key.Render("[←→]")+" 切换选项") + hints = append(hints, m.styles.key.Render("[Ctrl+S]")+" 保存") + if int(wz.step) == 2 { + hints = append(hints, m.styles.key.Render("[Ctrl+Enter]")+" 保存并运行") + } else { + hints = append(hints, m.styles.key.Render("[Enter]")+" 下一步") + } + hints = append(hints, m.styles.key.Render("[Esc]")+" 取消") + hintLine := strings.Join(hints, " ") + + dialogW := m.width * 70 / 100 + if dialogW < 60 { + dialogW = 60 + } + if dialogW > m.width-4 { + dialogW = m.width - 4 + } + + inner := fmt.Sprintf("%s\n\n%s\n\n%s\n\n%s", + m.styles.sectionHead.Render(title), + stepLine, + strings.Join(fieldLines, "\n"), + hintLine, + ) + + dialog := m.styles.dialog.Width(dialogW).Render(inner) + totalH := m.height + dialogRendered := dialog + + // 垂直居中 + dialogH := strings.Count(dialogRendered, "\n") + 1 + padTop := (totalH - dialogH) / 2 + if padTop < 0 { + padTop = 0 + } + topPad := strings.Repeat("\n", padTop) + + // 水平居中(外层宽度补齐) + dialogLineW := lipgloss.Width(strings.Split(dialogRendered, "\n")[0]) + leftPad := (m.width - dialogLineW) / 2 + if leftPad < 0 { + leftPad = 0 + } + paddedLines := make([]string, 0) + for _, l := range strings.Split(dialogRendered, "\n") { + paddedLines = append(paddedLines, strings.Repeat(" ", leftPad)+l) + } + + return topPad + strings.Join(paddedLines, "\n") +} From 8ce078cdf09ab9a689b8d5597dfd4ce814bf6201 Mon Sep 17 00:00:00 2001 From: Alain Date: Sun, 17 May 2026 01:23:15 +0800 Subject: [PATCH 06/52] feat: add task wizard for creating and editing tasks with user-friendly interface - Implemented a multi-step wizard for task creation and editing in the TUI. - Introduced fields for basic information, testing parameters, and confirmation. - Added functionality to handle user input and navigation through the wizard. - Removed old styles file as styles are now integrated within the wizard implementation. --- internal/tui/contextbar.go | 95 ---- internal/tui/model.go | 411 ++++++++++++----- internal/tui/model_test.go | 81 ++-- internal/tui/page_dashboard.go | 317 -------------- internal/tui/page_reqdetail.go | 182 -------- internal/tui/page_taskdetail.go | 196 --------- internal/tui/page_tasklist.go | 412 ------------------ internal/tui/page_wizard.go | 569 ------------------------ internal/tui/pages/contextbar.go | 116 +++++ internal/tui/pages/dashboard.go | 379 ++++++++++++++++ internal/tui/pages/helpers.go | 279 ++++++++++++ internal/tui/pages/nav.go | 53 +++ internal/tui/pages/reqdetail.go | 272 ++++++++++++ internal/tui/pages/styles.go | 130 ++++++ internal/tui/pages/taskdetail.go | 291 +++++++++++++ internal/tui/pages/tasklist.go | 348 +++++++++++++++ internal/tui/pages/turbodash.go | 341 +++++++++++++++ internal/tui/pages/wizard.go | 727 +++++++++++++++++++++++++++++++ internal/tui/styles.go | 184 -------- 19 files changed, 3269 insertions(+), 2114 deletions(-) delete mode 100644 internal/tui/contextbar.go delete mode 100644 internal/tui/page_dashboard.go delete mode 100644 internal/tui/page_reqdetail.go delete mode 100644 internal/tui/page_taskdetail.go delete mode 100644 internal/tui/page_tasklist.go delete mode 100644 internal/tui/page_wizard.go create mode 100644 internal/tui/pages/contextbar.go create mode 100644 internal/tui/pages/dashboard.go create mode 100644 internal/tui/pages/helpers.go create mode 100644 internal/tui/pages/nav.go create mode 100644 internal/tui/pages/reqdetail.go create mode 100644 internal/tui/pages/styles.go create mode 100644 internal/tui/pages/taskdetail.go create mode 100644 internal/tui/pages/tasklist.go create mode 100644 internal/tui/pages/turbodash.go create mode 100644 internal/tui/pages/wizard.go delete mode 100644 internal/tui/styles.go diff --git a/internal/tui/contextbar.go b/internal/tui/contextbar.go deleted file mode 100644 index 10d84de..0000000 --- a/internal/tui/contextbar.go +++ /dev/null @@ -1,95 +0,0 @@ -package tui - -import ( - "fmt" - "strings" -) - -// contextBarItem 描述 Context Bar 中的一个可用操作。 -type contextBarItem struct { - key string - desc string -} - -// renderContextBar 渲染 Context Bar:紧贴 Footer 上方的动态操作提示行。 -// 若 items 为空则返回空字符串(不占空间)。 -func (m *Model) renderContextBar(items []contextBarItem) string { - if len(items) == 0 { - return "" - } - var parts []string - for _, item := range items { - parts = append(parts, fmt.Sprintf("%s %s", - m.styles.key.Render("["+item.key+"]"), - m.styles.muted.Render(item.desc), - )) - } - bar := " " + strings.Join(parts, " ") - barW := m.width - if barW < 1 { - barW = 80 - } - return m.styles.footer.Width(barW).Render(bar) -} - -// contextBarItems_taskList 返回任务列表页的 Context Bar 内容。 -func contextBarItems_taskList(isRunning bool) []contextBarItem { - if isRunning { - return []contextBarItem{ - {"Enter", "进入仪表盘"}, - {"s", "停止"}, - {"y", "复制"}, - } - } - return []contextBarItem{ - {"Enter", "查看详情"}, - {"r", "运行"}, - {"e", "编辑"}, - {"d", "删除"}, - {"y", "复制"}, - } -} - -// contextBarItems_taskDetail 返回任务详情页的 Context Bar 内容。 -func contextBarItems_taskDetail(hasHistory bool) []contextBarItem { - if hasHistory { - return []contextBarItem{ - {"r", "生成报告"}, - {"c", "复制摘要"}, - {"Enter", "再次运行"}, - {"e", "编辑"}, - } - } - return []contextBarItem{ - {"Enter", "运行"}, - {"e", "编辑"}, - {"y", "复制"}, - {"d", "删除"}, - } -} - -// contextBarItems_dashboard_nosel 仪表盘无选中请求。 -func contextBarItems_dashboard_nosel() []contextBarItem { - return []contextBarItem{ - {"s", "停止"}, - {"b", "后台运行"}, - {"r", "提前报告"}, - } -} - -// contextBarItems_dashboard_sel 仪表盘有选中请求。 -func contextBarItems_dashboard_sel() []contextBarItem { - return []contextBarItem{ - {"Enter", "查看请求详情"}, - {"↑↓", "选择请求"}, - {"s", "停止"}, - } -} - -// contextBarItems_reqdetail 请求详情页。 -func contextBarItems_reqdetail() []contextBarItem { - return []contextBarItem{ - {"b/Esc", "返回仪表盘"}, - {"←→", "上/下一条请求"}, - } -} diff --git a/internal/tui/model.go b/internal/tui/model.go index 7367f10..8129f7d 100644 --- a/internal/tui/model.go +++ b/internal/tui/model.go @@ -4,11 +4,11 @@ package tui import ( "fmt" - "strings" tea "github.com/charmbracelet/bubbletea" - "github.com/charmbracelet/lipgloss" "github.com/yinxulai/ait/internal/server" + "github.com/yinxulai/ait/internal/tui/pages" + "github.com/yinxulai/ait/internal/types" ) // ─── 视图状态 ───────────────────────────────────────────────────────────────── @@ -20,37 +20,40 @@ const ( viewTaskDetail viewState = "task-detail" viewWizard viewState = "wizard" viewDashboard viewState = "dashboard" + viewTurboDash viewState = "turbo-dash" viewReqDetail viewState = "req-detail" ) // ─── 根 Model ───────────────────────────────────────────────────────────────── // Model 是 BubbleTea 的根状态机。 -// 所有 Server 交互均通过 Client 发出 tea.Cmd;Model 不直接 import runner/task/turbo。 +// 所有 Server 交互均通过 Client 发出 tea.Cmd;Model 不直接 import runner/task/turbo 等下层包。 type Model struct { - client *Client - styles styles - width int - height int - view viewState - status string - err error - - // 页面局部状态 - taskList taskListState - hist *historyState // 任务详情页的历史 - wizard *wizardState // nil = 向导未打开 - dash *dashboardState // nil = 无活跃运行 - reqDetail *reqDetailState // nil = 不在请求详情页 + client *Client + styles pages.Styles + width int + height int + view viewState + prevView viewState // 向导叠加时记录背景视图 + status string + err error + + // 页面局部状态(由 pages 包管理) + taskList *pages.TaskListState + detail *pages.TaskDetailState + wizard *pages.WizardState + dash *pages.DashboardState + turboDash *pages.TurboDashState + reqDetail *pages.ReqDetailState } // NewModel 创建 Model。srv 不能为 nil。 func NewModel(srv server.Server) *Model { return &Model{ - client: NewClient(srv), - styles: newStyles(), - view: viewTaskList, - taskList: taskListState{selected: 0}, + client: NewClient(srv), + styles: pages.NewStyles(), + view: viewTaskList, + taskList: pages.NewTaskListState(), } } @@ -83,10 +86,12 @@ func (m *Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { // ── 任务列表加载完成 ── case TasksLoadedMsg: - m.taskList.tasks = msg.Tasks - // 调整选中项不越界 - if m.taskList.selected >= len(msg.Tasks) { - m.taskList.selected = max(len(msg.Tasks)-1, 0) + if m.taskList == nil { + m.taskList = pages.NewTaskListState() + } + m.taskList.Tasks = msg.Tasks + if m.taskList.Selected >= len(msg.Tasks) { + m.taskList.Selected = max(len(msg.Tasks)-1, 0) } m.status = "" m.err = nil @@ -95,8 +100,7 @@ func (m *Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { // ── 任务保存完成(新建或更新) ── case TaskSavedMsg: m.status = fmt.Sprintf("任务 %q 已保存", msg.Task.Name) - // 若 AutoStart 且无活跃运行,立刻发起运行 - if msg.AutoStart && (m.dash == nil || !m.dash.isRunning()) { + if msg.AutoStart && (m.dash == nil || !m.dash.IsRunning()) { return m, tea.Batch( m.client.LoadTasksCmd(), m.client.StartRunCmd(msg.Task.ID), @@ -112,20 +116,30 @@ func (m *Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { // ── 历史加载完成 ── case HistoryLoadedMsg: - m.hist = &historyState{taskID: msg.TaskID, history: msg.History} + if m.detail != nil && m.detail.Task.ID == msg.TaskID { + autoExpand := m.view == viewTaskDetail && len(msg.History) > 0 + m.detail = pages.UpdateTaskDetailHistory(m.detail, msg.History, autoExpand) + } return m, nil // ── 运行启动 ── case RunStartedMsg: ch, cancel, firstCmd := m.client.SubscribeCmd(msg.RunID) - m.dash = &dashboardState{ - runID: msg.RunID, - taskID: msg.TaskID, - eventCh: ch, - cancelFn: cancel, - reqSel: -1, + taskMode := m.getTaskMode(msg.TaskID) + if taskMode == "turbo" { + m.turboDash = pages.NewTurboDashState(msg.RunID, msg.TaskID) + m.turboDash.EventCh = ch + m.turboDash.CancelFn = cancel + m.view = viewTurboDash + } else { + m.dash = pages.NewDashboardState(msg.RunID, msg.TaskID) + m.dash.EventCh = ch + m.dash.CancelFn = cancel + m.view = viewDashboard + } + if m.taskList != nil { + m.taskList.ActiveRuns[msg.TaskID] = nil } - m.view = viewDashboard m.status = "" return m, firstCmd @@ -135,8 +149,11 @@ func (m *Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { // ── 运行状态快照(重入仪表盘时) ── case RunStateMsg: - if m.dash != nil && msg.State != nil && m.dash.runID == msg.State.RunID { - m.dash.runState = msg.State + if m.dash != nil && msg.State != nil && m.dash.RunID == msg.State.RunID { + m.dash.RunState = msg.State + } + if m.turboDash != nil && msg.State != nil && m.turboDash.RunID == msg.State.RunID { + m.turboDash.RunState = msg.State } return m, nil @@ -157,15 +174,18 @@ func (m *Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { func (m *Model) View() string { switch m.view { case viewTaskList: - return m.renderTaskList() + return pages.RenderTaskList(m.taskList, m.styles, m.width, m.height) case viewTaskDetail: - return m.renderTaskDetail() + return pages.RenderTaskDetail(m.detail, m.styles, m.width, m.height) case viewWizard: - return m.renderWizard() + bg := m.renderBgForWizard() + return pages.RenderWizard(m.wizard, bg, m.styles, m.width, m.height) case viewDashboard: - return m.renderDashboard() + return pages.RenderDashboard(m.dash, m.dashTaskName(), m.styles, m.width, m.height) + case viewTurboDash: + return pages.RenderTurboDash(m.turboDash, m.turboDashTaskName(), m.styles, m.width, m.height) case viewReqDetail: - return m.renderReqDetail() + return pages.RenderReqDetail(m.reqDetail, m.reqDetailTaskName(), m.styles, m.width, m.height) } return "未知视图" } @@ -175,142 +195,293 @@ func (m *Model) View() string { func (m *Model) handleKey(msg tea.KeyMsg) (tea.Model, tea.Cmd) { switch m.view { case viewTaskList: - return m.handleTaskListKey(msg) + newState, cmd, nav := pages.HandleTaskListKey(m.taskList, msg, m.client) + m.taskList = newState + navCmd := m.handleNav(nav) + return m, tea.Batch(cmd, navCmd) + case viewTaskDetail: - return m, m.handleTaskDetailKey(msg) + newState, cmd, nav := pages.HandleTaskDetailKey(m.detail, msg, m.client) + m.detail = newState + navCmd := m.handleNav(nav) + return m, tea.Batch(cmd, navCmd) + case viewWizard: - return m.handleWizardKey(msg) + newState, cmd, nav := pages.HandleWizardKey(m.wizard, msg, m.client) + m.wizard = newState + navCmd := m.handleNav(nav) + return m, tea.Batch(cmd, navCmd) + case viewDashboard: - return m.handleDashboardKey(msg) + newState, cmd, nav := pages.HandleDashboardKey(m.dash, msg, m.client) + m.dash = newState + navCmd := m.handleNav(nav) + return m, tea.Batch(cmd, navCmd) + + case viewTurboDash: + newState, cmd, nav := pages.HandleTurboDashKey(m.turboDash, msg, m.client) + m.turboDash = newState + navCmd := m.handleNav(nav) + return m, tea.Batch(cmd, navCmd) + case viewReqDetail: - return m.handleReqDetailKey(msg) + newState, nav := pages.HandleReqDetailKey(m.reqDetail, msg) + m.reqDetail = newState + return m, m.handleNav(nav) } + return m, nil } +// ─── 导航处理 ───────────────────────────────────────────────────────────────── + +func (m *Model) handleNav(nav pages.NavAction) tea.Cmd { + switch nav.To { + case pages.NavNone: + return nil + + case pages.NavTaskList: + m.view = viewTaskList + return m.client.LoadTasksCmd() + + case pages.NavTaskDetail: + task := m.findTask(nav.TaskID) + if task != nil { + m.detail = pages.NewTaskDetailState(*task) + } else if m.detail == nil { + return nil + } + m.view = viewTaskDetail + if m.detail != nil { + return m.client.LoadHistoryCmd(m.detail.Task.ID, 10) + } + return nil + + case pages.NavWizard: + if nav.EditTask != nil { + m.wizard = pages.NewWizardStateEdit(nav.EditTask) + } else { + m.wizard = pages.NewWizardState() + } + m.prevView = m.view + m.view = viewWizard + return nil + + case pages.NavDashboard: + if m.dash != nil { + m.view = viewDashboard + } + return nil + + case pages.NavTurboDash: + if m.turboDash != nil { + m.view = viewTurboDash + } + return nil + + case pages.NavReqDetail: + reqs := m.collectRequests() + m.reqDetail = pages.NewReqDetailState(m.currentRunID(), reqs, nav.ReqIndex) + m.view = viewReqDetail + return nil + + case pages.NavQuit: + return tea.Quit + } + return nil +} + // ─── Server 事件处理 ────────────────────────────────────────────────────────── func (m *Model) handleServerEvent(msg ServerEventMsg) (tea.Model, tea.Cmd) { - if m.dash == nil { + e := msg.Event + + isDash := m.dash != nil && m.dash.RunID == e.RunID + isTurbo := m.turboDash != nil && m.turboDash.RunID == e.RunID + + if !isDash && !isTurbo { return m, nil } - e := msg.Event switch e.Kind { - case server.EventProgressTick: + case server.EventProgressTick, server.EventRequestDone, server.EventLevelDone: if rs, ok := e.Payload.(*server.RunState); ok { - m.dash.runState = rs + if isDash { + m.dash.RunState = rs + } else { + m.turboDash.RunState = rs + } + m.injectRunState(rs) } - case server.EventRequestDone: + case server.EventRunComplete: if rs, ok := e.Payload.(*server.RunState); ok { - m.dash.runState = rs + if isDash { + m.dash.RunState = rs + } else { + m.turboDash.RunState = rs + } } - - case server.EventLevelDone: - if rs, ok := e.Payload.(*server.RunState); ok { - m.dash.runState = rs + taskID := m.currentRunTaskID(isDash) + if m.taskList != nil { + delete(m.taskList.ActiveRuns, taskID) } - - case server.EventRunComplete: - if rs, ok := e.Payload.(*server.RunState); ok { - m.dash.runState = rs + task := m.findTask(taskID) + if task != nil { + m.detail = pages.NewTaskDetailState(*task) } - // 运行结束后保留 dash 供用户查阅,切换到详情页 m.view = viewTaskDetail return m, tea.Batch( m.client.LoadTasksCmd(), - m.client.LoadHistoryCmd(m.dash.taskID, 10), + m.client.LoadHistoryCmd(taskID, 10), ) case server.EventRunFailed: + var errorMsg string if rs, ok := e.Payload.(*server.RunState); ok { - m.dash.runState = rs + if isDash { + m.dash.RunState = rs + } else { + m.turboDash.RunState = rs + } + errorMsg = rs.ErrorMsg + } + if errorMsg == "" { + errorMsg = "运行异常终止" + } + m.err = fmt.Errorf("运行失败: %s", errorMsg) + taskID := m.currentRunTaskID(isDash) + if m.taskList != nil { + delete(m.taskList.ActiveRuns, taskID) + } + task := m.findTask(taskID) + if task != nil { + m.detail = pages.NewTaskDetailState(*task) } - m.err = fmt.Errorf("运行失败: %s", m.dash.runState.ErrorMsg) m.view = viewTaskDetail return m, tea.Batch( m.client.LoadTasksCmd(), - m.client.LoadHistoryCmd(m.dash.taskID, 10), + m.client.LoadHistoryCmd(taskID, 10), ) } - // 若 eventCh 还在,继续等待下一条事件 - if m.dash.eventCh != nil { - return m, WaitEventCmd(m.dash.eventCh) + // 继续等待下一条事件 + var ch <-chan server.Event + if isDash && m.dash.EventCh != nil { + ch = m.dash.EventCh + } else if isTurbo && m.turboDash.EventCh != nil { + ch = m.turboDash.EventCh + } + if ch != nil { + return m, WaitEventCmd(ch) } return m, nil } -// ─── 共享渲染工具 ───────────────────────────────────────────────────────────── +// ─── 辅助方法 ───────────────────────────────────────────────────────────────── -// renderHeader 渲染顶部状态栏(全宽,左侧标题 + 右侧信息)。 -func (m *Model) renderHeader(title, right string) string { - w := m.width - if w < 1 { - w = 80 +func (m *Model) getTaskMode(taskID string) string { + t := m.findTask(taskID) + if t != nil && t.Input.Turbo { + return "turbo" } - titleW := lipgloss.Width(title) - rightW := lipgloss.Width(right) - pad := w - titleW - rightW - 2 - if pad < 1 { - pad = 1 + return "standard" +} + +func (m *Model) findTask(taskID string) *types.TaskDefinition { + if m.taskList == nil { + return nil } - line := " " + title + strings.Repeat(" ", pad) + right + " " - // 截断 - if lipgloss.Width(line) > w { - line = line[:w] + for i := range m.taskList.Tasks { + if m.taskList.Tasks[i].ID == taskID { + return &m.taskList.Tasks[i] + } } - return m.styles.header.Width(w).Render(line) + return nil } -// renderFooter 渲染底部状态栏(全宽)。 -func (m *Model) renderFooter(parts ...string) string { - w := m.width - if w < 1 { - w = 80 +func (m *Model) injectRunState(rs *server.RunState) { + if m.taskList == nil || rs == nil { + return } - var visible []string - for _, p := range parts { - if p != "" { - visible = append(visible, p) - } + if rs.Status == server.RunStatusRunning { + m.taskList.ActiveRuns[rs.TaskID] = rs + } else { + delete(m.taskList.ActiveRuns, rs.TaskID) } - line := " " + strings.Join(visible, " │ ") - return m.styles.footer.Width(w).Render(line) } -// dualColumnLayout 将左右内容放入双列布局,高度限制为 maxH。 -func (m *Model) dualColumnLayout(left, right string, leftW, rightW, maxH int) string { - leftLines := strings.Split(left, "\n") - rightLines := strings.Split(right, "\n") +func (m *Model) renderBgForWizard() string { + if m.prevView == viewTaskDetail { + return pages.RenderTaskDetail(m.detail, m.styles, m.width, m.height) + } + return pages.RenderTaskList(m.taskList, m.styles, m.width, m.height) +} - // 裁剪至 maxH - if len(leftLines) > maxH { - leftLines = leftLines[:maxH] +func (m *Model) dashTaskName() string { + if m.dash == nil { + return "─" } - if len(rightLines) > maxH { - rightLines = rightLines[:maxH] + t := m.findTask(m.dash.TaskID) + if t != nil { + return t.Name } - // 补齐行数 - for len(leftLines) < maxH { - leftLines = append(leftLines, "") + return m.dash.TaskID +} + +func (m *Model) turboDashTaskName() string { + if m.turboDash == nil { + return "─" } - for len(rightLines) < maxH { - rightLines = append(rightLines, "") + t := m.findTask(m.turboDash.TaskID) + if t != nil { + return t.Name } + return m.turboDash.TaskID +} - var rows []string - for i := 0; i < maxH; i++ { - lLine := leftLines[i] - rLine := rightLines[i] - lW := lipgloss.Width(lLine) - if lW < leftW { - lLine += strings.Repeat(" ", leftW-lW) +func (m *Model) reqDetailTaskName() string { + if m.dash != nil { + if t := m.findTask(m.dash.TaskID); t != nil { + return t.Name + } + } + if m.turboDash != nil { + if t := m.findTask(m.turboDash.TaskID); t != nil { + return t.Name } - rows = append(rows, lLine+" "+rLine) } - return strings.Join(rows, "\n") + return "─" +} + +func (m *Model) currentRunID() server.RunID { + if m.dash != nil { + return m.dash.RunID + } + if m.turboDash != nil { + return m.turboDash.RunID + } + return "" +} + +func (m *Model) currentRunTaskID(isDash bool) string { + if isDash && m.dash != nil { + return m.dash.TaskID + } + if m.turboDash != nil { + return m.turboDash.TaskID + } + return "" +} + +func (m *Model) collectRequests() []*server.RequestMetrics { + if m.dash != nil && m.dash.RunState != nil { + return m.dash.RunState.Requests + } + if m.turboDash != nil && m.turboDash.RunState != nil { + return m.turboDash.RunState.Requests + } + return nil } // ─── 工具 ───────────────────────────────────────────────────────────────────── diff --git a/internal/tui/model_test.go b/internal/tui/model_test.go index a03ffea..03ee1b5 100644 --- a/internal/tui/model_test.go +++ b/internal/tui/model_test.go @@ -4,6 +4,7 @@ import ( "testing" "github.com/yinxulai/ait/internal/server" + "github.com/yinxulai/ait/internal/tui/pages" "github.com/yinxulai/ait/internal/types" ) @@ -43,22 +44,22 @@ func TestNewModel_InitialState(t *testing.T) { } } -// ─── Wizard: openWizard + buildTaskInput ────────────────────────────────────── +// ─── Wizard: NewWizardState + BuildTaskConfig ────────────────────────────────── func TestOpenWizard_NewTask_Defaults(t *testing.T) { m := NewModel(&stubServer{}) - m.openWizard(nil) + m.wizard = pages.NewWizardState() if m.wizard == nil { - t.Fatal("wizard should not be nil after openWizard") + t.Fatal("wizard should not be nil after NewWizardState") } - if m.wizard.editingID != "" { - t.Errorf("new task wizard should have empty editingID, got %q", m.wizard.editingID) + if m.wizard.EditingID != "" { + t.Errorf("new task wizard should have empty EditingID, got %q", m.wizard.EditingID) } - if m.wizard.concurrency <= 0 { - t.Errorf("default concurrency should be positive, got %d", m.wizard.concurrency) + if m.wizard.Concurrency <= 0 { + t.Errorf("default concurrency should be positive, got %d", m.wizard.Concurrency) } - if m.wizard.promptMode != promptModeText { - t.Errorf("default promptMode = %q, want %q", m.wizard.promptMode, promptModeText) + if m.wizard.PromptMode != pages.PromptModeText { + t.Errorf("default PromptMode = %q, want %q", m.wizard.PromptMode, pages.PromptModeText) } } @@ -73,37 +74,38 @@ func TestOpenWizard_EditTask_Populate(t *testing.T) { ApiKey: "sk-test", Concurrency: 5, Count: 50, - PromptMode: promptModeText, + PromptMode: pages.PromptModeText, PromptText: "hello", }, } - m.openWizard(&task) + m.wizard = pages.NewWizardStateEdit(&task) if m.wizard == nil { t.Fatal("wizard should not be nil") } - if m.wizard.editingID != "task-123" { - t.Errorf("editingID = %q, want %q", m.wizard.editingID, "task-123") + if m.wizard.EditingID != "task-123" { + t.Errorf("EditingID = %q, want %q", m.wizard.EditingID, "task-123") } - if m.wizard.model != "gpt-4" { - t.Errorf("model = %q, want %q", m.wizard.model, "gpt-4") + if m.wizard.Model != "gpt-4" { + t.Errorf("Model = %q, want %q", m.wizard.Model, "gpt-4") } - if m.wizard.concurrency != 5 { - t.Errorf("concurrency = %d, want 5", m.wizard.concurrency) + if m.wizard.Concurrency != 5 { + t.Errorf("Concurrency = %d, want 5", m.wizard.Concurrency) } } func TestBuildTaskInput_Standard(t *testing.T) { m := NewModel(&stubServer{}) - m.openWizard(nil) + m.wizard = pages.NewWizardState() wz := m.wizard - wz.model = "gpt-4.1" - wz.apiKey = "sk-test" - wz.concurrency = 8 - wz.count = 120 - wz.promptMode = promptModeText - wz.promptText = "hello" + wz.Model = "gpt-4.1" + wz.APIKey = "sk-test" + wz.Concurrency = 8 + wz.Count = 120 + wz.PromptMode = pages.PromptModeText + wz.PromptText = "hello" - inp := m.buildTaskInput() + cfg := wz.BuildTaskConfig() + inp := cfg.Input if inp.Model != "gpt-4.1" { t.Errorf("model = %q, want gpt-4.1", inp.Model) } @@ -113,7 +115,7 @@ func TestBuildTaskInput_Standard(t *testing.T) { if inp.Count != 120 { t.Errorf("count = %d, want 120", inp.Count) } - if inp.PromptMode != promptModeText || inp.PromptText != "hello" { + if inp.PromptMode != pages.PromptModeText || inp.PromptText != "hello" { t.Errorf("unexpected prompt config: mode=%q text=%q", inp.PromptMode, inp.PromptText) } if inp.Turbo { @@ -123,27 +125,28 @@ func TestBuildTaskInput_Standard(t *testing.T) { func TestBuildTaskInput_Turbo(t *testing.T) { m := NewModel(&stubServer{}) - m.openWizard(nil) + m.wizard = pages.NewWizardState() wz := m.wizard - wz.model = "claude-3-7-sonnet" - wz.apiKey = "sk-ant" - wz.protocol = types.ProtocolAnthropicMessages - wz.turbo = true - wz.initConcurrency = 1 - wz.maxConcurrency = 12 - wz.stepSize = 2 - wz.levelRequests = 20 - wz.promptMode = promptModeGenerated - wz.promptLength = 256 + wz.Model = "claude-3-7-sonnet" + wz.APIKey = "sk-ant" + wz.Protocol = types.ProtocolAnthropicMessages + wz.Turbo = true + wz.InitConcurrency = 1 + wz.MaxConcurrency = 12 + wz.StepSize = 2 + wz.LevelRequests = 20 + wz.PromptMode = pages.PromptModeGenerated + wz.PromptLength = 256 - inp := m.buildTaskInput() + cfg := wz.BuildTaskConfig() + inp := cfg.Input if !inp.Turbo { t.Error("expected Turbo=true") } if inp.TurboConfig.MaxConcurrency != 12 { t.Errorf("MaxConcurrency = %d, want 12", inp.TurboConfig.MaxConcurrency) } - if inp.PromptMode != promptModeGenerated || inp.PromptLength != 256 { + if inp.PromptMode != pages.PromptModeGenerated || inp.PromptLength != 256 { t.Errorf("unexpected prompt config: mode=%q len=%d", inp.PromptMode, inp.PromptLength) } if inp.Protocol != types.ProtocolAnthropicMessages { diff --git a/internal/tui/page_dashboard.go b/internal/tui/page_dashboard.go deleted file mode 100644 index 4f982c1..0000000 --- a/internal/tui/page_dashboard.go +++ /dev/null @@ -1,317 +0,0 @@ -package tui - -import ( - "fmt" - "strings" - - tea "github.com/charmbracelet/bubbletea" - "github.com/yinxulai/ait/internal/server" -) - -// dashboardState 仪表盘页的局部状态。 -type dashboardState struct { - runID server.RunID - taskID string - eventCh <-chan server.Event // nil 表示已后台/已结束 - cancelFn server.CancelFunc - runState *server.RunState - reqSel int // 选中请求的 index(-1 = 无选中) - reqOff int // 请求列表滚动偏移 -} - -// isRunning 判断仪表盘内的运行是否仍在进行中。 -func (d *dashboardState) isRunning() bool { - if d == nil || d.runState == nil { - return false - } - return d.runState.Status == server.RunStatusRunning -} - -// ─── 按键处理 ───────────────────────────────────────────────────────────────── - -func (m *Model) handleDashboardKey(msg tea.KeyMsg) (tea.Model, tea.Cmd) { - if m.dash == nil { - m.view = viewTaskList - return m, nil - } - d := m.dash - var reqs []*server.RequestMetrics - if d.runState != nil { - reqs = d.runState.Requests - } - - switch msg.String() { - case "up", "k": - if d.reqSel > 0 { - d.reqSel-- - } else if len(reqs) > 0 { - d.reqSel = len(reqs) - 1 - } - m.adjustReqOffset() - - case "down", "j": - if d.reqSel < len(reqs)-1 { - d.reqSel++ - } else { - d.reqSel = 0 - } - m.adjustReqOffset() - - case "enter": - if d.reqSel >= 0 && d.reqSel < len(reqs) { - m.reqDetail = &reqDetailState{ - runID: d.runID, - requests: reqs, - index: d.reqSel, - } - m.view = viewReqDetail - } - - case "s": - if d.isRunning() { - return m, m.client.StopRunCmd(d.runID) - } - - case "b": - // 后台运行:取消订阅,返回任务列表,保留 dash 状态 - if d.cancelFn != nil { - d.cancelFn() - } - d.eventCh = nil - d.cancelFn = nil - m.view = viewTaskList - m.status = fmt.Sprintf("运行 %s 已转入后台", d.runID) - - case "r": - // 生成报告 - if d.runState != nil && !d.isRunning() { - return m, m.client.GenerateReportCmd(d.runID, server.ReportFormatJSON) - } - - case "left", "esc": - if !d.isRunning() { - // 运行已结束,直接返回任务详情 - if d.cancelFn != nil { - d.cancelFn() - } - m.dash = nil - m.view = viewTaskDetail - } - - case "q": - return m, tea.Quit - } - - return m, nil -} - -// adjustReqOffset 根据 reqSel 调整列表的可见窗口。 -func (m *Model) adjustReqOffset() { - if m.dash == nil { - return - } - visH := m.height - 10 - if visH < 5 { - visH = 5 - } - sel := m.dash.reqSel - off := m.dash.reqOff - if sel < off { - off = sel - } else if sel >= off+visH { - off = sel - visH + 1 - } - m.dash.reqOff = off -} - -// ─── 渲染 ───────────────────────────────────────────────────────────────────── - -func (m *Model) renderDashboard() string { - if m.dash == nil || m.width == 0 { - return "加载中..." - } - d := m.dash - rs := d.runState - - statusStr := "等待中" - if rs != nil { - switch rs.Status { - case server.RunStatusRunning: - statusStr = m.styles.ok.Render("运行中") - case server.RunStatusCompleted: - statusStr = m.styles.ok.Render("已完成") - case server.RunStatusFailed: - statusStr = m.styles.errStyle.Render("失败") - case server.RunStatusStopped: - statusStr = m.styles.muted.Render("已停止") - } - } - header := m.renderHeader("AIT 仪表盘", statusStr) - - var cbItems []contextBarItem - if d.reqSel >= 0 { - cbItems = contextBarItems_dashboard_sel() - } else { - cbItems = contextBarItems_dashboard_nosel() - } - contextBar := m.renderContextBar(cbItems) - - var footerRight string - if rs != nil && rs.TotalReqs > 0 { - footerRight = fmt.Sprintf("%d/%d 请求", rs.DoneReqs, rs.TotalReqs) - } - footer := m.renderFooter("[s] 停止", "[b] 后台", "[r] 报告", footerRight) - - cbH := 0 - if contextBar != "" { - cbH = 1 - } - contentH := m.height - 1 - cbH - 1 - if contentH < 4 { - contentH = 4 - } - - leftW := (m.width - 4) * 55 / 100 - rightW := m.width - 4 - leftW - - leftContent := m.buildDashLeft(contentH, leftW) - rightContent := m.buildDashRight(contentH) - mid := m.dualColumnLayout(leftContent, rightContent, leftW, rightW, contentH) - - parts := []string{header, mid} - if contextBar != "" { - parts = append(parts, contextBar) - } - parts = append(parts, footer) - return strings.Join(parts, "\n") -} - -func (m *Model) buildDashLeft(maxH, width int) string { - d := m.dash - rs := d.runState - var lines []string - - // ── 汇总指标 ── - lines = append(lines, m.styles.sectionHead.Render("实时指标")) - lines = append(lines, "") - - if rs == nil { - lines = append(lines, m.styles.muted.Render(" 等待数据...")) - return strings.Join(lines, "\n") - } - - // 进度条 - pbW := width - 20 - if pbW < 10 { - pbW = 10 - } - pct := 0.0 - if rs.TotalReqs > 0 { - pct = float64(rs.DoneReqs) * 100 / float64(rs.TotalReqs) - } - pb := progressBar(rs.DoneReqs, rs.TotalReqs, pbW) - lines = append(lines, fmt.Sprintf(" 进度 %s %5.1f%%", pb, pct)) - lines = append(lines, "") - - lines = append(lines, row(m, "总请求数 ", fmt.Sprintf("%d", rs.TotalReqs))) - lines = append(lines, row(m, "已完成 ", fmt.Sprintf("%d", rs.DoneReqs))) - lines = append(lines, row(m, "成功 ", m.styles.ok.Render(fmt.Sprintf("%d", rs.SuccessReqs)))) - lines = append(lines, row(m, "失败 ", m.styles.errStyle.Render(fmt.Sprintf("%d", rs.FailedReqs)))) - lines = append(lines, row(m, "成功率 ", fmt.Sprintf("%.2f%%", rs.SuccessRate))) - lines = append(lines, "") - lines = append(lines, m.styles.sectionHead.Render("性能指标")) - lines = append(lines, "") - lines = append(lines, row(m, "平均 TPS ", m.styles.metricVal.Render(fmt.Sprintf("%.2f tok/s", rs.AvgTPS)))) - lines = append(lines, row(m, "平均 TTFT ", m.styles.metricVal.Render(fmt.Sprintf("%.0f ms", float64(rs.AvgTTFT.Milliseconds()))))) - lines = append(lines, row(m, "缓存命中率 ", fmt.Sprintf("%.2f%%", rs.CacheHitRate))) - - if rs.Mode == "turbo" && len(rs.Levels) > 0 { - lines = append(lines, "") - lines = append(lines, m.styles.sectionHead.Render("Turbo 并发探测")) - lines = append(lines, "") - for i, lv := range rs.Levels { - sel := "" - if i == rs.CurrentLevel { - sel = m.styles.ok.Render("▶") - } else { - sel = " " - } - stableStr := m.styles.muted.Render("探测中") - if lv.Stable { - stableStr = m.styles.ok.Render("稳定") - } else if lv.StopReason != "" { - stableStr = m.styles.errStyle.Render("停止") - } - lines = append(lines, fmt.Sprintf("%s 并发%3d TPS %5.1f 成功率 %5.1f%% %s", - sel, lv.Concurrency, lv.AvgTPS, lv.SuccessRate, stableStr)) - } - } - - if rs.ErrorMsg != "" { - lines = append(lines, "") - lines = append(lines, m.styles.errStyle.Render("错误: "+truncate(rs.ErrorMsg, width-10))) - } - - return strings.Join(lines, "\n") -} - -func (m *Model) buildDashRight(maxH int) string { - d := m.dash - rs := d.runState - var lines []string - - lines = append(lines, m.styles.sectionHead.Render("请求列表")) - lines = append(lines, "") - lines = append(lines, m.styles.tableHead.Render( - fmt.Sprintf(" %-4s %-5s %8s %8s %8s %-6s", "#", "状态", "耗时", "TTFT", "TPS", "Token"), - )) - lines = append(lines, m.styles.muted.Render(strings.Repeat("─", 56))) - - if rs == nil || len(rs.Requests) == 0 { - lines = append(lines, m.styles.muted.Render(" 暂无请求...")) - return strings.Join(lines, "\n") - } - - visH := maxH - len(lines) - 1 - if visH < 1 { - visH = 1 - } - - start := d.reqOff - if start < 0 { - start = 0 - } - end := start + visH - if end > len(rs.Requests) { - end = len(rs.Requests) - } - - for i := start; i < end; i++ { - r := rs.Requests[i] - statusIcon := m.styles.ok.Render("✓") - if !r.Success { - statusIcon = m.styles.errStyle.Render("✗") - } - line := fmt.Sprintf(" %3d %s %7dms %7dms %7.1f %-6d", - r.Index+1, - statusIcon, - r.TotalTime.Milliseconds(), - r.TTFT.Milliseconds(), - r.TPS, - r.CompletionTokens, - ) - if i == d.reqSel { - lines = append(lines, m.styles.tableRowSel.Render(line)) - } else { - lines = append(lines, line) - } - } - - // 滚动提示 - if len(rs.Requests) > visH { - lines = append(lines, m.styles.muted.Render( - fmt.Sprintf(" %d/%d 请求 [↑↓] 滚动", len(rs.Requests), len(rs.Requests)))) - } - - return strings.Join(lines, "\n") -} diff --git a/internal/tui/page_reqdetail.go b/internal/tui/page_reqdetail.go deleted file mode 100644 index 199852b..0000000 --- a/internal/tui/page_reqdetail.go +++ /dev/null @@ -1,182 +0,0 @@ -package tui - -import ( - "fmt" - "strings" - - tea "github.com/charmbracelet/bubbletea" - "github.com/yinxulai/ait/internal/server" -) - -// reqDetailState 请求详情页的状态。 -type reqDetailState struct { - runID server.RunID - requests []*server.RequestMetrics - index int -} - -// ─── 按键处理 ───────────────────────────────────────────────────────────────── - -func (m *Model) handleReqDetailKey(msg tea.KeyMsg) (tea.Model, tea.Cmd) { - rd := m.reqDetail - if rd == nil { - m.view = viewDashboard - return m, nil - } - - switch msg.String() { - case "left", "h": - if rd.index > 0 { - rd.index-- - } else { - rd.index = len(rd.requests) - 1 - } - case "right", "l": - if rd.index < len(rd.requests)-1 { - rd.index++ - } else { - rd.index = 0 - } - case "b", "esc", "backspace": - m.view = viewDashboard - case "q": - return m, tea.Quit - } - return m, nil -} - -// ─── 渲染 ───────────────────────────────────────────────────────────────────── - -func (m *Model) renderReqDetail() string { - rd := m.reqDetail - if rd == nil || m.width == 0 { - return "加载中..." - } - if len(rd.requests) == 0 { - return "无请求数据" - } - - idx := rd.index - if idx < 0 { - idx = 0 - } - if idx >= len(rd.requests) { - idx = len(rd.requests) - 1 - } - r := rd.requests[idx] - - header := m.renderHeader( - fmt.Sprintf("AIT 请求详情 #%d / %d", idx+1, len(rd.requests)), - statusStr(m, r), - ) - contextBar := m.renderContextBar(contextBarItems_reqdetail()) - footer := m.renderFooter("[←→] 切换", "[Esc] 返回仪表盘", "", "◆ AIT") - - cbH := 0 - if contextBar != "" { - cbH = 1 - } - contentH := m.height - 1 - cbH - 1 - if contentH < 4 { - contentH = 4 - } - - leftW := (m.width - 4) * 50 / 100 - rightW := m.width - 4 - leftW - - leftContent := m.buildReqLeft(r, contentH, leftW) - rightContent := m.buildReqRight(r, contentH) - mid := m.dualColumnLayout(leftContent, rightContent, leftW, rightW, contentH) - - parts := []string{header, mid} - if contextBar != "" { - parts = append(parts, contextBar) - } - parts = append(parts, footer) - return strings.Join(parts, "\n") -} - -func statusStr(m *Model, r *server.RequestMetrics) string { - if r.Success { - return m.styles.ok.Render("成功") - } - return m.styles.errStyle.Render("失败") -} - -func (m *Model) buildReqLeft(r *server.RequestMetrics, maxH, width int) string { - var lines []string - lines = append(lines, m.styles.sectionHead.Render("时间指标")) - lines = append(lines, "") - lines = append(lines, row(m, "总耗时 ", fmt.Sprintf("%d ms", r.TotalTime.Milliseconds()))) - lines = append(lines, row(m, "TTFT ", fmt.Sprintf("%d ms", r.TTFT.Milliseconds()))) - lines = append(lines, row(m, "TPS ", m.styles.metricVal.Render(fmt.Sprintf("%.2f tok/s", r.TPS)))) - lines = append(lines, "") - lines = append(lines, m.styles.sectionHead.Render("Token 统计")) - lines = append(lines, "") - lines = append(lines, row(m, "Prompt Tok ", fmt.Sprintf("%d", r.PromptTokens))) - lines = append(lines, row(m, "Output Tok ", fmt.Sprintf("%d", r.CompletionTokens))) - lines = append(lines, row(m, "缓存命中 ", fmt.Sprintf("%d tok (%.1f%%)", r.CachedTokens, r.CacheHitRate*100))) - lines = append(lines, "") - - if r.ErrorMessage != "" { - lines = append(lines, m.styles.sectionHead.Render("错误信息")) - lines = append(lines, "") - for _, part := range wrapText(r.ErrorMessage, width-4) { - lines = append(lines, m.styles.errStyle.Render(" "+part)) - } - } - - if r.PromptText != "" { - lines = append(lines, "") - lines = append(lines, m.styles.sectionHead.Render("Prompt")) - lines = append(lines, "") - for _, part := range wrapText(r.PromptText, width-4) { - if len(lines) >= maxH-1 { - break - } - lines = append(lines, " "+part) - } - } - return strings.Join(lines, "\n") -} - -func (m *Model) buildReqRight(r *server.RequestMetrics, maxH int) string { - var lines []string - lines = append(lines, m.styles.sectionHead.Render("网络指标")) - lines = append(lines, "") - lines = append(lines, row(m, "目标 IP ", r.TargetIP)) - lines = append(lines, row(m, "DNS 解析 ", fmt.Sprintf("%d ms", r.DNSTime.Milliseconds()))) - lines = append(lines, row(m, "TCP 连接 ", fmt.Sprintf("%d ms", r.ConnectTime.Milliseconds()))) - lines = append(lines, row(m, "TLS 握手 ", fmt.Sprintf("%d ms", r.TLSTime.Milliseconds()))) - lines = append(lines, "") - - if r.ResponseText != "" { - lines = append(lines, m.styles.sectionHead.Render("Response")) - lines = append(lines, "") - for _, part := range wrapText(r.ResponseText, 40) { - if len(lines) >= maxH-1 { - break - } - lines = append(lines, " "+part) - } - } - return strings.Join(lines, "\n") -} - -// wrapText 按宽度折行(简单按字节宽度,不处理 CJK)。 -func wrapText(s string, width int) []string { - if width <= 0 { - return []string{s} - } - var result []string - runes := []rune(s) - for len(runes) > 0 { - end := width - if end > len(runes) { - end = len(runes) - } - result = append(result, string(runes[:end])) - runes = runes[end:] - } - return result -} diff --git a/internal/tui/page_taskdetail.go b/internal/tui/page_taskdetail.go deleted file mode 100644 index 1720277..0000000 --- a/internal/tui/page_taskdetail.go +++ /dev/null @@ -1,196 +0,0 @@ -package tui - -import ( - "fmt" - "strings" - "time" - - tea "github.com/charmbracelet/bubbletea" - "github.com/yinxulai/ait/internal/types" -) - -// historyState 任务详情页的历史数据。 -type historyState struct { - taskID string - history []types.TaskRunSummary -} - -// ─── 渲染 ───────────────────────────────────────────────────────────────────── - -func (m *Model) renderTaskDetail() string { - if m.width == 0 { - return "加载中..." - } - - task, ok := m.taskList.currentTask() - if !ok { - return "请先选择任务" - } - - header := m.renderHeader( - "AIT 任务详情", - task.Name, - ) - - var cbItems []contextBarItem - if m.hist != nil { - cbItems = contextBarItems_taskDetail(len(m.hist.history) > 0) - } else { - cbItems = contextBarItems_taskDetail(false) - } - contextBar := m.renderContextBar(cbItems) - footer := m.renderFooter("[←/Esc] 返回", "[r] 运行", "[e] 编辑", "◆ AIT") - - cbH := 0 - if contextBar != "" { - cbH = 1 - } - contentH := m.height - 1 - cbH - 1 - if contentH < 4 { - contentH = 4 - } - - leftW := (m.width - 4) * 55 / 100 - rightW := m.width - 4 - leftW - - leftContent := m.buildDetailLeft(task, contentH, leftW) - rightContent := m.buildDetailRight(contentH) - mid := m.dualColumnLayout(leftContent, rightContent, leftW, rightW, contentH) - - parts := []string{header, mid} - if contextBar != "" { - parts = append(parts, contextBar) - } - parts = append(parts, footer) - return strings.Join(parts, "\n") -} - -func (m *Model) buildDetailLeft(task types.TaskDefinition, maxH, width int) string { - inp := task.Input - var lines []string - - lines = append(lines, m.styles.sectionHead.Render("基本配置")) - lines = append(lines, "") - lines = append(lines, row(m, "名称 ", task.Name)) - lines = append(lines, row(m, "创建时间 ", task.CreatedAt.Format("2006-01-02 15:04:05"))) - lines = append(lines, row(m, "更新时间 ", task.UpdatedAt.Format("2006-01-02 15:04:05"))) - if task.LastRunAt != nil { - lines = append(lines, row(m, "上次运行 ", timeAgo(*task.LastRunAt))) - } - lines = append(lines, "") - lines = append(lines, m.styles.sectionHead.Render("测试参数")) - lines = append(lines, "") - lines = append(lines, row(m, "协议 ", shortProtocol(inp.NormalizedProtocol()))) - lines = append(lines, row(m, "接口地址 ", truncate(inp.ResolvedEndpointURL(), width-20))) - lines = append(lines, row(m, "API Key ", maskAPIKey(inp.ApiKey))) - lines = append(lines, row(m, "模型 ", inp.Model)) - - modeStr := "标准" - if inp.Turbo { - modeStr = "Turbo (并发探测)" - } - lines = append(lines, row(m, "测试模式 ", modeStr)) - - if inp.Turbo { - tc := inp.TurboConfig - lines = append(lines, row(m, "初始并发 ", fmt.Sprintf("%d", tc.InitConcurrency))) - lines = append(lines, row(m, "最大并发 ", fmt.Sprintf("%d", tc.MaxConcurrency))) - lines = append(lines, row(m, "步进大小 ", fmt.Sprintf("+%d", tc.StepSize))) - lines = append(lines, row(m, "每级请求 ", fmt.Sprintf("%d", tc.LevelRequests))) - lines = append(lines, row(m, "最低成功率", fmt.Sprintf("%.0f%%", tc.MinSuccessRate))) - } else { - lines = append(lines, row(m, "并发数 ", fmt.Sprintf("%d", inp.Concurrency))) - lines = append(lines, row(m, "请求总数 ", fmt.Sprintf("%d", inp.Count))) - } - - lines = append(lines, row(m, "流式输出 ", boolLabel(inp.Stream))) - lines = append(lines, row(m, "Thinking ", boolLabel(inp.Thinking))) - lines = append(lines, "") - lines = append(lines, m.styles.sectionHead.Render("Prompt 配置")) - lines = append(lines, "") - lines = append(lines, row(m, "模式 ", inp.PromptMode)) - lines = append(lines, row(m, "内容 ", truncate(promptSummary(inp), width-20))) - - if m.status != "" { - lines = append(lines, "") - lines = append(lines, m.styles.muted.Render(m.status)) - } - if m.err != nil { - lines = append(lines, m.styles.errStyle.Render("错误: "+m.err.Error())) - } - - return strings.Join(lines, "\n") -} - -func (m *Model) buildDetailRight(maxH int) string { - var lines []string - lines = append(lines, m.styles.sectionHead.Render("运行历史")) - lines = append(lines, "") - - if m.hist == nil || len(m.hist.history) == 0 { - lines = append(lines, m.styles.muted.Render(" 暂无历史记录")) - lines = append(lines, "") - lines = append(lines, m.styles.muted.Render(" 按 [Enter] 或 [r] 开始第一次运行")) - return strings.Join(lines, "\n") - } - - for i, run := range m.hist.history { - if len(lines) >= maxH-2 { - break - } - statusIcon := m.styles.ok.Render("✓") - if run.Status != "completed" { - statusIcon = m.styles.errStyle.Render("✗") - } - elapsed := run.FinishedAt.Sub(run.StartedAt) - lines = append(lines, fmt.Sprintf("%s #%d %s", - statusIcon, i+1, timeAgo(run.StartedAt))) - lines = append(lines, fmt.Sprintf(" 成功率 %.1f%% TTFT %.0fms TPS %.1f", - run.SuccessRate, float64(run.AvgTTFT.Milliseconds()), run.AvgTPS)) - lines = append(lines, fmt.Sprintf(" 耗时 %s 模式 %s", - fmtDuration(elapsed), run.Mode)) - if run.ErrorSummary != "" { - lines = append(lines, m.styles.errStyle.Render(" "+truncate(run.ErrorSummary, 36))) - } - if run.ReportJSONPath != "" { - lines = append(lines, m.styles.muted.Render(" 报告: "+truncate(run.ReportJSONPath, 32))) - } - lines = append(lines, "") - } - - return strings.Join(lines, "\n") -} - -// ─── 按键处理 ───────────────────────────────────────────────────────────────── - -func (m *Model) handleTaskDetailKey(msg interface{ String() string }) tea.Cmd { - switch msg.String() { - case "left", "esc", "b": - m.view = viewTaskList - return nil - - case "enter", "r": - if t, ok := m.taskList.currentTask(); ok { - return m.startRunIfAllowed(t.ID, false) - } - } - return nil -} - -// ─── helpers ────────────────────────────────────────────────────────────────── - -func row(m *Model, label, value string) string { - return m.styles.label.Render(label) + " " + m.styles.value.Render(value) -} - -func fmtDuration(d time.Duration) string { - ms := d.Milliseconds() - if ms < 1000 { - return fmt.Sprintf("%dms", ms) - } - s := float64(ms) / 1000 - if s < 60 { - return fmt.Sprintf("%.1fs", s) - } - return fmt.Sprintf("%.0fm%.0fs", s/60, float64(int64(s)%60)) -} diff --git a/internal/tui/page_tasklist.go b/internal/tui/page_tasklist.go deleted file mode 100644 index 6ee8fc2..0000000 --- a/internal/tui/page_tasklist.go +++ /dev/null @@ -1,412 +0,0 @@ -package tui - -import ( - "fmt" - "strings" - "time" - - tea "github.com/charmbracelet/bubbletea" - "github.com/charmbracelet/lipgloss" - "github.com/yinxulai/ait/internal/types" -) - -// taskListState 任务列表页的局部状态。 -type taskListState struct { - tasks []types.TaskDefinition - selected int -} - -// currentTask 返回当前选中的任务。 -func (s *taskListState) currentTask() (types.TaskDefinition, bool) { - if len(s.tasks) == 0 || s.selected < 0 || s.selected >= len(s.tasks) { - return types.TaskDefinition{}, false - } - return s.tasks[s.selected], true -} - -// ─── 按键处理 ───────────────────────────────────────────────────────────────── - -func (m *Model) handleTaskListKey(msg tea.KeyMsg) (tea.Model, tea.Cmd) { - s := &m.taskList - - switch msg.String() { - case "up", "k": - if s.selected > 0 { - s.selected-- - } - case "down", "j": - if s.selected < len(s.tasks)-1 { - s.selected++ - } - - case "a": - // 新建任务 — 打开向导 - m.openWizard(nil) - - case "e": - if t, ok := s.currentTask(); ok { - m.openWizard(&t) - } - - case "y": - // 复制任务 - if t, ok := s.currentTask(); ok { - return m, m.client.CopyTaskCmd(t.ID) - } - - case "d": - // 删除任务 - if t, ok := s.currentTask(); ok { - return m, m.client.DeleteTaskCmd(t.ID) - } - - case "enter": - if t, ok := s.currentTask(); ok { - // 如果是运行中任务,进入仪表盘 - if m.dash != nil && m.dash.runID != "" && m.dash.taskID == t.ID { - m.view = viewDashboard - return m, nil - } - // 否则进入任务详情并加载历史 - m.view = viewTaskDetail - return m, m.client.LoadHistoryCmd(t.ID, 10) - } - - case "r": - if t, ok := s.currentTask(); ok { - return m, m.startRunIfAllowed(t.ID, false) - } - - case "s": - // 停止当前运行中的任务(若选中的是运行中任务) - if t, ok := s.currentTask(); ok { - if m.dash != nil && m.dash.taskID == t.ID { - return m, m.client.StopRunCmd(m.dash.runID) - } - } - - case "q": - return m, tea.Quit - } - - return m, nil -} - -// ─── 渲染 ───────────────────────────────────────────────────────────────────── - -func (m *Model) renderTaskList() string { - if m.width == 0 { - return "加载中..." - } - s := &m.taskList - - lastRunStr := "" - for _, t := range s.tasks { - if t.LastRunAt != nil { - lastRunStr = "最近: " + timeAgo(*t.LastRunAt) - break - } - } - header := m.renderHeader( - "AIT 任务中心", - fmt.Sprintf("已保存任务: %d %s", len(s.tasks), lastRunStr), - ) - - // 决定 context bar 内容 - var cbItems []contextBarItem - if t, ok := s.currentTask(); ok { - isRunning := m.dash != nil && m.dash.taskID == t.ID && m.dash.isRunning() - cbItems = contextBarItems_taskList(isRunning) - } - contextBar := m.renderContextBar(cbItems) - footer := m.renderFooter("[↑↓] 选择", "[a] 新建", "[q] 退出", "◆ AIT v0.1") - - // 内容区高度 = 总高 - header(1) - contextbar - footer(1) - cbH := 0 - if contextBar != "" { - cbH = 1 - } - contentH := m.height - 1 - cbH - 1 - if contentH < 4 { - contentH = 4 - } - - leftW := (m.width - 4) * 65 / 100 - rightW := m.width - 4 - leftW - - leftContent := m.buildTaskListTable(contentH, leftW) - rightContent := m.buildTaskListSidebar(contentH) - mid := m.dualColumnLayout(leftContent, rightContent, leftW, rightW, contentH) - - parts := []string{header, mid} - if contextBar != "" { - parts = append(parts, contextBar) - } - parts = append(parts, footer) - return strings.Join(parts, "\n") -} - -func (m *Model) buildTaskListTable(maxH, width int) string { - s := &m.taskList - var lines []string - - lines = append(lines, m.styles.tableHead.Render( - fmt.Sprintf(" %-28s %-9s %-14s %s", "任务名称", "模式", "协议", "上次结果"), - )) - lines = append(lines, m.styles.muted.Render(strings.Repeat("─", width))) - - if len(s.tasks) == 0 { - lines = append(lines, "") - lines = append(lines, m.styles.muted.Render(" 暂无任务 按 [a] 新建")) - return strings.Join(lines, "\n") - } - - for i, t := range s.tasks { - if len(lines) >= maxH-1 { - break - } - - // 运行中标记 - runIndicator := " " - if m.dash != nil && m.dash.taskID == t.ID && m.dash.isRunning() { - runIndicator = m.styles.ok.Render("◉") - } - - // 模式列(手动对齐 9 列宽) - var modeRendered string - if t.Input.Turbo { - modeRendered = m.styles.tagTurbo.Render("Turbo") - } else { - modeRendered = m.styles.tagStd.Render("标准") - } - modePad := 9 - lipgloss.Width(modeRendered) - if modePad < 0 { - modePad = 0 - } - modeCol := modeRendered + strings.Repeat(" ", modePad) - - proto := shortProtocol(t.Input.NormalizedProtocol()) - lastResult := m.styles.muted.Render("从未运行") - if t.LastRunSummary != nil { - pct := t.LastRunSummary.SuccessRate - if pct >= 99 { - lastResult = m.styles.ok.Render(fmt.Sprintf("✓ %.1f%%", pct)) - } else if pct >= 90 { - lastResult = m.styles.metricVal.Render(fmt.Sprintf("%.1f%%", pct)) - } else { - lastResult = m.styles.errStyle.Render(fmt.Sprintf("✗ %.1f%%", pct)) - } - } - - nameStr := truncate(t.Name, 27) - nameCol := fmt.Sprintf("%-27s ", nameStr) - protoCol := fmt.Sprintf("%-14s ", proto) - - mainRow := runIndicator + " " + nameCol + modeCol + " " + protoCol + lastResult - if i == s.selected { - // 选中行:纯文本 + tableRowSel 背景 - plainMode := "标准" - if t.Input.Turbo { - plainMode = "Turbo" - } - plainRow := " ▶ " + nameCol + fmt.Sprintf("%-9s ", plainMode) + protoCol + lastResult - lines = append(lines, m.styles.tableRowSel.Width(width).Render(plainRow)) - } else { - lines = append(lines, mainRow) - } - - // 二级子行:配置摘要 - var sub string - if m.dash != nil && m.dash.taskID == t.ID && m.dash.isRunning() { - rs := m.dash.runState - if rs != nil { - sub = fmt.Sprintf(" %s ◉ %d/%d 成功率 %.1f%%", - truncate(t.Input.Model, 18), rs.DoneReqs, rs.TotalReqs, rs.SuccessRate) - } - } - if sub == "" { - if t.Input.Turbo { - tc := t.Input.TurboConfig - sub = fmt.Sprintf(" %s %d→%d 步进+%d 每级%d", - truncate(t.Input.Model, 18), - tc.InitConcurrency, tc.MaxConcurrency, tc.StepSize, tc.LevelRequests) - } else { - sub = fmt.Sprintf(" %s 并发%d/请求%d", - truncate(t.Input.Model, 20), t.Input.Concurrency, t.Input.Count) - } - } - if i == s.selected { - lines = append(lines, m.styles.tableRowSel.Width(width).Render(sub)) - } else { - lines = append(lines, m.styles.muted.Render(sub)) - } - lines = append(lines, "") - } - return strings.Join(lines, "\n") -} - -func (m *Model) buildTaskListSidebar(maxH int) string { - var lines []string - lines = append(lines, m.styles.sectionHead.Render("快捷操作")) - lines = append(lines, "") - lines = append(lines, " "+m.styles.key.Render("[a]")+" 新建任务") - lines = append(lines, " "+m.styles.key.Render("[Enter]")+" 查看详情 / 进仪表盘") - lines = append(lines, " "+m.styles.key.Render("[r]")+" 运行选中任务") - lines = append(lines, " "+m.styles.key.Render("[e]")+" 编辑 "+ - m.styles.key.Render("[d]")+" 删除 "+ - m.styles.key.Render("[y]")+" 复制") - lines = append(lines, "") - lines = append(lines, m.styles.muted.Render(strings.Repeat("─", 28))) - lines = append(lines, "") - lines = append(lines, m.styles.sectionHead.Render("最近执行")) - lines = append(lines, "") - - count := 0 - for _, t := range m.taskList.tasks { - if t.LastRunSummary == nil { - continue - } - s := t.LastRunSummary - icon := m.styles.ok.Render("✓") - if s.SuccessRate < 90 { - icon = m.styles.errStyle.Render("✗") - } - lines = append(lines, fmt.Sprintf(" %s %-16s %.1f%% %.0f tok/s", - icon, truncate(t.Name, 16), s.SuccessRate, s.AvgTPS)) - count++ - if count >= 5 || len(lines) >= maxH-2 { - break - } - } - if count == 0 { - lines = append(lines, m.styles.muted.Render(" 暂无记录")) - } - - if m.status != "" { - lines = append(lines, "") - lines = append(lines, m.styles.muted.Render(m.status)) - } - if m.err != nil { - lines = append(lines, m.styles.errStyle.Render("错误: "+m.err.Error())) - } - return strings.Join(lines, "\n") -} - -// startRunIfAllowed 根据是否已有运行中任务决定是否启动新运行。 -// forceStart=true 表示无论是否有其他任务都启动(用于向导 [r] 保存并运行)。 -func (m *Model) startRunIfAllowed(taskID string, forceStart bool) tea.Cmd { - if !forceStart && m.dash != nil && m.dash.isRunning() { - m.status = fmt.Sprintf("已有任务 %q 在运行中,多任务并行可能影响网络指标", - m.dash.taskID) - return nil - } - return m.client.StartRunCmd(taskID) -} - -// ─── 共享渲染工具 ───────────────────────────────────────────────────────────── - -// 这些函数被多个 page_*.go 使用,统一放在此文件。 - -func progressBar(current, total, width int) string { - if total <= 0 || width <= 0 { - return lipgloss.NewStyle().Foreground(lipgloss.Color("238")).Render(strings.Repeat("░", width)) - } - filled := current * width / total - if filled > width { - filled = width - } - bar := lipgloss.NewStyle().Foreground(colorGreen).Render(strings.Repeat("█", filled)) - empty := lipgloss.NewStyle().Foreground(lipgloss.Color("238")).Render(strings.Repeat("░", width-filled)) - return bar + empty -} - -func progressBarRed(current, total, width int) string { - if total <= 0 || width <= 0 { - return lipgloss.NewStyle().Foreground(lipgloss.Color("238")).Render(strings.Repeat("░", width)) - } - filled := current * width / total - if filled > width { - filled = width - } - bar := lipgloss.NewStyle().Foreground(colorRed).Render(strings.Repeat("█", filled)) - empty := lipgloss.NewStyle().Foreground(lipgloss.Color("238")).Render(strings.Repeat("░", width-filled)) - return bar + empty -} - -// isRunningTask 判断任务是否当前正在运行。 -func (m *Model) isRunningTask(taskID string) bool { - return m.dash != nil && m.dash.taskID == taskID && m.dash.isRunning() -} - -// ─── 工具函数 ───────────────────────────────────────────────────────────────── - -func truncate(s string, n int) string { - if n <= 0 || len([]rune(s)) <= n { - return s - } - r := []rune(s) - if n <= 3 { - return string(r[:n]) - } - return string(r[:n-3]) + "..." -} - -func timeAgo(t time.Time) string { - d := time.Since(t) - switch { - case d < time.Minute: - return fmt.Sprintf("%ds 前", int(d.Seconds())) - case d < time.Hour: - return fmt.Sprintf("%dm 前", int(d.Minutes())) - case d < 24*time.Hour: - return fmt.Sprintf("%dh 前", int(d.Hours())) - } - return t.Format("01-02 15:04") -} - -func shortProtocol(p string) string { - p = strings.ReplaceAll(p, "openai-", "") - p = strings.ReplaceAll(p, "anthropic-", "") - return p -} - -func boolLabel(v bool) string { - if v { - return "开启" - } - return "关闭" -} - -func promptSummary(input types.Input) string { - switch input.PromptMode { - case promptModeFile: - return input.PromptFile - case promptModeGenerated: - return fmt.Sprintf("长度 %d", input.PromptLength) - default: - if len([]rune(input.PromptText)) > 48 { - return string([]rune(input.PromptText)[:45]) + "..." - } - return input.PromptText - } -} - -func maskAPIKey(key string) string { - if len(key) == 0 { - return "(空)" - } - if len(key) <= 8 { - return strings.Repeat("•", len(key)) - } - return key[:4] + strings.Repeat("•", len(key)-8) + key[len(key)-4:] -} - -func wrapIndex(index, length int) int { - if length == 0 { - return 0 - } - for index < 0 { - index += length - } - return index % length -} diff --git a/internal/tui/page_wizard.go b/internal/tui/page_wizard.go deleted file mode 100644 index dd8c0d4..0000000 --- a/internal/tui/page_wizard.go +++ /dev/null @@ -1,569 +0,0 @@ -package tui - -import ( - "fmt" - "strconv" - "strings" - - tea "github.com/charmbracelet/bubbletea" - "github.com/charmbracelet/lipgloss" - "github.com/yinxulai/ait/internal/server" - "github.com/yinxulai/ait/internal/types" -) - -// ─── Prompt 模式常量 ────────────────────────────────────────────────────────── - -const ( - promptModeText = "text" - promptModeFile = "file" - promptModeGenerated = "generated" -) - -// ─── 向导状态 ───────────────────────────────────────────────────────────────── - -type wizardStep int - -const ( - wizardStepBasic wizardStep = 0 // 基础配置(名称、模式、协议) - wizardStepEndpoint wizardStep = 1 // 接口配置(URL、APIKey、模型) - wizardStepPrompt wizardStep = 2 // Prompt 配置(模式、内容、并发参数) -) - -// wizardState 向导的完整状态。 -type wizardState struct { - step wizardStep - editingID string // 非空表示编辑模式,存放被编辑任务的 ID - - // Step 0 - name string - turbo bool - protocol string // types.Protocol* 常量 - - // Step 1 - endpointURL string - apiKey string - model string - stream bool - thinking bool - - // Step 2 — Standard - concurrency int - count int - - // Step 2 — Turbo - initConcurrency int - maxConcurrency int - stepSize int - levelRequests int - minSuccessRate float64 - - // Prompt - promptMode string - promptText string - promptFile string - promptLength int - - // 当前活跃字段索引 - fieldIndex int -} - -// openWizard 打开向导。task==nil 表示新建,非 nil 表示编辑。 -func (m *Model) openWizard(task *types.TaskDefinition) { - if task == nil { - m.wizard = &wizardState{ - step: wizardStepBasic, - protocol: types.ProtocolOpenAICompletions, - concurrency: 10, - count: 100, - initConcurrency: 1, - maxConcurrency: 50, - stepSize: 5, - levelRequests: 20, - minSuccessRate: 95, - promptMode: promptModeText, - } - } else { - inp := task.Input - tc := inp.TurboConfig - m.wizard = &wizardState{ - step: wizardStepBasic, - editingID: task.ID, - name: task.Name, - turbo: inp.Turbo, - protocol: types.NormalizeProtocol(inp.Protocol), - endpointURL: inp.EndpointURL, - apiKey: inp.ApiKey, - model: inp.Model, - stream: inp.Stream, - thinking: inp.Thinking, - concurrency: inp.Concurrency, - count: inp.Count, - initConcurrency: tc.InitConcurrency, - maxConcurrency: tc.MaxConcurrency, - stepSize: tc.StepSize, - levelRequests: tc.LevelRequests, - minSuccessRate: tc.MinSuccessRate, - promptMode: inp.PromptMode, - promptText: inp.PromptText, - promptFile: inp.PromptFile, - promptLength: inp.PromptLength, - } - if m.wizard.promptMode == "" { - m.wizard.promptMode = promptModeText - } - if m.wizard.concurrency == 0 { - m.wizard.concurrency = 10 - } - if m.wizard.count == 0 { - m.wizard.count = 100 - } - } - m.view = viewWizard -} - -// ─── 按键处理 ───────────────────────────────────────────────────────────────── - -func (m *Model) handleWizardKey(msg tea.KeyMsg) (tea.Model, tea.Cmd) { - wz := m.wizard - if wz == nil { - m.view = viewTaskList - return m, nil - } - - fields := m.wizardFields() - maxField := len(fields) - 1 - - switch msg.String() { - case "esc": - m.wizard = nil - m.view = viewTaskList - return m, nil - - case "tab", "down", "j": - if wz.fieldIndex < maxField { - wz.fieldIndex++ - } - - case "shift+tab", "up", "k": - if wz.fieldIndex > 0 { - wz.fieldIndex-- - } - - case "left", "right": - // 布尔/枚举切换 - m.wizardToggleField(fields, wz.fieldIndex, msg.String() == "right") - - case "enter": - // 如果在最后一个字段,或者按下 Enter 且是最后步骤,保存并运行 - if int(wz.step) == 2 && wz.fieldIndex == maxField { - return m, m.saveWizard(true) - } - // 否则 Next / 保存 - if wz.fieldIndex == maxField { - wz.step++ - wz.fieldIndex = 0 - } else { - wz.fieldIndex++ - } - - case "ctrl+s": - return m, m.saveWizard(false) - - case "ctrl+enter": - if int(wz.step) == 2 { - return m, m.saveWizard(true) - } - - case "backspace": - m.wizardBackspace(fields, wz.fieldIndex) - - default: - // 字符输入 - if len(msg.Runes) > 0 { - m.wizardInput(fields, wz.fieldIndex, string(msg.Runes)) - } - } - - return m, nil -} - -// ─── 字段定义 ───────────────────────────────────────────────────────────────── - -type fieldKind int - -const ( - fieldText fieldKind = iota // 自由文本输入 - fieldNumber // 数字 - fieldBool // 布尔开关 - fieldEnum // 枚举循环 - fieldAction // 动作按钮(保存/运行) -) - -type wizardField struct { - kind fieldKind - label string - getValue func(wz *wizardState) string - setValue func(wz *wizardState, s string) - options []string // 仅 fieldEnum 使用 -} - -// wizardFields 根据当前步骤和 turbo 模式动态返回字段列表。 -func (m *Model) wizardFields() []wizardField { - wz := m.wizard - if wz == nil { - return nil - } - switch wz.step { - case wizardStepBasic: - return []wizardField{ - {kind: fieldText, label: "名称", - getValue: func(wz *wizardState) string { return wz.name }, - setValue: func(wz *wizardState, s string) { wz.name = s }}, - {kind: fieldBool, label: "Turbo 模式", - getValue: func(wz *wizardState) string { return boolLabel(wz.turbo) }, - setValue: func(wz *wizardState, s string) { wz.turbo = (s == "true") }}, - {kind: fieldEnum, label: "协议", - options: []string{ - types.ProtocolOpenAICompletions, - types.ProtocolOpenAIResponses, - types.ProtocolAnthropicMessages, - }, - getValue: func(wz *wizardState) string { return wz.protocol }, - setValue: func(wz *wizardState, s string) { wz.protocol = s }}, - } - - case wizardStepEndpoint: - return []wizardField{ - {kind: fieldText, label: "接口地址 (可选)", - getValue: func(wz *wizardState) string { return wz.endpointURL }, - setValue: func(wz *wizardState, s string) { wz.endpointURL = s }}, - {kind: fieldText, label: "API Key", - getValue: func(wz *wizardState) string { return wz.apiKey }, - setValue: func(wz *wizardState, s string) { wz.apiKey = s }}, - {kind: fieldText, label: "模型", - getValue: func(wz *wizardState) string { return wz.model }, - setValue: func(wz *wizardState, s string) { wz.model = s }}, - {kind: fieldBool, label: "流式输出", - getValue: func(wz *wizardState) string { return boolLabel(wz.stream) }, - setValue: func(wz *wizardState, s string) { wz.stream = (s == "true") }}, - {kind: fieldBool, label: "Thinking 模式", - getValue: func(wz *wizardState) string { return boolLabel(wz.thinking) }, - setValue: func(wz *wizardState, s string) { wz.thinking = (s == "true") }}, - } - - case wizardStepPrompt: - base := []wizardField{ - {kind: fieldEnum, label: "Prompt 模式", - options: []string{promptModeText, promptModeFile, promptModeGenerated}, - getValue: func(wz *wizardState) string { return wz.promptMode }, - setValue: func(wz *wizardState, s string) { wz.promptMode = s }}, - } - switch wz.promptMode { - case promptModeFile: - base = append(base, wizardField{kind: fieldText, label: "文件路径", - getValue: func(wz *wizardState) string { return wz.promptFile }, - setValue: func(wz *wizardState, s string) { wz.promptFile = s }}) - case promptModeGenerated: - base = append(base, wizardField{kind: fieldNumber, label: "生成长度", - getValue: func(wz *wizardState) string { return fmt.Sprintf("%d", wz.promptLength) }, - setValue: func(wz *wizardState, s string) { - if n, err := strconv.Atoi(s); err == nil { - wz.promptLength = n - } - }}) - default: // text - base = append(base, wizardField{kind: fieldText, label: "Prompt 文本", - getValue: func(wz *wizardState) string { return wz.promptText }, - setValue: func(wz *wizardState, s string) { wz.promptText = s }}) - } - - if wz.turbo { - base = append(base, - wizardField{kind: fieldNumber, label: "初始并发", - getValue: func(wz *wizardState) string { return fmt.Sprintf("%d", wz.initConcurrency) }, - setValue: func(wz *wizardState, s string) { - if n, err := strconv.Atoi(s); err == nil && n > 0 { - wz.initConcurrency = n - } - }}, - wizardField{kind: fieldNumber, label: "最大并发", - getValue: func(wz *wizardState) string { return fmt.Sprintf("%d", wz.maxConcurrency) }, - setValue: func(wz *wizardState, s string) { - if n, err := strconv.Atoi(s); err == nil && n > 0 { - wz.maxConcurrency = n - } - }}, - wizardField{kind: fieldNumber, label: "步进大小", - getValue: func(wz *wizardState) string { return fmt.Sprintf("%d", wz.stepSize) }, - setValue: func(wz *wizardState, s string) { - if n, err := strconv.Atoi(s); err == nil && n > 0 { - wz.stepSize = n - } - }}, - wizardField{kind: fieldNumber, label: "每级请求数", - getValue: func(wz *wizardState) string { return fmt.Sprintf("%d", wz.levelRequests) }, - setValue: func(wz *wizardState, s string) { - if n, err := strconv.Atoi(s); err == nil && n > 0 { - wz.levelRequests = n - } - }}, - ) - } else { - base = append(base, - wizardField{kind: fieldNumber, label: "并发数", - getValue: func(wz *wizardState) string { return fmt.Sprintf("%d", wz.concurrency) }, - setValue: func(wz *wizardState, s string) { - if n, err := strconv.Atoi(s); err == nil && n > 0 { - wz.concurrency = n - } - }}, - wizardField{kind: fieldNumber, label: "请求总数", - getValue: func(wz *wizardState) string { return fmt.Sprintf("%d", wz.count) }, - setValue: func(wz *wizardState, s string) { - if n, err := strconv.Atoi(s); err == nil && n > 0 { - wz.count = n - } - }}, - ) - } - return base - } - return nil -} - -// wizardToggleField 处理 ←→ 按键的枚举/布尔切换。 -func (m *Model) wizardToggleField(fields []wizardField, idx int, forward bool) { - wz := m.wizard - if idx >= len(fields) { - return - } - f := fields[idx] - switch f.kind { - case fieldBool: - cur := f.getValue(wz) == "开启" - if forward { - f.setValue(wz, boolVal(!cur)) - } else { - f.setValue(wz, boolVal(cur)) - } - case fieldEnum: - cur := f.getValue(wz) - i := 0 - for j, o := range f.options { - if o == cur { - i = j - break - } - } - if forward { - i = (i + 1) % len(f.options) - } else { - i = (i - 1 + len(f.options)) % len(f.options) - } - f.setValue(wz, f.options[i]) - } -} - -func boolVal(v bool) string { - if v { - return "true" - } - return "false" -} - -// wizardInput 追加字符到文本/数字字段。 -func (m *Model) wizardInput(fields []wizardField, idx int, s string) { - wz := m.wizard - if idx >= len(fields) { - return - } - f := fields[idx] - if f.kind == fieldText { - cur := f.getValue(wz) - f.setValue(wz, cur+s) - } else if f.kind == fieldNumber { - cur := f.getValue(wz) - if len(s) == 1 && s[0] >= '0' && s[0] <= '9' { - f.setValue(wz, cur+s) - } - } -} - -// wizardBackspace 删除文本/数字字段末尾字符。 -func (m *Model) wizardBackspace(fields []wizardField, idx int) { - wz := m.wizard - if idx >= len(fields) { - return - } - f := fields[idx] - if f.kind == fieldText || f.kind == fieldNumber { - cur := []rune(f.getValue(wz)) - if len(cur) > 0 { - f.setValue(wz, string(cur[:len(cur)-1])) - } - } -} - -// ─── 保存 ───────────────────────────────────────────────────────────────────── - -// buildTaskInput 从 wizardState 构建 types.Input。 -func (m *Model) buildTaskInput() types.Input { - wz := m.wizard - inp := types.Input{ - Protocol: wz.protocol, - EndpointURL: wz.endpointURL, - ApiKey: wz.apiKey, - Model: wz.model, - Stream: wz.stream, - Thinking: wz.thinking, - Turbo: wz.turbo, - PromptMode: wz.promptMode, - PromptText: wz.promptText, - PromptFile: wz.promptFile, - PromptLength: wz.promptLength, - } - if wz.turbo { - inp.TurboConfig = types.TurboConfig{ - InitConcurrency: wz.initConcurrency, - MaxConcurrency: wz.maxConcurrency, - StepSize: wz.stepSize, - LevelRequests: wz.levelRequests, - MinSuccessRate: wz.minSuccessRate, - } - } else { - inp.Concurrency = wz.concurrency - inp.Count = wz.count - } - return inp -} - -// saveWizard 保存或创建任务。autoStart=true 表示成功后立刻运行。 -func (m *Model) saveWizard(autoStart bool) tea.Cmd { - wz := m.wizard - if wz == nil { - return nil - } - cfg := server.TaskConfig{ - Name: wz.name, - Input: m.buildTaskInput(), - } - m.wizard = nil - m.view = viewTaskList - - if wz.editingID != "" { - return m.client.UpdateTaskCmd(wz.editingID, cfg) - } - return m.client.CreateTaskCmd(cfg, autoStart) -} - -// ─── 渲染 ───────────────────────────────────────────────────────────────────── - -func (m *Model) renderWizard() string { - wz := m.wizard - if wz == nil { - return "" - } - - title := "新建任务" - if wz.editingID != "" { - title = "编辑任务" - } - - steps := []string{"基础配置", "接口配置", "参数配置"} - var stepParts []string - for i, s := range steps { - switch { - case i < int(wz.step): - stepParts = append(stepParts, m.styles.stepDone.Render("✓ "+s)) - case i == int(wz.step): - stepParts = append(stepParts, m.styles.stepActive.Render("▶ "+s)) - default: - stepParts = append(stepParts, m.styles.stepTodo.Render("○ "+s)) - } - } - stepLine := strings.Join(stepParts, " ") - - fields := m.wizardFields() - var fieldLines []string - for i, f := range fields { - label := fmt.Sprintf("%-16s", f.label) - val := f.getValue(wz) - if f.kind == fieldBool { - if wz.turbo && f.label == "Turbo 模式" { - val = m.styles.ok.Render("开启") - } else if !wz.turbo && f.label == "Turbo 模式" { - val = m.styles.muted.Render("关闭") - } - } - // mask API key display - if f.label == "API Key" && val != "" { - val = maskAPIKey(val) - } - - var line string - if i == wz.fieldIndex { - cursor := m.styles.cursor.Render("▶") - labelS := m.styles.sectionHead.Render(label) - valS := m.styles.fieldActive.Render(" " + val + " " + m.styles.cursor.Render("_")) - line = cursor + " " + labelS + " " + valS - } else { - labelS := m.styles.label.Render(label) - valS := m.styles.value.Render(val) - line = " " + labelS + " " + valS - } - fieldLines = append(fieldLines, line) - } - - // 底部操作提示 - var hints []string - hints = append(hints, m.styles.key.Render("[↑↓/Tab]")+" 切换字段") - hints = append(hints, m.styles.key.Render("[←→]")+" 切换选项") - hints = append(hints, m.styles.key.Render("[Ctrl+S]")+" 保存") - if int(wz.step) == 2 { - hints = append(hints, m.styles.key.Render("[Ctrl+Enter]")+" 保存并运行") - } else { - hints = append(hints, m.styles.key.Render("[Enter]")+" 下一步") - } - hints = append(hints, m.styles.key.Render("[Esc]")+" 取消") - hintLine := strings.Join(hints, " ") - - dialogW := m.width * 70 / 100 - if dialogW < 60 { - dialogW = 60 - } - if dialogW > m.width-4 { - dialogW = m.width - 4 - } - - inner := fmt.Sprintf("%s\n\n%s\n\n%s\n\n%s", - m.styles.sectionHead.Render(title), - stepLine, - strings.Join(fieldLines, "\n"), - hintLine, - ) - - dialog := m.styles.dialog.Width(dialogW).Render(inner) - totalH := m.height - dialogRendered := dialog - - // 垂直居中 - dialogH := strings.Count(dialogRendered, "\n") + 1 - padTop := (totalH - dialogH) / 2 - if padTop < 0 { - padTop = 0 - } - topPad := strings.Repeat("\n", padTop) - - // 水平居中(外层宽度补齐) - dialogLineW := lipgloss.Width(strings.Split(dialogRendered, "\n")[0]) - leftPad := (m.width - dialogLineW) / 2 - if leftPad < 0 { - leftPad = 0 - } - paddedLines := make([]string, 0) - for _, l := range strings.Split(dialogRendered, "\n") { - paddedLines = append(paddedLines, strings.Repeat(" ", leftPad)+l) - } - - return topPad + strings.Join(paddedLines, "\n") -} diff --git a/internal/tui/pages/contextbar.go b/internal/tui/pages/contextbar.go new file mode 100644 index 0000000..5b4e82c --- /dev/null +++ b/internal/tui/pages/contextbar.go @@ -0,0 +1,116 @@ +package pages + +import ( + "strings" + + "github.com/charmbracelet/lipgloss" +) + +// ContextBarItem 是 Context Bar 中的一个可操作项。 +type ContextBarItem struct { + Key string // 如 "Enter"、"r"、"↑↓" + Desc string // 操作描述 +} + +// RenderContextBar 渲染 Context Bar。 +// 若 items 为空则返回空字符串(不占空间)。 +func RenderContextBar(st Styles, width int, items []ContextBarItem) string { + if len(items) == 0 { + return "" + } + var parts []string + for _, item := range items { + parts = append(parts, "["+item.Key+"] "+item.Desc) + } + line := " " + strings.Join(parts, " ") + if lipgloss.Width(line) > width { + line = truncate(line, width) + } + return st.CtxBar.Width(width).Render(line) +} + +// ─── 各页面 Context Bar 内容定义 ───────────────────────────────────────────── + +// CtxBar_TaskList_Normal 普通任务选中时的 Context Bar。 +func CtxBar_TaskList_Normal() []ContextBarItem { + return []ContextBarItem{ + {Key: "Enter", Desc: "查看详情"}, + {Key: "r", Desc: "运行"}, + {Key: "e", Desc: "编辑"}, + {Key: "d", Desc: "删除"}, + {Key: "y", Desc: "复制"}, + } +} + +// CtxBar_TaskList_Running 运行中任务选中时的 Context Bar。 +func CtxBar_TaskList_Running() []ContextBarItem { + return []ContextBarItem{ + {Key: "Enter", Desc: "进入仪表盘"}, + {Key: "s", Desc: "停止"}, + {Key: "y", Desc: "复制"}, + } +} + +// CtxBar_TaskDetail_NoHistory 任务详情页,无运行记录时。 +func CtxBar_TaskDetail_NoHistory() []ContextBarItem { + return []ContextBarItem{ + {Key: "Enter/r", Desc: "运行"}, + {Key: "e", Desc: "编辑"}, + {Key: "y", Desc: "复制"}, + {Key: "d", Desc: "删除"}, + } +} + +// CtxBar_TaskDetail_HasHistory 任务详情页,有运行记录时。 +func CtxBar_TaskDetail_HasHistory() []ContextBarItem { + return []ContextBarItem{ + {Key: "r", Desc: "生成报告"}, + {Key: "c", Desc: "复制摘要"}, + {Key: "Enter/r", Desc: "再次运行"}, + {Key: "e", Desc: "编辑"}, + } +} + +// CtxBar_Dashboard_NoSel 标准仪表盘,无选中请求时。 +func CtxBar_Dashboard_NoSel() []ContextBarItem { + return []ContextBarItem{ + {Key: "s", Desc: "停止"}, + {Key: "b", Desc: "后台运行"}, + {Key: "r", Desc: "提前报告"}, + } +} + +// CtxBar_Dashboard_Sel 标准仪表盘,已选中请求时。 +func CtxBar_Dashboard_Sel() []ContextBarItem { + return []ContextBarItem{ + {Key: "Enter", Desc: "查看请求详情"}, + {Key: "↑↓", Desc: "选择请求"}, + {Key: "s", Desc: "停止"}, + } +} + +// CtxBar_TurboDash_NoSel Turbo 仪表盘,无选中级别时。 +func CtxBar_TurboDash_NoSel() []ContextBarItem { + return []ContextBarItem{ + {Key: "s", Desc: "停止"}, + {Key: "b", Desc: "后台运行"}, + {Key: "m", Desc: "标记极限"}, + } +} + +// CtxBar_TurboDash_Sel Turbo 仪表盘,已选中已完成级别时。 +func CtxBar_TurboDash_Sel() []ContextBarItem { + return []ContextBarItem{ + {Key: "Enter", Desc: "查看该级别请求列表"}, + {Key: "↑↓", Desc: "选择"}, + {Key: "s", Desc: "停止"}, + } +} + +// CtxBar_ReqDetail 请求详情页。 +func CtxBar_ReqDetail() []ContextBarItem { + return []ContextBarItem{ + {Key: "b/Esc", Desc: "返回仪表盘"}, + {Key: "←→", Desc: "上/下一条请求"}, + } +} diff --git a/internal/tui/pages/dashboard.go b/internal/tui/pages/dashboard.go new file mode 100644 index 0000000..af6547d --- /dev/null +++ b/internal/tui/pages/dashboard.go @@ -0,0 +1,379 @@ +package pages + +import ( + "fmt" + "strings" + + tea "github.com/charmbracelet/bubbletea" + "github.com/charmbracelet/lipgloss" + "github.com/yinxulai/ait/internal/server" +) + +// DashboardState 标准模式运行仪表盘页状态。 +type DashboardState struct { + RunID server.RunID + TaskID string + EventCh <-chan server.Event // nil = 已后台或已结束 + CancelFn server.CancelFunc + RunState *server.RunState + ReqSel int // 选中请求索引(-1 = 无选中) + ReqOff int // 滚动偏移 +} + +// NewDashboardState 创建仪表盘状态。 +func NewDashboardState(runID server.RunID, taskID string) *DashboardState { + return &DashboardState{ + RunID: runID, + TaskID: taskID, + ReqSel: -1, + } +} + +// IsRunning 判断运行是否仍在进行。 +func (d *DashboardState) IsRunning() bool { + if d == nil || d.RunState == nil { + return false + } + return d.RunState.Status == server.RunStatusRunning +} + +// AdjustReqOffset 根据 ReqSel 调整列表可见窗口。 +func (d *DashboardState) AdjustReqOffset(visH int) { + if d == nil { + return + } + if visH < 3 { + visH = 3 + } + sel := d.ReqSel + off := d.ReqOff + if sel < 0 { + return + } + if sel < off { + off = sel + } else if sel >= off+visH { + off = sel - visH + 1 + } + d.ReqOff = off +} + +// HandleDashboardKey 处理仪表盘页按键。 +func HandleDashboardKey(d *DashboardState, msg tea.KeyMsg, client Client) (*DashboardState, tea.Cmd, NavAction) { + nav := NavAction{} + if d == nil { + return d, nil, NavAction{To: NavTaskList} + } + + var reqs []*server.RequestMetrics + if d.RunState != nil { + reqs = d.RunState.Requests + } + + switch msg.String() { + case "up", "k": + if len(reqs) == 0 { + break + } + if d.ReqSel <= 0 { + d.ReqSel = len(reqs) - 1 + } else { + d.ReqSel-- + } + d.AdjustReqOffset(10) + + case "down", "j": + if len(reqs) == 0 { + break + } + if d.ReqSel < len(reqs)-1 { + d.ReqSel++ + } else { + d.ReqSel = 0 + } + d.AdjustReqOffset(10) + + case "enter": + if d.ReqSel >= 0 && d.ReqSel < len(reqs) { + nav = NavAction{To: NavReqDetail, ReqIndex: d.ReqSel} + } + + case "s": + if d.IsRunning() { + return d, client.StopRunCmd(d.RunID), nav + } + + case "b", "esc": + if d.CancelFn != nil { + d.CancelFn() + } + d.EventCh = nil + d.CancelFn = nil + nav = NavAction{To: NavTaskList} + + case "r": + if d.RunState != nil && !d.IsRunning() { + return d, client.GenerateReportCmd(d.RunID, server.ReportFormatJSON), nav + } + + case "q", "ctrl+c": + nav = NavAction{To: NavQuit} + } + + return d, nil, nav +} + +// RenderDashboard 渲染标准模式运行仪表盘。 +// +// 设计稿布局: +// +// ╔══ AIT 正在测试 ─ task-name ══════════╗ +// ║ ◆ AIT model · protocol · 并发: N · 请求: N ║ +// ╠══════════════════╦══════════════════════╣ +// ║ 任务参数 ║ 实时指标 ║ +// ║ ... ║ ... ║ +// ╠══════════════════╩══════════════════════╣ +// ║ 进度 ████░░ N/N 已用: Xs 剩余: ~Xs ║ +// ╠═════════════════════════════════════════╣ +// ║ 请求列表 ║ +// ║ # 状态 总耗时 TTFT Cache Token TPS║ +// ║ ────────────────────────────── ║ +// ║ ... ║ +// ╠═════════════════════════════════════════╣ +// ║ [Enter] 查看请求 [↑↓] 选择 [s] 停止 ║ ← context bar +// ╠═════════════════════════════════════════╣ +// ║ [s] 停止 [b] 后台运行 [r] 报告 [q] 退出 ║ +// ╚═════════════════════════════════════════╝ +func RenderDashboard(d *DashboardState, taskName string, st Styles, width, height int) string { + if d == nil || width == 0 { + return "加载中..." + } + rs := d.RunState + + // ── Header ── + statusStr := "等待中" + if rs != nil { + switch rs.Status { + case server.RunStatusRunning: + statusStr = st.Ok.Render("运行中") + case server.RunStatusCompleted: + statusStr = st.Ok.Render("已完成") + case server.RunStatusFailed: + statusStr = st.ErrStyle.Render("失败") + case server.RunStatusStopped: + statusStr = st.Muted.Render("已停止") + } + } + + subtitle := "─" + if rs != nil { + subtitle = fmt.Sprintf("◆ AIT %s · %s · 并发: %d · 请求: %d", + "─", "─", 0, rs.TotalReqs) + } + + header := renderHeader(st, width, + "AIT 正在测试 ─ "+truncate(taskName, 25), + statusStr, + subtitle, + "", + ) + + // ── Context Bar ── + var cbItems []ContextBarItem + if d.ReqSel >= 0 && rs != nil && d.ReqSel < len(rs.Requests) { + cbItems = CtxBar_Dashboard_Sel() + } else { + cbItems = CtxBar_Dashboard_NoSel() + } + ctxBar := RenderContextBar(st, width, cbItems) + + // ── Footer ── + footer := renderFooter(st, width, "[s] 停止", "[b] 后台运行", "[r] 提前报告", "[q] 退出") + + // ── 计算高度 ── + headerH := 2 + ctxBarH := 0 + if ctxBar != "" { + ctxBarH = 1 + } + footerH := 1 + splitH := 9 // 上方双栏区域高度 + progressH := 1 // 进度条行高 + divH := 3 // 分隔线总行数(3条分隔线各占1行) + reqListH := height - headerH - ctxBarH - footerH - splitH - progressH - divH + if reqListH < 3 { + reqListH = 3 + } + + // ── 双栏(任务参数 ║ 实时指标)── + leftW := (width - 2) * 45 / 100 + rightW := width - 2 - leftW - 1 // -1 for separator │ + leftContent := buildDashParamsPanel(d, rs, st, splitH-1, leftW) + rightContent := buildDashMetricsPanel(rs, st, splitH-1, rightW) + splitDiv := dividerLine(st, width) + split := dualColumnLayout(st, leftContent, rightContent, leftW, rightW, splitH) + + // ── 进度条 ── + progressLine := buildProgressLine(rs, st, width) + + // ── 请求列表 ── + reqDiv := dividerLine(st, width) + reqList := buildRequestList(d, rs, st, width, reqListH) + + parts := []string{header, splitDiv, split, splitDiv, progressLine, reqDiv, reqList} + if ctxBar != "" { + parts = append(parts, ctxBar) + } + parts = append(parts, footer) + return strings.Join(parts, "\n") +} + +// buildDashParamsPanel 构建左侧任务参数面板。 +func buildDashParamsPanel(d *DashboardState, rs *server.RunState, st Styles, maxH, width int) string { + var lines []string + lines = append(lines, " "+st.SectionHead.Render("任务参数")) + lines = append(lines, "") + + if rs == nil { + lines = append(lines, " "+st.Muted.Render("等待数据...")) + } else { + // 参数从 RunState 读取(实际可从 task 传入,此处用 RunState 已知信息展示) + lines = append(lines, " "+labelValue(st, "进度", fmt.Sprintf("%d/%d", rs.DoneReqs, rs.TotalReqs))) + lines = append(lines, " "+labelValue(st, "成功", fmt.Sprintf("%d", rs.SuccessReqs))) + lines = append(lines, " "+labelValue(st, "失败", fmt.Sprintf("%d", rs.FailedReqs))) + } + + for len(lines) < maxH { + lines = append(lines, "") + } + return strings.Join(lines[:maxH], "\n") +} + +// buildDashMetricsPanel 构建右侧实时指标面板。 +func buildDashMetricsPanel(rs *server.RunState, st Styles, maxH, width int) string { + var lines []string + lines = append(lines, " "+st.SectionHead.Render("实时指标")) + lines = append(lines, "") + + if rs == nil { + lines = append(lines, " "+st.Muted.Render("等待数据...")) + } else { + lines = append(lines, " "+labelValue(st, "成功率 ", + st.MetricVal.Render(fmt.Sprintf("%.1f%%", rs.SuccessRate*100)))) + lines = append(lines, " "+labelValue(st, "avg TPS ", + st.MetricVal.Render(fmt.Sprintf("%.1f tok/s", rs.AvgTPS)))) + lines = append(lines, " "+labelValue(st, "avg TTFT", + st.MetricVal.Render(fmtDuration(rs.AvgTTFT)))) + lines = append(lines, " "+labelValue(st, "缓存命中", + st.MetricVal.Render(fmt.Sprintf("%.1f%%", rs.CacheHitRate*100)))) + lines = append(lines, " "+st.Muted.Render(fmt.Sprintf(" 成功: %d 失败: %d", rs.SuccessReqs, rs.FailedReqs))) + } + + for len(lines) < maxH { + lines = append(lines, "") + } + return strings.Join(lines[:maxH], "\n") +} + +// buildProgressLine 构建进度条行。 +func buildProgressLine(rs *server.RunState, st Styles, width int) string { + if rs == nil { + return " 进度 " + st.Muted.Render("等待中...") + } + total := rs.TotalReqs + done := rs.DoneReqs + var ratio float64 + if total > 0 { + ratio = float64(done) / float64(total) + } + barW := 20 + bar := progressBar(ratio, barW) + barRendered := st.Ok.Render(strings.Repeat("█", int(ratio*float64(barW)))) + + st.Muted.Render(strings.Repeat("░", barW-int(ratio*float64(barW)))) + + elapsed := "" + if !rs.StartedAt.IsZero() { + // elapsed time display + elapsed = "─" + } + + line := fmt.Sprintf(" 进度 %s %d / %d %s", + barRendered, done, total, elapsed) + _ = bar + if lipgloss.Width(line) > width { + line = truncate(line, width) + } + return line +} + +// buildRequestList 构建请求列表区域。 +func buildRequestList(d *DashboardState, rs *server.RunState, st Styles, width, maxH int) string { + var lines []string + lines = append(lines, " "+st.SectionHead.Render("请求列表")) + + if rs == nil || len(rs.Requests) == 0 { + lines = append(lines, " "+st.Muted.Render("等待请求...")) + for len(lines) < maxH { + lines = append(lines, "") + } + return strings.Join(lines, "\n") + } + + // 表头 + lines = append(lines, " "+st.TableHead.Render( + padRight("#", 6)+padRight("状态", 6)+padRight("总耗时", 10)+ + padRight("TTFT", 10)+padRight("Cache", 8)+padRight("输出Token", 10)+"TPS")) + lines = append(lines, " "+st.Divider.Render(strings.Repeat("─", width-2))) + + reqs := rs.Requests + // 倒序展示(最新在上方) + off := d.ReqOff + for i := len(reqs) - 1 - off; i >= 0; i-- { + if len(lines) >= maxH { + break + } + r := reqs[i] + isSel := i == d.ReqSel + + statusStr := st.Ok.Render("✓") + if !r.Success { + statusStr = st.ErrStyle.Render("✗") + } + totalTime := fmtDuration(r.TotalTime) + if !r.Success && r.ErrorMessage != "" { + totalTime = st.ErrStyle.Render("timeout") + } + ttft := fmtDuration(r.TTFT) + cache := fmt.Sprintf("%.0f%%", r.CacheHitRate*100) + tok := fmt.Sprintf("%dtok", r.CompletionTokens) + tps := fmt.Sprintf("%.1f/s", r.TPS) + + row := fmt.Sprintf(" %s %s %s %s %s %s %s", + padRight(fmt.Sprintf("#%d", r.Index+1), 5), + statusStr, + padRight(totalTime, 9), + padRight(ttft, 9), + padRight(cache, 7), + padRight(tok, 9), + tps, + ) + + var rendered string + cursorStr := " " + if isSel { + cursorStr = "▶ " + } + if isSel { + rendered = st.TableRowSel.Render(cursorStr+row) + + strings.Repeat(" ", max(0, width-lipgloss.Width(cursorStr+row)-2)) + } else { + rendered = " " + st.TableRow.Render(row) + } + lines = append(lines, rendered) + } + + for len(lines) < maxH { + lines = append(lines, "") + } + return strings.Join(lines[:maxH], "\n") +} diff --git a/internal/tui/pages/helpers.go b/internal/tui/pages/helpers.go new file mode 100644 index 0000000..da6157f --- /dev/null +++ b/internal/tui/pages/helpers.go @@ -0,0 +1,279 @@ +package pages + +import ( + "fmt" + "strings" + "time" + "unicode/utf8" + + "github.com/charmbracelet/lipgloss" +) + +// ─── 文本工具 ───────────────────────────────────────────────────────────────── + +// truncate 截断字符串(按可见列宽),超出部分显示 "…"。 +func truncate(s string, maxW int) string { + if maxW <= 0 { + return "" + } + w := lipgloss.Width(s) + if w <= maxW { + return s + } + // 按 rune 截断 + runes := []rune(s) + total := 0 + for i, r := range runes { + rw := utf8.RuneLen(r) + if rw < 1 { + rw = 1 + } + if total+rw > maxW-1 { + return string(runes[:i]) + "…" + } + total += rw + } + return s +} + +// padRight 右侧补空格至 width(按可见列宽)。 +func padRight(s string, width int) string { + w := lipgloss.Width(s) + if w >= width { + return s + } + return s + strings.Repeat(" ", width-w) +} + +// wrapText 将文本按 maxW 宽度折行,返回行切片。 +func wrapText(s string, maxW int) []string { + if maxW <= 0 { + return []string{s} + } + var result []string + for _, line := range strings.Split(s, "\n") { + runes := []rune(line) + if len(runes) == 0 { + result = append(result, "") + continue + } + for len(runes) > 0 { + end := maxW + if end > len(runes) { + end = len(runes) + } + result = append(result, string(runes[:end])) + runes = runes[end:] + } + } + return result +} + +// dividerLine 生成全宽水平分隔线。 +func dividerLine(st Styles, width int) string { + if width <= 0 { + return "" + } + return st.Divider.Render(strings.Repeat("─", width)) +} + +// ─── 时间格式化 ─────────────────────────────────────────────────────────────── + +// timeAgo 将时间转换为"N分钟前"/"刚刚"等人性化描述。 +func timeAgo(t time.Time) string { + d := time.Since(t) + switch { + case d < time.Minute: + return "刚刚" + case d < time.Hour: + return fmt.Sprintf("%d 分钟前", int(d.Minutes())) + case d < 24*time.Hour: + return fmt.Sprintf("%d 小时前", int(d.Hours())) + default: + return t.Format("2006-01-02 15:04") + } +} + +// fmtDuration 格式化 Duration 为简短字符串(ms/s/min)。 +func fmtDuration(d time.Duration) string { + ms := d.Milliseconds() + if ms == 0 { + return "0ms" + } + if ms < 1000 { + return fmt.Sprintf("%dms", ms) + } + s := float64(ms) / 1000 + if s < 60 { + return fmt.Sprintf("%.1fs", s) + } + return fmt.Sprintf("%.0fm%.0fs", s/60, float64(int64(s)%60)) +} + +// ─── 布局工具 ───────────────────────────────────────────────────────────────── + +// renderHeader 渲染顶部双行标题栏。 +// 第一行:titleLeft(左)/ titleRight(右),紫色背景加粗 +// 第二行:infoLeft(左)/ infoRight(右),较暗色背景 +func renderHeader(st Styles, width int, titleLeft, titleRight, infoLeft, infoRight string) string { + w := width + if w < 1 { + w = 80 + } + // 第一行 + tl := " " + titleLeft + tr := titleRight + " " + tlW := lipgloss.Width(tl) + trW := lipgloss.Width(tr) + pad1 := w - tlW - trW + if pad1 < 0 { + pad1 = 0 + } + line1 := tl + strings.Repeat(" ", pad1) + tr + + // 第二行 + il := " " + infoLeft + ir := infoRight + " " + ilW := lipgloss.Width(il) + irW := lipgloss.Width(ir) + pad2 := w - ilW - irW + if pad2 < 0 { + pad2 = 0 + } + line2 := il + strings.Repeat(" ", pad2) + ir + + return st.Header.Width(w).Render(line1) + "\n" + + st.HeaderInfo.Width(w).Render(line2) +} + +// renderFooter 渲染底部状态栏(单行,深色背景)。 +func renderFooter(st Styles, width int, parts ...string) string { + w := width + if w < 1 { + w = 80 + } + var visible []string + for _, p := range parts { + if p != "" { + visible = append(visible, p) + } + } + line := " " + strings.Join(visible, " ") + return st.Footer.Width(w).Render(line) +} + +// dualColumnLayout 将左右两段文本排列为双栏,高度限定为 maxH。 +// 中间用竖线 │ 隔开。 +func dualColumnLayout(st Styles, left, right string, leftW, rightW, maxH int) string { + leftLines := strings.Split(left, "\n") + rightLines := strings.Split(right, "\n") + + if len(leftLines) > maxH { + leftLines = leftLines[:maxH] + } + if len(rightLines) > maxH { + rightLines = rightLines[:maxH] + } + for len(leftLines) < maxH { + leftLines = append(leftLines, "") + } + for len(rightLines) < maxH { + rightLines = append(rightLines, "") + } + + sep := st.Divider.Render("│") + var rows []string + for i := 0; i < maxH; i++ { + lLine := leftLines[i] + rLine := rightLines[i] + lW := lipgloss.Width(lLine) + if lW < leftW { + lLine += strings.Repeat(" ", leftW-lW) + } + rows = append(rows, lLine+sep+rLine) + } + return strings.Join(rows, "\n") +} + +// progressBar 生成进度条字符串(filled=已完成比例 0.0-1.0)。 +func progressBar(filled float64, width int) string { + if width <= 0 { + return "" + } + if filled < 0 { + filled = 0 + } + if filled > 1 { + filled = 1 + } + doneW := int(float64(width) * filled) + emptyW := width - doneW + done := strings.Repeat("█", doneW) + empty := strings.Repeat("░", emptyW) + return done + empty +} + +// wrapIndex 循环索引(保证 0 ≤ result < count)。 +func wrapIndex(idx, count int) int { + if count <= 0 { + return 0 + } + return ((idx % count) + count) % count +} + +// ─── 数据格式化 ─────────────────────────────────────────────────────────────── + +// maskAPIKey 遮蔽 API Key,只展示前 4 位和后 4 位。 +func maskAPIKey(key string) string { + r := []rune(key) + if len(r) <= 8 { + return strings.Repeat("•", len(r)) + } + return string(r[:4]) + "••••••••" + string(r[len(r)-4:]) +} + +// shortProtocol 将协议名缩短为仪表盘友好的短名。 +func shortProtocol(p string) string { + switch p { + case "openai-completions": + return "completions" + case "openai-responses": + return "responses" + case "anthropic-messages": + return "messages" + default: + return p + } +} + +// promptSummary 返回 Prompt 的简短摘要文本。 +func promptSummary(promptMode, promptText, promptFile string, promptLength int) string { + switch promptMode { + case "file": + return "文件: " + promptFile + case "generated": + return fmt.Sprintf("生成 %d 字符", promptLength) + default: + if promptText != "" { + r := []rune(promptText) + if len(r) > 20 { + return string(r[:20]) + "…" + } + return promptText + } + return "(未设置)" + } +} + +// boolLabel 将 bool 值转换为"开启"/"关闭"。 +func boolLabel(b bool) string { + if b { + return "开启" + } + return "关闭" +} + +// labelValue 渲染一个 label:value 对。 +func labelValue(st Styles, label, value string) string { + return st.Label.Render(label) + " " + st.Value.Render(value) +} diff --git a/internal/tui/pages/nav.go b/internal/tui/pages/nav.go new file mode 100644 index 0000000..765dcf7 --- /dev/null +++ b/internal/tui/pages/nav.go @@ -0,0 +1,53 @@ +// Package pages 包含 TUI 各页面的渲染与按键处理逻辑。 +// 本包只依赖 server 和 types 包,不 import 父包 tui,避免循环依赖。 +// 页面状态结构体、渲染函数、按键处理函数均定义于此。 +package pages + +import ( + tea "github.com/charmbracelet/bubbletea" + "github.com/yinxulai/ait/internal/server" + "github.com/yinxulai/ait/internal/types" +) + +// NavTarget 导航目标枚举。 +type NavTarget int + +const ( + NavNone NavTarget = iota + NavTaskList // 返回任务列表 + NavTaskDetail // 进入任务详情(需 TaskID) + NavWizard // 打开向导(EditTask == nil 为新建) + NavDashboard // 进入仪表盘(需 RunID + TaskID) + NavTurboDash // 进入 Turbo 仪表盘(需 RunID + TaskID) + NavReqDetail // 进入请求详情(需 ReqIndex) + NavQuit // 退出程序 +) + +// NavAction 页面处理函数的导航意图,由 root Model 统一处理。 +type NavAction struct { + To NavTarget + TaskID string + RunID server.RunID + ReqIndex int + EditTask *types.TaskDefinition // 向导编辑模式时非空;nil 表示新建 +} + +// Client 定义 pages 包对外依赖的操作集合。 +// tui.Client 实现此接口(Go duck typing)。 +type Client interface { + // 任务 CRUD + LoadTasksCmd() tea.Cmd + CreateTaskCmd(cfg server.TaskConfig, autoStart bool) tea.Cmd + UpdateTaskCmd(id string, cfg server.TaskConfig) tea.Cmd + DeleteTaskCmd(id string) tea.Cmd + CopyTaskCmd(id string) tea.Cmd + + // 运行管理 + StartRunCmd(taskID string) tea.Cmd + StopRunCmd(runID server.RunID) tea.Cmd + + // 历史 & 报告 + LoadHistoryCmd(taskID string, limit int) tea.Cmd + GetRunStateCmd(runID server.RunID) tea.Cmd + GenerateReportCmd(runID server.RunID, format server.ReportFormat) tea.Cmd +} diff --git a/internal/tui/pages/reqdetail.go b/internal/tui/pages/reqdetail.go new file mode 100644 index 0000000..11bf222 --- /dev/null +++ b/internal/tui/pages/reqdetail.go @@ -0,0 +1,272 @@ +package pages + +import ( + "fmt" + "strings" + + tea "github.com/charmbracelet/bubbletea" + "github.com/yinxulai/ait/internal/server" +) + +// ReqDetailState 请求详情页状态。 +type ReqDetailState struct { + RunID server.RunID + Requests []*server.RequestMetrics + Index int // 当前查看的请求索引 + ScrollY int // 输出区域滚动偏移 +} + +// NewReqDetailState 创建请求详情状态。 +func NewReqDetailState(runID server.RunID, reqs []*server.RequestMetrics, index int) *ReqDetailState { + return &ReqDetailState{ + RunID: runID, + Requests: reqs, + Index: index, + } +} + +// HandleReqDetailKey 处理请求详情页按键。 +func HandleReqDetailKey(s *ReqDetailState, msg tea.KeyMsg) (*ReqDetailState, NavAction) { + nav := NavAction{} + if s == nil { + return s, NavAction{To: NavDashboard} + } + + switch msg.String() { + case "left", "h": + if s.Index > 0 { + s.Index-- + } else { + s.Index = len(s.Requests) - 1 + } + s.ScrollY = 0 + + case "right", "l": + if s.Index < len(s.Requests)-1 { + s.Index++ + } else { + s.Index = 0 + } + s.ScrollY = 0 + + case "up", "k": + if s.ScrollY > 0 { + s.ScrollY-- + } + + case "down", "j": + s.ScrollY++ + + case "b", "esc", "backspace": + nav = NavAction{To: NavDashboard} + + case "q", "ctrl+c": + nav = NavAction{To: NavQuit} + } + + return s, nav +} + +// RenderReqDetail 渲染请求详情页。 +// +// 设计稿布局: +// +// ╔══ AIT 请求详情 - task-name #N ════════╗ +// ║ ◆ AIT 任务: name 请求 N/total ✓ 成功 ║ +// ╠══════════════════╦══════════════════════╣ +// ║ 性能指标 ║ 网络指标 ║ +// ║ 状态 ✓ 成功 ║ DNS 1.2ms ║ +// ║ 总耗时 245ms ║ TCP 连接 2.1ms ║ +// ║ TTFT 82ms ║ TLS 握手 8.4ms ║ +// ║ TPS 12.3/s ║ ║ +// ║ 输入Token 64 ║ ║ +// ║ 输出Token 128 ║ ║ +// ║ 缓存命中 100% ║ ║ +// ╠══════════════════╩══════════════════════╣ +// ║ 输入 (Prompt) ║ +// ║ ────────────────────────────────── ║ +// ║ 你好,介绍一下你自己。 ║ +// ╠═════════════════════════════════════════╣ +// ║ 输出 (Response) ║ +// ║ ────────────────────────────────── ║ +// ║ 你好!我是 Claude... ║ +// ║ (↑↓ 滚动查看完整内容) ║ +// ╠═════════════════════════════════════════╣ +// ║ [b/Esc] 返回仪表盘 [↑↓] 滚动 [←→] 上/下一条 ║ +// ╚═════════════════════════════════════════╝ +func RenderReqDetail(s *ReqDetailState, taskName string, st Styles, width, height int) string { + if s == nil || width == 0 { + return "加载中..." + } + if len(s.Requests) == 0 { + return "无请求数据" + } + + idx := s.Index + if idx < 0 { + idx = 0 + } + if idx >= len(s.Requests) { + idx = len(s.Requests) - 1 + } + r := s.Requests[idx] + + // ── Header ── + statusStr := st.Ok.Render("✓ 成功") + if !r.Success { + statusStr = st.ErrStyle.Render("✗ 失败") + } + header := renderHeader(st, width, + fmt.Sprintf("AIT 请求详情 - %s #%d", truncate(taskName, 20), idx+1), + statusStr, + fmt.Sprintf("◆ AIT 任务: %s 请求 %d / %d", + truncate(taskName, 20), idx+1, len(s.Requests)), + "", + ) + + // ── Context Bar ── + ctxBar := RenderContextBar(st, width, CtxBar_ReqDetail()) + + // ── Footer ── + footer := renderFooter(st, width, "[b/Esc] 返回仪表盘", "[↑↓] 滚动", "[←→] 上/下一条请求") + + // ── 计算高度 ── + headerH := 2 + ctxBarH := 0 + if ctxBar != "" { + ctxBarH = 1 + } + footerH := 1 + splitH := 9 + inputH := 5 + outputH := height - headerH - ctxBarH - footerH - splitH - inputH - 3 // -3 for dividers + if outputH < 4 { + outputH = 4 + } + + // ── 双栏(性能指标 ║ 网络指标)── + leftW := (width - 2) * 50 / 100 + rightW := width - 2 - leftW - 1 + leftContent := buildReqPerfPanel(r, st, splitH-1, leftW) + rightContent := buildReqNetworkPanel(r, st, splitH-1, rightW) + splitDiv := dividerLine(st, width) + split := dualColumnLayout(st, leftContent, rightContent, leftW, rightW, splitH) + + // ── 输入区 ── + inputSection := buildInputSection(r, st, width, inputH) + + // ── 输出区 ── + outputSection := buildOutputSection(r, s.ScrollY, st, width, outputH) + + parts := []string{header, splitDiv, split, splitDiv, inputSection, splitDiv, outputSection} + if ctxBar != "" { + parts = append(parts, ctxBar) + } + parts = append(parts, footer) + return strings.Join(parts, "\n") +} + +// buildReqPerfPanel 构建请求左侧性能指标面板。 +func buildReqPerfPanel(r *server.RequestMetrics, st Styles, maxH, width int) string { + var lines []string + lines = append(lines, " "+st.SectionHead.Render("性能指标")) + lines = append(lines, "") + + statusStr := st.Ok.Render("✓ 成功") + if !r.Success { + statusStr = st.ErrStyle.Render("✗ 失败") + } + lines = append(lines, " "+labelValue(st, "状态 ", statusStr)) + + if r.Success { + lines = append(lines, " "+labelValue(st, "总耗时 ", st.MetricVal.Render(fmtDuration(r.TotalTime)))) + lines = append(lines, " "+labelValue(st, "TTFT ", st.MetricVal.Render(fmtDuration(r.TTFT)))) + lines = append(lines, " "+labelValue(st, "输出TPS ", st.MetricVal.Render(fmt.Sprintf("%.1f tok/s", r.TPS)))) + lines = append(lines, " "+labelValue(st, "输入Token", fmt.Sprintf("%d", r.PromptTokens))) + lines = append(lines, " "+labelValue(st, "输出Token", fmt.Sprintf("%d", r.CompletionTokens))) + lines = append(lines, " "+labelValue(st, "缓存命中", fmt.Sprintf("%d tok (%.1f%%)", r.CachedTokens, r.CacheHitRate*100))) + } else { + if r.ErrorMessage != "" { + lines = append(lines, " "+st.ErrStyle.Render("错误: "+truncate(r.ErrorMessage, width-8))) + } + } + + for len(lines) < maxH { + lines = append(lines, "") + } + return strings.Join(lines[:maxH], "\n") +} + +// buildReqNetworkPanel 构建请求右侧网络指标面板。 +func buildReqNetworkPanel(r *server.RequestMetrics, st Styles, maxH, width int) string { + var lines []string + lines = append(lines, " "+st.SectionHead.Render("网络指标")) + lines = append(lines, "") + lines = append(lines, " "+labelValue(st, "DNS ", fmtDuration(r.DNSTime))) + lines = append(lines, " "+labelValue(st, "TCP 连接 ", fmtDuration(r.ConnectTime))) + lines = append(lines, " "+labelValue(st, "TLS 握手 ", fmtDuration(r.TLSTime))) + if r.TargetIP != "" { + lines = append(lines, " "+labelValue(st, "目标 IP ", r.TargetIP)) + } + + for len(lines) < maxH { + lines = append(lines, "") + } + return strings.Join(lines[:maxH], "\n") +} + +// buildInputSection 构建输入 (Prompt) 区域。 +func buildInputSection(r *server.RequestMetrics, st Styles, width, maxH int) string { + var lines []string + lines = append(lines, " "+st.SectionHead.Render("输入 (Prompt)")) + lines = append(lines, " "+dividerLine(st, width-2)) + + if r.PromptText == "" { + lines = append(lines, " "+st.Muted.Render("(未记录)")) + } else { + for _, l := range wrapText(r.PromptText, width-3) { + if len(lines) >= maxH-1 { + break + } + lines = append(lines, " "+l) + } + } + + for len(lines) < maxH { + lines = append(lines, "") + } + return strings.Join(lines[:maxH], "\n") +} + +// buildOutputSection 构建输出 (Response) 区域。 +func buildOutputSection(r *server.RequestMetrics, scrollY int, st Styles, width, maxH int) string { + var lines []string + lines = append(lines, " "+st.SectionHead.Render("输出 (Response)")) + lines = append(lines, " "+dividerLine(st, width-2)) + + if r.ResponseText == "" { + lines = append(lines, " "+st.Muted.Render("(未记录)")) + } else { + allLines := wrapText(r.ResponseText, width-3) + if scrollY >= len(allLines) { + scrollY = len(allLines) - 1 + } + if scrollY < 0 { + scrollY = 0 + } + for _, l := range allLines[scrollY:] { + if len(lines) >= maxH-1 { + break + } + lines = append(lines, " "+l) + } + if len(allLines) > maxH-3 { + lines = append(lines, " "+st.Muted.Render("(↑↓ 滚动查看完整内容)")) + } + } + + for len(lines) < maxH { + lines = append(lines, "") + } + return strings.Join(lines[:maxH], "\n") +} diff --git a/internal/tui/pages/styles.go b/internal/tui/pages/styles.go new file mode 100644 index 0000000..9639911 --- /dev/null +++ b/internal/tui/pages/styles.go @@ -0,0 +1,130 @@ +package pages + +import "github.com/charmbracelet/lipgloss" + +// Color palette +const ( + colorHeaderBg = lipgloss.Color("57") // electric indigo — header background + colorFooterBg = lipgloss.Color("235") // near-black footer background + colorCtxBarBg = lipgloss.Color("237") // slightly lighter than footer — context bar + colorPink = lipgloss.Color("205") // vivid hot pink/magenta — brand primary + colorCyan = lipgloss.Color("86") // bright aquamarine — table headers + colorPurple = lipgloss.Color("99") // medium violet + colorPurpleDim = lipgloss.Color("60") // slate purple — selected row bg + colorGreen = lipgloss.Color("78") // vivid spring green — ok/success + colorRed = lipgloss.Color("204") // vivid rose-red — error/fail + colorYellow = lipgloss.Color("221") // warm yellow — metric values + colorTeal = lipgloss.Color("111") // periwinkle-teal — labels + colorWhite = lipgloss.Color("255") // bright white + colorMuted = lipgloss.Color("245") // muted gray + colorGold = lipgloss.Color("214") // amber + colorHeaderFg = lipgloss.Color("212") // light pink — header right text + colorDivider = lipgloss.Color("238") // dim border gray +) + +// Styles 汇聚所有 TUI 样式,由 NewStyles() 初始化。 +type Styles struct { + Header lipgloss.Style + HeaderInfo lipgloss.Style + Footer lipgloss.Style + CtxBar lipgloss.Style + SectionHead lipgloss.Style + TableHead lipgloss.Style + TableRow lipgloss.Style + TableRowSel lipgloss.Style + Label lipgloss.Style + Value lipgloss.Style + Muted lipgloss.Style + Ok lipgloss.Style + ErrStyle lipgloss.Style + Key lipgloss.Style + MetricVal lipgloss.Style + Dialog lipgloss.Style + FieldActive lipgloss.Style + FieldIdle lipgloss.Style + Cursor lipgloss.Style + TagTurbo lipgloss.Style + TagStd lipgloss.Style + BtnPrimary lipgloss.Style + Divider lipgloss.Style +} + +// NewStyles 创建并返回默认样式集合。 +func NewStyles() Styles { + return Styles{ + Header: lipgloss.NewStyle(). + Background(colorHeaderBg). + Foreground(colorWhite). + Bold(true), + HeaderInfo: lipgloss.NewStyle(). + Background(colorHeaderBg). + Foreground(colorHeaderFg), + Footer: lipgloss.NewStyle(). + Background(colorFooterBg). + Foreground(colorMuted), + CtxBar: lipgloss.NewStyle(). + Background(colorCtxBarBg). + Foreground(colorWhite), + SectionHead: lipgloss.NewStyle(). + Foreground(colorPink). + Bold(true), + TableHead: lipgloss.NewStyle(). + Foreground(colorCyan). + Bold(true), + TableRow: lipgloss.NewStyle(). + Foreground(colorWhite), + TableRowSel: lipgloss.NewStyle(). + Background(colorPurpleDim). + Foreground(colorWhite). + Bold(true), + Label: lipgloss.NewStyle(). + Foreground(colorTeal). + Bold(true), + Value: lipgloss.NewStyle(). + Foreground(colorWhite), + Muted: lipgloss.NewStyle(). + Foreground(colorMuted), + Ok: lipgloss.NewStyle(). + Foreground(colorGreen). + Bold(true), + ErrStyle: lipgloss.NewStyle(). + Foreground(colorRed), + Key: lipgloss.NewStyle(). + Foreground(colorGold). + Bold(true), + MetricVal: lipgloss.NewStyle(). + Foreground(colorYellow). + Bold(true), + Dialog: lipgloss.NewStyle(). + Border(lipgloss.RoundedBorder()). + BorderForeground(colorPurple). + Padding(0, 1), + FieldActive: lipgloss.NewStyle(). + Background(lipgloss.Color("55")). + Foreground(colorWhite). + Padding(0, 1), + FieldIdle: lipgloss.NewStyle(). + Border(lipgloss.NormalBorder()). + BorderForeground(colorDivider). + Padding(0, 1), + Cursor: lipgloss.NewStyle(). + Foreground(colorPink). + Bold(true), + TagTurbo: lipgloss.NewStyle(). + Background(colorGold). + Foreground(colorDivider). + Bold(true). + Padding(0, 1), + TagStd: lipgloss.NewStyle(). + Background(colorPurple). + Foreground(colorWhite). + Padding(0, 1), + BtnPrimary: lipgloss.NewStyle(). + Background(colorPink). + Foreground(colorWhite). + Bold(true). + Padding(0, 2), + Divider: lipgloss.NewStyle(). + Foreground(colorDivider), + } +} diff --git a/internal/tui/pages/taskdetail.go b/internal/tui/pages/taskdetail.go new file mode 100644 index 0000000..56fd76b --- /dev/null +++ b/internal/tui/pages/taskdetail.go @@ -0,0 +1,291 @@ +package pages + +import ( + "fmt" + "strings" + + tea "github.com/charmbracelet/bubbletea" + "github.com/yinxulai/ait/internal/types" +) + +// TaskDetailState 任务详情页状态。 +type TaskDetailState struct { + Task types.TaskDefinition + History []types.TaskRunSummary + // LatestExpanded 控制最近一次运行是否展开(运行结束后自动置 true) + LatestExpanded bool +} + +// NewTaskDetailState 创建初始任务详情状态。 +func NewTaskDetailState(task types.TaskDefinition) *TaskDetailState { + return &TaskDetailState{Task: task} +} + +// HandleTaskDetailKey 处理任务详情页按键。 +func HandleTaskDetailKey(s *TaskDetailState, msg tea.KeyMsg, client Client) (*TaskDetailState, tea.Cmd, NavAction) { + nav := NavAction{} + switch msg.String() { + case "left", "esc", "b": + nav = NavAction{To: NavTaskList} + + case "enter", "r": + return s, client.StartRunCmd(s.Task.ID), nav + + case "e": + t := s.Task + nav = NavAction{To: NavWizard, EditTask: &t} + + case "y": + return s, client.CopyTaskCmd(s.Task.ID), nav + + case "d": + return s, client.DeleteTaskCmd(s.Task.ID), nav + + case "q", "ctrl+c": + nav = NavAction{To: NavQuit} + } + return s, nil, nav +} + +// RenderTaskDetail 渲染任务详情页。 +// +// 设计稿布局(全宽单列): +// +// ╔══ AIT 任务详情 ─ name ══════════════╗ +// ║ ◆ AIT 任务 ID: xxx 更新: xxx 刚刚 ║ +// ╠══════════════════════════════════════╣ +// ║ 配置摘要 ║ +// ║ 协议 xxx 接口 xxx ║ +// ║ 模型 xxx 模式 xxx 并发 N 请求 N ║ +// ║ 超时 xxx 流式 开启 Prompt xxx ║ +// ╠══════════════════════════════════════╣ +// ║ 最近运行 ▼ 2026-05-16 ✓ 完成 100请求 ║ +// ║ ── 指标表格 ──────────────────────── ║ +// ╠══════════════════════════════════════╣ +// ║ 历史运行记录 ║ +// ║ ── 历史列表 ─────────────────────── ║ +// ╠══════════════════════════════════════╣ +// ║ [r] 生成报告 [c] 复制摘要 ... ║ ← context bar +// ╠══════════════════════════════════════╣ +// ║ [b/Esc] 返回列表 ◆ AIT v0.1 ║ +// ╚══════════════════════════════════════╝ +func RenderTaskDetail(s *TaskDetailState, st Styles, width, height int) string { + if width == 0 { + return "加载中..." + } + t := s.Task + inp := t.Input + + // ── Header ── + updatedStr := timeAgo(t.UpdatedAt) + header := renderHeader(st, width, + "AIT 任务详情 ─ "+truncate(t.Name, 30), + "", + fmt.Sprintf("◆ AIT 任务 ID: %s 更新: %s %s", + truncate(t.ID, 10), t.UpdatedAt.Format("2006-01-02 15:04"), updatedStr), + "", + ) + + // ── Context Bar ── + hasHistory := len(s.History) > 0 + var cbItems []ContextBarItem + if hasHistory { + cbItems = CtxBar_TaskDetail_HasHistory() + } else { + cbItems = CtxBar_TaskDetail_NoHistory() + } + ctxBar := RenderContextBar(st, width, cbItems) + + // ── Footer ── + footer := renderFooter(st, width, "[b/Esc] 返回列表", "[r] 运行", "[e] 编辑", "◆ AIT v0.1") + + // ── 内容区高度 ── + headerH := 2 + ctxBarH := 0 + if ctxBar != "" { + ctxBarH = 1 + } + footerH := 1 + contentH := height - headerH - ctxBarH - footerH + if contentH < 6 { + contentH = 6 + } + + // ── 内容构建 ── + content := buildTaskDetailContent(s, st, t, inp, width, contentH) + + parts := []string{header, content} + if ctxBar != "" { + parts = append(parts, ctxBar) + } + parts = append(parts, footer) + return strings.Join(parts, "\n") +} + +// buildTaskDetailContent 构建任务详情内容区。 +func buildTaskDetailContent(s *TaskDetailState, st Styles, t types.TaskDefinition, inp types.Input, width, maxH int) string { + innerW := width - 2 + if innerW < 10 { + innerW = 10 + } + + var lines []string + + // ─── 配置摘要 ───────────────────────────────────────────── + lines = append(lines, " "+st.SectionHead.Render("配置摘要")) + lines = append(lines, " "+dividerLine(st, innerW-2)) + + // 行1:协议 + 接口 + proto := inp.NormalizedProtocol() + endpoint := truncate(inp.ResolvedEndpointURL(), innerW-30) + lines = append(lines, " "+ + st.Label.Render("协议")+" "+st.Value.Render(proto)+ + " "+st.Label.Render("接口")+" "+st.Value.Render(endpoint)) + + // 行2:模型 + 模式 + 并发 + 请求 + modeStr := "标准模式" + if inp.Turbo { + modeStr = "Turbo 模式" + } + if inp.Turbo { + tc := inp.TurboConfig + lines = append(lines, " "+ + st.Label.Render("模型")+" "+st.Value.Render(inp.Model)+ + " "+st.Label.Render("模式")+" "+st.Value.Render(modeStr)+ + " "+st.Label.Render("并发爬坡")+" "+ + st.Value.Render(fmt.Sprintf("%d → %d 步进+%d 每级%d请求", + tc.InitConcurrency, tc.MaxConcurrency, tc.StepSize, tc.LevelRequests))) + } else { + lines = append(lines, " "+ + st.Label.Render("模型")+" "+st.Value.Render(inp.Model)+ + " "+st.Label.Render("模式")+" "+st.Value.Render(modeStr)+ + " "+st.Label.Render("并发")+" "+st.Value.Render(fmt.Sprintf("%d", inp.Concurrency))+ + " "+st.Label.Render("请求")+" "+st.Value.Render(fmt.Sprintf("%d", inp.Count))) + } + + // 行3:超时 + 流式 + Prompt + prompt := promptSummary(inp.PromptMode, inp.PromptText, inp.PromptFile, inp.PromptLength) + lines = append(lines, " "+ + st.Label.Render("超时")+" "+st.Value.Render(fmtDuration(inp.Timeout))+ + " "+st.Label.Render("流式")+" "+st.Value.Render(boolLabel(inp.Stream))+ + " "+st.Label.Render("Prompt")+" "+st.Value.Render(truncate(prompt, innerW-50))) + + lines = append(lines, "") + + // ─── 最近运行 ────────────────────────────────────────────── + if len(s.History) > 0 { + latest := s.History[0] + statusStr := "✓ 完成" + if latest.Status != "completed" { + statusStr = "✗ " + latest.Status + } + elapsed := latest.FinishedAt.Sub(latest.StartedAt) + expandMark := "▼" + if !s.LatestExpanded { + expandMark = "▶" + } + lines = append(lines, fmt.Sprintf(" %s %s 最近运行 %s %s %d 请求 耗时 %s", + st.SectionHead.Render("最近运行"), + st.Ok.Render(expandMark), + latest.StartedAt.Format("2006-01-02 15:04"), + st.Ok.Render(statusStr), + 0, // 请求总数(运行摘要中需要补充该字段,暂用 0) + fmtDuration(elapsed), + )) + lines = append(lines, " "+dividerLine(st, innerW-2)) + + if s.LatestExpanded && len(lines) < maxH-10 { + // 指标表格 + lines = append(lines, " "+st.TableHead.Render( + padRight("指标", 16)+padRight("最小值", 10)+padRight("平均值", 10)+padRight("标准差", 10)+"最大值")) + lines = append(lines, " "+st.Divider.Render(strings.Repeat("─", innerW-2))) + + if len(lines) < maxH { + lines = append(lines, buildMetricRow(st, "TTFT", + fmtDuration(latest.AvgTTFT), fmtDuration(latest.AvgTTFT), "─", "─")) + } + if len(lines) < maxH { + lines = append(lines, buildMetricRow(st, "输出 TPS", + "─", fmt.Sprintf("%.1f", latest.AvgTPS), "─", "─")) + } + if len(lines) < maxH { + lines = append(lines, buildMetricRow(st, "成功率", + "─", fmt.Sprintf("%.1f%%", latest.SuccessRate*100), "─", "─")) + } + if latest.CacheHitRate > 0 && len(lines) < maxH { + lines = append(lines, buildMetricRow(st, "缓存命中率", + "─", fmt.Sprintf("%.1f%%", latest.CacheHitRate*100), "─", "─")) + } + if latest.ErrorSummary != "" && len(lines) < maxH { + lines = append(lines, " "+st.ErrStyle.Render("错误 "+truncate(latest.ErrorSummary, innerW-10))) + } + } + lines = append(lines, "") + } + + // ─── 历史运行记录 ────────────────────────────────────────── + if len(lines) < maxH-4 { + lines = append(lines, " "+st.SectionHead.Render("历史运行记录")) + lines = append(lines, " "+st.TableHead.Render( + padRight("时间", 20)+padRight("模式", 8)+padRight("成功率", 8)+ + padRight("TTFT", 10)+padRight("TPS", 10)+"Cache")) + lines = append(lines, " "+st.Divider.Render(strings.Repeat("─", innerW-2))) + + for _, run := range s.History { + if len(lines) >= maxH-1 { + break + } + statusIcon := st.Ok.Render("✓") + if run.Status != "completed" { + statusIcon = st.ErrStyle.Render("✗") + } + modeShort := "标准" + if run.Mode == "turbo" { + modeShort = "Turbo" + } + cacheStr := "─" + if run.CacheHitRate > 0 { + cacheStr = fmt.Sprintf("%.1f%%", run.CacheHitRate*100) + } + row := fmt.Sprintf(" %s %s %s %s %s %s %s", + statusIcon, + run.StartedAt.Format("2006-01-02 15:04"), + padRight(modeShort, 6), + padRight(fmt.Sprintf("%.1f%%", run.SuccessRate*100), 7), + padRight(fmtDuration(run.AvgTTFT), 9), + padRight(fmt.Sprintf("%.1f", run.AvgTPS), 9), + cacheStr, + ) + lines = append(lines, " "+st.TableRow.Render(row)) + } + } + + // 补齐剩余高度 + for len(lines) < maxH { + lines = append(lines, "") + } + + return strings.Join(lines, "\n") +} + +// buildMetricRow 构建指标表格一行。 +func buildMetricRow(st Styles, name, minV, avgV, stdV, maxV string) string { + return " " + st.Label.Render(padRight(name, 16)) + + st.Value.Render(padRight(minV, 10)) + + st.MetricVal.Render(padRight(avgV, 10)) + + st.Muted.Render(padRight(stdV, 10)) + + st.Value.Render(maxV) +} + +// TaskDetailFromMsg 从消息中提取 TaskDetailState 的帮助函数, +// 供 model.go 在 HistoryLoadedMsg 处理时使用。 +func UpdateTaskDetailHistory(s *TaskDetailState, history []types.TaskRunSummary, autoExpand bool) *TaskDetailState { + if s == nil { + return s + } + s.History = history + if autoExpand && len(history) > 0 { + s.LatestExpanded = true + } + return s +} diff --git a/internal/tui/pages/tasklist.go b/internal/tui/pages/tasklist.go new file mode 100644 index 0000000..5956e07 --- /dev/null +++ b/internal/tui/pages/tasklist.go @@ -0,0 +1,348 @@ +package pages + +import ( + "fmt" + "strings" + "time" + + tea "github.com/charmbracelet/bubbletea" + "github.com/charmbracelet/lipgloss" + "github.com/yinxulai/ait/internal/server" + "github.com/yinxulai/ait/internal/types" +) + +// TaskListState 任务列表页状态。 +type TaskListState struct { + Tasks []types.TaskDefinition + Selected int + // 运行中任务的进度(runID -> RunState 快照,由 Model 注入) + ActiveRuns map[string]*server.RunState // taskID -> RunState +} + +// NewTaskListState 创建初始任务列表状态。 +func NewTaskListState() *TaskListState { + return &TaskListState{ + ActiveRuns: make(map[string]*server.RunState), + } +} + +// CurrentTask 返回当前选中的任务。 +func (s *TaskListState) CurrentTask() (types.TaskDefinition, bool) { + if len(s.Tasks) == 0 || s.Selected < 0 || s.Selected >= len(s.Tasks) { + return types.TaskDefinition{}, false + } + return s.Tasks[s.Selected], true +} + +// IsTaskRunning 判断某任务是否正在运行。 +func (s *TaskListState) IsTaskRunning(taskID string) bool { + if rs, ok := s.ActiveRuns[taskID]; ok { + return rs != nil && rs.Status == server.RunStatusRunning + } + return false +} + +// latestRunAt 返回任务列表中最新一次运行时间(用于 header 显示)。 +func (s *TaskListState) latestRunAt() *time.Time { + var latest *time.Time + for _, t := range s.Tasks { + if t.LastRunAt != nil { + if latest == nil || t.LastRunAt.After(*latest) { + latest = t.LastRunAt + } + } + } + return latest +} + +// HandleTaskListKey 处理任务列表页的按键,返回 tea.Cmd 和导航意图。 +func HandleTaskListKey(s *TaskListState, msg tea.KeyMsg, client Client) (*TaskListState, tea.Cmd, NavAction) { + nav := NavAction{} + + switch msg.String() { + case "up", "k": + if s.Selected > 0 { + s.Selected-- + } + case "down", "j": + if s.Selected < len(s.Tasks)-1 { + s.Selected++ + } + + case "a": + nav = NavAction{To: NavWizard, EditTask: nil} + + case "e": + if t, ok := s.CurrentTask(); ok { + nav = NavAction{To: NavWizard, EditTask: &t} + } + + case "y": + if t, ok := s.CurrentTask(); ok { + return s, client.CopyTaskCmd(t.ID), nav + } + + case "d": + if t, ok := s.CurrentTask(); ok { + return s, client.DeleteTaskCmd(t.ID), nav + } + + case "enter": + if t, ok := s.CurrentTask(); ok { + if s.IsTaskRunning(t.ID) { + if rs, ok := s.ActiveRuns[t.ID]; ok { + nav = NavAction{To: NavDashboard, TaskID: t.ID, RunID: rs.RunID} + } + } else { + nav = NavAction{To: NavTaskDetail, TaskID: t.ID} + } + } + + case "r": + if t, ok := s.CurrentTask(); ok { + return s, client.StartRunCmd(t.ID), nav + } + + case "s": + if t, ok := s.CurrentTask(); ok { + if rs, ok := s.ActiveRuns[t.ID]; ok && rs != nil { + return s, client.StopRunCmd(rs.RunID), nav + } + } + + case "q", "ctrl+c": + nav = NavAction{To: NavQuit} + } + + return s, nil, nav +} + +// RenderTaskList 渲染任务列表页。 +// +// 设计稿布局: +// +// ╔══ AIT 任务中心 ══════════════╗ +// ║ ◆ AIT 已保存任务: N 最近运行: xxx ║ +// ╠══════════════════════════════╣ +// ║ 任务名称 模式 协议 上次结果 ║ +// ║ ─────────────────────────── ║ +// ║ ▶ ◉ name 标准 responses ✓ 98.5% ║ +// ║ model 并发10 请求200 ◉ 47/100 ║ +// ║ ║ +// ╠══════════════════════════════╣ +// ║ [Enter] 详情 [a] 新建 ... ║ ← context bar +// ╠══════════════════════════════╣ +// ║ [↑↓] 选择 [q] 退出 ◆ AIT ║ +// ╚══════════════════════════════╝ +func RenderTaskList(s *TaskListState, st Styles, width, height int) string { + if width == 0 { + return "加载中..." + } + + // ── Header ── + lastRunStr := "" + if lt := s.latestRunAt(); lt != nil { + lastRunStr = "最近运行: " + lt.Format("2006-01-02 15:04") + } + header := renderHeader(st, width, + "AIT 任务中心", + "", + fmt.Sprintf("◆ AIT 已保存任务: %d %s", len(s.Tasks), lastRunStr), + "", + ) + + // ── Context Bar ── + var cbItems []ContextBarItem + if t, ok := s.CurrentTask(); ok { + if s.IsTaskRunning(t.ID) { + cbItems = CtxBar_TaskList_Running() + } else { + cbItems = CtxBar_TaskList_Normal() + } + } + ctxBar := RenderContextBar(st, width, cbItems) + + // ── Footer ── + footer := renderFooter(st, width, "[↑↓] 选择", "[a] 新建", "[y] 复制", "[q] 退出", "◆ AIT v0.1") + + // ── 可用内容高度 ── + headerH := 2 + ctxBarH := 0 + if ctxBar != "" { + ctxBarH = 1 + } + footerH := 1 + contentH := height - headerH - ctxBarH - footerH + if contentH < 4 { + contentH = 4 + } + + // ── 内容区 ── + content := buildTaskListContent(s, st, width, contentH) + + parts := []string{header, content} + if ctxBar != "" { + parts = append(parts, ctxBar) + } + parts = append(parts, footer) + return strings.Join(parts, "\n") +} + +// buildTaskListContent 构建任务列表内容区(含表头 + 任务条目)。 +func buildTaskListContent(s *TaskListState, st Styles, width, maxH int) string { + innerW := width - 2 + if innerW < 20 { + innerW = 20 + } + + var lines []string + + // 表头行 + nameW := 28 + modeW := 8 + protoW := 14 + resultW := innerW - nameW - modeW - protoW - 4 + if resultW < 10 { + resultW = 10 + } + + header := st.TableHead.Render( + " " + padRight("任务名称", nameW) + + padRight("模式", modeW) + + padRight("协议", protoW) + + "上次结果", + ) + lines = append(lines, header) + lines = append(lines, " "+st.Divider.Render(strings.Repeat("─", innerW-1))) + + if len(s.Tasks) == 0 { + lines = append(lines, "") + lines = append(lines, " "+st.Muted.Render("暂无任务 按 [a] 新建第一个任务")) + return strings.Join(lines, "\n") + } + + for i, t := range s.Tasks { + if len(lines) >= maxH { + break + } + + isRunning := s.IsTaskRunning(t.ID) + isSel := i == s.Selected + rs := s.ActiveRuns[t.ID] + + // ── 指示符和运行中标记 ── + cursor := " " + if isSel { + cursor = "▶ " + } + runMark := " " + if isRunning { + runMark = st.Ok.Render("◉") + " " + } + prefix := cursor + runMark + + // ── 模式标签 ── + var modeTag string + if t.Input.Turbo { + modeTag = st.TagTurbo.Render("Turbo") + } else { + modeTag = st.TagStd.Render("标准 ") + } + modeTagW := lipgloss.Width(modeTag) + modePad := modeW - modeTagW + if modePad < 0 { + modePad = 0 + } + modeCol := modeTag + strings.Repeat(" ", modePad) + + // ── 协议 ── + proto := padRight(shortProtocol(t.Input.NormalizedProtocol()), protoW) + + // ── 上次结果 ── + lastResult := st.Muted.Render("从未运行") + if t.LastRunSummary != nil { + pct := t.LastRunSummary.SuccessRate + if t.Input.Turbo { + if t.LastRunSummary.MaxStableConcurrency > 0 { + lastResult = st.Ok.Render(fmt.Sprintf("★ 并发%d", t.LastRunSummary.MaxStableConcurrency)) + } + } else { + switch { + case pct >= 99: + lastResult = st.Ok.Render(fmt.Sprintf("✓ %.1f%%", pct)) + case pct >= 90: + lastResult = st.MetricVal.Render(fmt.Sprintf("%.1f%%", pct)) + default: + lastResult = st.ErrStyle.Render(fmt.Sprintf("✗ %.1f%%", pct)) + } + } + } + + // ── 任务名称(裁剪)── + name := truncate(t.Name, nameW) + namePad := nameW - lipgloss.Width(name) + if namePad < 0 { + namePad = 0 + } + nameCol := name + strings.Repeat(" ", namePad) + + // ── 第一行 ── + prefixW := lipgloss.Width(prefix) + row1Content := nameCol + modeCol + proto + lastResult + var row1 string + if isSel { + row1 = st.TableRowSel.Render(prefix+row1Content) + strings.Repeat(" ", max(0, width-prefixW-lipgloss.Width(row1Content)-2)) + } else { + row1 = " " + runMark + row1Content + } + lines = append(lines, row1) + + // ── 第二行(模型 + 参数 + 实时进度)── + if len(lines) < maxH { + indent := " " // 5 空格缩进(对齐任务名) + var params string + if t.Input.Turbo { + tc := t.Input.TurboConfig + params = fmt.Sprintf("%s %d→%d 步进+%d", + truncate(t.Input.Model, 12), + tc.InitConcurrency, tc.MaxConcurrency, tc.StepSize) + if t.LastRunSummary != nil { + params += fmt.Sprintf(" 上次: 峰值 TPS %.1f", t.LastRunSummary.AvgTPS) + } + } else { + params = fmt.Sprintf("%s 并发%d 请求%d", + truncate(t.Input.Model, 12), + t.Input.Concurrency, t.Input.Count) + } + + // 实时进度 + if isRunning && rs != nil { + prog := fmt.Sprintf(" %s %d/%d 成功率 %.1f%%", + st.Ok.Render("◉"), rs.DoneReqs, rs.TotalReqs, rs.SuccessRate*100) + params += prog + } + + row2 := indent + st.Muted.Render(truncate(params, width-7)) + lines = append(lines, row2) + } + + // ── 空行分隔 ── + if i < len(s.Tasks)-1 && len(lines) < maxH-1 { + lines = append(lines, "") + } + } + + // 补齐剩余行 + for len(lines) < maxH { + lines = append(lines, "") + } + + return strings.Join(lines, "\n") +} + +func max(a, b int) int { + if a > b { + return a + } + return b +} diff --git a/internal/tui/pages/turbodash.go b/internal/tui/pages/turbodash.go new file mode 100644 index 0000000..4f87a08 --- /dev/null +++ b/internal/tui/pages/turbodash.go @@ -0,0 +1,341 @@ +package pages + +import ( + "fmt" + "strings" + + tea "github.com/charmbracelet/bubbletea" + "github.com/yinxulai/ait/internal/server" + "github.com/yinxulai/ait/internal/types" +) + +// TurboDashState Turbo 模式仪表盘页状态。 +type TurboDashState struct { + RunID server.RunID + TaskID string + EventCh <-chan server.Event + CancelFn server.CancelFunc + RunState *server.RunState + LevelSel int // 选中的级别索引(-1 = 无选中) +} + +// NewTurboDashState 创建 Turbo 仪表盘初始状态。 +func NewTurboDashState(runID server.RunID, taskID string) *TurboDashState { + return &TurboDashState{ + RunID: runID, + TaskID: taskID, + LevelSel: -1, + } +} + +// IsRunning 判断是否仍在运行。 +func (d *TurboDashState) IsRunning() bool { + if d == nil || d.RunState == nil { + return false + } + return d.RunState.Status == server.RunStatusRunning +} + +// HandleTurboDashKey 处理 Turbo 仪表盘按键。 +func HandleTurboDashKey(d *TurboDashState, msg tea.KeyMsg, client Client) (*TurboDashState, tea.Cmd, NavAction) { + nav := NavAction{} + if d == nil { + return d, nil, NavAction{To: NavTaskList} + } + + var levels []types.TurboLevelResult + if d.RunState != nil { + levels = d.RunState.Levels + } + + switch msg.String() { + case "up", "k": + if len(levels) == 0 { + break + } + if d.LevelSel <= 0 { + d.LevelSel = len(levels) - 1 + } else { + d.LevelSel-- + } + + case "down", "j": + if len(levels) == 0 { + break + } + if d.LevelSel < len(levels)-1 { + d.LevelSel++ + } else { + d.LevelSel = 0 + } + + case "enter": + // 进入该级别的请求列表(使用标准仪表盘的请求详情,此处导航到 ReqDetail) + if d.LevelSel >= 0 && d.LevelSel < len(levels) { + nav = NavAction{To: NavReqDetail, ReqIndex: 0} + } + + case "s": + if d.IsRunning() { + return d, client.StopRunCmd(d.RunID), nav + } + + case "m": + // 手动标记极限并停止 + if d.IsRunning() { + return d, client.StopRunCmd(d.RunID), nav + } + + case "b", "esc": + if d.CancelFn != nil { + d.CancelFn() + } + d.EventCh = nil + d.CancelFn = nil + nav = NavAction{To: NavTaskList} + + case "r": + if d.RunState != nil && !d.IsRunning() { + return d, client.GenerateReportCmd(d.RunID, server.ReportFormatJSON), nav + } + + case "q", "ctrl+c": + nav = NavAction{To: NavQuit} + } + + return d, nil, nav +} + +// RenderTurboDash 渲染 Turbo 模式运行仪表盘。 +// +// 设计稿布局: +// +// ╔══ AIT Turbo 探测 ─ task-name ══════════╗ +// ║ ◆ AIT model · protocol · 1→50 步进+2 ║ +// ╠══════════════════╦══════════════════════╣ +// ║ 任务参数 ║ 当前级别实时指标 [并发=N] ║ +// ║ ... ║ ... ║ +// ╠══════════════════╩══════════════════════╣ +// ║ 进度 ████░░ N/30 当前并发 N 总进度: 已完成N/~N级 ║ +// ╠═════════════════════════════════════════╣ +// ║ 级别列表 ║ +// ║ 并发 成功率 TPS TTFT Cache 总耗时 结论 ║ +// ║ ... ║ +// ╠═════════════════════════════════════════╣ +// ║ context bar ║ +// ╠═════════════════════════════════════════╣ +// ║ [s] 停止 [b] 后台 [m] 标记极限 [q] 退出 ║ +// ╚═════════════════════════════════════════╝ +func RenderTurboDash(d *TurboDashState, taskName string, st Styles, width, height int) string { + if d == nil || width == 0 { + return "加载中..." + } + rs := d.RunState + + // ── Header ── + statusStr := "探测中" + if rs != nil && rs.Status != server.RunStatusRunning { + statusStr = st.Muted.Render(string(rs.Status)) + } else { + statusStr = st.Ok.Render("探测中") + } + + subtitle := "─" + if rs != nil && len(rs.Levels) > 0 { + curLevel := rs.CurrentLevel + subtitle = fmt.Sprintf("◆ AIT %s · 当前并发: %d 已完成 %d 级", + "─", curLevel, len(rs.Levels)) + } + + header := renderHeader(st, width, + "AIT Turbo 探测 ─ "+truncate(taskName, 22), + statusStr, + subtitle, + "", + ) + + // ── Context Bar ── + var cbItems []ContextBarItem + if d.LevelSel >= 0 && rs != nil && d.LevelSel < len(rs.Levels) { + cbItems = CtxBar_TurboDash_Sel() + } else { + cbItems = CtxBar_TurboDash_NoSel() + } + ctxBar := RenderContextBar(st, width, cbItems) + + // ── Footer ── + footer := renderFooter(st, width, "[s] 停止", "[b] 后台运行", "[m] 标记极限", "[r] 提前报告", "[q] 退出") + + // ── 计算高度 ── + headerH := 2 + ctxBarH := 0 + if ctxBar != "" { + ctxBarH = 1 + } + footerH := 1 + splitH := 9 + progressH := 1 + divH := 3 + levelListH := height - headerH - ctxBarH - footerH - splitH - progressH - divH + if levelListH < 3 { + levelListH = 3 + } + + // ── 双栏(任务参数 ║ 当前级别指标)── + leftW := (width - 2) * 45 / 100 + rightW := width - 2 - leftW - 1 + leftContent := buildTurboDashParams(rs, st, splitH-1, leftW) + rightContent := buildTurboDashMetrics(rs, st, splitH-1, rightW) + splitDiv := dividerLine(st, width) + split := dualColumnLayout(st, leftContent, rightContent, leftW, rightW, splitH) + + // ── 进度条 ── + progressLine := buildTurboProgressLine(rs, st, width) + + // ── 级别列表 ── + levelDiv := dividerLine(st, width) + levelList := buildLevelList(d, rs, st, width, levelListH) + + parts := []string{header, splitDiv, split, splitDiv, progressLine, levelDiv, levelList} + if ctxBar != "" { + parts = append(parts, ctxBar) + } + parts = append(parts, footer) + return strings.Join(parts, "\n") +} + +// buildTurboDashParams 构建 Turbo 仪表盘左侧任务参数面板。 +func buildTurboDashParams(rs *server.RunState, st Styles, maxH, width int) string { + var lines []string + lines = append(lines, " "+st.SectionHead.Render("任务参数")) + lines = append(lines, "") + + if rs == nil { + lines = append(lines, " "+st.Muted.Render("等待数据...")) + } else { + if rs.TurboResult != nil { + tc := rs.TurboResult.Config + lines = append(lines, " "+labelValue(st, "爬坡 ", fmt.Sprintf("%d→%d 步进+%d", tc.InitConcurrency, tc.MaxConcurrency, tc.StepSize))) + lines = append(lines, " "+labelValue(st, "每级 ", fmt.Sprintf("%d 请求", tc.LevelRequests))) + lines = append(lines, " "+labelValue(st, "停止 ", fmt.Sprintf("成功率 < %.0f%%", tc.MinSuccessRate*100))) + } + } + + for len(lines) < maxH { + lines = append(lines, "") + } + return strings.Join(lines[:maxH], "\n") +} + +// buildTurboDashMetrics 构建 Turbo 仪表盘右侧当前级别实时指标面板。 +func buildTurboDashMetrics(rs *server.RunState, st Styles, maxH, width int) string { + var lines []string + + curLevel := 0 + if rs != nil { + curLevel = rs.CurrentLevel + } + lines = append(lines, " "+st.SectionHead.Render(fmt.Sprintf("当前级别实时指标 [并发 = %d]", curLevel))) + lines = append(lines, "") + + if rs == nil { + lines = append(lines, " "+st.Muted.Render("等待数据...")) + } else { + lines = append(lines, " "+labelValue(st, "成功率 ", st.MetricVal.Render(fmt.Sprintf("%.1f%%", rs.SuccessRate*100)))) + lines = append(lines, " "+labelValue(st, "TPS ", st.MetricVal.Render(fmt.Sprintf("%.1f tok/s", rs.AvgTPS)))) + lines = append(lines, " "+labelValue(st, "TTFT ", st.MetricVal.Render(fmtDuration(rs.AvgTTFT)))) + lines = append(lines, " "+labelValue(st, "Cache ", st.MetricVal.Render(fmt.Sprintf("%.1f%%", rs.CacheHitRate*100)))) + } + + for len(lines) < maxH { + lines = append(lines, "") + } + return strings.Join(lines[:maxH], "\n") +} + +// buildTurboProgressLine 构建 Turbo 模式进度条行。 +func buildTurboProgressLine(rs *server.RunState, st Styles, width int) string { + if rs == nil { + return " 进度 " + st.Muted.Render("等待中...") + } + total := rs.TotalReqs + done := rs.DoneReqs + var ratio float64 + if total > 0 { + ratio = float64(done) / float64(total) + } + barW := 15 + barRendered := st.Ok.Render(strings.Repeat("█", int(ratio*float64(barW)))) + + st.Muted.Render(strings.Repeat("░", barW-int(ratio*float64(barW)))) + + levelTotal := len(rs.Levels) + line := fmt.Sprintf(" 进度 %s %d/%d 当前并发 %d 总进度: 已完成 %d/~? 级", + barRendered, done, total, rs.CurrentLevel, levelTotal) + return line +} + +// buildLevelList 构建 Turbo 级别列表区域。 +func buildLevelList(d *TurboDashState, rs *server.RunState, st Styles, width, maxH int) string { + var lines []string + lines = append(lines, " "+st.SectionHead.Render("级别列表")) + + if rs == nil || len(rs.Levels) == 0 { + lines = append(lines, " "+st.Muted.Render("等待第一个级别完成...")) + for len(lines) < maxH { + lines = append(lines, "") + } + return strings.Join(lines, "\n") + } + + // 表头 + lines = append(lines, " "+st.TableHead.Render( + padRight("并发", 6)+padRight("成功率", 8)+padRight("TPS", 10)+ + padRight("TTFT", 10)+padRight("Cache", 8)+padRight("总耗时", 9)+"结论")) + lines = append(lines, " "+st.Divider.Render(strings.Repeat("─", width-2))) + + for i, lv := range rs.Levels { + if len(lines) >= maxH { + break + } + isSel := i == d.LevelSel + + conclusion := st.Ok.Render("✓ 稳定") + if !lv.Stable { + conclusion = st.ErrStyle.Render("✗ 降级") + } + // 当前进行中的级别 + isCurrent := (i == len(rs.Levels)-1) && rs.Status == server.RunStatusRunning + if isCurrent { + conclusion = st.MetricVal.Render("🔄 进行中") + } + + row := fmt.Sprintf(" %s%s%s%s%s%s%s", + padRight(fmt.Sprintf("%d", lv.Concurrency), 6), + padRight(fmt.Sprintf("%.1f%%", lv.SuccessRate*100), 8), + padRight(fmt.Sprintf("%.1f", lv.AvgTPS), 10), + padRight(fmtDuration(lv.AvgTTFT), 10), + padRight(fmt.Sprintf("%.1f%%", lv.CacheHitRate*100), 8), + padRight(fmtDuration(lv.AvgTotalTime), 9), + conclusion, + ) + + cursorStr := " " + if isSel { + cursorStr = "▶ " + } + + var rendered string + if isSel { + rendered = st.TableRowSel.Render(cursorStr+row) + + strings.Repeat(" ", max(0, width-len([]rune(cursorStr+row))-2)) + } else { + rendered = " " + st.TableRow.Render(row) + } + lines = append(lines, rendered) + } + + for len(lines) < maxH { + lines = append(lines, "") + } + return strings.Join(lines[:maxH], "\n") +} diff --git a/internal/tui/pages/wizard.go b/internal/tui/pages/wizard.go new file mode 100644 index 0000000..99f9422 --- /dev/null +++ b/internal/tui/pages/wizard.go @@ -0,0 +1,727 @@ +package pages + +import ( + "fmt" + "strconv" + "strings" + + tea "github.com/charmbracelet/bubbletea" + "github.com/charmbracelet/lipgloss" + "github.com/yinxulai/ait/internal/server" + "github.com/yinxulai/ait/internal/types" +) + +// Prompt 模式常量 +const ( + PromptModeText = "text" + PromptModeFile = "file" + PromptModeGenerated = "generated" +) + +// wizardStep 步骤枚举 +type wizardStep int + +const ( + wizardStep1 wizardStep = 0 // Step 1/3 · 基本信息 + wizardStep2 wizardStep = 1 // Step 2/3 · 测试参数 + wizardStep3 wizardStep = 2 // Step 3/3 · 确认保存 +) + +// WizardState 向导的完整状态。 +type WizardState struct { + Step wizardStep + EditingID string // 非空 = 编辑模式 + + // Step 1: 基本信息 + Name string + Protocol string // types.Protocol* 常量 + EndpointURL string + APIKey string + Model string + + // Step 2: 测试参数 + Turbo bool + Stream bool + + // 标准模式参数 + Concurrency int + Count int + Timeout int // 秒 + + // Turbo 模式参数 + InitConcurrency int + MaxConcurrency int + StepSize int + LevelRequests int + MinSuccessRate float64 // 百分比,如 90 + + // Prompt 配置 + PromptMode string + PromptText string + PromptFile string + PromptLength int + + // 当前活跃字段索引(Tab 切换) + FieldIndex int +} + +// NewWizardState 创建新建任务向导状态(使用默认值)。 +func NewWizardState() *WizardState { + return &WizardState{ + Step: wizardStep1, + Protocol: types.ProtocolOpenAICompletions, + Concurrency: 10, + Count: 100, + Timeout: 30, + InitConcurrency: 1, + MaxConcurrency: 50, + StepSize: 2, + LevelRequests: 30, + MinSuccessRate: 90, + PromptMode: PromptModeText, + } +} + +// NewWizardStateEdit 创建编辑任务向导状态(预填任务数据)。 +func NewWizardStateEdit(t *types.TaskDefinition) *WizardState { + if t == nil { + return NewWizardState() + } + inp := t.Input + tc := inp.TurboConfig + wz := &WizardState{ + Step: wizardStep1, + EditingID: t.ID, + Name: t.Name, + Protocol: types.NormalizeProtocol(inp.Protocol), + EndpointURL: inp.EndpointURL, + APIKey: inp.ApiKey, + Model: inp.Model, + Turbo: inp.Turbo, + Stream: inp.Stream, + Concurrency: inp.Concurrency, + Count: inp.Count, + Timeout: int(inp.Timeout.Seconds()), + InitConcurrency: tc.InitConcurrency, + MaxConcurrency: tc.MaxConcurrency, + StepSize: tc.StepSize, + LevelRequests: tc.LevelRequests, + MinSuccessRate: tc.MinSuccessRate * 100, // 转为百分比 + PromptMode: inp.PromptMode, + PromptText: inp.PromptText, + PromptFile: inp.PromptFile, + PromptLength: inp.PromptLength, + } + if wz.PromptMode == "" { + wz.PromptMode = PromptModeText + } + if wz.Concurrency == 0 { + wz.Concurrency = 10 + } + if wz.Count == 0 { + wz.Count = 100 + } + if wz.Timeout == 0 { + wz.Timeout = 30 + } + if wz.MinSuccessRate == 0 { + wz.MinSuccessRate = 90 + } + if wz.StepSize == 0 { + wz.StepSize = 2 + } + if wz.LevelRequests == 0 { + wz.LevelRequests = 30 + } + if wz.MaxConcurrency == 0 { + wz.MaxConcurrency = 50 + } + return wz +} + +// BuildTaskConfig 将向导状态转换为 server.TaskConfig。 +func (wz *WizardState) BuildTaskConfig() server.TaskConfig { + turboRate := wz.MinSuccessRate / 100 // 转回小数 + if turboRate <= 0 { + turboRate = 0.9 + } + return server.TaskConfig{ + Name: wz.Name, + Input: types.Input{ + Protocol: wz.Protocol, + EndpointURL: wz.EndpointURL, + ApiKey: wz.APIKey, + Model: wz.Model, + Concurrency: wz.Concurrency, + Count: wz.Count, + Stream: wz.Stream, + Turbo: wz.Turbo, + TurboConfig: types.TurboConfig{ + InitConcurrency: wz.InitConcurrency, + MaxConcurrency: wz.MaxConcurrency, + StepSize: wz.StepSize, + LevelRequests: wz.LevelRequests, + MinSuccessRate: turboRate, + }, + PromptMode: wz.PromptMode, + PromptText: wz.PromptText, + PromptFile: wz.PromptFile, + PromptLength: wz.PromptLength, + }, + } +} + +// fieldDef 向导字段定义 +type fieldDef struct { + kind fieldKind + label string + // 获取当前值(字符串形式) + get func(wz *WizardState) string + // 设置文本值 + set func(wz *WizardState, v string) + // 枚举/布尔切换 + toggle func(wz *WizardState, forward bool) +} + +type fieldKind int + +const ( + fieldText fieldKind = iota // 自由文本输入 + fieldNumber // 数字 + fieldBool // 布尔开关 + fieldEnum // 枚举循环 +) + +// step1Fields 返回步骤1的字段列表。 +func step1Fields() []fieldDef { + protocols := []string{ + types.ProtocolOpenAICompletions, + types.ProtocolOpenAIResponses, + types.ProtocolAnthropicMessages, + } + return []fieldDef{ + { + kind: fieldText, label: "任务名称", + get: func(wz *WizardState) string { return wz.Name }, + set: func(wz *WizardState, v string) { wz.Name = v }, + }, + { + kind: fieldEnum, label: "协议类型", + get: func(wz *WizardState) string { return wz.Protocol }, + toggle: func(wz *WizardState, forward bool) { + idx := 0 + for i, p := range protocols { + if p == wz.Protocol { + idx = i + break + } + } + if forward { + idx = (idx + 1) % len(protocols) + } else { + idx = (idx - 1 + len(protocols)) % len(protocols) + } + wz.Protocol = protocols[idx] + // 清空 endpoint,使其跟随协议默认值 + wz.EndpointURL = "" + }, + }, + { + kind: fieldText, label: "接口地址", + get: func(wz *WizardState) string { + if wz.EndpointURL != "" { + return wz.EndpointURL + } + return types.DefaultEndpointURL(wz.Protocol) + }, + set: func(wz *WizardState, v string) { wz.EndpointURL = v }, + }, + { + kind: fieldText, label: "API 密钥", + get: func(wz *WizardState) string { return wz.APIKey }, + set: func(wz *WizardState, v string) { wz.APIKey = v }, + }, + { + kind: fieldText, label: "测试模型", + get: func(wz *WizardState) string { return wz.Model }, + set: func(wz *WizardState, v string) { wz.Model = v }, + }, + } +} + +// step2Fields 返回步骤2的字段列表(根据 Turbo 模式动态变化)。 +func step2Fields(turbo bool) []fieldDef { + fields := []fieldDef{ + { + kind: fieldBool, label: "测试模式", + get: func(wz *WizardState) string { + if wz.Turbo { + return "Turbo 模式" + } + return "标准模式" + }, + toggle: func(wz *WizardState, _ bool) { wz.Turbo = !wz.Turbo }, + }, + } + + if !turbo { + fields = append(fields, + fieldDef{ + kind: fieldNumber, label: "并发数", + get: func(wz *WizardState) string { return strconv.Itoa(wz.Concurrency) }, + set: func(wz *WizardState, v string) { + if n, err := strconv.Atoi(v); err == nil && n > 0 { + wz.Concurrency = n + } + }, + }, + fieldDef{ + kind: fieldNumber, label: "请求总数", + get: func(wz *WizardState) string { return strconv.Itoa(wz.Count) }, + set: func(wz *WizardState, v string) { + if n, err := strconv.Atoi(v); err == nil && n > 0 { + wz.Count = n + } + }, + }, + fieldDef{ + kind: fieldNumber, label: "超时(秒)", + get: func(wz *WizardState) string { return strconv.Itoa(wz.Timeout) }, + set: func(wz *WizardState, v string) { + if n, err := strconv.Atoi(v); err == nil && n > 0 { + wz.Timeout = n + } + }, + }, + fieldDef{ + kind: fieldBool, label: "流式模式", + get: func(wz *WizardState) string { return boolLabel(wz.Stream) }, + toggle: func(wz *WizardState, _ bool) { wz.Stream = !wz.Stream }, + }, + ) + } else { + fields = append(fields, + fieldDef{ + kind: fieldNumber, label: "初始并发", + get: func(wz *WizardState) string { return strconv.Itoa(wz.InitConcurrency) }, + set: func(wz *WizardState, v string) { + if n, err := strconv.Atoi(v); err == nil && n > 0 { + wz.InitConcurrency = n + } + }, + }, + fieldDef{ + kind: fieldNumber, label: "最大并发", + get: func(wz *WizardState) string { return strconv.Itoa(wz.MaxConcurrency) }, + set: func(wz *WizardState, v string) { + if n, err := strconv.Atoi(v); err == nil && n > 0 { + wz.MaxConcurrency = n + } + }, + }, + fieldDef{ + kind: fieldNumber, label: "步进值", + get: func(wz *WizardState) string { return strconv.Itoa(wz.StepSize) }, + set: func(wz *WizardState, v string) { + if n, err := strconv.Atoi(v); err == nil && n > 0 { + wz.StepSize = n + } + }, + }, + fieldDef{ + kind: fieldNumber, label: "每级请求数", + get: func(wz *WizardState) string { return strconv.Itoa(wz.LevelRequests) }, + set: func(wz *WizardState, v string) { + if n, err := strconv.Atoi(v); err == nil && n > 0 { + wz.LevelRequests = n + } + }, + }, + fieldDef{ + kind: fieldNumber, label: "最低成功率%", + get: func(wz *WizardState) string { return fmt.Sprintf("%.0f", wz.MinSuccessRate) }, + set: func(wz *WizardState, v string) { + if f, err := strconv.ParseFloat(v, 64); err == nil && f > 0 && f <= 100 { + wz.MinSuccessRate = f + } + }, + }, + ) + } + + // Prompt 字段(共用) + promptModes := []string{PromptModeText, PromptModeFile, PromptModeGenerated} + fields = append(fields, + fieldDef{ + kind: fieldEnum, label: "输入方式", + get: func(wz *WizardState) string { + switch wz.PromptMode { + case PromptModeFile: + return "文件" + case PromptModeGenerated: + return "按长度生成" + default: + return "直接输入" + } + }, + toggle: func(wz *WizardState, forward bool) { + idx := 0 + for i, m := range promptModes { + if m == wz.PromptMode { + idx = i + break + } + } + if forward { + idx = (idx + 1) % len(promptModes) + } else { + idx = (idx - 1 + len(promptModes)) % len(promptModes) + } + wz.PromptMode = promptModes[idx] + }, + }, + ) + + // 根据 prompt 模式添加对应字段(在渲染时动态决定) + fields = append(fields, + fieldDef{ + kind: fieldText, label: "内容", + get: func(wz *WizardState) string { + switch wz.PromptMode { + case PromptModeFile: + return wz.PromptFile + case PromptModeGenerated: + return strconv.Itoa(wz.PromptLength) + default: + return wz.PromptText + } + }, + set: func(wz *WizardState, v string) { + switch wz.PromptMode { + case PromptModeFile: + wz.PromptFile = v + case PromptModeGenerated: + if n, err := strconv.Atoi(v); err == nil && n > 0 { + wz.PromptLength = n + } + default: + wz.PromptText = v + } + }, + }, + ) + return fields +} + +// HandleWizardKey 处理向导按键。 +func HandleWizardKey(wz *WizardState, msg tea.KeyMsg, client Client) (*WizardState, tea.Cmd, NavAction) { + nav := NavAction{} + if wz == nil { + return wz, nil, NavAction{To: NavTaskList} + } + + // 当前步骤的字段列表 + var fields []fieldDef + switch wz.Step { + case wizardStep1: + fields = step1Fields() + case wizardStep2: + fields = step2Fields(wz.Turbo) + case wizardStep3: + // Step 3 只有两个动作:保存、保存并运行 + switch msg.String() { + case "esc": + wz.Step = wizardStep2 + wz.FieldIndex = 0 + case "enter": + // 保存任务 + cfg := wz.BuildTaskConfig() + var cmd tea.Cmd + if wz.EditingID != "" { + cmd = client.UpdateTaskCmd(wz.EditingID, cfg) + } else { + cmd = client.CreateTaskCmd(cfg, true) // autoStart + } + nav = NavAction{To: NavTaskList} + return wz, cmd, nav + case "r": + // 保存并运行(强制启动,忽略干扰检测) + cfg := wz.BuildTaskConfig() + var cmd tea.Cmd + if wz.EditingID != "" { + cmd = client.UpdateTaskCmd(wz.EditingID, cfg) + } else { + cmd = client.CreateTaskCmd(cfg, true) + } + nav = NavAction{To: NavTaskList} + return wz, cmd, nav + case "q", "ctrl+c": + nav = NavAction{To: NavQuit} + } + return wz, nil, nav + } + + maxField := len(fields) - 1 + + switch msg.String() { + case "esc": + if wz.Step == wizardStep1 { + nav = NavAction{To: NavTaskList} + } else { + wz.Step-- + wz.FieldIndex = 0 + } + + case "tab", "down", "j": + if wz.FieldIndex < maxField { + wz.FieldIndex++ + } + + case "shift+tab", "up", "k": + if wz.FieldIndex > 0 { + wz.FieldIndex-- + } + + case "left": + if wz.FieldIndex < len(fields) { + f := fields[wz.FieldIndex] + if f.toggle != nil { + f.toggle(wz, false) + // 如果切换了 turbo 模式,重置 fieldIndex + if f.label == "测试模式" { + wz.FieldIndex = 0 + } + } + } + + case "right": + if wz.FieldIndex < len(fields) { + f := fields[wz.FieldIndex] + if f.toggle != nil { + f.toggle(wz, true) + if f.label == "测试模式" { + wz.FieldIndex = 0 + } + } + } + + case "enter": + if wz.FieldIndex == maxField && int(wz.Step) < 2 { + wz.Step++ + wz.FieldIndex = 0 + } else if wz.FieldIndex < maxField { + wz.FieldIndex++ + } + + case "backspace": + if wz.FieldIndex < len(fields) { + f := fields[wz.FieldIndex] + if f.set != nil && f.kind == fieldText { + v := f.get(wz) + r := []rune(v) + if len(r) > 0 { + f.set(wz, string(r[:len(r)-1])) + } + } + } + + case "q", "ctrl+c": + nav = NavAction{To: NavQuit} + + default: + // 字符输入 + if len(msg.Runes) > 0 && wz.FieldIndex < len(fields) { + f := fields[wz.FieldIndex] + if f.set != nil && (f.kind == fieldText || f.kind == fieldNumber) { + f.set(wz, f.get(wz)+string(msg.Runes)) + } + } + } + + return wz, nil, nav +} + +// RenderWizard 渲染三步弹窗向导(overlay 覆盖在后台页面上)。 +func RenderWizard(wz *WizardState, bgView string, st Styles, width, height int) string { + if wz == nil { + return bgView + } + + // 暗化背景 + bgLines := strings.Split(bgView, "\n") + for i, line := range bgLines { + bgLines[i] = st.Muted.Render(line) + } + + // 弹窗尺寸 + dialogW := width - 8 + if dialogW > 72 { + dialogW = 72 + } + if dialogW < 40 { + dialogW = 40 + } + + var dialogLines []string + + stepTitles := []string{"1/3 · 基本信息", "2/3 · 测试参数", "3/3 · 确认保存"} + stepTitle := stepTitles[int(wz.Step)] + isEdit := wz.EditingID != "" + action := "新建任务" + if isEdit { + action = "编辑任务" + } + dialogLines = append(dialogLines, st.SectionHead.Render(fmt.Sprintf(" %s %s", action, stepTitle))) + dialogLines = append(dialogLines, "") + + switch wz.Step { + case wizardStep1: + fields := step1Fields() + for i, f := range fields { + dialogLines = append(dialogLines, renderWizardField(st, f, wz, i == wz.FieldIndex, dialogW-4)) + dialogLines = append(dialogLines, "") + } + dialogLines = append(dialogLines, dividerLine(st, dialogW-4)) + hintStyle := st.Muted + dialogLines = append(dialogLines, hintStyle.Render(" [Tab] 下一项 [↑↓] 切换协议 [Enter] 下一步 [Esc] 取消")) + + case wizardStep2: + fields := step2Fields(wz.Turbo) + for i, f := range fields { + dialogLines = append(dialogLines, renderWizardField(st, f, wz, i == wz.FieldIndex, dialogW-4)) + dialogLines = append(dialogLines, "") + } + dialogLines = append(dialogLines, dividerLine(st, dialogW-4)) + dialogLines = append(dialogLines, st.Muted.Render(" [Tab] 下一项 [←→] 切换模式 [Enter] 下一步 [Esc] 返回")) + + case wizardStep3: + dialogLines = append(dialogLines, renderStep3Summary(wz, st, dialogW-4)...) + dialogLines = append(dialogLines, "") + dialogLines = append(dialogLines, dividerLine(st, dialogW-4)) + dialogLines = append(dialogLines, st.Muted.Render(" [Enter] 保存任务 [r] 保存并运行 [Esc] 返回修改")) + } + + // 构建弹窗框 + innerLines := dialogLines + boxedLines := make([]string, len(innerLines)) + for i, l := range innerLines { + lW := lipgloss.Width(l) + pad := dialogW - 4 - lW + if pad < 0 { + pad = 0 + } + boxedLines[i] = " " + l + strings.Repeat(" ", pad) + } + + // 用 lipgloss rounded border 包裹 + inner := strings.Join(boxedLines, "\n") + box := st.Dialog.Width(dialogW).Render(inner) + + // 将弹窗叠加在背景中间 + boxLines := strings.Split(box, "\n") + startRow := (height - len(boxLines)) / 2 + if startRow < 0 { + startRow = 0 + } + startCol := (width - dialogW) / 2 + if startCol < 0 { + startCol = 0 + } + + for i, boxLine := range boxLines { + row := startRow + i + if row >= len(bgLines) { + bgLines = append(bgLines, strings.Repeat(" ", width)) + } + bgLine := []rune(bgLines[row]) + boxRunes := []rune(boxLine) + // 替换对应列 + for j, r := range boxRunes { + col := startCol + j + if col < len(bgLine) { + bgLine[col] = r + } + } + bgLines[row] = string(bgLine) + } + + return strings.Join(bgLines, "\n") +} + +// renderWizardField 渲染向导的一个字段行。 +func renderWizardField(st Styles, f fieldDef, wz *WizardState, active bool, maxW int) string { + label := padRight(f.label, 12) + var valueStr string + + if f.get != nil { + valueStr = f.get(wz) + } + + // API key 遮蔽显示 + if f.label == "API 密钥" && valueStr != "" { + valueStr = maskAPIKey(valueStr) + } + + var renderedValue string + if active { + if f.kind == fieldEnum || f.kind == fieldBool { + renderedValue = st.Ok.Render("● " + valueStr) + } else { + renderedValue = st.FieldActive.Width(maxW - 14).Render(valueStr + "█") // 光标 + } + } else { + if f.kind == fieldEnum || f.kind == fieldBool { + renderedValue = st.Muted.Render("○ " + valueStr) + } else { + renderedValue = st.FieldIdle.Width(maxW - 14).Render(valueStr) + } + } + + return " " + st.Label.Render(label) + " " + renderedValue +} + +// renderStep3Summary 渲染步骤3的确认内容。 +func renderStep3Summary(wz *WizardState, st Styles, innerW int) []string { + var lines []string + addRow := func(label, value string) { + lines = append(lines, " "+st.Label.Render(padRight(label, 12))+" "+st.Value.Render(value)) + } + + addRow("任务名称", wz.Name) + addRow("协议", wz.Protocol) + endpointDisplay := wz.EndpointURL + if endpointDisplay == "" { + endpointDisplay = types.DefaultEndpointURL(wz.Protocol) + } + addRow("接口地址", truncate(endpointDisplay, innerW-20)) + addRow("API 密钥", maskAPIKey(wz.APIKey)) + addRow("测试模型", wz.Model) + + if wz.Turbo { + addRow("测试模式", "Turbo 模式") + addRow("并发爬坡", fmt.Sprintf("%d → %d 步进 +%d 每级 %d 请求", + wz.InitConcurrency, wz.MaxConcurrency, wz.StepSize, wz.LevelRequests)) + addRow("停止条件", fmt.Sprintf("成功率 < %.0f%%", wz.MinSuccessRate)) + } else { + addRow("测试模式", "标准模式") + addRow("并发数", strconv.Itoa(wz.Concurrency)) + addRow("请求总数", strconv.Itoa(wz.Count)) + addRow("超时", fmt.Sprintf("%ds", wz.Timeout)) + addRow("流式模式", boolLabel(wz.Stream)) + } + + promptDesc := wz.PromptText + if wz.PromptMode == PromptModeFile { + promptDesc = "文件: " + wz.PromptFile + } else if wz.PromptMode == PromptModeGenerated { + promptDesc = fmt.Sprintf("生成 %d 字符", wz.PromptLength) + } + addRow("Prompt", truncate(promptDesc, innerW-20)+fmt.Sprintf(" (长度: %d)", len([]rune(wz.PromptText)))) + + lines = append(lines, "") + lines = append(lines, " "+st.Muted.Render("保存任务到 ~/.ait/tasks.json [✓]")) + lines = append(lines, "") + lines = append(lines, " "+st.BtnPrimary.Render("▶ 保存任务")) + + return lines +} diff --git a/internal/tui/styles.go b/internal/tui/styles.go deleted file mode 100644 index f0f61e0..0000000 --- a/internal/tui/styles.go +++ /dev/null @@ -1,184 +0,0 @@ -package tui - -import "github.com/charmbracelet/lipgloss" - -// Color palette — inspired by the Lip Gloss demo aesthetic: -// electric-purple header, vivid hot-pink brand, deep-plum panels, aqua accents. -const ( - colorHeaderBg = lipgloss.Color("57") // electric indigo — header background - colorFooterBg = lipgloss.Color("235") // near-black footer background - colorPink = lipgloss.Color("205") // vivid hot pink/magenta — brand primary - colorCyan = lipgloss.Color("86") // bright aquamarine — table headers - colorPurple = lipgloss.Color("99") // medium violet — badge, border - colorPurpleDim = lipgloss.Color("60") // slate purple — selected row bg - colorPanelBg = lipgloss.Color("55") // deep plum — content card / panel bg - colorGold = lipgloss.Color("214") // amber — status badge - colorGreen = lipgloss.Color("78") // vivid spring green — ok/success - colorRed = lipgloss.Color("204") // vivid rose-red — error/fail - colorYellow = lipgloss.Color("221") // warm yellow — metric values - colorTeal = lipgloss.Color("111") // periwinkle-teal — labels - colorWhite = lipgloss.Color("255") // bright white - colorMuted = lipgloss.Color("245") // muted gray - colorDimBorder = lipgloss.Color("238") // dim border gray - colorFieldBg = lipgloss.Color("55") // deep plum — active field bg - colorDark = lipgloss.Color("235") // near-black text on light badge - colorHeaderFg = lipgloss.Color("212") // light pink — header right text -) - -type styles struct { - header lipgloss.Style - footer lipgloss.Style - sectionHead lipgloss.Style - tableHead lipgloss.Style - tableRow lipgloss.Style - tableRowSel lipgloss.Style - label lipgloss.Style - value lipgloss.Style - muted lipgloss.Style - ok lipgloss.Style - errStyle lipgloss.Style - key lipgloss.Style - metricVal lipgloss.Style - dialog lipgloss.Style - fieldActive lipgloss.Style - fieldIdle lipgloss.Style - cursor lipgloss.Style - // Badge styles - badge lipgloss.Style // AIT brand badge (purple) - badgeAlt lipgloss.Style // alternate badge (gold) - tagTurbo lipgloss.Style // "TURBO" mode inline tag - tagStd lipgloss.Style // "标准" mode inline tag - // Log entry markers - logOk lipgloss.Style - logErr lipgloss.Style - // Wizard step indicators - stepDone lipgloss.Style - stepActive lipgloss.Style - stepTodo lipgloss.Style - // Primary action button - btnPrimary lipgloss.Style - // Divider / decorative line - divider lipgloss.Style - // Content panel (deep-plum background card, like the demo's purple paragraphs) - panel lipgloss.Style -} - -func newStyles() styles { - return styles{ - // Header: deep indigo-purple background, white foreground - header: lipgloss.NewStyle(). - Background(colorHeaderBg). - Foreground(colorWhite), - // Footer: dark near-black background - footer: lipgloss.NewStyle(). - Background(colorFooterBg). - Foreground(colorMuted), - // Section headings: hot pink, bold - sectionHead: lipgloss.NewStyle(). - Foreground(colorPink). - Bold(true), - // Table column headers: vivid cyan, bold - tableHead: lipgloss.NewStyle(). - Foreground(colorCyan). - Bold(true), - // Normal table rows: bright white - tableRow: lipgloss.NewStyle(). - Foreground(colorWhite), - // Selected table row: dim-purple bg, white text, bold - tableRowSel: lipgloss.NewStyle(). - Background(colorPurpleDim). - Foreground(colorWhite). - Bold(true), - // Property labels: soft teal, bold - label: lipgloss.NewStyle(). - Foreground(colorTeal). - Bold(true), - // Property values: bright white - value: lipgloss.NewStyle(). - Foreground(colorWhite), - // Secondary/muted text: gray - muted: lipgloss.NewStyle(). - Foreground(colorMuted), - // Success indicator: bright green, bold - ok: lipgloss.NewStyle(). - Foreground(colorGreen). - Bold(true), - // Error indicator: red, bold - errStyle: lipgloss.NewStyle(). - Foreground(colorRed). - Bold(true), - // Keyboard shortcut keys: hot pink, bold - key: lipgloss.NewStyle(). - Foreground(colorPink). - Bold(true), - // Metric numeric values: yellow, bold - metricVal: lipgloss.NewStyle(). - Foreground(colorYellow). - Bold(true), - // Dialog/modal box: rounded border in hot pink, padded - dialog: lipgloss.NewStyle(). - Border(lipgloss.RoundedBorder()). - BorderForeground(colorPink). - Padding(1, 2), - // Active wizard field: dark purple-blue bg, white text - fieldActive: lipgloss.NewStyle(). - Background(colorFieldBg). - Foreground(colorWhite), - // Idle wizard field: muted gray text - fieldIdle: lipgloss.NewStyle(). - Foreground(colorMuted), - // Cursor/selection arrow: hot pink, bold - cursor: lipgloss.NewStyle(). - Foreground(colorPink). - Bold(true), - // AIT brand badge: purple bg, white text, padded - badge: lipgloss.NewStyle(). - Background(colorPurple). - Foreground(colorWhite). - Bold(true). - Padding(0, 1), - // Alternate badge: gold bg, dark text - badgeAlt: lipgloss.NewStyle(). - Background(colorGold). - Foreground(colorDark). - Bold(true). - Padding(0, 1), - // TURBO mode tag: pink bg, dark text - tagTurbo: lipgloss.NewStyle(). - Foreground(colorPink). - Bold(true), - // Standard mode tag: cyan text - tagStd: lipgloss.NewStyle(). - Foreground(colorCyan), - // Log ok: green - logOk: lipgloss.NewStyle(). - Foreground(colorGreen), - // Log error: red - logErr: lipgloss.NewStyle(). - Foreground(colorRed), - // Wizard step done: green checkmark - stepDone: lipgloss.NewStyle(). - Foreground(colorGreen), - // Wizard step active: pink, bold - stepActive: lipgloss.NewStyle(). - Foreground(colorPink). - Bold(true), - // Wizard step todo: dim - stepTodo: lipgloss.NewStyle(). - Foreground(colorMuted), - // Primary action button: hot-pink bg, dark text, padded - btnPrimary: lipgloss.NewStyle(). - Background(colorPink). - Foreground(colorDark). - Bold(true). - Padding(0, 2), - // Divider line: dim gray - divider: lipgloss.NewStyle(). - Foreground(colorDimBorder), - // Content panel: deep-plum background, white text, padded (like demo's purple paragraphs) - panel: lipgloss.NewStyle(). - Background(colorPanelBg). - Foreground(colorWhite). - Padding(1, 2), - } -} From 3e1eeaf0aefe7d90b8df4502dfc620cf9d261adf Mon Sep 17 00:00:00 2001 From: Alain Date: Sun, 17 May 2026 01:33:32 +0800 Subject: [PATCH 07/52] refactor: remove display package and related tests - Deleted the entire display package, including the main display logic and test cases. - Removed dependencies on tablewriter and progressbar libraries. - Updated the TUI model and pages to ensure proper rendering and layout adjustments. - Introduced a new wrapPanel function to handle panel rendering with borders. - Adjusted various TUI pages to utilize the new panel wrapping for consistent styling. --- Makefile | 29 ++- cmd/tpg/tpg.go | 352 -------------------------- cmd/tpg/tpg_test.go | 421 ------------------------------- internal/display/display.go | 333 ------------------------ internal/display/display_test.go | 118 --------- internal/tui/model.go | 30 ++- internal/tui/pages/dashboard.go | 36 +-- internal/tui/pages/helpers.go | 8 + internal/tui/pages/reqdetail.go | 31 ++- internal/tui/pages/styles.go | 8 + internal/tui/pages/taskdetail.go | 7 +- internal/tui/pages/tasklist.go | 7 +- internal/tui/pages/turbodash.go | 35 +-- 13 files changed, 122 insertions(+), 1293 deletions(-) delete mode 100644 cmd/tpg/tpg.go delete mode 100644 cmd/tpg/tpg_test.go delete mode 100644 internal/display/display.go delete mode 100644 internal/display/display_test.go diff --git a/Makefile b/Makefile index a9a1c47..5403282 100644 --- a/Makefile +++ b/Makefile @@ -1,5 +1,5 @@ # 项目配置 -BINARIES=ait tpg +BINARY=ait BIN_DIR=bin # Go 相关变量 @@ -10,7 +10,10 @@ GOTEST=$(GOCMD) test GOMOD=$(GOCMD) mod # 构建标志 -LDFLAGS=-ldflags "-w -s" +VERSION ?= $(shell git describe --tags --always --dirty 2>/dev/null || echo "dev") +GIT_COMMIT ?= $(shell git rev-parse --short HEAD 2>/dev/null || echo "unknown") +BUILD_TIME ?= $(shell date -u +"%Y-%m-%dT%H:%M:%SZ") +LDFLAGS=-ldflags "-w -s -X main.Version=$(VERSION) -X main.GitCommit=$(GIT_COMMIT) -X main.BuildTime=$(BUILD_TIME)" BUILD_FLAGS=-trimpath $(LDFLAGS) ## help: 显示此帮助信息 @@ -19,15 +22,25 @@ help: @echo 'Usage:' @sed -n 's/^##//p' ${MAKEFILE_LIST} | column -t -s ':' | sed -e 's/^/ /' -## build: 构建所有二进制文件 +## build: 构建当前平台二进制 .PHONY: build build: - @echo "正在构建所有二进制文件..." + @echo "正在构建 $(BINARY)..." @mkdir -p $(BIN_DIR) - @for binary in $(BINARIES); do \ - echo "构建 $$binary..."; \ - $(GOBUILD) $(BUILD_FLAGS) -o $(BIN_DIR)/$$binary ./cmd/$$binary/; \ - done + $(GOBUILD) $(BUILD_FLAGS) -o $(BIN_DIR)/$(BINARY) ./cmd/$(BINARY)/ + +## build-all: 交叉编译所有平台 +.PHONY: build-all +build-all: + @echo "正在交叉编译所有平台..." + @mkdir -p $(BIN_DIR) + GOOS=linux GOARCH=amd64 $(GOBUILD) $(BUILD_FLAGS) -o $(BIN_DIR)/$(BINARY)-linux-amd64 ./cmd/$(BINARY)/ + GOOS=linux GOARCH=arm64 $(GOBUILD) $(BUILD_FLAGS) -o $(BIN_DIR)/$(BINARY)-linux-arm64 ./cmd/$(BINARY)/ + GOOS=linux GOARCH=386 $(GOBUILD) $(BUILD_FLAGS) -o $(BIN_DIR)/$(BINARY)-linux-386 ./cmd/$(BINARY)/ + GOOS=linux GOARCH=arm $(GOBUILD) $(BUILD_FLAGS) -o $(BIN_DIR)/$(BINARY)-linux-arm ./cmd/$(BINARY)/ + GOOS=darwin GOARCH=amd64 $(GOBUILD) $(BUILD_FLAGS) -o $(BIN_DIR)/$(BINARY)-darwin-amd64 ./cmd/$(BINARY)/ + GOOS=darwin GOARCH=arm64 $(GOBUILD) $(BUILD_FLAGS) -o $(BIN_DIR)/$(BINARY)-darwin-arm64 ./cmd/$(BINARY)/ + GOOS=windows GOARCH=amd64 $(GOBUILD) $(BUILD_FLAGS) -o $(BIN_DIR)/$(BINARY)-windows-amd64.exe ./cmd/$(BINARY)/ ## test: 运行所有测试 .PHONY: test diff --git a/cmd/tpg/tpg.go b/cmd/tpg/tpg.go deleted file mode 100644 index 35dfebf..0000000 --- a/cmd/tpg/tpg.go +++ /dev/null @@ -1,352 +0,0 @@ -package main - -import ( - "crypto/rand" - "flag" - "fmt" - mathRand "math/rand" - "os" - "strings" - "time" - - "github.com/yinxulai/ait/internal/display" -) - -// 版本信息,通过 ldflags 在构建时注入 -var ( - Version = "dev" - GitCommit = "unknown" - BuildTime = "unknown" -) - -var ( - // 使用全局随机数生成器 - rng = mathRand.New(mathRand.NewSource(time.Now().UnixNano())) - - sentences = []string{ - "The quick brown fox jumps over the lazy dog", - "Pack my box with five dozen liquor jugs", - "Generate a creative story about artificial intelligence", - "Explain quantum computing in simple terms", - "Write a poem about the changing seasons", - "Analyze the impact of social media on modern society", - "Describe your ideal futuristic cityscape", - "Compare and contrast machine learning algorithms", - "Provide a detailed recipe for chocolate chip cookies", - "Discuss the philosophical implications of consciousness", - "Create a business plan for a tech startup", - "Outline the key elements of effective leadership", - "Describe how blockchain technology works", - "Write a short script for a science fiction movie", - "Explain the theory of relativity using analogies", - "Discuss the pros and cons of renewable energy", - "Generate ideas for reducing carbon footprint", - "Analyze the theme of love in classical literature", - "Create a marketing strategy for a new product", - "Explain the water cycle to a 5-year-old", - "Describe the process of cellular respiration", - "Write a persuasive essay about education reform", - "Discuss the cultural significance of ancient mythology", - "Propose a solution to traffic congestion in cities", - "Explain how vaccines work to fight diseases", - // 中文句子 - "请解释人工智能的基本原理和应用场景", - "描述一下你理想中的智能城市是什么样子", - "分析社交媒体对现代社会的积极和消极影响", - "写一首关于四季变化的现代诗", - "解释区块链技术的工作原理和优势", - "设计一个可持续发展的商业模式", - "讨论教育改革的必要性和实施策略", - "分析古典文学中爱情主题的表现手法", - "提出解决城市交通拥堵的创新方案", - "解释疫苗如何帮助人体抵抗疾病", - // 日文句子 - "人工知能の基本的な仕組みと応用について説明してください", - "理想的な未来都市の姿を描写してください", - "ソーシャルメディアが現代社会に与える影響を分析してください", - "四季の移り変わりについて詩を書いてください", - "ブロックチェーン技術の仕組みと利点を説明してください", - "持続可能なビジネスモデルを設計してください", - "教育改革の必要性と実施戦略について論じてください", - "古典文学における愛のテーマの表現手法を分析してください", - "都市の交通渋滞を解決する革新的な方案を提案してください", - "ワクチンが病気と闘うメカニズムを説明してください", - // 韩文句子 - "인공지능의 기본 원리와 응용 분야에 대해 설명해주세요", - "당신이 생각하는 이상적인 미래 도시의 모습을 묘사해주세요", - "소셜미디어가 현대사회에 미치는 긍정적, 부정적 영향을 분석해주세요", - "계절의 변화에 대한 시를 써주세요", - "블록체인 기술의 작동 원리와 장점을 설명해주세요", - "지속가능한 비즈니스 모델을 설계해주세요", - "교육 개혁의 필요성과 실행 전략에 대해 논의해주세요", - "고전문학에서 사랑 주제의 표현 기법을 분석해주세요", - "도시 교통 체증을 해결할 혁신적인 방안을 제안해주세요", - "백신이 질병과 싸우는 메커니즘을 설명해주세요", - // 法文句子 - "Expliquez les principes fondamentaux de l'intelligence artificielle", - "Décrivez votre vision de la ville intelligente idéale", - "Analysez l'impact des réseaux sociaux sur la société moderne", - "Écrivez un poème sur les changements de saisons", - "Expliquez le fonctionnement de la technologie blockchain", - "Concevez un modèle commercial durable", - "Discutez de la nécessité de réformer l'éducation", - "Analysez le thème de l'amour dans la littérature classique", - "Proposez des solutions innovantes aux embouteillages urbains", - "Expliquez comment les vaccins combattent les maladies", - // 德文句子 - "Erklären Sie die Grundprinzipien der künstlichen Intelligenz", - "Beschreiben Sie Ihre Vision einer idealen Smart City", - "Analysieren Sie die Auswirkungen sozialer Medien auf die Gesellschaft", - "Schreiben Sie ein Gedicht über den Wechsel der Jahreszeiten", - "Erklären Sie die Funktionsweise der Blockchain-Technologie", - "Entwickeln Sie ein nachhaltiges Geschäftsmodell", - "Diskutieren Sie die Notwendigkeit von Bildungsreformen", - "Analysieren Sie das Liebesthema in der klassischen Literatur", - "Schlagen Sie innovative Lösungen für städtische Verkehrsstaus vor", - "Erklären Sie, wie Impfstoffe Krankheiten bekämpfen", - // 西班牙文句子 - "Explique los principios básicos de la inteligencia artificial", - "Describa su visión de la ciudad inteligente ideal", - "Analice el impacto de las redes sociales en la sociedad moderna", - "Escriba un poema sobre los cambios estacionales", - "Explique cómo funciona la tecnología blockchain", - "Diseñe un modelo de negocio sostenible", - "Discuta la necesidad de reformas educativas", - "Analice el tema del amor en la literatura clásica", - "Proponga soluciones innovadoras para la congestión del tráfico urbano", - "Explique cómo las vacunas combaten las enfermedades", - // 俄文句子 - "Объясните основные принципы искусственного интеллекта", - "Опишите вашу концепцию идеального умного города", - "Проанализируйте влияние социальных сетей на современное общество", - "Напишите стихотворение о смене времен года", - "Объясните принципы работы технологии блокчейн", - "Разработайте устойчивую бизнес-модель", - "Обсудите необходимость реформы образования", - "Проанализируйте тему любви в классической литературе", - "Предложите инновационные решения городских пробок", - "Объясните, как вакцины борются с болезнями", - // 阿拉伯文句子 - "اشرح المبادئ الأساسية للذكاء الاصطناعي", - "صف رؤيتك للمدينة الذكية المثالية", - "حلل تأثير وسائل التواصل الاجتماعي على المجتمع الحديث", - "اكتب قصيدة عن تغير الفصول", - "اشرح كيفية عمل تقنية البلوك تشين", - "صمم نموذج أعمال مستدام", - "ناقش ضرورة إصلاح التعليم", - "حلل موضوع الحب في الأدب الكلاسيكي", - "اقترح حلول مبتكرة لازدحام المرور في المدن", - "اشرح كيف تحارب اللقاحات الأمراض", - } -) - -// Template 模板结构 -type Template struct { - Content string - Variables map[string]string -} - -// applyTemplate 应用模板,替换占位符 -func (t *Template) applyTemplate(content string, index int, timestamp time.Time) string { - result := t.Content - - // 替换基本占位符 - result = strings.ReplaceAll(result, "{{content}}", content) - result = strings.ReplaceAll(result, "{{index}}", fmt.Sprintf("%d", index)) - result = strings.ReplaceAll(result, "{{timestamp}}", timestamp.Format(time.RFC3339)) - - // 替换自定义变量 - for key, value := range t.Variables { - placeholder := fmt.Sprintf("{{%s}}", key) - result = strings.ReplaceAll(result, placeholder, value) - } - - return result -} - -// generateTaskID 生成任务ID(参考 ait.go) -func generateTaskID() string { - bytes := make([]byte, 16) - rand.Read(bytes) - - // 设置版本 (4) 和变体位 - bytes[6] = (bytes[6] & 0x0f) | 0x40 // Version 4 - bytes[8] = (bytes[8] & 0x3f) | 0x80 // Variant 10 - - return fmt.Sprintf("%08x-%04x-%04x-%04x-%012x", - bytes[0:4], bytes[4:6], bytes[6:8], bytes[8:10], bytes[10:16]) -} - -// validateParams 验证参数 -func validateParams(count, length int, outputDir string) error { - if count <= 0 { - return fmt.Errorf("count 必须大于 0") - } - - if length <= 0 { - return fmt.Errorf("length 必须大于 0") - } - - if outputDir == "" { - return fmt.Errorf("输出目录不能为空") - } - - return nil -} - -// printHelp 打印帮助信息 -func printHelp() { - fmt.Println("TPG - 测试 Prompt 生成器") - fmt.Println("") - fmt.Println("用法:") - fmt.Println(" tpg [选项]") - fmt.Println("") - fmt.Println("选项:") - fmt.Println(" -count 生成的 prompt 数量 (默认: 10)") - fmt.Println(" -length 每个 prompt 的近似长度 (默认: 50)") - fmt.Println(" -output 输出目录 (默认: output)") - fmt.Println(" -template 模板字符串,支持占位符: {{content}}, {{index}}, {{timestamp}}") - fmt.Println(" -help 显示此帮助信息") - fmt.Println("") - fmt.Println("模板占位符说明:") - fmt.Println(" {{content}} - 生成的prompt内容") - fmt.Println(" {{index}} - prompt序号 (从1开始)") - fmt.Println(" {{timestamp}} - 当前时间戳 (RFC3339格式)") - fmt.Println("") - fmt.Println("示例:") - fmt.Println(" tpg -count=20 -length=100") - fmt.Println(" tpg -output=./test-prompts") - fmt.Println(" tpg -template=\"请回答以下问题: {{content}}\"") -} - -// generateRandomText 生成指定长度的随机文本 -func generateRandomText(desiredLength int) string { - var result strings.Builder - var selectedSentences []string - - // 选择句子直到累计长度达到或接近期望长度 - totalLength := 0 - for totalLength < desiredLength { - sentence := sentences[rng.Intn(len(sentences))] - // 如果加上这个句子会超出期望长度太多,且已经有句子了,就停止 - if totalLength > 0 && totalLength + len(sentence) + 1 > desiredLength * 2 { - break - } - selectedSentences = append(selectedSentences, sentence) - totalLength += len(sentence) + 1 // +1 for space - } - - // 拼接句子 - for i, sentence := range selectedSentences { - result.WriteString(sentence) - if i < len(selectedSentences) - 1 { - result.WriteString(" ") - } - } - - return result.String() -} - -// writePromptFile 写入 prompt 文件 -func writePromptFile(prompt, filename string, template *Template, index int) error { - // 如果有模板,应用模板 - finalContent := prompt - if template != nil { - finalContent = template.applyTemplate(prompt, index, time.Now()) - } - - if err := os.WriteFile(filename, []byte(finalContent), 0644); err != nil { - return fmt.Errorf("写入文件 %s 失败: %v", filename, err) - } - return nil -} - -func main() { - // 定义命令行参数 - count := flag.Int("count", 10, "生成的 prompt 数量") - length := flag.Int("length", 50, "每个 prompt 的近似长度") - outputDir := flag.String("output", "prompts", "输出目录") - templateStr := flag.String("template", "", "模板字符串,支持占位符") - help := flag.Bool("help", false, "显示帮助信息") - - flag.Parse() - - // 显示帮助信息 - if *help { - printHelp() - return - } - - // 验证参数 - if err := validateParams(*count, *length, *outputDir); err != nil { - fmt.Printf("%s错误: %s%s\n", display.ColorRed, err.Error(), display.ColorReset) - fmt.Println("使用 -help 查看帮助信息") - os.Exit(1) - } - - // 处理模板 - var template *Template - - if *templateStr != "" { - template = &Template{ - Content: *templateStr, - Variables: make(map[string]string), - } - } - - // 创建输出目录 - if err := os.MkdirAll(*outputDir, os.ModePerm); err != nil { - fmt.Printf("%s错误: 创建输出目录失败: %v%s\n", display.ColorRed, err, display.ColorReset) - os.Exit(1) - } - - // 生成任务ID(可用于日志等) - taskID := generateTaskID() - _ = taskID // 当前版本暂不使用,但保留接口 - - // 显示欢迎信息和配置 - displayer := display.New() - displayer.ShowWelcome(Version) - - fmt.Printf("%s=== TPG 配置信息 ===%s\n", display.ColorBlue, display.ColorReset) - fmt.Printf("数量: %d\n", *count) - fmt.Printf("长度: %d 字符\n", *length) - fmt.Printf("输出目录: %s\n", *outputDir) - if template != nil { - fmt.Printf("使用模板: 是\n") - } else { - fmt.Printf("使用模板: 否\n") - } - fmt.Println() - - // 生成 prompts - successCount := 0 - var errors []string - - for i := 0; i < *count; i++ { - prompt := generateRandomText(*length) - filename := fmt.Sprintf("%s/prompt_%d.txt", *outputDir, i+1) - - if err := writePromptFile(prompt, filename, template, i+1); err != nil { - errors = append(errors, err.Error()) - continue - } - - successCount++ - fmt.Printf("%s✓%s 生成文件: %s\n", display.ColorGreen, display.ColorReset, filename) - } - - // 显示结果统计 - fmt.Println() - fmt.Printf("%s=== 生成结果 ===%s\n", display.ColorBlue, display.ColorReset) - fmt.Printf("%s成功生成: %d/%d%s\n", display.ColorGreen, successCount, *count, display.ColorReset) - - if len(errors) > 0 { - fmt.Printf("%s失败: %d%s\n", display.ColorRed, len(errors), display.ColorReset) - fmt.Printf("%s错误详情:%s\n", display.ColorRed, display.ColorReset) - for _, err := range errors { - fmt.Printf(" - %s\n", err) - } - os.Exit(1) - } -} diff --git a/cmd/tpg/tpg_test.go b/cmd/tpg/tpg_test.go deleted file mode 100644 index 17d11c5..0000000 --- a/cmd/tpg/tpg_test.go +++ /dev/null @@ -1,421 +0,0 @@ -package main - -import ( - "os" - "path/filepath" - "strings" - "testing" - "time" -) - -// TestTemplate_applyTemplate 测试模板应用功能 -func TestTemplate_applyTemplate(t *testing.T) { - tests := []struct { - name string - template *Template - content string - index int - timestamp time.Time - expected string - }{ - { - name: "基本占位符替换", - template: &Template{ - Content: "内容: {{content}}, 序号: {{index}}, 时间: {{timestamp}}", - Variables: make(map[string]string), - }, - content: "测试内容", - index: 1, - timestamp: time.Date(2023, 1, 1, 12, 0, 0, 0, time.UTC), - expected: "内容: 测试内容, 序号: 1, 时间: 2023-01-01T12:00:00Z", - }, - { - name: "自定义变量替换", - template: &Template{ - Content: "{{content}} - {{custom1}} - {{custom2}}", - Variables: map[string]string{ - "custom1": "变量1", - "custom2": "变量2", - }, - }, - content: "主要内容", - index: 2, - timestamp: time.Date(2023, 1, 1, 12, 0, 0, 0, time.UTC), - expected: "主要内容 - 变量1 - 变量2", - }, - { - name: "无占位符模板", - template: &Template{ - Content: "固定内容", - Variables: make(map[string]string), - }, - content: "测试", - index: 1, - timestamp: time.Date(2023, 1, 1, 12, 0, 0, 0, time.UTC), - expected: "固定内容", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := tt.template.applyTemplate(tt.content, tt.index, tt.timestamp) - if result != tt.expected { - t.Errorf("applyTemplate() = %v, want %v", result, tt.expected) - } - }) - } -} - -// TestGenerateTaskID 测试任务ID生成 -func TestGenerateTaskID(t *testing.T) { - // 生成多个ID确保它们不重复 - ids := make(map[string]bool) - for i := 0; i < 100; i++ { - id := generateTaskID() - - // 检查格式 (8-4-4-4-12) - parts := strings.Split(id, "-") - if len(parts) != 5 { - t.Errorf("generateTaskID() format error, got %d parts, want 5", len(parts)) - } - - // 检查每部分长度 - expectedLengths := []int{8, 4, 4, 4, 12} - for j, part := range parts { - if len(part) != expectedLengths[j] { - t.Errorf("generateTaskID() part %d length error, got %d, want %d", j, len(part), expectedLengths[j]) - } - } - - // 检查重复 - if ids[id] { - t.Errorf("generateTaskID() generated duplicate ID: %s", id) - } - ids[id] = true - } -} - -// TestValidateParams 测试参数验证 -func TestValidateParams(t *testing.T) { - tests := []struct { - name string - count int - length int - outputDir string - wantErr bool - errMsg string - }{ - { - name: "有效参数", - count: 10, - length: 50, - outputDir: "output", - wantErr: false, - }, - { - name: "count为0", - count: 0, - length: 50, - outputDir: "output", - wantErr: true, - errMsg: "count 必须大于 0", - }, - { - name: "count为负数", - count: -1, - length: 50, - outputDir: "output", - wantErr: true, - errMsg: "count 必须大于 0", - }, - { - name: "length为0", - count: 10, - length: 0, - outputDir: "output", - wantErr: true, - errMsg: "length 必须大于 0", - }, - { - name: "length为负数", - count: 10, - length: -1, - outputDir: "output", - wantErr: true, - errMsg: "length 必须大于 0", - }, - { - name: "outputDir为空", - count: 10, - length: 50, - outputDir: "", - wantErr: true, - errMsg: "输出目录不能为空", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - err := validateParams(tt.count, tt.length, tt.outputDir) - if tt.wantErr { - if err == nil { - t.Errorf("validateParams() error = nil, wantErr %v", tt.wantErr) - return - } - if err.Error() != tt.errMsg { - t.Errorf("validateParams() error = %v, want %v", err.Error(), tt.errMsg) - } - } else { - if err != nil { - t.Errorf("validateParams() error = %v, wantErr %v", err, tt.wantErr) - } - } - }) - } -} - -// TestGenerateRandomText 测试随机文本生成 -func TestGenerateRandomText(t *testing.T) { - tests := []struct { - name string - desiredLength int - minLength int // 最小期望长度 - maxLength int // 最大期望长度 - }{ - { - name: "短文本", - desiredLength: 20, - minLength: 5, - maxLength: 80, // 增加上限以适应句子组合的特性 - }, - { - name: "中等文本", - desiredLength: 100, - minLength: 30, - maxLength: 300, // 增加上限 - }, - { - name: "长文本", - desiredLength: 500, - minLength: 200, - maxLength: 1500, // 增加上限 - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := generateRandomText(tt.desiredLength) - - // 检查结果不为空 - if result == "" { - t.Error("generateRandomText() returned empty string") - } - - // 检查长度在合理范围内 - if len(result) < tt.minLength || len(result) > tt.maxLength { - t.Errorf("generateRandomText() length = %d, want between %d and %d", - len(result), tt.minLength, tt.maxLength) - } - - // 检查结果包含预期的句子 - containsKnownSentence := false - for _, sentence := range sentences { - if strings.Contains(result, sentence) { - containsKnownSentence = true - break - } - } - if !containsKnownSentence { - t.Error("generateRandomText() result doesn't contain any known sentence") - } - }) - } -} - -// TestWritePromptFile 测试文件写入功能 -func TestWritePromptFile(t *testing.T) { - // 创建临时目录 - tempDir, err := os.MkdirTemp("", "tpg_test") - if err != nil { - t.Fatalf("Failed to create temp dir: %v", err) - } - defer os.RemoveAll(tempDir) - - tests := []struct { - name string - prompt string - filename string - template *Template - index int - wantErr bool - }{ - { - name: "基本文件写入", - prompt: "测试内容", - filename: filepath.Join(tempDir, "test1.txt"), - template: nil, - index: 1, - wantErr: false, - }, - { - name: "使用模板写入", - prompt: "测试内容", - filename: filepath.Join(tempDir, "test2.txt"), - template: &Template{ - Content: "序号{{index}}: {{content}}", - Variables: make(map[string]string), - }, - index: 2, - wantErr: false, - }, - { - name: "无效路径", - prompt: "测试内容", - filename: "/invalid/path/test.txt", - template: nil, - index: 1, - wantErr: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - err := writePromptFile(tt.prompt, tt.filename, tt.template, tt.index) - - if tt.wantErr { - if err == nil { - t.Error("writePromptFile() error = nil, wantErr true") - } - return - } - - if err != nil { - t.Errorf("writePromptFile() error = %v, wantErr false", err) - return - } - - // 验证文件是否创建 - if _, err := os.Stat(tt.filename); os.IsNotExist(err) { - t.Error("writePromptFile() file was not created") - return - } - - // 读取文件内容验证 - content, err := os.ReadFile(tt.filename) - if err != nil { - t.Errorf("Failed to read written file: %v", err) - return - } - - expectedContent := tt.prompt - if tt.template != nil { - expectedContent = tt.template.applyTemplate(tt.prompt, tt.index, time.Now()) - } - - contentStr := string(content) - if tt.template == nil && contentStr != expectedContent { - t.Errorf("File content = %v, want %v", contentStr, expectedContent) - } else if tt.template != nil { - // 对于使用模板的情况,至少检查基本内容是否包含 - if !strings.Contains(contentStr, tt.prompt) { - t.Errorf("File content should contain prompt: %v", tt.prompt) - } - } - }) - } -} - -// TestSentencesAvailability 测试句子库是否可用 -func TestSentencesAvailability(t *testing.T) { - if len(sentences) == 0 { - t.Error("sentences slice is empty") - } - - // 检查是否包含不同语言的句子 - languages := map[string]bool{ - "english": false, - "chinese": false, - "japanese": false, - "korean": false, - "french": false, - "german": false, - "spanish": false, - "russian": false, - "arabic": false, - } - - for _, sentence := range sentences { - if strings.Contains(sentence, "The quick brown fox") { - languages["english"] = true - } - if strings.Contains(sentence, "人工智能") { - languages["chinese"] = true - } - if strings.Contains(sentence, "人工知能") { - languages["japanese"] = true - } - if strings.Contains(sentence, "인공지능") { - languages["korean"] = true - } - if strings.Contains(sentence, "intelligence artificielle") { - languages["french"] = true - } - if strings.Contains(sentence, "künstlichen Intelligenz") { - languages["german"] = true - } - if strings.Contains(sentence, "inteligencia artificial") { - languages["spanish"] = true - } - if strings.Contains(sentence, "искусственного интеллекта") { - languages["russian"] = true - } - if strings.Contains(sentence, "الذكاء الاصطناعي") { - languages["arabic"] = true - } - } - - // 检查是否至少包含几种语言 - foundLanguages := 0 - for _, found := range languages { - if found { - foundLanguages++ - } - } - - if foundLanguages < 5 { - t.Errorf("Expected at least 5 different languages in sentences, found %d", foundLanguages) - } -} - -// BenchmarkGenerateRandomText 性能测试 -func BenchmarkGenerateRandomText(b *testing.B) { - for i := 0; i < b.N; i++ { - generateRandomText(100) - } -} - -// BenchmarkGenerateTaskID 任务ID生成性能测试 -func BenchmarkGenerateTaskID(b *testing.B) { - for i := 0; i < b.N; i++ { - generateTaskID() - } -} - -// BenchmarkTemplateApply 模板应用性能测试 -func BenchmarkTemplateApply(b *testing.B) { - template := &Template{ - Content: "内容: {{content}}, 序号: {{index}}, 时间: {{timestamp}}", - Variables: map[string]string{ - "custom1": "value1", - "custom2": "value2", - }, - } - - content := "测试内容" - timestamp := time.Now() - - b.ResetTimer() - for i := 0; i < b.N; i++ { - template.applyTemplate(content, i, timestamp) - } -} diff --git a/internal/display/display.go b/internal/display/display.go deleted file mode 100644 index 859b2a0..0000000 --- a/internal/display/display.go +++ /dev/null @@ -1,333 +0,0 @@ -package display - -import ( - "fmt" - "os" - "strconv" - "sync" - "time" - - "github.com/olekukonko/tablewriter" - "github.com/schollz/progressbar/v3" - "github.com/yinxulai/ait/internal/types" -) - -// Colors 定义终端颜色 - 导出供外部使用 -const ( - ColorReset = "\033[0m" - ColorRed = "\033[31m" - ColorGreen = "\033[32m" - ColorYellow = "\033[33m" - ColorBlue = "\033[34m" - ColorPurple = "\033[35m" - ColorCyan = "\033[36m" - ColorWhite = "\033[37m" - ColorBold = "\033[1m" -) - -type Input struct { - TaskId string // 任务 ID,随机生成的唯一标识符 - Protocol string - BaseUrl string - ApiKey string - Models []string // 多个模型列表 - Concurrency int - Count int - Stream bool - Thinking bool // 是否开启思考模式 - PromptText string // 用于显示的prompt文本 - PromptShouldTruncate bool // 是否需要截断显示 - IsFile bool // 是否为文件类型输入 - Report bool // 是否生成报告文件 - Timeout int // 请求超时时间(秒) -} - -// Displayer 测试显示器 -type Displayer struct { - progressBar *progressbar.ProgressBar - mu sync.Mutex -} - -// New 创建新的测试显示器 -func New() *Displayer { - return &Displayer{} -} - -func (td *Displayer) ShowWelcome(version string) { - fmt.Printf("\n") - // AIT ASCII 字符画和说明信息并排显示,使用统一的青色+粗体 - fmt.Printf("%s%s █████╗ ██╗ ████████╗%s %s🚀 %sAI 模型性能测试工具 %s(%s)%s\n", ColorBold, ColorCyan, ColorReset, ColorBold, ColorCyan, ColorGreen, version, ColorReset) - fmt.Printf("%s%s ██╔══██╗ ██║ ╚══██╔══╝%s %s一个强大的 CLI 工具,用于测试 AI 模型的性能指标%s\n", ColorBold, ColorCyan, ColorReset, ColorWhite, ColorReset) - fmt.Printf("%s%s ███████║ ██║ ██║%s %s🌐 项目地址: https://github.com/yinxulai/ait%s\n", ColorBold, ColorCyan, ColorReset, ColorBlue, ColorReset) - fmt.Printf("%s%s ██╔══██║ ██║ ██║%s \n", ColorBold, ColorCyan, ColorReset) - fmt.Printf("%s%s ██║ ██║ ██║ ██║%s %s✨ 功能特性:%s\n", ColorBold, ColorCyan, ColorReset, ColorBold, ColorReset) - fmt.Printf("%s%s ╚═╝ ╚═╝ ╚═╝ ╚═╝%s 🎯 多模型批量测试 ⚡ 并发压力测试 📊 实时进度显示\n", ColorBold, ColorCyan, ColorReset) - fmt.Printf(" 🌐 网络性能分析 📈 详细统计报告 🎨 美观界面输出\n") - fmt.Printf("\n") -} - -func (td *Displayer) ShowInput(data *Input) { - // 创建配置信息表格 - table := tablewriter.NewTable( - os.Stdout, - tablewriter.WithEastAsian(false), - ) - table.Header("配置项", "值", "说明") - - // 任务信息 - table.Append("🆔 任务 ID", data.TaskId, "本次测试的唯一标识符") - - // 基础配置 - table.Append("🔗 协议", data.Protocol, "API 协议类型") - table.Append("🌐 服务地址", data.BaseUrl, "API 基础 URL") - table.Append("🔑 API 密钥", maskApiKey(data.ApiKey), "API 访问密钥(已隐藏)") - - // 模型配置 - modelsStr := "" - if len(data.Models) > 0 { - for i, model := range data.Models { - if i > 0 { - modelsStr += ", " - } - modelsStr += model - } - } - table.Append("🤖 测试模型", modelsStr, "待测试的模型列表") - - // 测试参数 - table.Append("📊 请求总数", strconv.Itoa(data.Count), "每个模型的请求数量") - table.Append("⚡ 并发数", strconv.Itoa(data.Concurrency), "同时发送的请求数") - table.Append("🕐 超时时间", strconv.Itoa(data.Timeout)+"秒", "每个请求的超时时间") - table.Append("🌊 流式模式", strconv.FormatBool(data.Stream), "是否启用流式响应") - table.Append("🧠 思考模式", strconv.FormatBool(data.Thinking), "是否启用思考模式(仅OpenAI协议支持)") - - // 对于不需要截断的内容(文件或已包含长度信息的生成内容),直接显示 - var promptDisplay string - if !data.PromptShouldTruncate { - promptDisplay = data.PromptText - } else { - promptDisplay = truncatePrompt(data.PromptText) - } - - table.Append("📝 测试提示词", promptDisplay, "用于测试的提示内容") - - table.Append("📄 生成报告", strconv.FormatBool(data.Report), "是否生成测试报告文件") - - table.Render() -} - -// InitProgress 初始化进度条 -func (td *Displayer) InitProgress(total int, description string) { - td.mu.Lock() - defer td.mu.Unlock() - - td.progressBar = progressbar.NewOptions(total, - progressbar.OptionSetDescription(description), - progressbar.OptionSetTheme(progressbar.Theme{ - Saucer: "█", - SaucerPadding: "░", - BarStart: "[", - BarEnd: "]", - }), - progressbar.OptionShowCount(), - progressbar.OptionShowIts(), - progressbar.OptionSetWidth(50), - progressbar.OptionThrottle(100), // 限制更新频率 - progressbar.OptionShowElapsedTimeOnFinish(), - progressbar.OptionSetRenderBlankState(true), - ) -} - -func (td *Displayer) UpdateProgress(percent float64) { - td.mu.Lock() - defer td.mu.Unlock() - - if td.progressBar != nil { - // 计算当前进度值(基于进度条的最大值) - current := int(percent * float64(td.progressBar.GetMax()) / 100.0) - td.progressBar.Set(current) - } -} - -// FinishProgress 完成进度条 -func (td *Displayer) FinishProgress() { - td.mu.Lock() - defer td.mu.Unlock() - - if td.progressBar != nil { - td.progressBar.Finish() - fmt.Println() // 添加一个空行 - td.progressBar = nil - } -} - -func (td *Displayer) ShowErrorsReport(errors []*string) { - if len(errors) == 0 { - return - } - - // 统计错误信息和出现次数 - errorCounts := make(map[string]int) - totalErrors := 0 - - for _, errorPtr := range errors { - if errorPtr != nil { - errorMsg := *errorPtr - errorCounts[errorMsg]++ - totalErrors++ - } - } - - if totalErrors == 0 { - return - } - - fmt.Printf("%s%s❌ 错误信息报告%s\n", ColorBold, ColorRed, ColorReset) - fmt.Printf(" %s检测到 %d 个错误(%d 种不同类型)%s\n\n", ColorYellow, totalErrors, len(errorCounts), ColorReset) - - // 创建错误信息表格 - table := tablewriter.NewTable( - os.Stdout, - tablewriter.WithEastAsian(false), - ) - - table.Header("序号", "错误详情", "出现次数") - - // 添加错误信息到表格 - index := 1 - for errorMsg, count := range errorCounts { - // 如果错误信息太长,进行适当的截断和格式化 - displayMsg := errorMsg - if len(displayMsg) > 100 { - displayMsg = displayMsg[:97] + "..." - } - table.Append(fmt.Sprintf("%d", index), displayMsg, fmt.Sprintf("%d", count)) - index++ - } - - table.Render() - fmt.Println() -} - -// 将数据更新到终端上(刷新显示) -// 详细模式,展示所有 ReportData 的数据 -func (td *Displayer) ShowSignalReport(data *types.ReportData) { - // 单个综合表格 - table := tablewriter.NewTable( - os.Stdout, - tablewriter.WithEastAsian(false), - ) - - table.Header("指标", "最小值", "平均值", "标准差", "最大值", "单位", "采样方式说明") - - // 基础信息(这些只有单一值,只填最小值列) - table.Append("🔗 协议", data.Protocol, "", "", "", "-", "配置信息") - table.Append("🤖 模型", data.Model, "", "", "", "-", "配置信息") - table.Append("🌐 URL", data.BaseUrl, "", "", "", "-", "配置信息") - table.Append("🌊 流式", strconv.FormatBool(data.IsStream), "", "", "", "-", "配置信息") - table.Append("🧠 思考模式", strconv.FormatBool(data.IsThinking), "", "", "", "-", "配置信息") - table.Append("⚡ 并发数", strconv.Itoa(data.Concurrency), "", "", "", "个", "配置信息") - table.Append("📊 总请求数", strconv.Itoa(data.TotalRequests), "", "", "", "个", "完成的请求总数") - table.Append("✅ 成功率", fmt.Sprintf("%.2f", data.SuccessRate), "", "", "", "%", "成功请求占比") - - // 时间性能指标 - table.Append("🕐 总耗时", formatDuration(data.MinTotalTime), formatDuration(data.AvgTotalTime), fmt.Sprintf("±%s", formatDuration(data.StdDevTotalTime)), formatDuration(data.MaxTotalTime), "时间", "单个请求从发起到完全结束的时间") - - if data.TargetIP != "" { - table.Append("🎯 目标 IP", data.TargetIP, "", "", "", "-", "DNS 解析后的实际连接 IP") - } - // 内容性能指标 - if data.IsStream { - table.Append("⚡ TTFT", formatDuration(data.MinTTFT), formatDuration(data.AvgTTFT), fmt.Sprintf("±%s", formatDuration(data.StdDevTTFT)), formatDuration(data.MaxTTFT), "时间", "首个 token 响应时间 (含请求发送+网络+服务器处理)") - table.Append("⚡ TPOT", formatDuration(data.MinTPOT), formatDuration(data.AvgTPOT), fmt.Sprintf("±%s", formatDuration(data.StdDevTPOT)), formatDuration(data.MaxTPOT), "时间", "每个输出 token 的平均耗时 (除首token外)") - } - - // 网络性能指标 - table.Append("🔍 DNS 时间", formatDuration(data.MinDNSTime), formatDuration(data.AvgDNSTime), "", formatDuration(data.MaxDNSTime), "时间", "域名解析耗时 (httptrace)") - table.Append("🔒 TLS 时间", formatDuration(data.MinTLSHandshakeTime), formatDuration(data.AvgTLSHandshakeTime), "", formatDuration(data.MaxTLSHandshakeTime), "时间", "TLS 握手耗时 (httptrace)") - table.Append("🔌 TCP 连接时间", formatDuration(data.MinConnectTime), formatDuration(data.AvgConnectTime), "", formatDuration(data.MaxConnectTime), "时间", "TCP 连接建立耗时 (httptrace)") - - table.Append("🚀 输出 TPS", fmt.Sprintf("%.2f", data.MinTPS), fmt.Sprintf("%.2f", data.AvgTPS), fmt.Sprintf("±%.2f", data.StdDevTPS), fmt.Sprintf("%.2f", data.MaxTPS), "个/秒", "输出 tokens / 总耗时") - table.Append("🌐 吞吐 TPS", fmt.Sprintf("%.2f", data.MinTotalThroughputTPS), fmt.Sprintf("%.2f", data.AvgTotalThroughputTPS), fmt.Sprintf("±%.2f", data.StdDevTotalThroughputTPS), fmt.Sprintf("%.2f", data.MaxTotalThroughputTPS), "个/秒", "(输入+输出) tokens / 总耗时") - - // Token 数指标 - table.Append("📥 输入 Token 数", strconv.Itoa(data.MinInputTokenCount), strconv.Itoa(data.AvgInputTokenCount), fmt.Sprintf("±%.2f", data.StdDevInputTokenCount), strconv.Itoa(data.MaxInputTokenCount), "个", "API 请求的 prompt tokens") - table.Append("🎲 生成 Token 数", strconv.Itoa(data.MinOutputTokenCount), strconv.Itoa(data.AvgOutputTokenCount), fmt.Sprintf("±%.2f", data.StdDevOutputTokenCount), strconv.Itoa(data.MaxOutputTokenCount), "个", "API 返回的 completion tokens") - table.Append("🧠 思考 Token 数", strconv.Itoa(data.MinThinkingTokenCount), strconv.Itoa(data.AvgThinkingTokenCount), fmt.Sprintf("±%.2f", data.StdDevThinkingTokenCount), strconv.Itoa(data.MaxThinkingTokenCount), "个", "模型返回的 reasoning/thinking tokens") - - table.Render() - fmt.Println() -} - -// 将数据更新到终端上(刷新显示) -// 概览模式,每行一个,展示主要数据(平均值) -func (td *Displayer) ShowMultiReport(data []*types.ReportData) { - // 单个汇总表格,包含所有不同类型指标的平均值 - table := tablewriter.NewTable( - os.Stdout, - tablewriter.WithEastAsian(false), - ) - - table.Header("🤖 模型", "📊 请求数", "⚡ 并发", "✅ 成功率", - "🕐 平均总耗时", "⚡ 平均 TTFT", - "🚀 平均输出 TPS", "🌐 平均吞吐 TPS", - "🎲 平均输出 Token 数") - - for _, report := range data { - // TTFT 处理(流式模式才显示) - ttftStr := "-" - if report.IsStream { - ttftStr = formatDuration(report.AvgTTFT) - } - - table.Append( - report.Model, - strconv.Itoa(report.TotalRequests), - strconv.Itoa(report.Concurrency), - fmt.Sprintf("%.2f%%", report.SuccessRate), - formatDuration(report.AvgTotalTime), - ttftStr, - fmt.Sprintf("%.2f", report.AvgTPS), - fmt.Sprintf("%.2f", report.AvgTotalThroughputTPS), - strconv.Itoa(report.AvgOutputTokenCount), - ) - } - - table.Render() - fmt.Println() -} - -// maskApiKey 隐藏 API 密钥的敏感部分 -func maskApiKey(apiKey string) string { - if len(apiKey) <= 8 { - return "***" - } - return apiKey[:4] + "***" + apiKey[len(apiKey)-4:] -} - -// truncatePrompt 截断过长的提示词并显示长度信息 -func truncatePrompt(prompt string) string { - runes := []rune(prompt) - charCount := len(runes) - if charCount <= 50 { - return fmt.Sprintf("%s (长度: %d)", prompt, charCount) - } - return fmt.Sprintf("%s... (长度: %d)", string(runes[:47]), charCount) -} - -// formatDuration 格式化时间显示,保留2位小数 -func formatDuration(d time.Duration) string { - // 根据时间大小选择合适的单位 - if d >= time.Second { - // >= 1s: 显示为秒,保留2位小数 - return fmt.Sprintf("%.2fs", d.Seconds()) - } else if d >= time.Millisecond { - // >= 1ms: 显示为毫秒,保留2位小数 - return fmt.Sprintf("%.2fms", float64(d.Microseconds())/1000.0) - } else if d >= time.Microsecond { - // >= 1µs: 显示为微秒,保留2位小数 - return fmt.Sprintf("%.2fµs", float64(d.Nanoseconds())/1000.0) - } - // < 1µs: 显示为纳秒 - return fmt.Sprintf("%dns", d.Nanoseconds()) -} diff --git a/internal/display/display_test.go b/internal/display/display_test.go deleted file mode 100644 index 3e18567..0000000 --- a/internal/display/display_test.go +++ /dev/null @@ -1,118 +0,0 @@ -package display - -import ( - "testing" -) - -func TestTruncatePrompt(t *testing.T) { - // 测试用例 1: 短提示词 - t.Run("Short prompt", func(t *testing.T) { - result := truncatePrompt("你好") - expected := "你好 (长度: 2)" - if result != expected { - t.Errorf("Expected %q, got %q", expected, result) - } - }) - - // 测试用例 2: 长提示词 - t.Run("Long prompt", func(t *testing.T) { - longPrompt := "这是一个非常长的测试提示词,用于测试truncatePrompt函数的截断功能和长度显示" - result := truncatePrompt(longPrompt) - expected := "这是一个非常长的测试提示词,用于测试truncatePrompt函数的截断功能和长度显示 (长度: 44)" - if result != expected { - t.Errorf("Expected %q, got %q", expected, result) - } - }) - - // 测试用例 2.5: 超长提示词(需要截断) - t.Run("Very long prompt that needs truncation", func(t *testing.T) { - veryLongPrompt := "这是一个非常长的测试提示词,用于测试truncatePrompt函数的截断功能和长度显示,这个字符串超过五十个字符所以会被截断处理" - result := truncatePrompt(veryLongPrompt) - expected := "这是一个非常长的测试提示词,用于测试truncatePrompt函数的截断功能和长度显示,这个... (长度: 65)" - if result != expected { - t.Errorf("Expected %q, got %q", expected, result) - } - }) - - // 测试用例 3: 恰好50字符的提示词 - t.Run("Exactly 50 characters", func(t *testing.T) { - promptWith50Chars := "这个测试字符串包含恰好五十个字符用于测试边界条件是否能够正确处理各种情况的测试案例增加字符再加五十个" - result := truncatePrompt(promptWith50Chars) - expected := "这个测试字符串包含恰好五十个字符用于测试边界条件是否能够正确处理各种情况的测试案例增加字符再加五十个 (长度: 50)" - if result != expected { - t.Errorf("Expected %q, got %q", expected, result) - } - }) - - // 测试用例 4: 空字符串 - t.Run("Empty string", func(t *testing.T) { - result := truncatePrompt("") - expected := " (长度: 0)" - if result != expected { - t.Errorf("Expected %q, got %q", expected, result) - } - }) - - // 测试用例 5: 混合中英文 - t.Run("Mixed Chinese and English", func(t *testing.T) { - mixedPrompt := "Hello 世界 123" - result := truncatePrompt(mixedPrompt) - expected := "Hello 世界 123 (长度: 12)" - if result != expected { - t.Errorf("Expected %q, got %q", expected, result) - } - }) -} - -func TestShowErrorsReport(t *testing.T) { - displayer := New() - - // 测试用例 1: 空错误列表 - t.Run("Empty errors", func(t *testing.T) { - var errors []*string - displayer.ShowErrorsReport(errors) - // 应该不输出任何内容,函数应该直接返回 - }) - - // 测试用例 2: 单个错误 - t.Run("Single error", func(t *testing.T) { - error1 := "API 连接超时" - errors := []*string{&error1} - displayer.ShowErrorsReport(errors) - }) - - // 测试用例 3: 多个不同错误 - t.Run("Multiple different errors", func(t *testing.T) { - error1 := "API 连接超时" - error2 := "认证失败: 无效的 API 密钥" - error3 := "模型不存在或无法访问" - errors := []*string{&error1, &error2, &error3} - displayer.ShowErrorsReport(errors) - }) - - // 测试用例 4: 重复错误统计 - t.Run("Duplicate errors counting", func(t *testing.T) { - error1 := "API 连接超时" - error2 := "认证失败: 无效的 API 密钥" - error3 := "API 连接超时" // 重复错误 - error4 := "认证失败: 无效的 API 密钥" // 重复错误 - error5 := "API 连接超时" // 又一个重复错误 - errors := []*string{&error1, &error2, &error3, &error4, &error5} - displayer.ShowErrorsReport(errors) - }) - - // 测试用例 5: 包含长错误消息 - t.Run("Long error messages", func(t *testing.T) { - longError := "这是一个非常长的错误消息,用于测试当错误消息超过100个字符时的截断功能。这个错误消息包含了大量的细节信息,比如堆栈跟踪、请求参数、响应状态码等等。" - shortError := "短错误消息" - errors := []*string{&longError, &shortError, &longError} // 包含重复的长错误 - displayer.ShowErrorsReport(errors) - }) - - // 测试用例 6: 包含nil指针 - t.Run("With nil pointers", func(t *testing.T) { - error1 := "正常错误消息" - errors := []*string{&error1, nil, &error1, nil} - displayer.ShowErrorsReport(errors) - }) -} diff --git a/internal/tui/model.go b/internal/tui/model.go index 8129f7d..0b672ba 100644 --- a/internal/tui/model.go +++ b/internal/tui/model.go @@ -172,22 +172,32 @@ func (m *Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { } func (m *Model) View() string { + if m.width < 4 || m.height < 4 { + return "..." + } + innerW := m.width - 2 + innerH := m.height - 2 + + var content string switch m.view { case viewTaskList: - return pages.RenderTaskList(m.taskList, m.styles, m.width, m.height) + content = pages.RenderTaskList(m.taskList, m.styles, innerW, innerH) case viewTaskDetail: - return pages.RenderTaskDetail(m.detail, m.styles, m.width, m.height) + content = pages.RenderTaskDetail(m.detail, m.styles, innerW, innerH) case viewWizard: bg := m.renderBgForWizard() - return pages.RenderWizard(m.wizard, bg, m.styles, m.width, m.height) + content = pages.RenderWizard(m.wizard, bg, m.styles, innerW, innerH) case viewDashboard: - return pages.RenderDashboard(m.dash, m.dashTaskName(), m.styles, m.width, m.height) + content = pages.RenderDashboard(m.dash, m.dashTaskName(), m.styles, innerW, innerH) case viewTurboDash: - return pages.RenderTurboDash(m.turboDash, m.turboDashTaskName(), m.styles, m.width, m.height) + content = pages.RenderTurboDash(m.turboDash, m.turboDashTaskName(), m.styles, innerW, innerH) case viewReqDetail: - return pages.RenderReqDetail(m.reqDetail, m.reqDetailTaskName(), m.styles, m.width, m.height) + content = pages.RenderReqDetail(m.reqDetail, m.reqDetailTaskName(), m.styles, innerW, innerH) + default: + content = "未知视图" } - return "未知视图" + + return m.styles.AppBorder.Width(innerW).Height(innerH).Render(content) } // ─── 键盘分发 ───────────────────────────────────────────────────────────────── @@ -412,10 +422,12 @@ func (m *Model) injectRunState(rs *server.RunState) { } func (m *Model) renderBgForWizard() string { + innerW := m.width - 2 + innerH := m.height - 2 if m.prevView == viewTaskDetail { - return pages.RenderTaskDetail(m.detail, m.styles, m.width, m.height) + return pages.RenderTaskDetail(m.detail, m.styles, innerW, innerH) } - return pages.RenderTaskList(m.taskList, m.styles, m.width, m.height) + return pages.RenderTaskList(m.taskList, m.styles, innerW, innerH) } func (m *Model) dashTaskName() string { diff --git a/internal/tui/pages/dashboard.go b/internal/tui/pages/dashboard.go index af6547d..638b8e9 100644 --- a/internal/tui/pages/dashboard.go +++ b/internal/tui/pages/dashboard.go @@ -191,36 +191,38 @@ func RenderDashboard(d *DashboardState, taskName string, st Styles, width, heigh footer := renderFooter(st, width, "[s] 停止", "[b] 后台运行", "[r] 提前报告", "[q] 退出") // ── 计算高度 ── + // 布局:header(2) + split面板(splitH) + 进度面板(3) + 请求面板(reqListH+2) + ctxBarH + footer(1) headerH := 2 ctxBarH := 0 if ctxBar != "" { ctxBarH = 1 } footerH := 1 - splitH := 9 // 上方双栏区域高度 - progressH := 1 // 进度条行高 - divH := 3 // 分隔线总行数(3条分隔线各占1行) - reqListH := height - headerH - ctxBarH - footerH - splitH - progressH - divH + splitH := 9 // 双栏面板外部总高度(含面板边框) + progressPanel := 3 // 进度条面板外部高度(1内容+2边框) + reqListH := height - headerH - ctxBarH - footerH - splitH - progressPanel - 2 // -2 for req panel border if reqListH < 3 { reqListH = 3 } - // ── 双栏(任务参数 ║ 实时指标)── - leftW := (width - 2) * 45 / 100 - rightW := width - 2 - leftW - 1 // -1 for separator │ - leftContent := buildDashParamsPanel(d, rs, st, splitH-1, leftW) - rightContent := buildDashMetricsPanel(rs, st, splitH-1, rightW) - splitDiv := dividerLine(st, width) - split := dualColumnLayout(st, leftContent, rightContent, leftW, rightW, splitH) + // ── 双栏面板(任务参数 | 实时指标)── + leftW := width * 45 / 100 + rightW := width - leftW + leftContent := buildDashParamsPanel(d, rs, st, splitH-2, leftW-2) + rightContent := buildDashMetricsPanel(rs, st, splitH-2, rightW-2) + leftPanel := st.Panel.Width(leftW - 2).Render(leftContent) + rightPanel := st.Panel.Width(rightW - 2).Render(rightContent) + split := lipgloss.JoinHorizontal(lipgloss.Top, leftPanel, rightPanel) - // ── 进度条 ── - progressLine := buildProgressLine(rs, st, width) + // ── 进度条面板 ── + progressLine := buildProgressLine(rs, st, width-2) + progressPanelStr := wrapPanel(st, progressLine, width) - // ── 请求列表 ── - reqDiv := dividerLine(st, width) - reqList := buildRequestList(d, rs, st, width, reqListH) + // ── 请求列表面板 ── + reqList := buildRequestList(d, rs, st, width-2, reqListH) + reqPanelStr := wrapPanel(st, reqList, width) - parts := []string{header, splitDiv, split, splitDiv, progressLine, reqDiv, reqList} + parts := []string{header, split, progressPanelStr, reqPanelStr} if ctxBar != "" { parts = append(parts, ctxBar) } diff --git a/internal/tui/pages/helpers.go b/internal/tui/pages/helpers.go index da6157f..d3a4df7 100644 --- a/internal/tui/pages/helpers.go +++ b/internal/tui/pages/helpers.go @@ -277,3 +277,11 @@ func boolLabel(b bool) string { func labelValue(st Styles, label, value string) string { return st.Label.Render(label) + " " + st.Value.Render(value) } + +// wrapPanel 用带边框的 Panel 包裹内容,outerW 为包含边框的总宽度。 +func wrapPanel(st Styles, content string, outerW int) string { + if outerW < 4 { + return content + } + return st.Panel.Width(outerW - 2).Render(content) +} diff --git a/internal/tui/pages/reqdetail.go b/internal/tui/pages/reqdetail.go index 11bf222..ecf3359 100644 --- a/internal/tui/pages/reqdetail.go +++ b/internal/tui/pages/reqdetail.go @@ -5,6 +5,7 @@ import ( "strings" tea "github.com/charmbracelet/bubbletea" + "github.com/charmbracelet/lipgloss" "github.com/yinxulai/ait/internal/server" ) @@ -131,6 +132,7 @@ func RenderReqDetail(s *ReqDetailState, taskName string, st Styles, width, heigh footer := renderFooter(st, width, "[b/Esc] 返回仪表盘", "[↑↓] 滚动", "[←→] 上/下一条请求") // ── 计算高度 ── + // 布局:header(2) + split面板(splitH) + 输入面板(inputH+2) + 输出面板(outputH+2) + ctxBarH + footer(1) headerH := 2 ctxBarH := 0 if ctxBar != "" { @@ -139,26 +141,29 @@ func RenderReqDetail(s *ReqDetailState, taskName string, st Styles, width, heigh footerH := 1 splitH := 9 inputH := 5 - outputH := height - headerH - ctxBarH - footerH - splitH - inputH - 3 // -3 for dividers + outputH := height - headerH - ctxBarH - footerH - splitH - inputH - 2 - 2 // -2 for input panel border, -2 for output panel border if outputH < 4 { outputH = 4 } - // ── 双栏(性能指标 ║ 网络指标)── - leftW := (width - 2) * 50 / 100 - rightW := width - 2 - leftW - 1 - leftContent := buildReqPerfPanel(r, st, splitH-1, leftW) - rightContent := buildReqNetworkPanel(r, st, splitH-1, rightW) - splitDiv := dividerLine(st, width) - split := dualColumnLayout(st, leftContent, rightContent, leftW, rightW, splitH) + // ── 双栏面板(性能指标 | 网络指标)── + leftW := width * 50 / 100 + rightW := width - leftW + leftContent := buildReqPerfPanel(r, st, splitH-2, leftW-2) + rightContent := buildReqNetworkPanel(r, st, splitH-2, rightW-2) + leftPanel := st.Panel.Width(leftW - 2).Render(leftContent) + rightPanel := st.Panel.Width(rightW - 2).Render(rightContent) + split := lipgloss.JoinHorizontal(lipgloss.Top, leftPanel, rightPanel) - // ── 输入区 ── - inputSection := buildInputSection(r, st, width, inputH) + // ── 输入区面板 ── + inputSection := buildInputSection(r, st, width-2, inputH) + inputPanelStr := wrapPanel(st, inputSection, width) - // ── 输出区 ── - outputSection := buildOutputSection(r, s.ScrollY, st, width, outputH) + // ── 输出区面板 ── + outputSection := buildOutputSection(r, s.ScrollY, st, width-2, outputH) + outputPanelStr := wrapPanel(st, outputSection, width) - parts := []string{header, splitDiv, split, splitDiv, inputSection, splitDiv, outputSection} + parts := []string{header, split, inputPanelStr, outputPanelStr} if ctxBar != "" { parts = append(parts, ctxBar) } diff --git a/internal/tui/pages/styles.go b/internal/tui/pages/styles.go index 9639911..d1b4548 100644 --- a/internal/tui/pages/styles.go +++ b/internal/tui/pages/styles.go @@ -24,6 +24,8 @@ const ( // Styles 汇聚所有 TUI 样式,由 NewStyles() 初始化。 type Styles struct { + AppBorder lipgloss.Style + Panel lipgloss.Style Header lipgloss.Style HeaderInfo lipgloss.Style Footer lipgloss.Style @@ -126,5 +128,11 @@ func NewStyles() Styles { Padding(0, 2), Divider: lipgloss.NewStyle(). Foreground(colorDivider), + AppBorder: lipgloss.NewStyle(). + Border(lipgloss.RoundedBorder()). + BorderForeground(colorPurple), + Panel: lipgloss.NewStyle(). + Border(lipgloss.NormalBorder()). + BorderForeground(colorDivider), } } diff --git a/internal/tui/pages/taskdetail.go b/internal/tui/pages/taskdetail.go index 56fd76b..0b6b59d 100644 --- a/internal/tui/pages/taskdetail.go +++ b/internal/tui/pages/taskdetail.go @@ -106,15 +106,16 @@ func RenderTaskDetail(s *TaskDetailState, st Styles, width, height int) string { ctxBarH = 1 } footerH := 1 - contentH := height - headerH - ctxBarH - footerH + contentH := height - headerH - ctxBarH - footerH - 2 // -2 for panel border if contentH < 6 { contentH = 6 } // ── 内容构建 ── - content := buildTaskDetailContent(s, st, t, inp, width, contentH) + content := buildTaskDetailContent(s, st, t, inp, width-2, contentH) + panel := wrapPanel(st, content, width) - parts := []string{header, content} + parts := []string{header, panel} if ctxBar != "" { parts = append(parts, ctxBar) } diff --git a/internal/tui/pages/tasklist.go b/internal/tui/pages/tasklist.go index 5956e07..c41d22c 100644 --- a/internal/tui/pages/tasklist.go +++ b/internal/tui/pages/tasklist.go @@ -172,15 +172,16 @@ func RenderTaskList(s *TaskListState, st Styles, width, height int) string { ctxBarH = 1 } footerH := 1 - contentH := height - headerH - ctxBarH - footerH + contentH := height - headerH - ctxBarH - footerH - 2 // -2 for panel border if contentH < 4 { contentH = 4 } // ── 内容区 ── - content := buildTaskListContent(s, st, width, contentH) + content := buildTaskListContent(s, st, width-2, contentH) + panel := wrapPanel(st, content, width) - parts := []string{header, content} + parts := []string{header, panel} if ctxBar != "" { parts = append(parts, ctxBar) } diff --git a/internal/tui/pages/turbodash.go b/internal/tui/pages/turbodash.go index 4f87a08..4f9a328 100644 --- a/internal/tui/pages/turbodash.go +++ b/internal/tui/pages/turbodash.go @@ -5,6 +5,7 @@ import ( "strings" tea "github.com/charmbracelet/bubbletea" + "github.com/charmbracelet/lipgloss" "github.com/yinxulai/ait/internal/server" "github.com/yinxulai/ait/internal/types" ) @@ -167,6 +168,7 @@ func RenderTurboDash(d *TurboDashState, taskName string, st Styles, width, heigh footer := renderFooter(st, width, "[s] 停止", "[b] 后台运行", "[m] 标记极限", "[r] 提前报告", "[q] 退出") // ── 计算高度 ── + // 布局:header(2) + split面板(splitH) + 进度面板(3) + 级别面板(levelListH+2) + ctxBarH + footer(1) headerH := 2 ctxBarH := 0 if ctxBar != "" { @@ -174,29 +176,30 @@ func RenderTurboDash(d *TurboDashState, taskName string, st Styles, width, heigh } footerH := 1 splitH := 9 - progressH := 1 - divH := 3 - levelListH := height - headerH - ctxBarH - footerH - splitH - progressH - divH + progressPanel := 3 + levelListH := height - headerH - ctxBarH - footerH - splitH - progressPanel - 2 if levelListH < 3 { levelListH = 3 } - // ── 双栏(任务参数 ║ 当前级别指标)── - leftW := (width - 2) * 45 / 100 - rightW := width - 2 - leftW - 1 - leftContent := buildTurboDashParams(rs, st, splitH-1, leftW) - rightContent := buildTurboDashMetrics(rs, st, splitH-1, rightW) - splitDiv := dividerLine(st, width) - split := dualColumnLayout(st, leftContent, rightContent, leftW, rightW, splitH) + // ── 双栏面板(任务参数 | 当前级别指标)── + leftW := width * 45 / 100 + rightW := width - leftW + leftContent := buildTurboDashParams(rs, st, splitH-2, leftW-2) + rightContent := buildTurboDashMetrics(rs, st, splitH-2, rightW-2) + leftPanel := st.Panel.Width(leftW - 2).Render(leftContent) + rightPanel := st.Panel.Width(rightW - 2).Render(rightContent) + split := lipgloss.JoinHorizontal(lipgloss.Top, leftPanel, rightPanel) - // ── 进度条 ── - progressLine := buildTurboProgressLine(rs, st, width) + // ── 进度条面板 ── + progressLine := buildTurboProgressLine(rs, st, width-2) + progressPanelStr := wrapPanel(st, progressLine, width) - // ── 级别列表 ── - levelDiv := dividerLine(st, width) - levelList := buildLevelList(d, rs, st, width, levelListH) + // ── 级别列表面板 ── + levelList := buildLevelList(d, rs, st, width-2, levelListH) + levelPanelStr := wrapPanel(st, levelList, width) - parts := []string{header, splitDiv, split, splitDiv, progressLine, levelDiv, levelList} + parts := []string{header, split, progressPanelStr, levelPanelStr} if ctxBar != "" { parts = append(parts, ctxBar) } From f9e48f7edf35d1c552c745e49a9780e4cbc43755 Mon Sep 17 00:00:00 2001 From: Alain Date: Sun, 17 May 2026 09:59:05 +0800 Subject: [PATCH 08/52] =?UTF-8?q?feat:=20=E9=87=8D=E6=9E=84=E9=A1=B5?= =?UTF-8?q?=E9=9D=A2=E5=B8=83=E5=B1=80=EF=BC=8C=E6=B7=BB=E5=8A=A0=E6=9C=80?= =?UTF-8?q?=E5=B0=8F=E5=B0=BA=E5=AF=B8=E4=BF=9D=E6=8A=A4=E5=92=8C=E9=80=9A?= =?UTF-8?q?=E7=94=A8=E9=A1=B5=E9=9D=A2=E7=BB=93=E6=9E=84?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/tui/pages/dashboard.go | 53 +++++++---------- internal/tui/pages/helpers.go | 51 ----------------- internal/tui/pages/layout.go | 98 ++++++++++++++++++++++++++++++++ internal/tui/pages/reqdetail.go | 51 ++++++----------- internal/tui/pages/taskdetail.go | 50 ++++------------ internal/tui/pages/tasklist.go | 46 ++++----------- internal/tui/pages/turbodash.go | 56 +++++++----------- 7 files changed, 176 insertions(+), 229 deletions(-) create mode 100644 internal/tui/pages/layout.go diff --git a/internal/tui/pages/dashboard.go b/internal/tui/pages/dashboard.go index 638b8e9..ce11cb9 100644 --- a/internal/tui/pages/dashboard.go +++ b/internal/tui/pages/dashboard.go @@ -145,12 +145,15 @@ func HandleDashboardKey(d *DashboardState, msg tea.KeyMsg, client Client) (*Dash // ║ [s] 停止 [b] 后台运行 [r] 报告 [q] 退出 ║ // ╚═════════════════════════════════════════╝ func RenderDashboard(d *DashboardState, taskName string, st Styles, width, height int) string { - if d == nil || width == 0 { - return "加载中..." + if TooSmall(width, height) { + return renderTooSmall(st, width, height) + } + if d == nil { + return renderTooSmall(st, width, height) } rs := d.RunState - // ── Header ── + // ── 状态标识 ── statusStr := "等待中" if rs != nil { switch rs.Status { @@ -171,36 +174,24 @@ func RenderDashboard(d *DashboardState, taskName string, st Styles, width, heigh "─", "─", 0, rs.TotalReqs) } - header := renderHeader(st, width, - "AIT 正在测试 ─ "+truncate(taskName, 25), - statusStr, - subtitle, - "", - ) - - // ── Context Bar ── var cbItems []ContextBarItem if d.ReqSel >= 0 && rs != nil && d.ReqSel < len(rs.Requests) { cbItems = CtxBar_Dashboard_Sel() } else { cbItems = CtxBar_Dashboard_NoSel() } - ctxBar := RenderContextBar(st, width, cbItems) - - // ── Footer ── - footer := renderFooter(st, width, "[s] 停止", "[b] 后台运行", "[r] 提前报告", "[q] 退出") + l := PageLayout{ + TitleLeft: "AIT 正在测试 ─ " + truncate(taskName, 25), + TitleRight: statusStr, + InfoLeft: subtitle, + CtxItems: cbItems, + FooterParts: []string{"[s] 停止", "[b] 后台运行", "[r] 提前报告", "[q] 退出"}, + } // ── 计算高度 ── - // 布局:header(2) + split面板(splitH) + 进度面板(3) + 请求面板(reqListH+2) + ctxBarH + footer(1) - headerH := 2 - ctxBarH := 0 - if ctxBar != "" { - ctxBarH = 1 - } - footerH := 1 - splitH := 9 // 双栏面板外部总高度(含面板边框) + splitH := 9 // 双栏面板外部总高度(含面板边框) progressPanel := 3 // 进度条面板外部高度(1内容+2边框) - reqListH := height - headerH - ctxBarH - footerH - splitH - progressPanel - 2 // -2 for req panel border + reqListH := height - l.ChromeHeight() - splitH - progressPanel - 2 // -2 for req panel border if reqListH < 3 { reqListH = 3 } @@ -215,19 +206,15 @@ func RenderDashboard(d *DashboardState, taskName string, st Styles, width, heigh split := lipgloss.JoinHorizontal(lipgloss.Top, leftPanel, rightPanel) // ── 进度条面板 ── - progressLine := buildProgressLine(rs, st, width-2) + progressLine := buildProgressLine(rs, st, ContentWidth(width)) progressPanelStr := wrapPanel(st, progressLine, width) // ── 请求列表面板 ── - reqList := buildRequestList(d, rs, st, width-2, reqListH) + reqList := buildRequestList(d, rs, st, ContentWidth(width), reqListH) reqPanelStr := wrapPanel(st, reqList, width) - parts := []string{header, split, progressPanelStr, reqPanelStr} - if ctxBar != "" { - parts = append(parts, ctxBar) - } - parts = append(parts, footer) - return strings.Join(parts, "\n") + content := strings.Join([]string{split, progressPanelStr, reqPanelStr}, "\n") + return l.Assemble(content, st, width) } // buildDashParamsPanel 构建左侧任务参数面板。 @@ -289,7 +276,6 @@ func buildProgressLine(rs *server.RunState, st Styles, width int) string { ratio = float64(done) / float64(total) } barW := 20 - bar := progressBar(ratio, barW) barRendered := st.Ok.Render(strings.Repeat("█", int(ratio*float64(barW)))) + st.Muted.Render(strings.Repeat("░", barW-int(ratio*float64(barW)))) @@ -301,7 +287,6 @@ func buildProgressLine(rs *server.RunState, st Styles, width int) string { line := fmt.Sprintf(" 进度 %s %d / %d %s", barRendered, done, total, elapsed) - _ = bar if lipgloss.Width(line) > width { line = truncate(line, width) } diff --git a/internal/tui/pages/helpers.go b/internal/tui/pages/helpers.go index d3a4df7..4d3a1e5 100644 --- a/internal/tui/pages/helpers.go +++ b/internal/tui/pages/helpers.go @@ -162,57 +162,6 @@ func renderFooter(st Styles, width int, parts ...string) string { return st.Footer.Width(w).Render(line) } -// dualColumnLayout 将左右两段文本排列为双栏,高度限定为 maxH。 -// 中间用竖线 │ 隔开。 -func dualColumnLayout(st Styles, left, right string, leftW, rightW, maxH int) string { - leftLines := strings.Split(left, "\n") - rightLines := strings.Split(right, "\n") - - if len(leftLines) > maxH { - leftLines = leftLines[:maxH] - } - if len(rightLines) > maxH { - rightLines = rightLines[:maxH] - } - for len(leftLines) < maxH { - leftLines = append(leftLines, "") - } - for len(rightLines) < maxH { - rightLines = append(rightLines, "") - } - - sep := st.Divider.Render("│") - var rows []string - for i := 0; i < maxH; i++ { - lLine := leftLines[i] - rLine := rightLines[i] - lW := lipgloss.Width(lLine) - if lW < leftW { - lLine += strings.Repeat(" ", leftW-lW) - } - rows = append(rows, lLine+sep+rLine) - } - return strings.Join(rows, "\n") -} - -// progressBar 生成进度条字符串(filled=已完成比例 0.0-1.0)。 -func progressBar(filled float64, width int) string { - if width <= 0 { - return "" - } - if filled < 0 { - filled = 0 - } - if filled > 1 { - filled = 1 - } - doneW := int(float64(width) * filled) - emptyW := width - doneW - done := strings.Repeat("█", doneW) - empty := strings.Repeat("░", emptyW) - return done + empty -} - // wrapIndex 循环索引(保证 0 ≤ result < count)。 func wrapIndex(idx, count int) int { if count <= 0 { diff --git a/internal/tui/pages/layout.go b/internal/tui/pages/layout.go new file mode 100644 index 0000000..bc666aa --- /dev/null +++ b/internal/tui/pages/layout.go @@ -0,0 +1,98 @@ +package pages + +import "strings" + +// ── 尺寸常量 ────────────────────────────────────────────────────────────────── + +const ( + // MinWidth / MinHeight:低于此值时显示"窗口过小"提示而非正常页面。 + MinWidth = 40 + MinHeight = 10 + + // chrome 各组成部分的行数 + chromeHeaderH = 2 // 双行标题栏 + chromeFooterH = 1 // 单行底部状态栏 + chromeCtxBarH = 2 // ContextBar (1) + 分隔线 (1) + + // panelBorderV 是单个面板的上下边框行数之和。 + panelBorderV = 2 +) + +// ── PageLayout ──────────────────────────────────────────────────────────────── + +// PageLayout 描述一个完整页面的 chrome(顶部标题栏、底部 ContextBar + 分隔线 + Footer)。 +// 各页面 Render 函数先构造 PageLayout,再调用 Assemble 拼装最终输出, +// 从而消除页面间重复的 header/footer/ctxbar 组装逻辑。 +type PageLayout struct { + TitleLeft string + TitleRight string + InfoLeft string + InfoRight string + CtxItems []ContextBarItem + FooterParts []string +} + +// ChromeHeight 返回 chrome 占用的总行数 +// (header + 若有 ctxbar 则含 ctxbar+分隔线 + footer)。 +func (l PageLayout) ChromeHeight() int { + h := chromeHeaderH + chromeFooterH + if len(l.CtxItems) > 0 { + h += chromeCtxBarH + } + return h +} + +// ContentHeight 返回单面板页面主内容区的可用行数 +// (总高度 - chrome 行数 - 面板上下边框)。 +func (l PageLayout) ContentHeight(totalH int) int { + h := totalH - l.ChromeHeight() - panelBorderV + if h < 2 { + h = 2 + } + return h +} + +// ContentWidth 返回单面板页面内容区的可用列宽 +// (总宽度 - 面板左右边框)。 +func ContentWidth(totalW int) int { + w := totalW - 2 + if w < 1 { + w = 1 + } + return w +} + +// Assemble 拼装完整页面输出: +// +// header +// content ← 由调用方构建(单面板或多面板) +// [ctxbar] ← 仅当 CtxItems 非空时输出 +// [divider] ← ctxbar 与 footer 之间的分隔线 +// footer +func (l PageLayout) Assemble(content string, st Styles, width int) string { + header := renderHeader(st, width, l.TitleLeft, l.TitleRight, l.InfoLeft, l.InfoRight) + footer := renderFooter(st, width, l.FooterParts...) + + parts := []string{header, content} + if len(l.CtxItems) > 0 { + ctxBar := RenderContextBar(st, width, l.CtxItems) + parts = append(parts, ctxBar, dividerLine(st, width)) + } + parts = append(parts, footer) + return strings.Join(parts, "\n") +} + +// ── 最小尺寸保护 ────────────────────────────────────────────────────────────── + +// TooSmall 返回 true 当终端小于最小可用尺寸。 +func TooSmall(width, height int) bool { + return width < MinWidth || height < MinHeight +} + +// renderTooSmall 返回终端过小时的简洁提示。 +func renderTooSmall(st Styles, width, _ int) string { + if width < 4 { + return "..." + } + return st.Muted.Render(truncate("窗口过小 ↔ 请放大终端", width)) +} diff --git a/internal/tui/pages/reqdetail.go b/internal/tui/pages/reqdetail.go index ecf3359..87dc103 100644 --- a/internal/tui/pages/reqdetail.go +++ b/internal/tui/pages/reqdetail.go @@ -96,11 +96,11 @@ func HandleReqDetailKey(s *ReqDetailState, msg tea.KeyMsg) (*ReqDetailState, Nav // ║ [b/Esc] 返回仪表盘 [↑↓] 滚动 [←→] 上/下一条 ║ // ╚═════════════════════════════════════════╝ func RenderReqDetail(s *ReqDetailState, taskName string, st Styles, width, height int) string { - if s == nil || width == 0 { - return "加载中..." + if TooSmall(width, height) { + return renderTooSmall(st, width, height) } - if len(s.Requests) == 0 { - return "无请求数据" + if s == nil || len(s.Requests) == 0 { + return renderTooSmall(st, width, height) } idx := s.Index @@ -112,36 +112,25 @@ func RenderReqDetail(s *ReqDetailState, taskName string, st Styles, width, heigh } r := s.Requests[idx] - // ── Header ── + // ── 状态标识 ── statusStr := st.Ok.Render("✓ 成功") if !r.Success { statusStr = st.ErrStyle.Render("✗ 失败") } - header := renderHeader(st, width, - fmt.Sprintf("AIT 请求详情 - %s #%d", truncate(taskName, 20), idx+1), - statusStr, - fmt.Sprintf("◆ AIT 任务: %s 请求 %d / %d", - truncate(taskName, 20), idx+1, len(s.Requests)), - "", - ) - - // ── Context Bar ── - ctxBar := RenderContextBar(st, width, CtxBar_ReqDetail()) - // ── Footer ── - footer := renderFooter(st, width, "[b/Esc] 返回仪表盘", "[↑↓] 滚动", "[←→] 上/下一条请求") + l := PageLayout{ + TitleLeft: fmt.Sprintf("AIT 请求详情 - %s #%d", truncate(taskName, 20), idx+1), + TitleRight: statusStr, + InfoLeft: fmt.Sprintf("◆ AIT 任务: %s 请求 %d / %d", + truncate(taskName, 20), idx+1, len(s.Requests)), + CtxItems: CtxBar_ReqDetail(), + FooterParts: []string{"[b/Esc] 返回仪表盘", "[↑↓] 滚动", "[←→] 上/下一条请求"}, + } // ── 计算高度 ── - // 布局:header(2) + split面板(splitH) + 输入面板(inputH+2) + 输出面板(outputH+2) + ctxBarH + footer(1) - headerH := 2 - ctxBarH := 0 - if ctxBar != "" { - ctxBarH = 1 - } - footerH := 1 splitH := 9 inputH := 5 - outputH := height - headerH - ctxBarH - footerH - splitH - inputH - 2 - 2 // -2 for input panel border, -2 for output panel border + outputH := height - l.ChromeHeight() - splitH - inputH - 2 - 2 // -2 for input panel border, -2 for output panel border if outputH < 4 { outputH = 4 } @@ -156,19 +145,15 @@ func RenderReqDetail(s *ReqDetailState, taskName string, st Styles, width, heigh split := lipgloss.JoinHorizontal(lipgloss.Top, leftPanel, rightPanel) // ── 输入区面板 ── - inputSection := buildInputSection(r, st, width-2, inputH) + inputSection := buildInputSection(r, st, ContentWidth(width), inputH) inputPanelStr := wrapPanel(st, inputSection, width) // ── 输出区面板 ── - outputSection := buildOutputSection(r, s.ScrollY, st, width-2, outputH) + outputSection := buildOutputSection(r, s.ScrollY, st, ContentWidth(width), outputH) outputPanelStr := wrapPanel(st, outputSection, width) - parts := []string{header, split, inputPanelStr, outputPanelStr} - if ctxBar != "" { - parts = append(parts, ctxBar) - } - parts = append(parts, footer) - return strings.Join(parts, "\n") + content := strings.Join([]string{split, inputPanelStr, outputPanelStr}, "\n") + return l.Assemble(content, st, width) } // buildReqPerfPanel 构建请求左侧性能指标面板。 diff --git a/internal/tui/pages/taskdetail.go b/internal/tui/pages/taskdetail.go index 0b6b59d..f9e0155 100644 --- a/internal/tui/pages/taskdetail.go +++ b/internal/tui/pages/taskdetail.go @@ -70,57 +70,29 @@ func HandleTaskDetailKey(s *TaskDetailState, msg tea.KeyMsg, client Client) (*Ta // ║ [b/Esc] 返回列表 ◆ AIT v0.1 ║ // ╚══════════════════════════════════════╝ func RenderTaskDetail(s *TaskDetailState, st Styles, width, height int) string { - if width == 0 { - return "加载中..." + if TooSmall(width, height) { + return renderTooSmall(st, width, height) } t := s.Task inp := t.Input - // ── Header ── updatedStr := timeAgo(t.UpdatedAt) - header := renderHeader(st, width, - "AIT 任务详情 ─ "+truncate(t.Name, 30), - "", - fmt.Sprintf("◆ AIT 任务 ID: %s 更新: %s %s", - truncate(t.ID, 10), t.UpdatedAt.Format("2006-01-02 15:04"), updatedStr), - "", - ) - - // ── Context Bar ── - hasHistory := len(s.History) > 0 var cbItems []ContextBarItem - if hasHistory { + if len(s.History) > 0 { cbItems = CtxBar_TaskDetail_HasHistory() } else { cbItems = CtxBar_TaskDetail_NoHistory() } - ctxBar := RenderContextBar(st, width, cbItems) - - // ── Footer ── - footer := renderFooter(st, width, "[b/Esc] 返回列表", "[r] 运行", "[e] 编辑", "◆ AIT v0.1") - - // ── 内容区高度 ── - headerH := 2 - ctxBarH := 0 - if ctxBar != "" { - ctxBarH = 1 - } - footerH := 1 - contentH := height - headerH - ctxBarH - footerH - 2 // -2 for panel border - if contentH < 6 { - contentH = 6 + l := PageLayout{ + TitleLeft: "AIT 任务详情 ─ " + truncate(t.Name, 30), + InfoLeft: fmt.Sprintf("◆ AIT 任务 ID: %s 更新: %s %s", + truncate(t.ID, 10), t.UpdatedAt.Format("2006-01-02 15:04"), updatedStr), + CtxItems: cbItems, + FooterParts: []string{"[b/Esc] 返回列表", "[r] 运行", "[e] 编辑", "◆ AIT v0.1"}, } - // ── 内容构建 ── - content := buildTaskDetailContent(s, st, t, inp, width-2, contentH) - panel := wrapPanel(st, content, width) - - parts := []string{header, panel} - if ctxBar != "" { - parts = append(parts, ctxBar) - } - parts = append(parts, footer) - return strings.Join(parts, "\n") + content := buildTaskDetailContent(s, st, t, inp, ContentWidth(width), l.ContentHeight(height)) + return l.Assemble(wrapPanel(st, content, width), st, width) } // buildTaskDetailContent 构建任务详情内容区。 diff --git a/internal/tui/pages/tasklist.go b/internal/tui/pages/tasklist.go index c41d22c..d16996d 100644 --- a/internal/tui/pages/tasklist.go +++ b/internal/tui/pages/tasklist.go @@ -135,23 +135,14 @@ func HandleTaskListKey(s *TaskListState, msg tea.KeyMsg, client Client) (*TaskLi // ║ [↑↓] 选择 [q] 退出 ◆ AIT ║ // ╚══════════════════════════════╝ func RenderTaskList(s *TaskListState, st Styles, width, height int) string { - if width == 0 { - return "加载中..." + if TooSmall(width, height) { + return renderTooSmall(st, width, height) } - // ── Header ── lastRunStr := "" if lt := s.latestRunAt(); lt != nil { lastRunStr = "最近运行: " + lt.Format("2006-01-02 15:04") } - header := renderHeader(st, width, - "AIT 任务中心", - "", - fmt.Sprintf("◆ AIT 已保存任务: %d %s", len(s.Tasks), lastRunStr), - "", - ) - - // ── Context Bar ── var cbItems []ContextBarItem if t, ok := s.CurrentTask(); ok { if s.IsTaskRunning(t.ID) { @@ -160,33 +151,15 @@ func RenderTaskList(s *TaskListState, st Styles, width, height int) string { cbItems = CtxBar_TaskList_Normal() } } - ctxBar := RenderContextBar(st, width, cbItems) - - // ── Footer ── - footer := renderFooter(st, width, "[↑↓] 选择", "[a] 新建", "[y] 复制", "[q] 退出", "◆ AIT v0.1") - - // ── 可用内容高度 ── - headerH := 2 - ctxBarH := 0 - if ctxBar != "" { - ctxBarH = 1 + l := PageLayout{ + TitleLeft: "AIT 任务中心", + InfoLeft: fmt.Sprintf("◆ AIT 已保存任务: %d %s", len(s.Tasks), lastRunStr), + CtxItems: cbItems, + FooterParts: []string{"[↑↓] 选择", "[a] 新建", "[y] 复制", "[q] 退出", "◆ AIT v0.1"}, } - footerH := 1 - contentH := height - headerH - ctxBarH - footerH - 2 // -2 for panel border - if contentH < 4 { - contentH = 4 - } - - // ── 内容区 ── - content := buildTaskListContent(s, st, width-2, contentH) - panel := wrapPanel(st, content, width) - parts := []string{header, panel} - if ctxBar != "" { - parts = append(parts, ctxBar) - } - parts = append(parts, footer) - return strings.Join(parts, "\n") + content := buildTaskListContent(s, st, ContentWidth(width), l.ContentHeight(height)) + return l.Assemble(wrapPanel(st, content, width), st, width) } // buildTaskListContent 构建任务列表内容区(含表头 + 任务条目)。 @@ -215,6 +188,7 @@ func buildTaskListContent(s *TaskListState, st Styles, width, maxH int) string { ) lines = append(lines, header) lines = append(lines, " "+st.Divider.Render(strings.Repeat("─", innerW-1))) + lines = append(lines, "") // 表头与第一条目之间的呼吸间距 if len(s.Tasks) == 0 { lines = append(lines, "") diff --git a/internal/tui/pages/turbodash.go b/internal/tui/pages/turbodash.go index 4f9a328..4356fb8 100644 --- a/internal/tui/pages/turbodash.go +++ b/internal/tui/pages/turbodash.go @@ -128,56 +128,44 @@ func HandleTurboDashKey(d *TurboDashState, msg tea.KeyMsg, client Client) (*Turb // ║ [s] 停止 [b] 后台 [m] 标记极限 [q] 退出 ║ // ╚═════════════════════════════════════════╝ func RenderTurboDash(d *TurboDashState, taskName string, st Styles, width, height int) string { - if d == nil || width == 0 { - return "加载中..." + if TooSmall(width, height) { + return renderTooSmall(st, width, height) + } + if d == nil { + return renderTooSmall(st, width, height) } rs := d.RunState - // ── Header ── - statusStr := "探测中" + // ── 状态标识 ── + statusStr := st.Ok.Render("探测中") if rs != nil && rs.Status != server.RunStatusRunning { statusStr = st.Muted.Render(string(rs.Status)) - } else { - statusStr = st.Ok.Render("探测中") } subtitle := "─" if rs != nil && len(rs.Levels) > 0 { - curLevel := rs.CurrentLevel subtitle = fmt.Sprintf("◆ AIT %s · 当前并发: %d 已完成 %d 级", - "─", curLevel, len(rs.Levels)) + "─", rs.CurrentLevel, len(rs.Levels)) } - header := renderHeader(st, width, - "AIT Turbo 探测 ─ "+truncate(taskName, 22), - statusStr, - subtitle, - "", - ) - - // ── Context Bar ── var cbItems []ContextBarItem if d.LevelSel >= 0 && rs != nil && d.LevelSel < len(rs.Levels) { cbItems = CtxBar_TurboDash_Sel() } else { cbItems = CtxBar_TurboDash_NoSel() } - ctxBar := RenderContextBar(st, width, cbItems) - - // ── Footer ── - footer := renderFooter(st, width, "[s] 停止", "[b] 后台运行", "[m] 标记极限", "[r] 提前报告", "[q] 退出") + l := PageLayout{ + TitleLeft: "AIT Turbo 探测 ─ " + truncate(taskName, 22), + TitleRight: statusStr, + InfoLeft: subtitle, + CtxItems: cbItems, + FooterParts: []string{"[s] 停止", "[b] 后台运行", "[m] 标记极限", "[r] 提前报告", "[q] 退出"}, + } // ── 计算高度 ── - // 布局:header(2) + split面板(splitH) + 进度面板(3) + 级别面板(levelListH+2) + ctxBarH + footer(1) - headerH := 2 - ctxBarH := 0 - if ctxBar != "" { - ctxBarH = 1 - } - footerH := 1 splitH := 9 progressPanel := 3 - levelListH := height - headerH - ctxBarH - footerH - splitH - progressPanel - 2 + levelListH := height - l.ChromeHeight() - splitH - progressPanel - 2 if levelListH < 3 { levelListH = 3 } @@ -192,19 +180,15 @@ func RenderTurboDash(d *TurboDashState, taskName string, st Styles, width, heigh split := lipgloss.JoinHorizontal(lipgloss.Top, leftPanel, rightPanel) // ── 进度条面板 ── - progressLine := buildTurboProgressLine(rs, st, width-2) + progressLine := buildTurboProgressLine(rs, st, ContentWidth(width)) progressPanelStr := wrapPanel(st, progressLine, width) // ── 级别列表面板 ── - levelList := buildLevelList(d, rs, st, width-2, levelListH) + levelList := buildLevelList(d, rs, st, ContentWidth(width), levelListH) levelPanelStr := wrapPanel(st, levelList, width) - parts := []string{header, split, progressPanelStr, levelPanelStr} - if ctxBar != "" { - parts = append(parts, ctxBar) - } - parts = append(parts, footer) - return strings.Join(parts, "\n") + content := strings.Join([]string{split, progressPanelStr, levelPanelStr}, "\n") + return l.Assemble(content, st, width) } // buildTurboDashParams 构建 Turbo 仪表盘左侧任务参数面板。 From 225b775ce1b9cb948a89f79be185c6aef7be5902 Mon Sep 17 00:00:00 2001 From: Alain Date: Sun, 17 May 2026 14:50:40 +0800 Subject: [PATCH 09/52] Refactor TUI components for improved usability and aesthetics - Updated request detail page to simplify footer and header information. - Changed color palette for header background and text for better visibility. - Enhanced task detail page with improved navigation and selection handling. - Implemented dual-column layout for task details and history for better information display. - Refined task list rendering with improved selection indicators and result display. - Simplified turbo dashboard metrics display and improved level selection visibility. - General code cleanup and optimization for better readability and maintainability. --- internal/tui/model.go | 34 +-- internal/tui/pages/contextbar.go | 68 +++-- internal/tui/pages/dashboard.go | 151 ++++++---- internal/tui/pages/helpers.go | 233 +++++++++++++-- internal/tui/pages/layout.go | 48 ++- internal/tui/pages/reqdetail.go | 4 +- internal/tui/pages/styles.go | 40 ++- internal/tui/pages/taskdetail.go | 365 +++++++++++++++-------- internal/tui/pages/tasklist.go | 153 +++++----- internal/tui/pages/turbodash.go | 92 +++--- internal/tui/pages/wizard.go | 492 +++++++++++++++++++++---------- 11 files changed, 1112 insertions(+), 568 deletions(-) diff --git a/internal/tui/model.go b/internal/tui/model.go index 0b672ba..52d18fc 100644 --- a/internal/tui/model.go +++ b/internal/tui/model.go @@ -29,14 +29,13 @@ const ( // Model 是 BubbleTea 的根状态机。 // 所有 Server 交互均通过 Client 发出 tea.Cmd;Model 不直接 import runner/task/turbo 等下层包。 type Model struct { - client *Client - styles pages.Styles - width int - height int - view viewState - prevView viewState // 向导叠加时记录背景视图 - status string - err error + client *Client + styles pages.Styles + width int + height int + view viewState + status string + err error // 页面局部状态(由 pages 包管理) taskList *pages.TaskListState @@ -175,8 +174,8 @@ func (m *Model) View() string { if m.width < 4 || m.height < 4 { return "..." } - innerW := m.width - 2 - innerH := m.height - 2 + innerW := m.width + innerH := m.height var content string switch m.view { @@ -185,8 +184,7 @@ func (m *Model) View() string { case viewTaskDetail: content = pages.RenderTaskDetail(m.detail, m.styles, innerW, innerH) case viewWizard: - bg := m.renderBgForWizard() - content = pages.RenderWizard(m.wizard, bg, m.styles, innerW, innerH) + content = pages.RenderWizard(m.wizard, m.styles, innerW, innerH) case viewDashboard: content = pages.RenderDashboard(m.dash, m.dashTaskName(), m.styles, innerW, innerH) case viewTurboDash: @@ -197,7 +195,7 @@ func (m *Model) View() string { content = "未知视图" } - return m.styles.AppBorder.Width(innerW).Height(innerH).Render(content) + return content } // ─── 键盘分发 ───────────────────────────────────────────────────────────────── @@ -273,7 +271,6 @@ func (m *Model) handleNav(nav pages.NavAction) tea.Cmd { } else { m.wizard = pages.NewWizardState() } - m.prevView = m.view m.view = viewWizard return nil @@ -421,15 +418,6 @@ func (m *Model) injectRunState(rs *server.RunState) { } } -func (m *Model) renderBgForWizard() string { - innerW := m.width - 2 - innerH := m.height - 2 - if m.prevView == viewTaskDetail { - return pages.RenderTaskDetail(m.detail, m.styles, innerW, innerH) - } - return pages.RenderTaskList(m.taskList, m.styles, innerW, innerH) -} - func (m *Model) dashTaskName() string { if m.dash == nil { return "─" diff --git a/internal/tui/pages/contextbar.go b/internal/tui/pages/contextbar.go index 5b4e82c..eb45bf7 100644 --- a/internal/tui/pages/contextbar.go +++ b/internal/tui/pages/contextbar.go @@ -1,35 +1,12 @@ package pages -import ( - "strings" - - "github.com/charmbracelet/lipgloss" -) - -// ContextBarItem 是 Context Bar 中的一个可操作项。 +// ContextBarItem 是底栏中的一个可操作项。 type ContextBarItem struct { Key string // 如 "Enter"、"r"、"↑↓" Desc string // 操作描述 } -// RenderContextBar 渲染 Context Bar。 -// 若 items 为空则返回空字符串(不占空间)。 -func RenderContextBar(st Styles, width int, items []ContextBarItem) string { - if len(items) == 0 { - return "" - } - var parts []string - for _, item := range items { - parts = append(parts, "["+item.Key+"] "+item.Desc) - } - line := " " + strings.Join(parts, " ") - if lipgloss.Width(line) > width { - line = truncate(line, width) - } - return st.CtxBar.Width(width).Render(line) -} - -// ─── 各页面 Context Bar 内容定义 ───────────────────────────────────────────── +// ─── 各页面底栏操作定义 ─────────────────────────────────────────────────────── // CtxBar_TaskList_Normal 普通任务选中时的 Context Bar。 func CtxBar_TaskList_Normal() []ContextBarItem { @@ -64,10 +41,43 @@ func CtxBar_TaskDetail_NoHistory() []ContextBarItem { // CtxBar_TaskDetail_HasHistory 任务详情页,有运行记录时。 func CtxBar_TaskDetail_HasHistory() []ContextBarItem { return []ContextBarItem{ - {Key: "r", Desc: "生成报告"}, - {Key: "c", Desc: "复制摘要"}, - {Key: "Enter/r", Desc: "再次运行"}, + {Key: "↑↓", Desc: "选择记录"}, + {Key: "r", Desc: "导出 JSON 报告"}, + {Key: "Enter", Desc: "再次运行"}, {Key: "e", Desc: "编辑"}, + {Key: "y", Desc: "复制任务"}, + {Key: "d", Desc: "删除"}, + } +} + +// CtxBar_Wizard_Step1 创建任务页,第 1 步。 +func CtxBar_Wizard_Step1() []ContextBarItem { + return []ContextBarItem{ + {Key: "Tab/↑↓", Desc: "切换字段"}, + {Key: "←→", Desc: "切换协议"}, + {Key: "Enter", Desc: "下一步"}, + {Key: "Esc", Desc: "返回列表"}, + } +} + +// CtxBar_Wizard_Step2 创建任务页,第 2 步。 +func CtxBar_Wizard_Step2() []ContextBarItem { + return []ContextBarItem{ + {Key: "Tab/↑↓", Desc: "切换字段"}, + {Key: "←→", Desc: "切换选项"}, + {Key: "Enter", Desc: "下一步"}, + {Key: "Esc", Desc: "返回上一步"}, + } +} + +// CtxBar_Wizard_Step3 创建任务页,第 3 步。 +func CtxBar_Wizard_Step3() []ContextBarItem { + return []ContextBarItem{ + {Key: "↑↓", Desc: "滚动"}, + {Key: "PgUp/PgDn", Desc: "翻页"}, + {Key: "Enter", Desc: "保存"}, + {Key: "r", Desc: "保存并运行"}, + {Key: "Esc", Desc: "返回修改"}, } } @@ -95,6 +105,7 @@ func CtxBar_TurboDash_NoSel() []ContextBarItem { {Key: "s", Desc: "停止"}, {Key: "b", Desc: "后台运行"}, {Key: "m", Desc: "标记极限"}, + {Key: "r", Desc: "提前报告"}, } } @@ -111,6 +122,7 @@ func CtxBar_TurboDash_Sel() []ContextBarItem { func CtxBar_ReqDetail() []ContextBarItem { return []ContextBarItem{ {Key: "b/Esc", Desc: "返回仪表盘"}, + {Key: "↑↓", Desc: "滚动"}, {Key: "←→", Desc: "上/下一条请求"}, } } diff --git a/internal/tui/pages/dashboard.go b/internal/tui/pages/dashboard.go index ce11cb9..b0484cf 100644 --- a/internal/tui/pages/dashboard.go +++ b/internal/tui/pages/dashboard.go @@ -18,6 +18,7 @@ type DashboardState struct { RunState *server.RunState ReqSel int // 选中请求索引(-1 = 无选中) ReqOff int // 滚动偏移 + ReqVis int // 当前可见请求数 } // NewDashboardState 创建仪表盘状态。 @@ -37,25 +38,26 @@ func (d *DashboardState) IsRunning() bool { return d.RunState.Status == server.RunStatusRunning } -// AdjustReqOffset 根据 ReqSel 调整列表可见窗口。 -func (d *DashboardState) AdjustReqOffset(visH int) { +// AdjustReqOffset 根据屏幕显示顺序调整列表可见窗口。 +func (d *DashboardState) AdjustReqOffset(visH, total int) { if d == nil { return } if visH < 3 { visH = 3 } - sel := d.ReqSel - off := d.ReqOff - if sel < 0 { + if total <= 0 || d.ReqSel < 0 { + d.ReqOff = 0 return } + sel := requestDisplayPos(d.ReqSel, total) + off := d.ReqOff if sel < off { off = sel } else if sel >= off+visH { off = sel - visH + 1 } - d.ReqOff = off + d.ReqOff = clampInt(off, 0, maxInt(0, total-visH)) } // HandleDashboardKey 处理仪表盘页按键。 @@ -75,23 +77,33 @@ func HandleDashboardKey(d *DashboardState, msg tea.KeyMsg, client Client) (*Dash if len(reqs) == 0 { break } - if d.ReqSel <= 0 { - d.ReqSel = len(reqs) - 1 + selPos := 0 + if d.ReqSel >= 0 { + selPos = requestDisplayPos(d.ReqSel, len(reqs)) + } + if selPos <= 0 { + selPos = len(reqs) - 1 } else { - d.ReqSel-- + selPos-- } - d.AdjustReqOffset(10) + d.ReqSel = requestIndexFromDisplayPos(selPos, len(reqs)) + d.AdjustReqOffset(d.ReqVis, len(reqs)) case "down", "j": if len(reqs) == 0 { break } - if d.ReqSel < len(reqs)-1 { - d.ReqSel++ + selPos := 0 + if d.ReqSel >= 0 { + selPos = requestDisplayPos(d.ReqSel, len(reqs)) + } + if selPos < len(reqs)-1 { + selPos++ } else { - d.ReqSel = 0 + selPos = 0 } - d.AdjustReqOffset(10) + d.ReqSel = requestIndexFromDisplayPos(selPos, len(reqs)) + d.AdjustReqOffset(d.ReqVis, len(reqs)) case "enter": if d.ReqSel >= 0 && d.ReqSel < len(reqs) { @@ -170,7 +182,7 @@ func RenderDashboard(d *DashboardState, taskName string, st Styles, width, heigh subtitle := "─" if rs != nil { - subtitle = fmt.Sprintf("◆ AIT %s · %s · 并发: %d · 请求: %d", + subtitle = fmt.Sprintf("%s · %s · 并发: %d · 请求: %d", "─", "─", 0, rs.TotalReqs) } @@ -185,7 +197,7 @@ func RenderDashboard(d *DashboardState, taskName string, st Styles, width, heigh TitleRight: statusStr, InfoLeft: subtitle, CtxItems: cbItems, - FooterParts: []string{"[s] 停止", "[b] 后台运行", "[r] 提前报告", "[q] 退出"}, + FooterParts: []string{"[q] 退出"}, } // ── 计算高度 ── @@ -248,7 +260,7 @@ func buildDashMetricsPanel(rs *server.RunState, st Styles, maxH, width int) stri lines = append(lines, " "+st.Muted.Render("等待数据...")) } else { lines = append(lines, " "+labelValue(st, "成功率 ", - st.MetricVal.Render(fmt.Sprintf("%.1f%%", rs.SuccessRate*100)))) + st.MetricVal.Render(fmt.Sprintf("%.1f%%", rs.SuccessRate)))) lines = append(lines, " "+labelValue(st, "avg TPS ", st.MetricVal.Render(fmt.Sprintf("%.1f tok/s", rs.AvgTPS)))) lines = append(lines, " "+labelValue(st, "avg TTFT", @@ -306,57 +318,70 @@ func buildRequestList(d *DashboardState, rs *server.RunState, st Styles, width, return strings.Join(lines, "\n") } - // 表头 - lines = append(lines, " "+st.TableHead.Render( - padRight("#", 6)+padRight("状态", 6)+padRight("总耗时", 10)+ - padRight("TTFT", 10)+padRight("Cache", 8)+padRight("输出Token", 10)+"TPS")) - lines = append(lines, " "+st.Divider.Render(strings.Repeat("─", width-2))) + // 列宽(header 与 content 行保持一致,前缀均为 2 字符) + const ( + markW = 2 // 选择标记列 + idW = 6 // "#1" 等 + statW = 5 // "✓" / "✗" 加空白 + timeW = 10 // 总耗时 + ttftW = 10 // TTFT + cacheW = 8 // Cache + tokW = 10 // Token + // TPS: 余量 + ) + hdr := padRight("", markW) + padRight("#", idW) + padRight("状态", statW) + padRight("总耗时", timeW) + + padRight("TTFT", ttftW) + padRight("Cache", cacheW) + padRight("Token", tokW) + "TPS" + lines = append(lines, renderTableHeader(st, width, hdr)) + lines = append(lines, dividerLine(st, width)) + d.ReqVis = listVisibleItems(maxH, 3) + d.AdjustReqOffset(d.ReqVis, len(rs.Requests)) reqs := rs.Requests - // 倒序展示(最新在上方) - off := d.ReqOff - for i := len(reqs) - 1 - off; i >= 0; i-- { - if len(lines) >= maxH { - break - } + start := d.ReqOff + end := minInt(len(reqs), start+d.ReqVis) + for pos := start; pos < end; pos++ { + i := requestIndexFromDisplayPos(pos, len(reqs)) r := reqs[i] isSel := i == d.ReqSel - statusStr := st.Ok.Render("✓") + statusText := "✓" if !r.Success { - statusStr = st.ErrStyle.Render("✗") + statusText = "✗" } - totalTime := fmtDuration(r.TotalTime) + totalText := fmtDuration(r.TotalTime) if !r.Success && r.ErrorMessage != "" { - totalTime = st.ErrStyle.Render("timeout") - } - ttft := fmtDuration(r.TTFT) - cache := fmt.Sprintf("%.0f%%", r.CacheHitRate*100) - tok := fmt.Sprintf("%dtok", r.CompletionTokens) - tps := fmt.Sprintf("%.1f/s", r.TPS) - - row := fmt.Sprintf(" %s %s %s %s %s %s %s", - padRight(fmt.Sprintf("#%d", r.Index+1), 5), - statusStr, - padRight(totalTime, 9), - padRight(ttft, 9), - padRight(cache, 7), - padRight(tok, 9), - tps, - ) - - var rendered string - cursorStr := " " - if isSel { - cursorStr = "▶ " + totalText = "timeout" } - if isSel { - rendered = st.TableRowSel.Render(cursorStr+row) + - strings.Repeat(" ", max(0, width-lipgloss.Width(cursorStr+row)-2)) + + statusStr := statusText + if r.Success { + statusStr = styleWhenNotSelected(isSel, st.Ok, statusText) } else { - rendered = " " + st.TableRow.Render(row) + statusStr = styleWhenNotSelected(isSel, st.ErrStyle, statusText) + } + totalStr := totalText + if !r.Success && r.ErrorMessage != "" { + totalStr = styleWhenNotSelected(isSel, st.ErrStyle, totalText) } + + marker := selectionMarker(isSel) + + rowContent := padRight(marker, markW) + + padRight(fmt.Sprintf("#%d", r.Index+1), idW) + + padRight(statusStr, statW) + + padRight(totalStr, timeW) + + padRight(fmtDuration(r.TTFT), ttftW) + + padRight(fmt.Sprintf("%.0f%%", r.CacheHitRate*100), cacheW) + + padRight(fmt.Sprintf("%dtok", r.CompletionTokens), tokW) + + fmt.Sprintf("%.1f/s", r.TPS) + + rendered := renderTableRow(st, width, isSel, rowContent) lines = append(lines, rendered) + + // 行间分隔线 + if pos < end-1 && len(lines) < maxH-1 { + lines = append(lines, dividerLine(st, width)) + } } for len(lines) < maxH { @@ -364,3 +389,17 @@ func buildRequestList(d *DashboardState, rs *server.RunState, st Styles, width, } return strings.Join(lines[:maxH], "\n") } + +func requestDisplayPos(reqIndex, total int) int { + if total <= 0 { + return 0 + } + return clampInt(total-1-reqIndex, 0, total-1) +} + +func requestIndexFromDisplayPos(pos, total int) int { + if total <= 0 { + return 0 + } + return clampInt(total-1-pos, 0, total-1) +} diff --git a/internal/tui/pages/helpers.go b/internal/tui/pages/helpers.go index 4d3a1e5..c73d047 100644 --- a/internal/tui/pages/helpers.go +++ b/internal/tui/pages/helpers.go @@ -4,7 +4,6 @@ import ( "fmt" "strings" "time" - "unicode/utf8" "github.com/charmbracelet/lipgloss" ) @@ -24,10 +23,7 @@ func truncate(s string, maxW int) string { runes := []rune(s) total := 0 for i, r := range runes { - rw := utf8.RuneLen(r) - if rw < 1 { - rw = 1 - } + rw := lipgloss.Width(string(r)) if total+rw > maxW-1 { return string(runes[:i]) + "…" } @@ -113,25 +109,60 @@ func fmtDuration(d time.Duration) string { // ─── 布局工具 ───────────────────────────────────────────────────────────────── // renderHeader 渲染顶部双行标题栏。 -// 第一行:titleLeft(左)/ titleRight(右),紫色背景加粗 +// 第一行:◆ brand(粉色)│ 页面标题(青色),深色背景 // 第二行:infoLeft(左)/ infoRight(右),较暗色背景 func renderHeader(st Styles, width int, titleLeft, titleRight, infoLeft, infoRight string) string { w := width if w < 1 { w = 80 } - // 第一行 - tl := " " + titleLeft - tr := titleRight + " " - tlW := lipgloss.Width(tl) - trW := lipgloss.Width(tr) - pad1 := w - tlW - trW + + // Line 1: avoid nested Render() fragments to prevent ANSI reset from breaking background. + brand := titleLeft + pageTitle := "" + if idx := strings.Index(titleLeft, " "); idx >= 0 { + brand = titleLeft[:idx] + pageTitle = strings.TrimSpace(titleLeft[idx:]) + } + + brandSeg := lipgloss.NewStyle(). + Background(colorHeaderBg). + Foreground(colorPink). + Bold(true). + Render(" ◆ " + brand) + sepSeg := lipgloss.NewStyle(). + Background(colorHeaderBg). + Foreground(colorDivider). + Render(" │ ") + titleSeg := lipgloss.NewStyle(). + Background(colorHeaderBg). + Foreground(colorCyan). + Bold(true). + Render(pageTitle) + + left1 := brandSeg + if pageTitle != "" { + left1 += sepSeg + titleSeg + } + right1 := "" + if titleRight != "" { + right1 = lipgloss.NewStyle(). + Background(colorHeaderBg). + Foreground(colorMuted). + Render(titleRight + " ") + } + left1W := lipgloss.Width(left1) + right1W := lipgloss.Width(right1) + pad1 := w - left1W - right1W if pad1 < 0 { pad1 = 0 } - line1 := tl + strings.Repeat(" ", pad1) + tr + padSeg := lipgloss.NewStyle(). + Background(colorHeaderBg). + Render(strings.Repeat(" ", pad1)) + line1 := left1 + padSeg + right1 - // 第二行 + // ─ Line 2: info bar ─ il := " " + infoLeft ir := infoRight + " " ilW := lipgloss.Width(il) @@ -140,10 +171,9 @@ func renderHeader(st Styles, width int, titleLeft, titleRight, infoLeft, infoRig if pad2 < 0 { pad2 = 0 } - line2 := il + strings.Repeat(" ", pad2) + ir + line2 := st.HeaderInfo.Width(w).Render(il + strings.Repeat(" ", pad2) + ir) - return st.Header.Width(w).Render(line1) + "\n" + - st.HeaderInfo.Width(w).Render(line2) + return line1 + "\n" + line2 } // renderFooter 渲染底部状态栏(单行,深色背景)。 @@ -162,6 +192,175 @@ func renderFooter(st Styles, width int, parts ...string) string { return st.Footer.Width(w).Render(line) } +// renderTableHeader 统一渲染列表表头。 +func renderTableHeader(st Styles, width int, row string) string { + return st.TableHead.Width(width).Render(row) +} + +// renderTableRow 统一渲染列表行(选中/未选中)。 +func renderTableRow(st Styles, width int, isSel bool, row string) string { + if isSel { + return st.TableRowSel.Width(width).Render(row) + } + return st.TableRow.Width(width).Render(row) +} + +// minInt 返回两个整数中的较小值。 +func minInt(a, b int) int { + if a < b { + return a + } + return b +} + +// maxInt 返回两个整数中的较大值。 +func maxInt(a, b int) int { + if a > b { + return a + } + return b +} + +// clampInt 将 v 约束在 [low, high] 区间内。 +func clampInt(v, low, high int) int { + if v < low { + return low + } + if v > high { + return high + } + return v +} + +// listVisibleItems 计算在给定高度下可自然滚动的列表项数量。 +// staticLines 是列表项区域前的固定行数(如 section/header/divider)。 +func listVisibleItems(maxLines, staticLines int) int { + visible := (maxLines - staticLines + 1) / 2 + if visible < 1 { + return 1 + } + return visible +} + +// ensureVisibleOffset 让 selected 始终位于 offset/visible 定义的可视窗口内。 +func ensureVisibleOffset(selected, count, offset, visible int) int { + if count <= 0 { + return 0 + } + if visible < 1 { + visible = 1 + } + selected = clampInt(selected, 0, count-1) + maxOffset := maxInt(0, count-visible) + offset = clampInt(offset, 0, maxOffset) + if selected < offset { + offset = selected + } + if selected >= offset+visible { + offset = selected - visible + 1 + } + return clampInt(offset, 0, maxOffset) +} + +// selectionMarker 返回统一的选中标记列内容。 +func selectionMarker(isSel bool) string { + if isSel { + return "▶" + } + return "" +} + +// styleWhenNotSelected 仅在未选中时应用局部样式,避免重置选中行背景。 +func styleWhenNotSelected(isSel bool, style lipgloss.Style, text string) string { + if isSel { + return text + } + return style.Render(text) +} + +// renderWelcomeHero 渲染任务中心顶部的品牌欢迎区。 +func renderWelcomeHero(st Styles, width int) []string { + if width < 42 { + return nil + } + + art := []string{ + " _ ___ _____", + " / \\ |_ _|_ _|", + " / _ \\ | | | | ", + " / ___ \\ | | | | ", + "/_/ \\_\\___| |_| ", + } + artStyles := []lipgloss.Style{ + lipgloss.NewStyle().Foreground(colorPink).Bold(true), + lipgloss.NewStyle().Foreground(colorCyan).Bold(true), + lipgloss.NewStyle().Foreground(colorGold).Bold(true), + lipgloss.NewStyle().Foreground(colorTeal).Bold(true), + lipgloss.NewStyle().Foreground(colorPurple).Bold(true), + } + + type heroTextLine struct { + style lipgloss.Style + text string + } + intro := []heroTextLine{ + {style: st.SectionHead, text: "AI 模型性能测试工作台"}, + {style: st.Value, text: "批量压测 OpenAI / Anthropic 协议模型,聚焦 TTFT、TPS、缓存与网络指标。"}, + {style: st.Muted, text: "从任务中心出发:创建任务、直接运行、查看执行记录、导出报告。"}, + {style: st.Muted, text: "[a] 新建任务 [Enter] 查看详情/进入仪表盘 [r] 立即运行"}, + } + + artW := 0 + for _, line := range art { + artW = maxInt(artW, lipgloss.Width(line)) + } + + if width >= 76 { + gap := 3 + rightW := maxInt(18, width-artW-gap) + wrapped := make([]string, 0, 8) + for i, line := range intro { + segments := wrapText(line.text, rightW) + if len(segments) == 0 { + segments = []string{""} + } + for _, segment := range segments { + wrapped = append(wrapped, line.style.Render(segment)) + } + if i == 0 { + wrapped = append(wrapped, "") + } + } + + total := maxInt(len(art), len(wrapped)) + lines := make([]string, 0, total) + for i := 0; i < total; i++ { + left := strings.Repeat(" ", artW) + if i < len(art) { + left = artStyles[i].Render(art[i]) + } + right := "" + if i < len(wrapped) { + right = wrapped[i] + } + lines = append(lines, padRight(left, artW)+strings.Repeat(" ", gap)+right) + } + return lines + } + + lines := make([]string, 0, len(art)+len(intro)+1) + for i, line := range art { + lines = append(lines, artStyles[i].Render(line)) + } + lines = append(lines, "") + for _, line := range intro { + for _, segment := range wrapText(line.text, width) { + lines = append(lines, line.style.Render(segment)) + } + } + return lines +} + // wrapIndex 循环索引(保证 0 ≤ result < count)。 func wrapIndex(idx, count int) int { if count <= 0 { diff --git a/internal/tui/pages/layout.go b/internal/tui/pages/layout.go index bc666aa..e10c285 100644 --- a/internal/tui/pages/layout.go +++ b/internal/tui/pages/layout.go @@ -9,10 +9,9 @@ const ( MinWidth = 40 MinHeight = 10 - // chrome 各组成部分的行数 - chromeHeaderH = 2 // 双行标题栏 - chromeFooterH = 1 // 单行底部状态栏 - chromeCtxBarH = 2 // ContextBar (1) + 分隔线 (1) + // chrome 各组成部分的行数(仅保留单条合并底栏) + chromeHeaderH = 0 // 顶部 header 已移除 + chromeFooterH = 1 // 单行底部状态栏(含上下文操作 + 全局导航,已合并) // panelBorderV 是单个面板的上下边框行数之和。 panelBorderV = 2 @@ -20,9 +19,8 @@ const ( // ── PageLayout ──────────────────────────────────────────────────────────────── -// PageLayout 描述一个完整页面的 chrome(顶部标题栏、底部 ContextBar + 分隔线 + Footer)。 -// 各页面 Render 函数先构造 PageLayout,再调用 Assemble 拼装最终输出, -// 从而消除页面间重复的 header/footer/ctxbar 组装逻辑。 +// PageLayout 描述一个完整页面的 chrome(底部 ContextBar + Footer)。 +// 各页面 Render 函数先构造 PageLayout,再调用 Assemble 拼装最终输出。 type PageLayout struct { TitleLeft string TitleRight string @@ -32,14 +30,9 @@ type PageLayout struct { FooterParts []string } -// ChromeHeight 返回 chrome 占用的总行数 -// (header + 若有 ctxbar 则含 ctxbar+分隔线 + footer)。 +// ChromeHeight 返回 chrome 占用的总行数(当前仅包含合并底栏)。 func (l PageLayout) ChromeHeight() int { - h := chromeHeaderH + chromeFooterH - if len(l.CtxItems) > 0 { - h += chromeCtxBarH - } - return h + return chromeHeaderH + chromeFooterH } // ContentHeight 返回单面板页面主内容区的可用行数 @@ -64,22 +57,21 @@ func ContentWidth(totalW int) int { // Assemble 拼装完整页面输出: // -// header -// content ← 由调用方构建(单面板或多面板) -// [ctxbar] ← 仅当 CtxItems 非空时输出 -// [divider] ← ctxbar 与 footer 之间的分隔线 -// footer +// content +// 底栏(上下文操作 · 全局导航,合并为单行) func (l PageLayout) Assemble(content string, st Styles, width int) string { - header := renderHeader(st, width, l.TitleLeft, l.TitleRight, l.InfoLeft, l.InfoRight) - footer := renderFooter(st, width, l.FooterParts...) - - parts := []string{header, content} - if len(l.CtxItems) > 0 { - ctxBar := RenderContextBar(st, width, l.CtxItems) - parts = append(parts, ctxBar, dividerLine(st, width)) + // 将上下文操作与全局导航合并为单条底栏,用 · 分隔 + var barParts []string + for _, item := range l.CtxItems { + barParts = append(barParts, "["+item.Key+"] "+item.Desc) } - parts = append(parts, footer) - return strings.Join(parts, "\n") + if len(l.CtxItems) > 0 && len(l.FooterParts) > 0 { + barParts = append(barParts, "·") + } + barParts = append(barParts, l.FooterParts...) + footer := renderFooter(st, width, barParts...) + + return strings.Join([]string{content, footer}, "\n") } // ── 最小尺寸保护 ────────────────────────────────────────────────────────────── diff --git a/internal/tui/pages/reqdetail.go b/internal/tui/pages/reqdetail.go index 87dc103..ece0dcc 100644 --- a/internal/tui/pages/reqdetail.go +++ b/internal/tui/pages/reqdetail.go @@ -121,10 +121,10 @@ func RenderReqDetail(s *ReqDetailState, taskName string, st Styles, width, heigh l := PageLayout{ TitleLeft: fmt.Sprintf("AIT 请求详情 - %s #%d", truncate(taskName, 20), idx+1), TitleRight: statusStr, - InfoLeft: fmt.Sprintf("◆ AIT 任务: %s 请求 %d / %d", + InfoLeft: fmt.Sprintf("任务: %s 请求 %d / %d", truncate(taskName, 20), idx+1, len(s.Requests)), CtxItems: CtxBar_ReqDetail(), - FooterParts: []string{"[b/Esc] 返回仪表盘", "[↑↓] 滚动", "[←→] 上/下一条请求"}, + FooterParts: []string{"[q] 退出"}, } // ── 计算高度 ── diff --git a/internal/tui/pages/styles.go b/internal/tui/pages/styles.go index d1b4548..0b497b2 100644 --- a/internal/tui/pages/styles.go +++ b/internal/tui/pages/styles.go @@ -4,7 +4,7 @@ import "github.com/charmbracelet/lipgloss" // Color palette const ( - colorHeaderBg = lipgloss.Color("57") // electric indigo — header background + colorHeaderBg = lipgloss.Color("17") // dark navy — refined header background colorFooterBg = lipgloss.Color("235") // near-black footer background colorCtxBarBg = lipgloss.Color("237") // slightly lighter than footer — context bar colorPink = lipgloss.Color("205") // vivid hot pink/magenta — brand primary @@ -18,13 +18,12 @@ const ( colorWhite = lipgloss.Color("255") // bright white colorMuted = lipgloss.Color("245") // muted gray colorGold = lipgloss.Color("214") // amber - colorHeaderFg = lipgloss.Color("212") // light pink — header right text - colorDivider = lipgloss.Color("238") // dim border gray + colorHeaderFg = lipgloss.Color("248") // light gray — header info text + colorDivider = lipgloss.Color("241") // dim border gray — slightly more visible ) // Styles 汇聚所有 TUI 样式,由 NewStyles() 初始化。 type Styles struct { - AppBorder lipgloss.Style Panel lipgloss.Style Header lipgloss.Style HeaderInfo lipgloss.Style @@ -41,7 +40,6 @@ type Styles struct { ErrStyle lipgloss.Style Key lipgloss.Style MetricVal lipgloss.Style - Dialog lipgloss.Style FieldActive lipgloss.Style FieldIdle lipgloss.Style Cursor lipgloss.Style @@ -71,14 +69,18 @@ func NewStyles() Styles { Foreground(colorPink). Bold(true), TableHead: lipgloss.NewStyle(). + Background(lipgloss.Color("234")). Foreground(colorCyan). - Bold(true), + Bold(true). + Padding(0, 0), TableRow: lipgloss.NewStyle(). - Foreground(colorWhite), + Foreground(colorWhite). + Padding(0, 0), TableRowSel: lipgloss.NewStyle(). Background(colorPurpleDim). Foreground(colorWhite). - Bold(true), + Bold(true). + Padding(0, 0), Label: lipgloss.NewStyle(). Foreground(colorTeal). Bold(true), @@ -97,29 +99,28 @@ func NewStyles() Styles { MetricVal: lipgloss.NewStyle(). Foreground(colorYellow). Bold(true), - Dialog: lipgloss.NewStyle(). - Border(lipgloss.RoundedBorder()). - BorderForeground(colorPurple). - Padding(0, 1), FieldActive: lipgloss.NewStyle(). - Background(lipgloss.Color("55")). + Background(lipgloss.Color("236")). Foreground(colorWhite). + Border(lipgloss.NormalBorder()). + BorderForeground(colorPink). + Bold(true). Padding(0, 1), FieldIdle: lipgloss.NewStyle(). + Background(lipgloss.Color("234")). + Foreground(colorWhite). Border(lipgloss.NormalBorder()). - BorderForeground(colorDivider). + BorderForeground(lipgloss.Color("238")). Padding(0, 1), Cursor: lipgloss.NewStyle(). Foreground(colorPink). Bold(true), TagTurbo: lipgloss.NewStyle(). - Background(colorGold). - Foreground(colorDivider). + Foreground(colorGold). Bold(true). Padding(0, 1), TagStd: lipgloss.NewStyle(). - Background(colorPurple). - Foreground(colorWhite). + Foreground(colorPurple). Padding(0, 1), BtnPrimary: lipgloss.NewStyle(). Background(colorPink). @@ -128,9 +129,6 @@ func NewStyles() Styles { Padding(0, 2), Divider: lipgloss.NewStyle(). Foreground(colorDivider), - AppBorder: lipgloss.NewStyle(). - Border(lipgloss.RoundedBorder()). - BorderForeground(colorPurple), Panel: lipgloss.NewStyle(). Border(lipgloss.NormalBorder()). BorderForeground(colorDivider), diff --git a/internal/tui/pages/taskdetail.go b/internal/tui/pages/taskdetail.go index f9e0155..39b0519 100644 --- a/internal/tui/pages/taskdetail.go +++ b/internal/tui/pages/taskdetail.go @@ -4,7 +4,9 @@ import ( "fmt" "strings" + "github.com/charmbracelet/lipgloss" tea "github.com/charmbracelet/bubbletea" + "github.com/yinxulai/ait/internal/server" "github.com/yinxulai/ait/internal/types" ) @@ -12,6 +14,10 @@ import ( type TaskDetailState struct { Task types.TaskDefinition History []types.TaskRunSummary + // HistorySel 当前选中的历史记录索引(0 = 最近一次) + HistorySel int + HistoryOff int + HistoryVis int // LatestExpanded 控制最近一次运行是否展开(运行结束后自动置 true) LatestExpanded bool } @@ -25,10 +31,29 @@ func NewTaskDetailState(task types.TaskDefinition) *TaskDetailState { func HandleTaskDetailKey(s *TaskDetailState, msg tea.KeyMsg, client Client) (*TaskDetailState, tea.Cmd, NavAction) { nav := NavAction{} switch msg.String() { + case "up", "k": + if s.HistorySel > 0 { + s.HistorySel-- + } + + case "down", "j": + if s.HistorySel < len(s.History)-1 { + s.HistorySel++ + } + case "left", "esc", "b": nav = NavAction{To: NavTaskList} - case "enter", "r": + case "enter": + return s, client.StartRunCmd(s.Task.ID), nav + + case "r": + if s.HistorySel >= 0 && s.HistorySel < len(s.History) { + runID := strings.TrimSpace(s.History[s.HistorySel].RunID) + if runID != "" { + return s, client.GenerateReportCmd(server.RunID(runID), server.ReportFormatJSON), nav + } + } return s, client.StartRunCmd(s.Task.ID), nav case "e": @@ -44,6 +69,7 @@ func HandleTaskDetailKey(s *TaskDetailState, msg tea.KeyMsg, client Client) (*Ta case "q", "ctrl+c": nav = NavAction{To: NavQuit} } + s.HistoryOff = ensureVisibleOffset(s.HistorySel, len(s.History), s.HistoryOff, s.HistoryVis) return s, nil, nav } @@ -85,160 +111,146 @@ func RenderTaskDetail(s *TaskDetailState, st Styles, width, height int) string { } l := PageLayout{ TitleLeft: "AIT 任务详情 ─ " + truncate(t.Name, 30), - InfoLeft: fmt.Sprintf("◆ AIT 任务 ID: %s 更新: %s %s", + InfoLeft: fmt.Sprintf("任务 ID: %s 更新: %s %s", truncate(t.ID, 10), t.UpdatedAt.Format("2006-01-02 15:04"), updatedStr), CtxItems: cbItems, - FooterParts: []string{"[b/Esc] 返回列表", "[r] 运行", "[e] 编辑", "◆ AIT v0.1"}, + FooterParts: []string{"[b/Esc] 返回列表", "◆ AIT v0.1"}, } content := buildTaskDetailContent(s, st, t, inp, ContentWidth(width), l.ContentHeight(height)) return l.Assemble(wrapPanel(st, content, width), st, width) } -// buildTaskDetailContent 构建任务详情内容区。 +// buildTaskDetailContent 构建任务详情内容区(左右双栏布局)。 +// 左栏(40%):配置摘要 右栏(60%):历史运行记录 func buildTaskDetailContent(s *TaskDetailState, st Styles, t types.TaskDefinition, inp types.Input, width, maxH int) string { - innerW := width - 2 - if innerW < 10 { - innerW = 10 + leftW := width * 4 / 10 + if leftW < 26 { + leftW = 26 } + rightW := width - leftW - 1 // 1 列用于 │ 分隔符 - var lines []string + // ─── 左栏:配置摘要 ───────────────────────────────────────── + var leftLines []string + leftLines = append(leftLines, padRight(" "+st.SectionHead.Render("配置摘要"), leftW)) + leftLines = append(leftLines, padRight(st.Divider.Render(strings.Repeat("─", leftW)), leftW)) + leftLines = append(leftLines, padRight("", leftW)) - // ─── 配置摘要 ───────────────────────────────────────────── - lines = append(lines, " "+st.SectionHead.Render("配置摘要")) - lines = append(lines, " "+dividerLine(st, innerW-2)) - - // 行1:协议 + 接口 proto := inp.NormalizedProtocol() - endpoint := truncate(inp.ResolvedEndpointURL(), innerW-30) - lines = append(lines, " "+ - st.Label.Render("协议")+" "+st.Value.Render(proto)+ - " "+st.Label.Render("接口")+" "+st.Value.Render(endpoint)) + leftLines = append(leftLines, padRight(" "+st.Label.Render("协议")+" "+st.Value.Render(proto), leftW)) + endpoint := truncate(inp.ResolvedEndpointURL(), leftW-8) + leftLines = append(leftLines, padRight(" "+st.Label.Render("接口")+" "+st.Value.Render(endpoint), leftW)) + leftLines = append(leftLines, padRight("", leftW)) - // 行2:模型 + 模式 + 并发 + 请求 + model := truncate(inp.Model, leftW-10) + leftLines = append(leftLines, padRight(" "+st.Label.Render("模型")+" "+st.Value.Render(model), leftW)) modeStr := "标准模式" if inp.Turbo { modeStr = "Turbo 模式" } + leftLines = append(leftLines, padRight(" "+st.Label.Render("模式")+" "+st.Value.Render(modeStr), leftW)) if inp.Turbo { tc := inp.TurboConfig - lines = append(lines, " "+ - st.Label.Render("模型")+" "+st.Value.Render(inp.Model)+ - " "+st.Label.Render("模式")+" "+st.Value.Render(modeStr)+ - " "+st.Label.Render("并发爬坡")+" "+ - st.Value.Render(fmt.Sprintf("%d → %d 步进+%d 每级%d请求", - tc.InitConcurrency, tc.MaxConcurrency, tc.StepSize, tc.LevelRequests))) + leftLines = append(leftLines, padRight(" "+st.Label.Render("并发")+" "+st.Value.Render( + fmt.Sprintf("%d → %d", tc.InitConcurrency, tc.MaxConcurrency)), leftW)) + leftLines = append(leftLines, padRight(" "+st.Label.Render("步进")+" "+st.Value.Render( + fmt.Sprintf("+%d 每级%d请求", tc.StepSize, tc.LevelRequests)), leftW)) } else { - lines = append(lines, " "+ - st.Label.Render("模型")+" "+st.Value.Render(inp.Model)+ - " "+st.Label.Render("模式")+" "+st.Value.Render(modeStr)+ - " "+st.Label.Render("并发")+" "+st.Value.Render(fmt.Sprintf("%d", inp.Concurrency))+ - " "+st.Label.Render("请求")+" "+st.Value.Render(fmt.Sprintf("%d", inp.Count))) + leftLines = append(leftLines, padRight(" "+st.Label.Render("并发")+" "+st.Value.Render( + fmt.Sprintf("%d", inp.Concurrency)), leftW)) + leftLines = append(leftLines, padRight(" "+st.Label.Render("请求")+" "+st.Value.Render( + fmt.Sprintf("%d", inp.Count)), leftW)) } - - // 行3:超时 + 流式 + Prompt + leftLines = append(leftLines, padRight("", leftW)) + leftLines = append(leftLines, padRight(" "+st.Label.Render("超时")+" "+st.Value.Render(fmtDuration(inp.Timeout)), leftW)) + leftLines = append(leftLines, padRight(" "+st.Label.Render("流式")+" "+st.Value.Render(boolLabel(inp.Stream)), leftW)) prompt := promptSummary(inp.PromptMode, inp.PromptText, inp.PromptFile, inp.PromptLength) - lines = append(lines, " "+ - st.Label.Render("超时")+" "+st.Value.Render(fmtDuration(inp.Timeout))+ - " "+st.Label.Render("流式")+" "+st.Value.Render(boolLabel(inp.Stream))+ - " "+st.Label.Render("Prompt")+" "+st.Value.Render(truncate(prompt, innerW-50))) + leftLines = append(leftLines, padRight(" "+st.Label.Render("Prompt")+" "+st.Value.Render(truncate(prompt, leftW-12)), leftW)) - lines = append(lines, "") + // ─── 右栏:历史运行记录 ───────────────────────────────────── + var rightLines []string + rightLines = append(rightLines, padRight(" "+st.SectionHead.Render("历史运行记录"), rightW)) + rightLines = append(rightLines, padRight(st.Divider.Render(strings.Repeat("─", rightW)), rightW)) - // ─── 最近运行 ────────────────────────────────────────────── - if len(s.History) > 0 { - latest := s.History[0] - statusStr := "✓ 完成" - if latest.Status != "completed" { - statusStr = "✗ " + latest.Status - } - elapsed := latest.FinishedAt.Sub(latest.StartedAt) - expandMark := "▼" - if !s.LatestExpanded { - expandMark = "▶" - } - lines = append(lines, fmt.Sprintf(" %s %s 最近运行 %s %s %d 请求 耗时 %s", - st.SectionHead.Render("最近运行"), - st.Ok.Render(expandMark), - latest.StartedAt.Format("2006-01-02 15:04"), - st.Ok.Render(statusStr), - 0, // 请求总数(运行摘要中需要补充该字段,暂用 0) - fmtDuration(elapsed), - )) - lines = append(lines, " "+dividerLine(st, innerW-2)) - - if s.LatestExpanded && len(lines) < maxH-10 { - // 指标表格 - lines = append(lines, " "+st.TableHead.Render( - padRight("指标", 16)+padRight("最小值", 10)+padRight("平均值", 10)+padRight("标准差", 10)+"最大值")) - lines = append(lines, " "+st.Divider.Render(strings.Repeat("─", innerW-2))) - - if len(lines) < maxH { - lines = append(lines, buildMetricRow(st, "TTFT", - fmtDuration(latest.AvgTTFT), fmtDuration(latest.AvgTTFT), "─", "─")) - } - if len(lines) < maxH { - lines = append(lines, buildMetricRow(st, "输出 TPS", - "─", fmt.Sprintf("%.1f", latest.AvgTPS), "─", "─")) - } - if len(lines) < maxH { - lines = append(lines, buildMetricRow(st, "成功率", - "─", fmt.Sprintf("%.1f%%", latest.SuccessRate*100), "─", "─")) - } - if latest.CacheHitRate > 0 && len(lines) < maxH { - lines = append(lines, buildMetricRow(st, "缓存命中率", - "─", fmt.Sprintf("%.1f%%", latest.CacheHitRate*100), "─", "─")) - } - if latest.ErrorSummary != "" && len(lines) < maxH { - lines = append(lines, " "+st.ErrStyle.Render("错误 "+truncate(latest.ErrorSummary, innerW-10))) + const ( + markW = 2 + statW = 2 + timeW = 17 + modeW = 7 + rateW = 8 + ttftW = 10 + ) + hdr := padRight("", markW) + padRight("", statW) + padRight("时间", timeW) + padRight("模式", modeW) + + padRight("成功率", rateW) + padRight("TTFT", ttftW) + "TPS" + rightLines = append(rightLines, padRight(renderTableHeader(st, rightW, hdr), rightW)) + rightLines = append(rightLines, padRight(st.Divider.Render(strings.Repeat("─", rightW)), rightW)) + + if len(s.History) == 0 { + rightLines = append(rightLines, padRight(" "+st.Muted.Render("暂无运行记录"), rightW)) + } else { + detailLines := buildTaskHistoryDetailLines(s, st, rightW) + tableMaxH := maxH - len(detailLines) + if tableMaxH < 5 { + allowedDetail := maxInt(0, maxH-5) + if len(detailLines) > allowedDetail { + detailLines = detailLines[:allowedDetail] } + tableMaxH = maxH - len(detailLines) } - lines = append(lines, "") - } + s.HistoryVis = listVisibleItems(tableMaxH, 4) + s.HistoryOff = ensureVisibleOffset(s.HistorySel, len(s.History), s.HistoryOff, s.HistoryVis) + start := s.HistoryOff + end := minInt(len(s.History), start+s.HistoryVis) - // ─── 历史运行记录 ────────────────────────────────────────── - if len(lines) < maxH-4 { - lines = append(lines, " "+st.SectionHead.Render("历史运行记录")) - lines = append(lines, " "+st.TableHead.Render( - padRight("时间", 20)+padRight("模式", 8)+padRight("成功率", 8)+ - padRight("TTFT", 10)+padRight("TPS", 10)+"Cache")) - lines = append(lines, " "+st.Divider.Render(strings.Repeat("─", innerW-2))) - - for _, run := range s.History { - if len(lines) >= maxH-1 { - break - } - statusIcon := st.Ok.Render("✓") + // ── 历史列表 ── + for idx := start; idx < end; idx++ { + run := s.History[idx] + statusText := "✓" if run.Status != "completed" { - statusIcon = st.ErrStyle.Render("✗") + statusText = "✗" } modeShort := "标准" if run.Mode == "turbo" { modeShort = "Turbo" } - cacheStr := "─" - if run.CacheHitRate > 0 { - cacheStr = fmt.Sprintf("%.1f%%", run.CacheHitRate*100) + isSel := idx == s.HistorySel + statusIcon := statusText + if run.Status == "completed" { + statusIcon = styleWhenNotSelected(isSel, st.Ok, statusText) + } else { + statusIcon = styleWhenNotSelected(isSel, st.ErrStyle, statusText) + } + + marker := selectionMarker(isSel) + + row := padRight(marker, markW) + + padRight(statusIcon, statW) + + padRight(run.StartedAt.Format("2006-01-02 15:04"), timeW) + + padRight(modeShort, modeW) + + padRight(fmt.Sprintf("%.1f%%", run.SuccessRate), rateW) + + padRight(fmtDuration(run.AvgTTFT), ttftW) + + fmt.Sprintf("%.1f", run.AvgTPS) + rightLines = append(rightLines, padRight(renderTableRow(st, rightW, isSel, row), rightW)) + if idx < end-1 { + rightLines = append(rightLines, padRight(st.Divider.Render(strings.Repeat("─", rightW)), rightW)) } - row := fmt.Sprintf(" %s %s %s %s %s %s %s", - statusIcon, - run.StartedAt.Format("2006-01-02 15:04"), - padRight(modeShort, 6), - padRight(fmt.Sprintf("%.1f%%", run.SuccessRate*100), 7), - padRight(fmtDuration(run.AvgTTFT), 9), - padRight(fmt.Sprintf("%.1f", run.AvgTPS), 9), - cacheStr, - ) - lines = append(lines, " "+st.TableRow.Render(row)) } + rightLines = append(rightLines, detailLines...) } - // 补齐剩余高度 - for len(lines) < maxH { - lines = append(lines, "") + // ─── 合并双栏 ────────────────────────────────────────────── + for len(leftLines) < maxH { + leftLines = append(leftLines, padRight("", leftW)) } - - return strings.Join(lines, "\n") + for len(rightLines) < maxH { + rightLines = append(rightLines, padRight("", rightW)) + } + sep := st.Divider.Render("│") + var combined []string + for i := 0; i < maxH; i++ { + combined = append(combined, leftLines[i]+sep+rightLines[i]) + } + return strings.Join(combined, "\n") } // buildMetricRow 构建指标表格一行。 @@ -257,8 +269,129 @@ func UpdateTaskDetailHistory(s *TaskDetailState, history []types.TaskRunSummary, return s } s.History = history + if len(history) == 0 { + s.HistorySel = 0 + s.HistoryOff = 0 + } else { + if s.HistorySel < 0 { + s.HistorySel = 0 + } + if s.HistorySel >= len(history) { + s.HistorySel = len(history) - 1 + } + s.HistoryOff = ensureVisibleOffset(s.HistorySel, len(history), s.HistoryOff, s.HistoryVis) + } if autoExpand && len(history) > 0 { s.LatestExpanded = true } return s } + +func buildTaskHistoryDetailLines(s *TaskDetailState, st Styles, width int) []string { + if s.HistorySel < 0 || s.HistorySel >= len(s.History) { + return nil + } + sel := s.History[s.HistorySel] + elapsed := sel.FinishedAt.Sub(sel.StartedAt) + labelW := 8 + indent := " " + gap := 4 + contentW := maxInt(12, width-lipgloss.Width(indent)) + useTwoCols := contentW >= 48 + + statusText := sel.Status + statusStyle := st.Value + switch sel.Status { + case "completed": + statusText = "完成" + statusStyle = st.Ok + case "failed": + statusText = "失败" + statusStyle = st.ErrStyle + case "stopped": + statusText = "已停止" + statusStyle = st.Muted + } + + modeText := "标准" + if sel.Mode == "turbo" { + modeText = "Turbo" + } + + renderCell := func(label, value string, valueStyle lipgloss.Style, cellW int) string { + prefix := st.Label.Render(padRight(label, labelW)) + available := maxInt(6, cellW-labelW-2) + return prefix + " " + valueStyle.Render(truncate(value, available)) + } + + appendSingleField := func(lines []string, label, value string, valueStyle lipgloss.Style) []string { + valueW := maxInt(10, contentW-labelW-2) + segments := wrapText(value, valueW) + if len(segments) == 0 { + segments = []string{""} + } + lines = append(lines, indent+st.Label.Render(padRight(label, labelW))+" "+valueStyle.Render(segments[0])) + contIndent := strings.Repeat(" ", lipgloss.Width(indent)+labelW+2) + for _, seg := range segments[1:] { + lines = append(lines, contIndent+valueStyle.Render(seg)) + } + return lines + } + + appendPairRow := func(lines []string, leftLabel, leftValue string, leftStyle lipgloss.Style, rightLabel, rightValue string, rightStyle lipgloss.Style) []string { + if !useTwoCols { + lines = appendSingleField(lines, leftLabel, leftValue, leftStyle) + return appendSingleField(lines, rightLabel, rightValue, rightStyle) + } + leftW := (contentW - gap) / 2 + rightW := contentW - gap - leftW + row := indent + padRight(renderCell(leftLabel, leftValue, leftStyle, leftW), leftW) + strings.Repeat(" ", gap) + + renderCell(rightLabel, rightValue, rightStyle, rightW) + return append(lines, row) + } + + lines := []string{ + padRight(st.Divider.Render(strings.Repeat("─", width)), width), + padRight(" "+st.SectionHead.Render("记录详情"), width), + } + + lines = appendPairRow(lines, + "状态", statusText, statusStyle, + "模式", modeText, st.Value, + ) + lines = appendPairRow(lines, + "开始", sel.StartedAt.Format("2006-01-02 15:04"), st.Value, + "结束", sel.FinishedAt.Format("2006-01-02 15:04"), st.Value, + ) + lines = appendPairRow(lines, + "耗时", fmtDuration(elapsed), st.Value, + "成功率", fmt.Sprintf("%.1f%%", sel.SuccessRate), st.Value, + ) + lines = appendPairRow(lines, + "TTFT", fmtDuration(sel.AvgTTFT), st.Value, + "TPS", fmt.Sprintf("%.1f", sel.AvgTPS), st.MetricVal, + ) + lines = appendSingleField(lines, "协议", shortProtocol(sel.Protocol), st.Value) + lines = appendSingleField(lines, "模型", sel.Model, st.Value) + if sel.CacheHitRate > 0 { + lines = appendSingleField(lines, "缓存", fmt.Sprintf("%.1f%%", sel.CacheHitRate*100), st.Value) + } + if sel.ReportJSONPath != "" || sel.ReportCSVPath != "" { + reports := make([]string, 0, 2) + if sel.ReportJSONPath != "" { + reports = append(reports, "JSON") + } + if sel.ReportCSVPath != "" { + reports = append(reports, "CSV") + } + lines = appendSingleField(lines, "报告", strings.Join(reports, " / "), st.Muted) + } + if sel.ErrorSummary != "" { + lines = append(lines, indent+st.Label.Render("错误摘要")) + for _, seg := range wrapText(sel.ErrorSummary, maxInt(10, contentW-2)) { + lines = append(lines, indent+" "+st.ErrStyle.Render(seg)) + } + } + + return lines +} diff --git a/internal/tui/pages/tasklist.go b/internal/tui/pages/tasklist.go index d16996d..97e62a1 100644 --- a/internal/tui/pages/tasklist.go +++ b/internal/tui/pages/tasklist.go @@ -15,6 +15,8 @@ import ( type TaskListState struct { Tasks []types.TaskDefinition Selected int + Offset int + Visible int // 运行中任务的进度(runID -> RunState 快照,由 Model 注入) ActiveRuns map[string]*server.RunState // taskID -> RunState } @@ -114,6 +116,8 @@ func HandleTaskListKey(s *TaskListState, msg tea.KeyMsg, client Client) (*TaskLi nav = NavAction{To: NavQuit} } + s.Offset = ensureVisibleOffset(s.Selected, len(s.Tasks), s.Offset, s.Visible) + return s, nil, nav } @@ -153,9 +157,9 @@ func RenderTaskList(s *TaskListState, st Styles, width, height int) string { } l := PageLayout{ TitleLeft: "AIT 任务中心", - InfoLeft: fmt.Sprintf("◆ AIT 已保存任务: %d %s", len(s.Tasks), lastRunStr), + InfoLeft: fmt.Sprintf("已保存任务: %d %s", len(s.Tasks), lastRunStr), CtxItems: cbItems, - FooterParts: []string{"[↑↓] 选择", "[a] 新建", "[y] 复制", "[q] 退出", "◆ AIT v0.1"}, + FooterParts: []string{"[↑↓] 选择", "[a] 新建", "[q] 退出", "◆ AIT v0.1"}, } content := buildTaskListContent(s, st, ContentWidth(width), l.ContentHeight(height)) @@ -164,31 +168,31 @@ func RenderTaskList(s *TaskListState, st Styles, width, height int) string { // buildTaskListContent 构建任务列表内容区(含表头 + 任务条目)。 func buildTaskListContent(s *TaskListState, st Styles, width, maxH int) string { - innerW := width - 2 - if innerW < 20 { - innerW = 20 - } - var lines []string + showHero := width >= 60 && maxH >= 14 + if showHero { + heroLines := renderWelcomeHero(st, width) + lines = append(lines, heroLines...) + lines = append(lines, dividerLine(st, width)) + } + listTopLines := len(lines) - // 表头行 + // 列宽(合计 = nameW + modeW + protoW + 结果列) nameW := 28 modeW := 8 protoW := 14 - resultW := innerW - nameW - modeW - protoW - 4 - if resultW < 10 { - resultW = 10 - } - header := st.TableHead.Render( - " " + padRight("任务名称", nameW) + + // 表头:2 空格前缀与正文行对齐(cursor=2) + header := renderTableHeader(st, width, + " " + padRight("任务名称", nameW) + padRight("模式", modeW) + padRight("协议", protoW) + - "上次结果", - ) + "上次结果") lines = append(lines, header) - lines = append(lines, " "+st.Divider.Render(strings.Repeat("─", innerW-1))) - lines = append(lines, "") // 表头与第一条目之间的呼吸间距 + lines = append(lines, dividerLine(st, width)) + listMaxH := maxInt(3, maxH-listTopLines) + s.Visible = listVisibleItems(listMaxH, 2) + s.Offset = ensureVisibleOffset(s.Selected, len(s.Tasks), s.Offset, s.Visible) if len(s.Tasks) == 0 { lines = append(lines, "") @@ -196,62 +200,72 @@ func buildTaskListContent(s *TaskListState, st Styles, width, maxH int) string { return strings.Join(lines, "\n") } - for i, t := range s.Tasks { - if len(lines) >= maxH { - break - } + start := s.Offset + end := minInt(len(s.Tasks), start+s.Visible) + for i := start; i < end; i++ { + t := s.Tasks[i] isRunning := s.IsTaskRunning(t.ID) isSel := i == s.Selected rs := s.ActiveRuns[t.ID] - // ── 指示符和运行中标记 ── - cursor := " " - if isSel { - cursor = "▶ " - } - runMark := " " - if isRunning { - runMark = st.Ok.Render("◉") + " " - } - prefix := cursor + runMark + // ── 指示符 ── + prefix := padRight(selectionMarker(isSel), 2) - // ── 模式标签 ── - var modeTag string + // ── 模式(选中行禁用嵌套样式,避免重置整行背景)── + modeText := "标准" + modeCol := padRight(modeText, modeW) if t.Input.Turbo { - modeTag = st.TagTurbo.Render("Turbo") + modeText = "Turbo" + modeCol = padRight(styleWhenNotSelected(isSel, lipgloss.NewStyle().Foreground(colorGold).Bold(true), modeText), modeW) } else { - modeTag = st.TagStd.Render("标准 ") - } - modeTagW := lipgloss.Width(modeTag) - modePad := modeW - modeTagW - if modePad < 0 { - modePad = 0 + modeCol = padRight(styleWhenNotSelected(isSel, lipgloss.NewStyle().Foreground(colorPurple), modeText), modeW) } - modeCol := modeTag + strings.Repeat(" ", modePad) // ── 协议 ── proto := padRight(shortProtocol(t.Input.NormalizedProtocol()), protoW) - // ── 上次结果 ── - lastResult := st.Muted.Render("从未运行") + // ── 上次结果(选中行禁用嵌套样式,避免重置整行背景)── + lastResultText := "从未运行" if t.LastRunSummary != nil { pct := t.LastRunSummary.SuccessRate if t.Input.Turbo { if t.LastRunSummary.MaxStableConcurrency > 0 { - lastResult = st.Ok.Render(fmt.Sprintf("★ 并发%d", t.LastRunSummary.MaxStableConcurrency)) + lastResultText = fmt.Sprintf("★ 并发%d", t.LastRunSummary.MaxStableConcurrency) } } else { switch { case pct >= 99: - lastResult = st.Ok.Render(fmt.Sprintf("✓ %.1f%%", pct)) + lastResultText = fmt.Sprintf("✓ %.1f%%", pct) case pct >= 90: - lastResult = st.MetricVal.Render(fmt.Sprintf("%.1f%%", pct)) + lastResultText = fmt.Sprintf("%.1f%%", pct) default: - lastResult = st.ErrStyle.Render(fmt.Sprintf("✗ %.1f%%", pct)) + lastResultText = fmt.Sprintf("✗ %.1f%%", pct) } } } + if isRunning && rs != nil { + lastResultText = fmt.Sprintf("◉ %d/%d %.0f%%", rs.DoneReqs, rs.TotalReqs, rs.SuccessRate) + } + + lastResult := lastResultText + if isRunning && rs != nil { + lastResult = styleWhenNotSelected(isSel, st.Ok, lastResultText) + } else if t.LastRunSummary == nil { + lastResult = styleWhenNotSelected(isSel, st.Muted, lastResultText) + } else if t.Input.Turbo && t.LastRunSummary.MaxStableConcurrency > 0 { + lastResult = styleWhenNotSelected(isSel, st.Ok, lastResultText) + } else if !t.Input.Turbo { + pct := t.LastRunSummary.SuccessRate + switch { + case pct >= 99: + lastResult = styleWhenNotSelected(isSel, st.Ok, lastResultText) + case pct >= 90: + lastResult = styleWhenNotSelected(isSel, st.MetricVal, lastResultText) + default: + lastResult = styleWhenNotSelected(isSel, st.ErrStyle, lastResultText) + } + } // ── 任务名称(裁剪)── name := truncate(t.Name, nameW) @@ -262,48 +276,13 @@ func buildTaskListContent(s *TaskListState, st Styles, width, maxH int) string { nameCol := name + strings.Repeat(" ", namePad) // ── 第一行 ── - prefixW := lipgloss.Width(prefix) row1Content := nameCol + modeCol + proto + lastResult - var row1 string - if isSel { - row1 = st.TableRowSel.Render(prefix+row1Content) + strings.Repeat(" ", max(0, width-prefixW-lipgloss.Width(row1Content)-2)) - } else { - row1 = " " + runMark + row1Content - } + row1 := renderTableRow(st, width, isSel, prefix+row1Content) lines = append(lines, row1) - // ── 第二行(模型 + 参数 + 实时进度)── - if len(lines) < maxH { - indent := " " // 5 空格缩进(对齐任务名) - var params string - if t.Input.Turbo { - tc := t.Input.TurboConfig - params = fmt.Sprintf("%s %d→%d 步进+%d", - truncate(t.Input.Model, 12), - tc.InitConcurrency, tc.MaxConcurrency, tc.StepSize) - if t.LastRunSummary != nil { - params += fmt.Sprintf(" 上次: 峰值 TPS %.1f", t.LastRunSummary.AvgTPS) - } - } else { - params = fmt.Sprintf("%s 并发%d 请求%d", - truncate(t.Input.Model, 12), - t.Input.Concurrency, t.Input.Count) - } - - // 实时进度 - if isRunning && rs != nil { - prog := fmt.Sprintf(" %s %d/%d 成功率 %.1f%%", - st.Ok.Render("◉"), rs.DoneReqs, rs.TotalReqs, rs.SuccessRate*100) - params += prog - } - - row2 := indent + st.Muted.Render(truncate(params, width-7)) - lines = append(lines, row2) - } - - // ── 空行分隔 ── - if i < len(s.Tasks)-1 && len(lines) < maxH-1 { - lines = append(lines, "") + // ── 分隔线 ── + if i < end-1 && len(lines) < maxH-1 { + lines = append(lines, dividerLine(st, width)) } } diff --git a/internal/tui/pages/turbodash.go b/internal/tui/pages/turbodash.go index 4356fb8..c07c394 100644 --- a/internal/tui/pages/turbodash.go +++ b/internal/tui/pages/turbodash.go @@ -18,6 +18,8 @@ type TurboDashState struct { CancelFn server.CancelFunc RunState *server.RunState LevelSel int // 选中的级别索引(-1 = 无选中) + LevelOff int + LevelVis int } // NewTurboDashState 创建 Turbo 仪表盘初始状态。 @@ -103,6 +105,7 @@ func HandleTurboDashKey(d *TurboDashState, msg tea.KeyMsg, client Client) (*Turb case "q", "ctrl+c": nav = NavAction{To: NavQuit} } + d.LevelOff = ensureVisibleOffset(d.LevelSel, len(levels), d.LevelOff, d.LevelVis) return d, nil, nav } @@ -144,7 +147,7 @@ func RenderTurboDash(d *TurboDashState, taskName string, st Styles, width, heigh subtitle := "─" if rs != nil && len(rs.Levels) > 0 { - subtitle = fmt.Sprintf("◆ AIT %s · 当前并发: %d 已完成 %d 级", + subtitle = fmt.Sprintf("%s · 当前并发: %d 已完成 %d 级", "─", rs.CurrentLevel, len(rs.Levels)) } @@ -159,7 +162,7 @@ func RenderTurboDash(d *TurboDashState, taskName string, st Styles, width, heigh TitleRight: statusStr, InfoLeft: subtitle, CtxItems: cbItems, - FooterParts: []string{"[s] 停止", "[b] 后台运行", "[m] 标记极限", "[r] 提前报告", "[q] 退出"}, + FooterParts: []string{"[q] 退出"}, } // ── 计算高度 ── @@ -228,7 +231,7 @@ func buildTurboDashMetrics(rs *server.RunState, st Styles, maxH, width int) stri if rs == nil { lines = append(lines, " "+st.Muted.Render("等待数据...")) } else { - lines = append(lines, " "+labelValue(st, "成功率 ", st.MetricVal.Render(fmt.Sprintf("%.1f%%", rs.SuccessRate*100)))) + lines = append(lines, " "+labelValue(st, "成功率 ", st.MetricVal.Render(fmt.Sprintf("%.1f%%", rs.SuccessRate)))) lines = append(lines, " "+labelValue(st, "TPS ", st.MetricVal.Render(fmt.Sprintf("%.1f tok/s", rs.AvgTPS)))) lines = append(lines, " "+labelValue(st, "TTFT ", st.MetricVal.Render(fmtDuration(rs.AvgTTFT)))) lines = append(lines, " "+labelValue(st, "Cache ", st.MetricVal.Render(fmt.Sprintf("%.1f%%", rs.CacheHitRate*100)))) @@ -274,51 +277,66 @@ func buildLevelList(d *TurboDashState, rs *server.RunState, st Styles, width, ma return strings.Join(lines, "\n") } - // 表头 - lines = append(lines, " "+st.TableHead.Render( - padRight("并发", 6)+padRight("成功率", 8)+padRight("TPS", 10)+ - padRight("TTFT", 10)+padRight("Cache", 8)+padRight("总耗时", 9)+"结论")) - lines = append(lines, " "+st.Divider.Render(strings.Repeat("─", width-2))) - - for i, lv := range rs.Levels { - if len(lines) >= maxH { - break - } + // 列宽(header 与 content 行保持一致,前缀均为 2 字符) + const ( + markW = 2 // 选择标记列 + concW = 6 // 并发数 + rateW = 8 // 成功率 + tpsW = 10 // TPS + ttftW = 10 // TTFT + cacheW = 8 // Cache + totW = 9 // 总耗时 + // 结论: 余量 + ) + hdr := padRight("", markW) + padRight("并发", concW) + padRight("成功率", rateW) + padRight("TPS", tpsW) + + padRight("TTFT", ttftW) + padRight("Cache", cacheW) + padRight("总耗时", totW) + "结论" + lines = append(lines, renderTableHeader(st, width, hdr)) + lines = append(lines, dividerLine(st, width)) + d.LevelVis = listVisibleItems(maxH, 3) + d.LevelOff = ensureVisibleOffset(d.LevelSel, len(rs.Levels), d.LevelOff, d.LevelVis) + start := d.LevelOff + end := minInt(len(rs.Levels), start+d.LevelVis) + + for i := start; i < end; i++ { + lv := rs.Levels[i] isSel := i == d.LevelSel - conclusion := st.Ok.Render("✓ 稳定") + conclusionText := "✓ 稳定" if !lv.Stable { - conclusion = st.ErrStyle.Render("✗ 降级") + conclusionText = "✗ 降级" } - // 当前进行中的级别 isCurrent := (i == len(rs.Levels)-1) && rs.Status == server.RunStatusRunning if isCurrent { - conclusion = st.MetricVal.Render("🔄 进行中") + conclusionText = "🔄 进行中" } - row := fmt.Sprintf(" %s%s%s%s%s%s%s", - padRight(fmt.Sprintf("%d", lv.Concurrency), 6), - padRight(fmt.Sprintf("%.1f%%", lv.SuccessRate*100), 8), - padRight(fmt.Sprintf("%.1f", lv.AvgTPS), 10), - padRight(fmtDuration(lv.AvgTTFT), 10), - padRight(fmt.Sprintf("%.1f%%", lv.CacheHitRate*100), 8), - padRight(fmtDuration(lv.AvgTotalTime), 9), - conclusion, - ) - - cursorStr := " " - if isSel { - cursorStr = "▶ " - } - - var rendered string - if isSel { - rendered = st.TableRowSel.Render(cursorStr+row) + - strings.Repeat(" ", max(0, width-len([]rune(cursorStr+row))-2)) + conclusion := conclusionText + if isCurrent { + conclusion = styleWhenNotSelected(isSel, st.MetricVal, conclusionText) + } else if lv.Stable { + conclusion = styleWhenNotSelected(isSel, st.Ok, conclusionText) } else { - rendered = " " + st.TableRow.Render(row) + conclusion = styleWhenNotSelected(isSel, st.ErrStyle, conclusionText) } + + marker := selectionMarker(isSel) + + rowContent := padRight(marker, markW) + + padRight(fmt.Sprintf("%d", lv.Concurrency), concW) + + padRight(fmt.Sprintf("%.1f%%", lv.SuccessRate*100), rateW) + + padRight(fmt.Sprintf("%.1f", lv.AvgTPS), tpsW) + + padRight(fmtDuration(lv.AvgTTFT), ttftW) + + padRight(fmt.Sprintf("%.1f%%", lv.CacheHitRate*100), cacheW) + + padRight(fmtDuration(lv.AvgTotalTime), totW) + + conclusion + + rendered := renderTableRow(st, width, isSel, rowContent) lines = append(lines, rendered) + + // 行间分隔线 + if i < end-1 && len(lines) < maxH-1 { + lines = append(lines, dividerLine(st, width)) + } } for len(lines) < maxH { diff --git a/internal/tui/pages/wizard.go b/internal/tui/pages/wizard.go index 99f9422..03ce931 100644 --- a/internal/tui/pages/wizard.go +++ b/internal/tui/pages/wizard.go @@ -63,22 +63,23 @@ type WizardState struct { // 当前活跃字段索引(Tab 切换) FieldIndex int + ScrollOff int } // NewWizardState 创建新建任务向导状态(使用默认值)。 func NewWizardState() *WizardState { return &WizardState{ - Step: wizardStep1, - Protocol: types.ProtocolOpenAICompletions, - Concurrency: 10, - Count: 100, - Timeout: 30, + Step: wizardStep1, + Protocol: types.ProtocolOpenAICompletions, + Concurrency: 10, + Count: 100, + Timeout: 30, InitConcurrency: 1, - MaxConcurrency: 50, - StepSize: 2, - LevelRequests: 30, - MinSuccessRate: 90, - PromptMode: PromptModeText, + MaxConcurrency: 50, + StepSize: 2, + LevelRequests: 30, + MinSuccessRate: 90, + PromptMode: PromptModeText, } } @@ -90,27 +91,27 @@ func NewWizardStateEdit(t *types.TaskDefinition) *WizardState { inp := t.Input tc := inp.TurboConfig wz := &WizardState{ - Step: wizardStep1, - EditingID: t.ID, - Name: t.Name, - Protocol: types.NormalizeProtocol(inp.Protocol), - EndpointURL: inp.EndpointURL, - APIKey: inp.ApiKey, - Model: inp.Model, - Turbo: inp.Turbo, - Stream: inp.Stream, - Concurrency: inp.Concurrency, - Count: inp.Count, - Timeout: int(inp.Timeout.Seconds()), + Step: wizardStep1, + EditingID: t.ID, + Name: t.Name, + Protocol: types.NormalizeProtocol(inp.Protocol), + EndpointURL: inp.EndpointURL, + APIKey: inp.ApiKey, + Model: inp.Model, + Turbo: inp.Turbo, + Stream: inp.Stream, + Concurrency: inp.Concurrency, + Count: inp.Count, + Timeout: int(inp.Timeout.Seconds()), InitConcurrency: tc.InitConcurrency, - MaxConcurrency: tc.MaxConcurrency, - StepSize: tc.StepSize, - LevelRequests: tc.LevelRequests, - MinSuccessRate: tc.MinSuccessRate * 100, // 转为百分比 - PromptMode: inp.PromptMode, - PromptText: inp.PromptText, - PromptFile: inp.PromptFile, - PromptLength: inp.PromptLength, + MaxConcurrency: tc.MaxConcurrency, + StepSize: tc.StepSize, + LevelRequests: tc.LevelRequests, + MinSuccessRate: tc.MinSuccessRate * 100, // 转为百分比 + PromptMode: inp.PromptMode, + PromptText: inp.PromptText, + PromptFile: inp.PromptFile, + PromptLength: inp.PromptLength, } if wz.PromptMode == "" { wz.PromptMode = PromptModeText @@ -293,11 +294,6 @@ func step2Fields(turbo bool) []fieldDef { } }, }, - fieldDef{ - kind: fieldBool, label: "流式模式", - get: func(wz *WizardState) string { return boolLabel(wz.Stream) }, - toggle: func(wz *WizardState, _ bool) { wz.Stream = !wz.Stream }, - }, ) } else { fields = append(fields, @@ -349,6 +345,14 @@ func step2Fields(turbo bool) []fieldDef { ) } + // 流式模式:与测试模式无关,两种模式均可配置 + fields = append(fields, fieldDef{ + kind: fieldBool, + label: "流式模式", + get: func(wz *WizardState) string { return boolLabel(wz.Stream) }, + toggle: func(wz *WizardState, _ bool) { wz.Stream = !wz.Stream }, + }) + // Prompt 字段(共用) promptModes := []string{PromptModeText, PromptModeFile, PromptModeGenerated} fields = append(fields, @@ -433,6 +437,19 @@ func HandleWizardKey(wz *WizardState, msg tea.KeyMsg, client Client) (*WizardSta case "esc": wz.Step = wizardStep2 wz.FieldIndex = 0 + wz.ScrollOff = 0 + case "up", "k": + wz.ScrollOff-- + case "down", "j": + wz.ScrollOff++ + case "pgup": + wz.ScrollOff -= 5 + case "pgdown", " ": + wz.ScrollOff += 5 + case "home": + wz.ScrollOff = 0 + case "end": + wz.ScrollOff = 1 << 30 case "enter": // 保存任务 cfg := wz.BuildTaskConfig() @@ -470,6 +487,7 @@ func HandleWizardKey(wz *WizardState, msg tea.KeyMsg, client Client) (*WizardSta } else { wz.Step-- wz.FieldIndex = 0 + wz.ScrollOff = 0 } case "tab", "down", "j": @@ -490,6 +508,7 @@ func HandleWizardKey(wz *WizardState, msg tea.KeyMsg, client Client) (*WizardSta // 如果切换了 turbo 模式,重置 fieldIndex if f.label == "测试模式" { wz.FieldIndex = 0 + wz.ScrollOff = 0 } } } @@ -501,6 +520,7 @@ func HandleWizardKey(wz *WizardState, msg tea.KeyMsg, client Client) (*WizardSta f.toggle(wz, true) if f.label == "测试模式" { wz.FieldIndex = 0 + wz.ScrollOff = 0 } } } @@ -509,6 +529,7 @@ func HandleWizardKey(wz *WizardState, msg tea.KeyMsg, client Client) (*WizardSta if wz.FieldIndex == maxField && int(wz.Step) < 2 { wz.Step++ wz.FieldIndex = 0 + wz.ScrollOff = 0 } else if wz.FieldIndex < maxField { wz.FieldIndex++ } @@ -541,116 +562,144 @@ func HandleWizardKey(wz *WizardState, msg tea.KeyMsg, client Client) (*WizardSta return wz, nil, nav } -// RenderWizard 渲染三步弹窗向导(overlay 覆盖在后台页面上)。 -func RenderWizard(wz *WizardState, bgView string, st Styles, width, height int) string { +// RenderWizard 渲染三步创建/编辑任务页。 +func RenderWizard(wz *WizardState, st Styles, width, height int) string { + if TooSmall(width, height) { + return renderTooSmall(st, width, height) + } if wz == nil { - return bgView + return renderTooSmall(st, width, height) + } + + stepTitles := []string{"基本信息", "测试参数", "确认保存"} + stepDescs := []string{ + "配置任务名称、模型协议和连接信息。", + "选择压测模式,并补全并发与 Prompt 参数。", + "保存前快速检查关键配置。", + } + action := "创建任务" + if wz.EditingID != "" { + action = "编辑任务" } - // 暗化背景 - bgLines := strings.Split(bgView, "\n") - for i, line := range bgLines { - bgLines[i] = st.Muted.Render(line) + l := PageLayout{ + TitleLeft: fmt.Sprintf("AIT %s", action), + InfoLeft: fmt.Sprintf("步骤 %d/3 · %s", int(wz.Step)+1, stepTitles[int(wz.Step)]), + CtxItems: wizardContextItems(wz.Step), + FooterParts: []string{"[q] 退出", "◆ AIT v0.1"}, } - // 弹窗尺寸 - dialogW := width - 8 - if dialogW > 72 { - dialogW = 72 + content := buildWizardPageContent(wz, st, action, stepTitles[int(wz.Step)], stepDescs[int(wz.Step)], ContentWidth(width), l.ContentHeight(height)) + return l.Assemble(wrapPanel(st, content, width), st, width) +} + +func buildWizardPageContent(wz *WizardState, st Styles, action, stepTitle, stepDesc string, width, maxH int) string { + titleLeft := st.SectionHead.Render(action) + titleRight := st.Muted.Render(fmt.Sprintf("步骤 %d/3 · %s", int(wz.Step)+1, stepTitle)) + var topLines []string + if lipgloss.Width(titleLeft)+lipgloss.Width(titleRight)+2 <= width { + topLines = append(topLines, titleLeft+strings.Repeat(" ", width-lipgloss.Width(titleLeft)-lipgloss.Width(titleRight))+titleRight) + } else { + topLines = append(topLines, titleLeft, titleRight) } - if dialogW < 40 { - dialogW = 40 + if maxH >= 8 { + for _, line := range wrapText(stepDesc, width) { + topLines = append(topLines, st.Muted.Render(line)) + } + } + if maxH >= 10 && width >= 46 { + topLines = append(topLines, renderWizardStepStrip(wz.Step)) } - var dialogLines []string + bottomCount := 1 + showBottomDivider := maxH >= 6 + if showBottomDivider { + bottomCount = 2 + } - stepTitles := []string{"1/3 · 基本信息", "2/3 · 测试参数", "3/3 · 确认保存"} - stepTitle := stepTitles[int(wz.Step)] - isEdit := wz.EditingID != "" - action := "新建任务" - if isEdit { - action = "编辑任务" + // 为 body 保留最少 5 行空间 + minBodyH := 5 + availableForContent := maxH - bottomCount + maxTopH := maxInt(1, availableForContent-minBodyH) + + // 限制 topLines 大小 + if len(topLines) > maxTopH { + topLines = topLines[:maxTopH] + } + if maxH >= 6 { + topLines = append(topLines, dividerLine(st, width)) + } + + bodyLines, focusLine := buildWizardBody(wz, st, width) + bodyH := maxInt(1, availableForContent-len(topLines)) + offset := 0 + if wz.Step == wizardStep3 { + offset = clampInt(wz.ScrollOff, 0, maxInt(0, len(bodyLines)-bodyH)) + } else if focusLine >= 0 { + offset = ensureVisibleOffset(focusLine, len(bodyLines), 0, bodyH) + } + end := minInt(len(bodyLines), offset+bodyH) + visibleBody := append([]string{}, bodyLines[offset:end]...) + for len(visibleBody) < bodyH { + visibleBody = append(visibleBody, "") + } + + lines := append([]string{}, topLines...) + lines = append(lines, visibleBody...) + if showBottomDivider { + lines = append(lines, dividerLine(st, width)) + } + lines = append(lines, st.Muted.Render(truncate(wizardStatusText(wz, focusLine, offset, end, len(bodyLines), bodyH), width))) + + if len(lines) > maxH { + lines = lines[:maxH] + } + for len(lines) < maxH { + lines = append(lines, "") + } + return strings.Join(lines, "\n") +} + +func buildWizardBody(wz *WizardState, st Styles, contentW int) ([]string, int) { + var lines []string + focusLine := -1 + + // appendField 将字段渲染结果按行展开追加,因为 FieldActive/FieldIdle 带 Border + // 会产生 3 行输出(顶部边框 + 内容 + 底部边框),必须逐行记录才能正确计算高度。 + appendField := func(rendered string, focused bool) { + if focused { + focusLine = len(lines) + } + for _, l := range strings.Split(rendered, "\n") { + lines = append(lines, l) + } } - dialogLines = append(dialogLines, st.SectionHead.Render(fmt.Sprintf(" %s %s", action, stepTitle))) - dialogLines = append(dialogLines, "") switch wz.Step { case wizardStep1: fields := step1Fields() for i, f := range fields { - dialogLines = append(dialogLines, renderWizardField(st, f, wz, i == wz.FieldIndex, dialogW-4)) - dialogLines = append(dialogLines, "") + appendField(renderWizardField(st, f, wz, i == wz.FieldIndex, contentW), i == wz.FieldIndex) } - dialogLines = append(dialogLines, dividerLine(st, dialogW-4)) - hintStyle := st.Muted - dialogLines = append(dialogLines, hintStyle.Render(" [Tab] 下一项 [↑↓] 切换协议 [Enter] 下一步 [Esc] 取消")) case wizardStep2: fields := step2Fields(wz.Turbo) for i, f := range fields { - dialogLines = append(dialogLines, renderWizardField(st, f, wz, i == wz.FieldIndex, dialogW-4)) - dialogLines = append(dialogLines, "") + if f.label == "输入方式" { + lines = append(lines, "", st.Muted.Render("Prompt 配置")) + } + appendField(renderWizardField(st, f, wz, i == wz.FieldIndex, contentW), i == wz.FieldIndex) } - dialogLines = append(dialogLines, dividerLine(st, dialogW-4)) - dialogLines = append(dialogLines, st.Muted.Render(" [Tab] 下一项 [←→] 切换模式 [Enter] 下一步 [Esc] 返回")) case wizardStep3: - dialogLines = append(dialogLines, renderStep3Summary(wz, st, dialogW-4)...) - dialogLines = append(dialogLines, "") - dialogLines = append(dialogLines, dividerLine(st, dialogW-4)) - dialogLines = append(dialogLines, st.Muted.Render(" [Enter] 保存任务 [r] 保存并运行 [Esc] 返回修改")) - } - - // 构建弹窗框 - innerLines := dialogLines - boxedLines := make([]string, len(innerLines)) - for i, l := range innerLines { - lW := lipgloss.Width(l) - pad := dialogW - 4 - lW - if pad < 0 { - pad = 0 - } - boxedLines[i] = " " + l + strings.Repeat(" ", pad) + lines = append(lines, renderStep3Summary(wz, st, contentW)...) } - // 用 lipgloss rounded border 包裹 - inner := strings.Join(boxedLines, "\n") - box := st.Dialog.Width(dialogW).Render(inner) - - // 将弹窗叠加在背景中间 - boxLines := strings.Split(box, "\n") - startRow := (height - len(boxLines)) / 2 - if startRow < 0 { - startRow = 0 - } - startCol := (width - dialogW) / 2 - if startCol < 0 { - startCol = 0 - } - - for i, boxLine := range boxLines { - row := startRow + i - if row >= len(bgLines) { - bgLines = append(bgLines, strings.Repeat(" ", width)) - } - bgLine := []rune(bgLines[row]) - boxRunes := []rune(boxLine) - // 替换对应列 - for j, r := range boxRunes { - col := startCol + j - if col < len(bgLine) { - bgLine[col] = r - } - } - bgLines[row] = string(bgLine) - } - - return strings.Join(bgLines, "\n") + return lines, focusLine } // renderWizardField 渲染向导的一个字段行。 func renderWizardField(st Styles, f fieldDef, wz *WizardState, active bool, maxW int) string { - label := padRight(f.label, 12) var valueStr string if f.get != nil { @@ -662,66 +711,203 @@ func renderWizardField(st Styles, f fieldDef, wz *WizardState, active bool, maxW valueStr = maskAPIKey(valueStr) } - var renderedValue string - if active { - if f.kind == fieldEnum || f.kind == fieldBool { - renderedValue = st.Ok.Render("● " + valueStr) - } else { - renderedValue = st.FieldActive.Width(maxW - 14).Render(valueStr + "█") // 光标 + // FieldActive/Idle: Width(n) = 内容区宽度(在 padding/border 之内) + // 总渲染宽度 = n + padding(2) + border(2) = n + 4 + // Line1 = label(14) + space(1) + (n+4) = n + 19 ≤ maxW → n = maxW - 19 + fieldW := maxInt(10, maxW-19) + valueStyle := st.Value + if valueStr == "" && !active { + valueStr = "未填写" + valueStyle = st.Muted + } + + // Width(fieldW) 是内容区宽度,padding 在其外侧叠加,文字区即为 fieldW + // 激活时保留 1 列给光标 █,非激活可用满 fieldW + if f.kind == fieldEnum || f.kind == fieldBool { + if active { + valueStr = "‹ " + valueStr + " ›" } + valueStr = truncate(valueStr, maxInt(4, fieldW)) } else { - if f.kind == fieldEnum || f.kind == fieldBool { - renderedValue = st.Muted.Render("○ " + valueStr) + if active { + valueStr = fitTail(valueStr, maxInt(1, fieldW-1)) + "█" } else { - renderedValue = st.FieldIdle.Width(maxW - 14).Render(valueStr) + valueStr = fitTail(valueStr, maxInt(1, fieldW)) } } - return " " + st.Label.Render(label) + " " + renderedValue + fieldStyle := st.FieldIdle + if active { + fieldStyle = st.FieldActive + } + + renderedValue := fieldStyle.Width(fieldW).Render(valueStyle.Render(valueStr)) + // 用 JoinHorizontal 而非字符串拼接:renderedValue 有 3 行(上边框/内容/下边框), + // 直接 + 只有第一行有 label 前缀,后两行会从列 0 开始,导致布局混乱。 + // JoinHorizontal(Top, ...) 会将 label 块和 field 块按顶部对齐水平拼接, + // label 块高度自动补齐到与 field 相同(3 行),布局整齐。 + labelBlock := lipgloss.NewStyle().Width(15).Render(st.Label.Render(wizardFieldLabel(f, wz))) + return lipgloss.JoinHorizontal(lipgloss.Top, labelBlock, renderedValue) } // renderStep3Summary 渲染步骤3的确认内容。 func renderStep3Summary(wz *WizardState, st Styles, innerW int) []string { var lines []string - addRow := func(label, value string) { - lines = append(lines, " "+st.Label.Render(padRight(label, 12))+" "+st.Value.Render(value)) + addRow := func(label, value string, valueStyle lipgloss.Style) { + appendWizardSummaryRow(&lines, st, label, value, innerW, valueStyle) } - addRow("任务名称", wz.Name) - addRow("协议", wz.Protocol) + lines = append(lines, st.SectionHead.Render("配置概览")) + addRow("任务名称", wizardFallback(wz.Name, "未命名任务"), st.Value) + addRow("协议", wz.Protocol, st.Value) endpointDisplay := wz.EndpointURL if endpointDisplay == "" { endpointDisplay = types.DefaultEndpointURL(wz.Protocol) } - addRow("接口地址", truncate(endpointDisplay, innerW-20)) - addRow("API 密钥", maskAPIKey(wz.APIKey)) - addRow("测试模型", wz.Model) + addRow("接口地址", endpointDisplay, st.Value) + addRow("API 密钥", wizardFallback(maskAPIKey(wz.APIKey), "未填写"), st.Value) + addRow("测试模型", wizardFallback(wz.Model, "未填写"), st.Value) + lines = append(lines, "", st.SectionHead.Render("执行参数")) if wz.Turbo { - addRow("测试模式", "Turbo 模式") - addRow("并发爬坡", fmt.Sprintf("%d → %d 步进 +%d 每级 %d 请求", - wz.InitConcurrency, wz.MaxConcurrency, wz.StepSize, wz.LevelRequests)) - addRow("停止条件", fmt.Sprintf("成功率 < %.0f%%", wz.MinSuccessRate)) + addRow("测试模式", "Turbo 模式", st.Value) + addRow("并发爬坡", fmt.Sprintf("%d → %d · 步进 +%d · 每级 %d 请求", + wz.InitConcurrency, wz.MaxConcurrency, wz.StepSize, wz.LevelRequests), st.Value) + addRow("停止条件", fmt.Sprintf("成功率 < %.0f%%", wz.MinSuccessRate), st.Value) } else { - addRow("测试模式", "标准模式") - addRow("并发数", strconv.Itoa(wz.Concurrency)) - addRow("请求总数", strconv.Itoa(wz.Count)) - addRow("超时", fmt.Sprintf("%ds", wz.Timeout)) - addRow("流式模式", boolLabel(wz.Stream)) - } - - promptDesc := wz.PromptText - if wz.PromptMode == PromptModeFile { - promptDesc = "文件: " + wz.PromptFile + addRow("测试模式", "标准模式", st.Value) + addRow("并发数", strconv.Itoa(wz.Concurrency), st.Value) + addRow("请求总数", strconv.Itoa(wz.Count), st.Value) + addRow("超时", fmt.Sprintf("%ds", wz.Timeout), st.Value) + addRow("流式模式", boolLabel(wz.Stream), st.Value) + } + + lines = append(lines, "", st.SectionHead.Render("Prompt")) + addRow("输入方式", wizardPromptModeLabel(wz.PromptMode), st.Value) + promptDesc := promptSummary(wz.PromptMode, wz.PromptText, wz.PromptFile, wz.PromptLength) + addRow("内容摘要", wizardFallback(promptDesc, "未填写"), st.Value) + if wz.PromptMode == PromptModeText { + addRow("字符数", strconv.Itoa(len([]rune(wz.PromptText))), st.Muted) } else if wz.PromptMode == PromptModeGenerated { - promptDesc = fmt.Sprintf("生成 %d 字符", wz.PromptLength) + addRow("目标长度", strconv.Itoa(wz.PromptLength), st.Muted) } - addRow("Prompt", truncate(promptDesc, innerW-20)+fmt.Sprintf(" (长度: %d)", len([]rune(wz.PromptText)))) - lines = append(lines, "") - lines = append(lines, " "+st.Muted.Render("保存任务到 ~/.ait/tasks.json [✓]")) - lines = append(lines, "") - lines = append(lines, " "+st.BtnPrimary.Render("▶ 保存任务")) + lines = append(lines, "", st.Muted.Render("保存位置: ~/.ait/tasks.json")) return lines } + +func renderWizardStepStrip(step wizardStep) string { + active := lipgloss.NewStyle().Background(colorPink).Foreground(colorWhite).Bold(true).Padding(0, 1) + done := lipgloss.NewStyle().Background(colorCyan).Foreground(lipgloss.Color("233")).Bold(true).Padding(0, 1) + idle := lipgloss.NewStyle().Background(lipgloss.Color("238")).Foreground(colorMuted).Padding(0, 1) + labels := []string{"1 基本信息", "2 测试参数", "3 确认保存"} + parts := make([]string, 0, len(labels)) + for i, label := range labels { + switch { + case i < int(step): + parts = append(parts, done.Render("✓ "+label)) + case i == int(step): + parts = append(parts, active.Render(label)) + default: + parts = append(parts, idle.Render(label)) + } + } + return strings.Join(parts, " ") +} + +func wizardFieldLabel(f fieldDef, wz *WizardState) string { + if f.label != "内容" { + if f.label == "最低成功率%" { + return "最低成功率" + } + return f.label + } + switch wz.PromptMode { + case PromptModeFile: + return "文件路径" + case PromptModeGenerated: + return "生成长度" + default: + return "Prompt" + } +} + +func wizardPromptModeLabel(mode string) string { + switch mode { + case PromptModeFile: + return "文件" + case PromptModeGenerated: + return "按长度生成" + default: + return "直接输入" + } +} + +func wizardFallback(value, fallback string) string { + if value == "" { + return fallback + } + return value +} + +func fitTail(s string, maxW int) string { + if maxW <= 0 { + return "" + } + if lipgloss.Width(s) <= maxW { + return s + } + runes := []rune(s) + width := 0 + for i := len(runes) - 1; i >= 0; i-- { + rw := lipgloss.Width(string(runes[i])) + if width+rw > maxW-1 { + return "…" + string(runes[i+1:]) + } + width += rw + } + return s +} + +func appendWizardSummaryRow(lines *[]string, st Styles, label, value string, width int, valueStyle lipgloss.Style) { + labelW := 14 + contentW := maxInt(8, width-labelW-1) + segments := wrapText(value, contentW) + if len(segments) == 0 { + segments = []string{""} + } + *lines = append(*lines, st.Label.Render(padRight(label, labelW))+" "+valueStyle.Render(segments[0])) + indent := strings.Repeat(" ", labelW+1) + for _, segment := range segments[1:] { + *lines = append(*lines, indent+valueStyle.Render(segment)) + } +} + +func wizardContextItems(step wizardStep) []ContextBarItem { + switch step { + case wizardStep1: + return CtxBar_Wizard_Step1() + case wizardStep2: + return CtxBar_Wizard_Step2() + default: + return CtxBar_Wizard_Step3() + } +} + +func wizardStatusText(wz *WizardState, focusLine, offset, end, total, visible int) string { + if total <= 0 { + return "暂无配置项" + } + if wz.Step == wizardStep3 { + if total > visible { + return fmt.Sprintf("确认项 %d-%d/%d", offset+1, end, total) + } + return fmt.Sprintf("共 %d 项待确认", total) + } + current := clampInt(focusLine+1, 1, total) + if total > visible { + return fmt.Sprintf("当前字段 %d/%d · 内容较多时自动滚动", current, total) + } + return fmt.Sprintf("当前字段 %d/%d", current, total) +} From 6c592ccad62e7a212cae765b0f27f5c52056383e Mon Sep 17 00:00:00 2001 From: Alain Date: Sun, 17 May 2026 15:00:50 +0800 Subject: [PATCH 10/52] =?UTF-8?q?refactor:=20=E7=A7=BB=E9=99=A4=E4=BB=AA?= =?UTF-8?q?=E8=A1=A8=E6=9D=BF=E5=92=8C=E4=BB=BB=E5=8A=A1=E8=AF=A6=E6=83=85?= =?UTF-8?q?=E9=A1=B5=E9=9D=A2=E7=9A=84=E7=8A=B6=E6=80=81=E6=A0=87=E8=AF=86?= =?UTF-8?q?=EF=BC=8C=E7=AE=80=E5=8C=96=E9=A1=B5=E9=9D=A2=E5=B8=83=E5=B1=80?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/tui/pages/dashboard.go | 24 --- internal/tui/pages/layout.go | 4 - internal/tui/pages/reqdetail.go | 10 -- internal/tui/pages/taskdetail.go | 4 - internal/tui/pages/tasklist.go | 6 - internal/tui/pages/turbodash.go | 15 -- internal/tui/pages/wizard.go | 249 ++++++++++++------------------- 7 files changed, 96 insertions(+), 216 deletions(-) diff --git a/internal/tui/pages/dashboard.go b/internal/tui/pages/dashboard.go index b0484cf..813afea 100644 --- a/internal/tui/pages/dashboard.go +++ b/internal/tui/pages/dashboard.go @@ -165,27 +165,6 @@ func RenderDashboard(d *DashboardState, taskName string, st Styles, width, heigh } rs := d.RunState - // ── 状态标识 ── - statusStr := "等待中" - if rs != nil { - switch rs.Status { - case server.RunStatusRunning: - statusStr = st.Ok.Render("运行中") - case server.RunStatusCompleted: - statusStr = st.Ok.Render("已完成") - case server.RunStatusFailed: - statusStr = st.ErrStyle.Render("失败") - case server.RunStatusStopped: - statusStr = st.Muted.Render("已停止") - } - } - - subtitle := "─" - if rs != nil { - subtitle = fmt.Sprintf("%s · %s · 并发: %d · 请求: %d", - "─", "─", 0, rs.TotalReqs) - } - var cbItems []ContextBarItem if d.ReqSel >= 0 && rs != nil && d.ReqSel < len(rs.Requests) { cbItems = CtxBar_Dashboard_Sel() @@ -193,9 +172,6 @@ func RenderDashboard(d *DashboardState, taskName string, st Styles, width, heigh cbItems = CtxBar_Dashboard_NoSel() } l := PageLayout{ - TitleLeft: "AIT 正在测试 ─ " + truncate(taskName, 25), - TitleRight: statusStr, - InfoLeft: subtitle, CtxItems: cbItems, FooterParts: []string{"[q] 退出"}, } diff --git a/internal/tui/pages/layout.go b/internal/tui/pages/layout.go index e10c285..324f123 100644 --- a/internal/tui/pages/layout.go +++ b/internal/tui/pages/layout.go @@ -22,10 +22,6 @@ const ( // PageLayout 描述一个完整页面的 chrome(底部 ContextBar + Footer)。 // 各页面 Render 函数先构造 PageLayout,再调用 Assemble 拼装最终输出。 type PageLayout struct { - TitleLeft string - TitleRight string - InfoLeft string - InfoRight string CtxItems []ContextBarItem FooterParts []string } diff --git a/internal/tui/pages/reqdetail.go b/internal/tui/pages/reqdetail.go index ece0dcc..62efab4 100644 --- a/internal/tui/pages/reqdetail.go +++ b/internal/tui/pages/reqdetail.go @@ -112,17 +112,7 @@ func RenderReqDetail(s *ReqDetailState, taskName string, st Styles, width, heigh } r := s.Requests[idx] - // ── 状态标识 ── - statusStr := st.Ok.Render("✓ 成功") - if !r.Success { - statusStr = st.ErrStyle.Render("✗ 失败") - } - l := PageLayout{ - TitleLeft: fmt.Sprintf("AIT 请求详情 - %s #%d", truncate(taskName, 20), idx+1), - TitleRight: statusStr, - InfoLeft: fmt.Sprintf("任务: %s 请求 %d / %d", - truncate(taskName, 20), idx+1, len(s.Requests)), CtxItems: CtxBar_ReqDetail(), FooterParts: []string{"[q] 退出"}, } diff --git a/internal/tui/pages/taskdetail.go b/internal/tui/pages/taskdetail.go index 39b0519..5438261 100644 --- a/internal/tui/pages/taskdetail.go +++ b/internal/tui/pages/taskdetail.go @@ -102,7 +102,6 @@ func RenderTaskDetail(s *TaskDetailState, st Styles, width, height int) string { t := s.Task inp := t.Input - updatedStr := timeAgo(t.UpdatedAt) var cbItems []ContextBarItem if len(s.History) > 0 { cbItems = CtxBar_TaskDetail_HasHistory() @@ -110,9 +109,6 @@ func RenderTaskDetail(s *TaskDetailState, st Styles, width, height int) string { cbItems = CtxBar_TaskDetail_NoHistory() } l := PageLayout{ - TitleLeft: "AIT 任务详情 ─ " + truncate(t.Name, 30), - InfoLeft: fmt.Sprintf("任务 ID: %s 更新: %s %s", - truncate(t.ID, 10), t.UpdatedAt.Format("2006-01-02 15:04"), updatedStr), CtxItems: cbItems, FooterParts: []string{"[b/Esc] 返回列表", "◆ AIT v0.1"}, } diff --git a/internal/tui/pages/tasklist.go b/internal/tui/pages/tasklist.go index 97e62a1..63a2b62 100644 --- a/internal/tui/pages/tasklist.go +++ b/internal/tui/pages/tasklist.go @@ -143,10 +143,6 @@ func RenderTaskList(s *TaskListState, st Styles, width, height int) string { return renderTooSmall(st, width, height) } - lastRunStr := "" - if lt := s.latestRunAt(); lt != nil { - lastRunStr = "最近运行: " + lt.Format("2006-01-02 15:04") - } var cbItems []ContextBarItem if t, ok := s.CurrentTask(); ok { if s.IsTaskRunning(t.ID) { @@ -156,8 +152,6 @@ func RenderTaskList(s *TaskListState, st Styles, width, height int) string { } } l := PageLayout{ - TitleLeft: "AIT 任务中心", - InfoLeft: fmt.Sprintf("已保存任务: %d %s", len(s.Tasks), lastRunStr), CtxItems: cbItems, FooterParts: []string{"[↑↓] 选择", "[a] 新建", "[q] 退出", "◆ AIT v0.1"}, } diff --git a/internal/tui/pages/turbodash.go b/internal/tui/pages/turbodash.go index c07c394..eb12fbe 100644 --- a/internal/tui/pages/turbodash.go +++ b/internal/tui/pages/turbodash.go @@ -139,18 +139,6 @@ func RenderTurboDash(d *TurboDashState, taskName string, st Styles, width, heigh } rs := d.RunState - // ── 状态标识 ── - statusStr := st.Ok.Render("探测中") - if rs != nil && rs.Status != server.RunStatusRunning { - statusStr = st.Muted.Render(string(rs.Status)) - } - - subtitle := "─" - if rs != nil && len(rs.Levels) > 0 { - subtitle = fmt.Sprintf("%s · 当前并发: %d 已完成 %d 级", - "─", rs.CurrentLevel, len(rs.Levels)) - } - var cbItems []ContextBarItem if d.LevelSel >= 0 && rs != nil && d.LevelSel < len(rs.Levels) { cbItems = CtxBar_TurboDash_Sel() @@ -158,9 +146,6 @@ func RenderTurboDash(d *TurboDashState, taskName string, st Styles, width, heigh cbItems = CtxBar_TurboDash_NoSel() } l := PageLayout{ - TitleLeft: "AIT Turbo 探测 ─ " + truncate(taskName, 22), - TitleRight: statusStr, - InfoLeft: subtitle, CtxItems: cbItems, FooterParts: []string{"[q] 退出"}, } diff --git a/internal/tui/pages/wizard.go b/internal/tui/pages/wizard.go index 03ce931..7fbf530 100644 --- a/internal/tui/pages/wizard.go +++ b/internal/tui/pages/wizard.go @@ -83,59 +83,52 @@ func NewWizardState() *WizardState { } } -// NewWizardStateEdit 创建编辑任务向导状态(预填任务数据)。 +// NewWizardStateEdit 创建编辑任务向导状态(预填任务数据,零值字段沿用默认值)。 func NewWizardStateEdit(t *types.TaskDefinition) *WizardState { if t == nil { return NewWizardState() } + wz := NewWizardState() inp := t.Input tc := inp.TurboConfig - wz := &WizardState{ - Step: wizardStep1, - EditingID: t.ID, - Name: t.Name, - Protocol: types.NormalizeProtocol(inp.Protocol), - EndpointURL: inp.EndpointURL, - APIKey: inp.ApiKey, - Model: inp.Model, - Turbo: inp.Turbo, - Stream: inp.Stream, - Concurrency: inp.Concurrency, - Count: inp.Count, - Timeout: int(inp.Timeout.Seconds()), - InitConcurrency: tc.InitConcurrency, - MaxConcurrency: tc.MaxConcurrency, - StepSize: tc.StepSize, - LevelRequests: tc.LevelRequests, - MinSuccessRate: tc.MinSuccessRate * 100, // 转为百分比 - PromptMode: inp.PromptMode, - PromptText: inp.PromptText, - PromptFile: inp.PromptFile, - PromptLength: inp.PromptLength, - } - if wz.PromptMode == "" { - wz.PromptMode = PromptModeText - } - if wz.Concurrency == 0 { - wz.Concurrency = 10 - } - if wz.Count == 0 { - wz.Count = 100 - } - if wz.Timeout == 0 { - wz.Timeout = 30 - } - if wz.MinSuccessRate == 0 { - wz.MinSuccessRate = 90 - } - if wz.StepSize == 0 { - wz.StepSize = 2 - } - if wz.LevelRequests == 0 { - wz.LevelRequests = 30 - } - if wz.MaxConcurrency == 0 { - wz.MaxConcurrency = 50 + + wz.EditingID = t.ID + wz.Name = t.Name + wz.Protocol = types.NormalizeProtocol(inp.Protocol) + wz.EndpointURL = inp.EndpointURL + wz.APIKey = inp.ApiKey + wz.Model = inp.Model + wz.Turbo = inp.Turbo + wz.Stream = inp.Stream + wz.PromptText = inp.PromptText + wz.PromptFile = inp.PromptFile + wz.PromptLength = inp.PromptLength + if inp.PromptMode != "" { + wz.PromptMode = inp.PromptMode + } + if inp.Concurrency > 0 { + wz.Concurrency = inp.Concurrency + } + if inp.Count > 0 { + wz.Count = inp.Count + } + if inp.Timeout > 0 { + wz.Timeout = int(inp.Timeout.Seconds()) + } + if tc.InitConcurrency > 0 { + wz.InitConcurrency = tc.InitConcurrency + } + if tc.MaxConcurrency > 0 { + wz.MaxConcurrency = tc.MaxConcurrency + } + if tc.StepSize > 0 { + wz.StepSize = tc.StepSize + } + if tc.LevelRequests > 0 { + wz.LevelRequests = tc.LevelRequests + } + if tc.MinSuccessRate > 0 { + wz.MinSuccessRate = tc.MinSuccessRate * 100 } return wz } @@ -193,6 +186,20 @@ const ( fieldEnum // 枚举循环 ) +// intField 构造一个整数输入字段(值 > 0 时才写入)。 +func intField(label string, get func(*WizardState) int, set func(*WizardState, int)) fieldDef { + return fieldDef{ + kind: fieldNumber, + label: label, + get: func(wz *WizardState) string { return strconv.Itoa(get(wz)) }, + set: func(wz *WizardState, v string) { + if n, err := strconv.Atoi(v); err == nil && n > 0 { + set(wz, n) + } + }, + } +} + // step1Fields 返回步骤1的字段列表。 func step1Fields() []fieldDef { protocols := []string{ @@ -267,75 +274,20 @@ func step2Fields(turbo bool) []fieldDef { if !turbo { fields = append(fields, - fieldDef{ - kind: fieldNumber, label: "并发数", - get: func(wz *WizardState) string { return strconv.Itoa(wz.Concurrency) }, - set: func(wz *WizardState, v string) { - if n, err := strconv.Atoi(v); err == nil && n > 0 { - wz.Concurrency = n - } - }, - }, - fieldDef{ - kind: fieldNumber, label: "请求总数", - get: func(wz *WizardState) string { return strconv.Itoa(wz.Count) }, - set: func(wz *WizardState, v string) { - if n, err := strconv.Atoi(v); err == nil && n > 0 { - wz.Count = n - } - }, - }, - fieldDef{ - kind: fieldNumber, label: "超时(秒)", - get: func(wz *WizardState) string { return strconv.Itoa(wz.Timeout) }, - set: func(wz *WizardState, v string) { - if n, err := strconv.Atoi(v); err == nil && n > 0 { - wz.Timeout = n - } - }, - }, + intField("并发数", func(wz *WizardState) int { return wz.Concurrency }, func(wz *WizardState, n int) { wz.Concurrency = n }), + intField("请求总数", func(wz *WizardState) int { return wz.Count }, func(wz *WizardState, n int) { wz.Count = n }), + intField("超时(秒)", func(wz *WizardState) int { return wz.Timeout }, func(wz *WizardState, n int) { wz.Timeout = n }), ) } else { fields = append(fields, + intField("初始并发", func(wz *WizardState) int { return wz.InitConcurrency }, func(wz *WizardState, n int) { wz.InitConcurrency = n }), + intField("最大并发", func(wz *WizardState) int { return wz.MaxConcurrency }, func(wz *WizardState, n int) { wz.MaxConcurrency = n }), + intField("步进值", func(wz *WizardState) int { return wz.StepSize }, func(wz *WizardState, n int) { wz.StepSize = n }), + intField("每级请求数", func(wz *WizardState) int { return wz.LevelRequests }, func(wz *WizardState, n int) { wz.LevelRequests = n }), fieldDef{ - kind: fieldNumber, label: "初始并发", - get: func(wz *WizardState) string { return strconv.Itoa(wz.InitConcurrency) }, - set: func(wz *WizardState, v string) { - if n, err := strconv.Atoi(v); err == nil && n > 0 { - wz.InitConcurrency = n - } - }, - }, - fieldDef{ - kind: fieldNumber, label: "最大并发", - get: func(wz *WizardState) string { return strconv.Itoa(wz.MaxConcurrency) }, - set: func(wz *WizardState, v string) { - if n, err := strconv.Atoi(v); err == nil && n > 0 { - wz.MaxConcurrency = n - } - }, - }, - fieldDef{ - kind: fieldNumber, label: "步进值", - get: func(wz *WizardState) string { return strconv.Itoa(wz.StepSize) }, - set: func(wz *WizardState, v string) { - if n, err := strconv.Atoi(v); err == nil && n > 0 { - wz.StepSize = n - } - }, - }, - fieldDef{ - kind: fieldNumber, label: "每级请求数", - get: func(wz *WizardState) string { return strconv.Itoa(wz.LevelRequests) }, - set: func(wz *WizardState, v string) { - if n, err := strconv.Atoi(v); err == nil && n > 0 { - wz.LevelRequests = n - } - }, - }, - fieldDef{ - kind: fieldNumber, label: "最低成功率%", - get: func(wz *WizardState) string { return fmt.Sprintf("%.0f", wz.MinSuccessRate) }, + kind: fieldNumber, + label: "最低成功率", + get: func(wz *WizardState) string { return fmt.Sprintf("%.0f", wz.MinSuccessRate) }, set: func(wz *WizardState, v string) { if f, err := strconv.ParseFloat(v, 64); err == nil && f > 0 && f <= 100 { wz.MinSuccessRate = f @@ -450,19 +402,7 @@ func HandleWizardKey(wz *WizardState, msg tea.KeyMsg, client Client) (*WizardSta wz.ScrollOff = 0 case "end": wz.ScrollOff = 1 << 30 - case "enter": - // 保存任务 - cfg := wz.BuildTaskConfig() - var cmd tea.Cmd - if wz.EditingID != "" { - cmd = client.UpdateTaskCmd(wz.EditingID, cfg) - } else { - cmd = client.CreateTaskCmd(cfg, true) // autoStart - } - nav = NavAction{To: NavTaskList} - return wz, cmd, nav - case "r": - // 保存并运行(强制启动,忽略干扰检测) + case "enter", "r": cfg := wz.BuildTaskConfig() var cmd tea.Cmd if wz.EditingID != "" { @@ -571,6 +511,16 @@ func RenderWizard(wz *WizardState, st Styles, width, height int) string { return renderTooSmall(st, width, height) } + l := PageLayout{ + CtxItems: wizardContextItems(wz.Step), + FooterParts: []string{"[q] 退出", "◆ AIT v0.1"}, + } + + content := buildWizardPageContent(wz, st, ContentWidth(width), l.ContentHeight(height)) + return l.Assemble(wrapPanel(st, content, width), st, width) +} + +func buildWizardPageContent(wz *WizardState, st Styles, width, maxH int) string { stepTitles := []string{"基本信息", "测试参数", "确认保存"} stepDescs := []string{ "配置任务名称、模型协议和连接信息。", @@ -581,19 +531,9 @@ func RenderWizard(wz *WizardState, st Styles, width, height int) string { if wz.EditingID != "" { action = "编辑任务" } + stepTitle := stepTitles[int(wz.Step)] + stepDesc := stepDescs[int(wz.Step)] - l := PageLayout{ - TitleLeft: fmt.Sprintf("AIT %s", action), - InfoLeft: fmt.Sprintf("步骤 %d/3 · %s", int(wz.Step)+1, stepTitles[int(wz.Step)]), - CtxItems: wizardContextItems(wz.Step), - FooterParts: []string{"[q] 退出", "◆ AIT v0.1"}, - } - - content := buildWizardPageContent(wz, st, action, stepTitles[int(wz.Step)], stepDescs[int(wz.Step)], ContentWidth(width), l.ContentHeight(height)) - return l.Assemble(wrapPanel(st, content, width), st, width) -} - -func buildWizardPageContent(wz *WizardState, st Styles, action, stepTitle, stepDesc string, width, maxH int) string { titleLeft := st.SectionHead.Render(action) titleRight := st.Muted.Render(fmt.Sprintf("步骤 %d/3 · %s", int(wz.Step)+1, stepTitle)) var topLines []string @@ -649,7 +589,7 @@ func buildWizardPageContent(wz *WizardState, st Styles, action, stepTitle, stepD if showBottomDivider { lines = append(lines, dividerLine(st, width)) } - lines = append(lines, st.Muted.Render(truncate(wizardStatusText(wz, focusLine, offset, end, len(bodyLines), bodyH), width))) + lines = append(lines, st.Muted.Render(truncate(wizardStatusText(wz, offset, end, len(bodyLines), bodyH), width))) if len(lines) > maxH { lines = lines[:maxH] @@ -779,8 +719,8 @@ func renderStep3Summary(wz *WizardState, st Styles, innerW int) []string { addRow("并发数", strconv.Itoa(wz.Concurrency), st.Value) addRow("请求总数", strconv.Itoa(wz.Count), st.Value) addRow("超时", fmt.Sprintf("%ds", wz.Timeout), st.Value) - addRow("流式模式", boolLabel(wz.Stream), st.Value) } + addRow("流式模式", boolLabel(wz.Stream), st.Value) lines = append(lines, "", st.SectionHead.Render("Prompt")) addRow("输入方式", wizardPromptModeLabel(wz.PromptMode), st.Value) @@ -818,9 +758,6 @@ func renderWizardStepStrip(step wizardStep) string { func wizardFieldLabel(f fieldDef, wz *WizardState) string { if f.label != "内容" { - if f.label == "最低成功率%" { - return "最低成功率" - } return f.label } switch wz.PromptMode { @@ -895,19 +832,25 @@ func wizardContextItems(step wizardStep) []ContextBarItem { } } -func wizardStatusText(wz *WizardState, focusLine, offset, end, total, visible int) string { - if total <= 0 { - return "暂无配置项" - } +func wizardStatusText(wz *WizardState, offset, end, scrollTotal, visible int) string { if wz.Step == wizardStep3 { - if total > visible { - return fmt.Sprintf("确认项 %d-%d/%d", offset+1, end, total) + if scrollTotal <= 0 { + return "暂无确认项" + } + if scrollTotal > visible { + return fmt.Sprintf("确认项 %d-%d/%d", offset+1, end, scrollTotal) } - return fmt.Sprintf("共 %d 项待确认", total) + return fmt.Sprintf("共 %d 项待确认", scrollTotal) } - current := clampInt(focusLine+1, 1, total) - if total > visible { - return fmt.Sprintf("当前字段 %d/%d · 内容较多时自动滚动", current, total) + var fieldTotal int + switch wz.Step { + case wizardStep1: + fieldTotal = len(step1Fields()) + case wizardStep2: + fieldTotal = len(step2Fields(wz.Turbo)) + } + if fieldTotal <= 0 { + return "暂无配置项" } - return fmt.Sprintf("当前字段 %d/%d", current, total) + return fmt.Sprintf("当前字段 %d/%d", wz.FieldIndex+1, fieldTotal) } From 35b995051d3f03079da44dd156cc29ce29b5af39 Mon Sep 17 00:00:00 2001 From: Alain Date: Sun, 17 May 2026 16:50:47 +0800 Subject: [PATCH 11/52] feat: Enhance request and response handling in client and server - Added RequestBody and ResponseBody fields to ResponseMetrics for detailed logging. - Updated AnthropicClient and OpenAIClient to populate new fields with raw request and response data. - Modified server event handling to include request and response details in metrics. - Implemented tests for event bus and request metrics mapping. - Improved UI to display request and response data in task and request detail views. - Refactored key bindings in task detail and dashboard pages for better user experience. --- internal/client/anthropic.go | 43 +- internal/client/client.go | 4 + internal/client/openai.go | 32 +- internal/server/run.go | 25 +- internal/server/server_test.go | 712 +++++++++++++++++++++++++++++++ internal/server/types.go | 10 +- internal/tui/model.go | 12 +- internal/tui/pages/contextbar.go | 7 +- internal/tui/pages/dashboard.go | 17 +- internal/tui/pages/reqdetail.go | 16 +- internal/tui/pages/taskdetail.go | 17 +- internal/tui/pages/turbodash.go | 2 +- 12 files changed, 838 insertions(+), 59 deletions(-) create mode 100644 internal/server/server_test.go diff --git a/internal/client/anthropic.go b/internal/client/anthropic.go index 235f23a..f778026 100644 --- a/internal/client/anthropic.go +++ b/internal/client/anthropic.go @@ -304,6 +304,7 @@ func (c *AnthropicClient) Request(prompt string, stream bool) (*ResponseMetrics, var inputTokens int var cachedInputTokens int var streamChunks []string // 用于记录所有流式数据块 + var rawResponseLines strings.Builder // 记录流式响应开始日志 if c.logger != nil && c.logger.IsEnabled() { @@ -320,6 +321,8 @@ func (c *AnthropicClient) Request(prompt string, stream bool) (*ResponseMetrics, for scanner.Scan() { line := scanner.Text() + rawResponseLines.WriteString(line) + rawResponseLines.WriteByte('\n') if strings.HasPrefix(line, "data: ") { data := strings.TrimPrefix(line, "data: ") if strings.TrimSpace(data) == "" { @@ -394,16 +397,18 @@ func (c *AnthropicClient) Request(prompt string, stream bool) (*ResponseMetrics, } return &ResponseMetrics{ - TimeToFirstToken: firstTokenTime, - TotalTime: totalTime, - DNSTime: dnsTime, - ConnectTime: connectTime, - TLSHandshakeTime: tlsTime, - TargetIP: targetIP, - PromptTokens: inputTokens, + TimeToFirstToken: firstTokenTime, + TotalTime: totalTime, + DNSTime: dnsTime, + ConnectTime: connectTime, + TLSHandshakeTime: tlsTime, + TargetIP: targetIP, + PromptTokens: inputTokens, CachedInputTokens: cachedInputTokens, - CompletionTokens: outputTokens, - ErrorMessage: "", + CompletionTokens: outputTokens, + RequestBody: string(reqBodyBytes), + ResponseBody: rawResponseLines.String(), + ErrorMessage: "", }, nil } else { // 非流式响应处理 @@ -495,16 +500,18 @@ func (c *AnthropicClient) Request(prompt string, stream bool) (*ResponseMetrics, } return &ResponseMetrics{ - TimeToFirstToken: totalTime, // 非流式模式下,所有token一次性返回,TTFT等于总时间 - TotalTime: totalTime, - DNSTime: dnsTime, - ConnectTime: connectTime, - TLSHandshakeTime: tlsTime, - TargetIP: targetIP, - PromptTokens: anthropicResp.Usage.InputTokens, + TimeToFirstToken: totalTime, // 非流式模式下,所有token一次性返回,TTFT等于总时间 + TotalTime: totalTime, + DNSTime: dnsTime, + ConnectTime: connectTime, + TLSHandshakeTime: tlsTime, + TargetIP: targetIP, + PromptTokens: anthropicResp.Usage.InputTokens, CachedInputTokens: anthropicResp.Usage.CacheReadInputTokens, - CompletionTokens: anthropicResp.Usage.OutputTokens, - ErrorMessage: "", + CompletionTokens: anthropicResp.Usage.OutputTokens, + RequestBody: string(reqBodyBytes), + ResponseBody: string(responseData), + ErrorMessage: "", }, nil } } diff --git a/internal/client/client.go b/internal/client/client.go index bbf9b9e..93b48f4 100644 --- a/internal/client/client.go +++ b/internal/client/client.go @@ -28,6 +28,10 @@ type ResponseMetrics struct { // 错误信息 ErrorMessage string // 错误信息(如果有) + + // 原始数据(供请求详情页展示和复制) + RequestBody string // 发送给 API 的原始 JSON 请求体 + ResponseBody string // API 返回的原始数据(非流式为 JSON,流式为所有 SSE 行拼接) } // ModelClient 定义统一的模型客户端接口 diff --git a/internal/client/openai.go b/internal/client/openai.go index 9f9bddb..841ea05 100644 --- a/internal/client/openai.go +++ b/internal/client/openai.go @@ -211,7 +211,7 @@ func (c *OpenAIClient) buildRequestBody(prompt string, stream bool) ([]byte, err return json.Marshal(reqBody) } -func (c *OpenAIClient) parseResponsesStream(resp *http.Response, t0 time.Time, dnsTime, connectTime, tlsTime time.Duration, targetIP string) (*ResponseMetrics, error) { +func (c *OpenAIClient) parseResponsesStream(resp *http.Response, t0 time.Time, dnsTime, connectTime, tlsTime time.Duration, targetIP string, requestBody []byte) (*ResponseMetrics, error) { scanner := bufio.NewScanner(resp.Body) firstTokenTime := time.Duration(0) gotFirst := false @@ -220,9 +220,12 @@ func (c *OpenAIClient) parseResponsesStream(resp *http.Response, t0 time.Time, d var cachedInputTokens int var thinkingTokens int var streamChunks []string + var rawResponseBody strings.Builder for scanner.Scan() { line := scanner.Text() + rawResponseBody.WriteString(line) + rawResponseBody.WriteByte('\n') if !strings.HasPrefix(line, "data: ") { continue } @@ -239,9 +242,11 @@ func (c *OpenAIClient) parseResponsesStream(resp *http.Response, t0 time.Time, d continue } - if !gotFirst && event.Delta != "" { - firstTokenTime = time.Since(t0) - gotFirst = true + if event.Delta != "" { + if !gotFirst { + firstTokenTime = time.Since(t0) + gotFirst = true + } } if event.Usage != nil { @@ -285,11 +290,13 @@ func (c *OpenAIClient) parseResponsesStream(resp *http.Response, t0 time.Time, d CachedInputTokens: cachedInputTokens, CompletionTokens: completionTokens, ThinkingTokens: thinkingTokens, + RequestBody: string(requestBody), + ResponseBody: rawResponseBody.String(), ErrorMessage: "", }, nil } -func (c *OpenAIClient) parseResponsesNonStream(responseData []byte, totalTime, dnsTime, connectTime, tlsTime time.Duration, targetIP string) (*ResponseMetrics, error) { +func (c *OpenAIClient) parseResponsesNonStream(responseData []byte, totalTime, dnsTime, connectTime, tlsTime time.Duration, targetIP string, requestBody []byte) (*ResponseMetrics, error) { var apiResp ResponsesAPIResponse if err := json.Unmarshal(responseData, &apiResp); err != nil { if c.logger != nil && c.logger.IsEnabled() { @@ -318,6 +325,8 @@ func (c *OpenAIClient) parseResponsesNonStream(responseData []byte, totalTime, d CachedInputTokens: extractCachedInputTokens(apiResp.Usage.InputTokensDetails), CompletionTokens: apiResp.Usage.OutputTokens, ThinkingTokens: extractThinkingTokens(apiResp.Usage.OutputTokensDetails), + RequestBody: string(requestBody), + ResponseBody: string(responseData), ErrorMessage: "", }, nil } @@ -531,7 +540,7 @@ func (c *OpenAIClient) Request(prompt string, stream bool) (*ResponseMetrics, er } if c.Provider == types.ProtocolOpenAIResponses { - return c.parseResponsesStream(resp, t0, dnsTime, connectTime, tlsTime, targetIP) + return c.parseResponsesStream(resp, t0, dnsTime, connectTime, tlsTime, targetIP, jsonData) } scanner := bufio.NewScanner(resp.Body) @@ -543,6 +552,7 @@ func (c *OpenAIClient) Request(prompt string, stream bool) (*ResponseMetrics, er var cachedInputTokens int var thinkingTokens int var streamChunks []string // 用于记录所有流式数据块 + var rawResponseLines strings.Builder // 记录流式响应开始日志 if c.logger != nil && c.logger.IsEnabled() { @@ -559,6 +569,8 @@ func (c *OpenAIClient) Request(prompt string, stream bool) (*ResponseMetrics, er for scanner.Scan() { line := scanner.Text() + rawResponseLines.WriteString(line) + rawResponseLines.WriteByte('\n') if strings.HasPrefix(line, "data: ") { data := strings.TrimPrefix(line, "data: ") if data == "[DONE]" { @@ -638,6 +650,8 @@ func (c *OpenAIClient) Request(prompt string, stream bool) (*ResponseMetrics, er CachedInputTokens: cachedInputTokens, CompletionTokens: completionTokens, ThinkingTokens: thinkingTokens, + RequestBody: string(jsonData), + ResponseBody: rawResponseLines.String(), ErrorMessage: "", }, nil } else { @@ -721,7 +735,7 @@ func (c *OpenAIClient) Request(prompt string, stream bool) (*ResponseMetrics, er } if c.Provider == types.ProtocolOpenAIResponses { - return c.parseResponsesNonStream(responseData, totalTime, dnsTime, connectTime, tlsTime, targetIP) + return c.parseResponsesNonStream(responseData, totalTime, dnsTime, connectTime, tlsTime, targetIP, jsonData) } var chatResp ChatCompletionResponse @@ -755,8 +769,10 @@ func (c *OpenAIClient) Request(prompt string, stream bool) (*ResponseMetrics, er CachedInputTokens: extractCachedInputTokens(chatResp.Usage.PromptTokensDetails), CompletionTokens: chatResp.Usage.CompletionTokens, ThinkingTokens: thinkingTokens, + RequestBody: string(jsonData), + ResponseBody: string(responseData), ErrorMessage: "", - }, nil + }, nil } } diff --git a/internal/server/run.go b/internal/server/run.go index e28c809..f1a172a 100644 --- a/internal/server/run.go +++ b/internal/server/run.go @@ -69,6 +69,8 @@ func mapRequestMetrics(m *client.ResponseMetrics, idx int, err error) *RequestMe if err != nil && rm.ErrorMessage == "" { rm.ErrorMessage = err.Error() } + rm.RequestBody = m.RequestBody + rm.ResponseBody = m.ResponseBody if m.TotalTime > 0 && m.CompletionTokens > 0 { rm.TPS = float64(m.CompletionTokens) / m.TotalTime.Seconds() @@ -146,6 +148,24 @@ func (s *serverImpl) runStandard(ar *activeRun, runID RunID, taskDef types.TaskD ar.rnr = rnr ar.mu.Unlock() + // 启动 500ms 进度快照 goroutine,定期向订阅者推送 EventProgressTick。 + stopTick := make(chan struct{}) + go func() { + ticker := time.NewTicker(500 * time.Millisecond) + defer ticker.Stop() + for { + select { + case <-ticker.C: + ar.mu.RLock() + snap := ar.snapshotState() + ar.mu.RUnlock() + s.bus.Publish(Event{RunID: runID, Kind: EventProgressTick, Payload: snap}) + case <-stopTick: + return + } + } + }() + reportData, err := rnr.RunWithCallback(func(metrics *client.ResponseMetrics, idx int, cbErr error) { rm := mapRequestMetrics(metrics, idx, cbErr) @@ -173,11 +193,14 @@ func (s *serverImpl) runStandard(ar *activeRun, runID RunID, taskDef types.TaskD if done > 0 { ar.state.SuccessRate = float64(successCount) / float64(done) * 100 } + snap := ar.snapshotState() ar.mu.Unlock() - s.bus.Publish(Event{RunID: runID, Kind: EventRequestDone, Payload: rm}) + s.bus.Publish(Event{RunID: runID, Kind: EventRequestDone, Payload: snap}) }) + close(stopTick) + if err != nil { s.failRun(ar, runID, taskDef, historyDir, err) return diff --git a/internal/server/server_test.go b/internal/server/server_test.go new file mode 100644 index 0000000..f5f946c --- /dev/null +++ b/internal/server/server_test.go @@ -0,0 +1,712 @@ +package server + +import ( + "errors" + "os" + "path/filepath" + "strings" + "testing" + "time" + + "github.com/yinxulai/ait/internal/client" + "github.com/yinxulai/ait/internal/store" + "github.com/yinxulai/ait/internal/types" +) + +// ── test helpers ────────────────────────────────────────────────────────────── + +func newTestServer(t *testing.T) *serverImpl { + t.Helper() + dir := t.TempDir() + historyDir := filepath.Join(dir, "history") + if err := os.MkdirAll(historyDir, 0o755); err != nil { + t.Fatalf("mkdir history: %v", err) + } + ts := store.NewTaskStore(filepath.Join(dir, "tasks.json")) + if err := ts.Load(); err != nil { + t.Fatalf("load task store: %v", err) + } + return &serverImpl{ + taskStore: ts, + bus: newEventBus(), + activeRuns: make(map[RunID]*activeRun), + historyDir: historyDir, + } +} + +func makeTaskConfig(name string) TaskConfig { + return TaskConfig{ + Name: name, + Input: types.Input{ + Protocol: types.ProtocolOpenAICompletions, + EndpointURL: "http://localhost:19999", + Model: "test-model", + Concurrency: 1, + Count: 1, + PromptMode: "text", + PromptText: "hello", + }, + } +} + +// ── eventBus ────────────────────────────────────────────────────────────────── + +func TestEventBus_PublishDelivered(t *testing.T) { + bus := newEventBus() + rid := RunID("run_1") + ch, cancel := bus.Subscribe(rid) + defer cancel() + + want := Event{RunID: rid, Kind: EventRequestDone} + bus.Publish(want) + + select { + case got := <-ch: + if got != want { + t.Fatalf("got %v, want %v", got, want) + } + case <-time.After(time.Second): + t.Fatal("timeout waiting for event") + } +} + +func TestEventBus_MultipleSubscribers(t *testing.T) { + bus := newEventBus() + rid := RunID("run_multi") + const n = 3 + chs := make([]<-chan Event, n) + for i := range chs { + ch, cancel := bus.Subscribe(rid) + chs[i] = ch + defer cancel() + } + + ev := Event{RunID: rid, Kind: EventRunComplete} + bus.Publish(ev) + + for i, ch := range chs { + select { + case got := <-ch: + if got != ev { + t.Errorf("subscriber %d: got %v, want %v", i, got, ev) + } + case <-time.After(time.Second): + t.Errorf("subscriber %d: timeout", i) + } + } +} + +func TestEventBus_CancelClosesChannel(t *testing.T) { + bus := newEventBus() + rid := RunID("run_cancel") + ch, cancel := bus.Subscribe(rid) + cancel() + + select { + case _, ok := <-ch: + if ok { + t.Fatal("channel should be closed after cancel") + } + case <-time.After(time.Second): + t.Fatal("timeout: channel not closed after cancel") + } +} + +func TestEventBus_CloseRunClosesAllChannels(t *testing.T) { + bus := newEventBus() + rid := RunID("run_close") + ch1, _ := bus.Subscribe(rid) + ch2, _ := bus.Subscribe(rid) + + bus.CloseRun(rid) + + for i, ch := range []<-chan Event{ch1, ch2} { + select { + case _, ok := <-ch: + if ok { + t.Errorf("ch%d should be closed after CloseRun", i+1) + } + case <-time.After(time.Second): + t.Errorf("ch%d: timeout waiting for close", i+1) + } + } +} + +func TestEventBus_FullChannelDoesNotBlock(t *testing.T) { + bus := newEventBus() + rid := RunID("run_full") + // Subscribe but never drain the channel. + _, cancel := bus.Subscribe(rid) + defer cancel() + + done := make(chan struct{}) + go func() { + // Publish more events than the channel capacity (64) to verify non-blocking. + for i := 0; i < 100; i++ { + bus.Publish(Event{RunID: rid, Kind: EventRequestDone}) + } + close(done) + }() + + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatal("Publish blocked on full subscriber channel") + } +} + +func TestEventBus_EventsOnlyDeliveredToMatchingRunID(t *testing.T) { + bus := newEventBus() + ch1, cancel1 := bus.Subscribe(RunID("run_a")) + ch2, cancel2 := bus.Subscribe(RunID("run_b")) + defer cancel1() + defer cancel2() + + bus.Publish(Event{RunID: "run_a", Kind: EventRequestDone}) + + select { + case <-ch1: + // expected + case <-time.After(time.Second): + t.Fatal("run_a subscriber should have received event") + } + + // run_b should NOT receive the event. + select { + case <-ch2: + t.Fatal("run_b subscriber should not receive event for run_a") + default: + } +} + +// ── mapRequestMetrics ───────────────────────────────────────────────────────── + +func TestMapRequestMetrics_NilMetricsNoError(t *testing.T) { + rm := mapRequestMetrics(nil, 3, nil) + if rm.Index != 3 { + t.Errorf("Index: got %d, want 3", rm.Index) + } + if rm.Success { + t.Error("expected Success=false") + } + if rm.ErrorMessage != "" { + t.Errorf("expected empty ErrorMessage, got %q", rm.ErrorMessage) + } +} + +func TestMapRequestMetrics_NilMetricsWithError(t *testing.T) { + err := errors.New("connection refused") + rm := mapRequestMetrics(nil, 0, err) + if rm.Success { + t.Error("expected Success=false") + } + if rm.ErrorMessage != err.Error() { + t.Errorf("ErrorMessage: got %q, want %q", rm.ErrorMessage, err.Error()) + } +} + +func TestMapRequestMetrics_SuccessFields(t *testing.T) { + m := &client.ResponseMetrics{ + TotalTime: 2 * time.Second, + TimeToFirstToken: 100 * time.Millisecond, + CompletionTokens: 100, + PromptTokens: 200, + CachedInputTokens: 50, + TargetIP: "1.2.3.4", + DNSTime: 5 * time.Millisecond, + ConnectTime: 10 * time.Millisecond, + TLSHandshakeTime: 15 * time.Millisecond, + } + rm := mapRequestMetrics(m, 5, nil) + + if !rm.Success { + t.Error("expected Success=true") + } + if rm.Index != 5 { + t.Errorf("Index: got %d, want 5", rm.Index) + } + // TPS = CompletionTokens / TotalTime.Seconds() = 100 / 2 = 50 + if rm.TPS != 50.0 { + t.Errorf("TPS: got %v, want 50", rm.TPS) + } + // CacheHitRate = CachedInputTokens / PromptTokens = 50 / 200 = 0.25 + if rm.CacheHitRate != 0.25 { + t.Errorf("CacheHitRate: got %v, want 0.25", rm.CacheHitRate) + } + if rm.TargetIP != "1.2.3.4" { + t.Errorf("TargetIP: got %q, want %q", rm.TargetIP, "1.2.3.4") + } + if rm.TTFT != 100*time.Millisecond { + t.Errorf("TTFT: got %v, want 100ms", rm.TTFT) + } + if rm.CompletionTokens != 100 { + t.Errorf("CompletionTokens: got %d, want 100", rm.CompletionTokens) + } + if rm.PromptTokens != 200 { + t.Errorf("PromptTokens: got %d, want 200", rm.PromptTokens) + } + if rm.CachedTokens != 50 { + t.Errorf("CachedTokens: got %d, want 50", rm.CachedTokens) + } +} + +func TestMapRequestMetrics_FailureFromErrorMessage(t *testing.T) { + m := &client.ResponseMetrics{ErrorMessage: "rate limit exceeded"} + rm := mapRequestMetrics(m, 0, nil) + if rm.Success { + t.Error("expected Success=false when ErrorMessage is set") + } + if rm.ErrorMessage != "rate limit exceeded" { + t.Errorf("ErrorMessage: got %q", rm.ErrorMessage) + } +} + +func TestMapRequestMetrics_ErrorOverridesEmptyMessage(t *testing.T) { + m := &client.ResponseMetrics{} + err := errors.New("transport error") + rm := mapRequestMetrics(m, 0, err) + if rm.Success { + t.Error("expected Success=false") + } + if rm.ErrorMessage != err.Error() { + t.Errorf("ErrorMessage: got %q, want %q", rm.ErrorMessage, err.Error()) + } +} + +func TestMapRequestMetrics_ZeroTotalTimeSkipsTPS(t *testing.T) { + m := &client.ResponseMetrics{CompletionTokens: 100} // TotalTime == 0 + rm := mapRequestMetrics(m, 0, nil) + if rm.TPS != 0 { + t.Errorf("expected TPS=0 when TotalTime=0, got %v", rm.TPS) + } +} + +func TestMapRequestMetrics_ZeroPromptTokensSkipsCacheHitRate(t *testing.T) { + m := &client.ResponseMetrics{CachedInputTokens: 10} // PromptTokens == 0 + rm := mapRequestMetrics(m, 0, nil) + if rm.CacheHitRate != 0 { + t.Errorf("expected CacheHitRate=0 when PromptTokens=0, got %v", rm.CacheHitRate) + } +} + +// ── snapshotState ───────────────────────────────────────────────────────────── + +func TestSnapshotState_DeepCopiesRequests(t *testing.T) { + original := &RequestMetrics{Index: 0, Success: true} + ar := &activeRun{ + state: &RunState{ + Requests: []*RequestMetrics{original}, + }, + } + snap := ar.snapshotState() + + // Mutate original slice — snapshot must remain unchanged. + ar.state.Requests[0] = &RequestMetrics{Index: 99} + if snap.Requests[0].Index != 0 { + t.Error("Requests slice was not deep-copied: snapshot reflects mutation of original") + } +} + +func TestSnapshotState_DeepCopiesLevels(t *testing.T) { + ar := &activeRun{ + state: &RunState{ + Levels: []types.TurboLevelResult{{Concurrency: 5}}, + }, + } + snap := ar.snapshotState() + + ar.state.Levels[0] = types.TurboLevelResult{Concurrency: 99} + if snap.Levels[0].Concurrency != 5 { + t.Error("Levels slice was not deep-copied: snapshot reflects mutation of original") + } +} + +func TestSnapshotState_EmptySlicesNotCopied(t *testing.T) { + ar := &activeRun{ + state: &RunState{ + RunID: "run_snap", + Status: RunStatusRunning, + }, + } + snap := ar.snapshotState() + if snap.RunID != "run_snap" { + t.Errorf("RunID: got %q, want %q", snap.RunID, "run_snap") + } + if snap.Requests != nil { + t.Error("expected nil Requests for empty state") + } +} + +// ── historyPath ─────────────────────────────────────────────────────────────── + +func TestHistoryPath(t *testing.T) { + got := historyPath("/data/history", "task-abc") + want := filepath.Join("/data/history", "task-abc.json") + if got != want { + t.Errorf("historyPath: got %q, want %q", got, want) + } +} + +// ── task CRUD ───────────────────────────────────────────────────────────────── + +func TestListTasks_Empty(t *testing.T) { + s := newTestServer(t) + tasks := s.ListTasks() + if len(tasks) != 0 { + t.Errorf("expected empty list, got %d tasks", len(tasks)) + } +} + +func TestCreateTask_ReturnsTaskWithID(t *testing.T) { + s := newTestServer(t) + got, err := s.CreateTask(makeTaskConfig("my-task")) + if err != nil { + t.Fatalf("CreateTask: %v", err) + } + if got.Name != "my-task" { + t.Errorf("Name: got %q, want %q", got.Name, "my-task") + } + if got.ID == "" { + t.Error("expected non-empty ID") + } +} + +func TestCreateTask_AppearsInList(t *testing.T) { + s := newTestServer(t) + s.CreateTask(makeTaskConfig("task-a")) + all := s.ListTasks() + if len(all) != 1 { + t.Errorf("expected 1 task, got %d", len(all)) + } + if all[0].Name != "task-a" { + t.Errorf("Name: got %q, want task-a", all[0].Name) + } +} + +func TestCreateTask_MultipleTasksAllListed(t *testing.T) { + s := newTestServer(t) + for _, name := range []string{"alpha", "beta", "gamma"} { + if _, err := s.CreateTask(makeTaskConfig(name)); err != nil { + t.Fatalf("CreateTask %q: %v", name, err) + } + } + if len(s.ListTasks()) != 3 { + t.Errorf("expected 3 tasks, got %d", len(s.ListTasks())) + } +} + +func TestGetTask_Found(t *testing.T) { + s := newTestServer(t) + created, _ := s.CreateTask(makeTaskConfig("task-get")) + got, ok := s.GetTask(created.ID) + if !ok { + t.Fatal("GetTask returned not found") + } + if got.ID != created.ID { + t.Errorf("ID mismatch: %q vs %q", got.ID, created.ID) + } +} + +func TestGetTask_NotFound(t *testing.T) { + s := newTestServer(t) + _, ok := s.GetTask("nonexistent") + if ok { + t.Fatal("expected not found for nonexistent ID") + } +} + +func TestUpdateTask_Success(t *testing.T) { + s := newTestServer(t) + created, _ := s.CreateTask(makeTaskConfig("original")) + updated, err := s.UpdateTask(created.ID, makeTaskConfig("renamed")) + if err != nil { + t.Fatalf("UpdateTask: %v", err) + } + if updated.Name != "renamed" { + t.Errorf("Name: got %q, want renamed", updated.Name) + } + // Verify persistence via GetTask. + fetched, ok := s.GetTask(created.ID) + if !ok || fetched.Name != "renamed" { + t.Errorf("GetTask after update: ok=%v name=%q", ok, fetched.Name) + } +} + +func TestUpdateTask_NotFound(t *testing.T) { + s := newTestServer(t) + _, err := s.UpdateTask("missing-id", makeTaskConfig("x")) + if err == nil { + t.Fatal("expected error for missing task") + } +} + +func TestDeleteTask_Success(t *testing.T) { + s := newTestServer(t) + created, _ := s.CreateTask(makeTaskConfig("to-delete")) + if err := s.DeleteTask(created.ID); err != nil { + t.Fatalf("DeleteTask: %v", err) + } + if _, ok := s.GetTask(created.ID); ok { + t.Error("task still accessible after delete") + } + if len(s.ListTasks()) != 0 { + t.Error("expected empty list after delete") + } +} + +func TestDeleteTask_NotFound(t *testing.T) { + s := newTestServer(t) + if err := s.DeleteTask("missing-id"); err == nil { + t.Fatal("expected error for missing task") + } +} + +func TestCopyTask_CreatesNewTask(t *testing.T) { + s := newTestServer(t) + original, _ := s.CreateTask(makeTaskConfig("original")) + copied, err := s.CopyTask(original.ID) + if err != nil { + t.Fatalf("CopyTask: %v", err) + } + if copied.ID == original.ID { + t.Error("copy should have a new ID") + } + if copied.Name != "original (copy)" { + t.Errorf("Name: got %q, want %q", copied.Name, "original (copy)") + } + if len(s.ListTasks()) != 2 { + t.Errorf("expected 2 tasks after copy, got %d", len(s.ListTasks())) + } +} + +func TestCopyTask_NotFound(t *testing.T) { + s := newTestServer(t) + _, err := s.CopyTask("missing-id") + if err == nil { + t.Fatal("expected error for missing task") + } +} + +// ── run management ──────────────────────────────────────────────────────────── + +func TestStartRun_TaskNotFound(t *testing.T) { + s := newTestServer(t) + _, err := s.StartRun("no-such-task") + if err == nil { + t.Fatal("expected error for missing task") + } +} + +func TestStartRun_ReturnsRunIDAndRegistersActiveRun(t *testing.T) { + s := newTestServer(t) + task, _ := s.CreateTask(makeTaskConfig("run-task")) + runID, err := s.StartRun(task.ID) + if err != nil { + t.Fatalf("StartRun: %v", err) + } + if runID == "" { + t.Fatal("expected non-empty RunID") + } + state, ok := s.GetRunState(runID) + if !ok { + t.Fatal("GetRunState: run not found immediately after StartRun") + } + if state.TaskID != task.ID { + t.Errorf("TaskID: got %q, want %q", state.TaskID, task.ID) + } + // Initial status should be running (goroutine may not have progressed yet). + if state.Status != RunStatusRunning { + t.Errorf("Status: got %q, want %q", state.Status, RunStatusRunning) + } +} + +func TestGetRunState_NotFound(t *testing.T) { + s := newTestServer(t) + _, ok := s.GetRunState("run_nonexistent") + if ok { + t.Fatal("expected not found for unknown RunID") + } +} + +func TestStopRun_NotFound(t *testing.T) { + s := newTestServer(t) + err := s.StopRun("run_nonexistent") + if err == nil { + t.Fatal("expected error for unknown RunID") + } +} + +func TestStopRun_ActiveRunNoRunner(t *testing.T) { + s := newTestServer(t) + // Inject an activeRun with no runner/engine (neither rnr nor turboEngine). + runID := RunID("run_no_engine") + s.mu.Lock() + s.activeRuns[runID] = &activeRun{ + state: &RunState{RunID: runID, Status: RunStatusRunning}, + } + s.mu.Unlock() + + // Should not panic; both rnr and engine are nil — stop is a no-op. + if err := s.StopRun(runID); err != nil { + t.Fatalf("StopRun: unexpected error: %v", err) + } +} + +func TestGetHistory_EmptyForNewTask(t *testing.T) { + s := newTestServer(t) + task, _ := s.CreateTask(makeTaskConfig("hist-task")) + history, err := s.GetHistory(task.ID, 0) + if err != nil { + t.Fatalf("GetHistory: %v", err) + } + if len(history) != 0 { + t.Errorf("expected empty history, got %d entries", len(history)) + } +} + +func TestGetHistory_PersistsAfterRun(t *testing.T) { + s := newTestServer(t) + task, _ := s.CreateTask(makeTaskConfig("persist-task")) + + summary := types.TaskRunSummary{ + RunID: "run_test", + TaskID: task.ID, + Mode: "standard", + Status: string(RunStatusCompleted), + StartedAt: time.Now().Add(-time.Second), + FinishedAt: time.Now(), + } + s.persistRunResult(task.ID, s.historyDir, summary) + + history, err := s.GetHistory(task.ID, 0) + if err != nil { + t.Fatalf("GetHistory: %v", err) + } + if len(history) != 1 { + t.Fatalf("expected 1 history entry, got %d", len(history)) + } + if history[0].RunID != "run_test" { + t.Errorf("RunID: got %q, want run_test", history[0].RunID) + } +} + +func TestGetHistory_LimitRespected(t *testing.T) { + s := newTestServer(t) + task, _ := s.CreateTask(makeTaskConfig("limit-task")) + + for i := 0; i < 5; i++ { + s.persistRunResult(task.ID, s.historyDir, types.TaskRunSummary{ + RunID: "run_" + string(rune('0'+i)), + TaskID: task.ID, + StartedAt: time.Now(), + FinishedAt: time.Now(), + }) + } + + history, err := s.GetHistory(task.ID, 3) + if err != nil { + t.Fatalf("GetHistory: %v", err) + } + if len(history) != 3 { + t.Errorf("expected 3 entries with limit=3, got %d", len(history)) + } +} + +// ── GenerateReport ──────────────────────────────────────────────────────────── + +func TestGenerateReport_RunNotFound(t *testing.T) { + s := newTestServer(t) + _, err := s.GenerateReport("run_missing", ReportFormatJSON) + if err == nil { + t.Fatal("expected error for missing run") + } +} + +func TestGenerateReport_StillRunning(t *testing.T) { + s := newTestServer(t) + runID := RunID("run_in_progress") + s.mu.Lock() + s.activeRuns[runID] = &activeRun{ + state: &RunState{RunID: runID, Status: RunStatusRunning, Mode: "standard"}, + } + s.mu.Unlock() + + _, err := s.GenerateReport(runID, ReportFormatJSON) + if err == nil { + t.Fatal("expected error for in-progress run") + } + if !strings.Contains(err.Error(), "in progress") { + t.Errorf("unexpected error: %v", err) + } +} + +func TestGenerateReport_TurboNotSupported(t *testing.T) { + s := newTestServer(t) + runID := RunID("run_turbo") + s.mu.Lock() + s.activeRuns[runID] = &activeRun{ + state: &RunState{RunID: runID, Status: RunStatusCompleted, Mode: "turbo"}, + } + s.mu.Unlock() + + _, err := s.GenerateReport(runID, ReportFormatJSON) + if err == nil { + t.Fatal("expected error for turbo run") + } + if !strings.Contains(err.Error(), "turbo") { + t.Errorf("unexpected error message: %v", err) + } +} + +func TestGenerateReport_NoResultData(t *testing.T) { + s := newTestServer(t) + runID := RunID("run_no_result") + s.mu.Lock() + s.activeRuns[runID] = &activeRun{ + state: &RunState{RunID: runID, Status: RunStatusFailed, Mode: "standard"}, + } + s.mu.Unlock() + + _, err := s.GenerateReport(runID, ReportFormatJSON) + if err == nil { + t.Fatal("expected error for run with no result data") + } +} + +// ── Subscribe ───────────────────────────────────────────────────────────────── + +func TestSubscribe_DelegatesEventBus(t *testing.T) { + s := newTestServer(t) + runID := RunID("run_sub") + ch, cancel := s.Subscribe(runID) + defer cancel() + + ev := Event{RunID: runID, Kind: EventRunComplete} + s.bus.Publish(ev) + + select { + case got := <-ch: + if got != ev { + t.Fatalf("got %v, want %v", got, ev) + } + case <-time.After(time.Second): + t.Fatal("timeout waiting for event via Subscribe") + } +} + +func TestSubscribe_ChannelClosedAfterCloseRun(t *testing.T) { + s := newTestServer(t) + runID := RunID("run_lifecycle") + ch, _ := s.Subscribe(runID) + + s.bus.CloseRun(runID) + + select { + case _, ok := <-ch: + if ok { + t.Fatal("channel should be closed after CloseRun") + } + case <-time.After(time.Second): + t.Fatal("timeout: channel not closed after CloseRun") + } +} diff --git a/internal/server/types.go b/internal/server/types.go index 09564e8..b3d85b5 100644 --- a/internal/server/types.go +++ b/internal/server/types.go @@ -89,9 +89,9 @@ type RequestMetrics struct { TLSTime time.Duration TargetIP string ErrorMessage string - // 以下字段当前为空,待 client.ResponseMetrics 支持后填充 - PromptText string - ResponseText string + // 原始请求/响应数据(供请求详情页展示和复制) + RequestBody string + ResponseBody string } // EventKind 事件类型枚举。 @@ -111,8 +111,8 @@ const ( ) // Event 是推送给 TUI 层的通知。Payload 类型随 Kind 不同: -// - EventRequestDone → *RequestMetrics -// - EventProgressTick → *RunState(快照) +// - EventRequestDone → *RunState(含最新请求结果的完整快照) +// - EventProgressTick → *RunState(定时聚合快照) // - EventLevelDone → types.TurboLevelResult // - EventRunComplete → *RunState(最终快照) // - EventRunFailed → error diff --git a/internal/tui/model.go b/internal/tui/model.go index 52d18fc..7fc43e3 100644 --- a/internal/tui/model.go +++ b/internal/tui/model.go @@ -333,11 +333,7 @@ func (m *Model) handleServerEvent(msg ServerEventMsg) (tea.Model, tea.Cmd) { if m.taskList != nil { delete(m.taskList.ActiveRuns, taskID) } - task := m.findTask(taskID) - if task != nil { - m.detail = pages.NewTaskDetailState(*task) - } - m.view = viewTaskDetail + // 在后台刷新任务列表和历史,不自动跳转页面;用户可按 b/Esc 返回 return m, tea.Batch( m.client.LoadTasksCmd(), m.client.LoadHistoryCmd(taskID, 10), @@ -361,11 +357,7 @@ func (m *Model) handleServerEvent(msg ServerEventMsg) (tea.Model, tea.Cmd) { if m.taskList != nil { delete(m.taskList.ActiveRuns, taskID) } - task := m.findTask(taskID) - if task != nil { - m.detail = pages.NewTaskDetailState(*task) - } - m.view = viewTaskDetail + // 在后台刷新任务列表和历史,不自动跳转页面;用户可按 b/Esc 返回 return m, tea.Batch( m.client.LoadTasksCmd(), m.client.LoadHistoryCmd(taskID, 10), diff --git a/internal/tui/pages/contextbar.go b/internal/tui/pages/contextbar.go index eb45bf7..0ceb7e1 100644 --- a/internal/tui/pages/contextbar.go +++ b/internal/tui/pages/contextbar.go @@ -31,7 +31,7 @@ func CtxBar_TaskList_Running() []ContextBarItem { // CtxBar_TaskDetail_NoHistory 任务详情页,无运行记录时。 func CtxBar_TaskDetail_NoHistory() []ContextBarItem { return []ContextBarItem{ - {Key: "Enter/r", Desc: "运行"}, + {Key: "r", Desc: "运行"}, {Key: "e", Desc: "编辑"}, {Key: "y", Desc: "复制"}, {Key: "d", Desc: "删除"}, @@ -42,8 +42,9 @@ func CtxBar_TaskDetail_NoHistory() []ContextBarItem { func CtxBar_TaskDetail_HasHistory() []ContextBarItem { return []ContextBarItem{ {Key: "↑↓", Desc: "选择记录"}, - {Key: "r", Desc: "导出 JSON 报告"}, - {Key: "Enter", Desc: "再次运行"}, + {Key: "Enter", Desc: "展开/折叠详情"}, + {Key: "r", Desc: "再次运行"}, + {Key: "g", Desc: "导出 JSON 报告"}, {Key: "e", Desc: "编辑"}, {Key: "y", Desc: "复制任务"}, {Key: "d", Desc: "删除"}, diff --git a/internal/tui/pages/dashboard.go b/internal/tui/pages/dashboard.go index 813afea..72745d3 100644 --- a/internal/tui/pages/dashboard.go +++ b/internal/tui/pages/dashboard.go @@ -173,7 +173,7 @@ func RenderDashboard(d *DashboardState, taskName string, st Styles, width, heigh } l := PageLayout{ CtxItems: cbItems, - FooterParts: []string{"[q] 退出"}, + FooterParts: []string{"[b/Esc] 返回列表", "[q] 退出"}, } // ── 计算高度 ── @@ -320,6 +320,21 @@ func buildRequestList(d *DashboardState, rs *server.RunState, st Styles, width, r := reqs[i] isSel := i == d.ReqSel + if r == nil { + // 该请求尚未开始,渲染为等待中 + marker := selectionMarker(isSel) + rowContent := padRight(marker, markW) + + padRight(fmt.Sprintf("#%d", i+1), idW) + + padRight(st.Muted.Render("…"), statW) + + padRight(st.Muted.Render("等待中"), timeW) + + padRight("─", ttftW) + + padRight("─", cacheW) + + padRight("─", tokW) + + "─" + lines = append(lines, renderTableRow(st, width, isSel, rowContent)) + continue + } + statusText := "✓" if !r.Success { statusText = "✗" diff --git a/internal/tui/pages/reqdetail.go b/internal/tui/pages/reqdetail.go index 62efab4..15cc0e6 100644 --- a/internal/tui/pages/reqdetail.go +++ b/internal/tui/pages/reqdetail.go @@ -195,16 +195,16 @@ func buildReqNetworkPanel(r *server.RequestMetrics, st Styles, maxH, width int) return strings.Join(lines[:maxH], "\n") } -// buildInputSection 构建输入 (Prompt) 区域。 +// buildInputSection 构建输入 (请求体) 区域。 func buildInputSection(r *server.RequestMetrics, st Styles, width, maxH int) string { var lines []string - lines = append(lines, " "+st.SectionHead.Render("输入 (Prompt)")) + lines = append(lines, " "+st.SectionHead.Render("请求体 (Request Body)")) lines = append(lines, " "+dividerLine(st, width-2)) - if r.PromptText == "" { + if r.RequestBody == "" { lines = append(lines, " "+st.Muted.Render("(未记录)")) } else { - for _, l := range wrapText(r.PromptText, width-3) { + for _, l := range wrapText(r.RequestBody, width-3) { if len(lines) >= maxH-1 { break } @@ -218,16 +218,16 @@ func buildInputSection(r *server.RequestMetrics, st Styles, width, maxH int) str return strings.Join(lines[:maxH], "\n") } -// buildOutputSection 构建输出 (Response) 区域。 +// buildOutputSection 构建输出 (响应体) 区域。 func buildOutputSection(r *server.RequestMetrics, scrollY int, st Styles, width, maxH int) string { var lines []string - lines = append(lines, " "+st.SectionHead.Render("输出 (Response)")) + lines = append(lines, " "+st.SectionHead.Render("响应体 (Response Body)")) lines = append(lines, " "+dividerLine(st, width-2)) - if r.ResponseText == "" { + if r.ResponseBody == "" { lines = append(lines, " "+st.Muted.Render("(未记录)")) } else { - allLines := wrapText(r.ResponseText, width-3) + allLines := wrapText(r.ResponseBody, width-3) if scrollY >= len(allLines) { scrollY = len(allLines) - 1 } diff --git a/internal/tui/pages/taskdetail.go b/internal/tui/pages/taskdetail.go index 5438261..88f99eb 100644 --- a/internal/tui/pages/taskdetail.go +++ b/internal/tui/pages/taskdetail.go @@ -34,27 +34,33 @@ func HandleTaskDetailKey(s *TaskDetailState, msg tea.KeyMsg, client Client) (*Ta case "up", "k": if s.HistorySel > 0 { s.HistorySel-- + s.LatestExpanded = false } case "down", "j": if s.HistorySel < len(s.History)-1 { s.HistorySel++ + s.LatestExpanded = false + } + + case "enter": + if len(s.History) > 0 { + s.LatestExpanded = !s.LatestExpanded } case "left", "esc", "b": nav = NavAction{To: NavTaskList} - case "enter": + case "r": return s, client.StartRunCmd(s.Task.ID), nav - case "r": + case "g": if s.HistorySel >= 0 && s.HistorySel < len(s.History) { runID := strings.TrimSpace(s.History[s.HistorySel].RunID) if runID != "" { return s, client.GenerateReportCmd(server.RunID(runID), server.ReportFormatJSON), nav } } - return s, client.StartRunCmd(s.Task.ID), nav case "e": t := s.Task @@ -184,7 +190,10 @@ func buildTaskDetailContent(s *TaskDetailState, st Styles, t types.TaskDefinitio if len(s.History) == 0 { rightLines = append(rightLines, padRight(" "+st.Muted.Render("暂无运行记录"), rightW)) } else { - detailLines := buildTaskHistoryDetailLines(s, st, rightW) + var detailLines []string + if s.LatestExpanded { + detailLines = buildTaskHistoryDetailLines(s, st, rightW) + } tableMaxH := maxH - len(detailLines) if tableMaxH < 5 { allowedDetail := maxInt(0, maxH-5) diff --git a/internal/tui/pages/turbodash.go b/internal/tui/pages/turbodash.go index eb12fbe..07e3a80 100644 --- a/internal/tui/pages/turbodash.go +++ b/internal/tui/pages/turbodash.go @@ -147,7 +147,7 @@ func RenderTurboDash(d *TurboDashState, taskName string, st Styles, width, heigh } l := PageLayout{ CtxItems: cbItems, - FooterParts: []string{"[q] 退出"}, + FooterParts: []string{"[b/Esc] 返回列表", "[q] 退出"}, } // ── 计算高度 ── From 53c779671a61fbac65eab217c26d3a8546d461e3 Mon Sep 17 00:00:00 2001 From: Alain Date: Sun, 17 May 2026 17:22:56 +0800 Subject: [PATCH 12/52] =?UTF-8?q?feat:=20=E6=94=AF=E6=8C=81=E7=B3=BB?= =?UTF-8?q?=E7=BB=9F=E6=8F=90=E7=A4=BA=E5=92=8C=E7=94=A8=E6=88=B7=E6=8F=90?= =?UTF-8?q?=E7=A4=BA=E7=9A=84=E8=AF=B7=E6=B1=82=E5=A4=84=E7=90=86=EF=BC=8C?= =?UTF-8?q?=E5=A2=9E=E5=BC=BA=E8=AF=B7=E6=B1=82=E6=9E=84=E5=BB=BA=E9=80=BB?= =?UTF-8?q?=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/client/anthropic.go | 11 +++++-- internal/client/client.go | 3 +- internal/client/openai.go | 37 ++++++++++++++-------- internal/prompt/prompt.go | 53 +++++++++++++++++++++++++++----- internal/runner/runner.go | 22 ++++++++----- internal/server/run.go | 38 ++++++++++++++++++++--- internal/tui/client.go | 11 +++++++ internal/tui/messages.go | 3 +- internal/tui/model.go | 19 +++++++++--- internal/tui/pages/contextbar.go | 2 +- internal/tui/pages/nav.go | 2 ++ internal/tui/pages/taskdetail.go | 7 +++-- internal/types/types.go | 1 + 13 files changed, 163 insertions(+), 46 deletions(-) diff --git a/internal/client/anthropic.go b/internal/client/anthropic.go index f778026..8e5f041 100644 --- a/internal/client/anthropic.go +++ b/internal/client/anthropic.go @@ -106,10 +106,10 @@ func (c *AnthropicClient) SetLogger(l *logger.Logger) { } // Request 发送 Anthropic 协议请求(支持流式和非流式) -func (c *AnthropicClient) Request(prompt string, stream bool) (*ResponseMetrics, error) { +func (c *AnthropicClient) Request(systemPrompt, userPrompt string, stream bool) (*ResponseMetrics, error) { // 记录请求开始日志 if c.logger != nil && c.logger.IsEnabled() { - c.logger.LogTestStart(c.Model, prompt, map[string]interface{}{ + c.logger.LogTestStart(c.Model, userPrompt, map[string]interface{}{ "stream": stream, "protocol": c.Provider, "endpoint_url": c.EndpointURL, @@ -122,12 +122,17 @@ func (c *AnthropicClient) Request(prompt string, stream bool) (*ResponseMetrics, "messages": []map[string]interface{}{ { "role": "user", - "content": prompt, + "content": userPrompt, }, }, "stream": stream, } + // 如果有 system prompt,添加顶层 system 字段(Anthropic API 规范) + if systemPrompt != "" { + requestBody["system"] = systemPrompt + } + // 如果启用了 thinking 模式,添加 thinking 配置 if c.Thinking { requestBody["thinking"] = map[string]interface{}{ diff --git a/internal/client/client.go b/internal/client/client.go index 93b48f4..9bada71 100644 --- a/internal/client/client.go +++ b/internal/client/client.go @@ -36,7 +36,8 @@ type ResponseMetrics struct { // ModelClient 定义统一的模型客户端接口 type ModelClient interface { - Request(prompt string, stream bool) (*ResponseMetrics, error) + // Request 发送请求。systemPrompt 为空时行为与原来相同(不添加 system 消息)。 + Request(systemPrompt, userPrompt string, stream bool) (*ResponseMetrics, error) GetProtocol() string GetModel() string SetLogger(logger *logger.Logger) // 设置日志记录器 diff --git a/internal/client/openai.go b/internal/client/openai.go index 841ea05..b06e455 100644 --- a/internal/client/openai.go +++ b/internal/client/openai.go @@ -172,11 +172,15 @@ func extractCachedInputTokens(details *PromptTokensDetails) int { return details.CachedTokens } -func (c *OpenAIClient) buildRequestBody(prompt string, stream bool) ([]byte, error) { +func (c *OpenAIClient) buildRequestBody(systemPrompt, userPrompt string, stream bool) ([]byte, error) { if c.Provider == types.ProtocolOpenAIResponses { + input := userPrompt + if systemPrompt != "" { + input = systemPrompt + "\n\n" + userPrompt + } reqBody := ResponsesAPIRequest{ Model: c.Model, - Input: prompt, + Input: input, Stream: stream, } if c.Thinking { @@ -185,15 +189,22 @@ func (c *OpenAIClient) buildRequestBody(prompt string, stream bool) ([]byte, err return json.Marshal(reqBody) } + var messages []ChatCompletionMessage + if systemPrompt != "" { + messages = append(messages, ChatCompletionMessage{ + Role: "system", + Content: systemPrompt, + }) + } + messages = append(messages, ChatCompletionMessage{ + Role: "user", + Content: userPrompt, + }) + reqBody := ChatCompletionRequest{ - Model: c.Model, - Messages: []ChatCompletionMessage{ - { - Role: "user", - Content: prompt, - }, - }, - Stream: stream, + Model: c.Model, + Messages: messages, + Stream: stream, } if stream { @@ -379,17 +390,17 @@ func (c *OpenAIClient) SetLogger(l *logger.Logger) { } // Request 发送 OpenAI 协议请求(支持流式和非流式) -func (c *OpenAIClient) Request(prompt string, stream bool) (*ResponseMetrics, error) { +func (c *OpenAIClient) Request(systemPrompt, userPrompt string, stream bool) (*ResponseMetrics, error) { // 记录请求开始日志 if c.logger != nil && c.logger.IsEnabled() { - c.logger.LogTestStart(c.Model, prompt, map[string]interface{}{ + c.logger.LogTestStart(c.Model, userPrompt, map[string]interface{}{ "stream": stream, "protocol": c.Provider, "endpoint_url": c.endpointURL, }) } - jsonData, err := c.buildRequestBody(prompt, stream) + jsonData, err := c.buildRequestBody(systemPrompt, userPrompt, stream) if err != nil { // 记录错误日志 if c.logger != nil && c.logger.IsEnabled() { diff --git a/internal/prompt/prompt.go b/internal/prompt/prompt.go index 41062f9..9761f7d 100644 --- a/internal/prompt/prompt.go +++ b/internal/prompt/prompt.go @@ -16,6 +16,7 @@ type PromptSource struct { IsFile bool // 是否来自文件 FilePaths []string // 文件路径列表 Contents []string // prompt内容列表(仅用于非文件内容) + SystemContent string // 固定的系统消息内容(仅 generated 模式使用,用于触发前缀缓存) DisplayText string // 用于显示的文本 ShouldTruncate bool // 是否需要截断显示(对于已经包含长度信息的内容,不需要再次处理) } @@ -98,6 +99,12 @@ func loadMultipleFiles(pattern string) (*PromptSource, error) { }, nil } +// GetSystemContent 返回系统消息内容(固定的大段上下文,用于前缀缓存)。 +// 非 generated 模式返回空字符串,不影响原有请求结构。 +func (ps *PromptSource) GetSystemContent() string { + return ps.SystemContent +} + // GetRandomContent 随机获取一个prompt内容 func (ps *PromptSource) GetRandomContent() string { // 如果不是文件源,直接返回内容 @@ -144,10 +151,14 @@ func (ps *PromptSource) GetRandomContent() string { func (ps *PromptSource) GetContentByIndex(index int) string { // 如果不是文件源,直接返回内容 if !ps.IsFile { - if index < 0 || index >= len(ps.Contents) { + if len(ps.Contents) == 0 { return ps.GetRandomContent() } - return ps.Contents[index] + if index < 0 { + return ps.GetRandomContent() + } + // 用取模循环,确保多个请求在有限 Contents 上均匀分布 + return ps.Contents[index%len(ps.Contents)] } // 文件源:根据索引读取对应文件 @@ -261,20 +272,46 @@ func GeneratePromptByLength(length int) string { return builder.String() } -// LoadPromptByLength 创建指定长度的 PromptSource +// LoadPromptByLength 创建指定长度的 PromptSource。 +// +// 为了让测试中部分请求满足前缀缓存条件(Prefix Cache),内容被拆分为两部分: +// - SystemContent(约 90% 长度):固定不变的大段上下文,作为 system 消息发送; +// 同一批次所有请求共享相同的 system 消息,API 侧命中前缀缓存后可大幅降低延迟。 +// - Contents(user 消息候选列表):多条短问题,每个请求按 index 取模轮流使用, +// 既保证请求内容有差异,又确保 system 前缀不变以触发缓存。 func LoadPromptByLength(length int) (*PromptSource, error) { if length <= 0 { return nil, fmt.Errorf("prompt 长度必须大于 0") } - content := GeneratePromptByLength(length) - actualLength := utf8.RuneCountInString(content) + // 90% 作为 system 消息(固定,供缓存命中) + systemLen := length * 9 / 10 + if systemLen < 1 { + systemLen = 1 + } + systemContent := GeneratePromptByLength(systemLen) + actualSystemLen := utf8.RuneCountInString(systemContent) + + // 短而多样的 user 消息,各请求轮流使用(保证差异 + 共享 system 前缀) + userQuestions := []string{ + "请帮我总结一下上述内容的核心要点。", + "根据以上信息,有什么值得特别关注的地方?", + "上述内容中最重要的信息是什么?", + "请对以上内容进行简短分析。", + "上述内容的主要主题是什么,请概括。", + "从以上内容中能得出哪些结论?", + "以上内容有哪些值得深入探讨的点?", + "请提炼上述内容的关键信息。", + "对以上内容你有什么看法?", + "上述内容对实际应用有什么启示?", + } return &PromptSource{ IsFile: false, FilePaths: nil, - Contents: []string{content}, - DisplayText: fmt.Sprintf("生成内容 (长度: %d 字符)", actualLength), - ShouldTruncate: false, // 已经包含长度信息,不需要再次截断处理 + Contents: userQuestions, + SystemContent: systemContent, + DisplayText: fmt.Sprintf("生成内容 (系统消息: %d 字符,轮换用户问题 x%d)", actualSystemLen, len(userQuestions)), + ShouldTruncate: false, }, nil } diff --git a/internal/runner/runner.go b/internal/runner/runner.go index 47f20c9..01d209c 100644 --- a/internal/runner/runner.go +++ b/internal/runner/runner.go @@ -87,9 +87,10 @@ func (r *Runner) Run() (*types.ReportData, error) { defer func() { <-ch }() // 获取当前请求使用的prompt - currentPrompt := r.input.PromptSource.GetRandomContent() + systemPrompt := r.input.PromptSource.GetSystemContent() + userPrompt := r.input.PromptSource.GetContentByIndex(idx) - metrics, err := r.client.Request(currentPrompt, r.input.Stream) + metrics, err := r.client.Request(systemPrompt, userPrompt, r.input.Stream) if err != nil { // 即使有错误,也尝试保存 metrics(如果有的话) if metrics != nil { @@ -129,8 +130,9 @@ func (r *Runner) RunWithCallback(cb RequestDoneCallback) (*types.ReportData, err defer wg.Done() defer func() { <-ch }() - currentPrompt := r.input.PromptSource.GetRandomContent() - metrics, err := r.client.Request(currentPrompt, r.input.Stream) + systemPrompt := r.input.PromptSource.GetSystemContent() + userPrompt := r.input.PromptSource.GetContentByIndex(idx) + metrics, err := r.client.Request(systemPrompt, userPrompt, r.input.Stream) if metrics != nil { results[idx] = metrics } @@ -231,9 +233,10 @@ func (r *Runner) RunWithProgress(progressCallback func(types.StatsData)) (*types defer func() { <-ch }() // 获取当前请求使用的prompt - currentPrompt := r.input.PromptSource.GetRandomContent() + systemPrompt := r.input.PromptSource.GetSystemContent() + userPrompt := r.input.PromptSource.GetContentByIndex(idx) - metrics, err := r.client.Request(currentPrompt, r.input.Stream) + metrics, err := r.client.Request(systemPrompt, userPrompt, r.input.Stream) if err != nil { ttftsMutex.Lock() errorMessages = append(errorMessages, err.Error()) @@ -540,9 +543,12 @@ func (r *Runner) calculateResult(results []*client.ResponseMetrics, totalTime ti } } + // successCount 基于真正成功的请求(有输出 token 且无错误) + // validCount 可能是 successCount 的 fallback 集,仅用于计算平均指标,不参与成功率 + successCount := len(successResults) validCount := len(validResults) - errorRate := float64(requestCount-validCount) / float64(requestCount) * 100 - successRate := float64(validCount) / float64(requestCount) * 100 + errorRate := float64(requestCount-successCount) / float64(requestCount) * 100 + successRate := float64(successCount) / float64(requestCount) * 100 resolvedEndpoint := r.input.ResolvedEndpointURL() if validCount == 0 { diff --git a/internal/server/run.go b/internal/server/run.go index f1a172a..cf634b5 100644 --- a/internal/server/run.go +++ b/internal/server/run.go @@ -86,6 +86,17 @@ func historyPath(historyDir, taskID string) string { return filepath.Join(historyDir, taskID+".json") } +// runStatePath 返回指定运行的完整状态快照文件路径(用于历史回放)。 +func runStatePath(historyDir string, runID RunID) string { + return filepath.Join(historyDir, "runs", string(runID)+".json") +} + +// persistRunState 将完整 RunState 快照写入磁盘,供历史回放使用。 +func persistRunState(historyDir string, snap *RunState) { + st := store.NewJSONStore[*RunState](runStatePath(historyDir, snap.RunID)) + _ = st.Save(snap) // 失败不影响主流程 +} + // StartRun 启动一次新的运行,立即返回 RunID。 func (s *serverImpl) StartRun(taskID string) (RunID, error) { s.mu.RLock() @@ -246,6 +257,9 @@ func (s *serverImpl) completeStandardRun(ar *activeRun, runID RunID, taskDef typ s.bus.Publish(Event{RunID: runID, Kind: EventRunComplete, Payload: snap}) s.bus.CloseRun(runID) + // 将完整运行状态持久化到磁盘,供历史详情页回放 + persistRunState(historyDir, snap) + summary := types.TaskRunSummary{ RunID: string(runID), TaskID: taskDef.ID, @@ -282,6 +296,9 @@ func (s *serverImpl) completeTurboRun(ar *activeRun, runID RunID, taskDef types. s.bus.Publish(Event{RunID: runID, Kind: EventRunComplete, Payload: snap}) s.bus.CloseRun(runID) + // 将完整运行状态持久化到磁盘,供历史详情页回放 + persistRunState(historyDir, snap) + var maxStable int var peakTPS float64 if result != nil { @@ -319,6 +336,9 @@ func (s *serverImpl) failRun(ar *activeRun, runID RunID, taskDef types.TaskDefin s.bus.Publish(Event{RunID: runID, Kind: EventRunFailed, Payload: runErr}) s.bus.CloseRun(runID) + // 将完整运行状态持久化到磁盘,供历史详情页回放 + persistRunState(historyDir, snap) + summary := types.TaskRunSummary{ RunID: string(runID), TaskID: taskDef.ID, @@ -377,18 +397,26 @@ func (s *serverImpl) StopRun(runID RunID) error { } // GetRunState 返回指定运行的当前状态快照。 +// 先查内存中的 activeRuns;若不存在,再尝试从磁盘加载持久化的快照(历史回放)。 func (s *serverImpl) GetRunState(runID RunID) (*RunState, bool) { s.mu.RLock() ar, ok := s.activeRuns[runID] + historyDir := s.historyDir s.mu.RUnlock() - if !ok { - return nil, false + if ok { + ar.mu.RLock() + snap := ar.snapshotState() + ar.mu.RUnlock() + return snap, true } - ar.mu.RLock() - snap := ar.snapshotState() - ar.mu.RUnlock() + // 不在内存中,尝试从磁盘加载持久化的 RunState 快照 + st := store.NewJSONStore[*RunState](runStatePath(historyDir, runID)) + snap, err := st.Load() + if err != nil || snap == nil { + return nil, false + } return snap, true } diff --git a/internal/tui/client.go b/internal/tui/client.go index 9ddf935..7c924f9 100644 --- a/internal/tui/client.go +++ b/internal/tui/client.go @@ -135,6 +135,17 @@ func (c *Client) GetRunStateCmd(runID server.RunID) tea.Cmd { } } +// GetRunStateForHistoryCmd 从历史记录导航时异步加载运行状态快照。 +func (c *Client) GetRunStateForHistoryCmd(runID server.RunID) tea.Cmd { + return func() tea.Msg { + state, ok := c.srv.GetRunState(runID) + if !ok { + return ErrorMsg{Err: fmt.Errorf("该次运行数据不在内存中,请重新运行")} + } + return RunStateMsg{State: state, FromHistory: true} + } +} + // GenerateReportCmd 异步生成报告文件。 func (c *Client) GenerateReportCmd(runID server.RunID, format server.ReportFormat) tea.Cmd { return func() tea.Msg { diff --git a/internal/tui/messages.go b/internal/tui/messages.go index 7321d63..f64791f 100644 --- a/internal/tui/messages.go +++ b/internal/tui/messages.go @@ -40,7 +40,8 @@ type ServerEventMsg struct { // RunStateMsg server.GetRunState 的轮询结果(用于后台运行恢复仪表盘)。 type RunStateMsg struct { - State *server.RunState + State *server.RunState + FromHistory bool // true = 从历史记录导航过来,需新建 dash 并切换视图 } // ReportGeneratedMsg 报告文件生成完成。 diff --git a/internal/tui/model.go b/internal/tui/model.go index 7fc43e3..89a4efb 100644 --- a/internal/tui/model.go +++ b/internal/tui/model.go @@ -146,13 +146,20 @@ func (m *Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { case ServerEventMsg: return m.handleServerEvent(msg) - // ── 运行状态快照(重入仪表盘时) ── + // ── 运行状态快照(重入仪表盘时 / 从历史导航时) ── case RunStateMsg: - if m.dash != nil && msg.State != nil && m.dash.RunID == msg.State.RunID { - m.dash.RunState = msg.State + if msg.State == nil { + return m, nil } - if m.turboDash != nil && msg.State != nil && m.turboDash.RunID == msg.State.RunID { + if m.dash != nil && m.dash.RunID == msg.State.RunID { + m.dash.RunState = msg.State + } else if m.turboDash != nil && m.turboDash.RunID == msg.State.RunID { m.turboDash.RunState = msg.State + } else if msg.FromHistory { + // 从历史记录导航过来:用加载到的 RunState 新建 dash 并切换视图 + m.dash = pages.NewDashboardState(msg.State.RunID, msg.State.TaskID) + m.dash.RunState = msg.State + m.view = viewDashboard } return m, nil @@ -286,6 +293,10 @@ func (m *Model) handleNav(nav pages.NavAction) tea.Cmd { } return nil + case pages.NavRunDetail: + // 从历史记录进入某次运行的仪表盘 + return m.client.GetRunStateForHistoryCmd(nav.RunID) + case pages.NavReqDetail: reqs := m.collectRequests() m.reqDetail = pages.NewReqDetailState(m.currentRunID(), reqs, nav.ReqIndex) diff --git a/internal/tui/pages/contextbar.go b/internal/tui/pages/contextbar.go index 0ceb7e1..bb40c4c 100644 --- a/internal/tui/pages/contextbar.go +++ b/internal/tui/pages/contextbar.go @@ -42,7 +42,7 @@ func CtxBar_TaskDetail_NoHistory() []ContextBarItem { func CtxBar_TaskDetail_HasHistory() []ContextBarItem { return []ContextBarItem{ {Key: "↑↓", Desc: "选择记录"}, - {Key: "Enter", Desc: "展开/折叠详情"}, + {Key: "Enter", Desc: "查看运行详情"}, {Key: "r", Desc: "再次运行"}, {Key: "g", Desc: "导出 JSON 报告"}, {Key: "e", Desc: "编辑"}, diff --git a/internal/tui/pages/nav.go b/internal/tui/pages/nav.go index 765dcf7..1277179 100644 --- a/internal/tui/pages/nav.go +++ b/internal/tui/pages/nav.go @@ -19,6 +19,7 @@ const ( NavWizard // 打开向导(EditTask == nil 为新建) NavDashboard // 进入仪表盘(需 RunID + TaskID) NavTurboDash // 进入 Turbo 仪表盘(需 RunID + TaskID) + NavRunDetail // 从历史记录进入某次运行的仪表盘(需 RunID) NavReqDetail // 进入请求详情(需 ReqIndex) NavQuit // 退出程序 ) @@ -49,5 +50,6 @@ type Client interface { // 历史 & 报告 LoadHistoryCmd(taskID string, limit int) tea.Cmd GetRunStateCmd(runID server.RunID) tea.Cmd + GetRunStateForHistoryCmd(runID server.RunID) tea.Cmd GenerateReportCmd(runID server.RunID, format server.ReportFormat) tea.Cmd } diff --git a/internal/tui/pages/taskdetail.go b/internal/tui/pages/taskdetail.go index 88f99eb..018d840 100644 --- a/internal/tui/pages/taskdetail.go +++ b/internal/tui/pages/taskdetail.go @@ -44,8 +44,11 @@ func HandleTaskDetailKey(s *TaskDetailState, msg tea.KeyMsg, client Client) (*Ta } case "enter": - if len(s.History) > 0 { - s.LatestExpanded = !s.LatestExpanded + if s.HistorySel >= 0 && s.HistorySel < len(s.History) { + runID := strings.TrimSpace(s.History[s.HistorySel].RunID) + if runID != "" { + nav = NavAction{To: NavRunDetail, RunID: server.RunID(runID)} + } } case "left", "esc", "b": diff --git a/internal/types/types.go b/internal/types/types.go index a4d7836..024af8a 100644 --- a/internal/types/types.go +++ b/internal/types/types.go @@ -78,6 +78,7 @@ func ResolveEndpointURL(protocol, endpointURL, baseURL string) string { // PromptSource 需要前向声明,实际定义在 prompt 包中 type PromptSource interface { + GetSystemContent() string GetRandomContent() string GetContentByIndex(index int) string Count() int From 92fac50261893baa157936a6133c8eb7c617a17f Mon Sep 17 00:00:00 2001 From: Alain Date: Sun, 17 May 2026 23:19:34 +0800 Subject: [PATCH 13/52] =?UTF-8?q?feat:=20=E4=BC=98=E5=8C=96=E8=BF=9B?= =?UTF-8?q?=E5=BA=A6=E6=9D=A1=E6=98=BE=E7=A4=BA=E9=80=BB=E8=BE=91=EF=BC=8C?= =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E6=95=B0=E6=8D=AE=E5=8A=A0=E8=BD=BD=E6=8F=90?= =?UTF-8?q?=E7=A4=BA=EF=BC=8C=E5=A2=9E=E5=BC=BA=E5=90=91=E5=AF=BC=E5=AD=97?= =?UTF-8?q?=E6=AE=B5=E5=AE=9A=E4=B9=89?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/server/run.go | 177 ++++++++++++++++++++++++++++--- internal/tui/client.go | 32 +++++- internal/tui/model.go | 71 +++++++++++-- internal/tui/pages/contextbar.go | 60 ++++++++--- internal/tui/pages/dashboard.go | 106 +++++++++--------- internal/tui/pages/helpers.go | 22 ++++ internal/tui/pages/nav.go | 5 +- internal/tui/pages/reqdetail.go | 28 ++++- internal/tui/pages/taskdetail.go | 169 ++++++++++++++++++++--------- internal/tui/pages/tasklist.go | 118 ++++++++++----------- internal/tui/pages/turbodash.go | 52 ++++++--- internal/tui/pages/wizard.go | 17 ++- 12 files changed, 629 insertions(+), 228 deletions(-) diff --git a/internal/server/run.go b/internal/server/run.go index cf634b5..6265a2e 100644 --- a/internal/server/run.go +++ b/internal/server/run.go @@ -1,9 +1,14 @@ package server import ( + "bufio" + "encoding/json" "fmt" + "os" "path/filepath" + "sort" "sync" + "sync/atomic" "time" "github.com/yinxulai/ait/internal/client" @@ -28,6 +33,21 @@ type activeRun struct { doneCount int // 与 state.DoneReqs 保持同步,方便不加锁时计算 } +// callbackLevelRunner 包装 runner.Runner,在每次请求完成时调用回调, +// 使 turbo 运行也能逐请求采集详细指标数据。 +type callbackLevelRunner struct { + r *runner.Runner + cb runner.RequestDoneCallback +} + +func (c *callbackLevelRunner) Run() (*types.ReportData, error) { + return c.r.RunWithCallback(c.cb) +} + +func (c *callbackLevelRunner) Stop() { + c.r.Stop() +} + // snapshotState 返回 state 的深度拷贝(调用方须已持有 activeRun.mu 读锁)。 func (ar *activeRun) snapshotState() *RunState { s := ar.state @@ -91,10 +111,65 @@ func runStatePath(historyDir string, runID RunID) string { return filepath.Join(historyDir, "runs", string(runID)+".json") } -// persistRunState 将完整 RunState 快照写入磁盘,供历史回放使用。 +// requestsFilePath 返回指定运行的请求详情 JSONL 文件路径。 +func requestsFilePath(historyDir string, runID RunID) string { + return filepath.Join(historyDir, "runs", string(runID)+".jsonl") +} + +// appendRequestToDisk 将单条 RequestMetrics 以 JSON 行的形式追加写入磁盘。 +func appendRequestToDisk(historyDir string, runID RunID, rm *RequestMetrics) { + path := requestsFilePath(historyDir, runID) + f, err := os.OpenFile(path, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) + if err != nil { + return + } + defer f.Close() + data, err := json.Marshal(rm) + if err != nil { + return + } + _, _ = f.Write(data) + _, _ = f.Write([]byte{'\n'}) +} + +// loadRequestsFromDisk 从 JSONL 文件中加载所有 RequestMetrics,按 Index 排序。 +func loadRequestsFromDisk(historyDir string, runID RunID) []*RequestMetrics { + path := requestsFilePath(historyDir, runID) + f, err := os.Open(path) + if err != nil { + return nil + } + defer f.Close() + + const maxLineSize = 16 * 1024 * 1024 // 16 MB per line + buf := make([]byte, maxLineSize) + scanner := bufio.NewScanner(f) + scanner.Buffer(buf, maxLineSize) + + var reqs []*RequestMetrics + for scanner.Scan() { + line := scanner.Bytes() + if len(line) == 0 { + continue + } + var rm RequestMetrics + if err := json.Unmarshal(line, &rm); err != nil { + continue + } + reqs = append(reqs, &rm) + } + sort.Slice(reqs, func(i, j int) bool { + return reqs[i].Index < reqs[j].Index + }) + return reqs +} + +// persistRunState 将 RunState 元数据写入磁盘(不含 Requests,请求详情在 JSONL 文件中)。 func persistRunState(historyDir string, snap *RunState) { + toSave := *snap + toSave.Requests = nil // 请求详情已逐条写入 JSONL,避免重复存储 st := store.NewJSONStore[*RunState](runStatePath(historyDir, snap.RunID)) - _ = st.Save(snap) // 失败不影响主流程 + _ = st.Save(&toSave) // 失败不影响主流程 } // StartRun 启动一次新的运行,立即返回 RunID。 @@ -123,13 +198,18 @@ func (s *serverImpl) StartRun(taskID string) (RunID, error) { } state := &RunState{ - RunID: runID, - TaskID: taskID, - Status: RunStatusRunning, - Mode: mode, - StartedAt: now, - TotalReqs: hydratedInput.Count, - Requests: make([]*RequestMetrics, hydratedInput.Count), + RunID: runID, + TaskID: taskID, + Status: RunStatusRunning, + Mode: mode, + StartedAt: now, + } + if hydratedInput.Turbo { + // turbo 模式:跨多个并发级别探测,请求总数不固定,动态追加 + state.TotalReqs = 0 + } else { + // standard 模式:请求数固定,动态追加(按完成顺序) + state.TotalReqs = hydratedInput.Count } ar := &activeRun{state: state} @@ -179,11 +259,10 @@ func (s *serverImpl) runStandard(ar *activeRun, runID RunID, taskDef types.TaskD reportData, err := rnr.RunWithCallback(func(metrics *client.ResponseMetrics, idx int, cbErr error) { rm := mapRequestMetrics(metrics, idx, cbErr) + appendRequestToDisk(historyDir, runID, rm) ar.mu.Lock() - if idx < len(ar.state.Requests) { - ar.state.Requests[idx] = rm - } + ar.state.Requests = append(ar.state.Requests, rm) ar.state.DoneReqs++ if rm.Success { ar.state.SuccessReqs++ @@ -222,7 +301,50 @@ func (s *serverImpl) runStandard(ar *activeRun, runID RunID, taskDef types.TaskD // runTurbo 在 goroutine 中执行 Turbo 运行。 func (s *serverImpl) runTurbo(ar *activeRun, runID RunID, taskDef types.TaskDefinition, input types.Input, historyDir string) { - engine := turbo.New(turbo.DefaultRunnerFactory(taskDef.ID)) + // 全局请求计数器(原子递增),确保跨多个并发级别的请求索引唯一 + var globalIdx int64 + + factory := func(levelInput types.Input) (turbo.LevelRunner, error) { + r, err := runner.NewRunner(taskDef.ID, levelInput) + if err != nil { + return nil, err + } + return &callbackLevelRunner{ + r: r, + cb: func(metrics *client.ResponseMetrics, _ int, cbErr error) { + gIdx := int(atomic.AddInt64(&globalIdx, 1)) - 1 + rm := mapRequestMetrics(metrics, gIdx, cbErr) + appendRequestToDisk(historyDir, runID, rm) + + ar.mu.Lock() + ar.state.Requests = append(ar.state.Requests, rm) + ar.state.TotalReqs++ + ar.state.DoneReqs++ + if rm.Success { + ar.state.SuccessReqs++ + ar.tpsSum += rm.TPS + ar.ttftSum += rm.TTFT + ar.cacheSum += rm.CacheHitRate + } else { + ar.state.FailedReqs++ + } + if ar.state.SuccessReqs > 0 { + ar.state.AvgTPS = ar.tpsSum / float64(ar.state.SuccessReqs) + ar.state.AvgTTFT = ar.ttftSum / time.Duration(ar.state.SuccessReqs) + ar.state.CacheHitRate = ar.cacheSum / float64(ar.state.SuccessReqs) + } + if ar.state.DoneReqs > 0 { + ar.state.SuccessRate = float64(ar.state.SuccessReqs) / float64(ar.state.DoneReqs) * 100 + } + snap := ar.snapshotState() + ar.mu.Unlock() + + s.bus.Publish(Event{RunID: runID, Kind: EventRequestDone, Payload: snap}) + }, + }, nil + } + + engine := turbo.New(factory) ar.mu.Lock() ar.turboEngine = engine @@ -257,9 +379,14 @@ func (s *serverImpl) completeStandardRun(ar *activeRun, runID RunID, taskDef typ s.bus.Publish(Event{RunID: runID, Kind: EventRunComplete, Payload: snap}) s.bus.CloseRun(runID) - // 将完整运行状态持久化到磁盘,供历史详情页回放 + // 将元数据写入磁盘(Requests 已造话就写 JSONL,此处仅写 RunState 元数据) persistRunState(historyDir, snap) + // 运行已完成且数据已落盘,释放内存中的请求详情 + ar.mu.Lock() + ar.state.Requests = nil + ar.mu.Unlock() + summary := types.TaskRunSummary{ RunID: string(runID), TaskID: taskDef.ID, @@ -296,9 +423,14 @@ func (s *serverImpl) completeTurboRun(ar *activeRun, runID RunID, taskDef types. s.bus.Publish(Event{RunID: runID, Kind: EventRunComplete, Payload: snap}) s.bus.CloseRun(runID) - // 将完整运行状态持久化到磁盘,供历史详情页回放 + // 将元数据写入磁盘(Requests 已造话就写 JSONL,此处仅写 RunState 元数据) persistRunState(historyDir, snap) + // 运行已完成且数据已落盘,释放内存中的请求详情 + ar.mu.Lock() + ar.state.Requests = nil + ar.mu.Unlock() + var maxStable int var peakTPS float64 if result != nil { @@ -333,12 +465,17 @@ func (s *serverImpl) failRun(ar *activeRun, runID RunID, taskDef types.TaskDefin snap := ar.snapshotState() ar.mu.Unlock() - s.bus.Publish(Event{RunID: runID, Kind: EventRunFailed, Payload: runErr}) + s.bus.Publish(Event{RunID: runID, Kind: EventRunFailed, Payload: snap}) s.bus.CloseRun(runID) - // 将完整运行状态持久化到磁盘,供历史详情页回放 + // 将元数据写入磁盘(Requests 已造话就写 JSONL,此处仅写 RunState 元数据) persistRunState(historyDir, snap) + // 运行已完成且数据已落盘,释放内存中的请求详情 + ar.mu.Lock() + ar.state.Requests = nil + ar.mu.Unlock() + summary := types.TaskRunSummary{ RunID: string(runID), TaskID: taskDef.ID, @@ -408,6 +545,10 @@ func (s *serverImpl) GetRunState(runID RunID) (*RunState, bool) { ar.mu.RLock() snap := ar.snapshotState() ar.mu.RUnlock() + // 运行已完成且 Requests 已从内存清除(节省内存),从 JSONL 文件补充加载 + if snap.Status != RunStatusRunning && snap.Requests == nil { + snap.Requests = loadRequestsFromDisk(historyDir, snap.RunID) + } return snap, true } @@ -417,6 +558,8 @@ func (s *serverImpl) GetRunState(runID RunID) (*RunState, bool) { if err != nil || snap == nil { return nil, false } + // 从 JSONL 文件加载请求详情 + snap.Requests = loadRequestsFromDisk(historyDir, runID) return snap, true } diff --git a/internal/tui/client.go b/internal/tui/client.go index 7c924f9..aedbd5c 100644 --- a/internal/tui/client.go +++ b/internal/tui/client.go @@ -5,6 +5,7 @@ import ( tea "github.com/charmbracelet/bubbletea" "github.com/yinxulai/ait/internal/server" + "github.com/yinxulai/ait/internal/types" ) // Client 持有 server.Server,为 TUI 层提供 tea.Cmd 包装的异步调用。 @@ -136,16 +137,43 @@ func (c *Client) GetRunStateCmd(runID server.RunID) tea.Cmd { } // GetRunStateForHistoryCmd 从历史记录导航时异步加载运行状态快照。 -func (c *Client) GetRunStateForHistoryCmd(runID server.RunID) tea.Cmd { +// 若磁盘文件不存在,则用 summary 摘要数据构造最小化 RunState 作为回退。 +func (c *Client) GetRunStateForHistoryCmd(runID server.RunID, summary *types.TaskRunSummary) tea.Cmd { return func() tea.Msg { state, ok := c.srv.GetRunState(runID) if !ok { - return ErrorMsg{Err: fmt.Errorf("该次运行数据不在内存中,请重新运行")} + if summary != nil { + state = summaryToRunState(summary) + } else { + return ErrorMsg{Err: fmt.Errorf("该次运行数据不在内存中,请重新运行")} + } } return RunStateMsg{State: state, FromHistory: true} } } +// summaryToRunState 用 TaskRunSummary 摘要数据构造最小化 RunState,供无磁盘快照时回退展示。 +func summaryToRunState(s *types.TaskRunSummary) *server.RunState { + finished := s.FinishedAt + status := server.RunStatusCompleted + if s.Status == string(server.RunStatusFailed) { + status = server.RunStatusFailed + } + return &server.RunState{ + RunID: server.RunID(s.RunID), + TaskID: s.TaskID, + Status: status, + Mode: s.Mode, + StartedAt: s.StartedAt, + FinishedAt: &finished, + AvgTPS: s.AvgTPS, + AvgTTFT: s.AvgTTFT, + SuccessRate: s.SuccessRate, + CacheHitRate: s.CacheHitRate, + ErrorMsg: s.ErrorSummary, + } +} + // GenerateReportCmd 异步生成报告文件。 func (c *Client) GenerateReportCmd(runID server.RunID, format server.ReportFormat) tea.Cmd { return func() tea.Msg { diff --git a/internal/tui/model.go b/internal/tui/model.go index 89a4efb..7dc1ff4 100644 --- a/internal/tui/model.go +++ b/internal/tui/model.go @@ -156,10 +156,19 @@ func (m *Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { } else if m.turboDash != nil && m.turboDash.RunID == msg.State.RunID { m.turboDash.RunState = msg.State } else if msg.FromHistory { - // 从历史记录导航过来:用加载到的 RunState 新建 dash 并切换视图 - m.dash = pages.NewDashboardState(msg.State.RunID, msg.State.TaskID) - m.dash.RunState = msg.State - m.view = viewDashboard + // 从历史记录导航过来:根据运行模式选择对应仪表盘,并设置返回目标为任务详情 + backNav := pages.NavAction{To: pages.NavTaskDetail} + if msg.State.Mode == "turbo" { + m.turboDash = pages.NewTurboDashState(msg.State.RunID, msg.State.TaskID) + m.turboDash.RunState = msg.State + m.turboDash.BackNav = backNav + m.view = viewTurboDash + } else { + m.dash = pages.NewDashboardState(msg.State.RunID, msg.State.TaskID) + m.dash.RunState = msg.State + m.dash.BackNav = backNav + m.view = viewDashboard + } } return m, nil @@ -266,6 +275,12 @@ func (m *Model) handleNav(nav pages.NavAction) tea.Cmd { } else if m.detail == nil { return nil } + // 若该任务有正在运行的实例,注入快照 + if m.detail != nil && m.taskList != nil { + if rs, ok := m.taskList.ActiveRuns[m.detail.Task.ID]; ok && rs != nil { + m.detail.ActiveRun = rs + } + } m.view = viewTaskDetail if m.detail != nil { return m.client.LoadHistoryCmd(m.detail.Task.ID, 10) @@ -295,11 +310,18 @@ func (m *Model) handleNav(nav pages.NavAction) tea.Cmd { case pages.NavRunDetail: // 从历史记录进入某次运行的仪表盘 - return m.client.GetRunStateForHistoryCmd(nav.RunID) + return m.client.GetRunStateForHistoryCmd(nav.RunID, nav.Summary) case pages.NavReqDetail: reqs := m.collectRequests() - m.reqDetail = pages.NewReqDetailState(m.currentRunID(), reqs, nav.ReqIndex) + s := pages.NewReqDetailState(m.currentRunID(), reqs, nav.ReqIndex) + // 记录来源页面,用于 b/esc 返回 + if m.view == viewTurboDash { + s.BackNav = pages.NavAction{To: pages.NavTurboDash} + } else { + s.BackNav = pages.NavAction{To: pages.NavDashboard} + } + m.reqDetail = s m.view = viewReqDetail return nil @@ -344,6 +366,9 @@ func (m *Model) handleServerEvent(msg ServerEventMsg) (tea.Model, tea.Cmd) { if m.taskList != nil { delete(m.taskList.ActiveRuns, taskID) } + if m.detail != nil && m.detail.Task.ID == taskID { + m.detail.ActiveRun = nil + } // 在后台刷新任务列表和历史,不自动跳转页面;用户可按 b/Esc 返回 return m, tea.Batch( m.client.LoadTasksCmd(), @@ -368,6 +393,9 @@ func (m *Model) handleServerEvent(msg ServerEventMsg) (tea.Model, tea.Cmd) { if m.taskList != nil { delete(m.taskList.ActiveRuns, taskID) } + if m.detail != nil && m.detail.Task.ID == taskID { + m.detail.ActiveRun = nil + } // 在后台刷新任务列表和历史,不自动跳转页面;用户可按 b/Esc 返回 return m, tea.Batch( m.client.LoadTasksCmd(), @@ -419,6 +447,14 @@ func (m *Model) injectRunState(rs *server.RunState) { } else { delete(m.taskList.ActiveRuns, rs.TaskID) } + // 如果详情页正在显示该任务,同步更新 ActiveRun + if m.detail != nil && m.detail.Task.ID == rs.TaskID { + if rs.Status == server.RunStatusRunning { + m.detail.ActiveRun = rs + } else { + m.detail.ActiveRun = nil + } + } } func (m *Model) dashTaskName() string { @@ -444,6 +480,14 @@ func (m *Model) turboDashTaskName() string { } func (m *Model) reqDetailTaskName() string { + // 根据 reqDetail 的来源视图确定任务名,避免两个面板均有状态时取错 + if m.reqDetail != nil && m.reqDetail.BackNav.To == pages.NavTurboDash { + if m.turboDash != nil { + if t := m.findTask(m.turboDash.TaskID); t != nil { + return t.Name + } + } + } if m.dash != nil { if t := m.findTask(m.dash.TaskID); t != nil { return t.Name @@ -478,11 +522,16 @@ func (m *Model) currentRunTaskID(isDash bool) string { } func (m *Model) collectRequests() []*server.RequestMetrics { - if m.dash != nil && m.dash.RunState != nil { - return m.dash.RunState.Requests - } - if m.turboDash != nil && m.turboDash.RunState != nil { - return m.turboDash.RunState.Requests + // 优先使用当前活跃视图的数据,避免两个面板均有 RunState 时取错 + switch m.view { + case viewTurboDash: + if m.turboDash != nil && m.turboDash.RunState != nil { + return m.turboDash.RunState.Requests + } + case viewDashboard: + if m.dash != nil && m.dash.RunState != nil { + return m.dash.RunState.Requests + } } return nil } diff --git a/internal/tui/pages/contextbar.go b/internal/tui/pages/contextbar.go index bb40c4c..ab3557d 100644 --- a/internal/tui/pages/contextbar.go +++ b/internal/tui/pages/contextbar.go @@ -22,7 +22,7 @@ func CtxBar_TaskList_Normal() []ContextBarItem { // CtxBar_TaskList_Running 运行中任务选中时的 Context Bar。 func CtxBar_TaskList_Running() []ContextBarItem { return []ContextBarItem{ - {Key: "Enter", Desc: "进入仪表盘"}, + {Key: "Enter", Desc: "查看详情"}, {Key: "s", Desc: "停止"}, {Key: "y", Desc: "复制"}, } @@ -82,17 +82,24 @@ func CtxBar_Wizard_Step3() []ContextBarItem { } } -// CtxBar_Dashboard_NoSel 标准仪表盘,无选中请求时。 -func CtxBar_Dashboard_NoSel() []ContextBarItem { +// CtxBar_Dashboard_Running_NoSel 标准仪表盘运行中,无选中请求时。 +func CtxBar_Dashboard_Running_NoSel() []ContextBarItem { return []ContextBarItem{ {Key: "s", Desc: "停止"}, - {Key: "b", Desc: "后台运行"}, - {Key: "r", Desc: "提前报告"}, + {Key: "b/Esc", Desc: "返回列表"}, + } +} + +// CtxBar_Dashboard_Done_NoSel 标准仪表盘完成后,无选中请求时。 +func CtxBar_Dashboard_Done_NoSel() []ContextBarItem { + return []ContextBarItem{ + {Key: "r", Desc: "生成报告"}, + {Key: "b/Esc", Desc: "返回列表"}, } } -// CtxBar_Dashboard_Sel 标准仪表盘,已选中请求时。 -func CtxBar_Dashboard_Sel() []ContextBarItem { +// CtxBar_Dashboard_Running_Sel 标准仪表盘运行中,已选中请求时。 +func CtxBar_Dashboard_Running_Sel() []ContextBarItem { return []ContextBarItem{ {Key: "Enter", Desc: "查看请求详情"}, {Key: "↑↓", Desc: "选择请求"}, @@ -100,25 +107,48 @@ func CtxBar_Dashboard_Sel() []ContextBarItem { } } -// CtxBar_TurboDash_NoSel Turbo 仪表盘,无选中级别时。 -func CtxBar_TurboDash_NoSel() []ContextBarItem { +// CtxBar_Dashboard_Done_Sel 标准仪表盘完成后,已选中请求时。 +func CtxBar_Dashboard_Done_Sel() []ContextBarItem { + return []ContextBarItem{ + {Key: "Enter", Desc: "查看请求详情"}, + {Key: "↑↓", Desc: "选择请求"}, + } +} + +// CtxBar_TurboDash_Running_NoSel Turbo 仪表盘运行中,无选中级别时。 +func CtxBar_TurboDash_Running_NoSel() []ContextBarItem { return []ContextBarItem{ {Key: "s", Desc: "停止"}, - {Key: "b", Desc: "后台运行"}, - {Key: "m", Desc: "标记极限"}, - {Key: "r", Desc: "提前报告"}, + {Key: "m", Desc: "标记极限并停止"}, + {Key: "b/Esc", Desc: "返回列表"}, + } +} + +// CtxBar_TurboDash_Done_NoSel Turbo 仪表盘完成后,无选中级别时。 +func CtxBar_TurboDash_Done_NoSel() []ContextBarItem { + return []ContextBarItem{ + {Key: "r", Desc: "生成报告"}, + {Key: "b/Esc", Desc: "返回列表"}, } } -// CtxBar_TurboDash_Sel Turbo 仪表盘,已选中已完成级别时。 -func CtxBar_TurboDash_Sel() []ContextBarItem { +// CtxBar_TurboDash_Running_Sel Turbo 仪表盘运行中,已选中级别时。 +func CtxBar_TurboDash_Running_Sel() []ContextBarItem { return []ContextBarItem{ - {Key: "Enter", Desc: "查看该级别请求列表"}, + {Key: "Enter", Desc: "查看该级别请求"}, {Key: "↑↓", Desc: "选择"}, {Key: "s", Desc: "停止"}, } } +// CtxBar_TurboDash_Done_Sel Turbo 仪表盘完成后,已选中级别时。 +func CtxBar_TurboDash_Done_Sel() []ContextBarItem { + return []ContextBarItem{ + {Key: "Enter", Desc: "查看该级别请求"}, + {Key: "↑↓", Desc: "选择"}, + } +} + // CtxBar_ReqDetail 请求详情页。 func CtxBar_ReqDetail() []ContextBarItem { return []ContextBarItem{ diff --git a/internal/tui/pages/dashboard.go b/internal/tui/pages/dashboard.go index 72745d3..ee44575 100644 --- a/internal/tui/pages/dashboard.go +++ b/internal/tui/pages/dashboard.go @@ -16,9 +16,10 @@ type DashboardState struct { EventCh <-chan server.Event // nil = 已后台或已结束 CancelFn server.CancelFunc RunState *server.RunState - ReqSel int // 选中请求索引(-1 = 无选中) - ReqOff int // 滚动偏移 - ReqVis int // 当前可见请求数 + ReqSel int // 选中请求索引(-1 = 无选中) + ReqOff int // 滚动偏移 + ReqVis int // 当前可见请求数 + BackNav NavAction // 按 b/esc 时的返回目标;Zero = 返回任务列表 } // NewDashboardState 创建仪表盘状态。 @@ -77,32 +78,36 @@ func HandleDashboardKey(d *DashboardState, msg tea.KeyMsg, client Client) (*Dash if len(reqs) == 0 { break } - selPos := 0 - if d.ReqSel >= 0 { - selPos = requestDisplayPos(d.ReqSel, len(reqs)) - } - if selPos <= 0 { - selPos = len(reqs) - 1 + if d.ReqSel < 0 { + // 无选中项:首次按键选中最新一条(显示列表顶部) + d.ReqSel = requestIndexFromDisplayPos(0, len(reqs)) } else { - selPos-- + selPos := requestDisplayPos(d.ReqSel, len(reqs)) + if selPos <= 0 { + selPos = len(reqs) - 1 + } else { + selPos-- + } + d.ReqSel = requestIndexFromDisplayPos(selPos, len(reqs)) } - d.ReqSel = requestIndexFromDisplayPos(selPos, len(reqs)) d.AdjustReqOffset(d.ReqVis, len(reqs)) case "down", "j": if len(reqs) == 0 { break } - selPos := 0 - if d.ReqSel >= 0 { - selPos = requestDisplayPos(d.ReqSel, len(reqs)) - } - if selPos < len(reqs)-1 { - selPos++ + if d.ReqSel < 0 { + // 无选中项:首次按键选中最新一条(显示列表顶部) + d.ReqSel = requestIndexFromDisplayPos(0, len(reqs)) } else { - selPos = 0 + selPos := requestDisplayPos(d.ReqSel, len(reqs)) + if selPos < len(reqs)-1 { + selPos++ + } else { + selPos = 0 + } + d.ReqSel = requestIndexFromDisplayPos(selPos, len(reqs)) } - d.ReqSel = requestIndexFromDisplayPos(selPos, len(reqs)) d.AdjustReqOffset(d.ReqVis, len(reqs)) case "enter": @@ -121,7 +126,11 @@ func HandleDashboardKey(d *DashboardState, msg tea.KeyMsg, client Client) (*Dash } d.EventCh = nil d.CancelFn = nil - nav = NavAction{To: NavTaskList} + if d.BackNav.To != NavNone { + nav = d.BackNav + } else { + nav = NavAction{To: NavTaskList} + } case "r": if d.RunState != nil && !d.IsRunning() { @@ -165,11 +174,18 @@ func RenderDashboard(d *DashboardState, taskName string, st Styles, width, heigh } rs := d.RunState + isRunning := d.IsRunning() + hasSel := d.ReqSel >= 0 && rs != nil && d.ReqSel < len(rs.Requests) var cbItems []ContextBarItem - if d.ReqSel >= 0 && rs != nil && d.ReqSel < len(rs.Requests) { - cbItems = CtxBar_Dashboard_Sel() - } else { - cbItems = CtxBar_Dashboard_NoSel() + switch { + case hasSel && isRunning: + cbItems = CtxBar_Dashboard_Running_Sel() + case hasSel && !isRunning: + cbItems = CtxBar_Dashboard_Done_Sel() + case !hasSel && isRunning: + cbItems = CtxBar_Dashboard_Running_NoSel() + default: + cbItems = CtxBar_Dashboard_Done_NoSel() } l := PageLayout{ CtxItems: cbItems, @@ -263,22 +279,23 @@ func buildProgressLine(rs *server.RunState, st Styles, width int) string { if total > 0 { ratio = float64(done) / float64(total) } - barW := 20 - barRendered := st.Ok.Render(strings.Repeat("█", int(ratio*float64(barW)))) + - st.Muted.Render(strings.Repeat("░", barW-int(ratio*float64(barW)))) - + prefix := " 进度 " elapsed := "" if !rs.StartedAt.IsZero() { - // elapsed time display elapsed = "─" } + suffix := fmt.Sprintf(" %d / %d %s", done, total, elapsed) - line := fmt.Sprintf(" 进度 %s %d / %d %s", - barRendered, done, total, elapsed) - if lipgloss.Width(line) > width { - line = truncate(line, width) + barW := width - lipgloss.Width(prefix) - lipgloss.Width(suffix) + if barW < 5 { + barW = 5 } - return line + + filled := int(ratio * float64(barW)) + barRendered := st.Ok.Render(strings.Repeat("█", filled)) + + st.Muted.Render(strings.Repeat("░", barW-filled)) + + return prefix + barRendered + suffix } // buildRequestList 构建请求列表区域。 @@ -287,7 +304,11 @@ func buildRequestList(d *DashboardState, rs *server.RunState, st Styles, width, lines = append(lines, " "+st.SectionHead.Render("请求列表")) if rs == nil || len(rs.Requests) == 0 { - lines = append(lines, " "+st.Muted.Render("等待请求...")) + msg := "等待请求..." + if rs != nil && rs.Status != server.RunStatusRunning { + msg = "无请求详情数据" + } + lines = append(lines, " "+st.Muted.Render(msg)) for len(lines) < maxH { lines = append(lines, "") } @@ -320,21 +341,6 @@ func buildRequestList(d *DashboardState, rs *server.RunState, st Styles, width, r := reqs[i] isSel := i == d.ReqSel - if r == nil { - // 该请求尚未开始,渲染为等待中 - marker := selectionMarker(isSel) - rowContent := padRight(marker, markW) + - padRight(fmt.Sprintf("#%d", i+1), idW) + - padRight(st.Muted.Render("…"), statW) + - padRight(st.Muted.Render("等待中"), timeW) + - padRight("─", ttftW) + - padRight("─", cacheW) + - padRight("─", tokW) + - "─" - lines = append(lines, renderTableRow(st, width, isSel, rowContent)) - continue - } - statusText := "✓" if !r.Success { statusText = "✗" diff --git a/internal/tui/pages/helpers.go b/internal/tui/pages/helpers.go index c73d047..779f637 100644 --- a/internal/tui/pages/helpers.go +++ b/internal/tui/pages/helpers.go @@ -106,6 +106,28 @@ func fmtDuration(d time.Duration) string { return fmt.Sprintf("%.0fm%.0fs", s/60, float64(int64(s)%60)) } +// fmtRelativeTime 将过去的时间格式化为「X 前」的简短形式。 +func fmtRelativeTime(t time.Time) string { + if t.IsZero() { + return "" + } + d := time.Since(t) + if d < time.Minute { + return "刚刚" + } + if d < time.Hour { + return fmt.Sprintf("%d 分钟前", int(d.Minutes())) + } + if d < 24*time.Hour { + return fmt.Sprintf("%d 小时前", int(d.Hours())) + } + days := int(d.Hours() / 24) + if days < 30 { + return fmt.Sprintf("%d 天前", days) + } + return t.Format("2006-01-02") +} + // ─── 布局工具 ───────────────────────────────────────────────────────────────── // renderHeader 渲染顶部双行标题栏。 diff --git a/internal/tui/pages/nav.go b/internal/tui/pages/nav.go index 1277179..a698508 100644 --- a/internal/tui/pages/nav.go +++ b/internal/tui/pages/nav.go @@ -30,7 +30,8 @@ type NavAction struct { TaskID string RunID server.RunID ReqIndex int - EditTask *types.TaskDefinition // 向导编辑模式时非空;nil 表示新建 + EditTask *types.TaskDefinition // 向导编辑模式时非空;nil 表示新建 + Summary *types.TaskRunSummary // NavRunDetail 时,磁盘文件缺失的回退数据 } // Client 定义 pages 包对外依赖的操作集合。 @@ -50,6 +51,6 @@ type Client interface { // 历史 & 报告 LoadHistoryCmd(taskID string, limit int) tea.Cmd GetRunStateCmd(runID server.RunID) tea.Cmd - GetRunStateForHistoryCmd(runID server.RunID) tea.Cmd + GetRunStateForHistoryCmd(runID server.RunID, summary *types.TaskRunSummary) tea.Cmd GenerateReportCmd(runID server.RunID, format server.ReportFormat) tea.Cmd } diff --git a/internal/tui/pages/reqdetail.go b/internal/tui/pages/reqdetail.go index 15cc0e6..55ae705 100644 --- a/internal/tui/pages/reqdetail.go +++ b/internal/tui/pages/reqdetail.go @@ -13,8 +13,9 @@ import ( type ReqDetailState struct { RunID server.RunID Requests []*server.RequestMetrics - Index int // 当前查看的请求索引 - ScrollY int // 输出区域滚动偏移 + Index int // 当前查看的请求索引 + ScrollY int // 输出区域滚动偏移 + BackNav NavAction // 按 b/esc 时的返回目标 } // NewReqDetailState 创建请求详情状态。 @@ -59,7 +60,11 @@ func HandleReqDetailKey(s *ReqDetailState, msg tea.KeyMsg) (*ReqDetailState, Nav s.ScrollY++ case "b", "esc", "backspace": - nav = NavAction{To: NavDashboard} + if s.BackNav.To != NavNone { + nav = s.BackNav + } else { + nav = NavAction{To: NavDashboard} + } case "q", "ctrl+c": nav = NavAction{To: NavQuit} @@ -152,6 +157,14 @@ func buildReqPerfPanel(r *server.RequestMetrics, st Styles, maxH, width int) str lines = append(lines, " "+st.SectionHead.Render("性能指标")) lines = append(lines, "") + if r == nil { + lines = append(lines, " "+st.Muted.Render("等待数据...")) + for len(lines) < maxH { + lines = append(lines, "") + } + return strings.Join(lines[:maxH], "\n") + } + statusStr := st.Ok.Render("✓ 成功") if !r.Success { statusStr = st.ErrStyle.Render("✗ 失败") @@ -182,6 +195,15 @@ func buildReqNetworkPanel(r *server.RequestMetrics, st Styles, maxH, width int) var lines []string lines = append(lines, " "+st.SectionHead.Render("网络指标")) lines = append(lines, "") + + if r == nil { + lines = append(lines, " "+st.Muted.Render("等待数据...")) + for len(lines) < maxH { + lines = append(lines, "") + } + return strings.Join(lines[:maxH], "\n") + } + lines = append(lines, " "+labelValue(st, "DNS ", fmtDuration(r.DNSTime))) lines = append(lines, " "+labelValue(st, "TCP 连接 ", fmtDuration(r.ConnectTime))) lines = append(lines, " "+labelValue(st, "TLS 握手 ", fmtDuration(r.TLSTime))) diff --git a/internal/tui/pages/taskdetail.go b/internal/tui/pages/taskdetail.go index 018d840..10d1523 100644 --- a/internal/tui/pages/taskdetail.go +++ b/internal/tui/pages/taskdetail.go @@ -14,12 +14,12 @@ import ( type TaskDetailState struct { Task types.TaskDefinition History []types.TaskRunSummary - // HistorySel 当前选中的历史记录索引(0 = 最近一次) + // HistorySel 当前选中的历史记录索引(0 = 最近一次;若有正在运行的实例,0 = 运行中条目) HistorySel int HistoryOff int HistoryVis int - // LatestExpanded 控制最近一次运行是否展开(运行结束后自动置 true) - LatestExpanded bool + // ActiveRun 当前正在运行的实例快照(nil = 无),由 model 注入 + ActiveRun *server.RunState } // NewTaskDetailState 创建初始任务详情状态。 @@ -30,24 +30,40 @@ func NewTaskDetailState(task types.TaskDefinition) *TaskDetailState { // HandleTaskDetailKey 处理任务详情页按键。 func HandleTaskDetailKey(s *TaskDetailState, msg tea.KeyMsg, client Client) (*TaskDetailState, tea.Cmd, NavAction) { nav := NavAction{} + hasActive := s.ActiveRun != nil + effectiveLen := len(s.History) + if hasActive { + effectiveLen++ + } + switch msg.String() { case "up", "k": if s.HistorySel > 0 { s.HistorySel-- - s.LatestExpanded = false } case "down", "j": - if s.HistorySel < len(s.History)-1 { + if s.HistorySel < effectiveLen-1 { s.HistorySel++ - s.LatestExpanded = false } case "enter": - if s.HistorySel >= 0 && s.HistorySel < len(s.History) { - runID := strings.TrimSpace(s.History[s.HistorySel].RunID) - if runID != "" { - nav = NavAction{To: NavRunDetail, RunID: server.RunID(runID)} + if s.HistorySel >= 0 && s.HistorySel < effectiveLen { + if hasActive && s.HistorySel == 0 { + // 进入正在运行的仪表盘 + nav = NavAction{To: NavRunDetail, RunID: s.ActiveRun.RunID} + } else { + histIdx := s.HistorySel + if hasActive { + histIdx-- + } + if histIdx >= 0 && histIdx < len(s.History) { + runID := strings.TrimSpace(s.History[histIdx].RunID) + if runID != "" { + sum := s.History[histIdx] + nav = NavAction{To: NavRunDetail, RunID: server.RunID(runID), Summary: &sum} + } + } } } @@ -58,10 +74,19 @@ func HandleTaskDetailKey(s *TaskDetailState, msg tea.KeyMsg, client Client) (*Ta return s, client.StartRunCmd(s.Task.ID), nav case "g": - if s.HistorySel >= 0 && s.HistorySel < len(s.History) { - runID := strings.TrimSpace(s.History[s.HistorySel].RunID) - if runID != "" { - return s, client.GenerateReportCmd(server.RunID(runID), server.ReportFormatJSON), nav + if s.HistorySel >= 0 && s.HistorySel < effectiveLen { + if hasActive && s.HistorySel == 0 { + break // 正在运行中,无法导出报告 + } + histIdx := s.HistorySel + if hasActive { + histIdx-- + } + if histIdx >= 0 && histIdx < len(s.History) { + runID := strings.TrimSpace(s.History[histIdx].RunID) + if runID != "" { + return s, client.GenerateReportCmd(server.RunID(runID), server.ReportFormatJSON), nav + } } } @@ -78,7 +103,7 @@ func HandleTaskDetailKey(s *TaskDetailState, msg tea.KeyMsg, client Client) (*Ta case "q", "ctrl+c": nav = NavAction{To: NavQuit} } - s.HistoryOff = ensureVisibleOffset(s.HistorySel, len(s.History), s.HistoryOff, s.HistoryVis) + s.HistoryOff = ensureVisibleOffset(s.HistorySel, effectiveLen, s.HistoryOff, s.HistoryVis) return s, nil, nav } @@ -112,7 +137,12 @@ func RenderTaskDetail(s *TaskDetailState, st Styles, width, height int) string { inp := t.Input var cbItems []ContextBarItem - if len(s.History) > 0 { + hasActive := s.ActiveRun != nil + effectiveLen := len(s.History) + if hasActive { + effectiveLen++ + } + if effectiveLen > 0 { cbItems = CtxBar_TaskDetail_HasHistory() } else { cbItems = CtxBar_TaskDetail_NoHistory() @@ -177,6 +207,12 @@ func buildTaskDetailContent(s *TaskDetailState, st Styles, t types.TaskDefinitio rightLines = append(rightLines, padRight(" "+st.SectionHead.Render("历史运行记录"), rightW)) rightLines = append(rightLines, padRight(st.Divider.Render(strings.Repeat("─", rightW)), rightW)) + hasActive := s.ActiveRun != nil + effectiveLen := len(s.History) + if hasActive { + effectiveLen++ + } + const ( markW = 2 statW = 2 @@ -190,12 +226,23 @@ func buildTaskDetailContent(s *TaskDetailState, st Styles, t types.TaskDefinitio rightLines = append(rightLines, padRight(renderTableHeader(st, rightW, hdr), rightW)) rightLines = append(rightLines, padRight(st.Divider.Render(strings.Repeat("─", rightW)), rightW)) - if len(s.History) == 0 { + if effectiveLen == 0 { rightLines = append(rightLines, padRight(" "+st.Muted.Render("暂无运行记录"), rightW)) } else { + // 始终为当前选中的历史条目显示详情面板 var detailLines []string - if s.LatestExpanded { - detailLines = buildTaskHistoryDetailLines(s, st, rightW) + { + histIdx := s.HistorySel + if hasActive { + if s.HistorySel == 0 { + histIdx = -1 // 运行中条目无详情 + } else { + histIdx-- + } + } + if histIdx >= 0 { + detailLines = buildTaskHistoryDetailLines(s, histIdx, st, rightW) + } } tableMaxH := maxH - len(detailLines) if tableMaxH < 5 { @@ -206,38 +253,63 @@ func buildTaskDetailContent(s *TaskDetailState, st Styles, t types.TaskDefinitio tableMaxH = maxH - len(detailLines) } s.HistoryVis = listVisibleItems(tableMaxH, 4) - s.HistoryOff = ensureVisibleOffset(s.HistorySel, len(s.History), s.HistoryOff, s.HistoryVis) + s.HistoryOff = ensureVisibleOffset(s.HistorySel, effectiveLen, s.HistoryOff, s.HistoryVis) start := s.HistoryOff - end := minInt(len(s.History), start+s.HistoryVis) + end := minInt(effectiveLen, start+s.HistoryVis) // ── 历史列表 ── for idx := start; idx < end; idx++ { - run := s.History[idx] - statusText := "✓" - if run.Status != "completed" { - statusText = "✗" - } - modeShort := "标准" - if run.Mode == "turbo" { - modeShort = "Turbo" - } isSel := idx == s.HistorySel - statusIcon := statusText - if run.Status == "completed" { - statusIcon = styleWhenNotSelected(isSel, st.Ok, statusText) + marker := selectionMarker(isSel) + var row string + + if hasActive && idx == 0 { + // 正在运行中的条目 + rs := s.ActiveRun + modeShort := "标准" + if rs.Mode == "turbo" { + modeShort = "Turbo" + } + statusIcon := styleWhenNotSelected(isSel, st.Ok, "●") + rateStr := "─" + if rs.TotalReqs > 0 { + rateStr = fmt.Sprintf("%.0f%%", rs.SuccessRate) + } + progStr := fmt.Sprintf("%d/%d 正在运行...", rs.DoneReqs, rs.TotalReqs) + row = padRight(marker, markW) + + padRight(statusIcon, statW) + + padRight(rs.StartedAt.Format("2006-01-02 15:04"), timeW) + + padRight(modeShort, modeW) + + padRight(rateStr, rateW) + + styleWhenNotSelected(isSel, st.Ok, progStr) } else { - statusIcon = styleWhenNotSelected(isSel, st.ErrStyle, statusText) + histIdx := idx + if hasActive { + histIdx-- + } + run := s.History[histIdx] + statusText := "✓" + if run.Status != "completed" { + statusText = "✗" + } + modeShort := "标准" + if run.Mode == "turbo" { + modeShort = "Turbo" + } + statusIcon := statusText + if run.Status == "completed" { + statusIcon = styleWhenNotSelected(isSel, st.Ok, statusText) + } else { + statusIcon = styleWhenNotSelected(isSel, st.ErrStyle, statusText) + } + row = padRight(marker, markW) + + padRight(statusIcon, statW) + + padRight(run.StartedAt.Format("2006-01-02 15:04"), timeW) + + padRight(modeShort, modeW) + + padRight(fmt.Sprintf("%.1f%%", run.SuccessRate), rateW) + + padRight(fmtDuration(run.AvgTTFT), ttftW) + + fmt.Sprintf("%.1f", run.AvgTPS) } - - marker := selectionMarker(isSel) - - row := padRight(marker, markW) + - padRight(statusIcon, statW) + - padRight(run.StartedAt.Format("2006-01-02 15:04"), timeW) + - padRight(modeShort, modeW) + - padRight(fmt.Sprintf("%.1f%%", run.SuccessRate), rateW) + - padRight(fmtDuration(run.AvgTTFT), ttftW) + - fmt.Sprintf("%.1f", run.AvgTPS) rightLines = append(rightLines, padRight(renderTableRow(st, rightW, isSel, row), rightW)) if idx < end-1 { rightLines = append(rightLines, padRight(st.Divider.Render(strings.Repeat("─", rightW)), rightW)) @@ -290,16 +362,17 @@ func UpdateTaskDetailHistory(s *TaskDetailState, history []types.TaskRunSummary, s.HistoryOff = ensureVisibleOffset(s.HistorySel, len(history), s.HistoryOff, s.HistoryVis) } if autoExpand && len(history) > 0 { - s.LatestExpanded = true + // autoExpand 参数保留接口兼容性,展开行为已由渲染层自动处理 + _ = autoExpand } return s } -func buildTaskHistoryDetailLines(s *TaskDetailState, st Styles, width int) []string { - if s.HistorySel < 0 || s.HistorySel >= len(s.History) { +func buildTaskHistoryDetailLines(s *TaskDetailState, histIdx int, st Styles, width int) []string { + if histIdx < 0 || histIdx >= len(s.History) { return nil } - sel := s.History[s.HistorySel] + sel := s.History[histIdx] elapsed := sel.FinishedAt.Sub(sel.StartedAt) labelW := 8 indent := " " diff --git a/internal/tui/pages/tasklist.go b/internal/tui/pages/tasklist.go index 63a2b62..8c0fc0d 100644 --- a/internal/tui/pages/tasklist.go +++ b/internal/tui/pages/tasklist.go @@ -91,13 +91,7 @@ func HandleTaskListKey(s *TaskListState, msg tea.KeyMsg, client Client) (*TaskLi case "enter": if t, ok := s.CurrentTask(); ok { - if s.IsTaskRunning(t.ID) { - if rs, ok := s.ActiveRuns[t.ID]; ok { - nav = NavAction{To: NavDashboard, TaskID: t.ID, RunID: rs.RunID} - } - } else { - nav = NavAction{To: NavTaskDetail, TaskID: t.ID} - } + nav = NavAction{To: NavTaskDetail, TaskID: t.ID} } case "r": @@ -171,17 +165,25 @@ func buildTaskListContent(s *TaskListState, st Styles, width, maxH int) string { } listTopLines := len(lines) - // 列宽(合计 = nameW + modeW + protoW + 结果列) - nameW := 28 - modeW := 8 - protoW := 14 + // 列宽(gap=2 作为列间距内置到每个非末尾列的宽度中) + const ( + modeW = 9 // 7 + 2 gap + protoW = 12 // 10 + 2 gap + lastRunW = 13 // 11 + 2 gap + ttftW = 12 // 10 + 2 gap + tpsW = 9 // 末尾列,无需额外 gap + ) + fixedW := 2 + modeW + protoW + lastRunW + ttftW + tpsW + nameW := maxInt(10, width-fixedW) // 表头:2 空格前缀与正文行对齐(cursor=2) header := renderTableHeader(st, width, - " " + padRight("任务名称", nameW) + - padRight("模式", modeW) + - padRight("协议", protoW) + - "上次结果") + " "+padRight("任务名称", nameW)+ + padRight("模式", modeW)+ + padRight("协议", protoW)+ + padRight("上次运行", lastRunW)+ + padRight("TTFT", ttftW)+ + "TPS") lines = append(lines, header) lines = append(lines, dividerLine(st, width)) listMaxH := maxInt(3, maxH-listTopLines) @@ -199,9 +201,9 @@ func buildTaskListContent(s *TaskListState, st Styles, width, maxH int) string { for i := start; i < end; i++ { t := s.Tasks[i] - isRunning := s.IsTaskRunning(t.ID) isSel := i == s.Selected rs := s.ActiveRuns[t.ID] + _, hasActiveRun := s.ActiveRuns[t.ID] // ── 指示符 ── prefix := padRight(selectionMarker(isSel), 2) @@ -219,48 +221,6 @@ func buildTaskListContent(s *TaskListState, st Styles, width, maxH int) string { // ── 协议 ── proto := padRight(shortProtocol(t.Input.NormalizedProtocol()), protoW) - // ── 上次结果(选中行禁用嵌套样式,避免重置整行背景)── - lastResultText := "从未运行" - if t.LastRunSummary != nil { - pct := t.LastRunSummary.SuccessRate - if t.Input.Turbo { - if t.LastRunSummary.MaxStableConcurrency > 0 { - lastResultText = fmt.Sprintf("★ 并发%d", t.LastRunSummary.MaxStableConcurrency) - } - } else { - switch { - case pct >= 99: - lastResultText = fmt.Sprintf("✓ %.1f%%", pct) - case pct >= 90: - lastResultText = fmt.Sprintf("%.1f%%", pct) - default: - lastResultText = fmt.Sprintf("✗ %.1f%%", pct) - } - } - } - if isRunning && rs != nil { - lastResultText = fmt.Sprintf("◉ %d/%d %.0f%%", rs.DoneReqs, rs.TotalReqs, rs.SuccessRate) - } - - lastResult := lastResultText - if isRunning && rs != nil { - lastResult = styleWhenNotSelected(isSel, st.Ok, lastResultText) - } else if t.LastRunSummary == nil { - lastResult = styleWhenNotSelected(isSel, st.Muted, lastResultText) - } else if t.Input.Turbo && t.LastRunSummary.MaxStableConcurrency > 0 { - lastResult = styleWhenNotSelected(isSel, st.Ok, lastResultText) - } else if !t.Input.Turbo { - pct := t.LastRunSummary.SuccessRate - switch { - case pct >= 99: - lastResult = styleWhenNotSelected(isSel, st.Ok, lastResultText) - case pct >= 90: - lastResult = styleWhenNotSelected(isSel, st.MetricVal, lastResultText) - default: - lastResult = styleWhenNotSelected(isSel, st.ErrStyle, lastResultText) - } - } - // ── 任务名称(裁剪)── name := truncate(t.Name, nameW) namePad := nameW - lipgloss.Width(name) @@ -269,10 +229,44 @@ func buildTaskListContent(s *TaskListState, st Styles, width, maxH int) string { } nameCol := name + strings.Repeat(" ", namePad) - // ── 第一行 ── - row1Content := nameCol + modeCol + proto + lastResult - row1 := renderTableRow(st, width, isSel, prefix+row1Content) - lines = append(lines, row1) + // ── 上次运行时间 ── + lastRunText := "─" + if hasActiveRun { + lastRunText = "运行中" + } else if t.LastRunAt != nil { + lastRunText = fmtRelativeTime(*t.LastRunAt) + } + lastRunStyle := st.Muted + if hasActiveRun { + lastRunStyle = st.Ok + } + lastRunCol := padRight(styleWhenNotSelected(isSel, lastRunStyle, lastRunText), lastRunW) + + // ── TTFT ── + ttftText := "─" + if hasActiveRun && rs != nil && rs.AvgTTFT > 0 { + ttftText = fmtDuration(rs.AvgTTFT) + } else if !hasActiveRun && t.LastRunSummary != nil { + ttftText = fmtDuration(t.LastRunSummary.AvgTTFT) + } + ttftCol := padRight(styleWhenNotSelected(isSel, st.Value, ttftText), ttftW) + + // ── TPS ── + tpsText := "─" + if hasActiveRun && rs != nil && rs.AvgTPS > 0 { + tpsText = fmt.Sprintf("%.1f", rs.AvgTPS) + } else if !hasActiveRun && t.LastRunSummary != nil { + if t.Input.Turbo && t.LastRunSummary.MaxStableConcurrency > 0 { + tpsText = fmt.Sprintf("并发%d", t.LastRunSummary.MaxStableConcurrency) + } else if !t.Input.Turbo { + tpsText = fmt.Sprintf("%.1f", t.LastRunSummary.AvgTPS) + } + } + tpsCol := styleWhenNotSelected(isSel, st.Value, tpsText) + + // ── 单行:名称 | 模式 | 协议 | 上次运行 | TTFT | TPS ── + rowContent := nameCol + modeCol + proto + lastRunCol + ttftCol + tpsCol + lines = append(lines, renderTableRow(st, width, isSel, prefix+rowContent)) // ── 分隔线 ── if i < end-1 && len(lines) < maxH-1 { diff --git a/internal/tui/pages/turbodash.go b/internal/tui/pages/turbodash.go index 07e3a80..e706586 100644 --- a/internal/tui/pages/turbodash.go +++ b/internal/tui/pages/turbodash.go @@ -17,9 +17,10 @@ type TurboDashState struct { EventCh <-chan server.Event CancelFn server.CancelFunc RunState *server.RunState - LevelSel int // 选中的级别索引(-1 = 无选中) + LevelSel int // 选中的级别索引(-1 = 无选中) LevelOff int LevelVis int + BackNav NavAction // 按 b/esc 时的返回目标;Zero = 返回任务列表 } // NewTurboDashState 创建 Turbo 仪表盘初始状态。 @@ -73,9 +74,13 @@ func HandleTurboDashKey(d *TurboDashState, msg tea.KeyMsg, client Client) (*Turb } case "enter": - // 进入该级别的请求列表(使用标准仪表盘的请求详情,此处导航到 ReqDetail) + // 进入该级别的请求列表,定位到该级别第一条请求 if d.LevelSel >= 0 && d.LevelSel < len(levels) { - nav = NavAction{To: NavReqDetail, ReqIndex: 0} + startIdx := 0 + for j := 0; j < d.LevelSel; j++ { + startIdx += levels[j].TotalRequests + } + nav = NavAction{To: NavReqDetail, ReqIndex: startIdx} } case "s": @@ -95,7 +100,11 @@ func HandleTurboDashKey(d *TurboDashState, msg tea.KeyMsg, client Client) (*Turb } d.EventCh = nil d.CancelFn = nil - nav = NavAction{To: NavTaskList} + if d.BackNav.To != NavNone { + nav = d.BackNav + } else { + nav = NavAction{To: NavTaskList} + } case "r": if d.RunState != nil && !d.IsRunning() { @@ -139,11 +148,18 @@ func RenderTurboDash(d *TurboDashState, taskName string, st Styles, width, heigh } rs := d.RunState + isRunning := d.IsRunning() + hasSel := d.LevelSel >= 0 && rs != nil && d.LevelSel < len(rs.Levels) var cbItems []ContextBarItem - if d.LevelSel >= 0 && rs != nil && d.LevelSel < len(rs.Levels) { - cbItems = CtxBar_TurboDash_Sel() - } else { - cbItems = CtxBar_TurboDash_NoSel() + switch { + case hasSel && isRunning: + cbItems = CtxBar_TurboDash_Running_Sel() + case hasSel && !isRunning: + cbItems = CtxBar_TurboDash_Done_Sel() + case !hasSel && isRunning: + cbItems = CtxBar_TurboDash_Running_NoSel() + default: + cbItems = CtxBar_TurboDash_Done_NoSel() } l := PageLayout{ CtxItems: cbItems, @@ -239,14 +255,20 @@ func buildTurboProgressLine(rs *server.RunState, st Styles, width int) string { if total > 0 { ratio = float64(done) / float64(total) } - barW := 15 - barRendered := st.Ok.Render(strings.Repeat("█", int(ratio*float64(barW)))) + - st.Muted.Render(strings.Repeat("░", barW-int(ratio*float64(barW)))) - levelTotal := len(rs.Levels) - line := fmt.Sprintf(" 进度 %s %d/%d 当前并发 %d 总进度: 已完成 %d/~? 级", - barRendered, done, total, rs.CurrentLevel, levelTotal) - return line + prefix := " 进度 " + suffix := fmt.Sprintf(" %d/%d 当前并发 %d 总进度: 已完成 %d/~? 级", done, total, rs.CurrentLevel, levelTotal) + + barW := width - lipgloss.Width(prefix) - lipgloss.Width(suffix) + if barW < 5 { + barW = 5 + } + + filled := int(ratio * float64(barW)) + barRendered := st.Ok.Render(strings.Repeat("█", filled)) + + st.Muted.Render(strings.Repeat("░", barW-filled)) + + return prefix + barRendered + suffix } // buildLevelList 构建 Turbo 级别列表区域。 diff --git a/internal/tui/pages/wizard.go b/internal/tui/pages/wizard.go index 7fbf530..3210f77 100644 --- a/internal/tui/pages/wizard.go +++ b/internal/tui/pages/wizard.go @@ -169,8 +169,10 @@ func (wz *WizardState) BuildTaskConfig() server.TaskConfig { type fieldDef struct { kind fieldKind label string - // 获取当前值(字符串形式) + // 获取当前值(字符串形式),用于显示;可能包含占位默认值 get func(wz *WizardState) string + // 获取实际存储值(用于编辑操作);若为 nil 则退回到 get + getRaw func(wz *WizardState) string // 设置文本值 set func(wz *WizardState, v string) // 枚举/布尔切换 @@ -242,6 +244,7 @@ func step1Fields() []fieldDef { } return types.DefaultEndpointURL(wz.Protocol) }, + getRaw: func(wz *WizardState) string { return wz.EndpointURL }, set: func(wz *WizardState, v string) { wz.EndpointURL = v }, }, { @@ -478,7 +481,11 @@ func HandleWizardKey(wz *WizardState, msg tea.KeyMsg, client Client) (*WizardSta if wz.FieldIndex < len(fields) { f := fields[wz.FieldIndex] if f.set != nil && f.kind == fieldText { - v := f.get(wz) + getEdit := f.get + if f.getRaw != nil { + getEdit = f.getRaw + } + v := getEdit(wz) r := []rune(v) if len(r) > 0 { f.set(wz, string(r[:len(r)-1])) @@ -494,7 +501,11 @@ func HandleWizardKey(wz *WizardState, msg tea.KeyMsg, client Client) (*WizardSta if len(msg.Runes) > 0 && wz.FieldIndex < len(fields) { f := fields[wz.FieldIndex] if f.set != nil && (f.kind == fieldText || f.kind == fieldNumber) { - f.set(wz, f.get(wz)+string(msg.Runes)) + getEdit := f.get + if f.getRaw != nil { + getEdit = f.getRaw + } + f.set(wz, getEdit(wz)+string(msg.Runes)) } } } From 54ed8f8c2898c49f40476aab60833aa0841b3158 Mon Sep 17 00:00:00 2001 From: Alain Date: Sun, 17 May 2026 23:29:03 +0800 Subject: [PATCH 14/52] =?UTF-8?q?feat:=20=E6=B7=BB=E5=8A=A0=E4=BB=BB?= =?UTF-8?q?=E5=8A=A1=E8=B6=85=E6=97=B6=E9=85=8D=E7=BD=AE=EF=BC=8C=E4=BC=98?= =?UTF-8?q?=E5=8C=96=E4=BB=BB=E5=8A=A1=E5=88=9B=E5=BB=BA=E5=92=8C=E7=BC=96?= =?UTF-8?q?=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/tui/model.go | 3 ++- internal/tui/pages/wizard.go | 20 ++++++++++++++++++-- 2 files changed, 20 insertions(+), 3 deletions(-) diff --git a/internal/tui/model.go b/internal/tui/model.go index 7dc1ff4..ee930ed 100644 --- a/internal/tui/model.go +++ b/internal/tui/model.go @@ -99,7 +99,8 @@ func (m *Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { // ── 任务保存完成(新建或更新) ── case TaskSavedMsg: m.status = fmt.Sprintf("任务 %q 已保存", msg.Task.Name) - if msg.AutoStart && (m.dash == nil || !m.dash.IsRunning()) { + notRunning := (m.dash == nil || !m.dash.IsRunning()) && (m.turboDash == nil || !m.turboDash.IsRunning()) + if msg.AutoStart && notRunning { return m, tea.Batch( m.client.LoadTasksCmd(), m.client.StartRunCmd(msg.Task.ID), diff --git a/internal/tui/pages/wizard.go b/internal/tui/pages/wizard.go index 3210f77..3cce109 100644 --- a/internal/tui/pages/wizard.go +++ b/internal/tui/pages/wizard.go @@ -4,6 +4,7 @@ import ( "fmt" "strconv" "strings" + "time" tea "github.com/charmbracelet/bubbletea" "github.com/charmbracelet/lipgloss" @@ -139,6 +140,10 @@ func (wz *WizardState) BuildTaskConfig() server.TaskConfig { if turboRate <= 0 { turboRate = 0.9 } + var timeout time.Duration + if wz.Timeout > 0 { + timeout = time.Duration(wz.Timeout) * time.Second + } return server.TaskConfig{ Name: wz.Name, Input: types.Input{ @@ -148,6 +153,7 @@ func (wz *WizardState) BuildTaskConfig() server.TaskConfig { Model: wz.Model, Concurrency: wz.Concurrency, Count: wz.Count, + Timeout: timeout, Stream: wz.Stream, Turbo: wz.Turbo, TurboConfig: types.TurboConfig{ @@ -405,13 +411,23 @@ func HandleWizardKey(wz *WizardState, msg tea.KeyMsg, client Client) (*WizardSta wz.ScrollOff = 0 case "end": wz.ScrollOff = 1 << 30 - case "enter", "r": + case "enter": + cfg := wz.BuildTaskConfig() + var cmd tea.Cmd + if wz.EditingID != "" { + cmd = client.UpdateTaskCmd(wz.EditingID, cfg) + } else { + cmd = client.CreateTaskCmd(cfg, false) // 仅保存,不自动运行 + } + nav = NavAction{To: NavTaskList} + return wz, cmd, nav + case "r": cfg := wz.BuildTaskConfig() var cmd tea.Cmd if wz.EditingID != "" { cmd = client.UpdateTaskCmd(wz.EditingID, cfg) } else { - cmd = client.CreateTaskCmd(cfg, true) + cmd = client.CreateTaskCmd(cfg, true) // 保存并运行 } nav = NavAction{To: NavTaskList} return wz, cmd, nav From b9f3fb947e9fa257c74ca726b9c0ea9b03f387dc Mon Sep 17 00:00:00 2001 From: Alain Date: Sun, 17 May 2026 23:43:11 +0800 Subject: [PATCH 15/52] =?UTF-8?q?feat:=20=E5=A2=9E=E5=BC=BA=E4=BB=BB?= =?UTF-8?q?=E5=8A=A1=E8=AF=A6=E6=83=85=E9=A1=B5=E7=9A=84=E4=B8=8A=E4=B8=8B?= =?UTF-8?q?=E6=96=87=E6=9D=A1=EF=BC=8C=E6=B7=BB=E5=8A=A0=E4=BB=BB=E5=8A=A1?= =?UTF-8?q?=E6=AD=A3=E5=9C=A8=E8=BF=90=E8=A1=8C=E6=97=B6=E7=9A=84=E7=8A=B6?= =?UTF-8?q?=E6=80=81=E5=A4=84=E7=90=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/tui/pages/contextbar.go | 13 ++++++++++++- internal/tui/pages/dashboard.go | 9 +++++++-- internal/tui/pages/taskdetail.go | 11 ++++++++--- internal/tui/pages/tasklist.go | 2 +- internal/tui/pages/turbodash.go | 10 ++++++++-- 5 files changed, 36 insertions(+), 9 deletions(-) diff --git a/internal/tui/pages/contextbar.go b/internal/tui/pages/contextbar.go index ab3557d..8b8f2f3 100644 --- a/internal/tui/pages/contextbar.go +++ b/internal/tui/pages/contextbar.go @@ -38,7 +38,7 @@ func CtxBar_TaskDetail_NoHistory() []ContextBarItem { } } -// CtxBar_TaskDetail_HasHistory 任务详情页,有运行记录时。 +// CtxBar_TaskDetail_HasHistory 任务详情页,有运行记录且未运行时。 func CtxBar_TaskDetail_HasHistory() []ContextBarItem { return []ContextBarItem{ {Key: "↑↓", Desc: "选择记录"}, @@ -51,6 +51,17 @@ func CtxBar_TaskDetail_HasHistory() []ContextBarItem { } } +// CtxBar_TaskDetail_Running 任务详情页,任务正在运行时。 +func CtxBar_TaskDetail_Running() []ContextBarItem { + return []ContextBarItem{ + {Key: "↑↓", Desc: "选择记录"}, + {Key: "Enter", Desc: "进入运行中仓表盘"}, + {Key: "g", Desc: "导出历史 JSON"}, + {Key: "e", Desc: "编辑"}, + {Key: "y", Desc: "复制任务"}, + } +} + // CtxBar_Wizard_Step1 创建任务页,第 1 步。 func CtxBar_Wizard_Step1() []ContextBarItem { return []ContextBarItem{ diff --git a/internal/tui/pages/dashboard.go b/internal/tui/pages/dashboard.go index ee44575..0e2f698 100644 --- a/internal/tui/pages/dashboard.go +++ b/internal/tui/pages/dashboard.go @@ -3,6 +3,7 @@ package pages import ( "fmt" "strings" + "time" tea "github.com/charmbracelet/bubbletea" "github.com/charmbracelet/lipgloss" @@ -280,9 +281,13 @@ func buildProgressLine(rs *server.RunState, st Styles, width int) string { ratio = float64(done) / float64(total) } prefix := " 进度 " - elapsed := "" + elapsed := "─" if !rs.StartedAt.IsZero() { - elapsed = "─" + if rs.FinishedAt != nil { + elapsed = fmtDuration(rs.FinishedAt.Sub(rs.StartedAt)) + } else { + elapsed = fmtDuration(time.Since(rs.StartedAt)) + } } suffix := fmt.Sprintf(" %d / %d %s", done, total, elapsed) diff --git a/internal/tui/pages/taskdetail.go b/internal/tui/pages/taskdetail.go index 10d1523..3df591b 100644 --- a/internal/tui/pages/taskdetail.go +++ b/internal/tui/pages/taskdetail.go @@ -71,7 +71,9 @@ func HandleTaskDetailKey(s *TaskDetailState, msg tea.KeyMsg, client Client) (*Ta nav = NavAction{To: NavTaskList} case "r": - return s, client.StartRunCmd(s.Task.ID), nav + if s.ActiveRun == nil { + return s, client.StartRunCmd(s.Task.ID), nav + } case "g": if s.HistorySel >= 0 && s.HistorySel < effectiveLen { @@ -142,9 +144,12 @@ func RenderTaskDetail(s *TaskDetailState, st Styles, width, height int) string { if hasActive { effectiveLen++ } - if effectiveLen > 0 { + switch { + case hasActive: + cbItems = CtxBar_TaskDetail_Running() + case effectiveLen > 0: cbItems = CtxBar_TaskDetail_HasHistory() - } else { + default: cbItems = CtxBar_TaskDetail_NoHistory() } l := PageLayout{ diff --git a/internal/tui/pages/tasklist.go b/internal/tui/pages/tasklist.go index 8c0fc0d..7d60d6e 100644 --- a/internal/tui/pages/tasklist.go +++ b/internal/tui/pages/tasklist.go @@ -95,7 +95,7 @@ func HandleTaskListKey(s *TaskListState, msg tea.KeyMsg, client Client) (*TaskLi } case "r": - if t, ok := s.CurrentTask(); ok { + if t, ok := s.CurrentTask(); ok && !s.IsTaskRunning(t.ID) { return s, client.StartRunCmd(t.ID), nav } diff --git a/internal/tui/pages/turbodash.go b/internal/tui/pages/turbodash.go index e706586..09c115b 100644 --- a/internal/tui/pages/turbodash.go +++ b/internal/tui/pages/turbodash.go @@ -57,7 +57,10 @@ func HandleTurboDashKey(d *TurboDashState, msg tea.KeyMsg, client Client) (*Turb if len(levels) == 0 { break } - if d.LevelSel <= 0 { + if d.LevelSel < 0 { + // 首次按键:跳到最后一级(最新/最高并发),与 ↓ 保持一致 + d.LevelSel = len(levels) - 1 + } else if d.LevelSel <= 0 { d.LevelSel = len(levels) - 1 } else { d.LevelSel-- @@ -67,7 +70,10 @@ func HandleTurboDashKey(d *TurboDashState, msg tea.KeyMsg, client Client) (*Turb if len(levels) == 0 { break } - if d.LevelSel < len(levels)-1 { + if d.LevelSel < 0 { + // 首次按键:跳到最后一级(最新/最高并发),与 ↑ 保持一致 + d.LevelSel = len(levels) - 1 + } else if d.LevelSel < len(levels)-1 { d.LevelSel++ } else { d.LevelSel = 0 From c4b016c58c4b12e4cda36a7de3b9bb6404c83e84 Mon Sep 17 00:00:00 2001 From: Alain Date: Mon, 18 May 2026 08:41:40 +0800 Subject: [PATCH 16/52] =?UTF-8?q?feat:=20=E5=A2=9E=E5=BC=BA=E6=8A=A5?= =?UTF-8?q?=E5=91=8A=E6=95=B0=E6=8D=AE=E7=9A=84=20JSON=20=E5=8F=8D?= =?UTF-8?q?=E5=BA=8F=E5=88=97=E5=8C=96=EF=BC=8C=E6=94=AF=E6=8C=81=E6=97=B6?= =?UTF-8?q?=E9=97=B4=E5=AD=97=E6=AE=B5=E7=9A=84=E8=A7=A3=E6=9E=90=E5=92=8C?= =?UTF-8?q?=E5=8A=A0=E8=BD=BD=20feat:=20=E6=9B=B4=E6=96=B0=E4=BB=AA?= =?UTF-8?q?=E8=A1=A8=E6=9D=BF=E4=BB=BB=E5=8A=A1=E5=8F=82=E6=95=B0=E9=9D=A2?= =?UTF-8?q?=E6=9D=BF=E6=A0=87=E9=A2=98=E4=B8=BA=E2=80=9C=E8=BF=90=E8=A1=8C?= =?UTF-8?q?=E8=BF=9B=E5=BA=A6=E2=80=9D=20feat:=20=E4=BC=98=E5=8C=96?= =?UTF-8?q?=E6=8A=A5=E5=91=8A=E7=94=9F=E6=88=90=E9=80=BB=E8=BE=91=EF=BC=8C?= =?UTF-8?q?=E6=94=AF=E6=8C=81=E4=BB=8E=E7=A3=81=E7=9B=98=E5=BF=AB=E7=85=A7?= =?UTF-8?q?=E5=8A=A0=E8=BD=BD=E5=8E=86=E5=8F=B2=E8=BF=90=E8=A1=8C=E7=8A=B6?= =?UTF-8?q?=E6=80=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/server/run.go | 30 +++++++++++---- internal/tui/pages/dashboard.go | 2 +- internal/types/types.go | 67 +++++++++++++++++++++++++++++++++ 3 files changed, 90 insertions(+), 9 deletions(-) diff --git a/internal/server/run.go b/internal/server/run.go index 6265a2e..9a6810f 100644 --- a/internal/server/run.go +++ b/internal/server/run.go @@ -579,20 +579,34 @@ func (s *serverImpl) GetHistory(taskID string, limit int) ([]types.TaskRunSummar } // GenerateReport 为已完成的标准运行生成报告文件。 +// 先查内存中的 activeRuns,若不存在则从磁盘快照加载(支持跨 session 历史运行)。 func (s *serverImpl) GenerateReport(runID RunID, format ReportFormat) (string, error) { s.mu.RLock() ar, ok := s.activeRuns[runID] + historyDir := s.historyDir s.mu.RUnlock() - if !ok { - return "", fmt.Errorf("run %q not found", runID) - } + var status RunStatus + var mode string + var standardResult *types.ReportData - ar.mu.RLock() - status := ar.state.Status - mode := ar.state.Mode - standardResult := ar.state.StandardResult - ar.mu.RUnlock() + if ok { + ar.mu.RLock() + status = ar.state.Status + mode = ar.state.Mode + standardResult = ar.state.StandardResult + ar.mu.RUnlock() + } else { + // 不在内存中,尝试从磁盘快照加载(历史运行 / 程序重启后) + st := store.NewJSONStore[*RunState](runStatePath(historyDir, runID)) + snap, err := st.Load() + if err != nil || snap == nil { + return "", fmt.Errorf("run %q not found", runID) + } + status = snap.Status + mode = snap.Mode + standardResult = snap.StandardResult + } if status == RunStatusRunning { return "", fmt.Errorf("run %q is still in progress", runID) diff --git a/internal/tui/pages/dashboard.go b/internal/tui/pages/dashboard.go index 0e2f698..ac07c78 100644 --- a/internal/tui/pages/dashboard.go +++ b/internal/tui/pages/dashboard.go @@ -225,7 +225,7 @@ func RenderDashboard(d *DashboardState, taskName string, st Styles, width, heigh // buildDashParamsPanel 构建左侧任务参数面板。 func buildDashParamsPanel(d *DashboardState, rs *server.RunState, st Styles, maxH, width int) string { var lines []string - lines = append(lines, " "+st.SectionHead.Render("任务参数")) + lines = append(lines, " "+st.SectionHead.Render("运行进度")) lines = append(lines, "") if rs == nil { diff --git a/internal/types/types.go b/internal/types/types.go index 024af8a..dd458be 100644 --- a/internal/types/types.go +++ b/internal/types/types.go @@ -350,6 +350,73 @@ func (r *ReportData) MarshalJSON() ([]byte, error) { }) } +// UnmarshalJSON 自定义 JSON 反序列化,将字符串形式的 Duration 还原为 time.Duration。 +// 与 MarshalJSON 配对使用,确保持久化后的数据能正确加载。 +func (r *ReportData) UnmarshalJSON(data []byte) error { + type Alias ReportData + aux := &struct { + *Alias + TotalTime string `json:"total_time"` + AvgTotalTime string `json:"avg_total_time"` + MinTotalTime string `json:"min_total_time"` + MaxTotalTime string `json:"max_total_time"` + AvgDNSTime string `json:"avg_dns_time"` + MinDNSTime string `json:"min_dns_time"` + MaxDNSTime string `json:"max_dns_time"` + AvgConnectTime string `json:"avg_connect_time"` + MinConnectTime string `json:"min_connect_time"` + MaxConnectTime string `json:"max_connect_time"` + AvgTLSHandshakeTime string `json:"avg_tls_handshake_time"` + MinTLSHandshakeTime string `json:"min_tls_handshake_time"` + MaxTLSHandshakeTime string `json:"max_tls_handshake_time"` + AvgTTFT string `json:"avg_ttft"` + MinTTFT string `json:"min_ttft"` + MaxTTFT string `json:"max_ttft"` + AvgTPOT string `json:"avg_tpot"` + MinTPOT string `json:"min_tpot"` + MaxTPOT string `json:"max_tpot"` + StdDevTotalTime string `json:"stddev_total_time"` + StdDevTTFT string `json:"stddev_ttft"` + StdDevTPOT string `json:"stddev_tpot"` + }{Alias: (*Alias)(r)} + + if err := json.Unmarshal(data, aux); err != nil { + return err + } + + parseDur := func(s string) time.Duration { + if s == "" || s == "-" { + return 0 + } + d, _ := time.ParseDuration(s) + return d + } + + r.TotalTime = parseDur(aux.TotalTime) + r.AvgTotalTime = parseDur(aux.AvgTotalTime) + r.MinTotalTime = parseDur(aux.MinTotalTime) + r.MaxTotalTime = parseDur(aux.MaxTotalTime) + r.AvgDNSTime = parseDur(aux.AvgDNSTime) + r.MinDNSTime = parseDur(aux.MinDNSTime) + r.MaxDNSTime = parseDur(aux.MaxDNSTime) + r.AvgConnectTime = parseDur(aux.AvgConnectTime) + r.MinConnectTime = parseDur(aux.MinConnectTime) + r.MaxConnectTime = parseDur(aux.MaxConnectTime) + r.AvgTLSHandshakeTime = parseDur(aux.AvgTLSHandshakeTime) + r.MinTLSHandshakeTime = parseDur(aux.MinTLSHandshakeTime) + r.MaxTLSHandshakeTime = parseDur(aux.MaxTLSHandshakeTime) + r.AvgTTFT = parseDur(aux.AvgTTFT) + r.MinTTFT = parseDur(aux.MinTTFT) + r.MaxTTFT = parseDur(aux.MaxTTFT) + r.AvgTPOT = parseDur(aux.AvgTPOT) + r.MinTPOT = parseDur(aux.MinTPOT) + r.MaxTPOT = parseDur(aux.MaxTPOT) + r.StdDevTotalTime = parseDur(aux.StdDevTotalTime) + r.StdDevTTFT = parseDur(aux.StdDevTTFT) + r.StdDevTPOT = parseDur(aux.StdDevTPOT) + return nil +} + // formatTTFT 格式化 TTFT 字段,非流式模式返回 "-" func formatTTFT(duration time.Duration, isStream bool) string { if !isStream { From f81b783e6559106312bfd3a4c7f25d634f510e53 Mon Sep 17 00:00:00 2001 From: Alain Date: Tue, 19 May 2026 22:45:26 +0800 Subject: [PATCH 17/52] Refactor task management to use individual JSON files for persistence - Changed TaskStore to manage tasks in a directory instead of a single JSON file. - Implemented loading and saving of tasks as individual JSON files. - Removed the history management code and tests, as they are no longer needed. - Updated task-related types and methods to reflect the new structure. - Adjusted TUI components to work with the new task overview structure. - Added tests for task detail history entries to ensure correct behavior. --- docs/design.md | 4 +- docs/storage.md | 368 ++++++++++++++++++++++++++ internal/config/config.go | 81 ++++-- internal/config/config_test.go | 61 +++++ internal/server/run.go | 364 ++++++++++++------------- internal/server/server.go | 14 +- internal/server/server_test.go | 198 ++++++++++++-- internal/server/task.go | 78 +++++- internal/server/types.go | 30 +-- internal/store/history.go | 49 ---- internal/store/run.go | 311 ++++++++++++++++++++++ internal/store/store.go | 155 ++++++++++- internal/store/store_test.go | 77 ++++++ internal/store/task.go | 129 +++++++-- internal/task/history.go | 89 ------- internal/task/history_test.go | 41 --- internal/task/store.go | 102 ------- internal/task/store_test.go | 76 ------ internal/tui/client.go | 16 +- internal/tui/client_test.go | 30 +++ internal/tui/messages.go | 2 +- internal/tui/model.go | 4 +- internal/tui/model_test.go | 2 +- internal/tui/pages/dashboard.go | 3 +- internal/tui/pages/reqdetail.go | 13 +- internal/tui/pages/taskdetail.go | 104 +++++--- internal/tui/pages/taskdetail_test.go | 26 ++ internal/tui/pages/tasklist.go | 31 +-- internal/tui/pages/wizard.go | 2 +- internal/types/types.go | 33 ++- 30 files changed, 1745 insertions(+), 748 deletions(-) create mode 100644 docs/storage.md delete mode 100644 internal/store/history.go create mode 100644 internal/store/run.go create mode 100644 internal/store/store_test.go delete mode 100644 internal/task/history.go delete mode 100644 internal/task/history_test.go delete mode 100644 internal/task/store.go delete mode 100644 internal/task/store_test.go create mode 100644 internal/tui/client_test.go create mode 100644 internal/tui/pages/taskdetail_test.go diff --git a/docs/design.md b/docs/design.md index bb3a9eb..ea36792 100644 --- a/docs/design.md +++ b/docs/design.md @@ -198,12 +198,14 @@ type CancelFunc func() | **Server 层** | `internal/server` | 暴露业务 API;编排下层;管理运行状态;分发 Event | | **执行层** | `internal/runner` `internal/turbo` | 并发请求执行;回调推送指标;**不感知 UI** | | **协议层** | `internal/client` | OpenAI / Anthropic HTTP 客户端;不感知上层 | -| **持久化层** | `internal/task` `internal/config` | 任务 / 历史 / 配置的 JSON 文件读写 | +| **持久化层** | `internal/store` `internal/config` | `internal/store` 是下一版唯一持久化实现;`internal/config` 仅负责应用目录与路径解析 | | **渲染层** | `internal/report` | JSON / CSV / Turbo 报告渲染;纯函数,无副作用 | | **工具层** | `internal/prompt` `internal/network` `internal/logger` `internal/upload` | 公共工具,无业务依赖 | | **TUI Client** | `internal/tui` | BubbleTea 状态机;**只依赖 server.Server 接口**;渲染终端 UI | | **Web Client** _(Future)_ | `internal/webui` | HTTP/WS 桥接;**只依赖 server.Server 接口**;提供 Web API | +> 存储设计请优先参考 [docs/storage.md](storage.md)。该文档描述目标存储架构,不以当前实现与兼容层为约束。 + ### 3.4 目录结构 ``` diff --git a/docs/storage.md b/docs/storage.md new file mode 100644 index 0000000..e4d4c25 --- /dev/null +++ b/docs/storage.md @@ -0,0 +1,368 @@ +# 存储设计(精简目标方案) + +本文定义下一版存储架构的目标约束: + +- 只存业务数据。 +- 不存查询视图。 +- 不存中间态快照。 +- 不存可以稳定从底层业务数据重新聚合出来的数据副本。 + +如果某份数据既不是业务本体,也不是最终必须落库的业务结果,就不进入持久化层。 + +## 1. 只存什么 + +下一版只持久化四类数据: + +1. 配置。 +2. 任务。 +3. 请求明细。 +4. 测试结果指标。 + +除此之外,默认都不存。 + +## 2. 明确不存什么 + +以下内容不进入持久化层: + +1. 任务列表缓存。 +2. 历史列表索引。 +3. active runs 视图。 +4. recent runs 视图。 +5. last run summary 之类的冗余摘要。 +6. 运行中的快照文件。 +7. lifecycle 事件流。 +8. Turbo 级别事件日志。 +9. 报告产物清单。 +10. schema 文件。 +11. task meta、notes、tags、owner 这类当前非核心业务字段。 + +这些内容要么属于界面视图,要么属于运行时控制信息,要么可以从底层业务数据重新计算。 + +## 3. 设计原则 + +### 3.1 单份业务事实 + +同一类业务信息只保存一份: + +- 任务定义只保存在任务文件里。 +- 请求事实只保存在请求日志里。 +- 最终测试指标只保存在结果文件里。 + +### 3.2 不为查询速度存冗余副本 + +任务列表、任务历史、最近运行、当前运行,都通过扫描任务文件和运行目录实时构造。 + +本方案优先保证存储模型干净,而不是用额外索引换读取速度。 + +### 3.3 运行态不落盘 + +运行中的中间进度、实时聚合值、仪表盘卡片数据都只存在内存。 + +落盘只发生在两种时机: + +1. 请求完成,追加请求事实。 +2. 运行结束,写最终指标结果。 + +### 3.4 结果允许保留一份最终聚合 + +虽然最终指标可以从请求明细重新计算,但“最终测试指标”本身属于业务结果,因此允许保留一份最终结果文件。 + +约束是: + +- 只保留一份最终结果。 +- 不再额外保留面向不同页面的摘要副本。 + +## 4. 目录布局 + +推荐目录布局如下: + +```text +~/.ait/ + config.json + tasks/ + .json + runs/ + / + / + run.json + requests.jsonl + result.json +``` + +这就是完整持久化集合,不再额外引入 views、artifacts、meta、snapshot 等目录。 + +## 5. 核心数据模型 + +### 5.1 config.json + +只保存全局配置,例如: + +- 默认协议 +- 上次选中的任务 +- 是否保存 API Key + +不保存任何运行态信息。 + +### 5.2 tasks/.json + +每个任务只对应一个文件。 + +建议字段: + +- task_id +- name +- input +- created_at +- updated_at + +这里的 input 就是任务配置本体: + +- 协议 +- endpoint/base_url +- model +- concurrency +- count +- stream +- thinking +- turbo 配置 +- prompt 配置 +- timeout + +不再在任务文件里嵌入: + +- last_run_at +- last_run_summary +- history +- report path + +因为这些都不是任务定义本身。 + +### 5.3 runs///run.json + +run.json 只保存最小运行元数据,用来描述这次运行属于谁、何时开始、何时结束、最终状态是什么。 + +建议字段: + +- run_id +- task_id +- mode +- protocol +- model +- status +- started_at +- finished_at + +它不承担指标存储,不承担请求明细存储,也不承担列表视图存储。 + +### 5.4 runs///requests.jsonl + +这是请求级业务事实源。 + +每行一条请求记录,建议至少包含: + +- request_index +- success +- total_time +- ttft +- tps +- input_tokens +- output_tokens +- cached_tokens +- cache_hit_rate +- dns_time +- connect_time +- tls_time +- target_ip +- error_message +- request_body +- response_body + +所有请求详情、离线分析、结果复算,都以这里为准。 + +### 5.5 runs///result.json + +这是单次运行的最终业务结果。 + +建议保存: + +- 总请求数 +- 成功率 / 错误率 +- TTFT / TPS / 总耗时等最终聚合指标 +- 输入输出 token 统计 +- Turbo 模式的最终级别结果 +- 最终错误摘要(如有) + +这个文件只在运行结束后写一次。 + +它是最终对外展示的“测试结果”,而不是运行中的中间态快照。 + +## 6. 为什么不再存其他内容 + +### 6.1 不存任务历史索引 + +任务历史本质上可以通过扫描 `runs//` 下的 run 目录得到。 + +因此不再单独维护: + +- `history/.json` +- task-runs 视图文件 + +### 6.2 不存任务列表摘要 + +任务列表可以通过扫描 `tasks/` 目录构造。 + +因此不再单独维护: + +- `tasks.json` +- task-list view +- last_run_summary 类冗余字段 + +### 6.3 不存运行中快照 + +运行中快照本质上是 UI 关心的中间态,不是必须持久化的业务数据。 + +因此不再存: + +- snapshot.json +- active-runs.json +- recent-runs.json + +如果进程退出,运行中状态丢失是可接受的;真正的业务结果以已落盘的请求记录和最终 result 为准。 + +### 6.4 不存导出产物清单 + +报告是导出物,不属于核心业务数据。 + +下一版建议: + +- 报告按需生成。 +- 默认不纳入主存储目录。 +- 如果用户要导出,就直接输出到指定路径。 + +因此不再存: + +- artifacts.json +- report 路径索引 +- report availability 标志 + +### 6.5 不存额外生命周期日志 + +状态变更、停止原因、页面卡片状态等,如果不是最终 run.json 必须字段,就不单独落日志。 + +因此不再存: + +- lifecycle.jsonl +- 中间状态事件流 + +## 7. 写入模型 + +### 7.1 创建任务 + +只写一个任务文件: + +1. 写 `tasks/.json` + +### 7.2 更新任务 + +只覆盖任务文件: + +1. 覆盖 `tasks/.json` + +### 7.3 启动运行 + +启动运行时只写最小运行元数据: + +1. 创建 `runs///` +2. 写 `run.json`,状态为 running + +不写摘要,不写视图,不写任务回填字段。 + +### 7.4 运行中 + +每完成一个请求: + +1. 追加一条到 `requests.jsonl` + +不写快照,不写任务历史,不写任务摘要。 + +### 7.5 运行结束 + +运行结束时: + +1. 更新 `run.json` 的最终状态和结束时间 +2. 写 `result.json` + +到此结束,不做额外索引更新。 + +## 8. 读取模型 + +### 8.1 任务列表 + +通过扫描 `tasks/` 目录得到。 + +### 8.2 任务历史 + +通过扫描 `runs//` 下的所有 run 目录得到。 + +历史卡片需要的展示字段,来自: + +- `run.json` +- `result.json` + +### 8.3 单次运行详情 + +单次运行详情只读三类文件: + +1. `run.json` +2. `requests.jsonl` +3. `result.json` + +没有 fallback 文件,没有额外索引文件。 + +## 9. Repository 建议 + +下一版 `internal/store` 可以收敛成最小集合: + +```text +internal/store/ + fs.go # 原子写、JSON/JSONL、目录工具 + config_repo.go # config.json + task_repo.go # tasks/.json + run_repo.go # run.json / result.json + request_log.go # requests.jsonl append/read +``` + +不再需要: + +- view_repo +- projector +- artifact_repo +- rebuild +- history store +- task summary store + +## 10. 代价与取舍 + +这个方案的代价很明确: + +1. 任务列表和任务历史读取时需要扫描目录。 +2. 运行中的跨进程恢复能力会变弱。 +3. 某些 UI 页面首次读取会比“预存视图”慢。 + +但换来的好处更符合目标: + +1. 存储模型极简。 +2. 不再维护多份互相覆盖的摘要。 +3. 高低频写路径都更清晰。 +4. 数据结构更接近业务本体。 + +## 11. 一句话总结 + +下一版只存: + +- 配置 +- 任务 +- 请求明细 +- 最终测试指标 + +其他一律不存;需要展示时,从这四类业务数据实时聚合。 diff --git a/internal/config/config.go b/internal/config/config.go index c989c08..ecb2fae 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -1,17 +1,20 @@ package config import ( - "encoding/json" - "errors" "os" "path/filepath" + + storepkg "github.com/yinxulai/ait/internal/store" ) const ( appDirName = ".ait" configJSON = "config.json" - tasksJSON = "tasks.json" - historyDirName = "history" + tasksDirName = "tasks" + runsDirName = "runs" + runMetaJSON = "run.json" + runResultJSON = "result.json" + runReqsJSONL = "requests.jsonl" ) type Config struct { @@ -26,19 +29,11 @@ func Load() (*Config, error) { return nil, err } - data, err := os.ReadFile(path) - if errors.Is(err, os.ErrNotExist) { - return &Config{}, nil - } + loaded, err := storepkg.NewJSONStore[Config](path).Load() if err != nil { return nil, err } - - var cfg Config - if err := json.Unmarshal(data, &cfg); err != nil { - return nil, err - } - return &cfg, nil + return &loaded, nil } func (c *Config) Save() error { @@ -46,15 +41,7 @@ func (c *Config) Save() error { if err != nil { return err } - if _, err := EnsureAppDir(); err != nil { - return err - } - - data, err := json.MarshalIndent(c, "", " ") - if err != nil { - return err - } - return os.WriteFile(path, data, 0o644) + return storepkg.NewJSONStore[Config](path).Save(*c) } func AppDir() (string, error) { @@ -84,18 +71,58 @@ func ConfigPath() (string, error) { return filepath.Join(dir, configJSON), nil } -func TasksPath() (string, error) { +func TasksDir() (string, error) { dir, err := AppDir() if err != nil { return "", err } - return filepath.Join(dir, tasksJSON), nil + return filepath.Join(dir, tasksDirName), nil } -func HistoryDir() (string, error) { +func RunsDir() (string, error) { dir, err := AppDir() if err != nil { return "", err } - return filepath.Join(dir, historyDirName), nil + return filepath.Join(dir, runsDirName), nil +} + +func TaskPath(taskID string) (string, error) { + dir, err := TasksDir() + if err != nil { + return "", err + } + return filepath.Join(dir, taskID+".json"), nil +} + +func RunDir(taskID, runID string) (string, error) { + dir, err := RunsDir() + if err != nil { + return "", err + } + return filepath.Join(dir, taskID, runID), nil +} + +func RunMetadataPath(taskID, runID string) (string, error) { + dir, err := RunDir(taskID, runID) + if err != nil { + return "", err + } + return filepath.Join(dir, runMetaJSON), nil +} + +func RunResultPath(taskID, runID string) (string, error) { + dir, err := RunDir(taskID, runID) + if err != nil { + return "", err + } + return filepath.Join(dir, runResultJSON), nil +} + +func RunRequestsPath(taskID, runID string) (string, error) { + dir, err := RunDir(taskID, runID) + if err != nil { + return "", err + } + return filepath.Join(dir, runReqsJSONL), nil } diff --git a/internal/config/config_test.go b/internal/config/config_test.go index 622c961..657f2b8 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -49,3 +49,64 @@ func TestConfigSaveAndLoadRoundTrip(t *testing.T) { t.Fatalf("unexpected loaded config: %+v", loaded) } } + +func TestStoragePaths(t *testing.T) { + homeDir := t.TempDir() + t.Setenv("HOME", homeDir) + + tasksDir, err := TasksDir() + if err != nil { + t.Fatalf("TasksDir() returned unexpected error: %v", err) + } + if want := filepath.Join(homeDir, ".ait", "tasks"); tasksDir != want { + t.Fatalf("expected tasks dir %s, got %s", want, tasksDir) + } + + runsDir, err := RunsDir() + if err != nil { + t.Fatalf("RunsDir() returned unexpected error: %v", err) + } + if want := filepath.Join(homeDir, ".ait", "runs"); runsDir != want { + t.Fatalf("expected runs dir %s, got %s", want, runsDir) + } + + taskPath, err := TaskPath("task-1") + if err != nil { + t.Fatalf("TaskPath() returned unexpected error: %v", err) + } + if want := filepath.Join(homeDir, ".ait", "tasks", "task-1.json"); taskPath != want { + t.Fatalf("expected task path %s, got %s", want, taskPath) + } + + runDir, err := RunDir("task-1", "run-1") + if err != nil { + t.Fatalf("RunDir() returned unexpected error: %v", err) + } + if want := filepath.Join(homeDir, ".ait", "runs", "task-1", "run-1"); runDir != want { + t.Fatalf("expected run dir %s, got %s", want, runDir) + } + + runMetaPath, err := RunMetadataPath("task-1", "run-1") + if err != nil { + t.Fatalf("RunMetadataPath() returned unexpected error: %v", err) + } + if want := filepath.Join(homeDir, ".ait", "runs", "task-1", "run-1", "run.json"); runMetaPath != want { + t.Fatalf("expected run metadata path %s, got %s", want, runMetaPath) + } + + runResultPath, err := RunResultPath("task-1", "run-1") + if err != nil { + t.Fatalf("RunResultPath() returned unexpected error: %v", err) + } + if want := filepath.Join(homeDir, ".ait", "runs", "task-1", "run-1", "result.json"); runResultPath != want { + t.Fatalf("expected run result path %s, got %s", want, runResultPath) + } + + runRequestsPath, err := RunRequestsPath("task-1", "run-1") + if err != nil { + t.Fatalf("RunRequestsPath() returned unexpected error: %v", err) + } + if want := filepath.Join(homeDir, ".ait", "runs", "task-1", "run-1", "requests.jsonl"); runRequestsPath != want { + t.Fatalf("expected run requests path %s, got %s", want, runRequestsPath) + } +} diff --git a/internal/server/run.go b/internal/server/run.go index 9a6810f..bc4491c 100644 --- a/internal/server/run.go +++ b/internal/server/run.go @@ -1,12 +1,7 @@ package server import ( - "bufio" - "encoding/json" "fmt" - "os" - "path/filepath" - "sort" "sync" "sync/atomic" "time" @@ -54,7 +49,7 @@ func (ar *activeRun) snapshotState() *RunState { snap := *s // 深拷贝切片 if len(s.Requests) > 0 { - snap.Requests = make([]*RequestMetrics, len(s.Requests)) + snap.Requests = make([]*types.RequestMetrics, len(s.Requests)) copy(snap.Requests, s.Requests) } if len(s.Levels) > 0 { @@ -64,9 +59,9 @@ func (ar *activeRun) snapshotState() *RunState { return &snap } -// mapRequestMetrics 将 client.ResponseMetrics 映射到 server.RequestMetrics。 -func mapRequestMetrics(m *client.ResponseMetrics, idx int, err error) *RequestMetrics { - rm := &RequestMetrics{Index: idx} +// mapRequestMetrics 将 client.ResponseMetrics 映射到 types.RequestMetrics。 +func mapRequestMetrics(m *client.ResponseMetrics, idx int, err error) *types.RequestMetrics { + rm := &types.RequestMetrics{Index: idx} if m == nil { rm.Success = false if err != nil { @@ -101,82 +96,132 @@ func mapRequestMetrics(m *client.ResponseMetrics, idx int, err error) *RequestMe return rm } -// historyPath 返回指定任务的历史文件路径。 -func historyPath(historyDir, taskID string) string { - return filepath.Join(historyDir, taskID+".json") -} - -// runStatePath 返回指定运行的完整状态快照文件路径(用于历史回放)。 -func runStatePath(historyDir string, runID RunID) string { - return filepath.Join(historyDir, "runs", string(runID)+".json") +func requestPointers(requests []types.RequestMetrics) []*types.RequestMetrics { + if len(requests) == 0 { + return nil + } + pointers := make([]*types.RequestMetrics, 0, len(requests)) + for i := range requests { + request := requests[i] + pointers = append(pointers, &request) + } + return pointers } -// requestsFilePath 返回指定运行的请求详情 JSONL 文件路径。 -func requestsFilePath(historyDir string, runID RunID) string { - return filepath.Join(historyDir, "runs", string(runID)+".jsonl") +func buildStoredRunMetadata(taskDef types.TaskDefinition, snap *RunState) store.RunMetadata { + var finishedAt *time.Time + if snap.FinishedAt != nil { + finished := *snap.FinishedAt + finishedAt = &finished + } + return store.RunMetadata{ + RunID: string(snap.RunID), + TaskID: snap.TaskID, + Mode: snap.Mode, + Protocol: taskDef.Input.NormalizedProtocol(), + Model: taskDef.Input.Model, + Status: string(snap.Status), + StartedAt: snap.StartedAt, + FinishedAt: finishedAt, + } } -// appendRequestToDisk 将单条 RequestMetrics 以 JSON 行的形式追加写入磁盘。 -func appendRequestToDisk(historyDir string, runID RunID, rm *RequestMetrics) { - path := requestsFilePath(historyDir, runID) - f, err := os.OpenFile(path, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) - if err != nil { - return - } - defer f.Close() - data, err := json.Marshal(rm) - if err != nil { - return - } - _, _ = f.Write(data) - _, _ = f.Write([]byte{'\n'}) +func buildStoredRunResult(snap *RunState) store.RunResult { + result := store.RunResult{ + TotalReqs: snap.TotalReqs, + DoneReqs: snap.DoneReqs, + SuccessReqs: snap.SuccessReqs, + FailedReqs: snap.FailedReqs, + SuccessRate: snap.SuccessRate, + AvgTTFT: snap.AvgTTFT, + AvgTPS: snap.AvgTPS, + CacheHitRate: snap.CacheHitRate, + ErrorSummary: snap.ErrorMsg, + StandardResult: snap.StandardResult, + TurboResult: snap.TurboResult, + } + if snap.TurboResult != nil { + result.MaxStableConcurrency = snap.TurboResult.MaxStableConcurrency + } else if snap.CurrentLevel > 0 { + result.MaxStableConcurrency = snap.CurrentLevel + } + return result } -// loadRequestsFromDisk 从 JSONL 文件中加载所有 RequestMetrics,按 Index 排序。 -func loadRequestsFromDisk(historyDir string, runID RunID) []*RequestMetrics { - path := requestsFilePath(historyDir, runID) - f, err := os.Open(path) - if err != nil { +func buildRunStateFromStoredRun(run *store.StoredRun, requests []*types.RequestMetrics) *RunState { + if run == nil { return nil } - defer f.Close() - - const maxLineSize = 16 * 1024 * 1024 // 16 MB per line - buf := make([]byte, maxLineSize) - scanner := bufio.NewScanner(f) - scanner.Buffer(buf, maxLineSize) - var reqs []*RequestMetrics - for scanner.Scan() { - line := scanner.Bytes() - if len(line) == 0 { - continue - } - var rm RequestMetrics - if err := json.Unmarshal(line, &rm); err != nil { - continue - } - reqs = append(reqs, &rm) - } - sort.Slice(reqs, func(i, j int) bool { - return reqs[i].Index < reqs[j].Index - }) - return reqs + state := &RunState{ + RunID: RunID(run.Metadata.RunID), + TaskID: run.Metadata.TaskID, + Status: RunStatus(run.Metadata.Status), + Mode: run.Metadata.Mode, + StartedAt: run.Metadata.StartedAt, + Requests: requests, + } + if run.Metadata.FinishedAt != nil { + finished := *run.Metadata.FinishedAt + state.FinishedAt = &finished + } + if run.Result == nil { + return state + } + + state.TotalReqs = run.Result.TotalReqs + state.DoneReqs = run.Result.DoneReqs + state.SuccessReqs = run.Result.SuccessReqs + state.FailedReqs = run.Result.FailedReqs + state.SuccessRate = run.Result.SuccessRate + state.AvgTTFT = run.Result.AvgTTFT + state.AvgTPS = run.Result.AvgTPS + state.CacheHitRate = run.Result.CacheHitRate + state.StandardResult = run.Result.StandardResult + state.TurboResult = run.Result.TurboResult + state.ErrorMsg = run.Result.ErrorSummary + state.CurrentLevel = run.Result.MaxStableConcurrency + if state.DoneReqs == 0 && len(requests) > 0 { + state.DoneReqs = len(requests) + } + if state.TotalReqs == 0 && len(requests) > 0 { + state.TotalReqs = len(requests) + } + if run.Result.TurboResult != nil { + state.Levels = run.Result.TurboResult.Levels + state.CurrentLevel = run.Result.TurboResult.MaxStableConcurrency + } + return state } -// persistRunState 将 RunState 元数据写入磁盘(不含 Requests,请求详情在 JSONL 文件中)。 -func persistRunState(historyDir string, snap *RunState) { - toSave := *snap - toSave.Requests = nil // 请求详情已逐条写入 JSONL,避免重复存储 - st := store.NewJSONStore[*RunState](runStatePath(historyDir, snap.RunID)) - _ = st.Save(&toSave) // 失败不影响主流程 +func buildRunningRunSummary(taskDef types.TaskDefinition, snap *RunState) types.TaskRunSummary { + summary := types.TaskRunSummary{ + RunID: string(snap.RunID), + TaskID: taskDef.ID, + Mode: snap.Mode, + Status: string(snap.Status), + Protocol: taskDef.Input.NormalizedProtocol(), + Model: taskDef.Input.Model, + StartedAt: snap.StartedAt, + SuccessRate: snap.SuccessRate, + AvgTTFT: snap.AvgTTFT, + AvgTPS: snap.AvgTPS, + CacheHitRate: snap.CacheHitRate, + } + if snap.FinishedAt != nil { + summary.FinishedAt = *snap.FinishedAt + } + if snap.ErrorMsg != "" { + summary.ErrorSummary = snap.ErrorMsg + } + return summary } // StartRun 启动一次新的运行,立即返回 RunID。 func (s *serverImpl) StartRun(taskID string) (RunID, error) { s.mu.RLock() taskDef, ok := s.taskStore.Get(taskID) - historyDir := s.historyDir + runStore := s.runStore s.mu.RUnlock() if !ok { @@ -219,19 +264,19 @@ func (s *serverImpl) StartRun(taskID string) (RunID, error) { s.mu.Unlock() if hydratedInput.Turbo { - go s.runTurbo(ar, runID, taskDef, hydratedInput, historyDir) + go s.runTurbo(ar, runID, taskDef, hydratedInput, runStore) } else { - go s.runStandard(ar, runID, taskDef, hydratedInput, historyDir) + go s.runStandard(ar, runID, taskDef, hydratedInput, runStore) } return runID, nil } // runStandard 在 goroutine 中执行标准运行。 -func (s *serverImpl) runStandard(ar *activeRun, runID RunID, taskDef types.TaskDefinition, input types.Input, historyDir string) { +func (s *serverImpl) runStandard(ar *activeRun, runID RunID, taskDef types.TaskDefinition, input types.Input, runStore *store.RunStore) { rnr, err := runner.NewRunner(taskDef.ID, input) if err != nil { - s.failRun(ar, runID, taskDef, historyDir, err) + s.failRun(ar, runID, taskDef, runStore, err) return } @@ -259,7 +304,7 @@ func (s *serverImpl) runStandard(ar *activeRun, runID RunID, taskDef types.TaskD reportData, err := rnr.RunWithCallback(func(metrics *client.ResponseMetrics, idx int, cbErr error) { rm := mapRequestMetrics(metrics, idx, cbErr) - appendRequestToDisk(historyDir, runID, rm) + _ = runStore.AppendRequest(taskDef.ID, string(runID), *rm) ar.mu.Lock() ar.state.Requests = append(ar.state.Requests, rm) @@ -292,15 +337,15 @@ func (s *serverImpl) runStandard(ar *activeRun, runID RunID, taskDef types.TaskD close(stopTick) if err != nil { - s.failRun(ar, runID, taskDef, historyDir, err) + s.failRun(ar, runID, taskDef, runStore, err) return } - s.completeStandardRun(ar, runID, taskDef, historyDir, reportData) + s.completeStandardRun(ar, runID, taskDef, runStore, reportData) } // runTurbo 在 goroutine 中执行 Turbo 运行。 -func (s *serverImpl) runTurbo(ar *activeRun, runID RunID, taskDef types.TaskDefinition, input types.Input, historyDir string) { +func (s *serverImpl) runTurbo(ar *activeRun, runID RunID, taskDef types.TaskDefinition, input types.Input, runStore *store.RunStore) { // 全局请求计数器(原子递增),确保跨多个并发级别的请求索引唯一 var globalIdx int64 @@ -314,7 +359,7 @@ func (s *serverImpl) runTurbo(ar *activeRun, runID RunID, taskDef types.TaskDefi cb: func(metrics *client.ResponseMetrics, _ int, cbErr error) { gIdx := int(atomic.AddInt64(&globalIdx, 1)) - 1 rm := mapRequestMetrics(metrics, gIdx, cbErr) - appendRequestToDisk(historyDir, runID, rm) + _ = runStore.AppendRequest(taskDef.ID, string(runID), *rm) ar.mu.Lock() ar.state.Requests = append(ar.state.Requests, rm) @@ -352,15 +397,15 @@ func (s *serverImpl) runTurbo(ar *activeRun, runID RunID, taskDef types.TaskDefi turboResult, err := engine.Run(input) if err != nil { - s.failRun(ar, runID, taskDef, historyDir, err) + s.failRun(ar, runID, taskDef, runStore, err) return } - s.completeTurboRun(ar, runID, taskDef, historyDir, turboResult) + s.completeTurboRun(ar, runID, taskDef, runStore, turboResult) } // completeStandardRun 处理标准运行成功完成的后续工作。 -func (s *serverImpl) completeStandardRun(ar *activeRun, runID RunID, taskDef types.TaskDefinition, historyDir string, data *types.ReportData) { +func (s *serverImpl) completeStandardRun(ar *activeRun, runID RunID, taskDef types.TaskDefinition, runStore *store.RunStore, data *types.ReportData) { finishedAt := time.Now() ar.mu.Lock() @@ -378,35 +423,13 @@ func (s *serverImpl) completeStandardRun(ar *activeRun, runID RunID, taskDef typ s.bus.Publish(Event{RunID: runID, Kind: EventRunComplete, Payload: snap}) s.bus.CloseRun(runID) - - // 将元数据写入磁盘(Requests 已造话就写 JSONL,此处仅写 RunState 元数据) - persistRunState(historyDir, snap) - - // 运行已完成且数据已落盘,释放内存中的请求详情 - ar.mu.Lock() - ar.state.Requests = nil - ar.mu.Unlock() - - summary := types.TaskRunSummary{ - RunID: string(runID), - TaskID: taskDef.ID, - Mode: "standard", - Status: string(RunStatusCompleted), - Protocol: taskDef.Input.NormalizedProtocol(), - Model: taskDef.Input.Model, - StartedAt: snap.StartedAt, - FinishedAt: finishedAt, - SuccessRate: snap.SuccessRate, - AvgTTFT: snap.AvgTTFT, - AvgTPS: snap.AvgTPS, - CacheHitRate: snap.CacheHitRate, + if err := s.persistFinalRun(runStore, taskDef, snap); err == nil { + s.removeActiveRun(runID) } - - s.persistRunResult(taskDef.ID, historyDir, summary) } // completeTurboRun 处理 Turbo 运行成功完成的后续工作。 -func (s *serverImpl) completeTurboRun(ar *activeRun, runID RunID, taskDef types.TaskDefinition, historyDir string, result *types.TurboResult) { +func (s *serverImpl) completeTurboRun(ar *activeRun, runID RunID, taskDef types.TaskDefinition, runStore *store.RunStore, result *types.TurboResult) { finishedAt := time.Now() ar.mu.Lock() @@ -422,40 +445,13 @@ func (s *serverImpl) completeTurboRun(ar *activeRun, runID RunID, taskDef types. s.bus.Publish(Event{RunID: runID, Kind: EventRunComplete, Payload: snap}) s.bus.CloseRun(runID) - - // 将元数据写入磁盘(Requests 已造话就写 JSONL,此处仅写 RunState 元数据) - persistRunState(historyDir, snap) - - // 运行已完成且数据已落盘,释放内存中的请求详情 - ar.mu.Lock() - ar.state.Requests = nil - ar.mu.Unlock() - - var maxStable int - var peakTPS float64 - if result != nil { - maxStable = result.MaxStableConcurrency - peakTPS = result.PeakTPS + if err := s.persistFinalRun(runStore, taskDef, snap); err == nil { + s.removeActiveRun(runID) } - - summary := types.TaskRunSummary{ - RunID: string(runID), - TaskID: taskDef.ID, - Mode: "turbo", - Status: string(RunStatusCompleted), - Protocol: taskDef.Input.NormalizedProtocol(), - Model: taskDef.Input.Model, - StartedAt: snap.StartedAt, - FinishedAt: finishedAt, - MaxStableConcurrency: maxStable, - AvgTPS: peakTPS, - } - - s.persistRunResult(taskDef.ID, historyDir, summary) } // failRun 处理运行失败的后续工作。 -func (s *serverImpl) failRun(ar *activeRun, runID RunID, taskDef types.TaskDefinition, historyDir string, runErr error) { +func (s *serverImpl) failRun(ar *activeRun, runID RunID, taskDef types.TaskDefinition, runStore *store.RunStore, runErr error) { finishedAt := time.Now() ar.mu.Lock() @@ -467,46 +463,23 @@ func (s *serverImpl) failRun(ar *activeRun, runID RunID, taskDef types.TaskDefin s.bus.Publish(Event{RunID: runID, Kind: EventRunFailed, Payload: snap}) s.bus.CloseRun(runID) - - // 将元数据写入磁盘(Requests 已造话就写 JSONL,此处仅写 RunState 元数据) - persistRunState(historyDir, snap) - - // 运行已完成且数据已落盘,释放内存中的请求详情 - ar.mu.Lock() - ar.state.Requests = nil - ar.mu.Unlock() - - summary := types.TaskRunSummary{ - RunID: string(runID), - TaskID: taskDef.ID, - Mode: ar.state.Mode, - Status: string(RunStatusFailed), - Protocol: taskDef.Input.NormalizedProtocol(), - Model: taskDef.Input.Model, - StartedAt: snap.StartedAt, - FinishedAt: finishedAt, - ErrorSummary: runErr.Error(), + if err := s.persistFinalRun(runStore, taskDef, snap); err == nil { + s.removeActiveRun(runID) } +} - s.persistRunResult(taskDef.ID, historyDir, summary) +func (s *serverImpl) persistFinalRun(runStore *store.RunStore, taskDef types.TaskDefinition, snap *RunState) error { + return runStore.SaveFinal(buildStoredRunMetadata(taskDef, snap), buildStoredRunResult(snap)) } -// persistRunResult 将运行摘要写入历史文件,并更新任务的 LastRunAt/LastRunSummary。 -func (s *serverImpl) persistRunResult(taskID, historyDir string, summary types.TaskRunSummary) { - // 写历史文件 - hs := store.NewHistoryStore(historyPath(historyDir, taskID)) - _ = hs.Append(summary) // 历史记录失败不影响主流程 +func (s *serverImpl) persistRunResult(summary types.TaskRunSummary) error { + return s.runStore.SaveSummary(summary) +} - // 更新任务的最后运行时间和摘要 +func (s *serverImpl) removeActiveRun(runID RunID) { s.mu.Lock() defer s.mu.Unlock() - existing, ok := s.taskStore.Get(taskID) - if ok { - existing.LastRunAt = &summary.FinishedAt - existing.LastRunSummary = &summary - s.taskStore.Upsert(existing) - _ = s.taskStore.Save() - } + delete(s.activeRuns, runID) } // StopRun 请求停止指定运行。 @@ -534,33 +507,29 @@ func (s *serverImpl) StopRun(runID RunID) error { } // GetRunState 返回指定运行的当前状态快照。 -// 先查内存中的 activeRuns;若不存在,再尝试从磁盘加载持久化的快照(历史回放)。 +// 先查内存中的 activeRuns;若不存在,再尝试从磁盘加载最终运行结果(历史回放)。 func (s *serverImpl) GetRunState(runID RunID) (*RunState, bool) { s.mu.RLock() ar, ok := s.activeRuns[runID] - historyDir := s.historyDir + runStore := s.runStore s.mu.RUnlock() if ok { ar.mu.RLock() snap := ar.snapshotState() ar.mu.RUnlock() - // 运行已完成且 Requests 已从内存清除(节省内存),从 JSONL 文件补充加载 - if snap.Status != RunStatusRunning && snap.Requests == nil { - snap.Requests = loadRequestsFromDisk(historyDir, snap.RunID) - } return snap, true } - // 不在内存中,尝试从磁盘加载持久化的 RunState 快照 - st := store.NewJSONStore[*RunState](runStatePath(historyDir, runID)) - snap, err := st.Load() - if err != nil || snap == nil { + run, err := runStore.LoadByRunID(string(runID)) + if err != nil || run == nil { return nil, false } - // 从 JSONL 文件加载请求详情 - snap.Requests = loadRequestsFromDisk(historyDir, runID) - return snap, true + requests, err := runStore.LoadRequests(run.Metadata.TaskID, string(runID)) + if err != nil { + return nil, false + } + return buildRunStateFromStoredRun(run, requestPointers(requests)), true } // Subscribe 订阅指定运行的事件流。 @@ -570,20 +539,23 @@ func (s *serverImpl) Subscribe(runID RunID) (<-chan Event, CancelFunc) { // GetHistory 返回任务的历史运行摘要,最新在前。 func (s *serverImpl) GetHistory(taskID string, limit int) ([]types.TaskRunSummary, error) { - s.mu.RLock() - historyDir := s.historyDir - s.mu.RUnlock() - - hs := store.NewHistoryStore(historyPath(historyDir, taskID)) - return hs.Load(limit) + runs, err := s.runStore.ListByTask(taskID, limit) + if err != nil { + return nil, err + } + history := make([]types.TaskRunSummary, 0, len(runs)) + for _, run := range runs { + history = append(history, run.Summary()) + } + return history, nil } // GenerateReport 为已完成的标准运行生成报告文件。 -// 先查内存中的 activeRuns,若不存在则从磁盘快照加载(支持跨 session 历史运行)。 +// 先查内存中的 activeRuns,若不存在则从最终结果文件加载(支持跨 session 历史运行)。 func (s *serverImpl) GenerateReport(runID RunID, format ReportFormat) (string, error) { s.mu.RLock() ar, ok := s.activeRuns[runID] - historyDir := s.historyDir + runStore := s.runStore s.mu.RUnlock() var status RunStatus @@ -597,15 +569,15 @@ func (s *serverImpl) GenerateReport(runID RunID, format ReportFormat) (string, e standardResult = ar.state.StandardResult ar.mu.RUnlock() } else { - // 不在内存中,尝试从磁盘快照加载(历史运行 / 程序重启后) - st := store.NewJSONStore[*RunState](runStatePath(historyDir, runID)) - snap, err := st.Load() - if err != nil || snap == nil { + run, err := runStore.LoadByRunID(string(runID)) + if err != nil || run == nil { return "", fmt.Errorf("run %q not found", runID) } - status = snap.Status - mode = snap.Mode - standardResult = snap.StandardResult + status = RunStatus(run.Metadata.Status) + mode = run.Metadata.Mode + if run.Result != nil { + standardResult = run.Result.StandardResult + } } if status == RunStatusRunning { diff --git a/internal/server/server.go b/internal/server/server.go index ddd0d83..d095cbf 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -12,7 +12,7 @@ import ( // 所有方法均为线程安全。 type Server interface { // --- 任务 CRUD --- - ListTasks() []types.TaskDefinition + ListTasks() []TaskOverview GetTask(id string) (types.TaskDefinition, bool) CreateTask(cfg TaskConfig) (types.TaskDefinition, error) UpdateTask(id string, cfg TaskConfig) (types.TaskDefinition, error) @@ -46,37 +46,37 @@ type Server interface { type serverImpl struct { mu sync.RWMutex taskStore *store.TaskStore + runStore *store.RunStore bus *eventBus activeRuns map[RunID]*activeRun - historyDir string } // New 创建并初始化 Server 实例。 -// 会自动加载 ~/.ait/tasks.json;historyDir 用于存放每个任务的运行历史文件。 +// 会自动加载 ~/.ait/tasks/ 与 ~/.ait/runs/ 下的业务数据。 func New() (Server, error) { if _, err := config.EnsureAppDir(); err != nil { return nil, err } - tasksPath, err := config.TasksPath() + tasksDir, err := config.TasksDir() if err != nil { return nil, err } - historyDir, err := config.HistoryDir() + runsDir, err := config.RunsDir() if err != nil { return nil, err } - ts := store.NewTaskStore(tasksPath) + ts := store.NewTaskStore(tasksDir) if err := ts.Load(); err != nil { return nil, err } return &serverImpl{ taskStore: ts, + runStore: store.NewRunStore(runsDir), bus: newEventBus(), activeRuns: make(map[RunID]*activeRun), - historyDir: historyDir, }, nil } diff --git a/internal/server/server_test.go b/internal/server/server_test.go index f5f946c..0b1a52d 100644 --- a/internal/server/server_test.go +++ b/internal/server/server_test.go @@ -18,19 +18,20 @@ import ( func newTestServer(t *testing.T) *serverImpl { t.Helper() dir := t.TempDir() - historyDir := filepath.Join(dir, "history") - if err := os.MkdirAll(historyDir, 0o755); err != nil { - t.Fatalf("mkdir history: %v", err) + tasksDir := filepath.Join(dir, "tasks") + runsDir := filepath.Join(dir, "runs") + if err := os.MkdirAll(runsDir, 0o755); err != nil { + t.Fatalf("mkdir runs: %v", err) } - ts := store.NewTaskStore(filepath.Join(dir, "tasks.json")) + ts := store.NewTaskStore(tasksDir) if err := ts.Load(); err != nil { t.Fatalf("load task store: %v", err) } return &serverImpl{ taskStore: ts, + runStore: store.NewRunStore(runsDir), bus: newEventBus(), activeRuns: make(map[RunID]*activeRun), - historyDir: historyDir, } } @@ -292,16 +293,16 @@ func TestMapRequestMetrics_ZeroPromptTokensSkipsCacheHitRate(t *testing.T) { // ── snapshotState ───────────────────────────────────────────────────────────── func TestSnapshotState_DeepCopiesRequests(t *testing.T) { - original := &RequestMetrics{Index: 0, Success: true} + original := &types.RequestMetrics{Index: 0, Success: true} ar := &activeRun{ state: &RunState{ - Requests: []*RequestMetrics{original}, + Requests: []*types.RequestMetrics{original}, }, } snap := ar.snapshotState() // Mutate original slice — snapshot must remain unchanged. - ar.state.Requests[0] = &RequestMetrics{Index: 99} + ar.state.Requests[0] = &types.RequestMetrics{Index: 99} if snap.Requests[0].Index != 0 { t.Error("Requests slice was not deep-copied: snapshot reflects mutation of original") } @@ -337,13 +338,93 @@ func TestSnapshotState_EmptySlicesNotCopied(t *testing.T) { } } -// ── historyPath ─────────────────────────────────────────────────────────────── +func TestAppendRequestToDisk_CreatesParentDirectory(t *testing.T) { + s := newTestServer(t) + taskID := "task-1" + runID := RunID("run_disk_append") + req := types.RequestMetrics{Index: 0, Success: true, TotalTime: time.Second, TTFT: 100 * time.Millisecond, TPS: 12.5} + + if err := s.runStore.AppendRequest(taskID, string(runID), req); err != nil { + t.Fatalf("AppendRequest() returned unexpected error: %v", err) + } + + reqs, err := s.runStore.LoadRequests(taskID, string(runID)) + if err != nil { + t.Fatalf("LoadRequests() returned unexpected error: %v", err) + } + if len(reqs) != 1 { + t.Fatalf("expected 1 request loaded from disk, got %d", len(reqs)) + } + if reqs[0].Index != req.Index { + t.Errorf("Index: got %d, want %d", reqs[0].Index, req.Index) + } + if reqs[0].TPS != req.TPS { + t.Errorf("TPS: got %v, want %v", reqs[0].TPS, req.TPS) + } +} -func TestHistoryPath(t *testing.T) { - got := historyPath("/data/history", "task-abc") - want := filepath.Join("/data/history", "task-abc.json") - if got != want { - t.Errorf("historyPath: got %q, want %q", got, want) +func TestGetRunState_LoadsCompletedRunFromDisk(t *testing.T) { + s := newTestServer(t) + runID := RunID("run_disk_result") + taskID := "task-1" + startedAt := time.Now().Add(-2 * time.Second).UTC().Truncate(time.Second) + finishedAt := time.Now().UTC().Truncate(time.Second) + + if err := s.runStore.SaveFinal(store.RunMetadata{ + RunID: string(runID), + TaskID: taskID, + Mode: "standard", + Protocol: types.ProtocolOpenAICompletions, + Model: "test-model", + Status: string(RunStatusCompleted), + StartedAt: startedAt, + FinishedAt: &finishedAt, + }, store.RunResult{ + TotalReqs: 4, + DoneReqs: 1, + SuccessReqs: 1, + AvgTPS: 18.5, + AvgTTFT: 120 * time.Millisecond, + SuccessRate: 25, + CacheHitRate: 0.4, + ErrorSummary: "", + StandardResult: &types.ReportData{TotalRequests: 4, AvgTPS: 18.5, AvgTTFT: 120 * time.Millisecond, SuccessRate: 25}, + }); err != nil { + t.Fatalf("SaveFinal() returned unexpected error: %v", err) + } + if err := s.runStore.AppendRequest(taskID, string(runID), types.RequestMetrics{ + Index: 0, + Success: true, + TotalTime: time.Second, + TTFT: 120 * time.Millisecond, + TPS: 18.5, + PromptTokens: 100, + CompletionTokens: 18, + }); err != nil { + t.Fatalf("AppendRequest() returned unexpected error: %v", err) + } + + state, ok := s.GetRunState(runID) + if !ok { + t.Fatal("expected completed run to load from disk") + } + if state.Status != RunStatusCompleted { + t.Fatalf("Status: got %q, want %q", state.Status, RunStatusCompleted) + } + if state.TaskID != taskID { + t.Errorf("TaskID: got %q, want %q", state.TaskID, taskID) + } + if state.DoneReqs != 1 { + t.Errorf("DoneReqs: got %d, want 1", state.DoneReqs) + } + if len(state.Requests) != 1 { + t.Fatalf("expected 1 request in loaded state, got %d", len(state.Requests)) + } + if state.Requests[0].Index != 0 { + t.Errorf("request index: got %d, want 0", state.Requests[0].Index) + } + if state.Requests[0].TTFT != 120*time.Millisecond { + t.Errorf("request TTFT: got %v, want 120ms", state.Requests[0].TTFT) } } @@ -576,7 +657,9 @@ func TestGetHistory_PersistsAfterRun(t *testing.T) { StartedAt: time.Now().Add(-time.Second), FinishedAt: time.Now(), } - s.persistRunResult(task.ID, s.historyDir, summary) + if err := s.persistRunResult(summary); err != nil { + t.Fatalf("persistRunResult: %v", err) + } history, err := s.GetHistory(task.ID, 0) if err != nil { @@ -590,17 +673,98 @@ func TestGetHistory_PersistsAfterRun(t *testing.T) { } } +func TestGetTask_DerivesRunningSummaryFromActiveRun(t *testing.T) { + s := newTestServer(t) + task, _ := s.CreateTask(makeTaskConfig("running-task")) + startedAt := time.Now().Add(-2 * time.Second) + + runID := RunID("run_live") + s.mu.Lock() + s.activeRuns[runID] = &activeRun{ + state: &RunState{ + RunID: runID, + TaskID: task.ID, + Mode: "standard", + Status: RunStatusRunning, + StartedAt: startedAt, + SuccessRate: 100, + AvgTTFT: 120 * time.Millisecond, + AvgTPS: 18.5, + }, + } + s.mu.Unlock() + + tasks := s.ListTasks() + if len(tasks) != 1 { + t.Fatalf("expected 1 task overview, got %d", len(tasks)) + } + if tasks[0].LatestRun == nil { + t.Fatal("expected LatestRun to be derived when run starts") + } + if tasks[0].LatestRun.Status != string(RunStatusRunning) { + t.Fatalf("LatestRun.Status: got %q, want %q", tasks[0].LatestRun.Status, RunStatusRunning) + } + if !tasks[0].LatestRun.FinishedAt.IsZero() { + t.Fatal("expected running LatestRun to have zero FinishedAt") + } +} + +func TestPersistRunResult_DerivesLatestTaskSummary(t *testing.T) { + s := newTestServer(t) + task, _ := s.CreateTask(makeTaskConfig("finalize-task")) + startedAt := time.Now().Add(-2 * time.Second) + + finishedAt := time.Now() + if err := s.persistRunResult(types.TaskRunSummary{ + RunID: "run_same", + TaskID: task.ID, + Mode: "standard", + Status: string(RunStatusCompleted), + StartedAt: startedAt, + FinishedAt: finishedAt, + SuccessRate: 100, + AvgTTFT: 80 * time.Millisecond, + AvgTPS: 24.2, + }); err != nil { + t.Fatalf("persistRunResult: %v", err) + } + + history, err := s.GetHistory(task.ID, 0) + if err != nil { + t.Fatalf("GetHistory: %v", err) + } + if len(history) != 1 { + t.Fatalf("expected 1 history entry after finalize, got %d", len(history)) + } + if history[0].Status != string(RunStatusCompleted) { + t.Fatalf("Status: got %q, want %q", history[0].Status, RunStatusCompleted) + } + if history[0].FinishedAt.IsZero() { + t.Fatal("expected FinishedAt to be persisted for completed run") + } + + tasks := s.ListTasks() + if len(tasks) != 1 { + t.Fatalf("expected 1 task overview, got %d", len(tasks)) + } + if tasks[0].LatestRun == nil || !tasks[0].LatestRun.FinishedAt.Equal(finishedAt) { + t.Fatalf("expected LatestRun.FinishedAt to equal %v, got %+v", finishedAt, tasks[0].LatestRun) + } +} + func TestGetHistory_LimitRespected(t *testing.T) { s := newTestServer(t) task, _ := s.CreateTask(makeTaskConfig("limit-task")) for i := 0; i < 5; i++ { - s.persistRunResult(task.ID, s.historyDir, types.TaskRunSummary{ + if err := s.persistRunResult(types.TaskRunSummary{ RunID: "run_" + string(rune('0'+i)), TaskID: task.ID, StartedAt: time.Now(), FinishedAt: time.Now(), - }) + }); err != nil { + t.Fatalf("persistRunResult: %v", err) + } } history, err := s.GetHistory(task.ID, 3) diff --git a/internal/server/task.go b/internal/server/task.go index 3069bba..29c3ebc 100644 --- a/internal/server/task.go +++ b/internal/server/task.go @@ -7,17 +7,24 @@ import ( ) // ListTasks 返回所有任务(最近更新排在前面)。 -func (s *serverImpl) ListTasks() []types.TaskDefinition { +func (s *serverImpl) ListTasks() []TaskOverview { s.mu.RLock() - defer s.mu.RUnlock() - return s.taskStore.All() + tasks := s.taskStore.All() + running := s.runningTaskSummariesLocked(tasks) + s.mu.RUnlock() + + return s.buildTaskOverviews(tasks, running) } // GetTask 按 ID 查找任务。 func (s *serverImpl) GetTask(id string) (types.TaskDefinition, bool) { s.mu.RLock() - defer s.mu.RUnlock() - return s.taskStore.Get(id) + task, ok := s.taskStore.Get(id) + s.mu.RUnlock() + if !ok { + return types.TaskDefinition{}, false + } + return task, true } // CreateTask 新建任务并持久化。 @@ -72,7 +79,10 @@ func (s *serverImpl) DeleteTask(id string) error { if err := s.taskStore.Delete(id); err != nil { return err } - return s.taskStore.Save() + if err := s.taskStore.Save(); err != nil { + return err + } + return s.runStore.DeleteTask(id) } // CopyTask 复制指定任务(ID 和时间戳重置,名称加 " (copy)" 后缀)。 @@ -101,3 +111,59 @@ func (s *serverImpl) CopyTask(id string) (types.TaskDefinition, error) { } return copied, nil } + +func (s *serverImpl) buildTaskOverviews(tasks []types.TaskDefinition, running map[string]types.TaskRunSummary) []TaskOverview { + decorated := make([]TaskOverview, 0, len(tasks)) + for _, task := range tasks { + decorated = append(decorated, s.buildTaskOverview(task, running)) + } + return decorated +} + + +func (s *serverImpl) buildTaskOverview(task types.TaskDefinition, running map[string]types.TaskRunSummary) TaskOverview { + overview := TaskOverview{TaskDefinition: task} + latest, err := s.runStore.LatestByTask(task.ID) + if err == nil && latest != nil { + summary := latest.Summary() + overview.LatestRun = &summary + } + if summary, ok := running[task.ID]; ok { + runningSummary := summary + overview.LatestRun = &runningSummary + } + return overview +} + +func (s *serverImpl) runningTaskSummariesLocked(tasks []types.TaskDefinition) map[string]types.TaskRunSummary { + if len(tasks) == 0 || len(s.activeRuns) == 0 { + return nil + } + + taskByID := make(map[string]types.TaskDefinition, len(tasks)) + for _, task := range tasks { + taskByID[task.ID] = task + } + + running := make(map[string]types.TaskRunSummary) + for _, ar := range s.activeRuns { + ar.mu.RLock() + if ar.state == nil || ar.state.Status != RunStatusRunning { + ar.mu.RUnlock() + continue + } + taskDef, ok := taskByID[ar.state.TaskID] + if !ok { + ar.mu.RUnlock() + continue + } + summary := buildRunningRunSummary(taskDef, ar.snapshotState()) + ar.mu.RUnlock() + running[taskDef.ID] = summary + } + + if len(running) == 0 { + return nil + } + return running +} diff --git a/internal/server/types.go b/internal/server/types.go index b3d85b5..f98f361 100644 --- a/internal/server/types.go +++ b/internal/server/types.go @@ -27,6 +27,13 @@ type TaskConfig struct { Input types.Input } +// TaskOverview 是面向列表/摘要读取的任务视图。 +// 它组合任务定义本体与最近一次运行摘要,不进入持久化层。 +type TaskOverview struct { + types.TaskDefinition + LatestRun *types.TaskRunSummary +} + // RunStatus 运行的生命周期状态。 type RunStatus string @@ -60,7 +67,7 @@ type RunState struct { CacheHitRate float64 // 详细请求列表(按 index 排序) - Requests []*RequestMetrics + Requests []*types.RequestMetrics // Turbo 专用 Levels []types.TurboLevelResult @@ -73,27 +80,6 @@ type RunState struct { ErrorMsg string } -// RequestMetrics 单次请求的详细指标,供请求列表页展示。 -type RequestMetrics struct { - Index int - Success bool - TotalTime time.Duration - TTFT time.Duration - TPS float64 - PromptTokens int - CompletionTokens int - CachedTokens int - CacheHitRate float64 - DNSTime time.Duration - ConnectTime time.Duration - TLSTime time.Duration - TargetIP string - ErrorMessage string - // 原始请求/响应数据(供请求详情页展示和复制) - RequestBody string - ResponseBody string -} - // EventKind 事件类型枚举。 type EventKind string diff --git a/internal/store/history.go b/internal/store/history.go deleted file mode 100644 index 8b30cd6..0000000 --- a/internal/store/history.go +++ /dev/null @@ -1,49 +0,0 @@ -package store - -import "github.com/yinxulai/ait/internal/types" - -// HistoryStore 管理单个任务的运行历史文件(~/.ait/history/.json)。 -// 每个任务对应独立的 HistoryStore 实例和独立的文件。 -type HistoryStore struct { - store *JSONStore[[]types.TaskRunSummary] -} - -// NewHistoryStore 创建持久化到 path 的 HistoryStore。 -func NewHistoryStore(path string) *HistoryStore { - return &HistoryStore{store: NewJSONStore[[]types.TaskRunSummary](path)} -} - -// Append 追加一条运行摘要到历史文件。 -func (s *HistoryStore) Append(run types.TaskRunSummary) error { - runs, err := s.store.Load() - if err != nil { - return err - } - if runs == nil { - runs = []types.TaskRunSummary{} - } - runs = append(runs, run) - return s.store.Save(runs) -} - -// Load 返回运行历史,最新的排在前面。limit <= 0 表示不限制条数。 -func (s *HistoryStore) Load(limit int) ([]types.TaskRunSummary, error) { - runs, err := s.store.Load() - if err != nil { - return nil, err - } - if runs == nil { - return []types.TaskRunSummary{}, nil - } - - // 反转(最新在前) - reversed := make([]types.TaskRunSummary, len(runs)) - for i, r := range runs { - reversed[len(runs)-1-i] = r - } - - if limit > 0 && len(reversed) > limit { - reversed = reversed[:limit] - } - return reversed, nil -} diff --git a/internal/store/run.go b/internal/store/run.go new file mode 100644 index 0000000..0acdf7d --- /dev/null +++ b/internal/store/run.go @@ -0,0 +1,311 @@ +package store + +import ( + "bufio" + "encoding/json" + "fmt" + "os" + "path/filepath" + "sort" + "time" + + "github.com/yinxulai/ait/internal/types" +) + +type RunMetadata struct { + RunID string `json:"run_id"` + TaskID string `json:"task_id"` + Mode string `json:"mode"` + Protocol string `json:"protocol"` + Model string `json:"model"` + Status string `json:"status"` + StartedAt time.Time `json:"started_at"` + FinishedAt *time.Time `json:"finished_at,omitempty"` +} + +type RunResult struct { + TotalReqs int `json:"total_reqs"` + DoneReqs int `json:"done_reqs"` + SuccessReqs int `json:"success_reqs"` + FailedReqs int `json:"failed_reqs"` + SuccessRate float64 `json:"success_rate"` + AvgTTFT time.Duration `json:"avg_ttft"` + AvgTPS float64 `json:"avg_tps"` + CacheHitRate float64 `json:"cache_hit_rate"` + MaxStableConcurrency int `json:"max_stable_concurrency,omitempty"` + ErrorSummary string `json:"error_summary,omitempty"` + StandardResult *types.ReportData `json:"standard_result,omitempty"` + TurboResult *types.TurboResult `json:"turbo_result,omitempty"` +} + +type StoredRun struct { + Metadata RunMetadata + Result *RunResult +} + +type RunStore struct { + root string +} + +func NewRunStore(root string) *RunStore { + return &RunStore{root: root} +} + +func (s *RunStore) TaskDir(taskID string) string { + return filepath.Join(s.root, taskID) +} + +func (s *RunStore) RunDir(taskID, runID string) string { + return filepath.Join(s.TaskDir(taskID), runID) +} + +func (s *RunStore) MetadataPath(taskID, runID string) string { + return filepath.Join(s.RunDir(taskID, runID), "run.json") +} + +func (s *RunStore) ResultPath(taskID, runID string) string { + return filepath.Join(s.RunDir(taskID, runID), "result.json") +} + +func (s *RunStore) RequestsPath(taskID, runID string) string { + return filepath.Join(s.RunDir(taskID, runID), "requests.jsonl") +} + +func (s *RunStore) AppendRequest(taskID, runID string, request types.RequestMetrics) error { + if taskID == "" || runID == "" { + return fmt.Errorf("task id and run id are required") + } + if err := os.MkdirAll(s.RunDir(taskID, runID), 0o755); err != nil { + return err + } + + data, err := json.Marshal(request) + if err != nil { + return err + } + + f, err := os.OpenFile(s.RequestsPath(taskID, runID), os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0o644) + if err != nil { + return err + } + defer f.Close() + + if _, err := f.Write(data); err != nil { + return err + } + _, err = f.Write([]byte{'\n'}) + return err +} + +func (s *RunStore) LoadRequests(taskID, runID string) ([]types.RequestMetrics, error) { + f, err := os.Open(s.RequestsPath(taskID, runID)) + if os.IsNotExist(err) { + return []types.RequestMetrics{}, nil + } + if err != nil { + return nil, err + } + defer f.Close() + + const maxLineSize = 16 * 1024 * 1024 + buf := make([]byte, maxLineSize) + scanner := bufio.NewScanner(f) + scanner.Buffer(buf, maxLineSize) + + requests := make([]types.RequestMetrics, 0) + for scanner.Scan() { + line := scanner.Bytes() + if len(line) == 0 { + continue + } + var request types.RequestMetrics + if err := json.Unmarshal(line, &request); err != nil { + continue + } + requests = append(requests, request) + } + if err := scanner.Err(); err != nil { + return nil, err + } + + sort.Slice(requests, func(i, j int) bool { + return requests[i].Index < requests[j].Index + }) + return requests, nil +} + +func (s *RunStore) SaveFinal(meta RunMetadata, result RunResult) error { + if meta.TaskID == "" || meta.RunID == "" { + return fmt.Errorf("task id and run id are required") + } + if err := os.MkdirAll(s.RunDir(meta.TaskID, meta.RunID), 0o755); err != nil { + return err + } + if err := NewJSONStore[RunMetadata](s.MetadataPath(meta.TaskID, meta.RunID)).Save(meta); err != nil { + return err + } + return NewJSONStore[RunResult](s.ResultPath(meta.TaskID, meta.RunID)).Save(result) +} + +func (s *RunStore) SaveSummary(summary types.TaskRunSummary) error { + if summary.TaskID == "" || summary.RunID == "" { + return fmt.Errorf("task id and run id are required") + } + + var finishedAt *time.Time + if !summary.FinishedAt.IsZero() { + finished := summary.FinishedAt + finishedAt = &finished + } + + return s.SaveFinal(RunMetadata{ + RunID: summary.RunID, + TaskID: summary.TaskID, + Mode: summary.Mode, + Protocol: summary.Protocol, + Model: summary.Model, + Status: summary.Status, + StartedAt: summary.StartedAt, + FinishedAt: finishedAt, + }, RunResult{ + SuccessRate: summary.SuccessRate, + AvgTTFT: summary.AvgTTFT, + AvgTPS: summary.AvgTPS, + CacheHitRate: summary.CacheHitRate, + MaxStableConcurrency: summary.MaxStableConcurrency, + ErrorSummary: summary.ErrorSummary, + }) +} + +func (s *RunStore) Load(taskID, runID string) (*StoredRun, error) { + metaPath := s.MetadataPath(taskID, runID) + if _, err := os.Stat(metaPath); os.IsNotExist(err) { + return nil, nil + } else if err != nil { + return nil, err + } + + meta, err := NewJSONStore[RunMetadata](metaPath).Load() + if err != nil { + return nil, err + } + + resultPath := s.ResultPath(taskID, runID) + var result *RunResult + if _, err := os.Stat(resultPath); err == nil { + loaded, err := NewJSONStore[RunResult](resultPath).Load() + if err != nil { + return nil, err + } + result = &loaded + } else if !os.IsNotExist(err) { + return nil, err + } + + return &StoredRun{Metadata: meta, Result: result}, nil +} + +func (s *RunStore) LoadByRunID(runID string) (*StoredRun, error) { + taskEntries, err := os.ReadDir(s.root) + if os.IsNotExist(err) { + return nil, nil + } + if err != nil { + return nil, err + } + + for _, taskEntry := range taskEntries { + if !taskEntry.IsDir() { + continue + } + candidate, err := s.Load(taskEntry.Name(), runID) + if err != nil { + return nil, err + } + if candidate != nil { + return candidate, nil + } + } + + return nil, nil +} + +func (s *RunStore) ListByTask(taskID string, limit int) ([]StoredRun, error) { + runEntries, err := os.ReadDir(s.TaskDir(taskID)) + if os.IsNotExist(err) { + return []StoredRun{}, nil + } + if err != nil { + return nil, err + } + + runs := make([]StoredRun, 0, len(runEntries)) + for _, runEntry := range runEntries { + if !runEntry.IsDir() { + continue + } + run, err := s.Load(taskID, runEntry.Name()) + if err != nil { + return nil, err + } + if run == nil { + continue + } + runs = append(runs, *run) + } + + sort.Slice(runs, func(i, j int) bool { + return runSortTime(runs[i]).After(runSortTime(runs[j])) + }) + + if limit > 0 && len(runs) > limit { + runs = runs[:limit] + } + return runs, nil +} + +func (s *RunStore) LatestByTask(taskID string) (*StoredRun, error) { + runs, err := s.ListByTask(taskID, 1) + if err != nil { + return nil, err + } + if len(runs) == 0 { + return nil, nil + } + return &runs[0], nil +} + +func (s *RunStore) DeleteTask(taskID string) error { + return os.RemoveAll(s.TaskDir(taskID)) +} + +func (r StoredRun) Summary() types.TaskRunSummary { + summary := types.TaskRunSummary{ + RunID: r.Metadata.RunID, + TaskID: r.Metadata.TaskID, + Mode: r.Metadata.Mode, + Status: r.Metadata.Status, + Protocol: r.Metadata.Protocol, + Model: r.Metadata.Model, + StartedAt: r.Metadata.StartedAt, + } + if r.Metadata.FinishedAt != nil { + summary.FinishedAt = *r.Metadata.FinishedAt + } + if r.Result != nil { + summary.SuccessRate = r.Result.SuccessRate + summary.AvgTTFT = r.Result.AvgTTFT + summary.AvgTPS = r.Result.AvgTPS + summary.CacheHitRate = r.Result.CacheHitRate + summary.MaxStableConcurrency = r.Result.MaxStableConcurrency + summary.ErrorSummary = r.Result.ErrorSummary + } + return summary +} + +func runSortTime(run StoredRun) time.Time { + if run.Metadata.FinishedAt != nil && !run.Metadata.FinishedAt.IsZero() { + return *run.Metadata.FinishedAt + } + return run.Metadata.StartedAt +} diff --git a/internal/store/store.go b/internal/store/store.go index 6bc44fc..38def8d 100644 --- a/internal/store/store.go +++ b/internal/store/store.go @@ -6,26 +6,157 @@ import ( "os" "path/filepath" "sync" + "time" ) // JSONStore 是泛型 JSON 文件持久化基类。 // 内置进程级互斥锁,防止同一进程内并发读写;文件操作通过原子写入保证安全。 type JSONStore[T any] struct { - path string - mu sync.Mutex + path string + mu sync.Mutex + debounce time.Duration } +type debouncedPathState struct { + path string + mu sync.Mutex + cond *sync.Cond + timer *time.Timer + generation uint64 + flushRunning bool + pending []byte + lastErr error +} + +var debouncedPathStates = struct { + mu sync.Mutex + states map[string]*debouncedPathState +}{states: make(map[string]*debouncedPathState)} + // NewJSONStore 创建指向指定路径的 JSONStore。 func NewJSONStore[T any](path string) *JSONStore[T] { return &JSONStore[T]{path: path} } +// NewDebouncedJSONStore 创建带写入防抖的 JSONStore。 +// 同一路径的多个实例会共享同一个防抖器,用于合并高频 Save 调用。 +func NewDebouncedJSONStore[T any](path string, debounce time.Duration) *JSONStore[T] { + return &JSONStore[T]{path: path, debounce: debounce} +} + +func getDebouncedPathState(path string) *debouncedPathState { + debouncedPathStates.mu.Lock() + defer debouncedPathStates.mu.Unlock() + + if state, ok := debouncedPathStates.states[path]; ok { + return state + } + + state := &debouncedPathState{path: path} + state.cond = sync.NewCond(&state.mu) + debouncedPathStates.states[path] = state + return state +} + +func (s *debouncedPathState) schedule(data []byte, delay time.Duration) error { + s.mu.Lock() + defer s.mu.Unlock() + + s.pending = append([]byte(nil), data...) + s.lastErr = nil + s.generation++ + gen := s.generation + + if s.timer != nil { + s.timer.Stop() + } + s.timer = time.AfterFunc(delay, func() { + s.flushGeneration(gen) + }) + return nil +} + +func (s *debouncedPathState) saveNow(data []byte) error { + s.mu.Lock() + if s.timer != nil { + s.timer.Stop() + s.timer = nil + } + s.pending = append([]byte(nil), data...) + s.generation++ + s.lastErr = nil + s.mu.Unlock() + return s.flush() +} + +func (s *debouncedPathState) flushGeneration(gen uint64) { + s.mu.Lock() + if gen != s.generation { + s.mu.Unlock() + return + } + s.mu.Unlock() + _ = s.flush() +} + +func (s *debouncedPathState) flush() error { + s.mu.Lock() + for s.flushRunning { + s.cond.Wait() + } + if s.timer != nil { + s.timer.Stop() + s.timer = nil + } + if s.pending == nil { + err := s.lastErr + s.mu.Unlock() + return err + } + data := append([]byte(nil), s.pending...) + s.pending = nil + s.flushRunning = true + s.mu.Unlock() + + err := writeJSONFileAtomic(s.path, data) + + s.mu.Lock() + s.flushRunning = false + s.lastErr = err + hasPending := s.pending != nil + s.cond.Broadcast() + s.mu.Unlock() + + if err != nil { + return err + } + if hasPending { + return s.flush() + } + return nil +} + +func writeJSONFileAtomic(path string, data []byte) error { + if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { + return err + } + + tmp := path + ".tmp" + if err := os.WriteFile(tmp, data, 0o644); err != nil { + return err + } + return os.Rename(tmp, path) +} + // Load 从文件读取并反序列化为 T。文件不存在时返回零值(无错误)。 func (s *JSONStore[T]) Load() (T, error) { s.mu.Lock() defer s.mu.Unlock() var zero T + if err := getDebouncedPathState(s.path).flush(); err != nil { + return zero, err + } data, err := os.ReadFile(s.path) if errors.Is(err, os.ErrNotExist) { return zero, nil @@ -47,19 +178,21 @@ func (s *JSONStore[T]) Save(v T) error { s.mu.Lock() defer s.mu.Unlock() - if err := os.MkdirAll(filepath.Dir(s.path), 0o755); err != nil { - return err - } - data, err := json.MarshalIndent(v, "", " ") if err != nil { return err } - // 原子写入:先写临时文件,再重命名 - tmp := s.path + ".tmp" - if err := os.WriteFile(tmp, data, 0o644); err != nil { - return err + state := getDebouncedPathState(s.path) + if s.debounce > 0 { + return state.schedule(data, s.debounce) } - return os.Rename(tmp, s.path) + return state.saveNow(data) +} + +// Flush 立即落盘当前路径上的 pending 数据。 +func (s *JSONStore[T]) Flush() error { + s.mu.Lock() + defer s.mu.Unlock() + return getDebouncedPathState(s.path).flush() } diff --git a/internal/store/store_test.go b/internal/store/store_test.go new file mode 100644 index 0000000..d94b5ab --- /dev/null +++ b/internal/store/store_test.go @@ -0,0 +1,77 @@ +package store + +import ( + "path/filepath" + "testing" + "time" +) + +type testPayload struct { + Value string `json:"value"` +} + +func TestDebouncedJSONStore_LoadFlushesPendingWrite(t *testing.T) { + path := filepath.Join(t.TempDir(), "state.json") + debounced := NewDebouncedJSONStore[testPayload](path, 200*time.Millisecond) + + if err := debounced.Save(testPayload{Value: "pending"}); err != nil { + t.Fatalf("Save() returned unexpected error: %v", err) + } + + plain := NewJSONStore[testPayload](path) + loaded, err := plain.Load() + if err != nil { + t.Fatalf("Load() returned unexpected error: %v", err) + } + if loaded.Value != "pending" { + t.Fatalf("Value: got %q, want %q", loaded.Value, "pending") + } + if err := debounced.Flush(); err != nil { + t.Fatalf("Flush() returned unexpected error: %v", err) + } +} + +func TestDebouncedJSONStore_CoalescesWrites(t *testing.T) { + path := filepath.Join(t.TempDir(), "state.json") + store := NewDebouncedJSONStore[testPayload](path, 40*time.Millisecond) + + if err := store.Save(testPayload{Value: "first"}); err != nil { + t.Fatalf("Save(first) returned unexpected error: %v", err) + } + if err := store.Save(testPayload{Value: "second"}); err != nil { + t.Fatalf("Save(second) returned unexpected error: %v", err) + } + + time.Sleep(100 * time.Millisecond) + + loaded, err := NewJSONStore[testPayload](path).Load() + if err != nil { + t.Fatalf("Load() returned unexpected error: %v", err) + } + if loaded.Value != "second" { + t.Fatalf("Value: got %q, want %q", loaded.Value, "second") + } +} + +func TestJSONStore_ImmediateSaveOverridesPendingDebouncedWrite(t *testing.T) { + path := filepath.Join(t.TempDir(), "state.json") + debounced := NewDebouncedJSONStore[testPayload](path, 60*time.Millisecond) + plain := NewJSONStore[testPayload](path) + + if err := debounced.Save(testPayload{Value: "stale"}); err != nil { + t.Fatalf("debounced Save() returned unexpected error: %v", err) + } + if err := plain.Save(testPayload{Value: "final"}); err != nil { + t.Fatalf("plain Save() returned unexpected error: %v", err) + } + + time.Sleep(120 * time.Millisecond) + + loaded, err := plain.Load() + if err != nil { + t.Fatalf("Load() returned unexpected error: %v", err) + } + if loaded.Value != "final" { + t.Fatalf("Value: got %q, want %q", loaded.Value, "final") + } +} diff --git a/internal/store/task.go b/internal/store/task.go index 0e4e5ec..8008365 100644 --- a/internal/store/task.go +++ b/internal/store/task.go @@ -2,54 +2,131 @@ package store import ( "fmt" + "os" + "path/filepath" + "sort" + "strings" "time" "github.com/yinxulai/ait/internal/types" ) -type taskStoreData struct { - Tasks []types.TaskDefinition `json:"tasks"` +type persistedTaskDefinition struct { + ID string `json:"id"` + Name string `json:"name"` + Input types.Input `json:"input"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` } -// TaskStore 管理 ~/.ait/tasks.json 的任务列表持久化。 +// TaskStore 管理 ~/.ait/tasks/ 下的任务文件持久化。 type TaskStore struct { - store *JSONStore[taskStoreData] - data taskStoreData + dir string + data []types.TaskDefinition } -// NewTaskStore 创建持久化到 path 的 TaskStore(调用方需先调用 Load)。 -func NewTaskStore(path string) *TaskStore { - return &TaskStore{store: NewJSONStore[taskStoreData](path)} +// NewTaskStore 创建持久化到 dir 的 TaskStore(调用方需先调用 Load)。 +func NewTaskStore(dir string) *TaskStore { + return &TaskStore{dir: dir} } -// Load 从磁盘加载任务列表,文件不存在时初始化为空列表。 +// Load 从磁盘加载任务列表,目录不存在时初始化为空列表。 func (s *TaskStore) Load() error { - data, err := s.store.Load() + entries, err := os.ReadDir(s.dir) + if os.IsNotExist(err) { + s.data = []types.TaskDefinition{} + return nil + } if err != nil { return err } - if data.Tasks == nil { - data.Tasks = []types.TaskDefinition{} + + tasks := make([]types.TaskDefinition, 0, len(entries)) + for _, entry := range entries { + if entry.IsDir() || filepath.Ext(entry.Name()) != ".json" { + continue + } + + path := filepath.Join(s.dir, entry.Name()) + stored, err := NewJSONStore[persistedTaskDefinition](path).Load() + if err != nil { + return err + } + if strings.TrimSpace(stored.ID) == "" { + stored.ID = strings.TrimSuffix(entry.Name(), filepath.Ext(entry.Name())) + } + tasks = append(tasks, types.TaskDefinition{ + ID: stored.ID, + Name: stored.Name, + Input: stored.Input, + CreatedAt: stored.CreatedAt, + UpdatedAt: stored.UpdatedAt, + }) } - s.data = data + + sort.Slice(tasks, func(i, j int) bool { + return tasks[i].UpdatedAt.After(tasks[j].UpdatedAt) + }) + + s.data = tasks return nil } -// Save 将当前内存中的任务列表持久化到磁盘。 +// Save 将当前内存中的任务列表持久化到磁盘,按任务拆分成独立文件。 func (s *TaskStore) Save() error { - return s.store.Save(s.data) + if err := os.MkdirAll(s.dir, 0o755); err != nil { + return err + } + + keep := make(map[string]struct{}, len(s.data)) + for _, task := range s.data { + if strings.TrimSpace(task.ID) == "" { + return fmt.Errorf("task id cannot be empty") + } + keep[task.ID] = struct{}{} + path := filepath.Join(s.dir, task.ID+".json") + stored := persistedTaskDefinition{ + ID: task.ID, + Name: task.Name, + Input: task.Input, + CreatedAt: task.CreatedAt, + UpdatedAt: task.UpdatedAt, + } + if err := NewJSONStore[persistedTaskDefinition](path).Save(stored); err != nil { + return err + } + } + + entries, err := os.ReadDir(s.dir) + if err != nil { + return err + } + for _, entry := range entries { + if entry.IsDir() || filepath.Ext(entry.Name()) != ".json" { + continue + } + taskID := strings.TrimSuffix(entry.Name(), filepath.Ext(entry.Name())) + if _, ok := keep[taskID]; ok { + continue + } + if err := os.Remove(filepath.Join(s.dir, entry.Name())); err != nil && !os.IsNotExist(err) { + return err + } + } + + return nil } // All 返回所有任务的副本,最近更新的排在前面。 func (s *TaskStore) All() []types.TaskDefinition { - result := make([]types.TaskDefinition, len(s.data.Tasks)) - copy(result, s.data.Tasks) + result := make([]types.TaskDefinition, len(s.data)) + copy(result, s.data) return result } // Get 按 ID 查找任务,返回副本。 func (s *TaskStore) Get(id string) (types.TaskDefinition, bool) { - for _, t := range s.data.Tasks { + for _, t := range s.data { if t.ID == id { return t, true } @@ -66,7 +143,7 @@ func (s *TaskStore) Upsert(task types.TaskDefinition) { task.ID = fmt.Sprintf("task_%d", now.UnixNano()) } - for i, existing := range s.data.Tasks { + for i, existing := range s.data { if existing.ID != task.ID { continue } @@ -75,11 +152,11 @@ func (s *TaskStore) Upsert(task types.TaskDefinition) { } task.UpdatedAt = now // 移至列表头部 - tasks := make([]types.TaskDefinition, 0, len(s.data.Tasks)) + tasks := make([]types.TaskDefinition, 0, len(s.data)) tasks = append(tasks, task) - tasks = append(tasks, s.data.Tasks[:i]...) - tasks = append(tasks, s.data.Tasks[i+1:]...) - s.data.Tasks = tasks + tasks = append(tasks, s.data[:i]...) + tasks = append(tasks, s.data[i+1:]...) + s.data = tasks return } @@ -88,14 +165,14 @@ func (s *TaskStore) Upsert(task types.TaskDefinition) { task.CreatedAt = now } task.UpdatedAt = now - s.data.Tasks = append([]types.TaskDefinition{task}, s.data.Tasks...) + s.data = append([]types.TaskDefinition{task}, s.data...) } // Delete 按 ID 删除任务,任务不存在时返回错误。 func (s *TaskStore) Delete(id string) error { - for i, t := range s.data.Tasks { + for i, t := range s.data { if t.ID == id { - s.data.Tasks = append(s.data.Tasks[:i], s.data.Tasks[i+1:]...) + s.data = append(s.data[:i], s.data[i+1:]...) return nil } } diff --git a/internal/task/history.go b/internal/task/history.go deleted file mode 100644 index ae0063b..0000000 --- a/internal/task/history.go +++ /dev/null @@ -1,89 +0,0 @@ -package task - -import ( - "encoding/json" - "errors" - "os" - "path/filepath" - - "github.com/yinxulai/ait/internal/config" - "github.com/yinxulai/ait/internal/types" -) - -func AppendRun(taskID string, run types.TaskRunSummary) error { - runs, err := loadHistoryFile(taskID) - if err != nil && !errors.Is(err, os.ErrNotExist) { - return err - } - if run.TaskID == "" { - run.TaskID = taskID - } - runs = append(runs, run) - return saveHistoryFile(taskID, runs) -} - -func LoadHistory(taskID string, limit int) ([]types.TaskRunSummary, error) { - runs, err := loadHistoryFile(taskID) - if errors.Is(err, os.ErrNotExist) { - return []types.TaskRunSummary{}, nil - } - if err != nil { - return nil, err - } - - reversed := make([]types.TaskRunSummary, 0, len(runs)) - for i := len(runs) - 1; i >= 0; i-- { - reversed = append(reversed, runs[i]) - } - if limit > 0 && len(reversed) > limit { - reversed = reversed[:limit] - } - return reversed, nil -} - -func loadHistoryFile(taskID string) ([]types.TaskRunSummary, error) { - path, err := historyPath(taskID) - if err != nil { - return nil, err - } - - data, err := os.ReadFile(path) - if errors.Is(err, os.ErrNotExist) { - return []types.TaskRunSummary{}, os.ErrNotExist - } - if err != nil { - return nil, err - } - - var runs []types.TaskRunSummary - if err := json.Unmarshal(data, &runs); err != nil { - return nil, err - } - return runs, nil -} - -func saveHistoryFile(taskID string, runs []types.TaskRunSummary) error { - dir, err := config.HistoryDir() - if err != nil { - return err - } - if err := os.MkdirAll(dir, 0o755); err != nil { - return err - } - - data, err := json.MarshalIndent(runs, "", " ") - if err != nil { - return err - } - - path := filepath.Join(dir, taskID+".json") - return os.WriteFile(path, data, 0o644) -} - -func historyPath(taskID string) (string, error) { - dir, err := config.HistoryDir() - if err != nil { - return "", err - } - return filepath.Join(dir, taskID+".json"), nil -} diff --git a/internal/task/history_test.go b/internal/task/history_test.go deleted file mode 100644 index 6426fac..0000000 --- a/internal/task/history_test.go +++ /dev/null @@ -1,41 +0,0 @@ -package task - -import ( - "testing" - "time" - - "github.com/yinxulai/ait/internal/types" -) - -func TestAppendRunAndLoadHistoryNewestFirst(t *testing.T) { - t.Setenv("HOME", t.TempDir()) - - first := types.TaskRunSummary{RunID: "run-1", StartedAt: time.Unix(100, 0), FinishedAt: time.Unix(110, 0)} - second := types.TaskRunSummary{RunID: "run-2", StartedAt: time.Unix(200, 0), FinishedAt: time.Unix(210, 0)} - - if err := AppendRun("task-1", first); err != nil { - t.Fatalf("AppendRun(first) returned unexpected error: %v", err) - } - if err := AppendRun("task-1", second); err != nil { - t.Fatalf("AppendRun(second) returned unexpected error: %v", err) - } - - history, err := LoadHistory("task-1", 0) - if err != nil { - t.Fatalf("LoadHistory() returned unexpected error: %v", err) - } - if len(history) != 2 { - t.Fatalf("expected 2 history items, got %d", len(history)) - } - if history[0].RunID != "run-2" || history[1].RunID != "run-1" { - t.Fatalf("expected newest-first order, got %+v", history) - } - - limited, err := LoadHistory("task-1", 1) - if err != nil { - t.Fatalf("LoadHistory(limit) returned unexpected error: %v", err) - } - if len(limited) != 1 || limited[0].RunID != "run-2" { - t.Fatalf("unexpected limited history: %+v", limited) - } -} diff --git a/internal/task/store.go b/internal/task/store.go deleted file mode 100644 index 4c37620..0000000 --- a/internal/task/store.go +++ /dev/null @@ -1,102 +0,0 @@ -package task - -import ( - "encoding/json" - "errors" - "fmt" - "os" - "time" - - "github.com/yinxulai/ait/internal/config" - "github.com/yinxulai/ait/internal/types" -) - -type TaskStore struct { - Tasks []types.TaskDefinition `json:"tasks"` -} - -func LoadTasks() (*TaskStore, error) { - path, err := config.TasksPath() - if err != nil { - return nil, err - } - - data, err := os.ReadFile(path) - if errors.Is(err, os.ErrNotExist) { - return &TaskStore{Tasks: []types.TaskDefinition{}}, nil - } - if err != nil { - return nil, err - } - - var store TaskStore - if err := json.Unmarshal(data, &store); err != nil { - return nil, err - } - if store.Tasks == nil { - store.Tasks = []types.TaskDefinition{} - } - return &store, nil -} - -func (s *TaskStore) Save() error { - if _, err := config.EnsureAppDir(); err != nil { - return err - } - path, err := config.TasksPath() - if err != nil { - return err - } - data, err := json.MarshalIndent(s, "", " ") - if err != nil { - return err - } - return os.WriteFile(path, data, 0o644) -} - -func (s *TaskStore) Upsert(task types.TaskDefinition) { - now := time.Now() - if task.ID == "" { - task.ID = fmt.Sprintf("task_%d", now.UnixNano()) - } - - for i, existing := range s.Tasks { - if existing.ID != task.ID { - continue - } - if task.CreatedAt.IsZero() { - task.CreatedAt = existing.CreatedAt - } - task.UpdatedAt = now - updated := append([]types.TaskDefinition{task}, append(s.Tasks[:i], s.Tasks[i+1:]...)...) - s.Tasks = updated - return - } - - if task.CreatedAt.IsZero() { - task.CreatedAt = now - } - task.UpdatedAt = now - - s.Tasks = append([]types.TaskDefinition{task}, s.Tasks...) -} - -func (s *TaskStore) Delete(taskID string) error { - for i, task := range s.Tasks { - if task.ID != taskID { - continue - } - s.Tasks = append(s.Tasks[:i], s.Tasks[i+1:]...) - return nil - } - return os.ErrNotExist -} - -func (s *TaskStore) Get(taskID string) (types.TaskDefinition, bool) { - for _, task := range s.Tasks { - if task.ID == taskID { - return task, true - } - } - return types.TaskDefinition{}, false -} diff --git a/internal/task/store_test.go b/internal/task/store_test.go deleted file mode 100644 index 9959da0..0000000 --- a/internal/task/store_test.go +++ /dev/null @@ -1,76 +0,0 @@ -package task - -import ( - "errors" - "os" - "testing" - "time" - - "github.com/yinxulai/ait/internal/types" -) - -func TestLoadTasksReturnsEmptyStoreWhenMissing(t *testing.T) { - t.Setenv("HOME", t.TempDir()) - - store, err := LoadTasks() - if err != nil { - t.Fatalf("LoadTasks() returned unexpected error: %v", err) - } - if len(store.Tasks) != 0 { - t.Fatalf("expected no tasks, got %d", len(store.Tasks)) - } -} - -func TestTaskStoreUpsertSaveAndReload(t *testing.T) { - t.Setenv("HOME", t.TempDir()) - - store := &TaskStore{} - task := types.TaskDefinition{ - ID: "task-1", - Name: "nightly-openai", - Input: types.Input{ - Protocol: types.ProtocolOpenAIResponses, - EndpointURL: "https://api.openai.com/v1/responses", - Model: "gpt-4.1", - }, - } - store.Upsert(task) - if err := store.Save(); err != nil { - t.Fatalf("Save() returned unexpected error: %v", err) - } - - loaded, err := LoadTasks() - if err != nil { - t.Fatalf("LoadTasks() returned unexpected error: %v", err) - } - if len(loaded.Tasks) != 1 || loaded.Tasks[0].ID != "task-1" { - t.Fatalf("unexpected loaded tasks: %+v", loaded.Tasks) - } - - firstUpdatedAt := loaded.Tasks[0].UpdatedAt - time.Sleep(10 * time.Millisecond) - task.Name = "nightly-openai-updated" - loaded.Upsert(task) - if len(loaded.Tasks) != 1 { - t.Fatalf("expected one task after update, got %d", len(loaded.Tasks)) - } - if loaded.Tasks[0].Name != "nightly-openai-updated" { - t.Fatalf("expected updated task name, got %s", loaded.Tasks[0].Name) - } - if !loaded.Tasks[0].UpdatedAt.After(firstUpdatedAt) { - t.Fatalf("expected UpdatedAt to advance after Upsert") - } -} - -func TestTaskStoreDelete(t *testing.T) { - store := &TaskStore{Tasks: []types.TaskDefinition{{ID: "task-1"}, {ID: "task-2"}}} - if err := store.Delete("task-1"); err != nil { - t.Fatalf("Delete() returned unexpected error: %v", err) - } - if len(store.Tasks) != 1 || store.Tasks[0].ID != "task-2" { - t.Fatalf("unexpected tasks after delete: %+v", store.Tasks) - } - if err := store.Delete("missing"); !errors.Is(err, os.ErrNotExist) { - t.Fatalf("expected os.ErrNotExist, got %v", err) - } -} diff --git a/internal/tui/client.go b/internal/tui/client.go index aedbd5c..b605c58 100644 --- a/internal/tui/client.go +++ b/internal/tui/client.go @@ -2,6 +2,7 @@ package tui import ( "fmt" + "time" tea "github.com/charmbracelet/bubbletea" "github.com/yinxulai/ait/internal/server" @@ -154,10 +155,19 @@ func (c *Client) GetRunStateForHistoryCmd(runID server.RunID, summary *types.Tas // summaryToRunState 用 TaskRunSummary 摘要数据构造最小化 RunState,供无磁盘快照时回退展示。 func summaryToRunState(s *types.TaskRunSummary) *server.RunState { - finished := s.FinishedAt status := server.RunStatusCompleted - if s.Status == string(server.RunStatusFailed) { + switch s.Status { + case string(server.RunStatusRunning): + status = server.RunStatusRunning + case string(server.RunStatusFailed): status = server.RunStatusFailed + case string(server.RunStatusStopped): + status = server.RunStatusStopped + } + var finished *time.Time + if !s.FinishedAt.IsZero() { + finishedAt := s.FinishedAt + finished = &finishedAt } return &server.RunState{ RunID: server.RunID(s.RunID), @@ -165,7 +175,7 @@ func summaryToRunState(s *types.TaskRunSummary) *server.RunState { Status: status, Mode: s.Mode, StartedAt: s.StartedAt, - FinishedAt: &finished, + FinishedAt: finished, AvgTPS: s.AvgTPS, AvgTTFT: s.AvgTTFT, SuccessRate: s.SuccessRate, diff --git a/internal/tui/client_test.go b/internal/tui/client_test.go new file mode 100644 index 0000000..1cc43ad --- /dev/null +++ b/internal/tui/client_test.go @@ -0,0 +1,30 @@ +package tui + +import ( + "testing" + "time" + + "github.com/yinxulai/ait/internal/server" + "github.com/yinxulai/ait/internal/types" +) + +func TestSummaryToRunState_RunningSummaryKeepsNilFinishedAt(t *testing.T) { + state := summaryToRunState(&types.TaskRunSummary{ + RunID: "run-1", + TaskID: "task-1", + Status: string(server.RunStatusRunning), + StartedAt: time.Unix(100, 0), + AvgTPS: 12.5, + SuccessRate: 50, + }) + + if state.Status != server.RunStatusRunning { + t.Fatalf("Status: got %q, want %q", state.Status, server.RunStatusRunning) + } + if state.FinishedAt != nil { + t.Fatal("expected FinishedAt to stay nil for running summary fallback") + } + if state.AvgTPS != 12.5 { + t.Fatalf("AvgTPS: got %v, want %v", state.AvgTPS, 12.5) + } +} diff --git a/internal/tui/messages.go b/internal/tui/messages.go index f64791f..3099253 100644 --- a/internal/tui/messages.go +++ b/internal/tui/messages.go @@ -7,7 +7,7 @@ import ( // TasksLoadedMsg 任务列表加载完成(初始化或刷新后)。 type TasksLoadedMsg struct { - Tasks []types.TaskDefinition + Tasks []server.TaskOverview } // TaskSavedMsg 新建或更新任务完成。 diff --git a/internal/tui/model.go b/internal/tui/model.go index ee930ed..37b881d 100644 --- a/internal/tui/model.go +++ b/internal/tui/model.go @@ -433,7 +433,7 @@ func (m *Model) findTask(taskID string) *types.TaskDefinition { } for i := range m.taskList.Tasks { if m.taskList.Tasks[i].ID == taskID { - return &m.taskList.Tasks[i] + return &m.taskList.Tasks[i].TaskDefinition } } return nil @@ -522,7 +522,7 @@ func (m *Model) currentRunTaskID(isDash bool) string { return "" } -func (m *Model) collectRequests() []*server.RequestMetrics { +func (m *Model) collectRequests() []*types.RequestMetrics { // 优先使用当前活跃视图的数据,避免两个面板均有 RunState 时取错 switch m.view { case viewTurboDash: diff --git a/internal/tui/model_test.go b/internal/tui/model_test.go index 03ee1b5..e45c5a3 100644 --- a/internal/tui/model_test.go +++ b/internal/tui/model_test.go @@ -11,7 +11,7 @@ import ( // stubServer 是 server.Server 的测试桩,所有方法都返回零值。 type stubServer struct{} -func (s *stubServer) ListTasks() []types.TaskDefinition { return nil } +func (s *stubServer) ListTasks() []server.TaskOverview { return nil } func (s *stubServer) GetTask(id string) (types.TaskDefinition, bool) { return types.TaskDefinition{}, false } func (s *stubServer) CreateTask(cfg server.TaskConfig) (types.TaskDefinition, error) { return types.TaskDefinition{}, nil } func (s *stubServer) UpdateTask(id string, cfg server.TaskConfig) (types.TaskDefinition, error) { diff --git a/internal/tui/pages/dashboard.go b/internal/tui/pages/dashboard.go index ac07c78..32a2b33 100644 --- a/internal/tui/pages/dashboard.go +++ b/internal/tui/pages/dashboard.go @@ -8,6 +8,7 @@ import ( tea "github.com/charmbracelet/bubbletea" "github.com/charmbracelet/lipgloss" "github.com/yinxulai/ait/internal/server" + "github.com/yinxulai/ait/internal/types" ) // DashboardState 标准模式运行仪表盘页状态。 @@ -69,7 +70,7 @@ func HandleDashboardKey(d *DashboardState, msg tea.KeyMsg, client Client) (*Dash return d, nil, NavAction{To: NavTaskList} } - var reqs []*server.RequestMetrics + var reqs []*types.RequestMetrics if d.RunState != nil { reqs = d.RunState.Requests } diff --git a/internal/tui/pages/reqdetail.go b/internal/tui/pages/reqdetail.go index 55ae705..1e2d23f 100644 --- a/internal/tui/pages/reqdetail.go +++ b/internal/tui/pages/reqdetail.go @@ -7,19 +7,20 @@ import ( tea "github.com/charmbracelet/bubbletea" "github.com/charmbracelet/lipgloss" "github.com/yinxulai/ait/internal/server" + "github.com/yinxulai/ait/internal/types" ) // ReqDetailState 请求详情页状态。 type ReqDetailState struct { RunID server.RunID - Requests []*server.RequestMetrics + Requests []*types.RequestMetrics Index int // 当前查看的请求索引 ScrollY int // 输出区域滚动偏移 BackNav NavAction // 按 b/esc 时的返回目标 } // NewReqDetailState 创建请求详情状态。 -func NewReqDetailState(runID server.RunID, reqs []*server.RequestMetrics, index int) *ReqDetailState { +func NewReqDetailState(runID server.RunID, reqs []*types.RequestMetrics, index int) *ReqDetailState { return &ReqDetailState{ RunID: runID, Requests: reqs, @@ -152,7 +153,7 @@ func RenderReqDetail(s *ReqDetailState, taskName string, st Styles, width, heigh } // buildReqPerfPanel 构建请求左侧性能指标面板。 -func buildReqPerfPanel(r *server.RequestMetrics, st Styles, maxH, width int) string { +func buildReqPerfPanel(r *types.RequestMetrics, st Styles, maxH, width int) string { var lines []string lines = append(lines, " "+st.SectionHead.Render("性能指标")) lines = append(lines, "") @@ -191,7 +192,7 @@ func buildReqPerfPanel(r *server.RequestMetrics, st Styles, maxH, width int) str } // buildReqNetworkPanel 构建请求右侧网络指标面板。 -func buildReqNetworkPanel(r *server.RequestMetrics, st Styles, maxH, width int) string { +func buildReqNetworkPanel(r *types.RequestMetrics, st Styles, maxH, width int) string { var lines []string lines = append(lines, " "+st.SectionHead.Render("网络指标")) lines = append(lines, "") @@ -218,7 +219,7 @@ func buildReqNetworkPanel(r *server.RequestMetrics, st Styles, maxH, width int) } // buildInputSection 构建输入 (请求体) 区域。 -func buildInputSection(r *server.RequestMetrics, st Styles, width, maxH int) string { +func buildInputSection(r *types.RequestMetrics, st Styles, width, maxH int) string { var lines []string lines = append(lines, " "+st.SectionHead.Render("请求体 (Request Body)")) lines = append(lines, " "+dividerLine(st, width-2)) @@ -241,7 +242,7 @@ func buildInputSection(r *server.RequestMetrics, st Styles, width, maxH int) str } // buildOutputSection 构建输出 (响应体) 区域。 -func buildOutputSection(r *server.RequestMetrics, scrollY int, st Styles, width, maxH int) string { +func buildOutputSection(r *types.RequestMetrics, scrollY int, st Styles, width, maxH int) string { var lines []string lines = append(lines, " "+st.SectionHead.Render("响应体 (Response Body)")) lines = append(lines, " "+dividerLine(st, width-2)) diff --git a/internal/tui/pages/taskdetail.go b/internal/tui/pages/taskdetail.go index 3df591b..2bb6b28 100644 --- a/internal/tui/pages/taskdetail.go +++ b/internal/tui/pages/taskdetail.go @@ -3,6 +3,7 @@ package pages import ( "fmt" "strings" + "time" "github.com/charmbracelet/lipgloss" tea "github.com/charmbracelet/bubbletea" @@ -27,11 +28,25 @@ func NewTaskDetailState(task types.TaskDefinition) *TaskDetailState { return &TaskDetailState{Task: task} } +func taskDetailHistoryEntries(s *TaskDetailState) []types.TaskRunSummary { + if s == nil || len(s.History) == 0 { + return nil + } + if s.ActiveRun == nil { + return s.History + } + if strings.TrimSpace(s.History[0].RunID) == strings.TrimSpace(string(s.ActiveRun.RunID)) { + return s.History[1:] + } + return s.History +} + // HandleTaskDetailKey 处理任务详情页按键。 func HandleTaskDetailKey(s *TaskDetailState, msg tea.KeyMsg, client Client) (*TaskDetailState, tea.Cmd, NavAction) { nav := NavAction{} hasActive := s.ActiveRun != nil - effectiveLen := len(s.History) + historyEntries := taskDetailHistoryEntries(s) + effectiveLen := len(historyEntries) if hasActive { effectiveLen++ } @@ -57,10 +72,10 @@ func HandleTaskDetailKey(s *TaskDetailState, msg tea.KeyMsg, client Client) (*Ta if hasActive { histIdx-- } - if histIdx >= 0 && histIdx < len(s.History) { - runID := strings.TrimSpace(s.History[histIdx].RunID) + if histIdx >= 0 && histIdx < len(historyEntries) { + runID := strings.TrimSpace(historyEntries[histIdx].RunID) if runID != "" { - sum := s.History[histIdx] + sum := historyEntries[histIdx] nav = NavAction{To: NavRunDetail, RunID: server.RunID(runID), Summary: &sum} } } @@ -84,8 +99,11 @@ func HandleTaskDetailKey(s *TaskDetailState, msg tea.KeyMsg, client Client) (*Ta if hasActive { histIdx-- } - if histIdx >= 0 && histIdx < len(s.History) { - runID := strings.TrimSpace(s.History[histIdx].RunID) + if histIdx >= 0 && histIdx < len(historyEntries) { + if historyEntries[histIdx].Status == string(server.RunStatusRunning) { + break + } + runID := strings.TrimSpace(historyEntries[histIdx].RunID) if runID != "" { return s, client.GenerateReportCmd(server.RunID(runID), server.ReportFormatJSON), nav } @@ -140,7 +158,7 @@ func RenderTaskDetail(s *TaskDetailState, st Styles, width, height int) string { var cbItems []ContextBarItem hasActive := s.ActiveRun != nil - effectiveLen := len(s.History) + effectiveLen := len(taskDetailHistoryEntries(s)) if hasActive { effectiveLen++ } @@ -209,11 +227,12 @@ func buildTaskDetailContent(s *TaskDetailState, st Styles, t types.TaskDefinitio // ─── 右栏:历史运行记录 ───────────────────────────────────── var rightLines []string + historyEntries := taskDetailHistoryEntries(s) rightLines = append(rightLines, padRight(" "+st.SectionHead.Render("历史运行记录"), rightW)) rightLines = append(rightLines, padRight(st.Divider.Render(strings.Repeat("─", rightW)), rightW)) hasActive := s.ActiveRun != nil - effectiveLen := len(s.History) + effectiveLen := len(historyEntries) if hasActive { effectiveLen++ } @@ -246,7 +265,7 @@ func buildTaskDetailContent(s *TaskDetailState, st Styles, t types.TaskDefinitio } } if histIdx >= 0 { - detailLines = buildTaskHistoryDetailLines(s, histIdx, st, rightW) + detailLines = buildTaskHistoryDetailLines(historyEntries, histIdx, st, rightW) } } tableMaxH := maxH - len(detailLines) @@ -292,21 +311,25 @@ func buildTaskDetailContent(s *TaskDetailState, st Styles, t types.TaskDefinitio if hasActive { histIdx-- } - run := s.History[histIdx] - statusText := "✓" - if run.Status != "completed" { - statusText = "✗" + run := historyEntries[histIdx] + statusText := "✗" + statusStyle := st.ErrStyle + switch run.Status { + case string(server.RunStatusRunning): + statusText = "●" + statusStyle = st.Ok + case string(server.RunStatusCompleted): + statusText = "✓" + statusStyle = st.Ok + case string(server.RunStatusStopped): + statusText = "■" + statusStyle = st.Muted } modeShort := "标准" if run.Mode == "turbo" { modeShort = "Turbo" } - statusIcon := statusText - if run.Status == "completed" { - statusIcon = styleWhenNotSelected(isSel, st.Ok, statusText) - } else { - statusIcon = styleWhenNotSelected(isSel, st.ErrStyle, statusText) - } + statusIcon := styleWhenNotSelected(isSel, statusStyle, statusText) row = padRight(marker, markW) + padRight(statusIcon, statW) + padRight(run.StartedAt.Format("2006-01-02 15:04"), timeW) + @@ -354,31 +377,41 @@ func UpdateTaskDetailHistory(s *TaskDetailState, history []types.TaskRunSummary, return s } s.History = history - if len(history) == 0 { + effectiveLen := len(taskDetailHistoryEntries(s)) + if s.ActiveRun != nil { + effectiveLen++ + } + if effectiveLen == 0 { s.HistorySel = 0 s.HistoryOff = 0 } else { if s.HistorySel < 0 { s.HistorySel = 0 } - if s.HistorySel >= len(history) { - s.HistorySel = len(history) - 1 + if s.HistorySel >= effectiveLen { + s.HistorySel = effectiveLen - 1 } - s.HistoryOff = ensureVisibleOffset(s.HistorySel, len(history), s.HistoryOff, s.HistoryVis) + s.HistoryOff = ensureVisibleOffset(s.HistorySel, effectiveLen, s.HistoryOff, s.HistoryVis) } - if autoExpand && len(history) > 0 { + if autoExpand && len(taskDetailHistoryEntries(s)) > 0 { // autoExpand 参数保留接口兼容性,展开行为已由渲染层自动处理 _ = autoExpand } return s } -func buildTaskHistoryDetailLines(s *TaskDetailState, histIdx int, st Styles, width int) []string { - if histIdx < 0 || histIdx >= len(s.History) { +func buildTaskHistoryDetailLines(history []types.TaskRunSummary, histIdx int, st Styles, width int) []string { + if histIdx < 0 || histIdx >= len(history) { return nil } - sel := s.History[histIdx] + sel := history[histIdx] elapsed := sel.FinishedAt.Sub(sel.StartedAt) + elapsedText := fmtDuration(elapsed) + finishedText := sel.FinishedAt.Format("2006-01-02 15:04") + if sel.FinishedAt.IsZero() { + elapsedText = fmtDuration(time.Since(sel.StartedAt)) + finishedText = "进行中" + } labelW := 8 indent := " " gap := 4 @@ -388,6 +421,9 @@ func buildTaskHistoryDetailLines(s *TaskDetailState, histIdx int, st Styles, wid statusText := sel.Status statusStyle := st.Value switch sel.Status { + case "running": + statusText = "运行中" + statusStyle = st.Ok case "completed": statusText = "完成" statusStyle = st.Ok @@ -447,10 +483,10 @@ func buildTaskHistoryDetailLines(s *TaskDetailState, histIdx int, st Styles, wid ) lines = appendPairRow(lines, "开始", sel.StartedAt.Format("2006-01-02 15:04"), st.Value, - "结束", sel.FinishedAt.Format("2006-01-02 15:04"), st.Value, + "结束", finishedText, st.Value, ) lines = appendPairRow(lines, - "耗时", fmtDuration(elapsed), st.Value, + "耗时", elapsedText, st.Value, "成功率", fmt.Sprintf("%.1f%%", sel.SuccessRate), st.Value, ) lines = appendPairRow(lines, @@ -462,16 +498,6 @@ func buildTaskHistoryDetailLines(s *TaskDetailState, histIdx int, st Styles, wid if sel.CacheHitRate > 0 { lines = appendSingleField(lines, "缓存", fmt.Sprintf("%.1f%%", sel.CacheHitRate*100), st.Value) } - if sel.ReportJSONPath != "" || sel.ReportCSVPath != "" { - reports := make([]string, 0, 2) - if sel.ReportJSONPath != "" { - reports = append(reports, "JSON") - } - if sel.ReportCSVPath != "" { - reports = append(reports, "CSV") - } - lines = appendSingleField(lines, "报告", strings.Join(reports, " / "), st.Muted) - } if sel.ErrorSummary != "" { lines = append(lines, indent+st.Label.Render("错误摘要")) for _, seg := range wrapText(sel.ErrorSummary, maxInt(10, contentW-2)) { diff --git a/internal/tui/pages/taskdetail_test.go b/internal/tui/pages/taskdetail_test.go new file mode 100644 index 0000000..20e0cfb --- /dev/null +++ b/internal/tui/pages/taskdetail_test.go @@ -0,0 +1,26 @@ +package pages + +import ( + "testing" + + "github.com/yinxulai/ait/internal/server" + "github.com/yinxulai/ait/internal/types" +) + +func TestTaskDetailHistoryEntries_SkipsActiveRunDuplicate(t *testing.T) { + state := &TaskDetailState{ + ActiveRun: &server.RunState{RunID: "run-2"}, + History: []types.TaskRunSummary{ + {RunID: "run-2", Status: string(server.RunStatusRunning)}, + {RunID: "run-1", Status: string(server.RunStatusCompleted)}, + }, + } + + entries := taskDetailHistoryEntries(state) + if len(entries) != 1 { + t.Fatalf("expected 1 visible history entry, got %d", len(entries)) + } + if entries[0].RunID != "run-1" { + t.Fatalf("RunID: got %q, want %q", entries[0].RunID, "run-1") + } +} diff --git a/internal/tui/pages/tasklist.go b/internal/tui/pages/tasklist.go index 7d60d6e..5352b04 100644 --- a/internal/tui/pages/tasklist.go +++ b/internal/tui/pages/tasklist.go @@ -13,7 +13,7 @@ import ( // TaskListState 任务列表页状态。 type TaskListState struct { - Tasks []types.TaskDefinition + Tasks []server.TaskOverview Selected int Offset int Visible int @@ -33,7 +33,7 @@ func (s *TaskListState) CurrentTask() (types.TaskDefinition, bool) { if len(s.Tasks) == 0 || s.Selected < 0 || s.Selected >= len(s.Tasks) { return types.TaskDefinition{}, false } - return s.Tasks[s.Selected], true + return s.Tasks[s.Selected].TaskDefinition, true } // IsTaskRunning 判断某任务是否正在运行。 @@ -48,9 +48,10 @@ func (s *TaskListState) IsTaskRunning(taskID string) bool { func (s *TaskListState) latestRunAt() *time.Time { var latest *time.Time for _, t := range s.Tasks { - if t.LastRunAt != nil { - if latest == nil || t.LastRunAt.After(*latest) { - latest = t.LastRunAt + if t.LatestRun != nil && !t.LatestRun.FinishedAt.IsZero() { + finishedAt := t.LatestRun.FinishedAt + if latest == nil || finishedAt.After(*latest) { + latest = &finishedAt } } } @@ -231,13 +232,13 @@ func buildTaskListContent(s *TaskListState, st Styles, width, maxH int) string { // ── 上次运行时间 ── lastRunText := "─" - if hasActiveRun { + if hasActiveRun || (t.LatestRun != nil && t.LatestRun.Status == string(server.RunStatusRunning)) { lastRunText = "运行中" - } else if t.LastRunAt != nil { - lastRunText = fmtRelativeTime(*t.LastRunAt) + } else if t.LatestRun != nil && !t.LatestRun.FinishedAt.IsZero() { + lastRunText = fmtRelativeTime(t.LatestRun.FinishedAt) } lastRunStyle := st.Muted - if hasActiveRun { + if hasActiveRun || (t.LatestRun != nil && t.LatestRun.Status == string(server.RunStatusRunning)) { lastRunStyle = st.Ok } lastRunCol := padRight(styleWhenNotSelected(isSel, lastRunStyle, lastRunText), lastRunW) @@ -246,8 +247,8 @@ func buildTaskListContent(s *TaskListState, st Styles, width, maxH int) string { ttftText := "─" if hasActiveRun && rs != nil && rs.AvgTTFT > 0 { ttftText = fmtDuration(rs.AvgTTFT) - } else if !hasActiveRun && t.LastRunSummary != nil { - ttftText = fmtDuration(t.LastRunSummary.AvgTTFT) + } else if !hasActiveRun && t.LatestRun != nil { + ttftText = fmtDuration(t.LatestRun.AvgTTFT) } ttftCol := padRight(styleWhenNotSelected(isSel, st.Value, ttftText), ttftW) @@ -255,11 +256,11 @@ func buildTaskListContent(s *TaskListState, st Styles, width, maxH int) string { tpsText := "─" if hasActiveRun && rs != nil && rs.AvgTPS > 0 { tpsText = fmt.Sprintf("%.1f", rs.AvgTPS) - } else if !hasActiveRun && t.LastRunSummary != nil { - if t.Input.Turbo && t.LastRunSummary.MaxStableConcurrency > 0 { - tpsText = fmt.Sprintf("并发%d", t.LastRunSummary.MaxStableConcurrency) + } else if !hasActiveRun && t.LatestRun != nil { + if t.Input.Turbo && t.LatestRun.MaxStableConcurrency > 0 { + tpsText = fmt.Sprintf("并发%d", t.LatestRun.MaxStableConcurrency) } else if !t.Input.Turbo { - tpsText = fmt.Sprintf("%.1f", t.LastRunSummary.AvgTPS) + tpsText = fmt.Sprintf("%.1f", t.LatestRun.AvgTPS) } } tpsCol := styleWhenNotSelected(isSel, st.Value, tpsText) diff --git a/internal/tui/pages/wizard.go b/internal/tui/pages/wizard.go index 3cce109..d2ef1f2 100644 --- a/internal/tui/pages/wizard.go +++ b/internal/tui/pages/wizard.go @@ -759,7 +759,7 @@ func renderStep3Summary(wz *WizardState, st Styles, innerW int) []string { addRow("目标长度", strconv.Itoa(wz.PromptLength), st.Muted) } - lines = append(lines, "", st.Muted.Render("保存位置: ~/.ait/tasks.json")) + lines = append(lines, "", st.Muted.Render("保存位置: ~/.ait/tasks/.json")) return lines } diff --git a/internal/types/types.go b/internal/types/types.go index dd458be..f9523b7 100644 --- a/internal/types/types.go +++ b/internal/types/types.go @@ -230,13 +230,11 @@ type ReportData struct { } type TaskDefinition struct { - ID string `json:"id"` - Name string `json:"name"` - Input Input `json:"input"` - CreatedAt time.Time `json:"created_at"` - UpdatedAt time.Time `json:"updated_at"` - LastRunAt *time.Time `json:"last_run_at,omitempty"` - LastRunSummary *TaskRunSummary `json:"last_run_summary,omitempty"` + ID string `json:"id"` + Name string `json:"name"` + Input Input `json:"input"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` } type TaskRunSummary struct { @@ -253,11 +251,28 @@ type TaskRunSummary struct { AvgTPS float64 `json:"avg_tps"` CacheHitRate float64 `json:"cache_hit_rate"` MaxStableConcurrency int `json:"max_stable_concurrency,omitempty"` - ReportJSONPath string `json:"report_json_path,omitempty"` - ReportCSVPath string `json:"report_csv_path,omitempty"` ErrorSummary string `json:"error_summary,omitempty"` } +type RequestMetrics struct { + Index int `json:"index"` + Success bool `json:"success"` + TotalTime time.Duration `json:"total_time"` + TTFT time.Duration `json:"ttft"` + TPS float64 `json:"tps"` + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + CachedTokens int `json:"cached_tokens"` + CacheHitRate float64 `json:"cache_hit_rate"` + DNSTime time.Duration `json:"dns_time"` + ConnectTime time.Duration `json:"connect_time"` + TLSTime time.Duration `json:"tls_time"` + TargetIP string `json:"target_ip"` + ErrorMessage string `json:"error_message,omitempty"` + RequestBody string `json:"request_body,omitempty"` + ResponseBody string `json:"response_body,omitempty"` +} + type TurboConfig struct { InitConcurrency int `json:"init_concurrency"` MaxConcurrency int `json:"max_concurrency"` From ed738c31dc1bdc6b7a6e97996362cc48e4fcb252 Mon Sep 17 00:00:00 2001 From: Alain Date: Wed, 20 May 2026 00:43:18 +0800 Subject: [PATCH 18/52] =?UTF-8?q?feat:=20=E9=87=8D=E6=9E=84=E4=BB=BB?= =?UTF-8?q?=E5=8A=A1=E7=AE=A1=E7=90=86=EF=BC=8C=E6=94=AF=E6=8C=81=E4=BB=BB?= =?UTF-8?q?=E5=8A=A1=E8=A7=86=E5=9B=BE=E5=92=8C=E9=94=99=E8=AF=AF=E5=A4=84?= =?UTF-8?q?=E7=90=86=EF=BC=8C=E4=BC=98=E5=8C=96=E4=BB=BB=E5=8A=A1=E5=88=97?= =?UTF-8?q?=E8=A1=A8=E5=8A=A0=E8=BD=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- docs/storage.md | 25 +++-- internal/server/run.go | 81 ++++++-------- internal/server/server.go | 12 +-- internal/server/server_test.go | 152 ++++++++++++++++----------- internal/server/task.go | 144 ++++++++++++------------- internal/server/types.go | 7 -- internal/store/run.go | 168 +++++++++++++++++++++-------- internal/store/run_test.go | 96 +++++++++++++++++ internal/store/task.go | 186 ++++++++++++++------------------- internal/store/task_test.go | 163 +++++++++++++++++++++++++++++ internal/store/taskview.go | 35 +++++++ internal/tui/client.go | 6 +- internal/tui/messages.go | 2 +- internal/tui/model_test.go | 4 +- internal/tui/pages/tasklist.go | 2 +- internal/types/types.go | 5 + 16 files changed, 721 insertions(+), 367 deletions(-) create mode 100644 internal/store/run_test.go create mode 100644 internal/store/task_test.go create mode 100644 internal/store/taskview.go diff --git a/docs/storage.md b/docs/storage.md index e4d4c25..d819631 100644 --- a/docs/storage.md +++ b/docs/storage.md @@ -109,12 +109,15 @@ 建议字段: -- task_id - name - input - created_at - updated_at +说明: + +- task id 由文件名 `.json` 表达,不在文件体内重复存一份。 + 这里的 input 就是任务配置本体: - 协议 @@ -143,8 +146,6 @@ run.json 只保存最小运行元数据,用来描述这次运行属于谁、 建议字段: -- run_id -- task_id - mode - protocol - model @@ -152,6 +153,10 @@ run.json 只保存最小运行元数据,用来描述这次运行属于谁、 - started_at - finished_at +说明: + +- run id 和 task id 由目录路径 `runs///run.json` 表达,不在 run.json 里重复存储。 + 它不承担指标存储,不承担请求明细存储,也不承担列表视图存储。 ### 5.4 runs///requests.jsonl @@ -185,12 +190,16 @@ run.json 只保存最小运行元数据,用来描述这次运行属于谁、 建议保存: -- 总请求数 -- 成功率 / 错误率 -- TTFT / TPS / 总耗时等最终聚合指标 -- 输入输出 token 统计 +- 标准模式的最终报告结果 - Turbo 模式的最终级别结果 -- 最终错误摘要(如有) +- 无法从 requests.jsonl 重新推出的运行结果字段 + 例如提前失败时的计划总请求数、最终错误摘要 + +明确不再保存: + +- 可以从 requests.jsonl 直接重算的 done/success/failed 计数 +- 可以从 requests.jsonl 或 final report 重算的 success_rate / avg_ttft / avg_tps / cache_hit_rate +- 仅为任务列表或历史列表服务的摘要副本 这个文件只在运行结束后写一次。 diff --git a/internal/server/run.go b/internal/server/run.go index bc4491c..76a670e 100644 --- a/internal/server/run.go +++ b/internal/server/run.go @@ -1,6 +1,7 @@ package server import ( + "errors" "fmt" "sync" "sync/atomic" @@ -128,65 +129,57 @@ func buildStoredRunMetadata(taskDef types.TaskDefinition, snap *RunState) store. func buildStoredRunResult(snap *RunState) store.RunResult { result := store.RunResult{ - TotalReqs: snap.TotalReqs, - DoneReqs: snap.DoneReqs, - SuccessReqs: snap.SuccessReqs, - FailedReqs: snap.FailedReqs, - SuccessRate: snap.SuccessRate, - AvgTTFT: snap.AvgTTFT, - AvgTPS: snap.AvgTPS, - CacheHitRate: snap.CacheHitRate, ErrorSummary: snap.ErrorMsg, StandardResult: snap.StandardResult, TurboResult: snap.TurboResult, } - if snap.TurboResult != nil { - result.MaxStableConcurrency = snap.TurboResult.MaxStableConcurrency - } else if snap.CurrentLevel > 0 { + if snap.StandardResult == nil && snap.TurboResult == nil && snap.TotalReqs > 0 { + result.TotalReqs = snap.TotalReqs + } + if snap.TurboResult == nil && snap.CurrentLevel > 0 { result.MaxStableConcurrency = snap.CurrentLevel } return result } -func buildRunStateFromStoredRun(run *store.StoredRun, requests []*types.RequestMetrics) *RunState { +func buildRunStateFromStoredRun(run *store.StoredRun, requests []types.RequestMetrics) *RunState { if run == nil { return nil } + summary := run.Summary(requests) state := &RunState{ RunID: RunID(run.Metadata.RunID), TaskID: run.Metadata.TaskID, Status: RunStatus(run.Metadata.Status), Mode: run.Metadata.Mode, StartedAt: run.Metadata.StartedAt, - Requests: requests, + Requests: requestPointers(requests), + AvgTTFT: summary.AvgTTFT, + AvgTPS: summary.AvgTPS, + SuccessRate: summary.SuccessRate, + CacheHitRate: summary.CacheHitRate, + ErrorMsg: summary.ErrorSummary, + CurrentLevel: summary.MaxStableConcurrency, } if run.Metadata.FinishedAt != nil { finished := *run.Metadata.FinishedAt state.FinishedAt = &finished } + state.DoneReqs = len(requests) + for _, request := range requests { + if request.Success { + state.SuccessReqs++ + } + } + state.FailedReqs = state.DoneReqs - state.SuccessReqs + state.TotalReqs = run.TotalReqs(requests) if run.Result == nil { return state } - state.TotalReqs = run.Result.TotalReqs - state.DoneReqs = run.Result.DoneReqs - state.SuccessReqs = run.Result.SuccessReqs - state.FailedReqs = run.Result.FailedReqs - state.SuccessRate = run.Result.SuccessRate - state.AvgTTFT = run.Result.AvgTTFT - state.AvgTPS = run.Result.AvgTPS - state.CacheHitRate = run.Result.CacheHitRate state.StandardResult = run.Result.StandardResult state.TurboResult = run.Result.TurboResult - state.ErrorMsg = run.Result.ErrorSummary - state.CurrentLevel = run.Result.MaxStableConcurrency - if state.DoneReqs == 0 && len(requests) > 0 { - state.DoneReqs = len(requests) - } - if state.TotalReqs == 0 && len(requests) > 0 { - state.TotalReqs = len(requests) - } if run.Result.TurboResult != nil { state.Levels = run.Result.TurboResult.Levels state.CurrentLevel = run.Result.TurboResult.MaxStableConcurrency @@ -219,14 +212,14 @@ func buildRunningRunSummary(taskDef types.TaskDefinition, snap *RunState) types. // StartRun 启动一次新的运行,立即返回 RunID。 func (s *serverImpl) StartRun(taskID string) (RunID, error) { - s.mu.RLock() - taskDef, ok := s.taskStore.Get(taskID) - runStore := s.runStore - s.mu.RUnlock() - - if !ok { - return "", fmt.Errorf("task %q not found", taskID) + taskDef, err := s.taskStore.Get(taskID) + if err != nil { + if errors.Is(err, store.ErrTaskNotFound) { + return "", fmt.Errorf("task %q not found: %w", taskID, err) + } + return "", fmt.Errorf("get task %q: %w", taskID, err) } + runStore := s.runStore // 解析 PromptSource(将 PromptText/PromptFile 转换为可调用的 PromptSource) hydratedInput, err := task.HydrateInput(taskDef.Input) @@ -472,10 +465,6 @@ func (s *serverImpl) persistFinalRun(runStore *store.RunStore, taskDef types.Tas return runStore.SaveFinal(buildStoredRunMetadata(taskDef, snap), buildStoredRunResult(snap)) } -func (s *serverImpl) persistRunResult(summary types.TaskRunSummary) error { - return s.runStore.SaveSummary(summary) -} - func (s *serverImpl) removeActiveRun(runID RunID) { s.mu.Lock() defer s.mu.Unlock() @@ -529,7 +518,7 @@ func (s *serverImpl) GetRunState(runID RunID) (*RunState, bool) { if err != nil { return nil, false } - return buildRunStateFromStoredRun(run, requestPointers(requests)), true + return buildRunStateFromStoredRun(run, requests), true } // Subscribe 订阅指定运行的事件流。 @@ -539,15 +528,7 @@ func (s *serverImpl) Subscribe(runID RunID) (<-chan Event, CancelFunc) { // GetHistory 返回任务的历史运行摘要,最新在前。 func (s *serverImpl) GetHistory(taskID string, limit int) ([]types.TaskRunSummary, error) { - runs, err := s.runStore.ListByTask(taskID, limit) - if err != nil { - return nil, err - } - history := make([]types.TaskRunSummary, 0, len(runs)) - for _, run := range runs { - history = append(history, run.Summary()) - } - return history, nil + return s.runStore.ListSummariesByTask(taskID, limit) } // GenerateReport 为已完成的标准运行生成报告文件。 diff --git a/internal/server/server.go b/internal/server/server.go index d095cbf..03d4b58 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -12,8 +12,8 @@ import ( // 所有方法均为线程安全。 type Server interface { // --- 任务 CRUD --- - ListTasks() []TaskOverview - GetTask(id string) (types.TaskDefinition, bool) + ListTasks() ([]types.TaskOverview, error) + GetTask(id string) (types.TaskDefinition, error) CreateTask(cfg TaskConfig) (types.TaskDefinition, error) UpdateTask(id string, cfg TaskConfig) (types.TaskDefinition, error) DeleteTask(id string) error @@ -46,6 +46,7 @@ type Server interface { type serverImpl struct { mu sync.RWMutex taskStore *store.TaskStore + taskViews *store.TaskViewStore runStore *store.RunStore bus *eventBus activeRuns map[RunID]*activeRun @@ -69,13 +70,12 @@ func New() (Server, error) { } ts := store.NewTaskStore(tasksDir) - if err := ts.Load(); err != nil { - return nil, err - } + rs := store.NewRunStore(runsDir) return &serverImpl{ taskStore: ts, - runStore: store.NewRunStore(runsDir), + taskViews: store.NewTaskViewStore(ts, rs), + runStore: rs, bus: newEventBus(), activeRuns: make(map[RunID]*activeRun), }, nil diff --git a/internal/server/server_test.go b/internal/server/server_test.go index 0b1a52d..4a9521d 100644 --- a/internal/server/server_test.go +++ b/internal/server/server_test.go @@ -24,12 +24,11 @@ func newTestServer(t *testing.T) *serverImpl { t.Fatalf("mkdir runs: %v", err) } ts := store.NewTaskStore(tasksDir) - if err := ts.Load(); err != nil { - t.Fatalf("load task store: %v", err) - } + rs := store.NewRunStore(runsDir) return &serverImpl{ taskStore: ts, - runStore: store.NewRunStore(runsDir), + taskViews: store.NewTaskViewStore(ts, rs), + runStore: rs, bus: newEventBus(), activeRuns: make(map[RunID]*activeRun), } @@ -380,13 +379,6 @@ func TestGetRunState_LoadsCompletedRunFromDisk(t *testing.T) { StartedAt: startedAt, FinishedAt: &finishedAt, }, store.RunResult{ - TotalReqs: 4, - DoneReqs: 1, - SuccessReqs: 1, - AvgTPS: 18.5, - AvgTTFT: 120 * time.Millisecond, - SuccessRate: 25, - CacheHitRate: 0.4, ErrorSummary: "", StandardResult: &types.ReportData{TotalRequests: 4, AvgTPS: 18.5, AvgTTFT: 120 * time.Millisecond, SuccessRate: 25}, }); err != nil { @@ -432,7 +424,10 @@ func TestGetRunState_LoadsCompletedRunFromDisk(t *testing.T) { func TestListTasks_Empty(t *testing.T) { s := newTestServer(t) - tasks := s.ListTasks() + tasks, err := s.ListTasks() + if err != nil { + t.Fatalf("ListTasks: %v", err) + } if len(tasks) != 0 { t.Errorf("expected empty list, got %d tasks", len(tasks)) } @@ -455,7 +450,10 @@ func TestCreateTask_ReturnsTaskWithID(t *testing.T) { func TestCreateTask_AppearsInList(t *testing.T) { s := newTestServer(t) s.CreateTask(makeTaskConfig("task-a")) - all := s.ListTasks() + all, err := s.ListTasks() + if err != nil { + t.Fatalf("ListTasks: %v", err) + } if len(all) != 1 { t.Errorf("expected 1 task, got %d", len(all)) } @@ -471,17 +469,21 @@ func TestCreateTask_MultipleTasksAllListed(t *testing.T) { t.Fatalf("CreateTask %q: %v", name, err) } } - if len(s.ListTasks()) != 3 { - t.Errorf("expected 3 tasks, got %d", len(s.ListTasks())) + tasks, err := s.ListTasks() + if err != nil { + t.Fatalf("ListTasks: %v", err) + } + if len(tasks) != 3 { + t.Errorf("expected 3 tasks, got %d", len(tasks)) } } func TestGetTask_Found(t *testing.T) { s := newTestServer(t) created, _ := s.CreateTask(makeTaskConfig("task-get")) - got, ok := s.GetTask(created.ID) - if !ok { - t.Fatal("GetTask returned not found") + got, err := s.GetTask(created.ID) + if err != nil { + t.Fatalf("GetTask: %v", err) } if got.ID != created.ID { t.Errorf("ID mismatch: %q vs %q", got.ID, created.ID) @@ -490,9 +492,9 @@ func TestGetTask_Found(t *testing.T) { func TestGetTask_NotFound(t *testing.T) { s := newTestServer(t) - _, ok := s.GetTask("nonexistent") - if ok { - t.Fatal("expected not found for nonexistent ID") + _, err := s.GetTask("nonexistent") + if !errors.Is(err, store.ErrTaskNotFound) { + t.Fatalf("expected ErrTaskNotFound, got %v", err) } } @@ -507,9 +509,9 @@ func TestUpdateTask_Success(t *testing.T) { t.Errorf("Name: got %q, want renamed", updated.Name) } // Verify persistence via GetTask. - fetched, ok := s.GetTask(created.ID) - if !ok || fetched.Name != "renamed" { - t.Errorf("GetTask after update: ok=%v name=%q", ok, fetched.Name) + fetched, err := s.GetTask(created.ID) + if err != nil || fetched.Name != "renamed" { + t.Errorf("GetTask after update: err=%v name=%q", err, fetched.Name) } } @@ -527,10 +529,14 @@ func TestDeleteTask_Success(t *testing.T) { if err := s.DeleteTask(created.ID); err != nil { t.Fatalf("DeleteTask: %v", err) } - if _, ok := s.GetTask(created.ID); ok { - t.Error("task still accessible after delete") + if _, err := s.GetTask(created.ID); !errors.Is(err, store.ErrTaskNotFound) { + t.Errorf("expected deleted task to be missing, got %v", err) + } + tasks, err := s.ListTasks() + if err != nil { + t.Fatalf("ListTasks: %v", err) } - if len(s.ListTasks()) != 0 { + if len(tasks) != 0 { t.Error("expected empty list after delete") } } @@ -542,6 +548,26 @@ func TestDeleteTask_NotFound(t *testing.T) { } } +func TestDeleteTask_RunningTaskRejected(t *testing.T) { + s := newTestServer(t) + created, _ := s.CreateTask(makeTaskConfig("still-running")) + runID := RunID("run_live_delete") + + s.mu.Lock() + s.activeRuns[runID] = &activeRun{ + state: &RunState{ + RunID: runID, + TaskID: created.ID, + Status: RunStatusRunning, + }, + } + s.mu.Unlock() + + if err := s.DeleteTask(created.ID); err == nil { + t.Fatal("expected delete to fail while task is running") + } +} + func TestCopyTask_CreatesNewTask(t *testing.T) { s := newTestServer(t) original, _ := s.CreateTask(makeTaskConfig("original")) @@ -555,8 +581,12 @@ func TestCopyTask_CreatesNewTask(t *testing.T) { if copied.Name != "original (copy)" { t.Errorf("Name: got %q, want %q", copied.Name, "original (copy)") } - if len(s.ListTasks()) != 2 { - t.Errorf("expected 2 tasks after copy, got %d", len(s.ListTasks())) + tasks, err := s.ListTasks() + if err != nil { + t.Fatalf("ListTasks: %v", err) + } + if len(tasks) != 2 { + t.Errorf("expected 2 tasks after copy, got %d", len(tasks)) } } @@ -649,16 +679,16 @@ func TestGetHistory_PersistsAfterRun(t *testing.T) { s := newTestServer(t) task, _ := s.CreateTask(makeTaskConfig("persist-task")) - summary := types.TaskRunSummary{ - RunID: "run_test", - TaskID: task.ID, - Mode: "standard", - Status: string(RunStatusCompleted), - StartedAt: time.Now().Add(-time.Second), - FinishedAt: time.Now(), - } - if err := s.persistRunResult(summary); err != nil { - t.Fatalf("persistRunResult: %v", err) + finishedAt := time.Now() + if err := s.runStore.SaveFinal(store.RunMetadata{ + RunID: "run_test", + TaskID: task.ID, + Mode: "standard", + Status: string(RunStatusCompleted), + StartedAt: finishedAt.Add(-time.Second), + FinishedAt: &finishedAt, + }, store.RunResult{}); err != nil { + t.Fatalf("SaveFinal: %v", err) } history, err := s.GetHistory(task.ID, 0) @@ -694,7 +724,10 @@ func TestGetTask_DerivesRunningSummaryFromActiveRun(t *testing.T) { } s.mu.Unlock() - tasks := s.ListTasks() + tasks, err := s.ListTasks() + if err != nil { + t.Fatalf("ListTasks: %v", err) + } if len(tasks) != 1 { t.Fatalf("expected 1 task overview, got %d", len(tasks)) } @@ -709,24 +742,21 @@ func TestGetTask_DerivesRunningSummaryFromActiveRun(t *testing.T) { } } -func TestPersistRunResult_DerivesLatestTaskSummary(t *testing.T) { +func TestStoredRun_DerivesLatestTaskSummary(t *testing.T) { s := newTestServer(t) task, _ := s.CreateTask(makeTaskConfig("finalize-task")) startedAt := time.Now().Add(-2 * time.Second) finishedAt := time.Now() - if err := s.persistRunResult(types.TaskRunSummary{ - RunID: "run_same", - TaskID: task.ID, - Mode: "standard", - Status: string(RunStatusCompleted), - StartedAt: startedAt, - FinishedAt: finishedAt, - SuccessRate: 100, - AvgTTFT: 80 * time.Millisecond, - AvgTPS: 24.2, - }); err != nil { - t.Fatalf("persistRunResult: %v", err) + if err := s.runStore.SaveFinal(store.RunMetadata{ + RunID: "run_same", + TaskID: task.ID, + Mode: "standard", + Status: string(RunStatusCompleted), + StartedAt: startedAt, + FinishedAt: &finishedAt, + }, store.RunResult{}); err != nil { + t.Fatalf("SaveFinal: %v", err) } history, err := s.GetHistory(task.ID, 0) @@ -743,7 +773,10 @@ func TestPersistRunResult_DerivesLatestTaskSummary(t *testing.T) { t.Fatal("expected FinishedAt to be persisted for completed run") } - tasks := s.ListTasks() + tasks, err := s.ListTasks() + if err != nil { + t.Fatalf("ListTasks: %v", err) + } if len(tasks) != 1 { t.Fatalf("expected 1 task overview, got %d", len(tasks)) } @@ -757,13 +790,14 @@ func TestGetHistory_LimitRespected(t *testing.T) { task, _ := s.CreateTask(makeTaskConfig("limit-task")) for i := 0; i < 5; i++ { - if err := s.persistRunResult(types.TaskRunSummary{ + finishedAt := time.Now() + if err := s.runStore.SaveFinal(store.RunMetadata{ RunID: "run_" + string(rune('0'+i)), TaskID: task.ID, - StartedAt: time.Now(), - FinishedAt: time.Now(), - }); err != nil { - t.Fatalf("persistRunResult: %v", err) + StartedAt: finishedAt, + FinishedAt: &finishedAt, + }, store.RunResult{}); err != nil { + t.Fatalf("SaveFinal: %v", err) } } diff --git a/internal/server/task.go b/internal/server/task.go index 29c3ebc..e0be2a8 100644 --- a/internal/server/task.go +++ b/internal/server/task.go @@ -1,73 +1,61 @@ package server import ( + "errors" "fmt" + storepkg "github.com/yinxulai/ait/internal/store" "github.com/yinxulai/ait/internal/types" ) // ListTasks 返回所有任务(最近更新排在前面)。 -func (s *serverImpl) ListTasks() []TaskOverview { +func (s *serverImpl) ListTasks() ([]types.TaskOverview, error) { + overviews, err := s.taskViews.List() + if err != nil { + return nil, err + } + s.mu.RLock() - tasks := s.taskStore.All() - running := s.runningTaskSummariesLocked(tasks) + running := s.runningTaskSummariesLocked(overviews) s.mu.RUnlock() - return s.buildTaskOverviews(tasks, running) + return s.overlayRunningTaskOverviews(overviews, running), nil } // GetTask 按 ID 查找任务。 -func (s *serverImpl) GetTask(id string) (types.TaskDefinition, bool) { - s.mu.RLock() - task, ok := s.taskStore.Get(id) - s.mu.RUnlock() - if !ok { - return types.TaskDefinition{}, false - } - return task, true +func (s *serverImpl) GetTask(id string) (types.TaskDefinition, error) { + return s.taskStore.Get(id) } // CreateTask 新建任务并持久化。 func (s *serverImpl) CreateTask(cfg TaskConfig) (types.TaskDefinition, error) { - s.mu.Lock() - defer s.mu.Unlock() - - task := types.TaskDefinition{ + created, err := s.taskStore.Create(types.TaskDefinition{ Name: cfg.Name, Input: cfg.Input, + }) + if err != nil { + return types.TaskDefinition{}, fmt.Errorf("create task: %w", err) } - s.taskStore.Upsert(task) - if err := s.taskStore.Save(); err != nil { - return types.TaskDefinition{}, fmt.Errorf("save tasks: %w", err) - } - - // 返回已生成 ID 和时间戳的最新状态 - all := s.taskStore.All() - if len(all) > 0 { - return all[0], nil - } - return task, nil + return created, nil } // UpdateTask 更新指定任务,任务不存在时返回错误。 func (s *serverImpl) UpdateTask(id string, cfg TaskConfig) (types.TaskDefinition, error) { - s.mu.Lock() - defer s.mu.Unlock() - - existing, ok := s.taskStore.Get(id) - if !ok { - return types.TaskDefinition{}, fmt.Errorf("task %q not found", id) + existing, err := s.taskStore.Get(id) + if err != nil { + if errors.Is(err, storepkg.ErrTaskNotFound) { + return types.TaskDefinition{}, fmt.Errorf("task %q not found: %w", id, err) + } + return types.TaskDefinition{}, fmt.Errorf("get task %q: %w", id, err) } existing.Name = cfg.Name existing.Input = cfg.Input - s.taskStore.Upsert(existing) - if err := s.taskStore.Save(); err != nil { - return types.TaskDefinition{}, fmt.Errorf("save tasks: %w", err) + updated, err := s.taskStore.Update(existing) + if err != nil { + return types.TaskDefinition{}, fmt.Errorf("update task %q: %w", id, err) } - - updated, _ := s.taskStore.Get(id) return updated, nil } @@ -76,10 +64,10 @@ func (s *serverImpl) DeleteTask(id string) error { s.mu.Lock() defer s.mu.Unlock() - if err := s.taskStore.Delete(id); err != nil { - return err + if s.hasRunningTaskLocked(id) { + return fmt.Errorf("task %q is currently running", id) } - if err := s.taskStore.Save(); err != nil { + if err := s.taskStore.Delete(id); err != nil { return err } return s.runStore.DeleteTask(id) @@ -87,62 +75,48 @@ func (s *serverImpl) DeleteTask(id string) error { // CopyTask 复制指定任务(ID 和时间戳重置,名称加 " (copy)" 后缀)。 func (s *serverImpl) CopyTask(id string) (types.TaskDefinition, error) { - s.mu.Lock() - defer s.mu.Unlock() - - src, ok := s.taskStore.Get(id) - if !ok { - return types.TaskDefinition{}, fmt.Errorf("task %q not found", id) + src, err := s.taskStore.Get(id) + if err != nil { + if errors.Is(err, storepkg.ErrTaskNotFound) { + return types.TaskDefinition{}, fmt.Errorf("task %q not found: %w", id, err) + } + return types.TaskDefinition{}, fmt.Errorf("get task %q: %w", id, err) } - copied := types.TaskDefinition{ + created, err := s.taskStore.Create(types.TaskDefinition{ Name: src.Name + " (copy)", Input: src.Input, + }) + if err != nil { + return types.TaskDefinition{}, fmt.Errorf("copy task %q: %w", id, err) } - s.taskStore.Upsert(copied) - - if err := s.taskStore.Save(); err != nil { - return types.TaskDefinition{}, fmt.Errorf("save tasks: %w", err) - } - - all := s.taskStore.All() - if len(all) > 0 { - return all[0], nil - } - return copied, nil + return created, nil } -func (s *serverImpl) buildTaskOverviews(tasks []types.TaskDefinition, running map[string]types.TaskRunSummary) []TaskOverview { - decorated := make([]TaskOverview, 0, len(tasks)) - for _, task := range tasks { - decorated = append(decorated, s.buildTaskOverview(task, running)) +func (s *serverImpl) overlayRunningTaskOverviews(tasks []types.TaskOverview, running map[string]types.TaskRunSummary) []types.TaskOverview { + if len(running) == 0 { + return tasks } - return decorated -} - -func (s *serverImpl) buildTaskOverview(task types.TaskDefinition, running map[string]types.TaskRunSummary) TaskOverview { - overview := TaskOverview{TaskDefinition: task} - latest, err := s.runStore.LatestByTask(task.ID) - if err == nil && latest != nil { - summary := latest.Summary() - overview.LatestRun = &summary - } - if summary, ok := running[task.ID]; ok { - runningSummary := summary - overview.LatestRun = &runningSummary + overlaid := make([]types.TaskOverview, len(tasks)) + copy(overlaid, tasks) + for i := range overlaid { + if summary, ok := running[overlaid[i].ID]; ok { + runningSummary := summary + overlaid[i].LatestRun = &runningSummary + } } - return overview + return overlaid } -func (s *serverImpl) runningTaskSummariesLocked(tasks []types.TaskDefinition) map[string]types.TaskRunSummary { +func (s *serverImpl) runningTaskSummariesLocked(tasks []types.TaskOverview) map[string]types.TaskRunSummary { if len(tasks) == 0 || len(s.activeRuns) == 0 { return nil } taskByID := make(map[string]types.TaskDefinition, len(tasks)) for _, task := range tasks { - taskByID[task.ID] = task + taskByID[task.ID] = task.TaskDefinition } running := make(map[string]types.TaskRunSummary) @@ -167,3 +141,15 @@ func (s *serverImpl) runningTaskSummariesLocked(tasks []types.TaskDefinition) ma } return running } + +func (s *serverImpl) hasRunningTaskLocked(taskID string) bool { + for _, ar := range s.activeRuns { + ar.mu.RLock() + running := ar.state != nil && ar.state.TaskID == taskID && ar.state.Status == RunStatusRunning + ar.mu.RUnlock() + if running { + return true + } + } + return false +} diff --git a/internal/server/types.go b/internal/server/types.go index f98f361..1feda74 100644 --- a/internal/server/types.go +++ b/internal/server/types.go @@ -27,13 +27,6 @@ type TaskConfig struct { Input types.Input } -// TaskOverview 是面向列表/摘要读取的任务视图。 -// 它组合任务定义本体与最近一次运行摘要,不进入持久化层。 -type TaskOverview struct { - types.TaskDefinition - LatestRun *types.TaskRunSummary -} - // RunStatus 运行的生命周期状态。 type RunStatus string diff --git a/internal/store/run.go b/internal/store/run.go index 0acdf7d..4321e0d 100644 --- a/internal/store/run.go +++ b/internal/store/run.go @@ -13,8 +13,8 @@ import ( ) type RunMetadata struct { - RunID string `json:"run_id"` - TaskID string `json:"task_id"` + RunID string `json:"-"` + TaskID string `json:"-"` Mode string `json:"mode"` Protocol string `json:"protocol"` Model string `json:"model"` @@ -24,14 +24,7 @@ type RunMetadata struct { } type RunResult struct { - TotalReqs int `json:"total_reqs"` - DoneReqs int `json:"done_reqs"` - SuccessReqs int `json:"success_reqs"` - FailedReqs int `json:"failed_reqs"` - SuccessRate float64 `json:"success_rate"` - AvgTTFT time.Duration `json:"avg_ttft"` - AvgTPS float64 `json:"avg_tps"` - CacheHitRate float64 `json:"cache_hit_rate"` + TotalReqs int `json:"total_reqs,omitempty"` MaxStableConcurrency int `json:"max_stable_concurrency,omitempty"` ErrorSummary string `json:"error_summary,omitempty"` StandardResult *types.ReportData `json:"standard_result,omitempty"` @@ -147,36 +140,6 @@ func (s *RunStore) SaveFinal(meta RunMetadata, result RunResult) error { return NewJSONStore[RunResult](s.ResultPath(meta.TaskID, meta.RunID)).Save(result) } -func (s *RunStore) SaveSummary(summary types.TaskRunSummary) error { - if summary.TaskID == "" || summary.RunID == "" { - return fmt.Errorf("task id and run id are required") - } - - var finishedAt *time.Time - if !summary.FinishedAt.IsZero() { - finished := summary.FinishedAt - finishedAt = &finished - } - - return s.SaveFinal(RunMetadata{ - RunID: summary.RunID, - TaskID: summary.TaskID, - Mode: summary.Mode, - Protocol: summary.Protocol, - Model: summary.Model, - Status: summary.Status, - StartedAt: summary.StartedAt, - FinishedAt: finishedAt, - }, RunResult{ - SuccessRate: summary.SuccessRate, - AvgTTFT: summary.AvgTTFT, - AvgTPS: summary.AvgTPS, - CacheHitRate: summary.CacheHitRate, - MaxStableConcurrency: summary.MaxStableConcurrency, - ErrorSummary: summary.ErrorSummary, - }) -} - func (s *RunStore) Load(taskID, runID string) (*StoredRun, error) { metaPath := s.MetadataPath(taskID, runID) if _, err := os.Stat(metaPath); os.IsNotExist(err) { @@ -189,6 +152,8 @@ func (s *RunStore) Load(taskID, runID string) (*StoredRun, error) { if err != nil { return nil, err } + meta.TaskID = taskID + meta.RunID = runID resultPath := s.ResultPath(taskID, runID) var result *RunResult @@ -279,7 +244,47 @@ func (s *RunStore) DeleteTask(taskID string) error { return os.RemoveAll(s.TaskDir(taskID)) } -func (r StoredRun) Summary() types.TaskRunSummary { +func (s *RunStore) LoadSummary(taskID, runID string) (*types.TaskRunSummary, error) { + run, err := s.Load(taskID, runID) + if err != nil || run == nil { + return nil, err + } + requests, err := s.LoadRequests(taskID, runID) + if err != nil { + return nil, err + } + summary := run.Summary(requests) + return &summary, nil +} + +func (s *RunStore) LatestSummaryByTask(taskID string) (*types.TaskRunSummary, error) { + summaries, err := s.ListSummariesByTask(taskID, 1) + if err != nil { + return nil, err + } + if len(summaries) == 0 { + return nil, nil + } + return &summaries[0], nil +} + +func (s *RunStore) ListSummariesByTask(taskID string, limit int) ([]types.TaskRunSummary, error) { + runs, err := s.ListByTask(taskID, limit) + if err != nil { + return nil, err + } + summaries := make([]types.TaskRunSummary, 0, len(runs)) + for _, run := range runs { + requests, err := s.LoadRequests(run.Metadata.TaskID, run.Metadata.RunID) + if err != nil { + return nil, err + } + summaries = append(summaries, run.Summary(requests)) + } + return summaries, nil +} + +func (r StoredRun) Summary(requests []types.RequestMetrics) types.TaskRunSummary { summary := types.TaskRunSummary{ RunID: r.Metadata.RunID, TaskID: r.Metadata.TaskID, @@ -292,14 +297,85 @@ func (r StoredRun) Summary() types.TaskRunSummary { if r.Metadata.FinishedAt != nil { summary.FinishedAt = *r.Metadata.FinishedAt } + derived := summarizeRequests(requests) + summary.SuccessRate = derived.SuccessRate + summary.AvgTTFT = derived.AvgTTFT + summary.AvgTPS = derived.AvgTPS + summary.CacheHitRate = derived.CacheHitRate if r.Result != nil { - summary.SuccessRate = r.Result.SuccessRate - summary.AvgTTFT = r.Result.AvgTTFT - summary.AvgTPS = r.Result.AvgTPS - summary.CacheHitRate = r.Result.CacheHitRate - summary.MaxStableConcurrency = r.Result.MaxStableConcurrency summary.ErrorSummary = r.Result.ErrorSummary + summary.MaxStableConcurrency = r.Result.MaxStableConcurrency + if r.Result.StandardResult != nil { + summary.SuccessRate = r.Result.StandardResult.SuccessRate + summary.AvgTTFT = r.Result.StandardResult.AvgTTFT + summary.AvgTPS = r.Result.StandardResult.AvgTPS + summary.CacheHitRate = r.Result.StandardResult.AvgCacheHitRate + } + if r.Result.TurboResult != nil { + summary.MaxStableConcurrency = r.Result.TurboResult.MaxStableConcurrency + } + } + return summary +} + +func (r StoredRun) TotalReqs(requests []types.RequestMetrics) int { + if r.Result == nil { + return len(requests) + } + if r.Result.StandardResult != nil && r.Result.StandardResult.TotalRequests > 0 { + return r.Result.StandardResult.TotalRequests + } + if r.Result.TurboResult != nil { + total := 0 + for _, level := range r.Result.TurboResult.Levels { + total += level.TotalRequests + } + if total > 0 { + return total + } } + if r.Result.TotalReqs > 0 { + return r.Result.TotalReqs + } + return len(requests) +} + +type requestSummary struct { + DoneReqs int + SuccessReqs int + FailedReqs int + SuccessRate float64 + AvgTTFT time.Duration + AvgTPS float64 + CacheHitRate float64 +} + +func summarizeRequests(requests []types.RequestMetrics) requestSummary { + summary := requestSummary{DoneReqs: len(requests)} + var ttftSum time.Duration + var tpsSum float64 + var cacheSum float64 + + for _, request := range requests { + if !request.Success { + continue + } + summary.SuccessReqs++ + ttftSum += request.TTFT + tpsSum += request.TPS + cacheSum += request.CacheHitRate + } + + summary.FailedReqs = summary.DoneReqs - summary.SuccessReqs + if summary.DoneReqs > 0 { + summary.SuccessRate = float64(summary.SuccessReqs) / float64(summary.DoneReqs) * 100 + } + if summary.SuccessReqs > 0 { + summary.AvgTTFT = ttftSum / time.Duration(summary.SuccessReqs) + summary.AvgTPS = tpsSum / float64(summary.SuccessReqs) + summary.CacheHitRate = cacheSum / float64(summary.SuccessReqs) + } + return summary } diff --git a/internal/store/run_test.go b/internal/store/run_test.go new file mode 100644 index 0000000..0b45d14 --- /dev/null +++ b/internal/store/run_test.go @@ -0,0 +1,96 @@ +package store + +import ( + "encoding/json" + "os" + "strings" + "testing" + "time" + + "github.com/yinxulai/ait/internal/types" +) + +func TestRunStore_MetadataOmitsPathIdentifiers(t *testing.T) { + store := NewRunStore(t.TempDir()) + finishedAt := time.Now().UTC().Truncate(time.Second) + + meta := RunMetadata{ + RunID: "run-1", + TaskID: "task-1", + Mode: "standard", + Protocol: "openai-completions", + Model: "test-model", + Status: "completed", + StartedAt: finishedAt.Add(-time.Second), + FinishedAt: &finishedAt, + } + if err := store.SaveFinal(meta, RunResult{}); err != nil { + t.Fatalf("SaveFinal: %v", err) + } + + raw, err := os.ReadFile(store.MetadataPath(meta.TaskID, meta.RunID)) + if err != nil { + t.Fatalf("ReadFile: %v", err) + } + text := string(raw) + if strings.Contains(text, "run_id") || strings.Contains(text, "task_id") { + t.Fatalf("expected run metadata to omit path identifiers, got %s", raw) + } + + loaded, err := store.Load(meta.TaskID, meta.RunID) + if err != nil { + t.Fatalf("Load: %v", err) + } + if loaded == nil { + t.Fatal("expected stored run to load") + } + if loaded.Metadata.TaskID != meta.TaskID || loaded.Metadata.RunID != meta.RunID { + t.Fatalf("expected identifiers to be reconstructed from path, got task=%q run=%q", loaded.Metadata.TaskID, loaded.Metadata.RunID) + } +} + +func TestRunStore_ResultOmitsDerivedSummaryFields(t *testing.T) { + store := NewRunStore(t.TempDir()) + finishedAt := time.Now().UTC().Truncate(time.Second) + + if err := store.SaveFinal(RunMetadata{ + RunID: "run-2", + TaskID: "task-2", + Mode: "standard", + Status: "completed", + StartedAt: finishedAt.Add(-time.Second), + FinishedAt: &finishedAt, + }, RunResult{ + ErrorSummary: "boom", + StandardResult: &types.ReportData{ + TotalRequests: 4, + SuccessRate: 75, + AvgTPS: 12.5, + AvgTTFT: 100 * time.Millisecond, + }, + }); err != nil { + t.Fatalf("SaveFinal: %v", err) + } + + raw, err := os.ReadFile(store.ResultPath("task-2", "run-2")) + if err != nil { + t.Fatalf("ReadFile: %v", err) + } + + var payload map[string]json.RawMessage + if err := json.Unmarshal(raw, &payload); err != nil { + t.Fatalf("Unmarshal: %v", err) + } + + for _, key := range []string{"done_reqs", "success_reqs", "failed_reqs", "success_rate", "avg_ttft", "avg_tps", "cache_hit_rate"} { + if _, ok := payload[key]; ok { + t.Fatalf("expected derived summary field %q to be omitted from result.json, got %s", key, raw) + } + } + if _, ok := payload["standard_result"]; !ok { + t.Fatalf("expected final report payload to remain in result.json, got %s", raw) + } + if _, ok := payload["error_summary"]; !ok { + t.Fatalf("expected error_summary to remain in result.json, got %s", raw) + } +} diff --git a/internal/store/task.go b/internal/store/task.go index 8008365..ef140a5 100644 --- a/internal/store/task.go +++ b/internal/store/task.go @@ -1,6 +1,7 @@ package store import ( + "errors" "fmt" "os" "path/filepath" @@ -12,33 +13,31 @@ import ( ) type persistedTaskDefinition struct { - ID string `json:"id"` Name string `json:"name"` Input types.Input `json:"input"` CreatedAt time.Time `json:"created_at"` UpdatedAt time.Time `json:"updated_at"` } +var ErrTaskNotFound = errors.New("task not found") + // TaskStore 管理 ~/.ait/tasks/ 下的任务文件持久化。 +// 它是无状态仓储:每次调用直接从磁盘读取或写入单任务文件。 type TaskStore struct { - dir string - data []types.TaskDefinition + dir string } -// NewTaskStore 创建持久化到 dir 的 TaskStore(调用方需先调用 Load)。 func NewTaskStore(dir string) *TaskStore { return &TaskStore{dir: dir} } -// Load 从磁盘加载任务列表,目录不存在时初始化为空列表。 -func (s *TaskStore) Load() error { +func (s *TaskStore) List() ([]types.TaskDefinition, error) { entries, err := os.ReadDir(s.dir) if os.IsNotExist(err) { - s.data = []types.TaskDefinition{} - return nil + return []types.TaskDefinition{}, nil } if err != nil { - return err + return nil, err } tasks := make([]types.TaskDefinition, 0, len(entries)) @@ -50,131 +49,104 @@ func (s *TaskStore) Load() error { path := filepath.Join(s.dir, entry.Name()) stored, err := NewJSONStore[persistedTaskDefinition](path).Load() if err != nil { - return err + return nil, err } - if strings.TrimSpace(stored.ID) == "" { - stored.ID = strings.TrimSuffix(entry.Name(), filepath.Ext(entry.Name())) - } - tasks = append(tasks, types.TaskDefinition{ - ID: stored.ID, - Name: stored.Name, - Input: stored.Input, - CreatedAt: stored.CreatedAt, - UpdatedAt: stored.UpdatedAt, - }) + tasks = append(tasks, normalizeTaskDefinition(strings.TrimSuffix(entry.Name(), filepath.Ext(entry.Name())), stored)) } sort.Slice(tasks, func(i, j int) bool { return tasks[i].UpdatedAt.After(tasks[j].UpdatedAt) }) - s.data = tasks - return nil + return tasks, nil } -// Save 将当前内存中的任务列表持久化到磁盘,按任务拆分成独立文件。 -func (s *TaskStore) Save() error { - if err := os.MkdirAll(s.dir, 0o755); err != nil { - return err +func (s *TaskStore) Get(id string) (types.TaskDefinition, error) { + if strings.TrimSpace(id) == "" { + return types.TaskDefinition{}, ErrTaskNotFound } - keep := make(map[string]struct{}, len(s.data)) - for _, task := range s.data { - if strings.TrimSpace(task.ID) == "" { - return fmt.Errorf("task id cannot be empty") - } - keep[task.ID] = struct{}{} - path := filepath.Join(s.dir, task.ID+".json") - stored := persistedTaskDefinition{ - ID: task.ID, - Name: task.Name, - Input: task.Input, - CreatedAt: task.CreatedAt, - UpdatedAt: task.UpdatedAt, - } - if err := NewJSONStore[persistedTaskDefinition](path).Save(stored); err != nil { - return err - } + path := s.taskPath(id) + if _, err := os.Stat(path); os.IsNotExist(err) { + return types.TaskDefinition{}, ErrTaskNotFound + } else if err != nil { + return types.TaskDefinition{}, err } - entries, err := os.ReadDir(s.dir) + stored, err := NewJSONStore[persistedTaskDefinition](path).Load() if err != nil { - return err + return types.TaskDefinition{}, err } - for _, entry := range entries { - if entry.IsDir() || filepath.Ext(entry.Name()) != ".json" { - continue - } - taskID := strings.TrimSuffix(entry.Name(), filepath.Ext(entry.Name())) - if _, ok := keep[taskID]; ok { - continue - } - if err := os.Remove(filepath.Join(s.dir, entry.Name())); err != nil && !os.IsNotExist(err) { - return err - } - } - - return nil -} - -// All 返回所有任务的副本,最近更新的排在前面。 -func (s *TaskStore) All() []types.TaskDefinition { - result := make([]types.TaskDefinition, len(s.data)) - copy(result, s.data) - return result + return normalizeTaskDefinition(id, stored), nil } -// Get 按 ID 查找任务,返回副本。 -func (s *TaskStore) Get(id string) (types.TaskDefinition, bool) { - for _, t := range s.data { - if t.ID == id { - return t, true - } - } - return types.TaskDefinition{}, false -} - -// Upsert 新建或更新任务。 -// - 若 task.ID 为空,自动生成唯一 ID。 -// - 更新时将任务移至列表头部(最近活跃排序)。 -func (s *TaskStore) Upsert(task types.TaskDefinition) { +func (s *TaskStore) Create(task types.TaskDefinition) (types.TaskDefinition, error) { now := time.Now() - if task.ID == "" { + if strings.TrimSpace(task.ID) == "" { task.ID = fmt.Sprintf("task_%d", now.UnixNano()) } + if task.CreatedAt.IsZero() { + task.CreatedAt = now + } + task.UpdatedAt = now + if err := s.writeTask(task); err != nil { + return types.TaskDefinition{}, err + } + return task, nil +} - for i, existing := range s.data { - if existing.ID != task.ID { - continue - } - if task.CreatedAt.IsZero() { - task.CreatedAt = existing.CreatedAt - } - task.UpdatedAt = now - // 移至列表头部 - tasks := make([]types.TaskDefinition, 0, len(s.data)) - tasks = append(tasks, task) - tasks = append(tasks, s.data[:i]...) - tasks = append(tasks, s.data[i+1:]...) - s.data = tasks - return +func (s *TaskStore) Update(task types.TaskDefinition) (types.TaskDefinition, error) { + if strings.TrimSpace(task.ID) == "" { + return types.TaskDefinition{}, ErrTaskNotFound } - // 新增 + existing, err := s.Get(task.ID) + if err != nil { + return types.TaskDefinition{}, err + } if task.CreatedAt.IsZero() { - task.CreatedAt = now + task.CreatedAt = existing.CreatedAt } - task.UpdatedAt = now - s.data = append([]types.TaskDefinition{task}, s.data...) + task.UpdatedAt = time.Now() + if err := s.writeTask(task); err != nil { + return types.TaskDefinition{}, err + } + return task, nil } -// Delete 按 ID 删除任务,任务不存在时返回错误。 func (s *TaskStore) Delete(id string) error { - for i, t := range s.data { - if t.ID == id { - s.data = append(s.data[:i], s.data[i+1:]...) - return nil - } + err := os.Remove(s.taskPath(id)) + if os.IsNotExist(err) { + return ErrTaskNotFound + } + return err +} + +func (s *TaskStore) taskPath(id string) string { + return filepath.Join(s.dir, id+".json") +} + +func (s *TaskStore) writeTask(task types.TaskDefinition) error { + if strings.TrimSpace(task.ID) == "" { + return fmt.Errorf("task id cannot be empty") + } + if err := os.MkdirAll(s.dir, 0o755); err != nil { + return err + } + return NewJSONStore[persistedTaskDefinition](s.taskPath(task.ID)).Save(persistedTaskDefinition{ + Name: task.Name, + Input: task.Input, + CreatedAt: task.CreatedAt, + UpdatedAt: task.UpdatedAt, + }) +} + +func normalizeTaskDefinition(fallbackID string, stored persistedTaskDefinition) types.TaskDefinition { + return types.TaskDefinition{ + ID: fallbackID, + Name: stored.Name, + Input: stored.Input, + CreatedAt: stored.CreatedAt, + UpdatedAt: stored.UpdatedAt, } - return fmt.Errorf("task %q not found", id) } diff --git a/internal/store/task_test.go b/internal/store/task_test.go new file mode 100644 index 0000000..d740e21 --- /dev/null +++ b/internal/store/task_test.go @@ -0,0 +1,163 @@ +package store + +import ( + "errors" + "os" + "path/filepath" + "strings" + "testing" + "time" + + "github.com/yinxulai/ait/internal/types" +) + +func TestTaskStore_CRUD(t *testing.T) { + store := NewTaskStore(filepath.Join(t.TempDir(), "tasks")) + + created, err := store.Create(types.TaskDefinition{ + Name: "task-a", + Input: types.Input{ + Protocol: types.ProtocolOpenAICompletions, + EndpointURL: "http://localhost:19999", + Model: "test-model", + PromptMode: "text", + PromptText: "hello", + }, + }) + if err != nil { + t.Fatalf("Create: %v", err) + } + if created.ID == "" { + t.Fatal("expected created task to have ID") + } + if created.CreatedAt.IsZero() || created.UpdatedAt.IsZero() { + t.Fatal("expected timestamps to be populated") + } + + loaded, err := store.Get(created.ID) + if err != nil { + t.Fatalf("Get: %v", err) + } + if loaded.Name != created.Name { + t.Fatalf("Name: got %q, want %q", loaded.Name, created.Name) + } + + tasks, err := store.List() + if err != nil { + t.Fatalf("List: %v", err) + } + if len(tasks) != 1 { + t.Fatalf("expected 1 task, got %d", len(tasks)) + } + + raw, err := os.ReadFile(store.taskPath(created.ID)) + if err != nil { + t.Fatalf("ReadFile: %v", err) + } + if strings.Contains(string(raw), "\"id\"") { + t.Fatalf("expected task file to omit duplicated id field, got %s", raw) + } + + time.Sleep(time.Millisecond) + updated, err := store.Update(types.TaskDefinition{ + ID: created.ID, + Name: "task-b", + Input: created.Input, + CreatedAt: created.CreatedAt, + }) + if err != nil { + t.Fatalf("Update: %v", err) + } + if updated.Name != "task-b" { + t.Fatalf("Name after update: got %q, want %q", updated.Name, "task-b") + } + if !updated.UpdatedAt.After(created.UpdatedAt) { + t.Fatalf("expected UpdatedAt to advance, created=%v updated=%v", created.UpdatedAt, updated.UpdatedAt) + } + + if err := store.Delete(created.ID); err != nil { + t.Fatalf("Delete: %v", err) + } + if _, err := store.Get(created.ID); !errors.Is(err, ErrTaskNotFound) { + t.Fatalf("expected ErrTaskNotFound after delete, got %v", err) + } + if err := store.Delete(created.ID); !errors.Is(err, ErrTaskNotFound) { + t.Fatalf("expected ErrTaskNotFound on second delete, got %v", err) + } +} + +func TestTaskViewStore_ListIncludesLatestRun(t *testing.T) { + root := t.TempDir() + tasks := NewTaskStore(filepath.Join(root, "tasks")) + runs := NewRunStore(filepath.Join(root, "runs")) + views := NewTaskViewStore(tasks, runs) + + taskWithRun, err := tasks.Create(types.TaskDefinition{ + Name: "with-run", + Input: types.Input{Protocol: types.ProtocolOpenAICompletions}, + }) + if err != nil { + t.Fatalf("Create(with-run): %v", err) + } + _, err = tasks.Create(types.TaskDefinition{ + Name: "no-run", + Input: types.Input{Protocol: types.ProtocolOpenAICompletions}, + }) + if err != nil { + t.Fatalf("Create(no-run): %v", err) + } + + olderFinishedAt := time.Now().Add(-2 * time.Minute) + newerFinishedAt := time.Now().Add(-time.Minute) + if err := runs.SaveFinal(RunMetadata{ + RunID: "run-older", + TaskID: taskWithRun.ID, + Mode: "standard", + Status: "completed", + StartedAt: olderFinishedAt.Add(-time.Second), + FinishedAt: &olderFinishedAt, + }, RunResult{}); err != nil { + t.Fatalf("SaveFinal(older): %v", err) + } + if err := runs.SaveFinal(RunMetadata{ + RunID: "run-newer", + TaskID: taskWithRun.ID, + Mode: "standard", + Status: "completed", + StartedAt: newerFinishedAt.Add(-time.Second), + FinishedAt: &newerFinishedAt, + }, RunResult{}); err != nil { + t.Fatalf("SaveFinal(newer): %v", err) + } + + overviews, err := views.List() + if err != nil { + t.Fatalf("List: %v", err) + } + if len(overviews) != 2 { + t.Fatalf("expected 2 task overviews, got %d", len(overviews)) + } + + seenRun := false + seenNoRun := false + for _, overview := range overviews { + switch overview.Name { + case "with-run": + seenRun = true + if overview.LatestRun == nil { + t.Fatal("expected LatestRun for task with persisted runs") + } + if overview.LatestRun.RunID != "run-newer" { + t.Fatalf("LatestRun.RunID: got %q, want %q", overview.LatestRun.RunID, "run-newer") + } + case "no-run": + seenNoRun = true + if overview.LatestRun != nil { + t.Fatal("expected LatestRun to be nil when no persisted run exists") + } + } + } + if !seenRun || !seenNoRun { + t.Fatalf("expected to see both task overviews, got %+v", overviews) + } +} diff --git a/internal/store/taskview.go b/internal/store/taskview.go new file mode 100644 index 0000000..56ff847 --- /dev/null +++ b/internal/store/taskview.go @@ -0,0 +1,35 @@ +package store + +import "github.com/yinxulai/ait/internal/types" + +// TaskViewStore 负责聚合任务定义和最近一次已完成运行,用于列表读取。 +type TaskViewStore struct { + tasks *TaskStore + runs *RunStore +} + +func NewTaskViewStore(tasks *TaskStore, runs *RunStore) *TaskViewStore { + return &TaskViewStore{tasks: tasks, runs: runs} +} + +func (s *TaskViewStore) List() ([]types.TaskOverview, error) { + tasks, err := s.tasks.List() + if err != nil { + return nil, err + } + + overviews := make([]types.TaskOverview, 0, len(tasks)) + for _, task := range tasks { + overview := types.TaskOverview{TaskDefinition: task} + latest, err := s.runs.LatestSummaryByTask(task.ID) + if err != nil { + return nil, err + } + if latest != nil { + overview.LatestRun = latest + } + overviews = append(overviews, overview) + } + + return overviews, nil +} diff --git a/internal/tui/client.go b/internal/tui/client.go index b605c58..065555d 100644 --- a/internal/tui/client.go +++ b/internal/tui/client.go @@ -25,7 +25,11 @@ func NewClient(srv server.Server) *Client { // LoadTasksCmd 异步加载任务列表。 func (c *Client) LoadTasksCmd() tea.Cmd { return func() tea.Msg { - return TasksLoadedMsg{Tasks: c.srv.ListTasks()} + tasks, err := c.srv.ListTasks() + if err != nil { + return ErrorMsg{Err: fmt.Errorf("加载任务失败: %w", err)} + } + return TasksLoadedMsg{Tasks: tasks} } } diff --git a/internal/tui/messages.go b/internal/tui/messages.go index 3099253..c572f1b 100644 --- a/internal/tui/messages.go +++ b/internal/tui/messages.go @@ -7,7 +7,7 @@ import ( // TasksLoadedMsg 任务列表加载完成(初始化或刷新后)。 type TasksLoadedMsg struct { - Tasks []server.TaskOverview + Tasks []types.TaskOverview } // TaskSavedMsg 新建或更新任务完成。 diff --git a/internal/tui/model_test.go b/internal/tui/model_test.go index e45c5a3..4b170aa 100644 --- a/internal/tui/model_test.go +++ b/internal/tui/model_test.go @@ -11,8 +11,8 @@ import ( // stubServer 是 server.Server 的测试桩,所有方法都返回零值。 type stubServer struct{} -func (s *stubServer) ListTasks() []server.TaskOverview { return nil } -func (s *stubServer) GetTask(id string) (types.TaskDefinition, bool) { return types.TaskDefinition{}, false } +func (s *stubServer) ListTasks() ([]types.TaskOverview, error) { return nil, nil } +func (s *stubServer) GetTask(id string) (types.TaskDefinition, error) { return types.TaskDefinition{}, nil } func (s *stubServer) CreateTask(cfg server.TaskConfig) (types.TaskDefinition, error) { return types.TaskDefinition{}, nil } func (s *stubServer) UpdateTask(id string, cfg server.TaskConfig) (types.TaskDefinition, error) { return types.TaskDefinition{}, nil diff --git a/internal/tui/pages/tasklist.go b/internal/tui/pages/tasklist.go index 5352b04..a42b6c2 100644 --- a/internal/tui/pages/tasklist.go +++ b/internal/tui/pages/tasklist.go @@ -13,7 +13,7 @@ import ( // TaskListState 任务列表页状态。 type TaskListState struct { - Tasks []server.TaskOverview + Tasks []types.TaskOverview Selected int Offset int Visible int diff --git a/internal/types/types.go b/internal/types/types.go index f9523b7..43c9406 100644 --- a/internal/types/types.go +++ b/internal/types/types.go @@ -237,6 +237,11 @@ type TaskDefinition struct { UpdatedAt time.Time `json:"updated_at"` } +type TaskOverview struct { + TaskDefinition + LatestRun *TaskRunSummary `json:"latest_run,omitempty"` +} + type TaskRunSummary struct { RunID string `json:"run_id"` TaskID string `json:"task_id"` From 7dc5ec3e22e2c4f56ecdb471fc4087ac6c63e63e Mon Sep 17 00:00:00 2001 From: Alain Date: Wed, 20 May 2026 08:53:20 +0800 Subject: [PATCH 19/52] feat: Add proxy support for HTTP clients and update related components - Introduced ProxyURL field in Input struct to allow configuration of proxy settings. - Implemented newMeasuredTransport function to create HTTP transport with proxy support. - Updated OpenAIClient and other clients to utilize the new transport with proxy settings. - Modified tests to validate proxy functionality and ensure correct behavior with various configurations. - Enhanced TUI components to display proxy information and maintain navigation state. - Refactored dashboard and task detail pages to improve layout and user experience. - Adjusted wizard state to include proxy configuration and ensure proper defaults. --- internal/client/anthropic.go | 6 +- internal/client/anthropic_test.go | 206 +++++++++++++------------- internal/client/client_test.go | 16 +- internal/client/integration_test.go | 12 +- internal/client/openai.go | 7 +- internal/client/openai_test.go | 42 +++--- internal/client/transport.go | 34 +++++ internal/client/transport_test.go | 81 ++++++++++ internal/runner/runner_test.go | 4 +- internal/tui/model.go | 51 +++++-- internal/tui/model_test.go | 64 +++++++- internal/tui/pages/dashboard.go | 25 ++-- internal/tui/pages/reqdetail.go | 18 ++- internal/tui/pages/taskdetail.go | 15 +- internal/tui/pages/taskdetail_test.go | 13 ++ internal/tui/pages/turbodash.go | 25 ++-- internal/tui/pages/wizard.go | 37 ++++- internal/tui/pages/wizard_test.go | 60 ++++++++ internal/types/types.go | 1 + 19 files changed, 508 insertions(+), 209 deletions(-) create mode 100644 internal/client/transport.go create mode 100644 internal/client/transport_test.go create mode 100644 internal/tui/pages/wizard_test.go diff --git a/internal/client/anthropic.go b/internal/client/anthropic.go index 8e5f041..7959fe3 100644 --- a/internal/client/anthropic.go +++ b/internal/client/anthropic.go @@ -80,11 +80,7 @@ type AnthropicClient struct { // 网络栈性能,包括 DNS 解析、TCP 连接建立、TLS 握手等。 // - DisableCompression=false: 启用压缩以节省带宽 func NewAnthropicClient(config types.Input) *AnthropicClient { - // 禁用连接复用以确保每个请求都是独立的 - transport := &http.Transport{ - DisableKeepAlives: true, - DisableCompression: false, - } + transport := newMeasuredTransport(config) return &AnthropicClient{ EndpointURL: config.ResolvedEndpointURL(), diff --git a/internal/client/anthropic_test.go b/internal/client/anthropic_test.go index 715747d..63d3fb9 100644 --- a/internal/client/anthropic_test.go +++ b/internal/client/anthropic_test.go @@ -124,10 +124,10 @@ func TestNewAnthropicClient(t *testing.T) { }, want: &AnthropicClient{ EndpointURL: "https://api.anthropic.com/v1/messages", - ApiKey: "test-key", - Model: "claude-3-sonnet-20240229", - Provider: types.ProtocolAnthropicMessages, - Thinking: false, + ApiKey: "test-key", + Model: "claude-3-sonnet-20240229", + Provider: types.ProtocolAnthropicMessages, + Thinking: false, }, }, } @@ -183,7 +183,7 @@ func TestAnthropicClient_Request_NonStream(t *testing.T) { client := NewAnthropicClient(createTestConfig(server.URL, "test-key", "claude-3-sonnet-20240229", 30*time.Second, false)) start := time.Now() - metrics, err := client.Request("test prompt", false) + metrics, err := client.Request("", "test prompt", false) elapsed := time.Since(start) if err != nil { @@ -220,7 +220,7 @@ func TestAnthropicClient_Request_Stream(t *testing.T) { client := NewAnthropicClient(createTestConfig(server.URL, "test-key", "claude-3-sonnet-20240229", 30*time.Second, false)) start := time.Now() - metrics, err := client.Request("test prompt", true) + metrics, err := client.Request("", "test prompt", true) elapsed := time.Since(start) if err != nil { @@ -252,7 +252,7 @@ func TestAnthropicClient_Request_ServerError(t *testing.T) { client := NewAnthropicClient(createTestConfig(server.URL, "test-key", "claude-3-sonnet-20240229", 30*time.Second, false)) - metrics, err := client.Request("test prompt", false) + metrics, err := client.Request("", "test prompt", false) if err == nil { t.Error("Request() should return error for server error") @@ -294,7 +294,7 @@ func TestAnthropicClient_Request_InvalidEndpoint(t *testing.T) { client := NewAnthropicClient(createTestConfig(server.URL, "test-key", "claude-3-sonnet-20240229", 30*time.Second, false)) // 这应该成功,因为我们使用的是正确的端点 - _, err := client.Request("test prompt", false) + _, err := client.Request("", "test prompt", false) if err != nil { t.Errorf("Request() should succeed with correct endpoint, got error: %v", err) } @@ -326,7 +326,7 @@ func TestAnthropicClient_Request_MissingHeaders(t *testing.T) { client := NewAnthropicClient(createTestConfig(server.URL, "test-key", "claude-3-sonnet-20240229", 30*time.Second, false)) // 这应该成功,因为我们的客户端发送了正确的请求头 - _, err := client.Request("test prompt", false) + _, err := client.Request("", "test prompt", false) if err != nil { t.Errorf("Request() should succeed with correct headers, got error: %v", err) } @@ -336,7 +336,7 @@ func TestAnthropicClient_Request_NetworkError(t *testing.T) { // 使用一个无效的地址来模拟网络错误 client := NewAnthropicClient(createTestConfig("http://invalid-host-that-does-not-exist.example", "test-key", "claude-3-sonnet-20240229", 30*time.Second, false)) - metrics, err := client.Request("test prompt", false) + metrics, err := client.Request("", "test prompt", false) // 应该返回错误 if err == nil { @@ -367,7 +367,7 @@ func TestAnthropicClient_Request_InvalidURL(t *testing.T) { // 使用一个格式错误的 URL client := NewAnthropicClient(createTestConfig("://invalid-url", "test-key", "claude-3-sonnet-20240229", 30*time.Second, false)) - metrics, err := client.Request("test prompt", false) + metrics, err := client.Request("", "test prompt", false) // 应该返回错误 if err == nil { @@ -465,17 +465,17 @@ func TestAnthropicClient_ConnectionReuse(t *testing.T) { // 创建一个测试服务器,记录连接数 connectionCount := 0 var connMu sync.Mutex - + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // 每个请求到达时记录 connMu.Lock() connectionCount++ currentCount := connectionCount connMu.Unlock() - + w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) - + // 返回简单的非流式响应 response := fmt.Sprintf(`{"id":"msg-%d","type":"message","role":"assistant","content":[{"type":"text","text":"Response %d"}],"model":"claude-3","usage":{"input_tokens":1,"output_tokens":1}}`, currentCount, currentCount) w.Write([]byte(response)) @@ -489,7 +489,7 @@ func TestAnthropicClient_ConnectionReuse(t *testing.T) { if !ok { t.Fatal("Expected client to use http.Transport") } - + if !transport.DisableKeepAlives { t.Error("Expected DisableKeepAlives to be true to prevent connection reuse") } @@ -497,17 +497,17 @@ func TestAnthropicClient_ConnectionReuse(t *testing.T) { // 发送多个串行请求来验证不复用连接的行为 requestCount := 3 for i := 0; i < requestCount; i++ { - metrics, err := client.Request(fmt.Sprintf("test prompt %d", i), false) + metrics, err := client.Request("", fmt.Sprintf("test prompt %d", i), false) if err != nil { t.Errorf("Request %d failed: %v", i, err) continue } - + if metrics == nil { t.Errorf("Request %d returned nil metrics", i) continue } - + // 验证每个请求都有合理的时间指标 if metrics.TotalTime <= 0 { t.Errorf("Request %d has invalid TotalTime: %v", i, metrics.TotalTime) @@ -518,7 +518,7 @@ func TestAnthropicClient_ConnectionReuse(t *testing.T) { connMu.Lock() finalCount := connectionCount connMu.Unlock() - + if finalCount != requestCount { t.Errorf("Expected %d requests to reach server, got %d", requestCount, finalCount) } @@ -528,17 +528,17 @@ func TestAnthropicClient_ConnectionReuse(t *testing.T) { func TestAnthropicClient_NoConnectionReuse(t *testing.T) { // 验证客户端的 Transport 配置确实禁用了连接复用 client := NewAnthropicClient(createTestConfig("https://api.anthropic.com", "test-key", "claude-3-sonnet", 30*time.Second, false)) - + transport, ok := client.httpClient.Transport.(*http.Transport) if !ok { t.Fatal("Expected client to use http.Transport") } - + // 关键验证:DisableKeepAlives 应该为 true if !transport.DisableKeepAlives { t.Error("DisableKeepAlives should be true to prevent connection reuse, which could affect timing measurements") } - + // DisableCompression 应该为 false(我们想要压缩以节省带宽) if transport.DisableCompression { t.Error("DisableCompression should be false to enable compression") @@ -591,7 +591,7 @@ func TestAnthropicClient_Request_MalformedJSON(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) - + if strings.Contains(r.Header.Get("Accept"), "text/event-stream") { // 流式响应:发送畸形的 JSON w.Write([]byte("event: content_block_delta\n")) @@ -611,7 +611,7 @@ func TestAnthropicClient_Request_MalformedJSON(t *testing.T) { // 测试非流式请求的 JSON 解析错误 t.Run("non-stream malformed JSON", func(t *testing.T) { - _, err := client.Request("test prompt", false) + _, err := client.Request("", "test prompt", false) if err == nil { t.Error("Expected error for malformed JSON response") } @@ -619,7 +619,7 @@ func TestAnthropicClient_Request_MalformedJSON(t *testing.T) { // 测试流式请求(应该跳过畸形的 JSON 并处理有效的) t.Run("stream with some malformed JSON", func(t *testing.T) { - metrics, err := client.Request("test prompt", true) + metrics, err := client.Request("", "test prompt", true) if err != nil { t.Errorf("Request should succeed even with some malformed JSON: %v", err) } @@ -649,7 +649,7 @@ func TestAnthropicClient_Request_BodyReadError(t *testing.T) { client := NewAnthropicClient(createTestConfig(server.URL, "test-key", "claude-3-sonnet", 30*time.Second, false)) - _, err := client.Request("test prompt", false) + _, err := client.Request("", "test prompt", false) if err == nil { t.Error("Expected error when response body cannot be read") } @@ -660,7 +660,7 @@ func TestAnthropicClient_Request_ScannerError(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/event-stream") w.WriteHeader(http.StatusOK) - + // 发送一个非常长的行,可能导致 scanner 错误 longLine := strings.Repeat("x", 1024*1024) // 1MB 的数据 fmt.Fprintf(w, "event: content_block_delta\ndata: %s\n\n", longLine) @@ -671,7 +671,7 @@ func TestAnthropicClient_Request_ScannerError(t *testing.T) { client := NewAnthropicClient(createTestConfig(server.URL, "test-key", "claude-3-sonnet", 30*time.Second, false)) // 这个测试可能会因为 scanner 的缓冲区限制而失败 - metrics, err := client.Request("test prompt", true) + metrics, err := client.Request("", "test prompt", true) // 无论成功还是失败都是正常的,关键是要覆盖这个代码路径 if err != nil { t.Logf("Scanner error (expected in some cases): %v", err) @@ -724,7 +724,7 @@ func TestAnthropicClient_Request_EdgeCases(t *testing.T) { defer server.Close() client := NewAnthropicClient(createTestConfig(server.URL, "test-key", "claude-3-sonnet", 30*time.Second, false)) - _, err := client.Request("test", tt.stream) + _, err := client.Request("", "test", tt.stream) if tt.expectError && err == nil { t.Error("Expected error but got none") @@ -742,14 +742,14 @@ func TestAnthropicClient_ConcurrentRequests(t *testing.T) { time.Sleep(50 * time.Millisecond) // 模拟慢响应 w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) - + response := `{"id":"test","type":"message","content":[{"type":"text","text":"concurrent response"}],"usage":{"output_tokens":2}}` w.Write([]byte(response)) })) defer server.Close() client := NewAnthropicClient(createTestConfig(server.URL, "test-key", "claude-3-sonnet", 30*time.Second, false)) - + // 并发执行多个请求 numRequests := 10 var wg sync.WaitGroup @@ -761,9 +761,9 @@ func TestAnthropicClient_ConcurrentRequests(t *testing.T) { wg.Add(1) go func(id int) { defer wg.Done() - - metrics, err := client.Request(fmt.Sprintf("concurrent test %d", id), false) - + + metrics, err := client.Request("", fmt.Sprintf("concurrent test %d", id), false) + mu.Lock() if err != nil { errors = append(errors, err) @@ -785,7 +785,7 @@ func TestAnthropicClient_ConcurrentRequests(t *testing.T) { t.Errorf("Concurrent request error: %v", err) } } - + if successCount != numRequests { t.Errorf("Expected %d successful requests, got %d", numRequests, successCount) } @@ -802,12 +802,12 @@ func TestAnthropicClient_Request_TimeoutHandling(t *testing.T) { // 创建一个超时时间很短的客户端 client := NewAnthropicClient(createTestConfig(server.URL, "test-key", "claude-3-sonnet", 100*time.Millisecond, false)) - - _, err := client.Request("timeout test", false) + + _, err := client.Request("", "timeout test", false) if err == nil { t.Error("Expected timeout error but got none") } - + // 确保错误信息包含超时相关内容 if !strings.Contains(err.Error(), "timeout") && !strings.Contains(err.Error(), "context deadline exceeded") { t.Errorf("Expected timeout-related error, got: %v", err) @@ -818,7 +818,7 @@ func TestAnthropicClient_Request_EmptyContentArray(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) - + if strings.Contains(r.Header.Get("Accept"), "text/event-stream") { // 流式响应:发送空的 content w.Write([]byte("event: message_start\n")) @@ -835,7 +835,7 @@ func TestAnthropicClient_Request_EmptyContentArray(t *testing.T) { client := NewAnthropicClient(createTestConfig(server.URL, "test-key", "claude-3-sonnet", 30*time.Second, false)) // 测试非流式请求 - metrics, err := client.Request("test", false) + metrics, err := client.Request("", "test", false) if err != nil { t.Errorf("Request should succeed with empty content: %v", err) } @@ -844,7 +844,7 @@ func TestAnthropicClient_Request_EmptyContentArray(t *testing.T) { } // 测试流式请求 - metrics, err = client.Request("test", true) + metrics, err = client.Request("", "test", true) if err != nil { t.Errorf("Stream request should succeed with empty content: %v", err) } @@ -858,26 +858,26 @@ func TestAnthropicClient_Request_StreamWithThinking(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") w.Header().Set("Transfer-Encoding", "chunked") - + flusher, _ := w.(http.Flusher) - + // 发送开始事件 fmt.Fprint(w, "event: message_start\n") fmt.Fprint(w, `data: {"type": "message_start", "message": {"id": "msg_test", "type": "message", "role": "assistant", "content": [], "model": "claude-3-sonnet", "usage": {"input_tokens": 10, "output_tokens": 0}}}`+"\n\n") flusher.Flush() - + // 模拟延迟,然后发送 thinking 内容 time.Sleep(10 * time.Millisecond) fmt.Fprint(w, "event: content_block_delta\n") fmt.Fprint(w, `data: {"type": "content_block_delta", "index": 0, "delta": {"type": "text_delta", "thinking": "Let me think about this..."}}`+"\n\n") flusher.Flush() - + // 再发送一些普通文本 time.Sleep(5 * time.Millisecond) fmt.Fprint(w, "event: content_block_delta\n") fmt.Fprint(w, `data: {"type": "content_block_delta", "index": 0, "delta": {"type": "text_delta", "text": "Hello there!"}}`+"\n\n") flusher.Flush() - + // 发送结束事件 fmt.Fprint(w, "event: message_delta\n") fmt.Fprint(w, `data: {"type": "message_delta", "delta": {"stop_reason": "end_turn"}, "usage": {"output_tokens": 10}}`+"\n\n") @@ -886,23 +886,23 @@ func TestAnthropicClient_Request_StreamWithThinking(t *testing.T) { defer server.Close() client := NewAnthropicClient(createTestConfig(server.URL, "test-key", "claude-3-sonnet", 30*time.Second, false)) - + start := time.Now() - metrics, err := client.Request("test prompt", true) - + metrics, err := client.Request("", "test prompt", true) + if err != nil { t.Errorf("Request() error = %v", err) } - + if metrics.TimeToFirstToken <= 0 { t.Errorf("Request() TTFT should be > 0 when thinking content is present, got %v", metrics.TimeToFirstToken) } - + // TTFT 应该在第一个 thinking 输出时就开始计算 if metrics.TimeToFirstToken > time.Since(start) { t.Errorf("TTFT should be calculated from thinking output, got %v", metrics.TimeToFirstToken) } - + if metrics.CompletionTokens != 10 { t.Errorf("Request() CompletionTokens = %v, want 10", metrics.CompletionTokens) } @@ -913,26 +913,26 @@ func TestAnthropicClient_Request_StreamWithPartialJSON(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") w.Header().Set("Transfer-Encoding", "chunked") - + flusher, _ := w.(http.Flusher) - + // 发送开始事件 fmt.Fprint(w, "event: message_start\n") fmt.Fprint(w, `data: {"type": "message_start", "message": {"id": "msg_test", "type": "message", "role": "assistant", "content": [], "model": "claude-3-sonnet", "usage": {"input_tokens": 10, "output_tokens": 0}}}`+"\n\n") flusher.Flush() - + // 模拟延迟,然后发送 partial_json 内容 time.Sleep(10 * time.Millisecond) fmt.Fprint(w, "event: content_block_delta\n") fmt.Fprint(w, `data: {"type": "content_block_delta", "index": 0, "delta": {"type": "text_delta", "partial_json": "{\"name\": \"John\""}}`+"\n\n") flusher.Flush() - + // 继续发送更多的 partial_json time.Sleep(5 * time.Millisecond) fmt.Fprint(w, "event: content_block_delta\n") fmt.Fprint(w, `data: {"type": "content_block_delta", "index": 0, "delta": {"type": "text_delta", "partial_json": ", \"age\": 30}"}}`+"\n\n") flusher.Flush() - + // 发送结束事件 fmt.Fprint(w, "event: message_delta\n") fmt.Fprint(w, `data: {"type": "message_delta", "delta": {"stop_reason": "end_turn"}, "usage": {"output_tokens": 8}}`+"\n\n") @@ -941,23 +941,23 @@ func TestAnthropicClient_Request_StreamWithPartialJSON(t *testing.T) { defer server.Close() client := NewAnthropicClient(createTestConfig(server.URL, "test-key", "claude-3-sonnet", 30*time.Second, false)) - + start := time.Now() - metrics, err := client.Request("test prompt", true) - + metrics, err := client.Request("", "test prompt", true) + if err != nil { t.Errorf("Request() error = %v", err) } - + if metrics.TimeToFirstToken <= 0 { t.Errorf("Request() TTFT should be > 0 when partial_json content is present, got %v", metrics.TimeToFirstToken) } - + // TTFT 应该在第一个 partial_json 输出时就开始计算 if metrics.TimeToFirstToken > time.Since(start) { t.Errorf("TTFT should be calculated from partial_json output, got %v", metrics.TimeToFirstToken) } - + if metrics.CompletionTokens != 8 { t.Errorf("Request() CompletionTokens = %v, want 8", metrics.CompletionTokens) } @@ -968,32 +968,32 @@ func TestAnthropicClient_Request_StreamWithMixedContent(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") w.Header().Set("Transfer-Encoding", "chunked") - + flusher, _ := w.(http.Flusher) - + // 发送开始事件 fmt.Fprint(w, "event: message_start\n") fmt.Fprint(w, `data: {"type": "message_start", "message": {"id": "msg_test", "type": "message", "role": "assistant", "content": [], "model": "claude-3-sonnet", "usage": {"input_tokens": 10, "output_tokens": 0}}}`+"\n\n") flusher.Flush() - + // 首先发送 thinking 内容 time.Sleep(15 * time.Millisecond) fmt.Fprint(w, "event: content_block_delta\n") fmt.Fprint(w, `data: {"type": "content_block_delta", "index": 0, "delta": {"type": "text_delta", "thinking": "I need to analyze this carefully..."}}`+"\n\n") flusher.Flush() - + // 然后发送 partial_json time.Sleep(5 * time.Millisecond) fmt.Fprint(w, "event: content_block_delta\n") fmt.Fprint(w, `data: {"type": "content_block_delta", "index": 0, "delta": {"type": "text_delta", "partial_json": "{\"result\": \""}}`+"\n\n") flusher.Flush() - + // 最后发送普通文本 time.Sleep(5 * time.Millisecond) fmt.Fprint(w, "event: content_block_delta\n") fmt.Fprint(w, `data: {"type": "content_block_delta", "index": 0, "delta": {"type": "text_delta", "text": "This is the final answer."}}`+"\n\n") flusher.Flush() - + // 发送结束事件 fmt.Fprint(w, "event: message_delta\n") fmt.Fprint(w, `data: {"type": "message_delta", "delta": {"stop_reason": "end_turn"}, "usage": {"output_tokens": 20}}`+"\n\n") @@ -1002,29 +1002,29 @@ func TestAnthropicClient_Request_StreamWithMixedContent(t *testing.T) { defer server.Close() client := NewAnthropicClient(createTestConfig(server.URL, "test-key", "claude-3-sonnet", 30*time.Second, false)) - + start := time.Now() - metrics, err := client.Request("test prompt", true) - + metrics, err := client.Request("", "test prompt", true) + if err != nil { t.Errorf("Request() error = %v", err) } - + if metrics.TimeToFirstToken <= 0 { t.Errorf("Request() TTFT should be > 0 with mixed content, got %v", metrics.TimeToFirstToken) } - + // TTFT 应该在第一个内容输出时就开始计算(thinking 内容) - expectedMinTime := 10 * time.Millisecond // 小于第一个 thinking 输出的延迟 + expectedMinTime := 10 * time.Millisecond // 小于第一个 thinking 输出的延迟 if metrics.TimeToFirstToken < expectedMinTime { t.Errorf("TTFT seems too fast, expected >= %v, got %v", expectedMinTime, metrics.TimeToFirstToken) } - + expectedMaxTime := time.Since(start) if metrics.TimeToFirstToken > expectedMaxTime { t.Errorf("TTFT should be calculated from first output, got %v", metrics.TimeToFirstToken) } - + if metrics.CompletionTokens != 20 { t.Errorf("Request() CompletionTokens = %v, want 20", metrics.CompletionTokens) } @@ -1035,32 +1035,32 @@ func TestAnthropicClient_Request_StreamWithEmptyThinkingAndPartialJSON(t *testin server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") w.Header().Set("Transfer-Encoding", "chunked") - + flusher, _ := w.(http.Flusher) - + // 发送开始事件 fmt.Fprint(w, "event: message_start\n") fmt.Fprint(w, `data: {"type": "message_start", "message": {"id": "msg_test", "type": "message", "role": "assistant", "content": [], "model": "claude-3-sonnet", "usage": {"input_tokens": 10, "output_tokens": 0}}}`+"\n\n") flusher.Flush() - + // 发送空的 thinking 内容(不应该触发 TTFT) time.Sleep(10 * time.Millisecond) fmt.Fprint(w, "event: content_block_delta\n") fmt.Fprint(w, `data: {"type": "content_block_delta", "index": 0, "delta": {"type": "text_delta", "thinking": ""}}`+"\n\n") flusher.Flush() - + // 发送空的 partial_json 内容(不应该触发 TTFT) time.Sleep(5 * time.Millisecond) fmt.Fprint(w, "event: content_block_delta\n") fmt.Fprint(w, `data: {"type": "content_block_delta", "index": 0, "delta": {"type": "text_delta", "partial_json": ""}}`+"\n\n") flusher.Flush() - + // 最后发送真正的文本内容(应该触发 TTFT) time.Sleep(5 * time.Millisecond) fmt.Fprint(w, "event: content_block_delta\n") fmt.Fprint(w, `data: {"type": "content_block_delta", "index": 0, "delta": {"type": "text_delta", "text": "Real content here"}}`+"\n\n") flusher.Flush() - + // 发送结束事件 fmt.Fprint(w, "event: message_delta\n") fmt.Fprint(w, `data: {"type": "message_delta", "delta": {"stop_reason": "end_turn"}, "usage": {"output_tokens": 5}}`+"\n\n") @@ -1069,29 +1069,29 @@ func TestAnthropicClient_Request_StreamWithEmptyThinkingAndPartialJSON(t *testin defer server.Close() client := NewAnthropicClient(createTestConfig(server.URL, "test-key", "claude-3-sonnet", 30*time.Second, false)) - + start := time.Now() - metrics, err := client.Request("test prompt", true) - + metrics, err := client.Request("", "test prompt", true) + if err != nil { t.Errorf("Request() error = %v", err) } - + if metrics.TimeToFirstToken <= 0 { t.Errorf("Request() TTFT should be > 0 when real text content is present, got %v", metrics.TimeToFirstToken) } - + // TTFT 应该在真正的文本内容输出时计算,而不是空的 thinking/partial_json - expectedMinTime := 15 * time.Millisecond // 应该大于前两个空内容的延迟总和 + expectedMinTime := 15 * time.Millisecond // 应该大于前两个空内容的延迟总和 if metrics.TimeToFirstToken < expectedMinTime { t.Errorf("TTFT should be calculated from real text content, expected >= %v, got %v", expectedMinTime, metrics.TimeToFirstToken) } - + expectedMaxTime := time.Since(start) if metrics.TimeToFirstToken > expectedMaxTime { t.Errorf("TTFT calculation error, got %v", metrics.TimeToFirstToken) } - + if metrics.CompletionTokens != 5 { t.Errorf("Request() CompletionTokens = %v, want 5", metrics.CompletionTokens) } @@ -1107,7 +1107,7 @@ func TestAnthropicClient_Request_ErrorHandlingFixes(t *testing.T) { defer server.Close() client := NewAnthropicClient(createTestConfig(server.URL, "test-key", "claude-3-sonnet", 30*time.Second, false)) - metrics, err := client.Request("test prompt", false) + metrics, err := client.Request("", "test prompt", false) // 应该有错误 if err == nil { @@ -1146,7 +1146,7 @@ func TestAnthropicClient_Request_ErrorHandlingFixes(t *testing.T) { defer server.Close() client := NewAnthropicClient(createTestConfig(server.URL, "test-key", "claude-3-sonnet", 30*time.Second, false)) - metrics, err := client.Request("test prompt", false) + metrics, err := client.Request("", "test prompt", false) // 应该有错误 if err == nil { @@ -1172,7 +1172,7 @@ func TestAnthropicClient_Request_ErrorHandlingFixes(t *testing.T) { t.Run("Response body read error returns metrics", func(t *testing.T) { // 测试策略:创建一个声称有内容但实际没有完整内容的响应 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Length", "1000") + w.Header().Set("Content-Length", "1000") w.WriteHeader(http.StatusOK) w.Write([]byte("incomplete")) // 在 httptest 环境中,这种情况通常不会导致 io.ReadAll 错误 @@ -1181,7 +1181,7 @@ func TestAnthropicClient_Request_ErrorHandlingFixes(t *testing.T) { defer server.Close() client := NewAnthropicClient(createTestConfig(server.URL, "test-key", "claude-3-sonnet", 30*time.Second, false)) - metrics, err := client.Request("test prompt", false) + metrics, err := client.Request("", "test prompt", false) // 这种情况下通常会是 JSON 解析错误而不是读取错误 if metrics == nil && err != nil { @@ -1193,7 +1193,7 @@ func TestAnthropicClient_Request_ErrorHandlingFixes(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/event-stream") w.WriteHeader(http.StatusOK) - + // 发送一些无效的 JSON 数据块,然后发送有效的 w.Write([]byte("data: {invalid json}\n\n")) w.Write([]byte("data: {\"type\":\"content_block_delta\",\"delta\":{\"type\":\"text_delta\",\"text\":\"valid\"}}\n\n")) @@ -1202,7 +1202,7 @@ func TestAnthropicClient_Request_ErrorHandlingFixes(t *testing.T) { defer server.Close() client := NewAnthropicClient(createTestConfig(server.URL, "test-key", "claude-3-sonnet", 30*time.Second, false)) - metrics, err := client.Request("test prompt", true) + metrics, err := client.Request("", "test prompt", true) // 流式处理应该继续,即使有些 JSON 块无效 if err != nil { @@ -1253,7 +1253,7 @@ func TestAnthropicClient_Request_ErrorHandlingFixes(t *testing.T) { defer server.Close() client := NewAnthropicClient(createTestConfig(server.URL, "test-key", "claude-3-sonnet", 30*time.Second, false)) - metrics, err := client.Request("test prompt", false) + metrics, err := client.Request("", "test prompt", false) // 所有类型的错误都应该返回错误 if err == nil { @@ -1267,7 +1267,7 @@ func TestAnthropicClient_Request_ErrorHandlingFixes(t *testing.T) { // 验证错误信息包含预期内容 if !strings.Contains(metrics.ErrorMessage, tc.expectedErrMsg) { - t.Errorf("Expected ErrorMessage to contain '%s' for %s, got: %s", + t.Errorf("Expected ErrorMessage to contain '%s' for %s, got: %s", tc.expectedErrMsg, tc.name, metrics.ErrorMessage) } @@ -1301,21 +1301,21 @@ func TestAnthropicClientWithConfig(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { client := NewAnthropicClient(createTestConfig("https://api.anthropic.com", "test-key", "claude-3-sonnet", 30*time.Second, tt.thinking)) - + // 验证 thinking 字段设置正确 if client.Thinking != tt.thinking { t.Errorf("Expected Thinking = %v, got %v", tt.thinking, client.Thinking) } - + // 验证其他基本字段 if client.EndpointURL != "https://api.anthropic.com/v1/messages" { t.Errorf("Expected EndpointURL = https://api.anthropic.com/v1/messages, got %s", client.EndpointURL) } - + if client.Model != "claude-3-sonnet" { t.Errorf("Expected Model = claude-3-sonnet, got %s", client.Model) } - + if client.Provider != types.ProtocolAnthropicMessages { t.Errorf("Expected Provider = %s, got %s", types.ProtocolAnthropicMessages, client.Provider) } diff --git a/internal/client/client_test.go b/internal/client/client_test.go index 1f6b899..e4d9158 100644 --- a/internal/client/client_test.go +++ b/internal/client/client_test.go @@ -70,11 +70,11 @@ func TestNewClient(t *testing.T) { { name: "invalid provider", config: types.Input{ - Protocol: "invalid", + Protocol: "invalid", EndpointURL: "https://api.test.com/v1/anything", - ApiKey: "test-key", - Model: "test-model", - Timeout: 30 * time.Second, + ApiKey: "test-key", + Model: "test-model", + Timeout: 30 * time.Second, }, wantError: true, }, @@ -157,11 +157,11 @@ func TestNewClientWithTimeout(t *testing.T) { { name: "invalid provider with timeout", config: types.Input{ - Protocol: "invalid", + Protocol: "invalid", EndpointURL: "https://api.test.com/v1/anything", - ApiKey: "test-key", - Model: "test-model", - Timeout: 5 * time.Second, + ApiKey: "test-key", + Model: "test-model", + Timeout: 5 * time.Second, }, wantError: true, }, diff --git a/internal/client/integration_test.go b/internal/client/integration_test.go index aa083cb..59002ae 100644 --- a/internal/client/integration_test.go +++ b/internal/client/integration_test.go @@ -80,7 +80,7 @@ func TestOpenAIClient_Request_NonStream(t *testing.T) { client := NewOpenAIClient(createIntegrationTestConfig(server.URL, "test-key", "gpt-3.5-turbo")) start := time.Now() - metrics, err := client.Request("test prompt", false) + metrics, err := client.Request("", "test prompt", false) elapsed := time.Since(start) if err != nil { @@ -104,7 +104,7 @@ func TestOpenAIClient_Request_Stream(t *testing.T) { client := NewOpenAIClient(createIntegrationTestConfig(server.URL, "test-key", "gpt-3.5-turbo")) start := time.Now() - metrics, err := client.Request("test prompt", true) + metrics, err := client.Request("", "test prompt", true) elapsed := time.Since(start) if err != nil { @@ -129,7 +129,7 @@ func TestOpenAIClient_Request_ServerError(t *testing.T) { client := NewOpenAIClient(createIntegrationTestConfig(server.URL, "test-key", "gpt-3.5-turbo")) - _, err := client.Request("test prompt", false) + _, err := client.Request("", "test prompt", false) if err == nil { t.Error("Request() should return error for server error") @@ -140,7 +140,7 @@ func TestOpenAIClient_Request_NetworkError(t *testing.T) { // 使用一个无效的地址来模拟网络错误 client := NewOpenAIClient(createIntegrationTestConfig("http://invalid-host-that-does-not-exist.example", "test-key", "gpt-3.5-turbo")) - metrics, err := client.Request("test prompt", false) + metrics, err := client.Request("", "test prompt", false) // 应该返回错误 if err == nil { @@ -171,7 +171,7 @@ func TestOpenAIClient_Request_StreamNetworkError(t *testing.T) { // 测试流式模式下的网络错误 client := NewOpenAIClient(createIntegrationTestConfig("http://invalid-host-that-does-not-exist.example", "test-key", "gpt-3.5-turbo")) - metrics, err := client.Request("test prompt", true) + metrics, err := client.Request("", "test prompt", true) // 应该返回错误 if err == nil { @@ -194,7 +194,7 @@ func TestOpenAIClient_Request_InvalidURL(t *testing.T) { // 使用一个格式错误的 URL client := NewOpenAIClient(createIntegrationTestConfig("://invalid-url", "test-key", "gpt-3.5-turbo")) - metrics, err := client.Request("test prompt", false) + metrics, err := client.Request("", "test prompt", false) // 应该返回错误 if err == nil { diff --git a/internal/client/openai.go b/internal/client/openai.go index b06e455..993a520 100644 --- a/internal/client/openai.go +++ b/internal/client/openai.go @@ -363,12 +363,7 @@ type OpenAIClient struct { // - DisableCompression=false: 启用压缩以节省带宽 func NewOpenAIClient(config types.Input) *OpenAIClient { endpointURL := config.ResolvedEndpointURL() - - // 禁用连接复用以确保每个请求都是独立的 - transport := &http.Transport{ - DisableKeepAlives: true, - DisableCompression: false, - } + transport := newMeasuredTransport(config) return &OpenAIClient{ httpClient: &http.Client{ diff --git a/internal/client/openai_test.go b/internal/client/openai_test.go index 23505e1..20a8a39 100644 --- a/internal/client/openai_test.go +++ b/internal/client/openai_test.go @@ -216,7 +216,7 @@ func TestOpenAIClient_ConnectionReuse(t *testing.T) { // 发送多个串行请求来验证不复用连接的行为 requestCount := 3 for i := 0; i < requestCount; i++ { - metrics, err := client.Request(fmt.Sprintf("test prompt %d", i), false) + metrics, err := client.Request("", fmt.Sprintf("test prompt %d", i), false) if err != nil { t.Errorf("Request %d failed: %v", i, err) continue @@ -310,7 +310,7 @@ func TestOpenAIClient_ConnectionReuseImpactOnTiming(t *testing.T) { // 发送多个请求,每次都应该包含完整的连接建立时间 var totalTimes []time.Duration for i := 0; i < 3; i++ { - metrics, err := clientWithoutReuse.Request("test", false) + metrics, err := clientWithoutReuse.Request("", "test", false) if err != nil { t.Fatalf("Request failed: %v", err) } @@ -330,13 +330,13 @@ func TestOpenAIClient_ConnectionReuseImpactOnTiming(t *testing.T) { t.Run("with connection reuse demonstration", func(t *testing.T) { // 这里我们演示连接复用的情况,但在实际的性能测试工具中应该避免 // 首个请求建立连接 - metrics1, err := clientWithReuse.Request("test", false) + metrics1, err := clientWithReuse.Request("", "test", false) if err != nil { t.Fatalf("First request failed: %v", err) } // 后续请求可能复用连接,时间可能更短 - metrics2, err := clientWithReuse.Request("test", false) + metrics2, err := clientWithReuse.Request("", "test", false) if err != nil { t.Fatalf("Second request failed: %v", err) } @@ -424,7 +424,7 @@ func TestOpenAIClient_Request_MalformedJSON(t *testing.T) { // 测试非流式请求的 JSON 解析错误 t.Run("non-stream malformed JSON", func(t *testing.T) { - _, err := client.Request("test prompt", false) + _, err := client.Request("", "test prompt", false) if err == nil { t.Error("Expected error for malformed JSON response") } @@ -432,7 +432,7 @@ func TestOpenAIClient_Request_MalformedJSON(t *testing.T) { // 测试流式请求(应该跳过畸形的 JSON 并处理有效的) t.Run("stream with some malformed JSON", func(t *testing.T) { - metrics, err := client.Request("test prompt", true) + metrics, err := client.Request("", "test prompt", true) if err != nil { t.Errorf("Request should succeed even with some malformed JSON: %v", err) } @@ -458,7 +458,7 @@ func TestOpenAIClient_Request_OpenAIResponses_NonStream(t *testing.T) { defer server.Close() client := NewOpenAIClient(createOpenAIResponsesTestConfig(server.URL, "test-key", "gpt-4.1-mini", 30*time.Second, false)) - metrics, err := client.Request("hello from responses", false) + metrics, err := client.Request("", "hello from responses", false) if err != nil { t.Fatalf("Request() unexpected error: %v", err) } @@ -493,7 +493,7 @@ func TestOpenAIClient_Request_OpenAIResponses_Stream(t *testing.T) { defer server.Close() client := NewOpenAIClient(createOpenAIResponsesTestConfig(server.URL, "test-key", "gpt-4.1-mini", 30*time.Second, false)) - metrics, err := client.Request("stream me", true) + metrics, err := client.Request("", "stream me", true) if err != nil { t.Fatalf("Request() unexpected error: %v", err) } @@ -528,7 +528,7 @@ func TestOpenAIClient_Request_BodyReadError(t *testing.T) { client := NewOpenAIClient(createOpenAITestConfig(server.URL, "test-key", "test-model", 0, false)) - _, err := client.Request("test prompt", false) + _, err := client.Request("", "test prompt", false) if err == nil { t.Error("Expected error when response body cannot be read") } @@ -550,7 +550,7 @@ func TestOpenAIClient_Request_ScannerError(t *testing.T) { client := NewOpenAIClient(createOpenAITestConfig(server.URL, "test-key", "test-model", 0, false)) // 这个测试可能会因为 scanner 的缓冲区限制而失败 - metrics, err := client.Request("test prompt", true) + metrics, err := client.Request("", "test prompt", true) // 无论成功还是失败都是正常的,关键是要覆盖这个代码路径 if err != nil { t.Logf("Scanner error (expected in some cases): %v", err) @@ -603,7 +603,7 @@ func TestOpenAIClient_Request_EdgeCases(t *testing.T) { defer server.Close() client := NewOpenAIClient(createOpenAITestConfig(server.URL, "test-key", "test-model", 0, false)) - _, err := client.Request("test", tt.stream) + _, err := client.Request("", "test", tt.stream) if tt.expectError && err == nil { t.Error("Expected error but got none") @@ -641,7 +641,7 @@ func TestOpenAIClient_ConcurrentRequests(t *testing.T) { go func(id int) { defer wg.Done() - metrics, err := client.Request(fmt.Sprintf("concurrent test %d", id), false) + metrics, err := client.Request("", fmt.Sprintf("concurrent test %d", id), false) mu.Lock() if err != nil { @@ -682,7 +682,7 @@ func TestOpenAIClient_Request_TimeoutHandling(t *testing.T) { // 创建一个超时时间很短的客户端 client := NewOpenAIClient(createOpenAITestConfig(server.URL, "test-key", "test-model", 100*time.Millisecond, false)) - _, err := client.Request("timeout test", false) + _, err := client.Request("", "timeout test", false) if err == nil { t.Error("Expected timeout error but got none") } @@ -712,7 +712,7 @@ func TestOpenAIClient_Request_EmptyChoicesArray(t *testing.T) { client := NewOpenAIClient(createOpenAITestConfig(server.URL, "test-key", "test-model", 0, false)) // 测试非流式请求 - metrics, err := client.Request("test", false) + metrics, err := client.Request("", "test", false) if err != nil { t.Errorf("Request should succeed with empty choices: %v", err) } @@ -721,7 +721,7 @@ func TestOpenAIClient_Request_EmptyChoicesArray(t *testing.T) { } // 测试流式请求 - metrics, err = client.Request("test", true) + metrics, err = client.Request("", "test", true) if err != nil { t.Errorf("Stream request should succeed with empty choices: %v", err) } @@ -830,7 +830,7 @@ func TestOpenAIClient_Request_ThinkingContent(t *testing.T) { client := NewOpenAIClient(createOpenAITestConfig(server.URL, "test-key", "test-model", 0, false)) - metrics, err := client.Request("test prompt", true) + metrics, err := client.Request("", "test prompt", true) if err != nil { t.Errorf("Request failed: %v", err) return @@ -890,7 +890,7 @@ func TestOpenAIClient_Request_TTFTAccuracy(t *testing.T) { client := NewOpenAIClient(createOpenAITestConfig(server.URL, "test-key", "test-model", 0, false)) start := time.Now() - metrics, err := client.Request("test prompt", true) + metrics, err := client.Request("", "test prompt", true) totalDuration := time.Since(start) if err != nil { @@ -933,7 +933,7 @@ func TestOpenAIClient_Request_ErrorHandlingFixes(t *testing.T) { defer server.Close() client := NewOpenAIClient(createOpenAITestConfig(server.URL, "test-key", "test-model", 0, false)) - metrics, err := client.Request("test prompt", false) + metrics, err := client.Request("", "test prompt", false) // 应该有错误 if err == nil { @@ -972,7 +972,7 @@ func TestOpenAIClient_Request_ErrorHandlingFixes(t *testing.T) { defer server.Close() client := NewOpenAIClient(createOpenAITestConfig(server.URL, "test-key", "test-model", 0, false)) - metrics, err := client.Request("test prompt", false) + metrics, err := client.Request("", "test prompt", false) // 应该有错误 if err == nil { @@ -1009,7 +1009,7 @@ func TestOpenAIClient_Request_ErrorHandlingFixes(t *testing.T) { defer server.Close() client := NewOpenAIClient(createOpenAITestConfig(server.URL, "test-key", "test-model", 0, false)) - metrics, err := client.Request("test prompt", false) + metrics, err := client.Request("", "test prompt", false) // 注意:这个测试可能不会触发 io.ReadAll 错误,因为 httptest 服务器的行为 // 但我们仍然验证基本的错误处理逻辑 @@ -1031,7 +1031,7 @@ func TestOpenAIClient_Request_ErrorHandlingFixes(t *testing.T) { defer server.Close() client := NewOpenAIClient(createOpenAITestConfig(server.URL, "test-key", "test-model", 0, false)) - metrics, err := client.Request("test prompt", true) + metrics, err := client.Request("", "test prompt", true) // 流式处理应该继续,即使有些 JSON 块无效 if err != nil { diff --git a/internal/client/transport.go b/internal/client/transport.go new file mode 100644 index 0000000..d0eebbc --- /dev/null +++ b/internal/client/transport.go @@ -0,0 +1,34 @@ +package client + +import ( + "fmt" + "net/http" + "net/url" + "strings" + + "github.com/yinxulai/ait/internal/types" +) + +func newMeasuredTransport(config types.Input) *http.Transport { + transport := &http.Transport{ + DisableKeepAlives: true, + DisableCompression: false, + Proxy: http.ProxyFromEnvironment, + } + + proxyURL := strings.TrimSpace(config.ProxyURL) + if proxyURL == "" { + return transport + } + + parsed, err := url.Parse(proxyURL) + if err == nil && parsed.Scheme != "" && parsed.Host != "" { + transport.Proxy = http.ProxyURL(parsed) + return transport + } + + transport.Proxy = func(*http.Request) (*url.URL, error) { + return nil, fmt.Errorf("invalid proxy_url: %s", proxyURL) + } + return transport +} diff --git a/internal/client/transport_test.go b/internal/client/transport_test.go new file mode 100644 index 0000000..c72789a --- /dev/null +++ b/internal/client/transport_test.go @@ -0,0 +1,81 @@ +package client + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/yinxulai/ait/internal/types" +) + +func TestNewMeasuredTransport_ExplicitProxy(t *testing.T) { + transport := newMeasuredTransport(types.Input{ProxyURL: "http://proxy.example:8080"}) + if transport.Proxy == nil { + t.Fatal("Proxy should be configured when proxy_url is provided") + } + + proxy, err := transport.Proxy(httptest.NewRequest(http.MethodGet, "https://api.example.com", nil)) + if err != nil { + t.Fatalf("Proxy returned error: %v", err) + } + if proxy == nil || proxy.String() != "http://proxy.example:8080" { + t.Fatalf("proxy = %v, want http://proxy.example:8080", proxy) + } +} + +func TestNewMeasuredTransport_InvalidProxy(t *testing.T) { + transport := newMeasuredTransport(types.Input{ProxyURL: "://bad proxy"}) + if transport.Proxy == nil { + t.Fatal("Proxy callback should be set for invalid proxy_url") + } + + if _, err := transport.Proxy(httptest.NewRequest(http.MethodGet, "https://api.example.com", nil)); err == nil { + t.Fatal("expected invalid proxy_url to return an error") + } +} + +func TestNewClients_UseConfiguredProxy(t *testing.T) { + constructors := []struct { + name string + transport func() *http.Transport + }{ + { + name: "openai", + transport: func() *http.Transport { + client := NewOpenAIClient(types.Input{ + Protocol: types.ProtocolOpenAICompletions, + ProxyURL: "http://proxy.example:8080", + }) + transport, _ := client.httpClient.Transport.(*http.Transport) + return transport + }, + }, + { + name: "anthropic", + transport: func() *http.Transport { + client := NewAnthropicClient(types.Input{ + Protocol: types.ProtocolAnthropicMessages, + ProxyURL: "http://proxy.example:8080", + }) + transport, _ := client.httpClient.Transport.(*http.Transport) + return transport + }, + }, + } + + for _, tt := range constructors { + t.Run(tt.name, func(t *testing.T) { + transport := tt.transport() + if transport == nil { + t.Fatal("expected http.Transport") + } + proxy, err := transport.Proxy(httptest.NewRequest(http.MethodGet, "https://api.example.com", nil)) + if err != nil { + t.Fatalf("Proxy returned error: %v", err) + } + if proxy == nil || proxy.String() != "http://proxy.example:8080" { + t.Fatalf("proxy = %v, want http://proxy.example:8080", proxy) + } + }) + } +} diff --git a/internal/runner/runner_test.go b/internal/runner/runner_test.go index e5f227d..1dd42d7 100644 --- a/internal/runner/runner_test.go +++ b/internal/runner/runner_test.go @@ -32,7 +32,7 @@ type MockClient struct { model string } -func (m *MockClient) Request(prompt string, stream bool) (*client.ResponseMetrics, error) { +func (m *MockClient) Request(systemPrompt, prompt string, stream bool) (*client.ResponseMetrics, error) { callIndex := atomic.AddInt64(&m.callCount, 1) - 1 if m.requestDelay > 0 { @@ -1499,7 +1499,7 @@ type MockClientWithErrorMetrics struct { errorMetrics *client.ResponseMetrics } -func (m *MockClientWithErrorMetrics) Request(prompt string, stream bool) (*client.ResponseMetrics, error) { +func (m *MockClientWithErrorMetrics) Request(systemPrompt, prompt string, stream bool) (*client.ResponseMetrics, error) { callIndex := atomic.AddInt64(&m.callCount, 1) - 1 // 检查是否应该失败 diff --git a/internal/tui/model.go b/internal/tui/model.go index 37b881d..8c79ea0 100644 --- a/internal/tui/model.go +++ b/internal/tui/model.go @@ -152,24 +152,29 @@ func (m *Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { if msg.State == nil { return m, nil } - if m.dash != nil && m.dash.RunID == msg.State.RunID { - m.dash.RunState = msg.State - } else if m.turboDash != nil && m.turboDash.RunID == msg.State.RunID { - m.turboDash.RunState = msg.State - } else if msg.FromHistory { - // 从历史记录导航过来:根据运行模式选择对应仪表盘,并设置返回目标为任务详情 - backNav := pages.NavAction{To: pages.NavTaskDetail} + if msg.FromHistory { + backNav := pages.NavAction{To: pages.NavTaskDetail, TaskID: msg.State.TaskID} if msg.State.Mode == "turbo" { - m.turboDash = pages.NewTurboDashState(msg.State.RunID, msg.State.TaskID) + if m.turboDash == nil || m.turboDash.RunID != msg.State.RunID { + m.turboDash = pages.NewTurboDashState(msg.State.RunID, msg.State.TaskID) + } m.turboDash.RunState = msg.State m.turboDash.BackNav = backNav m.view = viewTurboDash } else { - m.dash = pages.NewDashboardState(msg.State.RunID, msg.State.TaskID) + if m.dash == nil || m.dash.RunID != msg.State.RunID { + m.dash = pages.NewDashboardState(msg.State.RunID, msg.State.TaskID) + } m.dash.RunState = msg.State m.dash.BackNav = backNav m.view = viewDashboard } + return m, nil + } + if m.dash != nil && m.dash.RunID == msg.State.RunID { + m.dash.RunState = msg.State + } else if m.turboDash != nil && m.turboDash.RunID == msg.State.RunID { + m.turboDash.RunState = msg.State } return m, nil @@ -270,12 +275,20 @@ func (m *Model) handleNav(nav pages.NavAction) tea.Cmd { return m.client.LoadTasksCmd() case pages.NavTaskDetail: + backNav := m.taskDetailBackNav() task := m.findTask(nav.TaskID) if task != nil { - m.detail = pages.NewTaskDetailState(*task) + if m.detail != nil && m.detail.Task.ID == task.ID { + m.detail.Task = *task + } else { + m.detail = pages.NewTaskDetailState(*task) + } } else if m.detail == nil { return nil } + if m.detail != nil { + m.detail.BackNav = backNav + } // 若该任务有正在运行的实例,注入快照 if m.detail != nil && m.taskList != nil { if rs, ok := m.taskList.ActiveRuns[m.detail.Task.ID]; ok && rs != nil { @@ -522,6 +535,24 @@ func (m *Model) currentRunTaskID(isDash bool) string { return "" } +func (m *Model) taskDetailBackNav() pages.NavAction { + switch m.view { + case viewDashboard: + if m.dash != nil { + return pages.NavAction{To: pages.NavDashboard} + } + case viewTurboDash: + if m.turboDash != nil { + return pages.NavAction{To: pages.NavTurboDash} + } + case viewTaskDetail: + if m.detail != nil && m.detail.BackNav.To != pages.NavNone { + return m.detail.BackNav + } + } + return pages.NavAction{To: pages.NavTaskList} +} + func (m *Model) collectRequests() []*types.RequestMetrics { // 优先使用当前活跃视图的数据,避免两个面板均有 RunState 时取错 switch m.view { diff --git a/internal/tui/model_test.go b/internal/tui/model_test.go index 4b170aa..2385dcd 100644 --- a/internal/tui/model_test.go +++ b/internal/tui/model_test.go @@ -58,8 +58,14 @@ func TestOpenWizard_NewTask_Defaults(t *testing.T) { if m.wizard.Concurrency <= 0 { t.Errorf("default concurrency should be positive, got %d", m.wizard.Concurrency) } - if m.wizard.PromptMode != pages.PromptModeText { - t.Errorf("default PromptMode = %q, want %q", m.wizard.PromptMode, pages.PromptModeText) + if !m.wizard.Stream { + t.Error("default Stream = false, want true") + } + if m.wizard.PromptMode != pages.PromptModeGenerated { + t.Errorf("default PromptMode = %q, want %q", m.wizard.PromptMode, pages.PromptModeGenerated) + } + if m.wizard.PromptLength != 100 { + t.Errorf("default PromptLength = %d, want 100", m.wizard.PromptLength) } } @@ -71,6 +77,7 @@ func TestOpenWizard_EditTask_Populate(t *testing.T) { Input: types.Input{ Model: "gpt-4", Protocol: types.ProtocolOpenAICompletions, + ProxyURL: "http://proxy.internal:8080", ApiKey: "sk-test", Concurrency: 5, Count: 50, @@ -91,6 +98,9 @@ func TestOpenWizard_EditTask_Populate(t *testing.T) { if m.wizard.Concurrency != 5 { t.Errorf("Concurrency = %d, want 5", m.wizard.Concurrency) } + if m.wizard.ProxyURL != "http://proxy.internal:8080" { + t.Errorf("ProxyURL = %q, want %q", m.wizard.ProxyURL, "http://proxy.internal:8080") + } } func TestBuildTaskInput_Standard(t *testing.T) { @@ -99,6 +109,7 @@ func TestBuildTaskInput_Standard(t *testing.T) { wz := m.wizard wz.Model = "gpt-4.1" wz.APIKey = "sk-test" + wz.ProxyURL = "http://proxy.internal:8080" wz.Concurrency = 8 wz.Count = 120 wz.PromptMode = pages.PromptModeText @@ -115,6 +126,9 @@ func TestBuildTaskInput_Standard(t *testing.T) { if inp.Count != 120 { t.Errorf("count = %d, want 120", inp.Count) } + if inp.ProxyURL != "http://proxy.internal:8080" { + t.Errorf("proxy_url = %q, want %q", inp.ProxyURL, "http://proxy.internal:8080") + } if inp.PromptMode != pages.PromptModeText || inp.PromptText != "hello" { t.Errorf("unexpected prompt config: mode=%q text=%q", inp.PromptMode, inp.PromptText) } @@ -153,3 +167,49 @@ func TestBuildTaskInput_Turbo(t *testing.T) { t.Errorf("protocol = %q, want anthropic-messages", inp.Protocol) } } + +func TestRunStateMsg_FromHistory_ReopensExistingDashboard(t *testing.T) { + m := NewModel(&stubServer{}) + m.view = viewTaskDetail + m.dash = pages.NewDashboardState("run-1", "task-1") + + updated, _ := m.Update(RunStateMsg{ + State: &server.RunState{ + RunID: "run-1", + TaskID: "task-1", + Mode: "standard", + }, + FromHistory: true, + }) + + got := updated.(*Model) + if got.view != viewDashboard { + t.Fatalf("view = %q, want %q", got.view, viewDashboard) + } + if got.dash == nil || got.dash.RunState == nil { + t.Fatal("dashboard should hold loaded history run state") + } + if got.dash.BackNav.To != pages.NavTaskDetail { + t.Fatalf("dash.BackNav.To = %v, want %v", got.dash.BackNav.To, pages.NavTaskDetail) + } +} + +func TestOpenWizard_EditTask_InferLegacyPromptMode(t *testing.T) { + m := NewModel(&stubServer{}) + task := types.TaskDefinition{ + ID: "task-legacy", + Name: "legacy-task", + Input: types.Input{ + Protocol: types.ProtocolOpenAICompletions, + PromptText: "legacy prompt", + }, + } + + m.wizard = pages.NewWizardStateEdit(&task) + if m.wizard.PromptMode != pages.PromptModeText { + t.Errorf("PromptMode = %q, want %q", m.wizard.PromptMode, pages.PromptModeText) + } + if m.wizard.PromptText != "legacy prompt" { + t.Errorf("PromptText = %q, want %q", m.wizard.PromptText, "legacy prompt") + } +} diff --git a/internal/tui/pages/dashboard.go b/internal/tui/pages/dashboard.go index 32a2b33..512de60 100644 --- a/internal/tui/pages/dashboard.go +++ b/internal/tui/pages/dashboard.go @@ -123,11 +123,6 @@ func HandleDashboardKey(d *DashboardState, msg tea.KeyMsg, client Client) (*Dash } case "b", "esc": - if d.CancelFn != nil { - d.CancelFn() - } - d.EventCh = nil - d.CancelFn = nil if d.BackNav.To != NavNone { nav = d.BackNav } else { @@ -191,20 +186,22 @@ func RenderDashboard(d *DashboardState, taskName string, st Styles, width, heigh } l := PageLayout{ CtxItems: cbItems, - FooterParts: []string{"[b/Esc] 返回列表", "[q] 退出"}, + FooterParts: []string{"[b/Esc] 返回上一页", "[q] 退出"}, } + innerW := ContentWidth(width) + innerH := l.ContentHeight(height) // ── 计算高度 ── splitH := 9 // 双栏面板外部总高度(含面板边框) progressPanel := 3 // 进度条面板外部高度(1内容+2边框) - reqListH := height - l.ChromeHeight() - splitH - progressPanel - 2 // -2 for req panel border + reqListH := innerH - splitH - progressPanel - 2 // -2 for req panel border if reqListH < 3 { reqListH = 3 } // ── 双栏面板(任务参数 | 实时指标)── - leftW := width * 45 / 100 - rightW := width - leftW + leftW := innerW * 45 / 100 + rightW := innerW - leftW leftContent := buildDashParamsPanel(d, rs, st, splitH-2, leftW-2) rightContent := buildDashMetricsPanel(rs, st, splitH-2, rightW-2) leftPanel := st.Panel.Width(leftW - 2).Render(leftContent) @@ -212,15 +209,15 @@ func RenderDashboard(d *DashboardState, taskName string, st Styles, width, heigh split := lipgloss.JoinHorizontal(lipgloss.Top, leftPanel, rightPanel) // ── 进度条面板 ── - progressLine := buildProgressLine(rs, st, ContentWidth(width)) - progressPanelStr := wrapPanel(st, progressLine, width) + progressLine := buildProgressLine(rs, st, ContentWidth(innerW)) + progressPanelStr := wrapPanel(st, progressLine, innerW) // ── 请求列表面板 ── - reqList := buildRequestList(d, rs, st, ContentWidth(width), reqListH) - reqPanelStr := wrapPanel(st, reqList, width) + reqList := buildRequestList(d, rs, st, ContentWidth(innerW), reqListH) + reqPanelStr := wrapPanel(st, reqList, innerW) content := strings.Join([]string{split, progressPanelStr, reqPanelStr}, "\n") - return l.Assemble(content, st, width) + return l.Assemble(wrapPanel(st, content, width), st, width) } // buildDashParamsPanel 构建左侧任务参数面板。 diff --git a/internal/tui/pages/reqdetail.go b/internal/tui/pages/reqdetail.go index 1e2d23f..c8d2eaa 100644 --- a/internal/tui/pages/reqdetail.go +++ b/internal/tui/pages/reqdetail.go @@ -122,18 +122,20 @@ func RenderReqDetail(s *ReqDetailState, taskName string, st Styles, width, heigh CtxItems: CtxBar_ReqDetail(), FooterParts: []string{"[q] 退出"}, } + innerW := ContentWidth(width) + innerH := l.ContentHeight(height) // ── 计算高度 ── splitH := 9 inputH := 5 - outputH := height - l.ChromeHeight() - splitH - inputH - 2 - 2 // -2 for input panel border, -2 for output panel border + outputH := innerH - splitH - inputH - 2 - 2 // -2 for input panel border, -2 for output panel border if outputH < 4 { outputH = 4 } // ── 双栏面板(性能指标 | 网络指标)── - leftW := width * 50 / 100 - rightW := width - leftW + leftW := innerW * 50 / 100 + rightW := innerW - leftW leftContent := buildReqPerfPanel(r, st, splitH-2, leftW-2) rightContent := buildReqNetworkPanel(r, st, splitH-2, rightW-2) leftPanel := st.Panel.Width(leftW - 2).Render(leftContent) @@ -141,15 +143,15 @@ func RenderReqDetail(s *ReqDetailState, taskName string, st Styles, width, heigh split := lipgloss.JoinHorizontal(lipgloss.Top, leftPanel, rightPanel) // ── 输入区面板 ── - inputSection := buildInputSection(r, st, ContentWidth(width), inputH) - inputPanelStr := wrapPanel(st, inputSection, width) + inputSection := buildInputSection(r, st, ContentWidth(innerW), inputH) + inputPanelStr := wrapPanel(st, inputSection, innerW) // ── 输出区面板 ── - outputSection := buildOutputSection(r, s.ScrollY, st, ContentWidth(width), outputH) - outputPanelStr := wrapPanel(st, outputSection, width) + outputSection := buildOutputSection(r, s.ScrollY, st, ContentWidth(innerW), outputH) + outputPanelStr := wrapPanel(st, outputSection, innerW) content := strings.Join([]string{split, inputPanelStr, outputPanelStr}, "\n") - return l.Assemble(content, st, width) + return l.Assemble(wrapPanel(st, content, width), st, width) } // buildReqPerfPanel 构建请求左侧性能指标面板。 diff --git a/internal/tui/pages/taskdetail.go b/internal/tui/pages/taskdetail.go index 2bb6b28..be34db8 100644 --- a/internal/tui/pages/taskdetail.go +++ b/internal/tui/pages/taskdetail.go @@ -15,6 +15,7 @@ import ( type TaskDetailState struct { Task types.TaskDefinition History []types.TaskRunSummary + BackNav NavAction // HistorySel 当前选中的历史记录索引(0 = 最近一次;若有正在运行的实例,0 = 运行中条目) HistorySel int HistoryOff int @@ -25,7 +26,7 @@ type TaskDetailState struct { // NewTaskDetailState 创建初始任务详情状态。 func NewTaskDetailState(task types.TaskDefinition) *TaskDetailState { - return &TaskDetailState{Task: task} + return &TaskDetailState{Task: task, BackNav: NavAction{To: NavTaskList}} } func taskDetailHistoryEntries(s *TaskDetailState) []types.TaskRunSummary { @@ -83,7 +84,11 @@ func HandleTaskDetailKey(s *TaskDetailState, msg tea.KeyMsg, client Client) (*Ta } case "left", "esc", "b": - nav = NavAction{To: NavTaskList} + if s.BackNav.To != NavNone { + nav = s.BackNav + } else { + nav = NavAction{To: NavTaskList} + } case "r": if s.ActiveRun == nil { @@ -172,7 +177,7 @@ func RenderTaskDetail(s *TaskDetailState, st Styles, width, height int) string { } l := PageLayout{ CtxItems: cbItems, - FooterParts: []string{"[b/Esc] 返回列表", "◆ AIT v0.1"}, + FooterParts: []string{"[b/Esc] 返回上一页", "◆ AIT v0.1"}, } content := buildTaskDetailContent(s, st, t, inp, ContentWidth(width), l.ContentHeight(height)) @@ -198,6 +203,10 @@ func buildTaskDetailContent(s *TaskDetailState, st Styles, t types.TaskDefinitio leftLines = append(leftLines, padRight(" "+st.Label.Render("协议")+" "+st.Value.Render(proto), leftW)) endpoint := truncate(inp.ResolvedEndpointURL(), leftW-8) leftLines = append(leftLines, padRight(" "+st.Label.Render("接口")+" "+st.Value.Render(endpoint), leftW)) + if inp.ProxyURL != "" { + proxy := truncate(inp.ProxyURL, leftW-8) + leftLines = append(leftLines, padRight(" "+st.Label.Render("代理")+" "+st.Value.Render(proxy), leftW)) + } leftLines = append(leftLines, padRight("", leftW)) model := truncate(inp.Model, leftW-10) diff --git a/internal/tui/pages/taskdetail_test.go b/internal/tui/pages/taskdetail_test.go index 20e0cfb..979e938 100644 --- a/internal/tui/pages/taskdetail_test.go +++ b/internal/tui/pages/taskdetail_test.go @@ -3,6 +3,7 @@ package pages import ( "testing" + tea "github.com/charmbracelet/bubbletea" "github.com/yinxulai/ait/internal/server" "github.com/yinxulai/ait/internal/types" ) @@ -24,3 +25,15 @@ func TestTaskDetailHistoryEntries_SkipsActiveRunDuplicate(t *testing.T) { t.Fatalf("RunID: got %q, want %q", entries[0].RunID, "run-1") } } + +func TestHandleTaskDetailKey_UsesBackNav(t *testing.T) { + state := &TaskDetailState{ + Task: types.TaskDefinition{ID: "task-1", Name: "task"}, + BackNav: NavAction{To: NavDashboard}, + } + + _, _, nav := HandleTaskDetailKey(state, tea.KeyMsg{Type: tea.KeyEsc}, nil) + if nav.To != NavDashboard { + t.Fatalf("nav.To = %v, want %v", nav.To, NavDashboard) + } +} diff --git a/internal/tui/pages/turbodash.go b/internal/tui/pages/turbodash.go index 09c115b..74cdf48 100644 --- a/internal/tui/pages/turbodash.go +++ b/internal/tui/pages/turbodash.go @@ -101,11 +101,6 @@ func HandleTurboDashKey(d *TurboDashState, msg tea.KeyMsg, client Client) (*Turb } case "b", "esc": - if d.CancelFn != nil { - d.CancelFn() - } - d.EventCh = nil - d.CancelFn = nil if d.BackNav.To != NavNone { nav = d.BackNav } else { @@ -169,20 +164,22 @@ func RenderTurboDash(d *TurboDashState, taskName string, st Styles, width, heigh } l := PageLayout{ CtxItems: cbItems, - FooterParts: []string{"[b/Esc] 返回列表", "[q] 退出"}, + FooterParts: []string{"[b/Esc] 返回上一页", "[q] 退出"}, } + innerW := ContentWidth(width) + innerH := l.ContentHeight(height) // ── 计算高度 ── splitH := 9 progressPanel := 3 - levelListH := height - l.ChromeHeight() - splitH - progressPanel - 2 + levelListH := innerH - splitH - progressPanel - 2 if levelListH < 3 { levelListH = 3 } // ── 双栏面板(任务参数 | 当前级别指标)── - leftW := width * 45 / 100 - rightW := width - leftW + leftW := innerW * 45 / 100 + rightW := innerW - leftW leftContent := buildTurboDashParams(rs, st, splitH-2, leftW-2) rightContent := buildTurboDashMetrics(rs, st, splitH-2, rightW-2) leftPanel := st.Panel.Width(leftW - 2).Render(leftContent) @@ -190,15 +187,15 @@ func RenderTurboDash(d *TurboDashState, taskName string, st Styles, width, heigh split := lipgloss.JoinHorizontal(lipgloss.Top, leftPanel, rightPanel) // ── 进度条面板 ── - progressLine := buildTurboProgressLine(rs, st, ContentWidth(width)) - progressPanelStr := wrapPanel(st, progressLine, width) + progressLine := buildTurboProgressLine(rs, st, ContentWidth(innerW)) + progressPanelStr := wrapPanel(st, progressLine, innerW) // ── 级别列表面板 ── - levelList := buildLevelList(d, rs, st, ContentWidth(width), levelListH) - levelPanelStr := wrapPanel(st, levelList, width) + levelList := buildLevelList(d, rs, st, ContentWidth(innerW), levelListH) + levelPanelStr := wrapPanel(st, levelList, innerW) content := strings.Join([]string{split, progressPanelStr, levelPanelStr}, "\n") - return l.Assemble(content, st, width) + return l.Assemble(wrapPanel(st, content, width), st, width) } // buildTurboDashParams 构建 Turbo 仪表盘左侧任务参数面板。 diff --git a/internal/tui/pages/wizard.go b/internal/tui/pages/wizard.go index d2ef1f2..9de5b9a 100644 --- a/internal/tui/pages/wizard.go +++ b/internal/tui/pages/wizard.go @@ -37,6 +37,7 @@ type WizardState struct { Name string Protocol string // types.Protocol* 常量 EndpointURL string + ProxyURL string APIKey string Model string @@ -80,7 +81,9 @@ func NewWizardState() *WizardState { StepSize: 2, LevelRequests: 30, MinSuccessRate: 90, - PromptMode: PromptModeText, + Stream: true, + PromptMode: PromptModeGenerated, + PromptLength: 100, } } @@ -97,15 +100,24 @@ func NewWizardStateEdit(t *types.TaskDefinition) *WizardState { wz.Name = t.Name wz.Protocol = types.NormalizeProtocol(inp.Protocol) wz.EndpointURL = inp.EndpointURL + wz.ProxyURL = inp.ProxyURL wz.APIKey = inp.ApiKey wz.Model = inp.Model wz.Turbo = inp.Turbo wz.Stream = inp.Stream wz.PromptText = inp.PromptText wz.PromptFile = inp.PromptFile - wz.PromptLength = inp.PromptLength + if inp.PromptLength > 0 { + wz.PromptLength = inp.PromptLength + } if inp.PromptMode != "" { wz.PromptMode = inp.PromptMode + } else if inp.PromptFile != "" { + wz.PromptMode = PromptModeFile + } else if inp.PromptLength > 0 { + wz.PromptMode = PromptModeGenerated + } else { + wz.PromptMode = PromptModeText } if inp.Concurrency > 0 { wz.Concurrency = inp.Concurrency @@ -149,6 +161,7 @@ func (wz *WizardState) BuildTaskConfig() server.TaskConfig { Input: types.Input{ Protocol: wz.Protocol, EndpointURL: wz.EndpointURL, + ProxyURL: wz.ProxyURL, ApiKey: wz.APIKey, Model: wz.Model, Concurrency: wz.Concurrency, @@ -253,6 +266,11 @@ func step1Fields() []fieldDef { getRaw: func(wz *WizardState) string { return wz.EndpointURL }, set: func(wz *WizardState, v string) { wz.EndpointURL = v }, }, + { + kind: fieldText, label: "代理地址", + get: func(wz *WizardState) string { return wz.ProxyURL }, + set: func(wz *WizardState, v string) { wz.ProxyURL = v }, + }, { kind: fieldText, label: "API 密钥", get: func(wz *WizardState) string { return wz.APIKey }, @@ -343,6 +361,9 @@ func step2Fields(turbo bool) []fieldDef { idx = (idx - 1 + len(promptModes)) % len(promptModes) } wz.PromptMode = promptModes[idx] + if wz.PromptMode == PromptModeGenerated && wz.PromptLength <= 0 { + wz.PromptLength = 100 + } }, }, ) @@ -709,11 +730,12 @@ func renderWizardField(st Styles, f fieldDef, wz *WizardState, active bool, maxW } renderedValue := fieldStyle.Width(fieldW).Render(valueStyle.Render(valueStr)) - // 用 JoinHorizontal 而非字符串拼接:renderedValue 有 3 行(上边框/内容/下边框), - // 直接 + 只有第一行有 label 前缀,后两行会从列 0 开始,导致布局混乱。 - // JoinHorizontal(Top, ...) 会将 label 块和 field 块按顶部对齐水平拼接, - // label 块高度自动补齐到与 field 相同(3 行),布局整齐。 - labelBlock := lipgloss.NewStyle().Width(15).Render(st.Label.Render(wizardFieldLabel(f, wz))) + labelLines := []string{ + strings.Repeat(" ", 15), + lipgloss.NewStyle().Width(15).Render(st.Label.Render(wizardFieldLabel(f, wz))), + strings.Repeat(" ", 15), + } + labelBlock := strings.Join(labelLines, "\n") return lipgloss.JoinHorizontal(lipgloss.Top, labelBlock, renderedValue) } @@ -732,6 +754,7 @@ func renderStep3Summary(wz *WizardState, st Styles, innerW int) []string { endpointDisplay = types.DefaultEndpointURL(wz.Protocol) } addRow("接口地址", endpointDisplay, st.Value) + addRow("代理地址", wizardFallback(wz.ProxyURL, "直连"), st.Value) addRow("API 密钥", wizardFallback(maskAPIKey(wz.APIKey), "未填写"), st.Value) addRow("测试模型", wizardFallback(wz.Model, "未填写"), st.Value) diff --git a/internal/tui/pages/wizard_test.go b/internal/tui/pages/wizard_test.go new file mode 100644 index 0000000..b54fd6d --- /dev/null +++ b/internal/tui/pages/wizard_test.go @@ -0,0 +1,60 @@ +package pages + +import ( + "regexp" + "strings" + "testing" + + tea "github.com/charmbracelet/bubbletea" + "github.com/yinxulai/ait/internal/server" +) + +var ansiRE = regexp.MustCompile("\\x1b\\[[0-9;]*m") + +func stripANSI(s string) string { + return ansiRE.ReplaceAllString(s, "") +} + +func TestRenderWizardField_CentersLabelOnInputRow(t *testing.T) { + st := NewStyles() + wz := NewWizardState() + wz.Name = "demo" + field := step1Fields()[0] + + rendered := stripANSI(renderWizardField(st, field, wz, true, 80)) + lines := strings.Split(rendered, "\n") + if len(lines) < 3 { + t.Fatalf("expected at least 3 lines, got %d", len(lines)) + } + if strings.Contains(lines[0], "任务名称") { + t.Fatalf("label should not be rendered on top border line: %q", lines[0]) + } + if !strings.Contains(lines[1], "任务名称") { + t.Fatalf("label should be rendered on the input content line: %q", lines[1]) + } +} + +func TestHandleDashboardKey_BackPreservesSubscription(t *testing.T) { + called := false + ch := make(chan server.Event) + d := &DashboardState{ + EventCh: ch, + CancelFn: func() { called = true }, + BackNav: NavAction{To: NavTaskDetail}, + } + + _, _, nav := HandleDashboardKey(d, tea.KeyMsg{Type: tea.KeyEsc}, nil) + if nav.To != NavTaskDetail { + t.Fatalf("nav.To = %v, want %v", nav.To, NavTaskDetail) + } + if called { + t.Fatal("CancelFn should not be called when returning to previous page") + } + if d.EventCh != ch { + t.Fatal("EventCh should be preserved when returning to previous page") + } + if d.CancelFn == nil { + t.Fatal("CancelFn should remain set when returning to previous page") + } + close(ch) +} diff --git a/internal/types/types.go b/internal/types/types.go index 43c9406..e128f27 100644 --- a/internal/types/types.go +++ b/internal/types/types.go @@ -89,6 +89,7 @@ type Input struct { Protocol string `json:"protocol"` EndpointURL string `json:"endpoint_url,omitempty"` BaseUrl string `json:"base_url,omitempty"` + ProxyURL string `json:"proxy_url,omitempty"` ApiKey string `json:"api_key,omitempty"` Model string `json:"model"` Concurrency int `json:"concurrency,omitempty"` From a77cceb8bce909dcada85f6fb5c032044392bc22 Mon Sep 17 00:00:00 2001 From: Alain Date: Wed, 20 May 2026 09:05:12 +0800 Subject: [PATCH 20/52] =?UTF-8?q?feat:=20=E6=9B=B4=E6=96=B0=20PromptSource?= =?UTF-8?q?=20=E7=BB=93=E6=9E=84=EF=BC=8C=E4=BC=98=E5=8C=96=E7=B3=BB?= =?UTF-8?q?=E7=BB=9F=E6=B6=88=E6=81=AF=E5=A4=84=E7=90=86=E5=92=8C=E7=94=9F?= =?UTF-8?q?=E6=88=90=E5=86=85=E5=AE=B9=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/prompt/prompt.go | 43 +++++++++------------------------------ 1 file changed, 10 insertions(+), 33 deletions(-) diff --git a/internal/prompt/prompt.go b/internal/prompt/prompt.go index 9761f7d..27b6f53 100644 --- a/internal/prompt/prompt.go +++ b/internal/prompt/prompt.go @@ -16,7 +16,7 @@ type PromptSource struct { IsFile bool // 是否来自文件 FilePaths []string // 文件路径列表 Contents []string // prompt内容列表(仅用于非文件内容) - SystemContent string // 固定的系统消息内容(仅 generated 模式使用,用于触发前缀缓存) + SystemContent string // 可选的系统消息内容;为空时表示不额外发送 system 消息 DisplayText string // 用于显示的文本 ShouldTruncate bool // 是否需要截断显示(对于已经包含长度信息的内容,不需要再次处理) } @@ -99,8 +99,7 @@ func loadMultipleFiles(pattern string) (*PromptSource, error) { }, nil } -// GetSystemContent 返回系统消息内容(固定的大段上下文,用于前缀缓存)。 -// 非 generated 模式返回空字符串,不影响原有请求结构。 +// GetSystemContent 返回系统消息内容;为空时不发送额外的 system 消息。 func (ps *PromptSource) GetSystemContent() string { return ps.SystemContent } @@ -274,44 +273,22 @@ func GeneratePromptByLength(length int) string { // LoadPromptByLength 创建指定长度的 PromptSource。 // -// 为了让测试中部分请求满足前缀缓存条件(Prefix Cache),内容被拆分为两部分: -// - SystemContent(约 90% 长度):固定不变的大段上下文,作为 system 消息发送; -// 同一批次所有请求共享相同的 system 消息,API 侧命中前缀缓存后可大幅降低延迟。 -// - Contents(user 消息候选列表):多条短问题,每个请求按 index 取模轮流使用, -// 既保证请求内容有差异,又确保 system 前缀不变以触发缓存。 +// generated 模式的语义是“生成一条指定长度的固定内容”,因此这里返回单内容源。 +// 运行时每个请求都会拿到同一份生成文本;若后续需要 system prompt,应由上层显式建模, +// 而不是在这里隐式拆分 generated prompt 的含义。 func LoadPromptByLength(length int) (*PromptSource, error) { if length <= 0 { return nil, fmt.Errorf("prompt 长度必须大于 0") } - - // 90% 作为 system 消息(固定,供缓存命中) - systemLen := length * 9 / 10 - if systemLen < 1 { - systemLen = 1 - } - systemContent := GeneratePromptByLength(systemLen) - actualSystemLen := utf8.RuneCountInString(systemContent) - - // 短而多样的 user 消息,各请求轮流使用(保证差异 + 共享 system 前缀) - userQuestions := []string{ - "请帮我总结一下上述内容的核心要点。", - "根据以上信息,有什么值得特别关注的地方?", - "上述内容中最重要的信息是什么?", - "请对以上内容进行简短分析。", - "上述内容的主要主题是什么,请概括。", - "从以上内容中能得出哪些结论?", - "以上内容有哪些值得深入探讨的点?", - "请提炼上述内容的关键信息。", - "对以上内容你有什么看法?", - "上述内容对实际应用有什么启示?", - } + content := GeneratePromptByLength(length) + actualLen := utf8.RuneCountInString(content) return &PromptSource{ IsFile: false, FilePaths: nil, - Contents: userQuestions, - SystemContent: systemContent, - DisplayText: fmt.Sprintf("生成内容 (系统消息: %d 字符,轮换用户问题 x%d)", actualSystemLen, len(userQuestions)), + Contents: []string{content}, + SystemContent: "", + DisplayText: fmt.Sprintf("生成内容 (%d 字符)", actualLen), ShouldTruncate: false, }, nil } From 61c14b625940e2f4eefb4c8b2a2b2b600543a86b Mon Sep 17 00:00:00 2001 From: Alain Date: Wed, 20 May 2026 21:06:52 +0800 Subject: [PATCH 21/52] =?UTF-8?q?feat:=20=E5=A2=9E=E5=BC=BA=20Anthropic=20?= =?UTF-8?q?=E5=AE=A2=E6=88=B7=E7=AB=AF=EF=BC=8C=E6=94=AF=E6=8C=81=E7=BC=93?= =?UTF-8?q?=E5=AD=98=E6=8E=A7=E5=88=B6=E5=92=8C=E8=BE=93=E5=85=A5=20token?= =?UTF-8?q?=20=E7=BB=9F=E8=AE=A1=EF=BC=8C=E4=BC=98=E5=8C=96=E7=94=9F?= =?UTF-8?q?=E6=88=90=E6=8F=90=E7=A4=BA=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/client/anthropic.go | 102 ++++++++++++++++++--- internal/client/anthropic_test.go | 92 +++++++++++++++++++ internal/prompt/prompt.go | 147 ++++++++++++++++++++++++++++-- internal/prompt/prompt_test.go | 45 ++++++--- internal/task/input_test.go | 8 +- 5 files changed, 358 insertions(+), 36 deletions(-) diff --git a/internal/client/anthropic.go b/internal/client/anthropic.go index 7959fe3..568542a 100644 --- a/internal/client/anthropic.go +++ b/internal/client/anthropic.go @@ -29,6 +29,7 @@ type AnthropicResponse struct { Model string `json:"model"` Usage struct { InputTokens int `json:"input_tokens"` + CacheCreationInputTokens int `json:"cache_creation_input_tokens"` CacheReadInputTokens int `json:"cache_read_input_tokens"` OutputTokens int `json:"output_tokens"` } `json:"usage"` @@ -47,6 +48,14 @@ type AnthropicErrorResponse struct { type AnthropicStreamChunk struct { Type string `json:"type"` Index int `json:"index,omitempty"` + Message *struct { + Usage *struct { + InputTokens int `json:"input_tokens"` + CacheCreationInputTokens int `json:"cache_creation_input_tokens"` + CacheReadInputTokens int `json:"cache_read_input_tokens"` + OutputTokens int `json:"output_tokens"` + } `json:"usage,omitempty"` + } `json:"message,omitempty"` Delta struct { Type string `json:"type"` Text string `json:"text"` @@ -54,12 +63,43 @@ type AnthropicStreamChunk struct { PartialJSON *string `json:"partial_json,omitempty"` } `json:"delta,omitempty"` Usage *struct { - InputTokens int `json:"input_tokens"` - CacheReadInputTokens int `json:"cache_read_input_tokens"` - OutputTokens int `json:"output_tokens"` + InputTokens int `json:"input_tokens"` + CacheCreationInputTokens int `json:"cache_creation_input_tokens"` + CacheReadInputTokens int `json:"cache_read_input_tokens"` + OutputTokens int `json:"output_tokens"` } `json:"usage,omitempty"` } +func anthropicTextBlock(text string) map[string]interface{} { + return map[string]interface{}{ + "type": "text", + "text": text, + } +} + +func buildAnthropicSystemBlocks(systemPrompt string) []map[string]interface{} { + parts := strings.Split(systemPrompt, "\n\n") + blocks := make([]map[string]interface{}, 0, len(parts)) + for _, part := range parts { + trimmed := strings.TrimSpace(part) + if trimmed == "" { + continue + } + blocks = append(blocks, anthropicTextBlock(trimmed)) + } + if len(blocks) == 0 { + return nil + } + blocks[len(blocks)-1]["cache_control"] = map[string]interface{}{ + "type": "ephemeral", + } + return blocks +} + +func anthropicTotalInputTokens(inputTokens, cacheCreationInputTokens, cacheReadInputTokens int) int { + return inputTokens + cacheCreationInputTokens + cacheReadInputTokens +} + // AnthropicClient Anthropic 协议客户端 type AnthropicClient struct { EndpointURL string @@ -117,16 +157,21 @@ func (c *AnthropicClient) Request(systemPrompt, userPrompt string, stream bool) "model": c.Model, "messages": []map[string]interface{}{ { - "role": "user", - "content": userPrompt, + "role": "user", + "content": []map[string]interface{}{ + anthropicTextBlock(userPrompt), + }, }, }, "stream": stream, } - // 如果有 system prompt,添加顶层 system 字段(Anthropic API 规范) + // Anthropic 的缓存需要显式 cache_control,公共前缀应放在稳定的 system blocks 上。 if systemPrompt != "" { - requestBody["system"] = systemPrompt + systemBlocks := buildAnthropicSystemBlocks(systemPrompt) + if len(systemBlocks) > 0 { + requestBody["system"] = systemBlocks + } } // 如果启用了 thinking 模式,添加 thinking 配置 @@ -303,6 +348,7 @@ func (c *AnthropicClient) Request(systemPrompt, userPrompt string, stream bool) var fullContent strings.Builder var outputTokens int var inputTokens int + var cacheCreationInputTokens int var cachedInputTokens int var streamChunks []string // 用于记录所有流式数据块 var rawResponseLines strings.Builder @@ -339,6 +385,21 @@ func (c *AnthropicClient) Request(systemPrompt, userPrompt string, stream bool) if err := json.Unmarshal([]byte(data), &chunk); err != nil { continue // 跳过无法解析的行 } + + if chunk.Message != nil && chunk.Message.Usage != nil { + if chunk.Message.Usage.InputTokens > 0 { + inputTokens = chunk.Message.Usage.InputTokens + } + if chunk.Message.Usage.CacheCreationInputTokens > 0 { + cacheCreationInputTokens = chunk.Message.Usage.CacheCreationInputTokens + } + if chunk.Message.Usage.CacheReadInputTokens > 0 { + cachedInputTokens = chunk.Message.Usage.CacheReadInputTokens + } + if chunk.Message.Usage.OutputTokens > 0 { + outputTokens = chunk.Message.Usage.OutputTokens + } + } if chunk.Type == "content_block_delta" { // 检查是否有任何形式的内容输出(包括 Text、Thinking 或 PartialJSON) @@ -363,9 +424,18 @@ func (c *AnthropicClient) Request(systemPrompt, userPrompt string, stream bool) // 获取 token 统计信息 if chunk.Usage != nil { - inputTokens = chunk.Usage.InputTokens - cachedInputTokens = chunk.Usage.CacheReadInputTokens - outputTokens = chunk.Usage.OutputTokens + if chunk.Usage.InputTokens > 0 { + inputTokens = chunk.Usage.InputTokens + } + if chunk.Usage.CacheCreationInputTokens > 0 { + cacheCreationInputTokens = chunk.Usage.CacheCreationInputTokens + } + if chunk.Usage.CacheReadInputTokens > 0 { + cachedInputTokens = chunk.Usage.CacheReadInputTokens + } + if chunk.Usage.OutputTokens > 0 { + outputTokens = chunk.Usage.OutputTokens + } } } } @@ -391,11 +461,13 @@ func (c *AnthropicClient) Request(systemPrompt, userPrompt string, stream bool) "total_time": totalTime.String(), "time_to_first_token": firstTokenTime.String(), "input_tokens": inputTokens, + "cache_creation_input_tokens": cacheCreationInputTokens, "cached_input_tokens": cachedInputTokens, "output_tokens": outputTokens, "full_content": fullContent.String(), }) } + promptTokens := anthropicTotalInputTokens(inputTokens, cacheCreationInputTokens, cachedInputTokens) return &ResponseMetrics{ TimeToFirstToken: firstTokenTime, @@ -404,7 +476,7 @@ func (c *AnthropicClient) Request(systemPrompt, userPrompt string, stream bool) ConnectTime: connectTime, TLSHandshakeTime: tlsTime, TargetIP: targetIP, - PromptTokens: inputTokens, + PromptTokens: promptTokens, CachedInputTokens: cachedInputTokens, CompletionTokens: outputTokens, RequestBody: string(reqBodyBytes), @@ -494,11 +566,17 @@ func (c *AnthropicClient) Request(systemPrompt, userPrompt string, stream bool) "total_time": totalTime.String(), "output_tokens": anthropicResp.Usage.OutputTokens, "input_tokens": anthropicResp.Usage.InputTokens, + "cache_creation_input_tokens": anthropicResp.Usage.CacheCreationInputTokens, "cached_input_tokens": anthropicResp.Usage.CacheReadInputTokens, "response_id": anthropicResp.ID, "content_length": len(contentText), }) } + promptTokens := anthropicTotalInputTokens( + anthropicResp.Usage.InputTokens, + anthropicResp.Usage.CacheCreationInputTokens, + anthropicResp.Usage.CacheReadInputTokens, + ) return &ResponseMetrics{ TimeToFirstToken: totalTime, // 非流式模式下,所有token一次性返回,TTFT等于总时间 @@ -507,7 +585,7 @@ func (c *AnthropicClient) Request(systemPrompt, userPrompt string, stream bool) ConnectTime: connectTime, TLSHandshakeTime: tlsTime, TargetIP: targetIP, - PromptTokens: anthropicResp.Usage.InputTokens, + PromptTokens: promptTokens, CachedInputTokens: anthropicResp.Usage.CacheReadInputTokens, CompletionTokens: anthropicResp.Usage.OutputTokens, RequestBody: string(reqBodyBytes), diff --git a/internal/client/anthropic_test.go b/internal/client/anthropic_test.go index 63d3fb9..bd8ce3a 100644 --- a/internal/client/anthropic_test.go +++ b/internal/client/anthropic_test.go @@ -1,6 +1,7 @@ package client import ( + "encoding/json" "fmt" "net/http" "net/http/httptest" @@ -246,6 +247,97 @@ func TestAnthropicClient_Request_Stream(t *testing.T) { } } +func TestAnthropicClient_Request_SystemPromptUsesCacheControl(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var body map[string]interface{} + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + t.Fatalf("decode request body: %v", err) + } + + system, ok := body["system"].([]interface{}) + if !ok || len(system) != 2 { + t.Fatalf("expected 2 system blocks, got %#v", body["system"]) + } + + lastBlock, ok := system[len(system)-1].(map[string]interface{}) + if !ok { + t.Fatalf("unexpected system block: %#v", system[len(system)-1]) + } + cacheControl, ok := lastBlock["cache_control"].(map[string]interface{}) + if !ok { + t.Fatalf("expected cache_control on last system block, got %#v", lastBlock) + } + if cacheControl["type"] != "ephemeral" { + t.Fatalf("cache_control.type = %#v, want %#v", cacheControl["type"], "ephemeral") + } + + w.Header().Set("Content-Type", "application/json") + fmt.Fprint(w, `{"id":"test","type":"message","role":"assistant","content":[{"type":"text","text":"ok"}],"model":"claude-3","usage":{"input_tokens":4,"output_tokens":1}}`) + })) + defer server.Close() + + client := NewAnthropicClient(createTestConfig(server.URL, "test-key", "claude-3-sonnet", 30*time.Second, false)) + if _, err := client.Request("公共消息1\n\n公共消息2", "user prompt", false); err != nil { + t.Fatalf("Request() error = %v", err) + } +} + +func TestAnthropicClient_Request_PromptTokensIncludeCachedAndCreatedInput(t *testing.T) { + t.Run("non-stream", func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + fmt.Fprint(w, `{"id":"msg_test","type":"message","role":"assistant","content":[{"type":"text","text":"ok"}],"model":"claude-3-sonnet","usage":{"input_tokens":50,"cache_creation_input_tokens":100,"cache_read_input_tokens":900,"output_tokens":10}}`) + })) + defer server.Close() + + client := NewAnthropicClient(createTestConfig(server.URL, "test-key", "claude-3-sonnet", 30*time.Second, false)) + metrics, err := client.Request("shared system", "user prompt", false) + if err != nil { + t.Fatalf("Request() error = %v", err) + } + if metrics.PromptTokens != 1050 { + t.Fatalf("PromptTokens = %d, want %d", metrics.PromptTokens, 1050) + } + if metrics.CachedInputTokens != 900 { + t.Fatalf("CachedInputTokens = %d, want %d", metrics.CachedInputTokens, 900) + } + }) + + t.Run("stream", func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.Header().Set("Transfer-Encoding", "chunked") + + flusher, _ := w.(http.Flusher) + fmt.Fprint(w, "event: message_start\n") + fmt.Fprint(w, `data: {"type":"message_start","message":{"usage":{"input_tokens":40,"cache_creation_input_tokens":160,"cache_read_input_tokens":800,"output_tokens":0}}}`+"\n\n") + flusher.Flush() + fmt.Fprint(w, "event: content_block_delta\n") + fmt.Fprint(w, `data: {"type":"content_block_delta","delta":{"type":"text_delta","text":"hello"}}`+"\n\n") + flusher.Flush() + fmt.Fprint(w, "event: message_delta\n") + fmt.Fprint(w, `data: {"type":"message_delta","usage":{"output_tokens":12}}`+"\n\n") + flusher.Flush() + })) + defer server.Close() + + client := NewAnthropicClient(createTestConfig(server.URL, "test-key", "claude-3-sonnet", 30*time.Second, false)) + metrics, err := client.Request("shared system", "user prompt", true) + if err != nil { + t.Fatalf("Request() error = %v", err) + } + if metrics.PromptTokens != 1000 { + t.Fatalf("PromptTokens = %d, want %d", metrics.PromptTokens, 1000) + } + if metrics.CachedInputTokens != 800 { + t.Fatalf("CachedInputTokens = %d, want %d", metrics.CachedInputTokens, 800) + } + if metrics.CompletionTokens != 12 { + t.Fatalf("CompletionTokens = %d, want %d", metrics.CompletionTokens, 12) + } + }) +} + func TestAnthropicClient_Request_ServerError(t *testing.T) { server := createMockAnthropicServer(0, false, http.StatusInternalServerError) defer server.Close() diff --git a/internal/prompt/prompt.go b/internal/prompt/prompt.go index 27b6f53..53fd92f 100644 --- a/internal/prompt/prompt.go +++ b/internal/prompt/prompt.go @@ -11,6 +11,21 @@ import ( "unicode/utf8" ) +var generatedCommonSeeds = []string{ + "公共消息1:以下内容描述一个固定的评测背景,所有请求都共享这段上下文,以便模拟前缀缓存命中。", + "公共消息2:请基于同一组系统约束、相同的领域设定和一致的输出风格进行分析,不要改变整体语境。", + "公共消息3:你正在参与稳定负载测试,公共上下文应尽量保持一致,只有用户问题会发生变化。", +} + +var generatedUserSeeds = []string{ + "随机用户消息1:请提炼以上背景里的三个关键结论,并说明它们为什么重要。", + "随机用户消息2:基于上述共享信息,总结最值得关注的风险点与应对方向。", + "随机用户消息3:请用简洁结构归纳核心要点,并指出其中最有价值的一条。", + "随机用户消息4:结合以上上下文,说明该场景在实际落地时应优先验证哪些指标。", + "随机用户消息5:请从性能、稳定性和可观测性三个角度给出简短分析。", + "随机用户消息6:在不改变公共背景的前提下,概括最可能影响结果判断的因素。", +} + // PromptSource 表示prompt的来源信息 type PromptSource struct { IsFile bool // 是否来自文件 @@ -271,24 +286,140 @@ func GeneratePromptByLength(length int) string { return builder.String() } +func splitGeneratedPromptLengths(total int) (commonLen, userLen int) { + if total <= 0 { + return 0, 0 + } + + if total <= 24 { + return 0, total + } + + commonLen = total * 7 / 10 + userLen = total - commonLen + + if userLen < 12 { + userLen = 12 + if total < userLen { + userLen = total + } + commonLen = total - userLen + } + + if commonLen < 0 { + commonLen = 0 + } + if userLen < 0 { + userLen = 0 + } + + return commonLen, userLen +} + +func splitBudget(total, parts int) []int { + if parts <= 0 { + return nil + } + budgets := make([]int, parts) + base := total / parts + rest := total % parts + for i := 0; i < parts; i++ { + budgets[i] = base + if i < rest { + budgets[i]++ + } + } + return budgets +} + +func truncateToRunes(text string, length int) string { + if length <= 0 { + return "" + } + runes := []rune(text) + if len(runes) <= length { + return text + } + return string(runes[:length]) +} + +func composeSizedText(seed string, target int) string { + if target <= 0 { + return "" + } + + seed = strings.TrimSpace(seed) + if utf8.RuneCountInString(seed) >= target { + return truncateToRunes(seed, target) + } + + builder := strings.Builder{} + builder.WriteString(seed) + currentLen := utf8.RuneCountInString(seed) + + if currentLen < target { + remaining := target - currentLen + if remaining > 0 { + builder.WriteString(GeneratePromptByLength(remaining)) + } + } + + return truncateToRunes(builder.String(), target) +} + +func buildGeneratedCommonPrompt(target int) string { + if target <= 0 { + return "" + } + + messageCount := len(generatedCommonSeeds) + separatorLen := 2 * (messageCount - 1) + if target <= separatorLen+messageCount { + return composeSizedText(generatedCommonSeeds[0], target) + } + + bodyBudget := target - separatorLen + budgets := splitBudget(bodyBudget, messageCount) + parts := make([]string, 0, messageCount) + for i, budget := range budgets { + parts = append(parts, composeSizedText(generatedCommonSeeds[i], budget)) + } + + return truncateToRunes(strings.Join(parts, "\n\n"), target) +} + +func buildGeneratedUserPrompts(target int) []string { + if target <= 0 { + return []string{""} + } + + contents := make([]string, 0, len(generatedUserSeeds)) + for _, seed := range generatedUserSeeds { + contents = append(contents, composeSizedText(seed, target)) + } + return contents +} + // LoadPromptByLength 创建指定长度的 PromptSource。 // -// generated 模式的语义是“生成一条指定长度的固定内容”,因此这里返回单内容源。 -// 运行时每个请求都会拿到同一份生成文本;若后续需要 system prompt,应由上层显式建模, -// 而不是在这里隐式拆分 generated prompt 的含义。 +// generated 模式会构造一段共享公共前缀和多条用户问题变体: +// - SystemContent: 共享的公共消息,所有请求保持一致,用于模拟缓存命中前缀。 +// - Contents: 多条不同的用户消息,请求按索引轮换,模拟公共前缀下的随机用户提问。 +// 单次请求的总 prompt 长度仍与传入的 length 保持一致。 func LoadPromptByLength(length int) (*PromptSource, error) { if length <= 0 { return nil, fmt.Errorf("prompt 长度必须大于 0") } - content := GeneratePromptByLength(length) - actualLen := utf8.RuneCountInString(content) + commonLen, userLen := splitGeneratedPromptLengths(length) + systemContent := buildGeneratedCommonPrompt(commonLen) + contents := buildGeneratedUserPrompts(userLen) return &PromptSource{ IsFile: false, FilePaths: nil, - Contents: []string{content}, - SystemContent: "", - DisplayText: fmt.Sprintf("生成内容 (%d 字符)", actualLen), + Contents: contents, + SystemContent: systemContent, + DisplayText: fmt.Sprintf("生成内容 (公共消息 %d 字符, 用户变体 x%d, 单次总长 %d 字符)", utf8.RuneCountInString(systemContent), len(contents), length), ShouldTruncate: false, }, nil } diff --git a/internal/prompt/prompt_test.go b/internal/prompt/prompt_test.go index 8f11c26..915a1d0 100644 --- a/internal/prompt/prompt_test.go +++ b/internal/prompt/prompt_test.go @@ -117,15 +117,24 @@ func TestLoadPromptByLength(t *testing.T) { t.Errorf("LoadPromptByLength 不应该设置 IsFile = true") } - if len(source.Contents) != 1 { - t.Errorf("LoadPromptByLength 应该返回 1 个内容,实际返回 %d 个", len(source.Contents)) + if len(source.Contents) <= 1 { + t.Errorf("LoadPromptByLength 应该返回多条用户变体,实际返回 %d 个", len(source.Contents)) } - // 验证内容长度 - content := source.GetRandomContent() - actualLen := utf8.RuneCountInString(content) - if actualLen != tt.length { - t.Errorf("LoadPromptByLength(%d) 返回内容长度 = %d", tt.length, actualLen) + if source.GetSystemContent() == "" { + t.Errorf("LoadPromptByLength 应该生成共享公共消息") + } + + if strings.Count(source.GetSystemContent(), "\n\n") < 2 { + t.Errorf("LoadPromptByLength 应该包含多条公共消息") + } + + // 验证单次请求的总长度保持与 length 一致 + for i, content := range source.Contents { + actualLen := utf8.RuneCountInString(source.GetSystemContent()) + utf8.RuneCountInString(content) + if actualLen != tt.length { + t.Errorf("LoadPromptByLength(%d) 第 %d 条变体总长度 = %d", tt.length, i, actualLen) + } } // 验证 DisplayText @@ -143,17 +152,23 @@ func TestPromptSourceWithGeneratedContent(t *testing.T) { t.Fatalf("LoadPromptByLength 失败: %v", err) } - // 测试多次调用 GetRandomContent 应该返回相同的内容 - content1 := source.GetRandomContent() - content2 := source.GetRandomContent() + if source.GetSystemContent() == "" { + t.Fatal("generated prompt 应包含公共消息") + } + + if source.Count() <= 1 { + t.Fatalf("Count() = %d, 期望大于 1", source.Count()) + } - if content1 != content2 { - t.Errorf("GetRandomContent 在单内容源时应该返回相同内容") + content1 := source.GetContentByIndex(0) + content2 := source.GetContentByIndex(1) + if content1 == content2 { + t.Errorf("不同索引的用户消息应该不同,以模拟随机用户消息") } - // 测试 Count 方法 - if source.Count() != 1 { - t.Errorf("Count() = %d, 期望 1", source.Count()) + totalLen := utf8.RuneCountInString(source.GetSystemContent()) + utf8.RuneCountInString(content1) + if totalLen != length { + t.Errorf("generated prompt 总长度 = %d, 期望 %d", totalLen, length) } } diff --git a/internal/task/input_test.go b/internal/task/input_test.go index 8a58f16..6f77605 100644 --- a/internal/task/input_test.go +++ b/internal/task/input_test.go @@ -21,9 +21,15 @@ func TestHydrateInputGeneratedMode(t *testing.T) { if err != nil { t.Fatalf("HydrateInput(generated) returned unexpected error: %v", err) } - if input.PromptSource == nil || input.PromptSource.Count() != 1 { + if input.PromptSource == nil { t.Fatal("expected generated PromptSource to be created") } + if input.PromptSource.Count() <= 1 { + t.Fatal("expected generated PromptSource to include multiple user variants") + } + if input.PromptSource.GetSystemContent() == "" { + t.Fatal("expected generated PromptSource to include shared common messages") + } } func TestHydrateInputRejectsInvalidMode(t *testing.T) { From 4022621c84fbbd840a5b8066cb2535fb12cda2c0 Mon Sep 17 00:00:00 2001 From: Alain Date: Wed, 20 May 2026 22:26:24 +0800 Subject: [PATCH 22/52] Add tests and refactor TUI page layouts for improved structure and styling - Introduced new test cases for page layouts in layout_test.go and reqdetail_test.go to ensure proper rendering and height consistency for success and failure scenarios. - Refactored reqdetail.go to streamline rendering logic and improve layout handling, including dynamic height calculations for panels. - Updated styles in styles.go to enhance visual consistency across the TUI, including new styles for hotkeys and panel headers. - Modified taskdetail.go and tasklist.go to utilize the new layout structure, ensuring proper rendering of task details and lists. - Enhanced turbodash.go and wizard.go to adopt the new layout framework, improving readability and maintainability of the code. --- cmd/ait/ait.go | 1 + internal/tui/model.go | 3 + internal/tui/pages/contextbar.go | 245 +++++++------- internal/tui/pages/dashboard.go | 98 +++--- internal/tui/pages/helpers.go | 467 ++++++++++++++++++--------- internal/tui/pages/layout.go | 169 ++++++++-- internal/tui/pages/layout_test.go | 65 ++++ internal/tui/pages/reqdetail.go | 133 ++++---- internal/tui/pages/reqdetail_test.go | 56 ++++ internal/tui/pages/styles.go | 51 +-- internal/tui/pages/taskdetail.go | 77 ++--- internal/tui/pages/tasklist.go | 40 ++- internal/tui/pages/turbodash.go | 101 +++--- internal/tui/pages/wizard.go | 68 ++-- 14 files changed, 1014 insertions(+), 560 deletions(-) create mode 100644 internal/tui/pages/layout_test.go create mode 100644 internal/tui/pages/reqdetail_test.go diff --git a/cmd/ait/ait.go b/cmd/ait/ait.go index 2d57758..bcecb93 100644 --- a/cmd/ait/ait.go +++ b/cmd/ait/ait.go @@ -98,6 +98,7 @@ func main() { } // ── 启动 TUI ────────────────────────────────────────────────────────────── + tui.SetVersion(Version) if err := tui.Run(srv); err != nil { fmt.Fprintf(os.Stderr, "TUI 启动失败: %v\n", err) os.Exit(1) diff --git a/internal/tui/model.go b/internal/tui/model.go index 8c79ea0..b23415d 100644 --- a/internal/tui/model.go +++ b/internal/tui/model.go @@ -64,6 +64,9 @@ func Run(srv server.Server) error { return err } +// SetVersion 设置 AppHeader 中显示的版本字符串,应在 Run 之前调用。 +func SetVersion(v string) { pages.SetAppVersion(v) } + // ─── BubbleTea 接口 ─────────────────────────────────────────────────────────── func (m *Model) Init() tea.Cmd { diff --git a/internal/tui/pages/contextbar.go b/internal/tui/pages/contextbar.go index 8b8f2f3..ea5ca93 100644 --- a/internal/tui/pages/contextbar.go +++ b/internal/tui/pages/contextbar.go @@ -1,170 +1,189 @@ package pages -// ContextBarItem 是底栏中的一个可操作项。 -type ContextBarItem struct { +// HotkeyItem 是底部 Hotkeys 区中的一个展示项。 +type HotkeyItem struct { Key string // 如 "Enter"、"r"、"↑↓" Desc string // 操作描述 + Text string // 纯文本提示 } -// ─── 各页面底栏操作定义 ─────────────────────────────────────────────────────── +func HotkeyAction(key, desc string) HotkeyItem { + return HotkeyItem{Key: key, Desc: desc} +} + +func HotkeyText(text string) HotkeyItem { + return HotkeyItem{Text: text} +} + +func HotkeyTexts(texts ...string) []HotkeyItem { + items := make([]HotkeyItem, 0, len(texts)) + for _, text := range texts { + if text != "" { + items = append(items, HotkeyText(text)) + } + } + return items +} + +// ─── 各页面底部 Hotkeys 定义 ──────────────────────────────────────────────── -// CtxBar_TaskList_Normal 普通任务选中时的 Context Bar。 -func CtxBar_TaskList_Normal() []ContextBarItem { - return []ContextBarItem{ - {Key: "Enter", Desc: "查看详情"}, - {Key: "r", Desc: "运行"}, - {Key: "e", Desc: "编辑"}, - {Key: "d", Desc: "删除"}, - {Key: "y", Desc: "复制"}, +// Hotkeys_TaskList_Normal 普通任务选中时的 Hotkeys。 +func Hotkeys_TaskList_Normal() []HotkeyItem { + return []HotkeyItem{ + HotkeyAction("Enter", "查看详情"), + HotkeyAction("r", "运行"), + HotkeyAction("e", "编辑"), + HotkeyAction("d", "删除"), + HotkeyAction("y", "复制"), } } -// CtxBar_TaskList_Running 运行中任务选中时的 Context Bar。 -func CtxBar_TaskList_Running() []ContextBarItem { - return []ContextBarItem{ - {Key: "Enter", Desc: "查看详情"}, - {Key: "s", Desc: "停止"}, - {Key: "y", Desc: "复制"}, +// Hotkeys_TaskList_Running 运行中任务选中时的 Hotkeys。 +func Hotkeys_TaskList_Running() []HotkeyItem { + return []HotkeyItem{ + HotkeyAction("Enter", "查看详情"), + HotkeyAction("s", "停止"), + HotkeyAction("y", "复制"), } } -// CtxBar_TaskDetail_NoHistory 任务详情页,无运行记录时。 -func CtxBar_TaskDetail_NoHistory() []ContextBarItem { - return []ContextBarItem{ - {Key: "r", Desc: "运行"}, - {Key: "e", Desc: "编辑"}, - {Key: "y", Desc: "复制"}, - {Key: "d", Desc: "删除"}, +// Hotkeys_TaskDetail_NoHistory 任务详情页,无运行记录时。 +func Hotkeys_TaskDetail_NoHistory() []HotkeyItem { + return []HotkeyItem{ + HotkeyAction("r", "运行"), + HotkeyAction("e", "编辑"), + HotkeyAction("y", "复制"), + HotkeyAction("d", "删除"), } } -// CtxBar_TaskDetail_HasHistory 任务详情页,有运行记录且未运行时。 -func CtxBar_TaskDetail_HasHistory() []ContextBarItem { - return []ContextBarItem{ - {Key: "↑↓", Desc: "选择记录"}, - {Key: "Enter", Desc: "查看运行详情"}, - {Key: "r", Desc: "再次运行"}, - {Key: "g", Desc: "导出 JSON 报告"}, - {Key: "e", Desc: "编辑"}, - {Key: "y", Desc: "复制任务"}, - {Key: "d", Desc: "删除"}, +// Hotkeys_TaskDetail_HasHistory 任务详情页,有运行记录且未运行时。 +func Hotkeys_TaskDetail_HasHistory() []HotkeyItem { + return []HotkeyItem{ + HotkeyAction("↑↓", "选择记录"), + HotkeyAction("Enter", "查看运行详情"), + HotkeyAction("r", "再次运行"), + HotkeyAction("g", "导出 JSON 报告"), + HotkeyAction("e", "编辑"), + HotkeyAction("y", "复制任务"), + HotkeyAction("d", "删除"), } } -// CtxBar_TaskDetail_Running 任务详情页,任务正在运行时。 -func CtxBar_TaskDetail_Running() []ContextBarItem { - return []ContextBarItem{ - {Key: "↑↓", Desc: "选择记录"}, - {Key: "Enter", Desc: "进入运行中仓表盘"}, - {Key: "g", Desc: "导出历史 JSON"}, - {Key: "e", Desc: "编辑"}, - {Key: "y", Desc: "复制任务"}, +// Hotkeys_TaskDetail_Running 任务详情页,任务正在运行时。 +func Hotkeys_TaskDetail_Running() []HotkeyItem { + return []HotkeyItem{ + HotkeyAction("↑↓", "选择记录"), + HotkeyAction("Enter", "进入运行中仓表盘"), + HotkeyAction("g", "导出历史 JSON"), + HotkeyAction("e", "编辑"), + HotkeyAction("y", "复制任务"), } } -// CtxBar_Wizard_Step1 创建任务页,第 1 步。 -func CtxBar_Wizard_Step1() []ContextBarItem { - return []ContextBarItem{ - {Key: "Tab/↑↓", Desc: "切换字段"}, - {Key: "←→", Desc: "切换协议"}, - {Key: "Enter", Desc: "下一步"}, - {Key: "Esc", Desc: "返回列表"}, +// Hotkeys_Wizard_Step1 创建任务页,第 1 步。 +func Hotkeys_Wizard_Step1() []HotkeyItem { + return []HotkeyItem{ + HotkeyAction("Tab/↑↓", "切换字段"), + HotkeyAction("←→", "切换协议"), + HotkeyAction("Enter", "下一步"), + HotkeyAction("Esc", "返回列表"), } } -// CtxBar_Wizard_Step2 创建任务页,第 2 步。 -func CtxBar_Wizard_Step2() []ContextBarItem { - return []ContextBarItem{ - {Key: "Tab/↑↓", Desc: "切换字段"}, - {Key: "←→", Desc: "切换选项"}, - {Key: "Enter", Desc: "下一步"}, - {Key: "Esc", Desc: "返回上一步"}, +// Hotkeys_Wizard_Step2 创建任务页,第 2 步。 +func Hotkeys_Wizard_Step2() []HotkeyItem { + return []HotkeyItem{ + HotkeyAction("Tab/↑↓", "切换字段"), + HotkeyAction("←→", "切换选项"), + HotkeyAction("Enter", "下一步"), + HotkeyAction("Esc", "返回上一步"), } } -// CtxBar_Wizard_Step3 创建任务页,第 3 步。 -func CtxBar_Wizard_Step3() []ContextBarItem { - return []ContextBarItem{ - {Key: "↑↓", Desc: "滚动"}, - {Key: "PgUp/PgDn", Desc: "翻页"}, - {Key: "Enter", Desc: "保存"}, - {Key: "r", Desc: "保存并运行"}, - {Key: "Esc", Desc: "返回修改"}, +// Hotkeys_Wizard_Step3 创建任务页,第 3 步。 +func Hotkeys_Wizard_Step3() []HotkeyItem { + return []HotkeyItem{ + HotkeyAction("↑↓", "滚动"), + HotkeyAction("PgUp/PgDn", "翻页"), + HotkeyAction("Enter", "保存"), + HotkeyAction("r", "保存并运行"), + HotkeyAction("Esc", "返回修改"), } } -// CtxBar_Dashboard_Running_NoSel 标准仪表盘运行中,无选中请求时。 -func CtxBar_Dashboard_Running_NoSel() []ContextBarItem { - return []ContextBarItem{ - {Key: "s", Desc: "停止"}, - {Key: "b/Esc", Desc: "返回列表"}, +// Hotkeys_Dashboard_Running_NoSel 标准仪表盘运行中,无选中请求时。 +func Hotkeys_Dashboard_Running_NoSel() []HotkeyItem { + return []HotkeyItem{ + HotkeyAction("s", "停止"), + HotkeyAction("b/Esc", "返回列表"), } } -// CtxBar_Dashboard_Done_NoSel 标准仪表盘完成后,无选中请求时。 -func CtxBar_Dashboard_Done_NoSel() []ContextBarItem { - return []ContextBarItem{ - {Key: "r", Desc: "生成报告"}, - {Key: "b/Esc", Desc: "返回列表"}, +// Hotkeys_Dashboard_Done_NoSel 标准仪表盘完成后,无选中请求时。 +func Hotkeys_Dashboard_Done_NoSel() []HotkeyItem { + return []HotkeyItem{ + HotkeyAction("r", "生成报告"), + HotkeyAction("b/Esc", "返回列表"), } } -// CtxBar_Dashboard_Running_Sel 标准仪表盘运行中,已选中请求时。 -func CtxBar_Dashboard_Running_Sel() []ContextBarItem { - return []ContextBarItem{ - {Key: "Enter", Desc: "查看请求详情"}, - {Key: "↑↓", Desc: "选择请求"}, - {Key: "s", Desc: "停止"}, +// Hotkeys_Dashboard_Running_Sel 标准仪表盘运行中,已选中请求时。 +func Hotkeys_Dashboard_Running_Sel() []HotkeyItem { + return []HotkeyItem{ + HotkeyAction("Enter", "查看请求详情"), + HotkeyAction("↑↓", "选择请求"), + HotkeyAction("s", "停止"), } } -// CtxBar_Dashboard_Done_Sel 标准仪表盘完成后,已选中请求时。 -func CtxBar_Dashboard_Done_Sel() []ContextBarItem { - return []ContextBarItem{ - {Key: "Enter", Desc: "查看请求详情"}, - {Key: "↑↓", Desc: "选择请求"}, +// Hotkeys_Dashboard_Done_Sel 标准仪表盘完成后,已选中请求时。 +func Hotkeys_Dashboard_Done_Sel() []HotkeyItem { + return []HotkeyItem{ + HotkeyAction("Enter", "查看请求详情"), + HotkeyAction("↑↓", "选择请求"), } } -// CtxBar_TurboDash_Running_NoSel Turbo 仪表盘运行中,无选中级别时。 -func CtxBar_TurboDash_Running_NoSel() []ContextBarItem { - return []ContextBarItem{ - {Key: "s", Desc: "停止"}, - {Key: "m", Desc: "标记极限并停止"}, - {Key: "b/Esc", Desc: "返回列表"}, +// Hotkeys_TurboDash_Running_NoSel Turbo 仪表盘运行中,无选中级别时。 +func Hotkeys_TurboDash_Running_NoSel() []HotkeyItem { + return []HotkeyItem{ + HotkeyAction("s", "停止"), + HotkeyAction("m", "标记极限并停止"), + HotkeyAction("b/Esc", "返回列表"), } } -// CtxBar_TurboDash_Done_NoSel Turbo 仪表盘完成后,无选中级别时。 -func CtxBar_TurboDash_Done_NoSel() []ContextBarItem { - return []ContextBarItem{ - {Key: "r", Desc: "生成报告"}, - {Key: "b/Esc", Desc: "返回列表"}, +// Hotkeys_TurboDash_Done_NoSel Turbo 仪表盘完成后,无选中级别时。 +func Hotkeys_TurboDash_Done_NoSel() []HotkeyItem { + return []HotkeyItem{ + HotkeyAction("r", "生成报告"), + HotkeyAction("b/Esc", "返回列表"), } } -// CtxBar_TurboDash_Running_Sel Turbo 仪表盘运行中,已选中级别时。 -func CtxBar_TurboDash_Running_Sel() []ContextBarItem { - return []ContextBarItem{ - {Key: "Enter", Desc: "查看该级别请求"}, - {Key: "↑↓", Desc: "选择"}, - {Key: "s", Desc: "停止"}, +// Hotkeys_TurboDash_Running_Sel Turbo 仪表盘运行中,已选中级别时。 +func Hotkeys_TurboDash_Running_Sel() []HotkeyItem { + return []HotkeyItem{ + HotkeyAction("Enter", "查看该级别请求"), + HotkeyAction("↑↓", "选择"), + HotkeyAction("s", "停止"), } } -// CtxBar_TurboDash_Done_Sel Turbo 仪表盘完成后,已选中级别时。 -func CtxBar_TurboDash_Done_Sel() []ContextBarItem { - return []ContextBarItem{ - {Key: "Enter", Desc: "查看该级别请求"}, - {Key: "↑↓", Desc: "选择"}, +// Hotkeys_TurboDash_Done_Sel Turbo 仪表盘完成后,已选中级别时。 +func Hotkeys_TurboDash_Done_Sel() []HotkeyItem { + return []HotkeyItem{ + HotkeyAction("Enter", "查看该级别请求"), + HotkeyAction("↑↓", "选择"), } } -// CtxBar_ReqDetail 请求详情页。 -func CtxBar_ReqDetail() []ContextBarItem { - return []ContextBarItem{ - {Key: "b/Esc", Desc: "返回仪表盘"}, - {Key: "↑↓", Desc: "滚动"}, - {Key: "←→", Desc: "上/下一条请求"}, +// Hotkeys_ReqDetail 请求详情页。 +func Hotkeys_ReqDetail() []HotkeyItem { + return []HotkeyItem{ + HotkeyAction("b/Esc", "返回仪表盘"), + HotkeyAction("↑↓", "滚动"), + HotkeyAction("←→", "上/下一条请求"), } } diff --git a/internal/tui/pages/dashboard.go b/internal/tui/pages/dashboard.go index 512de60..2a7fa8f 100644 --- a/internal/tui/pages/dashboard.go +++ b/internal/tui/pages/dashboard.go @@ -173,58 +173,67 @@ func RenderDashboard(d *DashboardState, taskName string, st Styles, width, heigh isRunning := d.IsRunning() hasSel := d.ReqSel >= 0 && rs != nil && d.ReqSel < len(rs.Requests) - var cbItems []ContextBarItem + var cbItems []HotkeyItem switch { case hasSel && isRunning: - cbItems = CtxBar_Dashboard_Running_Sel() + cbItems = Hotkeys_Dashboard_Running_Sel() case hasSel && !isRunning: - cbItems = CtxBar_Dashboard_Done_Sel() + cbItems = Hotkeys_Dashboard_Done_Sel() case !hasSel && isRunning: - cbItems = CtxBar_Dashboard_Running_NoSel() + cbItems = Hotkeys_Dashboard_Running_NoSel() default: - cbItems = CtxBar_Dashboard_Done_NoSel() + cbItems = Hotkeys_Dashboard_Done_NoSel() + } + headerLeft := []string{"等待数据"} + headerRight := []string{} + if rs != nil { + headerLeft = []string{runStatusText(string(rs.Status)), fmt.Sprintf("完成 %d/%d", rs.DoneReqs, rs.TotalReqs)} + headerRight = []string{fmt.Sprintf("成功率 %.1f%%", rs.SuccessRate)} + if !rs.StartedAt.IsZero() { + headerRight = append(headerRight, "开始 "+fmtRelativeTime(rs.StartedAt)) + } + } + if d.TaskID != "" { + headerRight = append(headerRight, "任务 "+truncate(d.TaskID, 14)) } l := PageLayout{ - CtxItems: cbItems, - FooterParts: []string{"[b/Esc] 返回上一页", "[q] 退出"}, + HeaderTitle: "标准运行监控", + HeaderSubtitle: "实时查看运行进度、吞吐和单请求明细", + HeaderMeta: "标准模式", + HeaderInfoLeft: headerLeft, + HeaderInfoRight: headerRight, + Hotkeys: NewPageHotkeys(cbItems, "[b/Esc] 返回上一页", "[q] 退出"), } - innerW := ContentWidth(width) - innerH := l.ContentHeight(height) + frame := l.Frame(width, height) + bodyPanel := frame.InnerPanel() // ── 计算高度 ── - splitH := 9 // 双栏面板外部总高度(含面板边框) - progressPanel := 3 // 进度条面板外部高度(1内容+2边框) - reqListH := innerH - splitH - progressPanel - 2 // -2 for req panel border - if reqListH < 3 { - reqListH = 3 - } + splitOuterH := 9 // 双栏面板外部总高度(含面板边框) + progressOuterH := 3 // 进度条面板外部高度(1内容+2边框) + reqOuterH := RemainingStackOuterHeight(frame.InnerHeight, splitOuterH, progressOuterH) + reqListH := PanelContentHeight(reqOuterH) // ── 双栏面板(任务参数 | 实时指标)── - leftW := innerW * 45 / 100 - rightW := innerW - leftW - leftContent := buildDashParamsPanel(d, rs, st, splitH-2, leftW-2) - rightContent := buildDashMetricsPanel(rs, st, splitH-2, rightW-2) - leftPanel := st.Panel.Width(leftW - 2).Render(leftContent) - rightPanel := st.Panel.Width(rightW - 2).Render(rightContent) - split := lipgloss.JoinHorizontal(lipgloss.Top, leftPanel, rightPanel) + leftPanelFrame, rightPanelFrame := bodyPanel.Split(45, 24) + leftContent := buildDashParamsPanel(d, rs, st, PanelContentHeight(splitOuterH), leftPanelFrame.InnerWidth) + rightContent := buildDashMetricsPanel(rs, st, PanelContentHeight(splitOuterH), rightPanelFrame.InnerWidth) + split := renderSplitPanels(st, leftPanelFrame, rightPanelFrame, leftContent, rightContent) // ── 进度条面板 ── - progressLine := buildProgressLine(rs, st, ContentWidth(innerW)) - progressPanelStr := wrapPanel(st, progressLine, innerW) + progressLine := buildProgressLine(rs, st, bodyPanel.InnerWidth) + progressPanelStr := bodyPanel.Wrap(st, progressLine) // ── 请求列表面板 ── - reqList := buildRequestList(d, rs, st, ContentWidth(innerW), reqListH) - reqPanelStr := wrapPanel(st, reqList, innerW) + reqList := buildRequestList(d, rs, st, bodyPanel.InnerWidth, reqListH) + reqPanelStr := bodyPanel.Wrap(st, reqList) - content := strings.Join([]string{split, progressPanelStr, reqPanelStr}, "\n") - return l.Assemble(wrapPanel(st, content, width), st, width) + content := joinVerticalBlocks(split, progressPanelStr, reqPanelStr) + return l.Assemble(frame.Wrap(st, content), st, width) } // buildDashParamsPanel 构建左侧任务参数面板。 func buildDashParamsPanel(d *DashboardState, rs *server.RunState, st Styles, maxH, width int) string { - var lines []string - lines = append(lines, " "+st.SectionHead.Render("运行进度")) - lines = append(lines, "") + lines := panelTitleLines(st, "运行进度", width, false) if rs == nil { lines = append(lines, " "+st.Muted.Render("等待数据...")) @@ -235,17 +244,12 @@ func buildDashParamsPanel(d *DashboardState, rs *server.RunState, st Styles, max lines = append(lines, " "+labelValue(st, "失败", fmt.Sprintf("%d", rs.FailedReqs))) } - for len(lines) < maxH { - lines = append(lines, "") - } - return strings.Join(lines[:maxH], "\n") + return finishPanelLines(lines, maxH) } // buildDashMetricsPanel 构建右侧实时指标面板。 func buildDashMetricsPanel(rs *server.RunState, st Styles, maxH, width int) string { - var lines []string - lines = append(lines, " "+st.SectionHead.Render("实时指标")) - lines = append(lines, "") + lines := panelTitleLines(st, "实时指标", width, false) if rs == nil { lines = append(lines, " "+st.Muted.Render("等待数据...")) @@ -261,10 +265,7 @@ func buildDashMetricsPanel(rs *server.RunState, st Styles, maxH, width int) stri lines = append(lines, " "+st.Muted.Render(fmt.Sprintf(" 成功: %d 失败: %d", rs.SuccessReqs, rs.FailedReqs))) } - for len(lines) < maxH { - lines = append(lines, "") - } - return strings.Join(lines[:maxH], "\n") + return finishPanelLines(lines, maxH) } // buildProgressLine 构建进度条行。 @@ -303,8 +304,7 @@ func buildProgressLine(rs *server.RunState, st Styles, width int) string { // buildRequestList 构建请求列表区域。 func buildRequestList(d *DashboardState, rs *server.RunState, st Styles, width, maxH int) string { - var lines []string - lines = append(lines, " "+st.SectionHead.Render("请求列表")) + lines := panelTitleLines(st, "请求列表", width, true) if rs == nil || len(rs.Requests) == 0 { msg := "等待请求..." @@ -312,10 +312,7 @@ func buildRequestList(d *DashboardState, rs *server.RunState, st Styles, width, msg = "无请求详情数据" } lines = append(lines, " "+st.Muted.Render(msg)) - for len(lines) < maxH { - lines = append(lines, "") - } - return strings.Join(lines, "\n") + return finishPanelLines(lines, maxH) } // 列宽(header 与 content 行保持一致,前缀均为 2 字符) @@ -384,10 +381,7 @@ func buildRequestList(d *DashboardState, rs *server.RunState, st Styles, width, } } - for len(lines) < maxH { - lines = append(lines, "") - } - return strings.Join(lines[:maxH], "\n") + return finishPanelLines(lines, maxH) } func requestDisplayPos(reqIndex, total int) int { diff --git a/internal/tui/pages/helpers.go b/internal/tui/pages/helpers.go index 779f637..e041258 100644 --- a/internal/tui/pages/helpers.go +++ b/internal/tui/pages/helpers.go @@ -130,88 +130,344 @@ func fmtRelativeTime(t time.Time) string { // ─── 布局工具 ───────────────────────────────────────────────────────────────── -// renderHeader 渲染顶部双行标题栏。 -// 第一行:◆ brand(粉色)│ 页面标题(青色),深色背景 -// 第二行:infoLeft(左)/ infoRight(右),较暗色背景 -func renderHeader(st Styles, width int, titleLeft, titleRight, infoLeft, infoRight string) string { +// AppVersion 是展示在 AppHeader 中的版本字符串,由 SetAppVersion 在启动时设置。 +var AppVersion = "dev" + +// SetAppVersion 更新 AppHeader 中展示的版本字符串。 +func SetAppVersion(v string) { AppVersion = v } + +// renderHeader 渲染统一 AppHeader(三行)。 +// 第一行:AIT ASCII 字符画 + 页面标题 | meta 徽章 +// 第二行:AIT 字符画第二行 + 子标题 | 版本徽章 +// 第三行:AIT 字符画第三行 + 左信息 chips | 右信息 chips +func renderHeader(st Styles, width int, title, subtitle, meta string, infoLeft, infoRight []string) string { w := width if w < 1 { w = 80 } - // Line 1: avoid nested Render() fragments to prevent ANSI reset from breaking background. - brand := titleLeft - pageTitle := "" - if idx := strings.Index(titleLeft, " "); idx >= 0 { - brand = titleLeft[:idx] - pageTitle = strings.TrimSpace(titleLeft[idx:]) - } - - brandSeg := lipgloss.NewStyle(). - Background(colorHeaderBg). - Foreground(colorPink). - Bold(true). - Render(" ◆ " + brand) - sepSeg := lipgloss.NewStyle(). - Background(colorHeaderBg). - Foreground(colorDivider). - Render(" │ ") - titleSeg := lipgloss.NewStyle(). - Background(colorHeaderBg). - Foreground(colorCyan). - Bold(true). - Render(pageTitle) - - left1 := brandSeg - if pageTitle != "" { - left1 += sepSeg + titleSeg + title = truncate(strings.TrimSpace(title), maxInt(12, w/2)) + subtitle = truncate(strings.TrimSpace(subtitle), maxInt(16, w/2)) + meta = truncate(strings.TrimSpace(meta), maxInt(10, w/4)) + if title == "" { + title = "AIT" + } + + // AIT ASCII 字符画(三行,实心彩色渐变) + artA := [3]string{" ██ ", " █ █ ", "██████"} // 可视宽 6 + artI := [3]string{" █ ", " █ ", " █ "} // 可视宽 3 + artT := [3]string{"█████", " █ ", " █ "} // 可视宽 5 + + styleA := lipgloss.NewStyle().Foreground(colorPink).Bold(true) + styleI := lipgloss.NewStyle().Foreground(colorGold).Bold(true) + styleT := lipgloss.NewStyle().Foreground(colorCyan).Bold(true) + styleSep := lipgloss.NewStyle().Foreground(colorDivider) + + // artVisW = 6+1+3+1+5 = 16; plus " │ " (3) + leading space (1) = total art prefix 20 + artVisW := 6 + 1 + 3 + 1 + 5 + artSepW := artVisW + 1 + 3 // art + " " + "│" + " " + + artRow := func(i int) string { + return styleA.Render(artA[i]) + " " + styleI.Render(artI[i]) + " " + styleT.Render(artT[i]) } + vsep := styleSep.Render("│") + + wideEnough := w >= 48 // 宽屏才展示 ASCII art + + // ── Line 1: [art row 0] │ title [meta badge] ──────── right1 := "" - if titleRight != "" { + if meta != "" { right1 = lipgloss.NewStyle(). - Background(colorHeaderBg). - Foreground(colorMuted). - Render(titleRight + " ") - } - left1W := lipgloss.Width(left1) - right1W := lipgloss.Width(right1) - pad1 := w - left1W - right1W - if pad1 < 0 { - pad1 = 0 - } - padSeg := lipgloss.NewStyle(). - Background(colorHeaderBg). - Render(strings.Repeat(" ", pad1)) - line1 := left1 + padSeg + right1 - - // ─ Line 2: info bar ─ - il := " " + infoLeft - ir := infoRight + " " - ilW := lipgloss.Width(il) - irW := lipgloss.Width(ir) - pad2 := w - ilW - irW - if pad2 < 0 { - pad2 = 0 - } - line2 := st.HeaderInfo.Width(w).Render(il + strings.Repeat(" ", pad2) + ir) + Background(colorCyan).Foreground(colorHeaderBg).Bold(true).Padding(0, 1). + Render(meta) + " " + } + var left1 string + if wideEnough { + titleSeg := lipgloss.NewStyle().Foreground(colorWhite).Bold(true).Render(title) + left1 = " " + artRow(0) + " " + vsep + " " + titleSeg + } else { + brandPill := lipgloss.NewStyle().Background(colorPink).Foreground(colorHeaderBg).Bold(true).Padding(0, 1).Render("AIT") + titlePill := lipgloss.NewStyle().Foreground(colorWhite).Bold(true).Padding(0, 1).Render(title) + left1 = " " + lipgloss.JoinHorizontal(lipgloss.Center, brandPill, titlePill) + } + line1 := renderChromeLine(st.Header, w, left1, right1) + + // ── Line 2: [art row 1] │ subtitle [version badge] ────── + verBadge := lipgloss.NewStyle(). + Background(colorPurple).Foreground(colorWhite).Padding(0, 1). + Render("v" + AppVersion) + " " + var left2 string + if wideEnough { + subSeg := "" + if subtitle != "" { + subSeg = lipgloss.NewStyle(). + Background(colorHotkeysPrimaryBg).Foreground(colorWhite).Padding(0, 1). + Render(subtitle) + } + left2 = " " + artRow(1) + " " + vsep + " " + subSeg + } else { + if subtitle != "" { + left2 = " " + lipgloss.NewStyle(). + Background(colorHotkeysPrimaryBg).Foreground(colorWhite).Padding(0, 1). + Render(subtitle) + } + } + line2 := renderChromeLine(st.Header, w, left2, verBadge) + + // ── Line 3: [art row 2] │ infoLeft chips infoRight chips ─── + var left3 string + if wideEnough { + artPart := " " + artRow(2) + " " + vsep + availW := maxInt(8, w-artSepW-2-maxInt(10, w/3)) + if leftPills := renderInfoPills(infoLeft, availW); leftPills != "" { + left3 = artPart + " " + leftPills + } else { + left3 = artPart + } + } else { + if leftPills := renderInfoPills(infoLeft, maxInt(8, w/3)); leftPills != "" { + left3 = " " + leftPills + } + } + right3 := "" + if pills := renderInfoPills(infoRight, maxInt(10, w/3)); pills != "" { + right3 = pills + " " + } + line3 := renderChromeLine(st.HeaderInfo, w, left3, right3) - return line1 + "\n" + line2 + return line1 + "\n" + line2 + "\n" + line3 } -// renderFooter 渲染底部状态栏(单行,深色背景)。 -func renderFooter(st Styles, width int, parts ...string) string { +// renderHotkeys 渲染统一页面 Hotkeys。 +// 第一行展示当前页快捷操作,第二行展示返回/退出等全局上下文与应用标识。 +func renderHotkeys(st Styles, width int, hk PageHotkeys) string { w := width if w < 1 { w = 80 } + + hkLine := renderPrimaryHotkeyItems(hk.Hotkeys, maxInt(8, w-4)) + line1 := renderChromeLine(st.HotkeysPrimary, w, " "+hkLine, "") + + appStamp := lipgloss.NewStyle().Foreground(colorPink).Bold(true).Render("AIT") + + lipgloss.NewStyle().Foreground(colorMuted).Render(" 终端 · "+time.Now().Format("15:04")) + left2 := renderSecondaryHotkeyItems(hk.Hints, maxInt(8, w-lipgloss.Width(appStamp)-4)) + line2 := renderChromeLine(st.HotkeysSecondary, w, " "+left2, appStamp+" ") + + return line1 + "\n" + line2 +} + +func renderChromeLine(base lipgloss.Style, width int, left, right string) string { + leftW := lipgloss.Width(left) + rightW := lipgloss.Width(right) + pad := width - leftW - rightW + if pad < 0 { + pad = 0 + } + return base.Width(width).Render(left + strings.Repeat(" ", pad) + right) +} + +func renderInfoPills(parts []string, maxW int) string { + parts = nonEmptyParts(parts) + if len(parts) == 0 { + return "" + } + + var rendered []string + for _, part := range parts { + rendered = append(rendered, lipgloss.NewStyle(). + Background(lipgloss.Color("239")). + Foreground(colorHeaderFg). + Padding(0, 1). + Render(truncate(part, 28))) + } + return fitRenderedParts(rendered, " ", maxW) +} + +func renderPrimaryHotkeyItems(items []HotkeyItem, maxW int) string { + if len(items) == 0 { + return lipgloss.NewStyle(). + Background(lipgloss.Color("239")). + Foreground(colorMuted). + Padding(0, 1). + Render("当前页暂无快捷操作") + } + + var rendered []string + for _, item := range items { + if item.Key == "" && item.Desc == "" { + if item.Text == "" { + continue + } + rendered = append(rendered, lipgloss.NewStyle(). + Background(lipgloss.Color("239")). + Foreground(colorWhite). + Padding(0, 1). + Render(item.Text)) + continue + } + keySeg := lipgloss.NewStyle(). + Background(colorGold). + Foreground(colorHeaderBg). + Bold(true). + Padding(0, 1). + Render(item.Key) + descSeg := lipgloss.NewStyle(). + Background(lipgloss.Color("239")). + Foreground(colorWhite). + Padding(0, 1). + Render(item.Desc) + rendered = append(rendered, lipgloss.JoinHorizontal(lipgloss.Center, keySeg, descSeg)) + } + return fitRenderedParts(rendered, " ", maxW) +} + +func renderSecondaryHotkeyItems(items []HotkeyItem, maxW int) string { + var parts []string + for _, item := range items { + text := strings.TrimSpace(item.Text) + if text == "" && (item.Key != "" || item.Desc != "") { + switch { + case item.Key != "" && item.Desc != "": + text = "[" + item.Key + "] " + item.Desc + case item.Key != "": + text = item.Key + default: + text = item.Desc + } + } + if text != "" { + parts = append(parts, text) + } + } + return fitPlainParts(parts, " • ", maxW) +} + +func fitRenderedParts(parts []string, sep string, maxW int) string { + visible := nonEmptyParts(parts) + if len(visible) == 0 { + return "" + } + if maxW <= 0 { + return strings.Join(visible, sep) + } + + var chosen []string + used := 0 + sepW := lipgloss.Width(sep) + for _, part := range visible { + partW := lipgloss.Width(part) + extra := partW + if len(chosen) > 0 { + extra += sepW + } + if len(chosen) > 0 && used+extra > maxW { + break + } + chosen = append(chosen, part) + used += extra + } + if len(chosen) == 0 { + return visible[0] + } + return strings.Join(chosen, sep) +} + +func fitPlainParts(parts []string, sep string, maxW int) string { + visible := nonEmptyParts(parts) + if len(visible) == 0 { + return "" + } + if maxW <= 0 { + return strings.Join(visible, sep) + } + + var chosen []string + used := 0 + sepW := lipgloss.Width(sep) + for _, part := range visible { + part = truncate(part, maxW) + partW := lipgloss.Width(part) + extra := partW + if len(chosen) > 0 { + extra += sepW + } + if len(chosen) > 0 && used+extra > maxW { + break + } + chosen = append(chosen, part) + used += extra + } + if len(chosen) == 0 { + return truncate(visible[0], maxW) + } + return strings.Join(chosen, sep) +} + +func nonEmptyParts(parts []string) []string { var visible []string - for _, p := range parts { - if p != "" { - visible = append(visible, p) + for _, part := range parts { + part = strings.TrimSpace(part) + if part != "" { + visible = append(visible, part) } } - line := " " + strings.Join(visible, " ") - return st.Footer.Width(w).Render(line) + return visible +} + +func runStatusText(status string) string { + switch strings.ToLower(strings.TrimSpace(status)) { + case "running": + return "运行中" + case "completed": + return "已完成" + case "failed": + return "运行失败" + case "stopped": + return "已停止" + case "": + return "等待数据" + default: + return status + } +} + +func panelTitleLines(st Styles, title string, width int, compact bool) []string { + var rendered string + if width > 0 { + rendered = st.PanelHead.Width(width).Render(" " + title) + } else { + rendered = st.PanelHead.Render(" " + title) + } + lines := []string{rendered} + if !compact { + lines = append(lines, "") + } + return lines +} + +func finishPanelLines(lines []string, maxH int) string { + if maxH < 1 { + maxH = 1 + } + if len(lines) > maxH { + lines = lines[:maxH] + } + for len(lines) < maxH { + lines = append(lines, "") + } + return strings.Join(lines[:maxH], "\n") +} + +func renderSplitPanels(st Styles, leftFrame, rightFrame PanelFrame, leftContent, rightContent string) string { + leftPanel := leftFrame.Wrap(st, leftContent) + rightPanel := rightFrame.Wrap(st, rightContent) + return lipgloss.JoinHorizontal(lipgloss.Top, leftPanel, rightPanel) +} + +func normalizeInlineText(s string) string { + replacer := strings.NewReplacer("\r", " ", "\n", " ", "\t", " ") + return strings.Join(strings.Fields(replacer.Replace(s)), " ") } // renderTableHeader 统一渲染列表表头。 @@ -300,89 +556,6 @@ func styleWhenNotSelected(isSel bool, style lipgloss.Style, text string) string return style.Render(text) } -// renderWelcomeHero 渲染任务中心顶部的品牌欢迎区。 -func renderWelcomeHero(st Styles, width int) []string { - if width < 42 { - return nil - } - - art := []string{ - " _ ___ _____", - " / \\ |_ _|_ _|", - " / _ \\ | | | | ", - " / ___ \\ | | | | ", - "/_/ \\_\\___| |_| ", - } - artStyles := []lipgloss.Style{ - lipgloss.NewStyle().Foreground(colorPink).Bold(true), - lipgloss.NewStyle().Foreground(colorCyan).Bold(true), - lipgloss.NewStyle().Foreground(colorGold).Bold(true), - lipgloss.NewStyle().Foreground(colorTeal).Bold(true), - lipgloss.NewStyle().Foreground(colorPurple).Bold(true), - } - - type heroTextLine struct { - style lipgloss.Style - text string - } - intro := []heroTextLine{ - {style: st.SectionHead, text: "AI 模型性能测试工作台"}, - {style: st.Value, text: "批量压测 OpenAI / Anthropic 协议模型,聚焦 TTFT、TPS、缓存与网络指标。"}, - {style: st.Muted, text: "从任务中心出发:创建任务、直接运行、查看执行记录、导出报告。"}, - {style: st.Muted, text: "[a] 新建任务 [Enter] 查看详情/进入仪表盘 [r] 立即运行"}, - } - - artW := 0 - for _, line := range art { - artW = maxInt(artW, lipgloss.Width(line)) - } - - if width >= 76 { - gap := 3 - rightW := maxInt(18, width-artW-gap) - wrapped := make([]string, 0, 8) - for i, line := range intro { - segments := wrapText(line.text, rightW) - if len(segments) == 0 { - segments = []string{""} - } - for _, segment := range segments { - wrapped = append(wrapped, line.style.Render(segment)) - } - if i == 0 { - wrapped = append(wrapped, "") - } - } - - total := maxInt(len(art), len(wrapped)) - lines := make([]string, 0, total) - for i := 0; i < total; i++ { - left := strings.Repeat(" ", artW) - if i < len(art) { - left = artStyles[i].Render(art[i]) - } - right := "" - if i < len(wrapped) { - right = wrapped[i] - } - lines = append(lines, padRight(left, artW)+strings.Repeat(" ", gap)+right) - } - return lines - } - - lines := make([]string, 0, len(art)+len(intro)+1) - for i, line := range art { - lines = append(lines, artStyles[i].Render(line)) - } - lines = append(lines, "") - for _, line := range intro { - for _, segment := range wrapText(line.text, width) { - lines = append(lines, line.style.Render(segment)) - } - } - return lines -} - // wrapIndex 循环索引(保证 0 ≤ result < count)。 func wrapIndex(idx, count int) int { if count <= 0 { diff --git a/internal/tui/pages/layout.go b/internal/tui/pages/layout.go index 324f123..5dcdcb5 100644 --- a/internal/tui/pages/layout.go +++ b/internal/tui/pages/layout.go @@ -7,11 +7,11 @@ import "strings" const ( // MinWidth / MinHeight:低于此值时显示"窗口过小"提示而非正常页面。 MinWidth = 40 - MinHeight = 10 + MinHeight = 12 - // chrome 各组成部分的行数(仅保留单条合并底栏) - chromeHeaderH = 0 // 顶部 header 已移除 - chromeFooterH = 1 // 单行底部状态栏(含上下文操作 + 全局导航,已合并) + // chrome 各组成部分的行数(三行 AppHeader + 双行 Hotkeys) + chromeHeaderH = 3 + chromeHotkeysH = 2 // panelBorderV 是单个面板的上下边框行数之和。 panelBorderV = 2 @@ -19,16 +19,114 @@ const ( // ── PageLayout ──────────────────────────────────────────────────────────────── -// PageLayout 描述一个完整页面的 chrome(底部 ContextBar + Footer)。 -// 各页面 Render 函数先构造 PageLayout,再调用 Assemble 拼装最终输出。 +// PageLayout 描述一个完整页面的共享 chrome(AppHeader + Hotkeys)。 +// 各页面 Render 函数只提供标题、状态信息和底部 Hotkeys,Assemble 负责统一拼装。 type PageLayout struct { - CtxItems []ContextBarItem - FooterParts []string + HeaderTitle string + HeaderSubtitle string + HeaderMeta string + HeaderInfoLeft []string + HeaderInfoRight []string + Hotkeys PageHotkeys } -// ChromeHeight 返回 chrome 占用的总行数(当前仅包含合并底栏)。 +// PageHotkeys 描述页面底部统一的 Hotkeys 区域。 +// Hotkeys 用于当前页快捷操作,Hints 用于返回、退出等全局提示。 +type PageHotkeys struct { + Hotkeys []HotkeyItem + Hints []HotkeyItem +} + +// NewPageHotkeys 用于构建统一的页面 Hotkeys。 +func NewPageHotkeys(hotkeys []HotkeyItem, hints ...string) PageHotkeys { + return PageHotkeys{ + Hotkeys: hotkeys, + Hints: HotkeyTexts(hints...), + } +} + +// PageFrame 描述页面主内容区的统一尺寸。 +// OuterWidth 是最外层内容面板总宽度,InnerWidth/InnerHeight 是面板内部可用区域。 +type PageFrame struct { + OuterWidth int + InnerWidth int + InnerHeight int +} + +// PanelFrame 描述嵌套子面板的统一尺寸。 +// OuterWidth 是含边框总宽度,InnerWidth 是面板内可用宽度。 +type PanelFrame struct { + OuterWidth int + InnerWidth int +} + +// ChromeHeight 返回 chrome 占用的总行数。 func (l PageLayout) ChromeHeight() int { - return chromeHeaderH + chromeFooterH + return chromeHeaderH + chromeHotkeysH +} + +// Frame 统一计算页面主内容区的外层与内层尺寸。 +func (l PageLayout) Frame(totalW, totalH int) PageFrame { + if totalW < 1 { + totalW = 1 + } + return PageFrame{ + OuterWidth: totalW, + InnerWidth: ContentWidth(totalW), + InnerHeight: l.ContentHeight(totalH), + } +} + +// Wrap 用统一外层面板包裹页面内容。 +func (f PageFrame) Wrap(st Styles, content string) string { + return wrapPanel(st, content, f.OuterWidth) +} + +// InnerPanel 返回可用于嵌套子面板的统一 frame。 +func (f PageFrame) InnerPanel() PanelFrame { + return NewPanelFrame(f.InnerWidth) +} + +// NewPanelFrame 创建一个统一的子面板尺寸描述。 +func NewPanelFrame(outerW int) PanelFrame { + if outerW < 1 { + outerW = 1 + } + return PanelFrame{OuterWidth: outerW, InnerWidth: ContentWidth(outerW)} +} + +// Wrap 用统一子面板包裹内容。 +func (f PanelFrame) Wrap(st Styles, content string) string { + return wrapPanel(st, content, f.OuterWidth) +} + +// Split 按比例拆分左右子面板宽度,避免各页重复手写宽度和最小值逻辑。 +func (f PanelFrame) Split(leftPercent, minLeftOuter int) (PanelFrame, PanelFrame) { + total := f.OuterWidth + if total <= 1 { + return NewPanelFrame(total), NewPanelFrame(1) + } + if leftPercent <= 0 || leftPercent >= 100 { + leftPercent = 50 + } + if minLeftOuter < 1 { + minLeftOuter = 1 + } + + leftOuter := total * leftPercent / 100 + if leftOuter < minLeftOuter { + leftOuter = minLeftOuter + } + if leftOuter >= total { + leftOuter = total - 1 + } + rightOuter := total - leftOuter + if rightOuter < 1 { + rightOuter = 1 + leftOuter = total - rightOuter + } + + return NewPanelFrame(leftOuter), NewPanelFrame(rightOuter) } // ContentHeight 返回单面板页面主内容区的可用行数 @@ -51,23 +149,46 @@ func ContentWidth(totalW int) int { return w } -// Assemble 拼装完整页面输出: -// -// content -// 底栏(上下文操作 · 全局导航,合并为单行) -func (l PageLayout) Assemble(content string, st Styles, width int) string { - // 将上下文操作与全局导航合并为单条底栏,用 · 分隔 - var barParts []string - for _, item := range l.CtxItems { - barParts = append(barParts, "["+item.Key+"] "+item.Desc) +// PanelContentHeight 将含边框高度转换为子面板可用内容高度。 +func PanelContentHeight(outerH int) int { + h := outerH - panelBorderV + if h < 1 { + h = 1 } - if len(l.CtxItems) > 0 && len(l.FooterParts) > 0 { - barParts = append(barParts, "·") + return h +} + +// RemainingStackOuterHeight 计算纵向堆叠场景下,最后一个区块可用的外层高度。 +// 会统一扣除前置区块自身高度,以及区块之间的换行间隔,避免各页重复手写偏移逻辑。 +func RemainingStackOuterHeight(totalH int, fixedOuterHeights ...int) int { + remaining := totalH - len(fixedOuterHeights) + for _, h := range fixedOuterHeights { + remaining -= h + } + minOuterHeight := panelBorderV + 1 + if remaining < minOuterHeight { + remaining = minOuterHeight + } + return remaining +} + +// joinVerticalBlocks 统一拼接纵向区块,避免各页自行处理空块和换行。 +func joinVerticalBlocks(blocks ...string) string { + var visible []string + for _, block := range blocks { + if block != "" { + visible = append(visible, block) + } } - barParts = append(barParts, l.FooterParts...) - footer := renderFooter(st, width, barParts...) + return strings.Join(visible, "\n") +} + +// Assemble 拼装完整页面输出:header + content + hotkeys。 +func (l PageLayout) Assemble(content string, st Styles, width int) string { + header := renderHeader(st, width, l.HeaderTitle, l.HeaderSubtitle, l.HeaderMeta, l.HeaderInfoLeft, l.HeaderInfoRight) + hotkeys := renderHotkeys(st, width, l.Hotkeys) - return strings.Join([]string{content, footer}, "\n") + return strings.Join([]string{header, content, hotkeys}, "\n") } // ── 最小尺寸保护 ────────────────────────────────────────────────────────────── diff --git a/internal/tui/pages/layout_test.go b/internal/tui/pages/layout_test.go new file mode 100644 index 0000000..a28ed87 --- /dev/null +++ b/internal/tui/pages/layout_test.go @@ -0,0 +1,65 @@ +package pages + +import ( + "strings" + "testing" +) + +func TestPageLayoutAssembleRendersSharedChrome(t *testing.T) { + st := NewStyles() + l := PageLayout{ + HeaderTitle: "任务中心", + HeaderSubtitle: "创建任务、运行压测、查看执行记录与导出报告", + HeaderMeta: "2 个任务", + HeaderInfoLeft: []string{"运行中 1"}, + HeaderInfoRight: []string{"最近运行 1 分钟前"}, + Hotkeys: NewPageHotkeys([]HotkeyItem{ + HotkeyAction("Enter", "查看详情"), + }, "[q] 退出"), + } + + rendered := stripANSI(l.Assemble("content", st, 80)) + lines := strings.Split(rendered, "\n") + if len(lines) < 6 { + t.Fatalf("expected shared chrome to add header/hotkeys lines, got %d lines", len(lines)) + } + if !strings.Contains(rendered, "AIT") || !strings.Contains(rendered, "任务中心") { + t.Fatalf("expected header brand/title in output: %q", rendered) + } + if !strings.Contains(rendered, "创建任务、运行压测") { + t.Fatalf("expected header subtitle in output: %q", rendered) + } + if !strings.Contains(rendered, "查看详情") || !strings.Contains(rendered, "[q] 退出") { + t.Fatalf("expected hotkey actions and global hints in output: %q", rendered) + } +} + +func TestPageLayoutFrameCalculatesNestedPanelSizes(t *testing.T) { + l := PageLayout{} + frame := l.Frame(80, 30) + if frame.OuterWidth != 80 || frame.InnerWidth != 78 || frame.InnerHeight != 23 { + t.Fatalf("unexpected page frame: %#v", frame) + } + + body := frame.InnerPanel() + if body.OuterWidth != 78 || body.InnerWidth != 76 { + t.Fatalf("unexpected inner panel frame: %#v", body) + } +} + +func TestRemainingStackOuterHeightAccountsForJoinGaps(t *testing.T) { + totalHeight := 24 + remaining := RemainingStackOuterHeight(totalHeight, 9, 3) + if remaining != 10 { + t.Fatalf("expected remaining outer height 10, got %d", remaining) + } + + used := 9 + 1 + 3 + 1 + remaining + if used != totalHeight { + t.Fatalf("expected stacked blocks to fit exactly, used %d of %d", used, totalHeight) + } + + if PanelContentHeight(remaining) != 8 { + t.Fatalf("expected remaining content height 8, got %d", PanelContentHeight(remaining)) + } +} diff --git a/internal/tui/pages/reqdetail.go b/internal/tui/pages/reqdetail.go index c8d2eaa..cb92203 100644 --- a/internal/tui/pages/reqdetail.go +++ b/internal/tui/pages/reqdetail.go @@ -2,10 +2,8 @@ package pages import ( "fmt" - "strings" tea "github.com/charmbracelet/bubbletea" - "github.com/charmbracelet/lipgloss" "github.com/yinxulai/ait/internal/server" "github.com/yinxulai/ait/internal/types" ) @@ -117,94 +115,104 @@ func RenderReqDetail(s *ReqDetailState, taskName string, st Styles, width, heigh idx = len(s.Requests) - 1 } r := s.Requests[idx] + status := "失败" + if r.Success { + status = "成功" + } l := PageLayout{ - CtxItems: CtxBar_ReqDetail(), - FooterParts: []string{"[q] 退出"}, + HeaderTitle: "请求详情", + HeaderSubtitle: "查看单次请求的耗时、网络阶段和完整报文", + HeaderMeta: truncate(string(s.RunID), 18), + HeaderInfoLeft: []string{fmt.Sprintf("请求 %d/%d", idx+1, len(s.Requests)), status}, + HeaderInfoRight: []string{fmt.Sprintf("缓存 %.0f%%", r.CacheHitRate*100), "耗时 " + fmtDuration(r.TotalTime)}, + Hotkeys: NewPageHotkeys(Hotkeys_ReqDetail(), "[b/Esc] 返回上一页", "[q] 退出"), } - innerW := ContentWidth(width) - innerH := l.ContentHeight(height) + frame := l.Frame(width, height) + bodyPanel := frame.InnerPanel() // ── 计算高度 ── - splitH := 9 + splitOuterH := 9 inputH := 5 - outputH := innerH - splitH - inputH - 2 - 2 // -2 for input panel border, -2 for output panel border - if outputH < 4 { - outputH = 4 - } + inputOuterH := inputH + panelBorderV + outputOuterH := RemainingStackOuterHeight(frame.InnerHeight, splitOuterH, inputOuterH) + outputH := PanelContentHeight(outputOuterH) // ── 双栏面板(性能指标 | 网络指标)── - leftW := innerW * 50 / 100 - rightW := innerW - leftW - leftContent := buildReqPerfPanel(r, st, splitH-2, leftW-2) - rightContent := buildReqNetworkPanel(r, st, splitH-2, rightW-2) - leftPanel := st.Panel.Width(leftW - 2).Render(leftContent) - rightPanel := st.Panel.Width(rightW - 2).Render(rightContent) - split := lipgloss.JoinHorizontal(lipgloss.Top, leftPanel, rightPanel) + leftPanelFrame, rightPanelFrame := bodyPanel.Split(50, 24) + leftContent := buildReqPerfPanel(r, st, PanelContentHeight(splitOuterH), leftPanelFrame.InnerWidth) + rightContent := buildReqNetworkPanel(r, st, PanelContentHeight(splitOuterH), rightPanelFrame.InnerWidth) + split := renderSplitPanels(st, leftPanelFrame, rightPanelFrame, leftContent, rightContent) // ── 输入区面板 ── - inputSection := buildInputSection(r, st, ContentWidth(innerW), inputH) - inputPanelStr := wrapPanel(st, inputSection, innerW) + inputSection := buildInputSection(r, st, bodyPanel.InnerWidth, inputH) + inputPanelStr := bodyPanel.Wrap(st, inputSection) // ── 输出区面板 ── - outputSection := buildOutputSection(r, s.ScrollY, st, ContentWidth(innerW), outputH) - outputPanelStr := wrapPanel(st, outputSection, innerW) + outputSection := buildOutputSection(r, s.ScrollY, st, bodyPanel.InnerWidth, outputH) + outputPanelStr := bodyPanel.Wrap(st, outputSection) - content := strings.Join([]string{split, inputPanelStr, outputPanelStr}, "\n") - return l.Assemble(wrapPanel(st, content, width), st, width) + content := joinVerticalBlocks(split, inputPanelStr, outputPanelStr) + return l.Assemble(frame.Wrap(st, content), st, width) } // buildReqPerfPanel 构建请求左侧性能指标面板。 func buildReqPerfPanel(r *types.RequestMetrics, st Styles, maxH, width int) string { - var lines []string - lines = append(lines, " "+st.SectionHead.Render("性能指标")) - lines = append(lines, "") + lines := panelTitleLines(st, "性能指标", width, true) if r == nil { lines = append(lines, " "+st.Muted.Render("等待数据...")) - for len(lines) < maxH { - lines = append(lines, "") - } - return strings.Join(lines[:maxH], "\n") + return finishPanelLines(lines, maxH) } statusStr := st.Ok.Render("✓ 成功") if !r.Success { statusStr = st.ErrStyle.Render("✗ 失败") } - lines = append(lines, " "+labelValue(st, "状态 ", statusStr)) + totalTime := "─" + if r.TotalTime > 0 { + totalTime = fmtDuration(r.TotalTime) + } + ttft := "─" + if r.TTFT > 0 { + ttft = fmtDuration(r.TTFT) + } + tps := "─" + if r.TPS > 0 { + tps = fmt.Sprintf("%.1f tok/s", r.TPS) + } + tokenSummary := fmt.Sprintf("%d in / %d out", r.PromptTokens, r.CompletionTokens) + cacheSummary := fmt.Sprintf("%d tok (%.1f%%)", r.CachedTokens, r.CacheHitRate*100) + errorSummary := "—" + if !r.Success { + errorSummary = normalizeInlineText(r.ErrorMessage) + if errorSummary == "" { + errorSummary = "请求失败" + } + errorSummary = truncate(errorSummary, maxInt(8, width-8)) + } + lines = append(lines, " "+labelValue(st, "状态 ", statusStr)) + lines = append(lines, " "+labelValue(st, "总耗时 ", st.MetricVal.Render(totalTime))) + lines = append(lines, " "+labelValue(st, "TTFT ", st.MetricVal.Render(ttft))) + lines = append(lines, " "+labelValue(st, "输出TPS ", st.MetricVal.Render(tps))) + lines = append(lines, " "+labelValue(st, "令牌 ", tokenSummary)) if r.Success { - lines = append(lines, " "+labelValue(st, "总耗时 ", st.MetricVal.Render(fmtDuration(r.TotalTime)))) - lines = append(lines, " "+labelValue(st, "TTFT ", st.MetricVal.Render(fmtDuration(r.TTFT)))) - lines = append(lines, " "+labelValue(st, "输出TPS ", st.MetricVal.Render(fmt.Sprintf("%.1f tok/s", r.TPS)))) - lines = append(lines, " "+labelValue(st, "输入Token", fmt.Sprintf("%d", r.PromptTokens))) - lines = append(lines, " "+labelValue(st, "输出Token", fmt.Sprintf("%d", r.CompletionTokens))) - lines = append(lines, " "+labelValue(st, "缓存命中", fmt.Sprintf("%d tok (%.1f%%)", r.CachedTokens, r.CacheHitRate*100))) + lines = append(lines, " "+labelValue(st, "缓存 ", cacheSummary)) } else { - if r.ErrorMessage != "" { - lines = append(lines, " "+st.ErrStyle.Render("错误: "+truncate(r.ErrorMessage, width-8))) - } + lines = append(lines, " "+st.ErrStyle.Render("错误: "+errorSummary)) } - for len(lines) < maxH { - lines = append(lines, "") - } - return strings.Join(lines[:maxH], "\n") + return finishPanelLines(lines, maxH) } // buildReqNetworkPanel 构建请求右侧网络指标面板。 func buildReqNetworkPanel(r *types.RequestMetrics, st Styles, maxH, width int) string { - var lines []string - lines = append(lines, " "+st.SectionHead.Render("网络指标")) - lines = append(lines, "") + lines := panelTitleLines(st, "网络指标", width, true) if r == nil { lines = append(lines, " "+st.Muted.Render("等待数据...")) - for len(lines) < maxH { - lines = append(lines, "") - } - return strings.Join(lines[:maxH], "\n") + return finishPanelLines(lines, maxH) } lines = append(lines, " "+labelValue(st, "DNS ", fmtDuration(r.DNSTime))) @@ -214,16 +222,12 @@ func buildReqNetworkPanel(r *types.RequestMetrics, st Styles, maxH, width int) s lines = append(lines, " "+labelValue(st, "目标 IP ", r.TargetIP)) } - for len(lines) < maxH { - lines = append(lines, "") - } - return strings.Join(lines[:maxH], "\n") + return finishPanelLines(lines, maxH) } // buildInputSection 构建输入 (请求体) 区域。 func buildInputSection(r *types.RequestMetrics, st Styles, width, maxH int) string { - var lines []string - lines = append(lines, " "+st.SectionHead.Render("请求体 (Request Body)")) + lines := panelTitleLines(st, "请求体 (Request Body)", width, true) lines = append(lines, " "+dividerLine(st, width-2)) if r.RequestBody == "" { @@ -237,16 +241,12 @@ func buildInputSection(r *types.RequestMetrics, st Styles, width, maxH int) stri } } - for len(lines) < maxH { - lines = append(lines, "") - } - return strings.Join(lines[:maxH], "\n") + return finishPanelLines(lines, maxH) } // buildOutputSection 构建输出 (响应体) 区域。 func buildOutputSection(r *types.RequestMetrics, scrollY int, st Styles, width, maxH int) string { - var lines []string - lines = append(lines, " "+st.SectionHead.Render("响应体 (Response Body)")) + lines := panelTitleLines(st, "响应体 (Response Body)", width, true) lines = append(lines, " "+dividerLine(st, width-2)) if r.ResponseBody == "" { @@ -270,8 +270,5 @@ func buildOutputSection(r *types.RequestMetrics, scrollY int, st Styles, width, } } - for len(lines) < maxH { - lines = append(lines, "") - } - return strings.Join(lines[:maxH], "\n") + return finishPanelLines(lines, maxH) } diff --git a/internal/tui/pages/reqdetail_test.go b/internal/tui/pages/reqdetail_test.go new file mode 100644 index 0000000..39e670d --- /dev/null +++ b/internal/tui/pages/reqdetail_test.go @@ -0,0 +1,56 @@ +package pages + +import ( + "strings" + "testing" + "time" + + "github.com/yinxulai/ait/internal/server" + "github.com/yinxulai/ait/internal/types" +) + +func TestRenderReqDetailKeepsSameHeightForSuccessAndFailure(t *testing.T) { + st := NewStyles() + success := &ReqDetailState{ + RunID: server.RunID("run_success"), + Requests: []*types.RequestMetrics{{ + Success: true, + TotalTime: 250 * time.Millisecond, + TTFT: 80 * time.Millisecond, + TPS: 12.5, + PromptTokens: 64, + CompletionTokens: 128, + CachedTokens: 32, + CacheHitRate: 0.5, + DNSTime: 2 * time.Millisecond, + ConnectTime: 3 * time.Millisecond, + TLSTime: 4 * time.Millisecond, + TargetIP: "1.2.3.4", + RequestBody: "hello", + ResponseBody: strings.Repeat("ok ", 50), + }}, + } + failure := &ReqDetailState{ + RunID: server.RunID("run_failure"), + Requests: []*types.RequestMetrics{{ + Success: false, + TotalTime: 250 * time.Millisecond, + DNSTime: 2 * time.Millisecond, + ConnectTime: 3 * time.Millisecond, + TLSTime: 4 * time.Millisecond, + RequestBody: "hello", + ResponseBody: "", + ErrorMessage: "dial tcp:\nlookup api.example.com: no such host", + }}, + } + + successLines := strings.Split(stripANSI(RenderReqDetail(success, "示例任务", st, 96, 30)), "\n") + failureLines := strings.Split(stripANSI(RenderReqDetail(failure, "示例任务", st, 96, 30)), "\n") + if len(successLines) != len(failureLines) { + t.Fatalf("expected success/failure render heights to match, got %d vs %d", len(successLines), len(failureLines)) + } + + if strings.Contains(stripANSI(RenderReqDetail(failure, "示例任务", st, 96, 30)), "dial tcp:\n") { + t.Fatalf("expected failure error summary to be normalized into a single visual line") + } +} diff --git a/internal/tui/pages/styles.go b/internal/tui/pages/styles.go index 0b497b2..6551e71 100644 --- a/internal/tui/pages/styles.go +++ b/internal/tui/pages/styles.go @@ -4,22 +4,22 @@ import "github.com/charmbracelet/lipgloss" // Color palette const ( - colorHeaderBg = lipgloss.Color("17") // dark navy — refined header background - colorFooterBg = lipgloss.Color("235") // near-black footer background - colorCtxBarBg = lipgloss.Color("237") // slightly lighter than footer — context bar - colorPink = lipgloss.Color("205") // vivid hot pink/magenta — brand primary - colorCyan = lipgloss.Color("86") // bright aquamarine — table headers - colorPurple = lipgloss.Color("99") // medium violet - colorPurpleDim = lipgloss.Color("60") // slate purple — selected row bg - colorGreen = lipgloss.Color("78") // vivid spring green — ok/success - colorRed = lipgloss.Color("204") // vivid rose-red — error/fail - colorYellow = lipgloss.Color("221") // warm yellow — metric values - colorTeal = lipgloss.Color("111") // periwinkle-teal — labels - colorWhite = lipgloss.Color("255") // bright white - colorMuted = lipgloss.Color("245") // muted gray - colorGold = lipgloss.Color("214") // amber - colorHeaderFg = lipgloss.Color("248") // light gray — header info text - colorDivider = lipgloss.Color("241") // dim border gray — slightly more visible + colorHeaderBg = lipgloss.Color("17") // dark navy — refined header background + colorHotkeysSecondaryBg = lipgloss.Color("235") // near-black secondary hotkeys background + colorHotkeysPrimaryBg = lipgloss.Color("237") // slightly lighter primary hotkeys background + colorPink = lipgloss.Color("205") // vivid hot pink/magenta — brand primary + colorCyan = lipgloss.Color("86") // bright aquamarine — table headers + colorPurple = lipgloss.Color("99") // medium violet + colorPurpleDim = lipgloss.Color("60") // slate purple — selected row bg + colorGreen = lipgloss.Color("78") // vivid spring green — ok/success + colorRed = lipgloss.Color("204") // vivid rose-red — error/fail + colorYellow = lipgloss.Color("221") // warm yellow — metric values + colorTeal = lipgloss.Color("111") // periwinkle-teal — labels + colorWhite = lipgloss.Color("255") // bright white + colorMuted = lipgloss.Color("245") // muted gray + colorGold = lipgloss.Color("214") // amber + colorHeaderFg = lipgloss.Color("248") // light gray — header info text + colorDivider = lipgloss.Color("241") // dim border gray — slightly more visible ) // Styles 汇聚所有 TUI 样式,由 NewStyles() 初始化。 @@ -27,8 +27,9 @@ type Styles struct { Panel lipgloss.Style Header lipgloss.Style HeaderInfo lipgloss.Style - Footer lipgloss.Style - CtxBar lipgloss.Style + HotkeysSecondary lipgloss.Style + HotkeysPrimary lipgloss.Style + PanelHead lipgloss.Style SectionHead lipgloss.Style TableHead lipgloss.Style TableRow lipgloss.Style @@ -57,14 +58,18 @@ func NewStyles() Styles { Foreground(colorWhite). Bold(true), HeaderInfo: lipgloss.NewStyle(). - Background(colorHeaderBg). + Background(colorHotkeysPrimaryBg). Foreground(colorHeaderFg), - Footer: lipgloss.NewStyle(). - Background(colorFooterBg). + HotkeysSecondary: lipgloss.NewStyle(). + Background(colorHotkeysSecondaryBg). Foreground(colorMuted), - CtxBar: lipgloss.NewStyle(). - Background(colorCtxBarBg). + HotkeysPrimary: lipgloss.NewStyle(). + Background(colorHotkeysPrimaryBg). Foreground(colorWhite), + PanelHead: lipgloss.NewStyle(). + Background(lipgloss.Color("234")). + Foreground(colorPink). + Bold(true), SectionHead: lipgloss.NewStyle(). Foreground(colorPink). Bold(true), diff --git a/internal/tui/pages/taskdetail.go b/internal/tui/pages/taskdetail.go index be34db8..b69d3c4 100644 --- a/internal/tui/pages/taskdetail.go +++ b/internal/tui/pages/taskdetail.go @@ -161,7 +161,7 @@ func RenderTaskDetail(s *TaskDetailState, st Styles, width, height int) string { t := s.Task inp := t.Input - var cbItems []ContextBarItem + var cbItems []HotkeyItem hasActive := s.ActiveRun != nil effectiveLen := len(taskDetailHistoryEntries(s)) if hasActive { @@ -169,35 +169,48 @@ func RenderTaskDetail(s *TaskDetailState, st Styles, width, height int) string { } switch { case hasActive: - cbItems = CtxBar_TaskDetail_Running() + cbItems = Hotkeys_TaskDetail_Running() case effectiveLen > 0: - cbItems = CtxBar_TaskDetail_HasHistory() + cbItems = Hotkeys_TaskDetail_HasHistory() default: - cbItems = CtxBar_TaskDetail_NoHistory() + cbItems = Hotkeys_TaskDetail_NoHistory() + } + modeStr := "标准模式" + if inp.Turbo { + modeStr = "Turbo 模式" + } + headerRight := []string{"暂无运行记录"} + historyCount := len(taskDetailHistoryEntries(s)) + if historyCount > 0 { + headerRight = []string{fmt.Sprintf("历史 %d 条", historyCount)} + } + if hasActive { + headerRight = append([]string{"运行中"}, headerRight...) } l := PageLayout{ - CtxItems: cbItems, - FooterParts: []string{"[b/Esc] 返回上一页", "◆ AIT v0.1"}, + HeaderTitle: truncate(t.Name, 28), + HeaderSubtitle: "查看任务配置、当前运行状态与历史记录", + HeaderMeta: "任务详情", + HeaderInfoLeft: []string{modeStr, inp.NormalizedProtocol()}, + HeaderInfoRight: headerRight, + Hotkeys: NewPageHotkeys(cbItems, "[b/Esc] 返回上一页", "[q] 退出"), } + frame := l.Frame(width, height) - content := buildTaskDetailContent(s, st, t, inp, ContentWidth(width), l.ContentHeight(height)) - return l.Assemble(wrapPanel(st, content, width), st, width) + content := buildTaskDetailContent(s, st, t, inp, frame.InnerWidth, frame.InnerHeight) + return l.Assemble(frame.Wrap(st, content), st, width) } // buildTaskDetailContent 构建任务详情内容区(左右双栏布局)。 // 左栏(40%):配置摘要 右栏(60%):历史运行记录 func buildTaskDetailContent(s *TaskDetailState, st Styles, t types.TaskDefinition, inp types.Input, width, maxH int) string { - leftW := width * 4 / 10 - if leftW < 26 { - leftW = 26 - } - rightW := width - leftW - 1 // 1 列用于 │ 分隔符 + bodyPanel := NewPanelFrame(width) + leftPanelFrame, rightPanelFrame := bodyPanel.Split(40, 28) + panelContentH := PanelContentHeight(maxH) // ─── 左栏:配置摘要 ───────────────────────────────────────── - var leftLines []string - leftLines = append(leftLines, padRight(" "+st.SectionHead.Render("配置摘要"), leftW)) - leftLines = append(leftLines, padRight(st.Divider.Render(strings.Repeat("─", leftW)), leftW)) - leftLines = append(leftLines, padRight("", leftW)) + leftW := leftPanelFrame.InnerWidth + leftLines := panelTitleLines(st, "配置摘要", leftW, false) proto := inp.NormalizedProtocol() leftLines = append(leftLines, padRight(" "+st.Label.Render("协议")+" "+st.Value.Render(proto), leftW)) @@ -233,12 +246,12 @@ func buildTaskDetailContent(s *TaskDetailState, st Styles, t types.TaskDefinitio leftLines = append(leftLines, padRight(" "+st.Label.Render("流式")+" "+st.Value.Render(boolLabel(inp.Stream)), leftW)) prompt := promptSummary(inp.PromptMode, inp.PromptText, inp.PromptFile, inp.PromptLength) leftLines = append(leftLines, padRight(" "+st.Label.Render("Prompt")+" "+st.Value.Render(truncate(prompt, leftW-12)), leftW)) + leftContent := finishPanelLines(leftLines, panelContentH) // ─── 右栏:历史运行记录 ───────────────────────────────────── - var rightLines []string + rightW := rightPanelFrame.InnerWidth + rightLines := panelTitleLines(st, "历史运行记录", rightW, false) historyEntries := taskDetailHistoryEntries(s) - rightLines = append(rightLines, padRight(" "+st.SectionHead.Render("历史运行记录"), rightW)) - rightLines = append(rightLines, padRight(st.Divider.Render(strings.Repeat("─", rightW)), rightW)) hasActive := s.ActiveRun != nil effectiveLen := len(historyEntries) @@ -256,7 +269,7 @@ func buildTaskDetailContent(s *TaskDetailState, st Styles, t types.TaskDefinitio ) hdr := padRight("", markW) + padRight("", statW) + padRight("时间", timeW) + padRight("模式", modeW) + padRight("成功率", rateW) + padRight("TTFT", ttftW) + "TPS" - rightLines = append(rightLines, padRight(renderTableHeader(st, rightW, hdr), rightW)) + rightLines = append(rightLines, renderTableHeader(st, rightW, hdr)) rightLines = append(rightLines, padRight(st.Divider.Render(strings.Repeat("─", rightW)), rightW)) if effectiveLen == 0 { @@ -277,13 +290,13 @@ func buildTaskDetailContent(s *TaskDetailState, st Styles, t types.TaskDefinitio detailLines = buildTaskHistoryDetailLines(historyEntries, histIdx, st, rightW) } } - tableMaxH := maxH - len(detailLines) + tableMaxH := panelContentH - len(detailLines) if tableMaxH < 5 { - allowedDetail := maxInt(0, maxH-5) + allowedDetail := maxInt(0, panelContentH-5) if len(detailLines) > allowedDetail { detailLines = detailLines[:allowedDetail] } - tableMaxH = maxH - len(detailLines) + tableMaxH = panelContentH - len(detailLines) } s.HistoryVis = listVisibleItems(tableMaxH, 4) s.HistoryOff = ensureVisibleOffset(s.HistorySel, effectiveLen, s.HistoryOff, s.HistoryVis) @@ -354,20 +367,8 @@ func buildTaskDetailContent(s *TaskDetailState, st Styles, t types.TaskDefinitio } rightLines = append(rightLines, detailLines...) } - - // ─── 合并双栏 ────────────────────────────────────────────── - for len(leftLines) < maxH { - leftLines = append(leftLines, padRight("", leftW)) - } - for len(rightLines) < maxH { - rightLines = append(rightLines, padRight("", rightW)) - } - sep := st.Divider.Render("│") - var combined []string - for i := 0; i < maxH; i++ { - combined = append(combined, leftLines[i]+sep+rightLines[i]) - } - return strings.Join(combined, "\n") + rightContent := finishPanelLines(rightLines, panelContentH) + return renderSplitPanels(st, leftPanelFrame, rightPanelFrame, leftContent, rightContent) } // buildMetricRow 构建指标表格一行。 diff --git a/internal/tui/pages/tasklist.go b/internal/tui/pages/tasklist.go index a42b6c2..927b6da 100644 --- a/internal/tui/pages/tasklist.go +++ b/internal/tui/pages/tasklist.go @@ -138,32 +138,46 @@ func RenderTaskList(s *TaskListState, st Styles, width, height int) string { return renderTooSmall(st, width, height) } - var cbItems []ContextBarItem + var cbItems []HotkeyItem if t, ok := s.CurrentTask(); ok { if s.IsTaskRunning(t.ID) { - cbItems = CtxBar_TaskList_Running() + cbItems = Hotkeys_TaskList_Running() } else { - cbItems = CtxBar_TaskList_Normal() + cbItems = Hotkeys_TaskList_Normal() } + } else { + cbItems = []HotkeyItem{HotkeyAction("a", "新建任务")} + } + runningCount := 0 + for _, rs := range s.ActiveRuns { + if rs != nil && rs.Status == server.RunStatusRunning { + runningCount++ + } + } + headerRight := []string{"暂无运行历史"} + if latest := s.latestRunAt(); latest != nil { + headerRight = []string{"最近运行 " + fmtRelativeTime(*latest)} + } + if t, ok := s.CurrentTask(); ok { + headerRight = append([]string{"当前 " + truncate(t.Name, 22)}, headerRight...) } l := PageLayout{ - CtxItems: cbItems, - FooterParts: []string{"[↑↓] 选择", "[a] 新建", "[q] 退出", "◆ AIT v0.1"}, + HeaderTitle: "任务中心", + HeaderSubtitle: "创建任务、运行压测、查看执行记录与导出报告", + HeaderMeta: fmt.Sprintf("%d 个任务", len(s.Tasks)), + HeaderInfoLeft: []string{fmt.Sprintf("运行中 %d", runningCount)}, + HeaderInfoRight: headerRight, + Hotkeys: NewPageHotkeys(cbItems, "[↑↓] 选择", "[a] 新建", "[q] 退出"), } + frame := l.Frame(width, height) - content := buildTaskListContent(s, st, ContentWidth(width), l.ContentHeight(height)) - return l.Assemble(wrapPanel(st, content, width), st, width) + content := buildTaskListContent(s, st, frame.InnerWidth, frame.InnerHeight) + return l.Assemble(frame.Wrap(st, content), st, width) } // buildTaskListContent 构建任务列表内容区(含表头 + 任务条目)。 func buildTaskListContent(s *TaskListState, st Styles, width, maxH int) string { var lines []string - showHero := width >= 60 && maxH >= 14 - if showHero { - heroLines := renderWelcomeHero(st, width) - lines = append(lines, heroLines...) - lines = append(lines, dividerLine(st, width)) - } listTopLines := len(lines) // 列宽(gap=2 作为列间距内置到每个非末尾列的宽度中) diff --git a/internal/tui/pages/turbodash.go b/internal/tui/pages/turbodash.go index 74cdf48..b83f367 100644 --- a/internal/tui/pages/turbodash.go +++ b/internal/tui/pages/turbodash.go @@ -151,58 +151,71 @@ func RenderTurboDash(d *TurboDashState, taskName string, st Styles, width, heigh isRunning := d.IsRunning() hasSel := d.LevelSel >= 0 && rs != nil && d.LevelSel < len(rs.Levels) - var cbItems []ContextBarItem + var cbItems []HotkeyItem switch { case hasSel && isRunning: - cbItems = CtxBar_TurboDash_Running_Sel() + cbItems = Hotkeys_TurboDash_Running_Sel() case hasSel && !isRunning: - cbItems = CtxBar_TurboDash_Done_Sel() + cbItems = Hotkeys_TurboDash_Done_Sel() case !hasSel && isRunning: - cbItems = CtxBar_TurboDash_Running_NoSel() + cbItems = Hotkeys_TurboDash_Running_NoSel() default: - cbItems = CtxBar_TurboDash_Done_NoSel() + cbItems = Hotkeys_TurboDash_Done_NoSel() + } + headerLeft := []string{"等待数据"} + headerRight := []string{} + if rs != nil { + headerLeft = []string{runStatusText(string(rs.Status)), fmt.Sprintf("完成 %d/%d", rs.DoneReqs, rs.TotalReqs)} + currentLevel := rs.CurrentLevel + 1 + if currentLevel < 1 { + currentLevel = 1 + } + headerRight = []string{fmt.Sprintf("级别 %d", currentLevel)} + if len(rs.Levels) > 0 { + headerRight = append(headerRight, fmt.Sprintf("已探测 %d 档", len(rs.Levels))) + } + } + if d.TaskID != "" { + headerRight = append(headerRight, "任务 "+truncate(d.TaskID, 14)) } l := PageLayout{ - CtxItems: cbItems, - FooterParts: []string{"[b/Esc] 返回上一页", "[q] 退出"}, + HeaderTitle: "Turbo 探测监控", + HeaderSubtitle: "观察并发爬坡过程、级别指标与稳定区间", + HeaderMeta: "Turbo 模式", + HeaderInfoLeft: headerLeft, + HeaderInfoRight: headerRight, + Hotkeys: NewPageHotkeys(cbItems, "[b/Esc] 返回上一页", "[q] 退出"), } - innerW := ContentWidth(width) - innerH := l.ContentHeight(height) + frame := l.Frame(width, height) + bodyPanel := frame.InnerPanel() // ── 计算高度 ── - splitH := 9 - progressPanel := 3 - levelListH := innerH - splitH - progressPanel - 2 - if levelListH < 3 { - levelListH = 3 - } + splitOuterH := 9 + progressOuterH := 3 + levelOuterH := RemainingStackOuterHeight(frame.InnerHeight, splitOuterH, progressOuterH) + levelListH := PanelContentHeight(levelOuterH) // ── 双栏面板(任务参数 | 当前级别指标)── - leftW := innerW * 45 / 100 - rightW := innerW - leftW - leftContent := buildTurboDashParams(rs, st, splitH-2, leftW-2) - rightContent := buildTurboDashMetrics(rs, st, splitH-2, rightW-2) - leftPanel := st.Panel.Width(leftW - 2).Render(leftContent) - rightPanel := st.Panel.Width(rightW - 2).Render(rightContent) - split := lipgloss.JoinHorizontal(lipgloss.Top, leftPanel, rightPanel) + leftPanelFrame, rightPanelFrame := bodyPanel.Split(45, 24) + leftContent := buildTurboDashParams(rs, st, PanelContentHeight(splitOuterH), leftPanelFrame.InnerWidth) + rightContent := buildTurboDashMetrics(rs, st, PanelContentHeight(splitOuterH), rightPanelFrame.InnerWidth) + split := renderSplitPanels(st, leftPanelFrame, rightPanelFrame, leftContent, rightContent) // ── 进度条面板 ── - progressLine := buildTurboProgressLine(rs, st, ContentWidth(innerW)) - progressPanelStr := wrapPanel(st, progressLine, innerW) + progressLine := buildTurboProgressLine(rs, st, bodyPanel.InnerWidth) + progressPanelStr := bodyPanel.Wrap(st, progressLine) // ── 级别列表面板 ── - levelList := buildLevelList(d, rs, st, ContentWidth(innerW), levelListH) - levelPanelStr := wrapPanel(st, levelList, innerW) + levelList := buildLevelList(d, rs, st, bodyPanel.InnerWidth, levelListH) + levelPanelStr := bodyPanel.Wrap(st, levelList) - content := strings.Join([]string{split, progressPanelStr, levelPanelStr}, "\n") - return l.Assemble(wrapPanel(st, content, width), st, width) + content := joinVerticalBlocks(split, progressPanelStr, levelPanelStr) + return l.Assemble(frame.Wrap(st, content), st, width) } // buildTurboDashParams 构建 Turbo 仪表盘左侧任务参数面板。 func buildTurboDashParams(rs *server.RunState, st Styles, maxH, width int) string { - var lines []string - lines = append(lines, " "+st.SectionHead.Render("任务参数")) - lines = append(lines, "") + lines := panelTitleLines(st, "任务参数", width, false) if rs == nil { lines = append(lines, " "+st.Muted.Render("等待数据...")) @@ -215,10 +228,7 @@ func buildTurboDashParams(rs *server.RunState, st Styles, maxH, width int) strin } } - for len(lines) < maxH { - lines = append(lines, "") - } - return strings.Join(lines[:maxH], "\n") + return finishPanelLines(lines, maxH) } // buildTurboDashMetrics 构建 Turbo 仪表盘右侧当前级别实时指标面板。 @@ -229,8 +239,7 @@ func buildTurboDashMetrics(rs *server.RunState, st Styles, maxH, width int) stri if rs != nil { curLevel = rs.CurrentLevel } - lines = append(lines, " "+st.SectionHead.Render(fmt.Sprintf("当前级别实时指标 [并发 = %d]", curLevel))) - lines = append(lines, "") + lines = panelTitleLines(st, fmt.Sprintf("当前级别实时指标 [并发 = %d]", curLevel), width, false) if rs == nil { lines = append(lines, " "+st.Muted.Render("等待数据...")) @@ -241,10 +250,7 @@ func buildTurboDashMetrics(rs *server.RunState, st Styles, maxH, width int) stri lines = append(lines, " "+labelValue(st, "Cache ", st.MetricVal.Render(fmt.Sprintf("%.1f%%", rs.CacheHitRate*100)))) } - for len(lines) < maxH { - lines = append(lines, "") - } - return strings.Join(lines[:maxH], "\n") + return finishPanelLines(lines, maxH) } // buildTurboProgressLine 构建 Turbo 模式进度条行。 @@ -276,15 +282,11 @@ func buildTurboProgressLine(rs *server.RunState, st Styles, width int) string { // buildLevelList 构建 Turbo 级别列表区域。 func buildLevelList(d *TurboDashState, rs *server.RunState, st Styles, width, maxH int) string { - var lines []string - lines = append(lines, " "+st.SectionHead.Render("级别列表")) + lines := panelTitleLines(st, "级别列表", width, true) if rs == nil || len(rs.Levels) == 0 { lines = append(lines, " "+st.Muted.Render("等待第一个级别完成...")) - for len(lines) < maxH { - lines = append(lines, "") - } - return strings.Join(lines, "\n") + return finishPanelLines(lines, maxH) } // 列宽(header 与 content 行保持一致,前缀均为 2 字符) @@ -349,8 +351,5 @@ func buildLevelList(d *TurboDashState, rs *server.RunState, st Styles, width, ma } } - for len(lines) < maxH { - lines = append(lines, "") - } - return strings.Join(lines[:maxH], "\n") + return finishPanelLines(lines, maxH) } diff --git a/internal/tui/pages/wizard.go b/internal/tui/pages/wizard.go index 9de5b9a..e7753e6 100644 --- a/internal/tui/pages/wizard.go +++ b/internal/tui/pages/wizard.go @@ -558,44 +558,50 @@ func RenderWizard(wz *WizardState, st Styles, width, height int) string { if wz == nil { return renderTooSmall(st, width, height) } - - l := PageLayout{ - CtxItems: wizardContextItems(wz.Step), - FooterParts: []string{"[q] 退出", "◆ AIT v0.1"}, - } - - content := buildWizardPageContent(wz, st, ContentWidth(width), l.ContentHeight(height)) - return l.Assemble(wrapPanel(st, content, width), st, width) -} - -func buildWizardPageContent(wz *WizardState, st Styles, width, maxH int) string { stepTitles := []string{"基本信息", "测试参数", "确认保存"} stepDescs := []string{ "配置任务名称、模型协议和连接信息。", "选择压测模式,并补全并发与 Prompt 参数。", "保存前快速检查关键配置。", } + stepTitle := stepTitles[int(wz.Step)] + headerLeft := []string{stepTitle} + if wz.Protocol != "" { + headerLeft = append(headerLeft, strings.ToUpper(wz.Protocol)) + } + headerRight := []string{} + if wz.Step >= wizardStep2 { + if wz.Turbo { + headerRight = append(headerRight, "Turbo 模式") + } else { + headerRight = append(headerRight, "标准模式") + } + } + if wz.Model != "" { + headerRight = append(headerRight, "模型 "+truncate(wz.Model, 18)) + } action := "创建任务" if wz.EditingID != "" { action = "编辑任务" } - stepTitle := stepTitles[int(wz.Step)] - stepDesc := stepDescs[int(wz.Step)] - titleLeft := st.SectionHead.Render(action) - titleRight := st.Muted.Render(fmt.Sprintf("步骤 %d/3 · %s", int(wz.Step)+1, stepTitle)) - var topLines []string - if lipgloss.Width(titleLeft)+lipgloss.Width(titleRight)+2 <= width { - topLines = append(topLines, titleLeft+strings.Repeat(" ", width-lipgloss.Width(titleLeft)-lipgloss.Width(titleRight))+titleRight) - } else { - topLines = append(topLines, titleLeft, titleRight) - } - if maxH >= 8 { - for _, line := range wrapText(stepDesc, width) { - topLines = append(topLines, st.Muted.Render(line)) - } + l := PageLayout{ + HeaderTitle: action, + HeaderSubtitle: stepDescs[int(wz.Step)], + HeaderMeta: fmt.Sprintf("步骤 %d/3", int(wz.Step)+1), + HeaderInfoLeft: headerLeft, + HeaderInfoRight: headerRight, + Hotkeys: NewPageHotkeys(wizardHotkeyItems(wz.Step), "[q] 退出"), } - if maxH >= 10 && width >= 46 { + frame := l.Frame(width, height) + + content := buildWizardPageContent(wz, st, frame.InnerWidth, frame.InnerHeight) + return l.Assemble(frame.Wrap(st, content), st, width) +} + +func buildWizardPageContent(wz *WizardState, st Styles, width, maxH int) string { + var topLines []string + if maxH >= 8 && width >= 46 { topLines = append(topLines, renderWizardStepStrip(wz.Step)) } @@ -614,7 +620,7 @@ func buildWizardPageContent(wz *WizardState, st Styles, width, maxH int) string if len(topLines) > maxTopH { topLines = topLines[:maxTopH] } - if maxH >= 6 { + if len(topLines) > 0 && maxH >= 6 { topLines = append(topLines, dividerLine(st, width)) } @@ -871,14 +877,14 @@ func appendWizardSummaryRow(lines *[]string, st Styles, label, value string, wid } } -func wizardContextItems(step wizardStep) []ContextBarItem { +func wizardHotkeyItems(step wizardStep) []HotkeyItem { switch step { case wizardStep1: - return CtxBar_Wizard_Step1() + return Hotkeys_Wizard_Step1() case wizardStep2: - return CtxBar_Wizard_Step2() + return Hotkeys_Wizard_Step2() default: - return CtxBar_Wizard_Step3() + return Hotkeys_Wizard_Step3() } } From c479d163b4ca86d0b8c57bcd94de623fb2439cc4 Mon Sep 17 00:00:00 2001 From: Alain Date: Wed, 20 May 2026 22:47:56 +0800 Subject: [PATCH 23/52] =?UTF-8?q?feat:=20=E9=87=8D=E6=9E=84=E9=A1=B5?= =?UTF-8?q?=E9=9D=A2=E5=B8=83=E5=B1=80=E5=92=8C=E6=B8=B2=E6=9F=93=E9=80=BB?= =?UTF-8?q?=E8=BE=91=EF=BC=8C=E4=BC=98=E5=8C=96=E9=9D=A2=E6=9D=BF=E5=B0=BA?= =?UTF-8?q?=E5=AF=B8=E8=AE=A1=E7=AE=97=E5=92=8C=E5=86=85=E5=AE=B9=E5=8C=85?= =?UTF-8?q?=E8=A3=85?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/tui/pages/helpers.go | 38 ++++++++++++++++++++++--------- internal/tui/pages/layout.go | 15 ++++++------ internal/tui/pages/layout_test.go | 4 ++-- internal/tui/pages/tasklist.go | 6 ++--- internal/tui/pages/wizard.go | 6 ++--- 5 files changed, 43 insertions(+), 26 deletions(-) diff --git a/internal/tui/pages/helpers.go b/internal/tui/pages/helpers.go index e041258..2da4912 100644 --- a/internal/tui/pages/helpers.go +++ b/internal/tui/pages/helpers.go @@ -153,26 +153,42 @@ func renderHeader(st Styles, width int, title, subtitle, meta string, infoLeft, title = "AIT" } - // AIT ASCII 字符画(三行,实心彩色渐变) - artA := [3]string{" ██ ", " █ █ ", "██████"} // 可视宽 6 - artI := [3]string{" █ ", " █ ", " █ "} // 可视宽 3 - artT := [3]string{"█████", " █ ", " █ "} // 可视宽 5 + // AIT ASCII 字符画(三行,粗体像素字体,实心彩色) + // A (10) I (5) T (10) + // ████ █████ ██████████ + // ██ ██ █ ██ + // ████████ █████ ██ + artA := [3]string{ + " ████ ", // 可视宽 10 + " ██ ██ ", // 可视宽 10 + " ████████ ", // 可视宽 10 + } + artI := [3]string{ + "█████", // 可视宽 5 + " █ ", // 可视宽 5 + "█████", // 可视宽 5 + } + artT := [3]string{ + "██████████", // 可视宽 10 + " ██ ", // 可视宽 10 + " ██ ", // 可视宽 10 + } styleA := lipgloss.NewStyle().Foreground(colorPink).Bold(true) styleI := lipgloss.NewStyle().Foreground(colorGold).Bold(true) styleT := lipgloss.NewStyle().Foreground(colorCyan).Bold(true) - styleSep := lipgloss.NewStyle().Foreground(colorDivider) + styleSep := lipgloss.NewStyle().Foreground(colorPink) - // artVisW = 6+1+3+1+5 = 16; plus " │ " (3) + leading space (1) = total art prefix 20 - artVisW := 6 + 1 + 3 + 1 + 5 - artSepW := artVisW + 1 + 3 // art + " " + "│" + " " + // artVisW = 10+2+5+2+10 = 29; artSepW = " "(1) + art(29) + " "(1) + "┃"(1) + " "(1) = 33 + artVisW := 10 + 2 + 5 + 2 + 10 + artSepW := artVisW + 4 artRow := func(i int) string { - return styleA.Render(artA[i]) + " " + styleI.Render(artI[i]) + " " + styleT.Render(artT[i]) + return styleA.Render(artA[i]) + " " + styleI.Render(artI[i]) + " " + styleT.Render(artT[i]) } - vsep := styleSep.Render("│") + vsep := styleSep.Render("┃") - wideEnough := w >= 48 // 宽屏才展示 ASCII art + wideEnough := w >= 65 // 宽屏才展示 ASCII art // ── Line 1: [art row 0] │ title [meta badge] ──────── right1 := "" diff --git a/internal/tui/pages/layout.go b/internal/tui/pages/layout.go index 5dcdcb5..3cad1aa 100644 --- a/internal/tui/pages/layout.go +++ b/internal/tui/pages/layout.go @@ -66,20 +66,21 @@ func (l PageLayout) ChromeHeight() int { } // Frame 统一计算页面主内容区的外层与内层尺寸。 +// 外层无边框,InnerWidth 等于 OuterWidth(内层子面板自行管理各自的边框)。 func (l PageLayout) Frame(totalW, totalH int) PageFrame { if totalW < 1 { totalW = 1 } return PageFrame{ OuterWidth: totalW, - InnerWidth: ContentWidth(totalW), + InnerWidth: totalW, InnerHeight: l.ContentHeight(totalH), } } -// Wrap 用统一外层面板包裹页面内容。 -func (f PageFrame) Wrap(st Styles, content string) string { - return wrapPanel(st, content, f.OuterWidth) +// Wrap 保留接口兼容,外层不再添加边框,直接返回内容。 +func (f PageFrame) Wrap(_ Styles, content string) string { + return content } // InnerPanel 返回可用于嵌套子面板的统一 frame。 @@ -129,10 +130,10 @@ func (f PanelFrame) Split(leftPercent, minLeftOuter int) (PanelFrame, PanelFrame return NewPanelFrame(leftOuter), NewPanelFrame(rightOuter) } -// ContentHeight 返回单面板页面主内容区的可用行数 -// (总高度 - chrome 行数 - 面板上下边框)。 +// ContentHeight 返回页面主内容区的可用行数(总高度 - chrome 行数)。 +// 外层不再有边框,故不扣除 panelBorderV。 func (l PageLayout) ContentHeight(totalH int) int { - h := totalH - l.ChromeHeight() - panelBorderV + h := totalH - l.ChromeHeight() if h < 2 { h = 2 } diff --git a/internal/tui/pages/layout_test.go b/internal/tui/pages/layout_test.go index a28ed87..83601b7 100644 --- a/internal/tui/pages/layout_test.go +++ b/internal/tui/pages/layout_test.go @@ -37,12 +37,12 @@ func TestPageLayoutAssembleRendersSharedChrome(t *testing.T) { func TestPageLayoutFrameCalculatesNestedPanelSizes(t *testing.T) { l := PageLayout{} frame := l.Frame(80, 30) - if frame.OuterWidth != 80 || frame.InnerWidth != 78 || frame.InnerHeight != 23 { + if frame.OuterWidth != 80 || frame.InnerWidth != 80 || frame.InnerHeight != 25 { t.Fatalf("unexpected page frame: %#v", frame) } body := frame.InnerPanel() - if body.OuterWidth != 78 || body.InnerWidth != 76 { + if body.OuterWidth != 80 || body.InnerWidth != 78 { t.Fatalf("unexpected inner panel frame: %#v", body) } } diff --git a/internal/tui/pages/tasklist.go b/internal/tui/pages/tasklist.go index 927b6da..fdbfb45 100644 --- a/internal/tui/pages/tasklist.go +++ b/internal/tui/pages/tasklist.go @@ -170,9 +170,9 @@ func RenderTaskList(s *TaskListState, st Styles, width, height int) string { Hotkeys: NewPageHotkeys(cbItems, "[↑↓] 选择", "[a] 新建", "[q] 退出"), } frame := l.Frame(width, height) - - content := buildTaskListContent(s, st, frame.InnerWidth, frame.InnerHeight) - return l.Assemble(frame.Wrap(st, content), st, width) + panel := NewPanelFrame(frame.OuterWidth) + content := buildTaskListContent(s, st, panel.InnerWidth, PanelContentHeight(frame.InnerHeight)) + return l.Assemble(panel.Wrap(st, content), st, width) } // buildTaskListContent 构建任务列表内容区(含表头 + 任务条目)。 diff --git a/internal/tui/pages/wizard.go b/internal/tui/pages/wizard.go index e7753e6..3d0b3c8 100644 --- a/internal/tui/pages/wizard.go +++ b/internal/tui/pages/wizard.go @@ -594,9 +594,9 @@ func RenderWizard(wz *WizardState, st Styles, width, height int) string { Hotkeys: NewPageHotkeys(wizardHotkeyItems(wz.Step), "[q] 退出"), } frame := l.Frame(width, height) - - content := buildWizardPageContent(wz, st, frame.InnerWidth, frame.InnerHeight) - return l.Assemble(frame.Wrap(st, content), st, width) + panel := NewPanelFrame(frame.OuterWidth) + content := buildWizardPageContent(wz, st, panel.InnerWidth, PanelContentHeight(frame.InnerHeight)) + return l.Assemble(panel.Wrap(st, content), st, width) } func buildWizardPageContent(wz *WizardState, st Styles, width, maxH int) string { From 53870b876d5c8a091954baa992d673e30bd1bf42 Mon Sep 17 00:00:00 2001 From: Alain Date: Wed, 20 May 2026 23:04:39 +0800 Subject: [PATCH 24/52] =?UTF-8?q?feat:=20=E4=BC=98=E5=8C=96=E4=BB=AA?= =?UTF-8?q?=E8=A1=A8=E7=9B=98=E8=AF=B7=E6=B1=82=E5=88=97=E8=A1=A8=E5=92=8C?= =?UTF-8?q?=E4=BB=BB=E5=8A=A1=E8=AF=A6=E6=83=85=E5=AF=BC=E8=88=AA=E9=80=BB?= =?UTF-8?q?=E8=BE=91=EF=BC=8C=E9=81=BF=E5=85=8D=E5=8E=86=E5=8F=B2=E8=B7=AF?= =?UTF-8?q?=E5=BE=84=E5=BE=AA=E7=8E=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/tui/pages/dashboard.go | 2 +- internal/tui/pages/taskdetail.go | 9 +++++++-- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/internal/tui/pages/dashboard.go b/internal/tui/pages/dashboard.go index 2a7fa8f..2133273 100644 --- a/internal/tui/pages/dashboard.go +++ b/internal/tui/pages/dashboard.go @@ -364,7 +364,7 @@ func buildRequestList(d *DashboardState, rs *server.RunState, st Styles, width, marker := selectionMarker(isSel) rowContent := padRight(marker, markW) + - padRight(fmt.Sprintf("#%d", r.Index+1), idW) + + padRight(fmt.Sprintf("#%d", len(reqs)-pos), idW) + padRight(statusStr, statW) + padRight(totalStr, timeW) + padRight(fmtDuration(r.TTFT), ttftW) + diff --git a/internal/tui/pages/taskdetail.go b/internal/tui/pages/taskdetail.go index b69d3c4..8282683 100644 --- a/internal/tui/pages/taskdetail.go +++ b/internal/tui/pages/taskdetail.go @@ -66,8 +66,13 @@ func HandleTaskDetailKey(s *TaskDetailState, msg tea.KeyMsg, client Client) (*Ta case "enter": if s.HistorySel >= 0 && s.HistorySel < effectiveLen { if hasActive && s.HistorySel == 0 { - // 进入正在运行的仪表盘 - nav = NavAction{To: NavRunDetail, RunID: s.ActiveRun.RunID} + // 进入正在运行的仪表盘,直接导航,避免走 FromHistory 路径 + // (FromHistory 路径会覆盖 dash.BackNav,导致循环:dashboard ↔ taskdetail) + if s.ActiveRun.Mode == "turbo" { + nav = NavAction{To: NavTurboDash} + } else { + nav = NavAction{To: NavDashboard} + } } else { histIdx := s.HistorySel if hasActive { From 968b7e1ba33cf376c39f370f8ac340c91d43a26d Mon Sep 17 00:00:00 2001 From: Alain Date: Wed, 20 May 2026 23:29:07 +0800 Subject: [PATCH 25/52] =?UTF-8?q?feat:=20=E6=9B=B4=E6=96=B0=E5=AF=BC?= =?UTF-8?q?=E8=88=AA=E9=80=BB=E8=BE=91=E4=BB=A5=E6=94=AF=E6=8C=81=E4=BB=BB?= =?UTF-8?q?=E5=8A=A1=E8=AF=A6=E6=83=85=E8=BF=94=E5=9B=9E=EF=BC=8C=E4=BC=98?= =?UTF-8?q?=E5=8C=96=20API=20=E5=AF=86=E9=92=A5=E6=8E=A9=E7=A0=81=E5=A4=84?= =?UTF-8?q?=E7=90=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/tui/model.go | 17 +++-------------- internal/tui/pages/helpers.go | 6 +++--- 2 files changed, 6 insertions(+), 17 deletions(-) diff --git a/internal/tui/model.go b/internal/tui/model.go index b23415d..b9e497b 100644 --- a/internal/tui/model.go +++ b/internal/tui/model.go @@ -129,15 +129,18 @@ func (m *Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { case RunStartedMsg: ch, cancel, firstCmd := m.client.SubscribeCmd(msg.RunID) taskMode := m.getTaskMode(msg.TaskID) + backNav := pages.NavAction{To: pages.NavTaskDetail, TaskID: msg.TaskID} if taskMode == "turbo" { m.turboDash = pages.NewTurboDashState(msg.RunID, msg.TaskID) m.turboDash.EventCh = ch m.turboDash.CancelFn = cancel + m.turboDash.BackNav = backNav m.view = viewTurboDash } else { m.dash = pages.NewDashboardState(msg.RunID, msg.TaskID) m.dash.EventCh = ch m.dash.CancelFn = cancel + m.dash.BackNav = backNav m.view = viewDashboard } if m.taskList != nil { @@ -539,20 +542,6 @@ func (m *Model) currentRunTaskID(isDash bool) string { } func (m *Model) taskDetailBackNav() pages.NavAction { - switch m.view { - case viewDashboard: - if m.dash != nil { - return pages.NavAction{To: pages.NavDashboard} - } - case viewTurboDash: - if m.turboDash != nil { - return pages.NavAction{To: pages.NavTurboDash} - } - case viewTaskDetail: - if m.detail != nil && m.detail.BackNav.To != pages.NavNone { - return m.detail.BackNav - } - } return pages.NavAction{To: pages.NavTaskList} } diff --git a/internal/tui/pages/helpers.go b/internal/tui/pages/helpers.go index 2da4912..7b8ef82 100644 --- a/internal/tui/pages/helpers.go +++ b/internal/tui/pages/helpers.go @@ -595,11 +595,11 @@ func maskAPIKey(key string) string { func shortProtocol(p string) string { switch p { case "openai-completions": - return "completions" + return "openai-completions" case "openai-responses": - return "responses" + return "openai-responses" case "anthropic-messages": - return "messages" + return "anthropic-messages" default: return p } From 37fa0e38aa2d83fbf0a3fced397faa043a1697c9 Mon Sep 17 00:00:00 2001 From: Alain Date: Wed, 20 May 2026 23:32:41 +0800 Subject: [PATCH 26/52] =?UTF-8?q?feat:=20=E6=9B=B4=E6=96=B0=E4=BB=BB?= =?UTF-8?q?=E5=8A=A1=E5=88=97=E8=A1=A8=E6=B8=B2=E6=9F=93=E9=80=BB=E8=BE=91?= =?UTF-8?q?=EF=BC=8C=E8=B0=83=E6=95=B4=E5=88=97=E5=AE=BD=E4=BB=A5=E9=80=82?= =?UTF-8?q?=E5=BA=94=E5=86=85=E5=AE=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/tui/pages/tasklist.go | 24 +++++------------------- 1 file changed, 5 insertions(+), 19 deletions(-) diff --git a/internal/tui/pages/tasklist.go b/internal/tui/pages/tasklist.go index fdbfb45..0b24bf8 100644 --- a/internal/tui/pages/tasklist.go +++ b/internal/tui/pages/tasklist.go @@ -117,21 +117,7 @@ func HandleTaskListKey(s *TaskListState, msg tea.KeyMsg, client Client) (*TaskLi } // RenderTaskList 渲染任务列表页。 -// -// 设计稿布局: -// -// ╔══ AIT 任务中心 ══════════════╗ -// ║ ◆ AIT 已保存任务: N 最近运行: xxx ║ -// ╠══════════════════════════════╣ -// ║ 任务名称 模式 协议 上次结果 ║ -// ║ ─────────────────────────── ║ -// ║ ▶ ◉ name 标准 responses ✓ 98.5% ║ -// ║ model 并发10 请求200 ◉ 47/100 ║ -// ║ ║ -// ╠══════════════════════════════╣ -// ║ [Enter] 详情 [a] 新建 ... ║ ← context bar -// ╠══════════════════════════════╣ -// ║ [↑↓] 选择 [q] 退出 ◆ AIT ║ + // ╚══════════════════════════════╝ func RenderTaskList(s *TaskListState, st Styles, width, height int) string { if TooSmall(width, height) { @@ -183,10 +169,10 @@ func buildTaskListContent(s *TaskListState, st Styles, width, maxH int) string { // 列宽(gap=2 作为列间距内置到每个非末尾列的宽度中) const ( modeW = 9 // 7 + 2 gap - protoW = 12 // 10 + 2 gap - lastRunW = 13 // 11 + 2 gap - ttftW = 12 // 10 + 2 gap - tpsW = 9 // 末尾列,无需额外 gap + protoW = 20 // 10 + 2 gap + lastRunW = 16 // 11 + 2 gap + ttftW = 16 // 10 + 2 gap + tpsW = 16 // 末尾列,无需额外 gap ) fixedW := 2 + modeW + protoW + lastRunW + ttftW + tpsW nameW := maxInt(10, width-fixedW) From eddd5ac666f7ff2fbaaf0694baef520dc03b668a Mon Sep 17 00:00:00 2001 From: Alain Date: Wed, 20 May 2026 23:58:14 +0800 Subject: [PATCH 27/52] =?UTF-8?q?feat:=20=E6=B7=BB=E5=8A=A0=E4=BB=A3?= =?UTF-8?q?=E7=90=86=E9=85=8D=E7=BD=AE=E9=A1=B5=E9=9D=A2=E5=8F=8A=E7=9B=B8?= =?UTF-8?q?=E5=85=B3=E9=80=BB=E8=BE=91=EF=BC=8C=E6=94=AF=E6=8C=81=E5=85=A8?= =?UTF-8?q?=E5=B1=80=20HTTP=20=E4=BB=A3=E7=90=86=E8=AE=BE=E7=BD=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/config/config.go | 1 + internal/server/run.go | 8 +++ internal/server/server.go | 8 +++ internal/server/task.go | 22 ++++++ internal/tui/client.go | 23 +++++++ internal/tui/messages.go | 10 +++ internal/tui/model.go | 28 +++++++- internal/tui/model_test.go | 3 + internal/tui/pages/contextbar.go | 10 +++ internal/tui/pages/nav.go | 5 ++ internal/tui/pages/proxy.go | 111 +++++++++++++++++++++++++++++++ internal/tui/pages/tasklist.go | 3 + internal/tui/pages/wizard.go | 6 -- 13 files changed, 231 insertions(+), 7 deletions(-) create mode 100644 internal/tui/pages/proxy.go diff --git a/internal/config/config.go b/internal/config/config.go index ecb2fae..c9940fe 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -21,6 +21,7 @@ type Config struct { SaveAPIKey bool `json:"save_api_key"` LastSelectedTaskID string `json:"last_selected_task_id,omitempty"` DefaultProtocol string `json:"default_protocol,omitempty"` + ProxyURL string `json:"proxy_url,omitempty"` } func Load() (*Config, error) { diff --git a/internal/server/run.go b/internal/server/run.go index 76a670e..e818080 100644 --- a/internal/server/run.go +++ b/internal/server/run.go @@ -8,6 +8,7 @@ import ( "time" "github.com/yinxulai/ait/internal/client" + "github.com/yinxulai/ait/internal/config" "github.com/yinxulai/ait/internal/report" "github.com/yinxulai/ait/internal/runner" "github.com/yinxulai/ait/internal/store" @@ -227,6 +228,13 @@ func (s *serverImpl) StartRun(taskID string) (RunID, error) { return "", fmt.Errorf("hydrate input: %w", err) } + // 若任务未单独配置代理,使用全局配置中的代理地址 + if hydratedInput.ProxyURL == "" { + if cfg, err := config.Load(); err == nil { + hydratedInput.ProxyURL = cfg.ProxyURL + } + } + runID := RunID(fmt.Sprintf("run_%d", time.Now().UnixNano())) now := time.Now() diff --git a/internal/server/server.go b/internal/server/server.go index 03d4b58..e1448fd 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -40,6 +40,14 @@ type Server interface { // GenerateReport 为已完成的运行生成报告文件,返回文件路径。 GenerateReport(runID RunID, format ReportFormat) (string, error) + + // --- 全局配置 --- + + // GetConfig 返回当前全局配置。 + GetConfig() (*config.Config, error) + + // SetProxyURL 更新并持久化全局代理 URL。 + SetProxyURL(proxyURL string) error } // serverImpl 是 Server 的具体实现。 diff --git a/internal/server/task.go b/internal/server/task.go index e0be2a8..cd191dc 100644 --- a/internal/server/task.go +++ b/internal/server/task.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" + "github.com/yinxulai/ait/internal/config" storepkg "github.com/yinxulai/ait/internal/store" "github.com/yinxulai/ait/internal/types" ) @@ -153,3 +154,24 @@ func (s *serverImpl) hasRunningTaskLocked(taskID string) bool { } return false } + +// ─── 全局配置 ───────────────────────────────────────────────────────────────── + +// GetConfig 返回当前全局配置。 +func (s *serverImpl) GetConfig() (*config.Config, error) { + cfg, err := config.Load() + if err != nil { + return &config.Config{}, nil // 文件不存在时返回空配置 + } + return cfg, nil +} + +// SetProxyURL 更新并持久化全局代理 URL。 +func (s *serverImpl) SetProxyURL(proxyURL string) error { + cfg, err := config.Load() + if err != nil { + cfg = &config.Config{} + } + cfg.ProxyURL = proxyURL + return cfg.Save() +} diff --git a/internal/tui/client.go b/internal/tui/client.go index 065555d..902e392 100644 --- a/internal/tui/client.go +++ b/internal/tui/client.go @@ -198,3 +198,26 @@ func (c *Client) GenerateReportCmd(runID server.RunID, format server.ReportForma return ReportGeneratedMsg{RunID: runID, Path: path} } } + +// ─── 全局配置 ───────────────────────────────────────────────────────────────── + +// LoadProxyConfigCmd 异步加载全局代理配置。 +func (c *Client) LoadProxyConfigCmd() tea.Cmd { + return func() tea.Msg { + cfg, err := c.srv.GetConfig() + if err != nil { + return ErrorMsg{Err: fmt.Errorf("加载配置失败: %w", err)} + } + return ProxyConfigLoadedMsg{ProxyURL: cfg.ProxyURL} + } +} + +// SaveProxyConfigCmd 异步保存全局代理配置。 +func (c *Client) SaveProxyConfigCmd(proxyURL string) tea.Cmd { + return func() tea.Msg { + if err := c.srv.SetProxyURL(proxyURL); err != nil { + return ErrorMsg{Err: fmt.Errorf("保存代理配置失败: %w", err)} + } + return ProxyConfigSavedMsg{ProxyURL: proxyURL} + } +} diff --git a/internal/tui/messages.go b/internal/tui/messages.go index c572f1b..fee072f 100644 --- a/internal/tui/messages.go +++ b/internal/tui/messages.go @@ -54,3 +54,13 @@ type ReportGeneratedMsg struct { type ErrorMsg struct { Err error } + +// ProxyConfigLoadedMsg 全局代理配置加载完成。 +type ProxyConfigLoadedMsg struct { + ProxyURL string +} + +// ProxyConfigSavedMsg 全局代理配置保存完成。 +type ProxyConfigSavedMsg struct { + ProxyURL string +} diff --git a/internal/tui/model.go b/internal/tui/model.go index b9e497b..9d236bd 100644 --- a/internal/tui/model.go +++ b/internal/tui/model.go @@ -22,6 +22,7 @@ const ( viewDashboard viewState = "dashboard" viewTurboDash viewState = "turbo-dash" viewReqDetail viewState = "req-detail" + viewProxy viewState = "proxy" ) // ─── 根 Model ───────────────────────────────────────────────────────────────── @@ -38,12 +39,13 @@ type Model struct { err error // 页面局部状态(由 pages 包管理) - taskList *pages.TaskListState + taskList *pages.TaskListState detail *pages.TaskDetailState wizard *pages.WizardState dash *pages.DashboardState turboDash *pages.TurboDashState reqDetail *pages.ReqDetailState + proxyConf *pages.ProxyConfigState } // NewModel 创建 Model。srv 不能为 nil。 @@ -193,6 +195,17 @@ func (m *Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { case ErrorMsg: m.err = msg.Err return m, nil + + // ── 代理配置 ── + case ProxyConfigLoadedMsg: + if m.proxyConf != nil { + m.proxyConf.URL = msg.ProxyURL + } + return m, nil + + case ProxyConfigSavedMsg: + m.status = "代理配置已保存" + return m, nil } return m, nil @@ -219,6 +232,8 @@ func (m *Model) View() string { content = pages.RenderTurboDash(m.turboDash, m.turboDashTaskName(), m.styles, innerW, innerH) case viewReqDetail: content = pages.RenderReqDetail(m.reqDetail, m.reqDetailTaskName(), m.styles, innerW, innerH) + case viewProxy: + content = pages.RenderProxyConfig(m.proxyConf, m.styles, innerW, innerH) default: content = "未知视图" } @@ -264,6 +279,12 @@ func (m *Model) handleKey(msg tea.KeyMsg) (tea.Model, tea.Cmd) { newState, nav := pages.HandleReqDetailKey(m.reqDetail, msg) m.reqDetail = newState return m, m.handleNav(nav) + + case viewProxy: + newState, cmd, nav := pages.HandleProxyConfigKey(m.proxyConf, msg, m.client) + m.proxyConf = newState + navCmd := m.handleNav(nav) + return m, tea.Batch(cmd, navCmd) } return m, nil @@ -345,6 +366,11 @@ func (m *Model) handleNav(nav pages.NavAction) tea.Cmd { m.view = viewReqDetail return nil + case pages.NavProxy: + m.proxyConf = pages.NewProxyConfigState("") + m.view = viewProxy + return m.client.LoadProxyConfigCmd() + case pages.NavQuit: return tea.Quit } diff --git a/internal/tui/model_test.go b/internal/tui/model_test.go index 2385dcd..9c8cc37 100644 --- a/internal/tui/model_test.go +++ b/internal/tui/model_test.go @@ -3,6 +3,7 @@ package tui import ( "testing" + "github.com/yinxulai/ait/internal/config" "github.com/yinxulai/ait/internal/server" "github.com/yinxulai/ait/internal/tui/pages" "github.com/yinxulai/ait/internal/types" @@ -31,6 +32,8 @@ func (s *stubServer) GetHistory(taskID string, limit int) ([]types.TaskRunSummar func (s *stubServer) GenerateReport(runID server.RunID, fmt server.ReportFormat) (string, error) { return "", nil } +func (s *stubServer) GetConfig() (*config.Config, error) { return &config.Config{}, nil } +func (s *stubServer) SetProxyURL(proxyURL string) error { return nil } // ─── NewModel ───────────────────────────────────────────────────────────────── diff --git a/internal/tui/pages/contextbar.go b/internal/tui/pages/contextbar.go index ea5ca93..4f4c7f8 100644 --- a/internal/tui/pages/contextbar.go +++ b/internal/tui/pages/contextbar.go @@ -35,6 +35,7 @@ func Hotkeys_TaskList_Normal() []HotkeyItem { HotkeyAction("e", "编辑"), HotkeyAction("d", "删除"), HotkeyAction("y", "复制"), + HotkeyAction("p", "代理配置"), } } @@ -44,6 +45,7 @@ func Hotkeys_TaskList_Running() []HotkeyItem { HotkeyAction("Enter", "查看详情"), HotkeyAction("s", "停止"), HotkeyAction("y", "复制"), + HotkeyAction("p", "代理配置"), } } @@ -187,3 +189,11 @@ func Hotkeys_ReqDetail() []HotkeyItem { HotkeyAction("←→", "上/下一条请求"), } } + +// Hotkeys_ProxyConfig 代理配置页。 +func Hotkeys_ProxyConfig() []HotkeyItem { + return []HotkeyItem{ + HotkeyAction("Enter", "保存"), + HotkeyAction("Ctrl+U", "清空"), + } +} diff --git a/internal/tui/pages/nav.go b/internal/tui/pages/nav.go index a698508..7f70876 100644 --- a/internal/tui/pages/nav.go +++ b/internal/tui/pages/nav.go @@ -21,6 +21,7 @@ const ( NavTurboDash // 进入 Turbo 仪表盘(需 RunID + TaskID) NavRunDetail // 从历史记录进入某次运行的仪表盘(需 RunID) NavReqDetail // 进入请求详情(需 ReqIndex) + NavProxy // 进入代理配置页 NavQuit // 退出程序 ) @@ -53,4 +54,8 @@ type Client interface { GetRunStateCmd(runID server.RunID) tea.Cmd GetRunStateForHistoryCmd(runID server.RunID, summary *types.TaskRunSummary) tea.Cmd GenerateReportCmd(runID server.RunID, format server.ReportFormat) tea.Cmd + + // 全局配置 + SaveProxyConfigCmd(proxyURL string) tea.Cmd + LoadProxyConfigCmd() tea.Cmd } diff --git a/internal/tui/pages/proxy.go b/internal/tui/pages/proxy.go new file mode 100644 index 0000000..48ac68c --- /dev/null +++ b/internal/tui/pages/proxy.go @@ -0,0 +1,111 @@ +package pages + +import ( + "strings" + + tea "github.com/charmbracelet/bubbletea" + "github.com/charmbracelet/lipgloss" +) + +// ProxyConfigState 代理配置页面状态。 +type ProxyConfigState struct { + URL string // 当前编辑中的代理 URL +} + +// NewProxyConfigState 创建代理配置页面状态,传入当前已保存的代理 URL。 +func NewProxyConfigState(currentURL string) *ProxyConfigState { + return &ProxyConfigState{URL: currentURL} +} + +// HandleProxyConfigKey 处理代理配置页面的按键。 +func HandleProxyConfigKey(s *ProxyConfigState, msg tea.KeyMsg, client Client) (*ProxyConfigState, tea.Cmd, NavAction) { + nav := NavAction{} + if s == nil { + return s, nil, NavAction{To: NavTaskList} + } + + switch msg.String() { + case "esc": + nav = NavAction{To: NavTaskList} + + case "enter": + cmd := client.SaveProxyConfigCmd(s.URL) + nav = NavAction{To: NavTaskList} + return s, cmd, nav + + case "backspace": + r := []rune(s.URL) + if len(r) > 0 { + s.URL = string(r[:len(r)-1]) + } + + case "ctrl+u": + s.URL = "" + + case "q", "ctrl+c": + nav = NavAction{To: NavQuit} + + default: + if len(msg.Runes) > 0 { + s.URL += string(msg.Runes) + } + } + + return s, nil, nav +} + +// RenderProxyConfig 渲染代理配置页面。 +func RenderProxyConfig(s *ProxyConfigState, st Styles, width, height int) string { + if TooSmall(width, height) { + return renderTooSmall(st, width, height) + } + if s == nil { + return renderTooSmall(st, width, height) + } + + l := PageLayout{ + HeaderTitle: "代理配置", + HeaderSubtitle: "设置全局 HTTP 代理,适用于所有任务的请求。留空则使用系统环境变量或直连。", + HeaderMeta: "全局配置", + Hotkeys: NewPageHotkeys(Hotkeys_ProxyConfig(), "[Esc] 返回", "[q] 退出"), + } + frame := l.Frame(width, height) + panel := NewPanelFrame(frame.OuterWidth) + + content := buildProxyConfigContent(s, st, panel.InnerWidth, PanelContentHeight(frame.InnerHeight)) + return l.Assemble(panel.Wrap(st, content), st, width) +} + +func buildProxyConfigContent(s *ProxyConfigState, st Styles, contentW, maxH int) string { + var lines []string + + lines = append(lines, st.SectionHead.Render("代理地址")) + lines = append(lines, "") + + // 字段宽度(与 wizard renderWizardField 保持一致) + fieldW := maxInt(10, contentW-19) + displayURL := fitTail(s.URL, maxInt(1, fieldW-1)) + "█" + renderedField := st.FieldActive.Width(fieldW).Render(st.Value.Render(displayURL)) + + labelBlock := strings.Join([]string{ + strings.Repeat(" ", 15), + lipgloss.NewStyle().Width(15).Render(st.Label.Render("代理地址")), + strings.Repeat(" ", 15), + }, "\n") + lines = append(lines, lipgloss.JoinHorizontal(lipgloss.Top, labelBlock, renderedField)) + lines = append(lines, "") + + hint := "示例: http://127.0.0.1:7890 或留空以直连" + lines = append(lines, st.Muted.Render(truncate(hint, contentW))) + lines = append(lines, "") + lines = append(lines, st.Muted.Render(truncate("配置保存至 ~/.ait/config.json,重启无需重新输入。", contentW))) + + // 填充至 maxH + for len(lines) < maxH { + lines = append(lines, "") + } + if len(lines) > maxH { + lines = lines[:maxH] + } + return strings.Join(lines, "\n") +} diff --git a/internal/tui/pages/tasklist.go b/internal/tui/pages/tasklist.go index 0b24bf8..78d0986 100644 --- a/internal/tui/pages/tasklist.go +++ b/internal/tui/pages/tasklist.go @@ -75,6 +75,9 @@ func HandleTaskListKey(s *TaskListState, msg tea.KeyMsg, client Client) (*TaskLi case "a": nav = NavAction{To: NavWizard, EditTask: nil} + case "p": + nav = NavAction{To: NavProxy} + case "e": if t, ok := s.CurrentTask(); ok { nav = NavAction{To: NavWizard, EditTask: &t} diff --git a/internal/tui/pages/wizard.go b/internal/tui/pages/wizard.go index 3d0b3c8..9ce9ea4 100644 --- a/internal/tui/pages/wizard.go +++ b/internal/tui/pages/wizard.go @@ -266,11 +266,6 @@ func step1Fields() []fieldDef { getRaw: func(wz *WizardState) string { return wz.EndpointURL }, set: func(wz *WizardState, v string) { wz.EndpointURL = v }, }, - { - kind: fieldText, label: "代理地址", - get: func(wz *WizardState) string { return wz.ProxyURL }, - set: func(wz *WizardState, v string) { wz.ProxyURL = v }, - }, { kind: fieldText, label: "API 密钥", get: func(wz *WizardState) string { return wz.APIKey }, @@ -760,7 +755,6 @@ func renderStep3Summary(wz *WizardState, st Styles, innerW int) []string { endpointDisplay = types.DefaultEndpointURL(wz.Protocol) } addRow("接口地址", endpointDisplay, st.Value) - addRow("代理地址", wizardFallback(wz.ProxyURL, "直连"), st.Value) addRow("API 密钥", wizardFallback(maskAPIKey(wz.APIKey), "未填写"), st.Value) addRow("测试模型", wizardFallback(wz.Model, "未填写"), st.Value) From 2033134610055c4a61e7f3a45d95c66a5661084d Mon Sep 17 00:00:00 2001 From: Alain Date: Thu, 21 May 2026 00:26:07 +0800 Subject: [PATCH 28/52] =?UTF-8?q?feat:=20=E6=9B=B4=E6=96=B0=20ResponsesAPI?= =?UTF-8?q?Request=20=E7=BB=93=E6=9E=84=EF=BC=8C=E6=B7=BB=E5=8A=A0=20Instr?= =?UTF-8?q?uctions=20=E5=92=8C=20Store=20=E5=AD=97=E6=AE=B5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/client/openai.go | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/internal/client/openai.go b/internal/client/openai.go index 993a520..fbfa384 100644 --- a/internal/client/openai.go +++ b/internal/client/openai.go @@ -58,10 +58,12 @@ type ChatCompletionRequest struct { } type ResponsesAPIRequest struct { - Model string `json:"model"` - Input string `json:"input"` - Stream bool `json:"stream,omitempty"` - Reasoning *ResponsesReasoningOptions `json:"reasoning,omitempty"` + Model string `json:"model"` + Input string `json:"input"` + Instructions string `json:"instructions,omitempty"` + Store bool `json:"store,omitempty"` + Stream bool `json:"stream,omitempty"` + Reasoning *ResponsesReasoningOptions `json:"reasoning,omitempty"` } // ChatCompletionResponse represents the response from chat completion @@ -174,14 +176,12 @@ func extractCachedInputTokens(details *PromptTokensDetails) int { func (c *OpenAIClient) buildRequestBody(systemPrompt, userPrompt string, stream bool) ([]byte, error) { if c.Provider == types.ProtocolOpenAIResponses { - input := userPrompt - if systemPrompt != "" { - input = systemPrompt + "\n\n" + userPrompt - } reqBody := ResponsesAPIRequest{ - Model: c.Model, - Input: input, - Stream: stream, + Model: c.Model, + Input: userPrompt, + Instructions: systemPrompt, + Store: true, + Stream: stream, } if c.Thinking { reqBody.Reasoning = &ResponsesReasoningOptions{Effort: "medium"} From 3da552abf49e22ea1bdf26552d84dc596744d58b Mon Sep 17 00:00:00 2001 From: Alain Date: Thu, 21 May 2026 08:51:12 +0800 Subject: [PATCH 29/52] =?UTF-8?q?feat:=20=E6=9B=B4=E6=96=B0=E4=BB=BB?= =?UTF-8?q?=E5=8A=A1=E5=88=97=E8=A1=A8=E5=88=B7=E6=96=B0=E9=80=BB=E8=BE=91?= =?UTF-8?q?=EF=BC=8C=E4=BF=9D=E6=8C=81=E5=85=89=E6=A0=87=E6=8C=87=E5=90=91?= =?UTF-8?q?=E5=BD=93=E5=89=8D=E9=80=89=E4=B8=AD=E4=BB=BB=E5=8A=A1?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/tui/model.go | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/internal/tui/model.go b/internal/tui/model.go index 9d236bd..eb48a23 100644 --- a/internal/tui/model.go +++ b/internal/tui/model.go @@ -93,7 +93,22 @@ func (m *Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { if m.taskList == nil { m.taskList = pages.NewTaskListState() } + // 刷新前记录当前选中任务的 ID,刷新后保持光标指向同一任务。 + // 任务列表按 UpdatedAt 排序,编辑/复制任务后顺序会变化,若只靠下标 + // 定位会导致光标悄悄滑到别的任务上,进入错误任务的详情页。 + var prevID string + if t, ok := m.taskList.CurrentTask(); ok { + prevID = t.ID + } m.taskList.Tasks = msg.Tasks + if prevID != "" { + for i, t := range msg.Tasks { + if t.ID == prevID { + m.taskList.Selected = i + break + } + } + } if m.taskList.Selected >= len(msg.Tasks) { m.taskList.Selected = max(len(msg.Tasks)-1, 0) } @@ -310,7 +325,8 @@ func (m *Model) handleNav(nav pages.NavAction) tea.Cmd { } else { m.detail = pages.NewTaskDetailState(*task) } - } else if m.detail == nil { + } else { + // 目标任务不在列表中(已删除或列表尚未加载),中止导航 return nil } if m.detail != nil { From caea313c28eab74a7915373f5a51f244a2a06762 Mon Sep 17 00:00:00 2001 From: Alain Date: Thu, 21 May 2026 23:25:45 +0800 Subject: [PATCH 30/52] =?UTF-8?q?feat:=20=E6=9B=B4=E6=96=B0=E8=AF=B7?= =?UTF-8?q?=E6=B1=82=E4=BD=93=E7=BB=93=E6=9E=84=EF=BC=8C=E6=B7=BB=E5=8A=A0?= =?UTF-8?q?=E8=AF=B7=E6=B1=82=E5=92=8C=E5=93=8D=E5=BA=94=E4=BD=93=E5=AD=97?= =?UTF-8?q?=E6=AE=B5=EF=BC=8C=E4=BC=98=E5=8C=96=E4=BB=A3=E7=90=86=E9=85=8D?= =?UTF-8?q?=E7=BD=AE=E9=A1=B5=E9=9D=A2=E8=BE=93=E5=85=A5=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/client/anthropic.go | 4 + internal/client/openai.go | 23 +++- internal/client/openai_test.go | 2 +- internal/tui/model.go | 2 +- internal/tui/pages/dashboard.go | 2 +- internal/tui/pages/proxy.go | 34 +++--- internal/tui/pages/turbodash.go | 189 ++++++++++++++++++-------------- internal/tui/pages/wizard.go | 115 ++++++++++++++----- 8 files changed, 237 insertions(+), 134 deletions(-) diff --git a/internal/client/anthropic.go b/internal/client/anthropic.go index 568542a..19d743f 100644 --- a/internal/client/anthropic.go +++ b/internal/client/anthropic.go @@ -215,6 +215,7 @@ func (c *AnthropicClient) Request(systemPrompt, userPrompt string, stream bool) TLSHandshakeTime: 0, TargetIP: "", CompletionTokens: 0, + RequestBody: string(reqBodyBytes), ErrorMessage: fmt.Sprintf("Request creation error: %s", err.Error()), }, err } @@ -293,6 +294,7 @@ func (c *AnthropicClient) Request(systemPrompt, userPrompt string, stream bool) TLSHandshakeTime: tlsTime, TargetIP: targetIP, CompletionTokens: 0, + RequestBody: string(reqBodyBytes), ErrorMessage: fmt.Sprintf("Network error: %s", err.Error()), }, err } @@ -336,6 +338,8 @@ func (c *AnthropicClient) Request(systemPrompt, userPrompt string, stream bool) TLSHandshakeTime: tlsTime, TargetIP: targetIP, CompletionTokens: 0, + RequestBody: string(reqBodyBytes), + ResponseBody: responseBody, ErrorMessage: errorMessage, }, fmt.Errorf(errorMessage) } diff --git a/internal/client/openai.go b/internal/client/openai.go index fbfa384..0e71216 100644 --- a/internal/client/openai.go +++ b/internal/client/openai.go @@ -57,9 +57,14 @@ type ChatCompletionRequest struct { Thinking *ThinkingOptions `json:"thinking,omitempty"` } +type ResponsesAPIInputItem struct { + Role string `json:"role"` + Content string `json:"content"` +} + type ResponsesAPIRequest struct { Model string `json:"model"` - Input string `json:"input"` + Input []ResponsesAPIInputItem `json:"input"` Instructions string `json:"instructions,omitempty"` Store bool `json:"store,omitempty"` Stream bool `json:"stream,omitempty"` @@ -177,8 +182,10 @@ func extractCachedInputTokens(details *PromptTokensDetails) int { func (c *OpenAIClient) buildRequestBody(systemPrompt, userPrompt string, stream bool) ([]byte, error) { if c.Provider == types.ProtocolOpenAIResponses { reqBody := ResponsesAPIRequest{ - Model: c.Model, - Input: userPrompt, + Model: c.Model, + Input: []ResponsesAPIInputItem{ + {Role: "user", Content: userPrompt}, + }, Instructions: systemPrompt, Store: true, Stream: stream, @@ -419,6 +426,7 @@ func (c *OpenAIClient) Request(systemPrompt, userPrompt string, stream bool) (*R TLSHandshakeTime: 0, TargetIP: "", CompletionTokens: 0, + RequestBody: string(jsonData), ErrorMessage: fmt.Sprintf("Request creation error: %s", err.Error()), }, err } @@ -499,6 +507,7 @@ func (c *OpenAIClient) Request(systemPrompt, userPrompt string, stream bool) (*R TLSHandshakeTime: tlsTime, TargetIP: targetIP, CompletionTokens: 0, + RequestBody: string(jsonData), ErrorMessage: fmt.Sprintf("Network error: %s", err.Error()), }, err } @@ -541,6 +550,8 @@ func (c *OpenAIClient) Request(systemPrompt, userPrompt string, stream bool) (*R TLSHandshakeTime: tlsTime, TargetIP: targetIP, CompletionTokens: 0, + RequestBody: string(jsonData), + ResponseBody: responseBody, ErrorMessage: errorMessage, }, fmt.Errorf(errorMessage) } @@ -672,8 +683,7 @@ func (c *OpenAIClient) Request(systemPrompt, userPrompt string, stream bool) (*R ConnectTime: connectTime, TLSHandshakeTime: tlsTime, TargetIP: targetIP, - CompletionTokens: 0, - ErrorMessage: fmt.Sprintf("Network error: %s", err.Error()), + CompletionTokens: 0, RequestBody: string(jsonData), ErrorMessage: fmt.Sprintf("Network error: %s", err.Error()), }, err } defer resp.Body.Close() @@ -699,6 +709,8 @@ func (c *OpenAIClient) Request(systemPrompt, userPrompt string, stream bool) (*R TLSHandshakeTime: tlsTime, TargetIP: targetIP, CompletionTokens: 0, + RequestBody: string(jsonData), + ResponseBody: string(responseData), ErrorMessage: errorMessage, }, fmt.Errorf(errorMessage) } @@ -717,6 +729,7 @@ func (c *OpenAIClient) Request(systemPrompt, userPrompt string, stream bool) (*R TLSHandshakeTime: tlsTime, TargetIP: targetIP, CompletionTokens: 0, + RequestBody: string(jsonData), ErrorMessage: fmt.Sprintf("Response body read error: %s", err.Error()), }, err } diff --git a/internal/client/openai_test.go b/internal/client/openai_test.go index 20a8a39..41f4211 100644 --- a/internal/client/openai_test.go +++ b/internal/client/openai_test.go @@ -465,7 +465,7 @@ func TestOpenAIClient_Request_OpenAIResponses_NonStream(t *testing.T) { if strings.Contains(requestBody, "messages") { t.Fatalf("responses request should not use chat-completions payload: %s", requestBody) } - if !strings.Contains(requestBody, `"input":"hello from responses"`) { + if !strings.Contains(requestBody, `"input":[{"role":"user","content":"hello from responses"}]`) { t.Fatalf("responses request body missing input field: %s", requestBody) } if metrics.PromptTokens != 12 || metrics.CachedInputTokens != 3 || metrics.CompletionTokens != 7 || metrics.ThinkingTokens != 2 { diff --git a/internal/tui/model.go b/internal/tui/model.go index eb48a23..afdff2f 100644 --- a/internal/tui/model.go +++ b/internal/tui/model.go @@ -214,7 +214,7 @@ func (m *Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { // ── 代理配置 ── case ProxyConfigLoadedMsg: if m.proxyConf != nil { - m.proxyConf.URL = msg.ProxyURL + m.proxyConf = pages.NewProxyConfigState(msg.ProxyURL) } return m, nil diff --git a/internal/tui/pages/dashboard.go b/internal/tui/pages/dashboard.go index 2133273..29d23d5 100644 --- a/internal/tui/pages/dashboard.go +++ b/internal/tui/pages/dashboard.go @@ -347,7 +347,7 @@ func buildRequestList(d *DashboardState, rs *server.RunState, st Styles, width, } totalText := fmtDuration(r.TotalTime) if !r.Success && r.ErrorMessage != "" { - totalText = "timeout" + totalText = r.ErrorMessage } statusStr := statusText diff --git a/internal/tui/pages/proxy.go b/internal/tui/pages/proxy.go index 48ac68c..729a7a8 100644 --- a/internal/tui/pages/proxy.go +++ b/internal/tui/pages/proxy.go @@ -3,18 +3,26 @@ package pages import ( "strings" + "github.com/charmbracelet/bubbles/cursor" + "github.com/charmbracelet/bubbles/textinput" tea "github.com/charmbracelet/bubbletea" "github.com/charmbracelet/lipgloss" ) // ProxyConfigState 代理配置页面状态。 type ProxyConfigState struct { - URL string // 当前编辑中的代理 URL + input textinput.Model // 代理 URL 输入框 } // NewProxyConfigState 创建代理配置页面状态,传入当前已保存的代理 URL。 func NewProxyConfigState(currentURL string) *ProxyConfigState { - return &ProxyConfigState{URL: currentURL} + ti := textinput.New() + ti.Prompt = "" + ti.Cursor.SetMode(cursor.CursorStatic) + ti.SetValue(currentURL) + ti.CursorEnd() + ti.Focus() + return &ProxyConfigState{input: ti} } // HandleProxyConfigKey 处理代理配置页面的按键。 @@ -29,26 +37,17 @@ func HandleProxyConfigKey(s *ProxyConfigState, msg tea.KeyMsg, client Client) (* nav = NavAction{To: NavTaskList} case "enter": - cmd := client.SaveProxyConfigCmd(s.URL) + cmd := client.SaveProxyConfigCmd(s.input.Value()) nav = NavAction{To: NavTaskList} return s, cmd, nav - case "backspace": - r := []rune(s.URL) - if len(r) > 0 { - s.URL = string(r[:len(r)-1]) - } - - case "ctrl+u": - s.URL = "" - case "q", "ctrl+c": nav = NavAction{To: NavQuit} default: - if len(msg.Runes) > 0 { - s.URL += string(msg.Runes) - } + var cmd tea.Cmd + s.input, cmd = s.input.Update(msg) + return s, cmd, nav } return s, nil, nav @@ -84,8 +83,9 @@ func buildProxyConfigContent(s *ProxyConfigState, st Styles, contentW, maxH int) // 字段宽度(与 wizard renderWizardField 保持一致) fieldW := maxInt(10, contentW-19) - displayURL := fitTail(s.URL, maxInt(1, fieldW-1)) + "█" - renderedField := st.FieldActive.Width(fieldW).Render(st.Value.Render(displayURL)) + s.input.Width = fieldW + s.input.TextStyle = st.Value + renderedField := st.FieldActive.Width(fieldW).Render(s.input.View()) labelBlock := strings.Join([]string{ strings.Repeat(" ", 15), diff --git a/internal/tui/pages/turbodash.go b/internal/tui/pages/turbodash.go index b83f367..d5e81af 100644 --- a/internal/tui/pages/turbodash.go +++ b/internal/tui/pages/turbodash.go @@ -17,18 +17,40 @@ type TurboDashState struct { EventCh <-chan server.Event CancelFn server.CancelFunc RunState *server.RunState - LevelSel int // 选中的级别索引(-1 = 无选中) - LevelOff int - LevelVis int + ReqSel int // 选中请求索引(-1 = 无选中) + ReqOff int // 滚动偏移 + ReqVis int // 当前可见请求数 BackNav NavAction // 按 b/esc 时的返回目标;Zero = 返回任务列表 } +// AdjustReqOffset 根据屏幕显示顺序调整列表可见窗口。 +func (d *TurboDashState) AdjustReqOffset(visH, total int) { + if d == nil { + return + } + if visH < 3 { + visH = 3 + } + if total <= 0 || d.ReqSel < 0 { + d.ReqOff = 0 + return + } + sel := requestDisplayPos(d.ReqSel, total) + off := d.ReqOff + if sel < off { + off = sel + } else if sel >= off+visH { + off = sel - visH + 1 + } + d.ReqOff = clampInt(off, 0, maxInt(0, total-visH)) +} + // NewTurboDashState 创建 Turbo 仪表盘初始状态。 func NewTurboDashState(runID server.RunID, taskID string) *TurboDashState { return &TurboDashState{ - RunID: runID, - TaskID: taskID, - LevelSel: -1, + RunID: runID, + TaskID: taskID, + ReqSel: -1, } } @@ -47,46 +69,47 @@ func HandleTurboDashKey(d *TurboDashState, msg tea.KeyMsg, client Client) (*Turb return d, nil, NavAction{To: NavTaskList} } - var levels []types.TurboLevelResult + var reqs []*types.RequestMetrics if d.RunState != nil { - levels = d.RunState.Levels + reqs = d.RunState.Requests } switch msg.String() { case "up", "k": - if len(levels) == 0 { + if len(reqs) == 0 { break } - if d.LevelSel < 0 { - // 首次按键:跳到最后一级(最新/最高并发),与 ↓ 保持一致 - d.LevelSel = len(levels) - 1 - } else if d.LevelSel <= 0 { - d.LevelSel = len(levels) - 1 + if d.ReqSel < 0 { + d.ReqSel = requestIndexFromDisplayPos(0, len(reqs)) } else { - d.LevelSel-- + selPos := requestDisplayPos(d.ReqSel, len(reqs)) + if selPos > 0 { + selPos-- + } else { + selPos = len(reqs) - 1 + } + d.ReqSel = requestIndexFromDisplayPos(selPos, len(reqs)) } case "down", "j": - if len(levels) == 0 { + if len(reqs) == 0 { break } - if d.LevelSel < 0 { - // 首次按键:跳到最后一级(最新/最高并发),与 ↑ 保持一致 - d.LevelSel = len(levels) - 1 - } else if d.LevelSel < len(levels)-1 { - d.LevelSel++ + if d.ReqSel < 0 { + d.ReqSel = requestIndexFromDisplayPos(0, len(reqs)) } else { - d.LevelSel = 0 + selPos := requestDisplayPos(d.ReqSel, len(reqs)) + if selPos < len(reqs)-1 { + selPos++ + } else { + selPos = 0 + } + d.ReqSel = requestIndexFromDisplayPos(selPos, len(reqs)) } case "enter": - // 进入该级别的请求列表,定位到该级别第一条请求 - if d.LevelSel >= 0 && d.LevelSel < len(levels) { - startIdx := 0 - for j := 0; j < d.LevelSel; j++ { - startIdx += levels[j].TotalRequests - } - nav = NavAction{To: NavReqDetail, ReqIndex: startIdx} + if d.ReqSel >= 0 && d.ReqSel < len(reqs) { + nav = NavAction{To: NavReqDetail, ReqIndex: d.ReqSel} } case "s": @@ -115,7 +138,7 @@ func HandleTurboDashKey(d *TurboDashState, msg tea.KeyMsg, client Client) (*Turb case "q", "ctrl+c": nav = NavAction{To: NavQuit} } - d.LevelOff = ensureVisibleOffset(d.LevelSel, len(levels), d.LevelOff, d.LevelVis) + d.AdjustReqOffset(d.ReqVis, len(reqs)) return d, nil, nav } @@ -150,7 +173,7 @@ func RenderTurboDash(d *TurboDashState, taskName string, st Styles, width, heigh rs := d.RunState isRunning := d.IsRunning() - hasSel := d.LevelSel >= 0 && rs != nil && d.LevelSel < len(rs.Levels) + hasSel := d.ReqSel >= 0 && rs != nil && d.ReqSel < len(rs.Requests) var cbItems []HotkeyItem switch { case hasSel && isRunning: @@ -205,9 +228,9 @@ func RenderTurboDash(d *TurboDashState, taskName string, st Styles, width, heigh progressLine := buildTurboProgressLine(rs, st, bodyPanel.InnerWidth) progressPanelStr := bodyPanel.Wrap(st, progressLine) - // ── 级别列表面板 ── - levelList := buildLevelList(d, rs, st, bodyPanel.InnerWidth, levelListH) - levelPanelStr := bodyPanel.Wrap(st, levelList) + // ── 请求列表面板 ── + requestList := buildTurboRequestList(d, rs, st, bodyPanel.InnerWidth, levelListH) + levelPanelStr := bodyPanel.Wrap(st, requestList) content := joinVerticalBlocks(split, progressPanelStr, levelPanelStr) return l.Assemble(frame.Wrap(st, content), st, width) @@ -280,73 +303,77 @@ func buildTurboProgressLine(rs *server.RunState, st Styles, width int) string { return prefix + barRendered + suffix } -// buildLevelList 构建 Turbo 级别列表区域。 -func buildLevelList(d *TurboDashState, rs *server.RunState, st Styles, width, maxH int) string { - lines := panelTitleLines(st, "级别列表", width, true) +// buildTurboRequestList 构建 Turbo 模式请求列表区域。 +func buildTurboRequestList(d *TurboDashState, rs *server.RunState, st Styles, width, maxH int) string { + lines := panelTitleLines(st, "请求列表", width, true) - if rs == nil || len(rs.Levels) == 0 { - lines = append(lines, " "+st.Muted.Render("等待第一个级别完成...")) + if rs == nil || len(rs.Requests) == 0 { + msg := "等待请求..." + if rs != nil && rs.Status != server.RunStatusRunning { + msg = "无请求详情数据" + } + lines = append(lines, " "+st.Muted.Render(msg)) return finishPanelLines(lines, maxH) } - // 列宽(header 与 content 行保持一致,前缀均为 2 字符) const ( - markW = 2 // 选择标记列 - concW = 6 // 并发数 - rateW = 8 // 成功率 - tpsW = 10 // TPS - ttftW = 10 // TTFT - cacheW = 8 // Cache - totW = 9 // 总耗时 - // 结论: 余量 + markW = 2 + idW = 6 + statW = 5 + timeW = 10 + ttftW = 10 + cacheW = 8 + tokW = 10 ) - hdr := padRight("", markW) + padRight("并发", concW) + padRight("成功率", rateW) + padRight("TPS", tpsW) + - padRight("TTFT", ttftW) + padRight("Cache", cacheW) + padRight("总耗时", totW) + "结论" + hdr := padRight("", markW) + padRight("#", idW) + padRight("状态", statW) + padRight("总耗时", timeW) + + padRight("TTFT", ttftW) + padRight("Cache", cacheW) + padRight("Token", tokW) + "TPS" lines = append(lines, renderTableHeader(st, width, hdr)) lines = append(lines, dividerLine(st, width)) - d.LevelVis = listVisibleItems(maxH, 3) - d.LevelOff = ensureVisibleOffset(d.LevelSel, len(rs.Levels), d.LevelOff, d.LevelVis) - start := d.LevelOff - end := minInt(len(rs.Levels), start+d.LevelVis) - - for i := start; i < end; i++ { - lv := rs.Levels[i] - isSel := i == d.LevelSel - - conclusionText := "✓ 稳定" - if !lv.Stable { - conclusionText = "✗ 降级" + d.ReqVis = listVisibleItems(maxH, 3) + d.AdjustReqOffset(d.ReqVis, len(rs.Requests)) + + reqs := rs.Requests + start := d.ReqOff + end := minInt(len(reqs), start+d.ReqVis) + for pos := start; pos < end; pos++ { + i := requestIndexFromDisplayPos(pos, len(reqs)) + r := reqs[i] + isSel := i == d.ReqSel + + statusText := "✓" + if !r.Success { + statusText = "✗" } - isCurrent := (i == len(rs.Levels)-1) && rs.Status == server.RunStatusRunning - if isCurrent { - conclusionText = "🔄 进行中" + totalText := fmtDuration(r.TotalTime) + if !r.Success && r.ErrorMessage != "" { + totalText = r.ErrorMessage } - conclusion := conclusionText - if isCurrent { - conclusion = styleWhenNotSelected(isSel, st.MetricVal, conclusionText) - } else if lv.Stable { - conclusion = styleWhenNotSelected(isSel, st.Ok, conclusionText) + statusStr := statusText + if r.Success { + statusStr = styleWhenNotSelected(isSel, st.Ok, statusText) } else { - conclusion = styleWhenNotSelected(isSel, st.ErrStyle, conclusionText) + statusStr = styleWhenNotSelected(isSel, st.ErrStyle, statusText) + } + totalStr := totalText + if !r.Success && r.ErrorMessage != "" { + totalStr = styleWhenNotSelected(isSel, st.ErrStyle, totalText) } marker := selectionMarker(isSel) - rowContent := padRight(marker, markW) + - padRight(fmt.Sprintf("%d", lv.Concurrency), concW) + - padRight(fmt.Sprintf("%.1f%%", lv.SuccessRate*100), rateW) + - padRight(fmt.Sprintf("%.1f", lv.AvgTPS), tpsW) + - padRight(fmtDuration(lv.AvgTTFT), ttftW) + - padRight(fmt.Sprintf("%.1f%%", lv.CacheHitRate*100), cacheW) + - padRight(fmtDuration(lv.AvgTotalTime), totW) + - conclusion + padRight(fmt.Sprintf("#%d", len(reqs)-pos), idW) + + padRight(statusStr, statW) + + padRight(totalStr, timeW) + + padRight(fmtDuration(r.TTFT), ttftW) + + padRight(fmt.Sprintf("%.0f%%", r.CacheHitRate*100), cacheW) + + padRight(fmt.Sprintf("%dtok", r.CompletionTokens), tokW) + + fmt.Sprintf("%.1f/s", r.TPS) rendered := renderTableRow(st, width, isSel, rowContent) lines = append(lines, rendered) - // 行间分隔线 - if i < end-1 && len(lines) < maxH-1 { + if pos < end-1 && len(lines) < maxH-1 { lines = append(lines, dividerLine(st, width)) } } diff --git a/internal/tui/pages/wizard.go b/internal/tui/pages/wizard.go index 9ce9ea4..54afce1 100644 --- a/internal/tui/pages/wizard.go +++ b/internal/tui/pages/wizard.go @@ -6,6 +6,8 @@ import ( "strings" "time" + "github.com/charmbracelet/bubbles/cursor" + "github.com/charmbracelet/bubbles/textinput" tea "github.com/charmbracelet/bubbletea" "github.com/charmbracelet/lipgloss" "github.com/yinxulai/ait/internal/server" @@ -66,11 +68,58 @@ type WizardState struct { // 当前活跃字段索引(Tab 切换) FieldIndex int ScrollOff int + input textinput.Model // 当前活跃文本字段的光标与编辑状态 +} + +// newWizardTextInput 创建向导使用的 textinput,禁用光标闪烁。 +func newWizardTextInput() textinput.Model { + ti := textinput.New() + ti.Prompt = "" + ti.Cursor.SetMode(cursor.CursorStatic) + ti.Focus() + return ti +} + +// loadInputForField 将字段的当前值加载到 wz.input,并将光标移到末尾。 +func loadInputForField(wz *WizardState, f fieldDef) { + rawVal := "" + if f.getRaw != nil { + rawVal = f.getRaw(wz) + } else if f.get != nil { + rawVal = f.get(wz) + } + if f.label == "API 密钥" { + wz.input.EchoMode = textinput.EchoPassword + } else { + wz.input.EchoMode = textinput.EchoNormal + } + wz.input.SetValue(rawVal) + wz.input.CursorEnd() +} + +// loadCurrentFieldInput 根据当前 Step/FieldIndex 重新加载 input。 +// 在字段切换或步骤切换后调用。 +func loadCurrentFieldInput(wz *WizardState) { + var fields []fieldDef + switch wz.Step { + case wizardStep1: + fields = step1Fields() + case wizardStep2: + fields = step2Fields(wz.Turbo) + default: + return + } + if wz.FieldIndex < len(fields) { + f := fields[wz.FieldIndex] + if f.kind == fieldText || f.kind == fieldNumber { + loadInputForField(wz, f) + } + } } // NewWizardState 创建新建任务向导状态(使用默认值)。 func NewWizardState() *WizardState { - return &WizardState{ + wz := &WizardState{ Step: wizardStep1, Protocol: types.ProtocolOpenAICompletions, Concurrency: 10, @@ -84,7 +133,10 @@ func NewWizardState() *WizardState { Stream: true, PromptMode: PromptModeGenerated, PromptLength: 100, + input: newWizardTextInput(), } + loadCurrentFieldInput(wz) + return wz } // NewWizardStateEdit 创建编辑任务向导状态(预填任务数据,零值字段沿用默认值)。 @@ -143,6 +195,8 @@ func NewWizardStateEdit(t *types.TaskDefinition) *WizardState { if tc.MinSuccessRate > 0 { wz.MinSuccessRate = tc.MinSuccessRate * 100 } + // 数据字段全部填充完毕后,重新加载当前字段(Name)到 input + loadCurrentFieldInput(wz) return wz } @@ -463,16 +517,19 @@ func HandleWizardKey(wz *WizardState, msg tea.KeyMsg, client Client) (*WizardSta wz.Step-- wz.FieldIndex = 0 wz.ScrollOff = 0 + loadCurrentFieldInput(wz) } case "tab", "down", "j": if wz.FieldIndex < maxField { wz.FieldIndex++ + loadCurrentFieldInput(wz) } case "shift+tab", "up", "k": if wz.FieldIndex > 0 { wz.FieldIndex-- + loadCurrentFieldInput(wz) } case "left": @@ -480,11 +537,15 @@ func HandleWizardKey(wz *WizardState, msg tea.KeyMsg, client Client) (*WizardSta f := fields[wz.FieldIndex] if f.toggle != nil { f.toggle(wz, false) - // 如果切换了 turbo 模式,重置 fieldIndex if f.label == "测试模式" { wz.FieldIndex = 0 wz.ScrollOff = 0 + loadCurrentFieldInput(wz) } + } else if f.kind == fieldText || f.kind == fieldNumber { + var cmd tea.Cmd + wz.input, cmd = wz.input.Update(msg) + return wz, cmd, nav } } @@ -496,7 +557,12 @@ func HandleWizardKey(wz *WizardState, msg tea.KeyMsg, client Client) (*WizardSta if f.label == "测试模式" { wz.FieldIndex = 0 wz.ScrollOff = 0 + loadCurrentFieldInput(wz) } + } else if f.kind == fieldText || f.kind == fieldNumber { + var cmd tea.Cmd + wz.input, cmd = wz.input.Update(msg) + return wz, cmd, nav } } @@ -508,36 +574,22 @@ func HandleWizardKey(wz *WizardState, msg tea.KeyMsg, client Client) (*WizardSta } else if wz.FieldIndex < maxField { wz.FieldIndex++ } - - case "backspace": - if wz.FieldIndex < len(fields) { - f := fields[wz.FieldIndex] - if f.set != nil && f.kind == fieldText { - getEdit := f.get - if f.getRaw != nil { - getEdit = f.getRaw - } - v := getEdit(wz) - r := []rune(v) - if len(r) > 0 { - f.set(wz, string(r[:len(r)-1])) - } - } - } + loadCurrentFieldInput(wz) case "q", "ctrl+c": nav = NavAction{To: NavQuit} default: - // 字符输入 - if len(msg.Runes) > 0 && wz.FieldIndex < len(fields) { + // 所有非导航键转发给 textinput 处理(退格、ctrl+u/a/e/w/k、字符输入等) + if wz.FieldIndex < len(fields) { f := fields[wz.FieldIndex] - if f.set != nil && (f.kind == fieldText || f.kind == fieldNumber) { - getEdit := f.get - if f.getRaw != nil { - getEdit = f.getRaw + if f.kind == fieldText || f.kind == fieldNumber { + var cmd tea.Cmd + wz.input, cmd = wz.input.Update(msg) + if f.set != nil { + f.set(wz, wz.input.Value()) } - f.set(wz, getEdit(wz)+string(msg.Runes)) + return wz, cmd, nav } } } @@ -711,7 +763,6 @@ func renderWizardField(st Styles, f fieldDef, wz *WizardState, active bool, maxW } // Width(fieldW) 是内容区宽度,padding 在其外侧叠加,文字区即为 fieldW - // 激活时保留 1 列给光标 █,非激活可用满 fieldW if f.kind == fieldEnum || f.kind == fieldBool { if active { valueStr = "‹ " + valueStr + " ›" @@ -719,7 +770,8 @@ func renderWizardField(st Styles, f fieldDef, wz *WizardState, active bool, maxW valueStr = truncate(valueStr, maxInt(4, fieldW)) } else { if active { - valueStr = fitTail(valueStr, maxInt(1, fieldW-1)) + "█" + wz.input.Width = fieldW + valueStr = wz.input.View() } else { valueStr = fitTail(valueStr, maxInt(1, fieldW)) } @@ -730,7 +782,14 @@ func renderWizardField(st Styles, f fieldDef, wz *WizardState, active bool, maxW fieldStyle = st.FieldActive } - renderedValue := fieldStyle.Width(fieldW).Render(valueStyle.Render(valueStr)) + var renderedValue string + if active && (f.kind == fieldText || f.kind == fieldNumber) { + // textinput 自带光标和滚动,设置文本样式后直接渲染,不再二次包裹 valueStyle + wz.input.TextStyle = valueStyle + renderedValue = fieldStyle.Width(fieldW).Render(wz.input.View()) + } else { + renderedValue = fieldStyle.Width(fieldW).Render(valueStyle.Render(valueStr)) + } labelLines := []string{ strings.Repeat(" ", 15), lipgloss.NewStyle().Width(15).Render(st.Label.Render(wizardFieldLabel(f, wz))), From 9fdb75ad242f0f0cfe792c2af33e1b8285e4f6e7 Mon Sep 17 00:00:00 2001 From: Alain Date: Sat, 23 May 2026 00:09:35 +0800 Subject: [PATCH 31/52] refactor: migrate to lipgloss v2 and improve layout handling - Updated import paths from "github.com/charmbracelet/lipgloss" to "charm.land/lipgloss/v2" across multiple files. - Refactored progress line and request list rendering to use lipgloss.JoinHorizontal for better layout management. - Introduced tableCol helper function for consistent table cell rendering with automatic width handling. - Adjusted panel and field rendering to account for new width calculations in lipgloss v2. - Enhanced height calculations in layout functions to ensure correct rendering of components. - Added tests for height correctness in various components, including ReqDetail, Dashboard, and TaskList. --- cmd/ait/ait.go | 81 +++++++++++++++++------------- go.mod | 34 ++++++------- go.sum | 74 +++++++++++---------------- internal/tui/pages/dashboard.go | 37 +++++++++----- internal/tui/pages/helpers.go | 31 +++++++++--- internal/tui/pages/layout.go | 4 +- internal/tui/pages/layout_test.go | 11 ++-- internal/tui/pages/proxy.go | 6 +-- internal/tui/pages/reqdetail.go | 2 +- internal/tui/pages/styles.go | 4 +- internal/tui/pages/taskdetail.go | 2 +- internal/tui/pages/tasklist.go | 48 +++++++++--------- internal/tui/pages/turbodash.go | 37 +++++++++----- internal/tui/pages/verify_test.go | 83 +++++++++++++++++++++++++++++++ internal/tui/pages/wizard.go | 16 +++--- 15 files changed, 298 insertions(+), 172 deletions(-) create mode 100644 internal/tui/pages/verify_test.go diff --git a/cmd/ait/ait.go b/cmd/ait/ait.go index bcecb93..587fb00 100644 --- a/cmd/ait/ait.go +++ b/cmd/ait/ait.go @@ -24,7 +24,8 @@ func main() { versionFlag := flag.Bool("version", false, "显示版本信息") baseURL := flag.String("baseUrl", "", "服务基础地址(可选,留空使用协议默认地址)") apiKey := flag.String("apiKey", "", "API 密钥") - model := flag.String("model", "", "模型名称") + model := flag.String("model", "", "模型名称(单个模型,与 -models 二选一)") + models := flag.String("models", "", "模型名称列表,逗号分隔,可批量创建任务(如 gpt-4,gpt-4o,gpt-3.5-turbo)") protocol := flag.String("protocol", "", "协议类型: openai / anthropic") promptText := flag.String("prompt", "", "Prompt 文本(可选)") promptFile := flag.String("prompt-file", "", "从文件读取 Prompt") @@ -53,47 +54,59 @@ func main() { } // ── 若提供了足够参数则预建任务并自动启动 ──────────────────────────────────── - if *model != "" { + // 合并 --model 和 --models,去重,保持顺序 + var modelList []string + seen := map[string]bool{} + for _, m := range append(strings.Split(*models, ","), *model) { + m = strings.TrimSpace(m) + if m != "" && !seen[m] { + seen[m] = true + modelList = append(modelList, m) + } + } + + if len(modelList) > 0 { finalProtocol, finalBaseURL, finalAPIKey := resolveConfig(*protocol, *baseURL, *apiKey) if finalAPIKey == "" { fmt.Fprintln(os.Stderr, "错误: 缺少 API Key(-apiKey 或环境变量)") os.Exit(1) } - inp := types.Input{ - Protocol: finalProtocol, - BaseUrl: finalBaseURL, - ApiKey: finalAPIKey, - Model: *model, - Stream: *stream, - Thinking: *thinking, - Concurrency: *concurrency, - Count: *count, - Turbo: *turboFlag, - Timeout: time.Duration(*timeout) * time.Second, - } + for _, m := range modelList { + inp := types.Input{ + Protocol: finalProtocol, + BaseUrl: finalBaseURL, + ApiKey: finalAPIKey, + Model: m, + Stream: *stream, + Thinking: *thinking, + Concurrency: *concurrency, + Count: *count, + Turbo: *turboFlag, + Timeout: time.Duration(*timeout) * time.Second, + } - // Prompt 配置 - switch { - case *promptLen > 0: - inp.PromptMode = "generated" - inp.PromptLength = *promptLen - case *promptFile != "": - inp.PromptMode = "file" - inp.PromptFile = *promptFile - case *promptText != "": - inp.PromptMode = "text" - inp.PromptText = *promptText - default: - inp.PromptMode = "text" - inp.PromptText = "你好,介绍一下你自己。" - } + // Prompt 配置 + switch { + case *promptLen > 0: + inp.PromptMode = "generated" + inp.PromptLength = *promptLen + case *promptFile != "": + inp.PromptMode = "file" + inp.PromptFile = *promptFile + case *promptText != "": + inp.PromptMode = "text" + inp.PromptText = *promptText + default: + inp.PromptMode = "text" + inp.PromptText = "你好,介绍一下你自己。" + } - taskName := fmt.Sprintf("%s@%s", *model, strings.TrimRight(finalBaseURL, "/")) - _, err := srv.CreateTask(server.TaskConfig{Name: taskName, Input: inp}) - if err != nil { - fmt.Fprintf(os.Stderr, "创建任务失败: %v\n", err) - os.Exit(1) + taskName := fmt.Sprintf("%s@%s", m, strings.TrimRight(finalBaseURL, "/")) + if _, err := srv.CreateTask(server.TaskConfig{Name: taskName, Input: inp}); err != nil { + fmt.Fprintf(os.Stderr, "创建任务失败 [%s]: %v\n", m, err) + os.Exit(1) + } } } diff --git a/go.mod b/go.mod index 9e6973a..ceaf26a 100644 --- a/go.mod +++ b/go.mod @@ -1,38 +1,36 @@ module github.com/yinxulai/ait -go 1.22 +go 1.25.0 require ( + charm.land/lipgloss/v2 v2.0.3 github.com/charmbracelet/bubbles v0.20.0 github.com/charmbracelet/bubbletea v1.2.1 - github.com/charmbracelet/lipgloss v1.0.0 - github.com/olekukonko/tablewriter v1.0.9 - github.com/schollz/progressbar/v3 v3.18.0 ) require ( github.com/atotto/clipboard v0.1.4 // indirect github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect - github.com/charmbracelet/x/ansi v0.8.0 // indirect - github.com/charmbracelet/x/term v0.2.0 // indirect - github.com/clipperhouse/stringish v0.1.1 // indirect - github.com/clipperhouse/uax29/v2 v2.5.0 // indirect + github.com/charmbracelet/colorprofile v0.4.3 // indirect + github.com/charmbracelet/lipgloss v1.0.0 // indirect + github.com/charmbracelet/ultraviolet v0.0.0-20251205161215-1948445e3318 // indirect + github.com/charmbracelet/x/ansi v0.11.7 // indirect + github.com/charmbracelet/x/term v0.2.2 // indirect + github.com/charmbracelet/x/termios v0.1.1 // indirect + github.com/charmbracelet/x/windows v0.2.2 // indirect + github.com/clipperhouse/displaywidth v0.11.0 // indirect + github.com/clipperhouse/uax29/v2 v2.7.0 // indirect github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect - github.com/fatih/color v1.18.0 // indirect - github.com/lucasb-eyer/go-colorful v1.3.0 // indirect - github.com/mattn/go-colorable v0.1.13 // indirect + github.com/lucasb-eyer/go-colorful v1.4.0 // indirect github.com/mattn/go-isatty v0.0.20 // indirect github.com/mattn/go-localereader v0.0.1 // indirect - github.com/mattn/go-runewidth v0.0.19 // indirect - github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db // indirect + github.com/mattn/go-runewidth v0.0.23 // indirect github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 // indirect github.com/muesli/cancelreader v0.2.2 // indirect github.com/muesli/termenv v0.15.2 // indirect - github.com/olekukonko/errors v1.1.0 // indirect - github.com/olekukonko/ll v0.0.9 // indirect github.com/rivo/uniseg v0.4.7 // indirect - golang.org/x/sync v0.9.0 // indirect - golang.org/x/sys v0.29.0 // indirect - golang.org/x/term v0.28.0 // indirect + github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect + golang.org/x/sync v0.18.0 // indirect + golang.org/x/sys v0.43.0 // indirect golang.org/x/text v0.3.8 // indirect ) diff --git a/go.sum b/go.sum index 35c3f09..95ee307 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,5 @@ +charm.land/lipgloss/v2 v2.0.3 h1:yM2zJ4Cf5Y51b7RHIwioil4ApI/aypFXXVHSwlM6RzU= +charm.land/lipgloss/v2 v2.0.3/go.mod h1:7myLU9iG/3xluAWzpY/fSxYYHCgoKTie7laxk6ATwXA= github.com/atotto/clipboard v0.1.4 h1:EH0zSVneZPSuFR11BlR9YppQTVDbh5+16AmcJi4g1z4= github.com/atotto/clipboard v0.1.4/go.mod h1:ZY9tmq7sm5xIbd9bOK4onWV4S6X0u6GY7Vn0Yu86PYI= github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k= @@ -6,67 +8,51 @@ github.com/charmbracelet/bubbles v0.20.0 h1:jSZu6qD8cRQ6k9OMfR1WlM+ruM8fkPWkHvQW github.com/charmbracelet/bubbles v0.20.0/go.mod h1:39slydyswPy+uVOHZ5x/GjwVAFkCsV8IIVy+4MhzwwU= github.com/charmbracelet/bubbletea v1.2.1 h1:J041h57zculJKEKf/O2pS4edXGIz+V0YvojvfGXePIk= github.com/charmbracelet/bubbletea v1.2.1/go.mod h1:viLoDL7hG4njLJSKU2gw7kB3LSEmWsrM80rO1dBJWBI= +github.com/charmbracelet/colorprofile v0.4.3 h1:QPa1IWkYI+AOB+fE+mg/5/4HRMZcaXex9t5KX76i20Q= +github.com/charmbracelet/colorprofile v0.4.3/go.mod h1:/zT4BhpD5aGFpqQQqw7a+VtHCzu+zrQtt1zhMt9mR4Q= github.com/charmbracelet/lipgloss v1.0.0 h1:O7VkGDvqEdGi93X+DeqsQ7PKHDgtQfF8j8/O2qFMQNg= github.com/charmbracelet/lipgloss v1.0.0/go.mod h1:U5fy9Z+C38obMs+T+tJqst9VGzlOYGj4ri9reL3qUlo= -github.com/charmbracelet/x/ansi v0.8.0 h1:9GTq3xq9caJW8ZrBTe0LIe2fvfLR/bYXKTx2llXn7xE= -github.com/charmbracelet/x/ansi v0.8.0/go.mod h1:wdYl/ONOLHLIVmQaxbIYEC/cRKOQyjTkowiI4blgS9Q= -github.com/charmbracelet/x/term v0.2.0 h1:cNB9Ot9q8I711MyZ7myUR5HFWL/lc3OpU8jZ4hwm0x0= -github.com/charmbracelet/x/term v0.2.0/go.mod h1:GVxgxAbjUrmpvIINHIQnJJKpMlHiZ4cktEQCN6GWyF0= -github.com/chengxilo/virtualterm v1.0.4 h1:Z6IpERbRVlfB8WkOmtbHiDbBANU7cimRIof7mk9/PwM= -github.com/chengxilo/virtualterm v1.0.4/go.mod h1:DyxxBZz/x1iqJjFxTFcr6/x+jSpqN0iwWCOK1q10rlY= -github.com/clipperhouse/stringish v0.1.1 h1:+NSqMOr3GR6k1FdRhhnXrLfztGzuG+VuFDfatpWHKCs= -github.com/clipperhouse/stringish v0.1.1/go.mod h1:v/WhFtE1q0ovMta2+m+UbpZ+2/HEXNWYXQgCt4hdOzA= -github.com/clipperhouse/uax29/v2 v2.5.0 h1:x7T0T4eTHDONxFJsL94uKNKPHrclyFI0lm7+w94cO8U= -github.com/clipperhouse/uax29/v2 v2.5.0/go.mod h1:Wn1g7MK6OoeDT0vL+Q0SQLDz/KpfsVRgg6W7ihQeh4g= -github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= -github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/charmbracelet/ultraviolet v0.0.0-20251205161215-1948445e3318 h1:OqDqxQZliC7C8adA7KjelW3OjtAxREfeHkNcd66wpeI= +github.com/charmbracelet/ultraviolet v0.0.0-20251205161215-1948445e3318/go.mod h1:Y6kE2GzHfkyQQVCSL9r2hwokSrIlHGzZG+71+wDYSZI= +github.com/charmbracelet/x/ansi v0.11.7 h1:kzv1kJvjg2S3r9KHo8hDdHFQLEqn4RBCb39dAYC84jI= +github.com/charmbracelet/x/ansi v0.11.7/go.mod h1:9qGpnAVYz+8ACONkZBUWPtL7lulP9No6p1epAihUZwQ= +github.com/charmbracelet/x/term v0.2.2 h1:xVRT/S2ZcKdhhOuSP4t5cLi5o+JxklsoEObBSgfgZRk= +github.com/charmbracelet/x/term v0.2.2/go.mod h1:kF8CY5RddLWrsgVwpw4kAa6TESp6EB5y3uxGLeCqzAI= +github.com/charmbracelet/x/termios v0.1.1 h1:o3Q2bT8eqzGnGPOYheoYS8eEleT5ZVNYNy8JawjaNZY= +github.com/charmbracelet/x/termios v0.1.1/go.mod h1:rB7fnv1TgOPOyyKRJ9o+AsTU/vK5WHJ2ivHeut/Pcwo= +github.com/charmbracelet/x/windows v0.2.2 h1:IofanmuvaxnKHuV04sC0eBy/smG6kIKrWG2/jYn2GuM= +github.com/charmbracelet/x/windows v0.2.2/go.mod h1:/8XtdKZzedat74NQFn0NGlGL4soHB0YQZrETF96h75k= +github.com/clipperhouse/displaywidth v0.11.0 h1:lBc6kY44VFw+TDx4I8opi/EtL9m20WSEFgwIwO+UVM8= +github.com/clipperhouse/displaywidth v0.11.0/go.mod h1:bkrFNkf81G8HyVqmKGxsPufD3JhNl3dSqnGhOoSD/o0= +github.com/clipperhouse/uax29/v2 v2.7.0 h1:+gs4oBZ2gPfVrKPthwbMzWZDaAFPGYK72F0NJv2v7Vk= +github.com/clipperhouse/uax29/v2 v2.7.0/go.mod h1:EFJ2TJMRUaplDxHKj1qAEhCtQPW2tJSwu5BF98AuoVM= github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f h1:Y/CXytFA4m6baUTXGLOoWe4PQhGxaX0KpnayAqC48p4= github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f/go.mod h1:vw97MGsxSvLiUE2X8qFplwetxpGLQrlU1Q9AUEIzCaM= -github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM= -github.com/fatih/color v1.18.0/go.mod h1:4FelSpRwEGDpQ12mAdzqdOukCy4u8WUtOY6lkT/6HfU= -github.com/lucasb-eyer/go-colorful v1.3.0 h1:2/yBRLdWBZKrf7gB40FoiKfAWYQ0lqNcbuQwVHXptag= -github.com/lucasb-eyer/go-colorful v1.3.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0= -github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= -github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= -github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= +github.com/lucasb-eyer/go-colorful v1.4.0 h1:UtrWVfLdarDgc44HcS7pYloGHJUjHV/4FwW4TvVgFr4= +github.com/lucasb-eyer/go-colorful v1.4.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-localereader v0.0.1 h1:ygSAOl7ZXTx4RdPYinUpg6W99U8jWvWi9Ye2JC/oIi4= github.com/mattn/go-localereader v0.0.1/go.mod h1:8fBrzywKY7BI3czFoHkuzRoWE9C+EiG4R1k4Cjx5p88= -github.com/mattn/go-runewidth v0.0.19 h1:v++JhqYnZuu5jSKrk9RbgF5v4CGUjqRfBm05byFGLdw= -github.com/mattn/go-runewidth v0.0.19/go.mod h1:XBkDxAl56ILZc9knddidhrOlY5R/pDhgLpndooCuJAs= -github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db h1:62I3jR2EmQ4l5rM/4FEfDWcRD+abF5XlKShorW5LRoQ= -github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db/go.mod h1:l0dey0ia/Uv7NcFFVbCLtqEBQbrT4OCwCSKTEv6enCw= +github.com/mattn/go-runewidth v0.0.23 h1:7ykA0T0jkPpzSvMS5i9uoNn2Xy3R383f9HDx3RybWcw= +github.com/mattn/go-runewidth v0.0.23/go.mod h1:XBkDxAl56ILZc9knddidhrOlY5R/pDhgLpndooCuJAs= github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 h1:ZK8zHtRHOkbHy6Mmr5D264iyp3TiX5OmNcI5cIARiQI= github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6/go.mod h1:CJlz5H+gyd6CUWT45Oy4q24RdLyn7Md9Vj2/ldJBSIo= github.com/muesli/cancelreader v0.2.2 h1:3I4Kt4BQjOR54NavqnDogx/MIoWBFa0StPA8ELUXHmA= github.com/muesli/cancelreader v0.2.2/go.mod h1:3XuTXfFS2VjM+HTLZY9Ak0l6eUKfijIfMUZ4EgX0QYo= github.com/muesli/termenv v0.15.2 h1:GohcuySI0QmI3wN8Ok9PtKGkgkFIk7y6Vpb5PvrY+Wo= github.com/muesli/termenv v0.15.2/go.mod h1:Epx+iuz8sNs7mNKhxzH4fWXGNpZwUaJKRS1noLXviQ8= -github.com/olekukonko/errors v1.1.0 h1:RNuGIh15QdDenh+hNvKrJkmxxjV4hcS50Db478Ou5sM= -github.com/olekukonko/errors v1.1.0/go.mod h1:ppzxA5jBKcO1vIpCXQ9ZqgDh8iwODz6OXIGKU8r5m4Y= -github.com/olekukonko/ll v0.0.9 h1:Y+1YqDfVkqMWuEQMclsF9HUR5+a82+dxJuL1HHSRpxI= -github.com/olekukonko/ll v0.0.9/go.mod h1:En+sEW0JNETl26+K8eZ6/W4UQ7CYSrrgg/EdIYT2H8g= -github.com/olekukonko/tablewriter v1.0.9 h1:XGwRsYLC2bY7bNd93Dk51bcPZksWZmLYuaTHR0FqfL8= -github.com/olekukonko/tablewriter v1.0.9/go.mod h1:5c+EBPeSqvXnLLgkm9isDdzR3wjfBkHR9Nhfp3NWrzo= -github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= -github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ= github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= -github.com/schollz/progressbar/v3 v3.18.0 h1:uXdoHABRFmNIjUfte/Ex7WtuyVslrw2wVPQmCN62HpA= -github.com/schollz/progressbar/v3 v3.18.0/go.mod h1:IsO3lpbaGuzh8zIMzgY3+J8l4C8GjO0Y9S69eFvNsec= -github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= -github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= -golang.org/x/sync v0.9.0 h1:fEo0HyrW1GIgZdpbhCRO0PkJajUS5H9IFUztCgEo2jQ= -golang.org/x/sync v0.9.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no= +github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM= +golang.org/x/exp v0.0.0-20231006140011-7918f672742d h1:jtJma62tbqLibJ5sFQz8bKtEM8rJBtfilJ2qTU199MI= +golang.org/x/exp v0.0.0-20231006140011-7918f672742d/go.mod h1:ldy0pHrwJyGW56pPQzzkH36rKxoZW1tw7ZJpeKx+hdo= +golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I= +golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.29.0 h1:TPYlXGxvx1MGTn2GiZDhnjPA9wZzZeGKHHmKhHYvgaU= -golang.org/x/sys v0.29.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/term v0.28.0 h1:/Ts8HFuMR2E6IP/jlo7QVLZHggjKQbhu/7H0LJFr3Gg= -golang.org/x/term v0.28.0/go.mod h1:Sw/lC2IAUZ92udQNf3WodGtn4k/XoLyZoh8v/8uiwek= +golang.org/x/sys v0.43.0 h1:Rlag2XtaFTxp19wS8MXlJwTvoh8ArU6ezoyFsMyCTNI= +golang.org/x/sys v0.43.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= golang.org/x/text v0.3.8 h1:nAL+RVCQ9uMn3vJZbV+MRnydTJFPf8qqY42YiA6MrqY= golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ= -gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= -gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/internal/tui/pages/dashboard.go b/internal/tui/pages/dashboard.go index 29d23d5..26605ac 100644 --- a/internal/tui/pages/dashboard.go +++ b/internal/tui/pages/dashboard.go @@ -6,7 +6,7 @@ import ( "time" tea "github.com/charmbracelet/bubbletea" - "github.com/charmbracelet/lipgloss" + "charm.land/lipgloss/v2" "github.com/yinxulai/ait/internal/server" "github.com/yinxulai/ait/internal/types" ) @@ -293,13 +293,16 @@ func buildProgressLine(rs *server.RunState, st Styles, width int) string { barW := width - lipgloss.Width(prefix) - lipgloss.Width(suffix) if barW < 5 { barW = 5 + // 压缩 suffix 确保进度行总宽度不超过 width,防止 lipgloss 折行 + maxSuffixW := maxInt(0, width-lipgloss.Width(prefix)-barW) + suffix = truncate(suffix, maxSuffixW) } filled := int(ratio * float64(barW)) barRendered := st.Ok.Render(strings.Repeat("█", filled)) + st.Muted.Render(strings.Repeat("░", barW-filled)) - return prefix + barRendered + suffix + return lipgloss.JoinHorizontal(lipgloss.Top, prefix, barRendered, suffix) } // buildRequestList 构建请求列表区域。 @@ -326,8 +329,16 @@ func buildRequestList(d *DashboardState, rs *server.RunState, st Styles, width, tokW = 10 // Token // TPS: 余量 ) - hdr := padRight("", markW) + padRight("#", idW) + padRight("状态", statW) + padRight("总耗时", timeW) + - padRight("TTFT", ttftW) + padRight("Cache", cacheW) + padRight("Token", tokW) + "TPS" + hdr := lipgloss.JoinHorizontal(lipgloss.Top, + tableCol(markW, ""), + tableCol(idW, "#"), + tableCol(statW, "状态"), + tableCol(timeW, "总耗时"), + tableCol(ttftW, "TTFT"), + tableCol(cacheW, "Cache"), + tableCol(tokW, "Token"), + "TPS", + ) lines = append(lines, renderTableHeader(st, width, hdr)) lines = append(lines, dividerLine(st, width)) d.ReqVis = listVisibleItems(maxH, 3) @@ -363,14 +374,16 @@ func buildRequestList(d *DashboardState, rs *server.RunState, st Styles, width, marker := selectionMarker(isSel) - rowContent := padRight(marker, markW) + - padRight(fmt.Sprintf("#%d", len(reqs)-pos), idW) + - padRight(statusStr, statW) + - padRight(totalStr, timeW) + - padRight(fmtDuration(r.TTFT), ttftW) + - padRight(fmt.Sprintf("%.0f%%", r.CacheHitRate*100), cacheW) + - padRight(fmt.Sprintf("%dtok", r.CompletionTokens), tokW) + - fmt.Sprintf("%.1f/s", r.TPS) + rowContent := lipgloss.JoinHorizontal(lipgloss.Top, + tableCol(markW, marker), + tableCol(idW, fmt.Sprintf("#%d", len(reqs)-pos)), + tableCol(statW, statusStr), + tableCol(timeW, totalStr), + tableCol(ttftW, fmtDuration(r.TTFT)), + tableCol(cacheW, fmt.Sprintf("%.0f%%", r.CacheHitRate*100)), + tableCol(tokW, fmt.Sprintf("%dtok", r.CompletionTokens)), + fmt.Sprintf("%.1f/s", r.TPS), + ) rendered := renderTableRow(st, width, isSel, rowContent) lines = append(lines, rendered) diff --git a/internal/tui/pages/helpers.go b/internal/tui/pages/helpers.go index 7b8ef82..6499bfc 100644 --- a/internal/tui/pages/helpers.go +++ b/internal/tui/pages/helpers.go @@ -5,7 +5,7 @@ import ( "strings" "time" - "github.com/charmbracelet/lipgloss" + "charm.land/lipgloss/v2" ) // ─── 文本工具 ───────────────────────────────────────────────────────────────── @@ -41,7 +41,12 @@ func padRight(s string, width int) string { return s + strings.Repeat(" ", width-w) } -// wrapText 将文本按 maxW 宽度折行,返回行切片。 +// tableCol 返回固定宽度的表格单元格(自动填充空格并截断,始终单行)。 +func tableCol(w int, text string) string { + return lipgloss.NewStyle().Width(w).Render(truncate(text, w)) +} + +// wrapText 将文本按 maxW 列宽折行,返回行切片(CJK 字符按 2 列宽计算)。 func wrapText(s string, maxW int) []string { if maxW <= 0 { return []string{s} @@ -54,9 +59,19 @@ func wrapText(s string, maxW int) []string { continue } for len(runes) > 0 { - end := maxW - if end > len(runes) { - end = len(runes) + colW := 0 + end := 0 + for end < len(runes) { + rw := lipgloss.Width(string(runes[end])) + if colW+rw > maxW { + break + } + colW += rw + end++ + } + if end == 0 { + // 单个字符超宽(如极窄终端)——强制取一个避免死循环 + end = 1 } result = append(result, string(runes[:end])) runes = runes[end:] @@ -451,7 +466,8 @@ func runStatusText(status string) string { func panelTitleLines(st Styles, title string, width int, compact bool) []string { var rendered string if width > 0 { - rendered = st.PanelHead.Width(width).Render(" " + title) + // 截断标题防止超宽后被 lipgloss 折行 + rendered = st.PanelHead.Width(width).Render(" " + truncate(title, maxInt(1, width-1))) } else { rendered = st.PanelHead.Render(" " + title) } @@ -638,9 +654,10 @@ func labelValue(st Styles, label, value string) string { } // wrapPanel 用带边框的 Panel 包裹内容,outerW 为包含边框的总宽度。 +// lipgloss v2: Width(n) = 外部总宽度(含 border),Panel 有 1 字符宽边框。 func wrapPanel(st Styles, content string, outerW int) string { if outerW < 4 { return content } - return st.Panel.Width(outerW - 2).Render(content) + return st.Panel.Width(outerW).Render(content) } diff --git a/internal/tui/pages/layout.go b/internal/tui/pages/layout.go index 3cad1aa..e29504c 100644 --- a/internal/tui/pages/layout.go +++ b/internal/tui/pages/layout.go @@ -160,9 +160,9 @@ func PanelContentHeight(outerH int) int { } // RemainingStackOuterHeight 计算纵向堆叠场景下,最后一个区块可用的外层高度。 -// 会统一扣除前置区块自身高度,以及区块之间的换行间隔,避免各页重复手写偏移逻辑。 +// 会统一扣除前置区块自身高度。各块通过 strings.Join 拼接,总行数 = 各块行数之和,无需额外扣减分隔符。 func RemainingStackOuterHeight(totalH int, fixedOuterHeights ...int) int { - remaining := totalH - len(fixedOuterHeights) + remaining := totalH for _, h := range fixedOuterHeights { remaining -= h } diff --git a/internal/tui/pages/layout_test.go b/internal/tui/pages/layout_test.go index 83601b7..34f2253 100644 --- a/internal/tui/pages/layout_test.go +++ b/internal/tui/pages/layout_test.go @@ -50,16 +50,17 @@ func TestPageLayoutFrameCalculatesNestedPanelSizes(t *testing.T) { func TestRemainingStackOuterHeightAccountsForJoinGaps(t *testing.T) { totalHeight := 24 remaining := RemainingStackOuterHeight(totalHeight, 9, 3) - if remaining != 10 { - t.Fatalf("expected remaining outer height 10, got %d", remaining) + if remaining != 12 { + t.Fatalf("expected remaining outer height 12, got %d", remaining) } - used := 9 + 1 + 3 + 1 + remaining + // strings.Join 拼接时总行数 = 各块行数之和,分隔符 \n 不额外增加行数 + used := 9 + 3 + remaining if used != totalHeight { t.Fatalf("expected stacked blocks to fit exactly, used %d of %d", used, totalHeight) } - if PanelContentHeight(remaining) != 8 { - t.Fatalf("expected remaining content height 8, got %d", PanelContentHeight(remaining)) + if PanelContentHeight(remaining) != 10 { + t.Fatalf("expected remaining content height 10, got %d", PanelContentHeight(remaining)) } } diff --git a/internal/tui/pages/proxy.go b/internal/tui/pages/proxy.go index 729a7a8..62ff30a 100644 --- a/internal/tui/pages/proxy.go +++ b/internal/tui/pages/proxy.go @@ -6,7 +6,7 @@ import ( "github.com/charmbracelet/bubbles/cursor" "github.com/charmbracelet/bubbles/textinput" tea "github.com/charmbracelet/bubbletea" - "github.com/charmbracelet/lipgloss" + "charm.land/lipgloss/v2" ) // ProxyConfigState 代理配置页面状态。 @@ -82,10 +82,10 @@ func buildProxyConfigContent(s *ProxyConfigState, st Styles, contentW, maxH int) lines = append(lines, "") // 字段宽度(与 wizard renderWizardField 保持一致) + // lipgloss v2: Width(fieldW+4) 使内容区 = fieldW,与 input.Width 对齐 fieldW := maxInt(10, contentW-19) s.input.Width = fieldW - s.input.TextStyle = st.Value - renderedField := st.FieldActive.Width(fieldW).Render(s.input.View()) + renderedField := st.FieldActive.Width(fieldW + 4).Render(s.input.View()) labelBlock := strings.Join([]string{ strings.Repeat(" ", 15), diff --git a/internal/tui/pages/reqdetail.go b/internal/tui/pages/reqdetail.go index cb92203..bfabbbc 100644 --- a/internal/tui/pages/reqdetail.go +++ b/internal/tui/pages/reqdetail.go @@ -219,7 +219,7 @@ func buildReqNetworkPanel(r *types.RequestMetrics, st Styles, maxH, width int) s lines = append(lines, " "+labelValue(st, "TCP 连接 ", fmtDuration(r.ConnectTime))) lines = append(lines, " "+labelValue(st, "TLS 握手 ", fmtDuration(r.TLSTime))) if r.TargetIP != "" { - lines = append(lines, " "+labelValue(st, "目标 IP ", r.TargetIP)) + lines = append(lines, " "+labelValue(st, "目标 IP ", truncate(r.TargetIP, maxInt(4, width-12)))) } return finishPanelLines(lines, maxH) diff --git a/internal/tui/pages/styles.go b/internal/tui/pages/styles.go index 6551e71..736b3cb 100644 --- a/internal/tui/pages/styles.go +++ b/internal/tui/pages/styles.go @@ -1,9 +1,9 @@ package pages -import "github.com/charmbracelet/lipgloss" +import "charm.land/lipgloss/v2" // Color palette -const ( +var ( colorHeaderBg = lipgloss.Color("17") // dark navy — refined header background colorHotkeysSecondaryBg = lipgloss.Color("235") // near-black secondary hotkeys background colorHotkeysPrimaryBg = lipgloss.Color("237") // slightly lighter primary hotkeys background diff --git a/internal/tui/pages/taskdetail.go b/internal/tui/pages/taskdetail.go index 8282683..acaa23a 100644 --- a/internal/tui/pages/taskdetail.go +++ b/internal/tui/pages/taskdetail.go @@ -5,7 +5,7 @@ import ( "strings" "time" - "github.com/charmbracelet/lipgloss" + "charm.land/lipgloss/v2" tea "github.com/charmbracelet/bubbletea" "github.com/yinxulai/ait/internal/server" "github.com/yinxulai/ait/internal/types" diff --git a/internal/tui/pages/tasklist.go b/internal/tui/pages/tasklist.go index 78d0986..9d5f430 100644 --- a/internal/tui/pages/tasklist.go +++ b/internal/tui/pages/tasklist.go @@ -6,7 +6,7 @@ import ( "time" tea "github.com/charmbracelet/bubbletea" - "github.com/charmbracelet/lipgloss" + "charm.land/lipgloss/v2" "github.com/yinxulai/ait/internal/server" "github.com/yinxulai/ait/internal/types" ) @@ -182,12 +182,15 @@ func buildTaskListContent(s *TaskListState, st Styles, width, maxH int) string { // 表头:2 空格前缀与正文行对齐(cursor=2) header := renderTableHeader(st, width, - " "+padRight("任务名称", nameW)+ - padRight("模式", modeW)+ - padRight("协议", protoW)+ - padRight("上次运行", lastRunW)+ - padRight("TTFT", ttftW)+ - "TPS") + lipgloss.JoinHorizontal(lipgloss.Top, + tableCol(2, ""), + tableCol(nameW, "任务名称"), + tableCol(modeW, "模式"), + tableCol(protoW, "协议"), + tableCol(lastRunW, "上次运行"), + tableCol(ttftW, "TTFT"), + "TPS", + )) lines = append(lines, header) lines = append(lines, dividerLine(st, width)) listMaxH := maxInt(3, maxH-listTopLines) @@ -197,6 +200,10 @@ func buildTaskListContent(s *TaskListState, st Styles, width, maxH int) string { if len(s.Tasks) == 0 { lines = append(lines, "") lines = append(lines, " "+st.Muted.Render("暂无任务 按 [a] 新建第一个任务")) + // 补齐剩余行 + for len(lines) < maxH { + lines = append(lines, "") + } return strings.Join(lines, "\n") } @@ -210,28 +217,23 @@ func buildTaskListContent(s *TaskListState, st Styles, width, maxH int) string { _, hasActiveRun := s.ActiveRuns[t.ID] // ── 指示符 ── - prefix := padRight(selectionMarker(isSel), 2) + prefix := tableCol(2, selectionMarker(isSel)) // ── 模式(选中行禁用嵌套样式,避免重置整行背景)── modeText := "标准" - modeCol := padRight(modeText, modeW) + var modeCol string if t.Input.Turbo { modeText = "Turbo" - modeCol = padRight(styleWhenNotSelected(isSel, lipgloss.NewStyle().Foreground(colorGold).Bold(true), modeText), modeW) + modeCol = tableCol(modeW, styleWhenNotSelected(isSel, lipgloss.NewStyle().Foreground(colorGold).Bold(true), modeText)) } else { - modeCol = padRight(styleWhenNotSelected(isSel, lipgloss.NewStyle().Foreground(colorPurple), modeText), modeW) + modeCol = tableCol(modeW, styleWhenNotSelected(isSel, lipgloss.NewStyle().Foreground(colorPurple), modeText)) } // ── 协议 ── - proto := padRight(shortProtocol(t.Input.NormalizedProtocol()), protoW) + proto := tableCol(protoW, shortProtocol(t.Input.NormalizedProtocol())) - // ── 任务名称(裁剪)── - name := truncate(t.Name, nameW) - namePad := nameW - lipgloss.Width(name) - if namePad < 0 { - namePad = 0 - } - nameCol := name + strings.Repeat(" ", namePad) + // ── 任务名称 ── + nameCol := tableCol(nameW, t.Name) // ── 上次运行时间 ── lastRunText := "─" @@ -244,7 +246,7 @@ func buildTaskListContent(s *TaskListState, st Styles, width, maxH int) string { if hasActiveRun || (t.LatestRun != nil && t.LatestRun.Status == string(server.RunStatusRunning)) { lastRunStyle = st.Ok } - lastRunCol := padRight(styleWhenNotSelected(isSel, lastRunStyle, lastRunText), lastRunW) + lastRunCol := tableCol(lastRunW, styleWhenNotSelected(isSel, lastRunStyle, lastRunText)) // ── TTFT ── ttftText := "─" @@ -253,7 +255,7 @@ func buildTaskListContent(s *TaskListState, st Styles, width, maxH int) string { } else if !hasActiveRun && t.LatestRun != nil { ttftText = fmtDuration(t.LatestRun.AvgTTFT) } - ttftCol := padRight(styleWhenNotSelected(isSel, st.Value, ttftText), ttftW) + ttftCol := tableCol(ttftW, styleWhenNotSelected(isSel, st.Value, ttftText)) // ── TPS ── tpsText := "─" @@ -269,8 +271,8 @@ func buildTaskListContent(s *TaskListState, st Styles, width, maxH int) string { tpsCol := styleWhenNotSelected(isSel, st.Value, tpsText) // ── 单行:名称 | 模式 | 协议 | 上次运行 | TTFT | TPS ── - rowContent := nameCol + modeCol + proto + lastRunCol + ttftCol + tpsCol - lines = append(lines, renderTableRow(st, width, isSel, prefix+rowContent)) + lines = append(lines, renderTableRow(st, width, isSel, lipgloss.JoinHorizontal(lipgloss.Top, + prefix, nameCol, modeCol, proto, lastRunCol, ttftCol, tpsCol))) // ── 分隔线 ── if i < end-1 && len(lines) < maxH-1 { diff --git a/internal/tui/pages/turbodash.go b/internal/tui/pages/turbodash.go index d5e81af..91b356f 100644 --- a/internal/tui/pages/turbodash.go +++ b/internal/tui/pages/turbodash.go @@ -5,7 +5,7 @@ import ( "strings" tea "github.com/charmbracelet/bubbletea" - "github.com/charmbracelet/lipgloss" + "charm.land/lipgloss/v2" "github.com/yinxulai/ait/internal/server" "github.com/yinxulai/ait/internal/types" ) @@ -294,13 +294,16 @@ func buildTurboProgressLine(rs *server.RunState, st Styles, width int) string { barW := width - lipgloss.Width(prefix) - lipgloss.Width(suffix) if barW < 5 { barW = 5 + // 压缩 suffix 确保进度行总宽度不超过 width,防止 lipgloss 折行 + maxSuffixW := maxInt(0, width-lipgloss.Width(prefix)-barW) + suffix = truncate(suffix, maxSuffixW) } filled := int(ratio * float64(barW)) barRendered := st.Ok.Render(strings.Repeat("█", filled)) + st.Muted.Render(strings.Repeat("░", barW-filled)) - return prefix + barRendered + suffix + return lipgloss.JoinHorizontal(lipgloss.Top, prefix, barRendered, suffix) } // buildTurboRequestList 构建 Turbo 模式请求列表区域。 @@ -325,8 +328,16 @@ func buildTurboRequestList(d *TurboDashState, rs *server.RunState, st Styles, wi cacheW = 8 tokW = 10 ) - hdr := padRight("", markW) + padRight("#", idW) + padRight("状态", statW) + padRight("总耗时", timeW) + - padRight("TTFT", ttftW) + padRight("Cache", cacheW) + padRight("Token", tokW) + "TPS" + hdr := lipgloss.JoinHorizontal(lipgloss.Top, + tableCol(markW, ""), + tableCol(idW, "#"), + tableCol(statW, "状态"), + tableCol(timeW, "总耗时"), + tableCol(ttftW, "TTFT"), + tableCol(cacheW, "Cache"), + tableCol(tokW, "Token"), + "TPS", + ) lines = append(lines, renderTableHeader(st, width, hdr)) lines = append(lines, dividerLine(st, width)) d.ReqVis = listVisibleItems(maxH, 3) @@ -361,14 +372,16 @@ func buildTurboRequestList(d *TurboDashState, rs *server.RunState, st Styles, wi } marker := selectionMarker(isSel) - rowContent := padRight(marker, markW) + - padRight(fmt.Sprintf("#%d", len(reqs)-pos), idW) + - padRight(statusStr, statW) + - padRight(totalStr, timeW) + - padRight(fmtDuration(r.TTFT), ttftW) + - padRight(fmt.Sprintf("%.0f%%", r.CacheHitRate*100), cacheW) + - padRight(fmt.Sprintf("%dtok", r.CompletionTokens), tokW) + - fmt.Sprintf("%.1f/s", r.TPS) + rowContent := lipgloss.JoinHorizontal(lipgloss.Top, + tableCol(markW, marker), + tableCol(idW, fmt.Sprintf("#%d", len(reqs)-pos)), + tableCol(statW, statusStr), + tableCol(timeW, totalStr), + tableCol(ttftW, fmtDuration(r.TTFT)), + tableCol(cacheW, fmt.Sprintf("%.0f%%", r.CacheHitRate*100)), + tableCol(tokW, fmt.Sprintf("%dtok", r.CompletionTokens)), + fmt.Sprintf("%.1f/s", r.TPS), + ) rendered := renderTableRow(st, width, isSel, rowContent) lines = append(lines, rendered) diff --git a/internal/tui/pages/verify_test.go b/internal/tui/pages/verify_test.go new file mode 100644 index 0000000..15b8797 --- /dev/null +++ b/internal/tui/pages/verify_test.go @@ -0,0 +1,83 @@ +package pages + +import ( + "fmt" + "strings" + "testing" + "time" + + "github.com/yinxulai/ait/internal/server" + "github.com/yinxulai/ait/internal/types" +) + +func TestHeightCorrectness(t *testing.T) { + st := NewStyles() + + // ReqDetail + s := &ReqDetailState{ + RunID: server.RunID("test"), + Requests: []*types.RequestMetrics{{ + Success: true, TotalTime: 250 * time.Millisecond, + RequestBody: "hello", ResponseBody: strings.Repeat("ok ", 50), + }}, + } + fmt.Println("--- ReqDetail ---") + for _, h := range []int{24, 26, 30, 40} { + out := RenderReqDetail(s, "task", st, 80, h) + got := strings.Count(out, "\n") + 1 + diff := got - h + marker := "✓" + if diff != 0 { marker = fmt.Sprintf("FAIL diff=%+d", diff) } + fmt.Printf("height=%d → rendered=%d %s\n", h, got, marker) + if diff != 0 { t.Errorf("ReqDetail height=%d: want %d lines, got %d", h, h, got) } + } + + // Dashboard + fmt.Println("--- Dashboard ---") + ds := NewDashboardState("run1", "task1") + for _, h := range []int{22, 26, 30, 40} { + out := RenderDashboard(ds, "task", st, 80, h) + got := strings.Count(out, "\n") + 1 + diff := got - h + marker := "✓" + if diff != 0 { marker = fmt.Sprintf("FAIL diff=%+d", diff) } + fmt.Printf("height=%d → rendered=%d %s\n", h, got, marker) + if diff != 0 { t.Errorf("Dashboard height=%d: want %d lines, got %d", h, h, got) } + } + + // TaskList (empty tasks) + fmt.Println("--- TaskList (empty) ---") + ts := &TaskListState{} + for _, h := range []int{24, 26, 30, 40} { + out := RenderTaskList(ts, st, 80, h) + got := strings.Count(out, "\n") + 1 + diff := got - h + marker := "✓" + if diff != 0 { marker = fmt.Sprintf("FAIL diff=%+d", diff) } + fmt.Printf("height=%d → rendered=%d %s\n", h, got, marker) + if diff != 0 { t.Errorf("TaskList height=%d: want %d lines, got %d", h, h, got) } + } +} + +func TestHeightWithCJKContent(t *testing.T) { + st := NewStyles() + // 模拟真实 LLM 响应:纯中文内容 + cjkBody := strings.Repeat("你好,我是一个大型语言模型,很高兴为你服务。", 10) + s := &ReqDetailState{ + RunID: server.RunID("test"), + Requests: []*types.RequestMetrics{{ + Success: true, TotalTime: 500 * time.Millisecond, + RequestBody: "请介绍一下自己", + ResponseBody: cjkBody, + }}, + } + for _, h := range []int{24, 30, 40} { + out := RenderReqDetail(s, "task", st, 80, h) + got := strings.Count(out, "\n") + 1 + diff := got - h + marker := "✓" + if diff != 0 { marker = fmt.Sprintf("FAIL diff=%+d", diff) } + t.Logf("CJK ReqDetail height=%d → rendered=%d %s", h, got, marker) + if diff != 0 { t.Errorf("CJK content: height=%d want %d lines, got %d", h, h, got) } + } +} diff --git a/internal/tui/pages/wizard.go b/internal/tui/pages/wizard.go index 54afce1..8ce045c 100644 --- a/internal/tui/pages/wizard.go +++ b/internal/tui/pages/wizard.go @@ -9,7 +9,7 @@ import ( "github.com/charmbracelet/bubbles/cursor" "github.com/charmbracelet/bubbles/textinput" tea "github.com/charmbracelet/bubbletea" - "github.com/charmbracelet/lipgloss" + "charm.land/lipgloss/v2" "github.com/yinxulai/ait/internal/server" "github.com/yinxulai/ait/internal/types" ) @@ -752,9 +752,10 @@ func renderWizardField(st Styles, f fieldDef, wz *WizardState, active bool, maxW valueStr = maskAPIKey(valueStr) } - // FieldActive/Idle: Width(n) = 内容区宽度(在 padding/border 之内) - // 总渲染宽度 = n + padding(2) + border(2) = n + 4 - // Line1 = label(14) + space(1) + (n+4) = n + 19 ≤ maxW → n = maxW - 19 + // lipgloss v2: Width(n) = 外部总宽度(含 border+padding) + // 内容区 = n - border(2) - padding(2) = n - 4 + // fieldW 为内容区目标宽度,渲染时传 fieldW+4 作为 Width 参数 + // 总宽 = label(15) + (fieldW+4) = fieldW + 19 ≤ maxW → fieldW = maxW - 19 fieldW := maxInt(10, maxW-19) valueStyle := st.Value if valueStr == "" && !active { @@ -784,11 +785,10 @@ func renderWizardField(st Styles, f fieldDef, wz *WizardState, active bool, maxW var renderedValue string if active && (f.kind == fieldText || f.kind == fieldNumber) { - // textinput 自带光标和滚动,设置文本样式后直接渲染,不再二次包裹 valueStyle - wz.input.TextStyle = valueStyle - renderedValue = fieldStyle.Width(fieldW).Render(wz.input.View()) + // textinput 自带光标和滚动;Width(fieldW+4) 使内容区 = fieldW,与 input.Width 对齐 + renderedValue = fieldStyle.Width(fieldW + 4).Render(wz.input.View()) } else { - renderedValue = fieldStyle.Width(fieldW).Render(valueStyle.Render(valueStr)) + renderedValue = fieldStyle.Width(fieldW + 4).Render(valueStyle.Render(valueStr)) } labelLines := []string{ strings.Repeat(" ", 15), From a1d64c6d04e51ab9800a43a5e7a7702092ecb854 Mon Sep 17 00:00:00 2001 From: Alain Date: Sat, 23 May 2026 09:50:23 +0800 Subject: [PATCH 32/52] =?UTF-8?q?feat:=20=E6=B7=BB=E5=8A=A0=E6=89=B9?= =?UTF-8?q?=E9=87=8F=E5=88=9B=E5=BB=BA=E4=BB=BB=E5=8A=A1=E5=8A=9F=E8=83=BD?= =?UTF-8?q?=E5=8F=8A=E7=9B=B8=E5=85=B3=E7=95=8C=E9=9D=A2=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/tui/client.go | 12 ++ internal/tui/messages.go | 5 + internal/tui/model.go | 19 +- internal/tui/model_test.go | 18 +- internal/tui/pages/contextbar.go | 15 +- internal/tui/pages/nav.go | 3 + internal/tui/pages/proxy.go | 186 ++++++++++++++++++-- internal/tui/pages/tasklist.go | 61 ++++++- internal/tui/pages/wizard.go | 292 ++++++++++++++++++++++++------- 9 files changed, 509 insertions(+), 102 deletions(-) diff --git a/internal/tui/client.go b/internal/tui/client.go index 902e392..e109500 100644 --- a/internal/tui/client.go +++ b/internal/tui/client.go @@ -76,6 +76,18 @@ func (c *Client) CopyTaskCmd(id string) tea.Cmd { } } +// CreateBatchTasksCmd 批量异步创建任务。 +func (c *Client) CreateBatchTasksCmd(cfgs []server.TaskConfig) tea.Cmd { + return func() tea.Msg { + for _, cfg := range cfgs { + if _, err := c.srv.CreateTask(cfg); err != nil { + return ErrorMsg{Err: fmt.Errorf("批量创建任务 %q 失败: %w", cfg.Name, err)} + } + } + return BatchTasksSavedMsg{Count: len(cfgs)} + } +} + // ─── 运行管理 ───────────────────────────────────────────────────────────────── // StartRunCmd 异步启动运行。 diff --git a/internal/tui/messages.go b/internal/tui/messages.go index fee072f..00f4900 100644 --- a/internal/tui/messages.go +++ b/internal/tui/messages.go @@ -21,6 +21,11 @@ type TaskDeletedMsg struct { TaskID string } +// BatchTasksSavedMsg 批量创建任务完成。 +type BatchTasksSavedMsg struct { + Count int +} + // HistoryLoadedMsg 任务历史记录加载完成。 type HistoryLoadedMsg struct { TaskID string diff --git a/internal/tui/model.go b/internal/tui/model.go index afdff2f..48837f1 100644 --- a/internal/tui/model.go +++ b/internal/tui/model.go @@ -39,13 +39,13 @@ type Model struct { err error // 页面局部状态(由 pages 包管理) - taskList *pages.TaskListState - detail *pages.TaskDetailState - wizard *pages.WizardState - dash *pages.DashboardState - turboDash *pages.TurboDashState - reqDetail *pages.ReqDetailState - proxyConf *pages.ProxyConfigState + taskList *pages.TaskListState + detail *pages.TaskDetailState + wizard *pages.WizardState + dash *pages.DashboardState + turboDash *pages.TurboDashState + reqDetail *pages.ReqDetailState + proxyConf *pages.ProxyConfigState } // NewModel 创建 Model。srv 不能为 nil。 @@ -221,6 +221,11 @@ func (m *Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { case ProxyConfigSavedMsg: m.status = "代理配置已保存" return m, nil + + case BatchTasksSavedMsg: + m.status = fmt.Sprintf("已批量创建 %d 个任务", msg.Count) + m.view = viewTaskList + return m, m.client.LoadTasksCmd() } return m, nil diff --git a/internal/tui/model_test.go b/internal/tui/model_test.go index 9c8cc37..2bc04e3 100644 --- a/internal/tui/model_test.go +++ b/internal/tui/model_test.go @@ -95,24 +95,20 @@ func TestOpenWizard_EditTask_Populate(t *testing.T) { if m.wizard.EditingID != "task-123" { t.Errorf("EditingID = %q, want %q", m.wizard.EditingID, "task-123") } - if m.wizard.Model != "gpt-4" { - t.Errorf("Model = %q, want %q", m.wizard.Model, "gpt-4") + if m.wizard.ModelsText != "gpt-4" { + t.Errorf("ModelsText = %q, want %q", m.wizard.ModelsText, "gpt-4") } if m.wizard.Concurrency != 5 { t.Errorf("Concurrency = %d, want 5", m.wizard.Concurrency) } - if m.wizard.ProxyURL != "http://proxy.internal:8080" { - t.Errorf("ProxyURL = %q, want %q", m.wizard.ProxyURL, "http://proxy.internal:8080") - } } func TestBuildTaskInput_Standard(t *testing.T) { m := NewModel(&stubServer{}) m.wizard = pages.NewWizardState() wz := m.wizard - wz.Model = "gpt-4.1" + wz.ModelsText = "gpt-4.1" wz.APIKey = "sk-test" - wz.ProxyURL = "http://proxy.internal:8080" wz.Concurrency = 8 wz.Count = 120 wz.PromptMode = pages.PromptModeText @@ -129,9 +125,6 @@ func TestBuildTaskInput_Standard(t *testing.T) { if inp.Count != 120 { t.Errorf("count = %d, want 120", inp.Count) } - if inp.ProxyURL != "http://proxy.internal:8080" { - t.Errorf("proxy_url = %q, want %q", inp.ProxyURL, "http://proxy.internal:8080") - } if inp.PromptMode != pages.PromptModeText || inp.PromptText != "hello" { t.Errorf("unexpected prompt config: mode=%q text=%q", inp.PromptMode, inp.PromptText) } @@ -144,9 +137,10 @@ func TestBuildTaskInput_Turbo(t *testing.T) { m := NewModel(&stubServer{}) m.wizard = pages.NewWizardState() wz := m.wizard - wz.Model = "claude-3-7-sonnet" + wz.ModelsText = "claude-3-7-sonnet" wz.APIKey = "sk-ant" - wz.Protocol = types.ProtocolAnthropicMessages + wz.SelOpenAICompletions = false + wz.SelAnthropic = true wz.Turbo = true wz.InitConcurrency = 1 wz.MaxConcurrency = 12 diff --git a/internal/tui/pages/contextbar.go b/internal/tui/pages/contextbar.go index 4f4c7f8..72f566a 100644 --- a/internal/tui/pages/contextbar.go +++ b/internal/tui/pages/contextbar.go @@ -32,6 +32,7 @@ func Hotkeys_TaskList_Normal() []HotkeyItem { return []HotkeyItem{ HotkeyAction("Enter", "查看详情"), HotkeyAction("r", "运行"), + HotkeyAction("a", "新建任务"), HotkeyAction("e", "编辑"), HotkeyAction("d", "删除"), HotkeyAction("y", "复制"), @@ -87,7 +88,7 @@ func Hotkeys_TaskDetail_Running() []HotkeyItem { func Hotkeys_Wizard_Step1() []HotkeyItem { return []HotkeyItem{ HotkeyAction("Tab/↑↓", "切换字段"), - HotkeyAction("←→", "切换协议"), + HotkeyAction("Space/←→", "切换选项"), HotkeyAction("Enter", "下一步"), HotkeyAction("Esc", "返回列表"), } @@ -114,6 +115,16 @@ func Hotkeys_Wizard_Step3() []HotkeyItem { } } +// Hotkeys_Wizard_Step3_Batch 批量创建确认页热键。 +func Hotkeys_Wizard_Step3_Batch() []HotkeyItem { + return []HotkeyItem{ + HotkeyAction("↑↓", "滚动"), + HotkeyAction("PgUp/PgDn", "翻页"), + HotkeyAction("Enter", "批量创建"), + HotkeyAction("Esc", "返回修改"), + } +} + // Hotkeys_Dashboard_Running_NoSel 标准仪表盘运行中,无选中请求时。 func Hotkeys_Dashboard_Running_NoSel() []HotkeyItem { return []HotkeyItem{ @@ -193,6 +204,8 @@ func Hotkeys_ReqDetail() []HotkeyItem { // Hotkeys_ProxyConfig 代理配置页。 func Hotkeys_ProxyConfig() []HotkeyItem { return []HotkeyItem{ + HotkeyAction("Tab/↑↓", "切换字段"), + HotkeyAction("←→/Space", "切换类型"), HotkeyAction("Enter", "保存"), HotkeyAction("Ctrl+U", "清空"), } diff --git a/internal/tui/pages/nav.go b/internal/tui/pages/nav.go index 7f70876..40c025f 100644 --- a/internal/tui/pages/nav.go +++ b/internal/tui/pages/nav.go @@ -55,6 +55,9 @@ type Client interface { GetRunStateForHistoryCmd(runID server.RunID, summary *types.TaskRunSummary) tea.Cmd GenerateReportCmd(runID server.RunID, format server.ReportFormat) tea.Cmd + // 批量创建任务 + CreateBatchTasksCmd(cfgs []server.TaskConfig) tea.Cmd + // 全局配置 SaveProxyConfigCmd(proxyURL string) tea.Cmd LoadProxyConfigCmd() tea.Cmd diff --git a/internal/tui/pages/proxy.go b/internal/tui/pages/proxy.go index 62ff30a..ba2e4e4 100644 --- a/internal/tui/pages/proxy.go +++ b/internal/tui/pages/proxy.go @@ -9,20 +9,97 @@ import ( "charm.land/lipgloss/v2" ) +// 代理类型常量 +const ( + ProxyTypeHTTP = "http" + ProxyTypeSOCKS5 = "socks5" + ProxyTypeSSH = "ssh" +) + +var proxyTypes = []string{ProxyTypeHTTP, ProxyTypeSOCKS5, ProxyTypeSSH} + // ProxyConfigState 代理配置页面状态。 type ProxyConfigState struct { - input textinput.Model // 代理 URL 输入框 + ProxyType string // "http" | "socks5" | "ssh" + FieldIndex int // 0=代理类型, 1=代理地址 + input textinput.Model } // NewProxyConfigState 创建代理配置页面状态,传入当前已保存的代理 URL。 func NewProxyConfigState(currentURL string) *ProxyConfigState { + proxyType := ProxyTypeHTTP + switch { + case strings.HasPrefix(currentURL, "socks5://"): + proxyType = ProxyTypeSOCKS5 + case strings.HasPrefix(currentURL, "ssh://"): + proxyType = ProxyTypeSSH + } + ti := textinput.New() ti.Prompt = "" ti.Cursor.SetMode(cursor.CursorStatic) ti.SetValue(currentURL) ti.CursorEnd() ti.Focus() - return &ProxyConfigState{input: ti} + return &ProxyConfigState{ProxyType: proxyType, FieldIndex: 1, input: ti} +} + +// proxyTypeLabel 返回代理类型的显示名。 +func proxyTypeLabel(t string) string { + switch t { + case ProxyTypeSOCKS5: + return "SOCKS5" + case ProxyTypeSSH: + return "SSH" + default: + return "HTTP" + } +} + +// proxyTypeHint 返回类型对应的示例 URL 提示。 +func proxyTypeHint(t string) string { + switch t { + case ProxyTypeSOCKS5: + return "示例: socks5://127.0.0.1:1080" + case ProxyTypeSSH: + return "示例: ssh://user@host:22" + default: + return "示例: http://127.0.0.1:7890" + } +} + +// cycleProxyType 循环切换代理类型,同时更新 URL 的 scheme 前缀。 +func cycleProxyType(s *ProxyConfigState, forward bool) { + idx := 0 + for i, t := range proxyTypes { + if t == s.ProxyType { + idx = i + break + } + } + if forward { + idx = (idx + 1) % len(proxyTypes) + } else { + idx = (idx - 1 + len(proxyTypes)) % len(proxyTypes) + } + newType := proxyTypes[idx] + + // 更新 URL scheme 前缀 + url := s.input.Value() + for _, t := range proxyTypes { + scheme := t + "://" + if strings.HasPrefix(url, scheme) { + url = strings.TrimPrefix(url, scheme) + break + } + } + if url != "" { + url = newType + "://" + url + } + s.input.SetValue(url) + s.input.CursorEnd() + + s.ProxyType = newType } // HandleProxyConfigKey 处理代理配置页面的按键。 @@ -37,17 +114,61 @@ func HandleProxyConfigKey(s *ProxyConfigState, msg tea.KeyMsg, client Client) (* nav = NavAction{To: NavTaskList} case "enter": + if s.FieldIndex == 0 { + // 在类型字段上按 Enter,切换到 URL 字段 + s.FieldIndex = 1 + return s, nil, nav + } cmd := client.SaveProxyConfigCmd(s.input.Value()) nav = NavAction{To: NavTaskList} return s, cmd, nav + case "tab", "down", "j": + if s.FieldIndex < 1 { + s.FieldIndex++ + } + + case "shift+tab", "up", "k": + if s.FieldIndex > 0 { + s.FieldIndex-- + } + + case "left": + if s.FieldIndex == 0 { + cycleProxyType(s, false) + } else { + var cmd tea.Cmd + s.input, cmd = s.input.Update(msg) + return s, cmd, nav + } + + case "right": + if s.FieldIndex == 0 { + cycleProxyType(s, true) + } else { + var cmd tea.Cmd + s.input, cmd = s.input.Update(msg) + return s, cmd, nav + } + + case " ": + if s.FieldIndex == 0 { + cycleProxyType(s, true) + } else { + var cmd tea.Cmd + s.input, cmd = s.input.Update(msg) + return s, cmd, nav + } + case "q", "ctrl+c": nav = NavAction{To: NavQuit} default: - var cmd tea.Cmd - s.input, cmd = s.input.Update(msg) - return s, cmd, nav + if s.FieldIndex == 1 { + var cmd tea.Cmd + s.input, cmd = s.input.Update(msg) + return s, cmd, nav + } } return s, nil, nav @@ -78,25 +199,58 @@ func RenderProxyConfig(s *ProxyConfigState, st Styles, width, height int) string func buildProxyConfigContent(s *ProxyConfigState, st Styles, contentW, maxH int) string { var lines []string - lines = append(lines, st.SectionHead.Render("代理地址")) - lines = append(lines, "") + appendBlock := func(block string) { + for _, l := range strings.Split(block, "\n") { + lines = append(lines, l) + } + } - // 字段宽度(与 wizard renderWizardField 保持一致) - // lipgloss v2: Width(fieldW+4) 使内容区 = fieldW,与 input.Width 对齐 fieldW := maxInt(10, contentW-19) - s.input.Width = fieldW - renderedField := st.FieldActive.Width(fieldW + 4).Render(s.input.View()) - labelBlock := strings.Join([]string{ + // 代理类型字段 + typeLabel := proxyTypeLabel(s.ProxyType) + if s.FieldIndex == 0 { + typeLabel = "‹ " + typeLabel + " ›" + } + typeLabel = truncate(typeLabel, maxInt(4, fieldW)) + typeFieldStyle := st.FieldIdle + if s.FieldIndex == 0 { + typeFieldStyle = st.FieldActive + } + typeLabelBlock := strings.Join([]string{ + strings.Repeat(" ", 15), + lipgloss.NewStyle().Width(15).Render(st.Label.Render("代理类型")), + strings.Repeat(" ", 15), + }, "\n") + typeRendered := typeFieldStyle.Width(fieldW + 4).Render(st.Value.Render(typeLabel)) + appendBlock(lipgloss.JoinHorizontal(lipgloss.Top, typeLabelBlock, typeRendered)) + + // 代理地址字段 + s.input.Width = fieldW + urlFieldStyle := st.FieldIdle + if s.FieldIndex == 1 { + urlFieldStyle = st.FieldActive + } + var urlRendered string + if s.FieldIndex == 1 { + urlRendered = urlFieldStyle.Width(fieldW + 4).Render(s.input.View()) + } else { + v := s.input.Value() + if v == "" { + urlRendered = urlFieldStyle.Width(fieldW + 4).Render(st.Muted.Render("未填写")) + } else { + urlRendered = urlFieldStyle.Width(fieldW + 4).Render(st.Value.Render(fitTail(v, fieldW))) + } + } + urlLabelBlock := strings.Join([]string{ strings.Repeat(" ", 15), lipgloss.NewStyle().Width(15).Render(st.Label.Render("代理地址")), strings.Repeat(" ", 15), }, "\n") - lines = append(lines, lipgloss.JoinHorizontal(lipgloss.Top, labelBlock, renderedField)) - lines = append(lines, "") + appendBlock(lipgloss.JoinHorizontal(lipgloss.Top, urlLabelBlock, urlRendered)) - hint := "示例: http://127.0.0.1:7890 或留空以直连" - lines = append(lines, st.Muted.Render(truncate(hint, contentW))) + lines = append(lines, "") + lines = append(lines, st.Muted.Render(truncate(proxyTypeHint(s.ProxyType), contentW))) lines = append(lines, "") lines = append(lines, st.Muted.Render(truncate("配置保存至 ~/.ait/config.json,重启无需重新输入。", contentW))) diff --git a/internal/tui/pages/tasklist.go b/internal/tui/pages/tasklist.go index 9d5f430..282d0b7 100644 --- a/internal/tui/pages/tasklist.go +++ b/internal/tui/pages/tasklist.go @@ -19,6 +19,8 @@ type TaskListState struct { Visible int // 运行中任务的进度(runID -> RunState 快照,由 Model 注入) ActiveRuns map[string]*server.RunState // taskID -> RunState + // 删除二次确认 + ConfirmDelete bool } // NewTaskListState 创建初始任务列表状态。 @@ -62,6 +64,22 @@ func (s *TaskListState) latestRunAt() *time.Time { func HandleTaskListKey(s *TaskListState, msg tea.KeyMsg, client Client) (*TaskListState, tea.Cmd, NavAction) { nav := NavAction{} + // 删除确认模式:拦截所有按键,只处理确认/取消 + if s.ConfirmDelete { + switch msg.String() { + case "y", "enter": + s.ConfirmDelete = false + if t, ok := s.CurrentTask(); ok { + return s, client.DeleteTaskCmd(t.ID), nav + } + case "n", "esc", "q": + s.ConfirmDelete = false + case "ctrl+c": + nav = NavAction{To: NavQuit} + } + return s, nil, nav + } + switch msg.String() { case "up", "k": if s.Selected > 0 { @@ -89,8 +107,8 @@ func HandleTaskListKey(s *TaskListState, msg tea.KeyMsg, client Client) (*TaskLi } case "d": - if t, ok := s.CurrentTask(); ok { - return s, client.DeleteTaskCmd(t.ID), nav + if t, ok := s.CurrentTask(); ok && !s.IsTaskRunning(t.ID) { + s.ConfirmDelete = true } case "enter": @@ -128,7 +146,9 @@ func RenderTaskList(s *TaskListState, st Styles, width, height int) string { } var cbItems []HotkeyItem - if t, ok := s.CurrentTask(); ok { + if s.ConfirmDelete { + cbItems = []HotkeyItem{HotkeyAction("y/Enter", "确认删除"), HotkeyAction("n/Esc", "取消")} + } else if t, ok := s.CurrentTask(); ok { if s.IsTaskRunning(t.ID) { cbItems = Hotkeys_TaskList_Running() } else { @@ -160,7 +180,15 @@ func RenderTaskList(s *TaskListState, st Styles, width, height int) string { } frame := l.Frame(width, height) panel := NewPanelFrame(frame.OuterWidth) - content := buildTaskListContent(s, st, panel.InnerWidth, PanelContentHeight(frame.InnerHeight)) + innerW := panel.InnerWidth + innerH := PanelContentHeight(frame.InnerHeight) + + var content string + if s.ConfirmDelete { + content = buildTaskListConfirmContent(s, st, innerW, innerH) + } else { + content = buildTaskListContent(s, st, innerW, innerH) + } return l.Assemble(panel.Wrap(st, content), st, width) } @@ -288,6 +316,31 @@ func buildTaskListContent(s *TaskListState, st Styles, width, maxH int) string { return strings.Join(lines, "\n") } +// buildTaskListConfirmContent 渲染删除确认对话框内容。 +func buildTaskListConfirmContent(s *TaskListState, st Styles, width, maxH int) string { + var lines []string + task, ok := s.CurrentTask() + if !ok { + return strings.Repeat("\n", maxH-1) + } + lines = append(lines, "") + lines = append(lines, st.ErrStyle.Render(" 确认删除任务?")) + lines = append(lines, "") + lines = append(lines, " "+st.Label.Render("任务名称")+" "+st.Value.Render(truncate(task.Name, maxInt(8, width-14)))) + lines = append(lines, " "+st.Label.Render("任务 ID ")+" "+st.Muted.Render(task.ID)) + lines = append(lines, "") + lines = append(lines, " "+st.Muted.Render("此操作不可恢复,任务的历史运行记录将一并删除。")) + lines = append(lines, "") + lines = append(lines, " "+st.Value.Render("[y / Enter]")+" 确认删除 "+st.Value.Render("[n / Esc]")+" 取消") + for len(lines) < maxH { + lines = append(lines, "") + } + if len(lines) > maxH { + lines = lines[:maxH] + } + return strings.Join(lines, "\n") +} + func max(a, b int) int { if a > b { return a diff --git a/internal/tui/pages/wizard.go b/internal/tui/pages/wizard.go index 8ce045c..c462027 100644 --- a/internal/tui/pages/wizard.go +++ b/internal/tui/pages/wizard.go @@ -37,11 +37,13 @@ type WizardState struct { // Step 1: 基本信息 Name string - Protocol string // types.Protocol* 常量 + // 协议多选(至少选一个) + SelOpenAICompletions bool + SelOpenAIResponses bool + SelAnthropic bool EndpointURL string - ProxyURL string APIKey string - Model string + ModelsText string // 逗号分隔,支持多个模型 // Step 2: 测试参数 Turbo bool @@ -120,8 +122,8 @@ func loadCurrentFieldInput(wz *WizardState) { // NewWizardState 创建新建任务向导状态(使用默认值)。 func NewWizardState() *WizardState { wz := &WizardState{ - Step: wizardStep1, - Protocol: types.ProtocolOpenAICompletions, + Step: wizardStep1, + SelOpenAICompletions: true, Concurrency: 10, Count: 100, Timeout: 30, @@ -150,11 +152,19 @@ func NewWizardStateEdit(t *types.TaskDefinition) *WizardState { wz.EditingID = t.ID wz.Name = t.Name - wz.Protocol = types.NormalizeProtocol(inp.Protocol) + switch types.NormalizeProtocol(inp.Protocol) { + case types.ProtocolOpenAIResponses: + wz.SelOpenAICompletions = false + wz.SelOpenAIResponses = true + case types.ProtocolAnthropicMessages: + wz.SelOpenAICompletions = false + wz.SelAnthropic = true + default: + wz.SelOpenAICompletions = true + } wz.EndpointURL = inp.EndpointURL - wz.ProxyURL = inp.ProxyURL wz.APIKey = inp.ApiKey - wz.Model = inp.Model + wz.ModelsText = inp.Model wz.Turbo = inp.Turbo wz.Stream = inp.Stream wz.PromptText = inp.PromptText @@ -200,6 +210,53 @@ func NewWizardStateEdit(t *types.TaskDefinition) *WizardState { return wz } +// ParseModels 解析 ModelsText,返回去重后的模型名列表。 +func (wz *WizardState) ParseModels() []string { + raw := strings.NewReplacer("\n", ",", ";", ",").Replace(wz.ModelsText) + seen := map[string]bool{} + var result []string + for _, m := range strings.Split(raw, ",") { + m = strings.TrimSpace(m) + if m != "" && !seen[m] { + seen[m] = true + result = append(result, m) + } + } + return result +} + +// SelectedProtocols 返回已勾选的协议列表(保持固定顺序)。 +func (wz *WizardState) SelectedProtocols() []string { + var protos []string + if wz.SelOpenAICompletions { + protos = append(protos, types.ProtocolOpenAICompletions) + } + if wz.SelOpenAIResponses { + protos = append(protos, types.ProtocolOpenAIResponses) + } + if wz.SelAnthropic { + protos = append(protos, types.ProtocolAnthropicMessages) + } + return protos +} + +// SingleProtocol 返回单协议场景下的协议名(优先第一个勾选;默认 OpenAI Chat)。 +func (wz *WizardState) SingleProtocol() string { + protos := wz.SelectedProtocols() + if len(protos) > 0 { + return protos[0] + } + return types.ProtocolOpenAICompletions +} + +// IsBatch 判断当前配置是否需要批量创建(多模型或多协议,且非编辑模式)。 +func (wz *WizardState) IsBatch() bool { + if wz.EditingID != "" { + return false + } + return len(wz.ParseModels()) > 1 || len(wz.SelectedProtocols()) > 1 +} + // BuildTaskConfig 将向导状态转换为 server.TaskConfig。 func (wz *WizardState) BuildTaskConfig() server.TaskConfig { turboRate := wz.MinSuccessRate / 100 // 转回小数 @@ -210,14 +267,17 @@ func (wz *WizardState) BuildTaskConfig() server.TaskConfig { if wz.Timeout > 0 { timeout = time.Duration(wz.Timeout) * time.Second } + model := "" + if models := wz.ParseModels(); len(models) > 0 { + model = models[0] + } return server.TaskConfig{ - Name: wz.Name, + Name: wizardFallback(wz.Name, "未命名任务"), Input: types.Input{ - Protocol: wz.Protocol, + Protocol: wz.SingleProtocol(), EndpointURL: wz.EndpointURL, - ProxyURL: wz.ProxyURL, ApiKey: wz.APIKey, - Model: wz.Model, + Model: model, Concurrency: wz.Concurrency, Count: wz.Count, Timeout: timeout, @@ -238,6 +298,64 @@ func (wz *WizardState) BuildTaskConfig() server.TaskConfig { } } +// BuildBatchTaskConfigs 返回 (模型 × 协议) 笛卡尔积的任务配置列表。 +func (wz *WizardState) BuildBatchTaskConfigs() []server.TaskConfig { + protos := wz.SelectedProtocols() + models := wz.ParseModels() + if len(protos) == 0 || len(models) == 0 { + return nil + } + turboRate := wz.MinSuccessRate / 100 + if turboRate <= 0 { + turboRate = 0.9 + } + var timeout time.Duration + if wz.Timeout > 0 { + timeout = time.Duration(wz.Timeout) * time.Second + } + multiProto := len(protos) > 1 + var cfgs []server.TaskConfig + for _, model := range models { + for _, proto := range protos { + endpointURL := wz.EndpointURL + displayBase := endpointURL + if displayBase == "" { + displayBase = types.DefaultEndpointURL(proto) + } + taskName := fmt.Sprintf("%s@%s", model, strings.TrimRight(displayBase, "/")) + if multiProto { + taskName = fmt.Sprintf("%s (%s)@%s", model, proto, strings.TrimRight(displayBase, "/")) + } + cfgs = append(cfgs, server.TaskConfig{ + Name: taskName, + Input: types.Input{ + Protocol: proto, + EndpointURL: endpointURL, + ApiKey: wz.APIKey, + Model: model, + Concurrency: wz.Concurrency, + Count: wz.Count, + Timeout: timeout, + Stream: wz.Stream, + Turbo: wz.Turbo, + TurboConfig: types.TurboConfig{ + InitConcurrency: wz.InitConcurrency, + MaxConcurrency: wz.MaxConcurrency, + StepSize: wz.StepSize, + LevelRequests: wz.LevelRequests, + MinSuccessRate: turboRate, + }, + PromptMode: wz.PromptMode, + PromptText: wz.PromptText, + PromptFile: wz.PromptFile, + PromptLength: wz.PromptLength, + }, + }) + } + } + return cfgs +} + // fieldDef 向导字段定义 type fieldDef struct { kind fieldKind @@ -277,11 +395,6 @@ func intField(label string, get func(*WizardState) int, set func(*WizardState, i // step1Fields 返回步骤1的字段列表。 func step1Fields() []fieldDef { - protocols := []string{ - types.ProtocolOpenAICompletions, - types.ProtocolOpenAIResponses, - types.ProtocolAnthropicMessages, - } return []fieldDef{ { kind: fieldText, label: "任务名称", @@ -289,25 +402,19 @@ func step1Fields() []fieldDef { set: func(wz *WizardState, v string) { wz.Name = v }, }, { - kind: fieldEnum, label: "协议类型", - get: func(wz *WizardState) string { return wz.Protocol }, - toggle: func(wz *WizardState, forward bool) { - idx := 0 - for i, p := range protocols { - if p == wz.Protocol { - idx = i - break - } - } - if forward { - idx = (idx + 1) % len(protocols) - } else { - idx = (idx - 1 + len(protocols)) % len(protocols) - } - wz.Protocol = protocols[idx] - // 清空 endpoint,使其跟随协议默认值 - wz.EndpointURL = "" - }, + kind: fieldBool, label: "OpenAI Chat", + get: func(wz *WizardState) string { return boolLabel(wz.SelOpenAICompletions) }, + toggle: func(wz *WizardState, _ bool) { wz.SelOpenAICompletions = !wz.SelOpenAICompletions }, + }, + { + kind: fieldBool, label: "OpenAI Resp", + get: func(wz *WizardState) string { return boolLabel(wz.SelOpenAIResponses) }, + toggle: func(wz *WizardState, _ bool) { wz.SelOpenAIResponses = !wz.SelOpenAIResponses }, + }, + { + kind: fieldBool, label: "Anthropic", + get: func(wz *WizardState) string { return boolLabel(wz.SelAnthropic) }, + toggle: func(wz *WizardState, _ bool) { wz.SelAnthropic = !wz.SelAnthropic }, }, { kind: fieldText, label: "接口地址", @@ -315,10 +422,10 @@ func step1Fields() []fieldDef { if wz.EndpointURL != "" { return wz.EndpointURL } - return types.DefaultEndpointURL(wz.Protocol) + return types.DefaultEndpointURL(wz.SingleProtocol()) }, getRaw: func(wz *WizardState) string { return wz.EndpointURL }, - set: func(wz *WizardState, v string) { wz.EndpointURL = v }, + set: func(wz *WizardState, v string) { wz.EndpointURL = v }, }, { kind: fieldText, label: "API 密钥", @@ -326,9 +433,9 @@ func step1Fields() []fieldDef { set: func(wz *WizardState, v string) { wz.APIKey = v }, }, { - kind: fieldText, label: "测试模型", - get: func(wz *WizardState) string { return wz.Model }, - set: func(wz *WizardState, v string) { wz.Model = v }, + kind: fieldText, label: "模型列表", + get: func(wz *WizardState) string { return wz.ModelsText }, + set: func(wz *WizardState, v string) { wz.ModelsText = v }, }, } } @@ -482,6 +589,14 @@ func HandleWizardKey(wz *WizardState, msg tea.KeyMsg, client Client) (*WizardSta case "end": wz.ScrollOff = 1 << 30 case "enter": + if wz.IsBatch() { + cfgs := wz.BuildBatchTaskConfigs() + if len(cfgs) == 0 { + return wz, nil, nav + } + nav = NavAction{To: NavTaskList} + return wz, client.CreateBatchTasksCmd(cfgs), nav + } cfg := wz.BuildTaskConfig() var cmd tea.Cmd if wz.EditingID != "" { @@ -492,15 +607,17 @@ func HandleWizardKey(wz *WizardState, msg tea.KeyMsg, client Client) (*WizardSta nav = NavAction{To: NavTaskList} return wz, cmd, nav case "r": - cfg := wz.BuildTaskConfig() - var cmd tea.Cmd - if wz.EditingID != "" { - cmd = client.UpdateTaskCmd(wz.EditingID, cfg) - } else { - cmd = client.CreateTaskCmd(cfg, true) // 保存并运行 + if !wz.IsBatch() { + cfg := wz.BuildTaskConfig() + var cmd tea.Cmd + if wz.EditingID != "" { + cmd = client.UpdateTaskCmd(wz.EditingID, cfg) + } else { + cmd = client.CreateTaskCmd(cfg, true) // 保存并运行 + } + nav = NavAction{To: NavTaskList} + return wz, cmd, nav } - nav = NavAction{To: NavTaskList} - return wz, cmd, nav case "q", "ctrl+c": nav = NavAction{To: NavQuit} } @@ -613,8 +730,12 @@ func RenderWizard(wz *WizardState, st Styles, width, height int) string { } stepTitle := stepTitles[int(wz.Step)] headerLeft := []string{stepTitle} - if wz.Protocol != "" { - headerLeft = append(headerLeft, strings.ToUpper(wz.Protocol)) + if protos := wz.SelectedProtocols(); len(protos) > 0 && wz.Step >= wizardStep2 { + if len(protos) == 1 { + headerLeft = append(headerLeft, strings.ToUpper(protos[0])) + } else { + headerLeft = append(headerLeft, fmt.Sprintf("%d 协议", len(protos))) + } } headerRight := []string{} if wz.Step >= wizardStep2 { @@ -624,12 +745,16 @@ func RenderWizard(wz *WizardState, st Styles, width, height int) string { headerRight = append(headerRight, "标准模式") } } - if wz.Model != "" { - headerRight = append(headerRight, "模型 "+truncate(wz.Model, 18)) + if n := len(wz.ParseModels()); n == 1 { + headerRight = append(headerRight, "模型 "+truncate(wz.ParseModels()[0], 18)) + } else if n > 1 { + headerRight = append(headerRight, fmt.Sprintf("%d 个模型", n)) } action := "创建任务" if wz.EditingID != "" { action = "编辑任务" + } else if wz.IsBatch() { + action = "批量创建" } l := PageLayout{ @@ -638,7 +763,7 @@ func RenderWizard(wz *WizardState, st Styles, width, height int) string { HeaderMeta: fmt.Sprintf("步骤 %d/3", int(wz.Step)+1), HeaderInfoLeft: headerLeft, HeaderInfoRight: headerRight, - Hotkeys: NewPageHotkeys(wizardHotkeyItems(wz.Step), "[q] 退出"), + Hotkeys: NewPageHotkeys(wizardHotkeyItems(wz), "[q] 退出"), } frame := l.Frame(width, height) panel := NewPanelFrame(frame.OuterWidth) @@ -649,7 +774,7 @@ func RenderWizard(wz *WizardState, st Styles, width, height int) string { func buildWizardPageContent(wz *WizardState, st Styles, width, maxH int) string { var topLines []string if maxH >= 8 && width >= 46 { - topLines = append(topLines, renderWizardStepStrip(wz.Step)) + topLines = append(topLines, renderWizardStepStrip(wz.Step, wz.IsBatch())) } bottomCount := 1 @@ -801,6 +926,9 @@ func renderWizardField(st Styles, f fieldDef, wz *WizardState, active bool, maxW // renderStep3Summary 渲染步骤3的确认内容。 func renderStep3Summary(wz *WizardState, st Styles, innerW int) []string { + if wz.IsBatch() { + return renderBatchConfirmLines(wz, st, innerW) + } var lines []string addRow := func(label, value string, valueStyle lipgloss.Style) { appendWizardSummaryRow(&lines, st, label, value, innerW, valueStyle) @@ -808,14 +936,19 @@ func renderStep3Summary(wz *WizardState, st Styles, innerW int) []string { lines = append(lines, st.SectionHead.Render("配置概览")) addRow("任务名称", wizardFallback(wz.Name, "未命名任务"), st.Value) - addRow("协议", wz.Protocol, st.Value) + proto := wz.SingleProtocol() + addRow("协议", proto, st.Value) endpointDisplay := wz.EndpointURL if endpointDisplay == "" { - endpointDisplay = types.DefaultEndpointURL(wz.Protocol) + endpointDisplay = types.DefaultEndpointURL(proto) } addRow("接口地址", endpointDisplay, st.Value) addRow("API 密钥", wizardFallback(maskAPIKey(wz.APIKey), "未填写"), st.Value) - addRow("测试模型", wizardFallback(wz.Model, "未填写"), st.Value) + model := "" + if models := wz.ParseModels(); len(models) > 0 { + model = models[0] + } + addRow("测试模型", wizardFallback(model, "未填写"), st.Value) lines = append(lines, "", st.SectionHead.Render("执行参数")) if wz.Turbo { @@ -846,11 +979,43 @@ func renderStep3Summary(wz *WizardState, st Styles, innerW int) []string { return lines } -func renderWizardStepStrip(step wizardStep) string { +// renderBatchConfirmLines 渲染批量创建的确认内容。 +func renderBatchConfirmLines(wz *WizardState, st Styles, _ int) []string { + var lines []string + cfgs := wz.BuildBatchTaskConfigs() + if len(cfgs) == 0 { + lines = append(lines, st.Muted.Render("请至少选择一个协议并填写一个模型名称。")) + return lines + } + lines = append(lines, st.SectionHead.Render(fmt.Sprintf("批量任务预览(共 %d 个)", len(cfgs)))) + lines = append(lines, "") + for i, cfg := range cfgs { + lines = append(lines, fmt.Sprintf(" %d. %s", i+1, cfg.Name)) + lines = append(lines, fmt.Sprintf(" 协议: %s 模型: %s", cfg.Input.Protocol, cfg.Input.Model)) + if cfg.Input.Turbo { + lines = append(lines, fmt.Sprintf(" 并发: %d→%d 请求: %d/级", + cfg.Input.TurboConfig.InitConcurrency, cfg.Input.TurboConfig.MaxConcurrency, cfg.Input.TurboConfig.LevelRequests)) + } else { + lines = append(lines, fmt.Sprintf(" 并发: %d 请求: %d", cfg.Input.Concurrency, cfg.Input.Count)) + } + if i < len(cfgs)-1 { + lines = append(lines, "") + } + } + lines = append(lines, "") + lines = append(lines, st.Muted.Render("保存位置: ~/.ait/tasks/.json")) + return lines +} + +func renderWizardStepStrip(step wizardStep, isBatch bool) string { active := lipgloss.NewStyle().Background(colorPink).Foreground(colorWhite).Bold(true).Padding(0, 1) done := lipgloss.NewStyle().Background(colorCyan).Foreground(lipgloss.Color("233")).Bold(true).Padding(0, 1) idle := lipgloss.NewStyle().Background(lipgloss.Color("238")).Foreground(colorMuted).Padding(0, 1) - labels := []string{"1 基本信息", "2 测试参数", "3 确认保存"} + label3 := "3 确认保存" + if isBatch { + label3 = "3 确认创建" + } + labels := []string{"1 基本信息", "2 测试参数", label3} parts := make([]string, 0, len(labels)) for i, label := range labels { switch { @@ -930,13 +1095,16 @@ func appendWizardSummaryRow(lines *[]string, st Styles, label, value string, wid } } -func wizardHotkeyItems(step wizardStep) []HotkeyItem { - switch step { +func wizardHotkeyItems(wz *WizardState) []HotkeyItem { + switch wz.Step { case wizardStep1: return Hotkeys_Wizard_Step1() case wizardStep2: return Hotkeys_Wizard_Step2() default: + if wz.IsBatch() { + return Hotkeys_Wizard_Step3_Batch() + } return Hotkeys_Wizard_Step3() } } From 98a3b040a843fd98ac3152e844a4f0fb437f638a Mon Sep 17 00:00:00 2001 From: Alain Date: Sat, 23 May 2026 09:55:43 +0800 Subject: [PATCH 33/52] =?UTF-8?q?refactor:=20=E7=A7=BB=E9=99=A4=E6=89=B9?= =?UTF-8?q?=E9=87=8F=E5=88=9B=E5=BB=BA=E4=BB=BB=E5=8A=A1=E7=9B=B8=E5=85=B3?= =?UTF-8?q?=E4=BB=A3=E7=A0=81=EF=BC=8C=E7=AE=80=E5=8C=96=E4=BB=BB=E5=8A=A1?= =?UTF-8?q?=E5=88=9B=E5=BB=BA=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/tui/client.go | 12 -- internal/tui/messages.go | 5 - internal/tui/model.go | 5 - internal/tui/model_test.go | 11 +- internal/tui/pages/contextbar.go | 12 +- internal/tui/pages/nav.go | 3 - internal/tui/pages/wizard.go | 285 +++++++------------------------ 7 files changed, 63 insertions(+), 270 deletions(-) diff --git a/internal/tui/client.go b/internal/tui/client.go index e109500..902e392 100644 --- a/internal/tui/client.go +++ b/internal/tui/client.go @@ -76,18 +76,6 @@ func (c *Client) CopyTaskCmd(id string) tea.Cmd { } } -// CreateBatchTasksCmd 批量异步创建任务。 -func (c *Client) CreateBatchTasksCmd(cfgs []server.TaskConfig) tea.Cmd { - return func() tea.Msg { - for _, cfg := range cfgs { - if _, err := c.srv.CreateTask(cfg); err != nil { - return ErrorMsg{Err: fmt.Errorf("批量创建任务 %q 失败: %w", cfg.Name, err)} - } - } - return BatchTasksSavedMsg{Count: len(cfgs)} - } -} - // ─── 运行管理 ───────────────────────────────────────────────────────────────── // StartRunCmd 异步启动运行。 diff --git a/internal/tui/messages.go b/internal/tui/messages.go index 00f4900..fee072f 100644 --- a/internal/tui/messages.go +++ b/internal/tui/messages.go @@ -21,11 +21,6 @@ type TaskDeletedMsg struct { TaskID string } -// BatchTasksSavedMsg 批量创建任务完成。 -type BatchTasksSavedMsg struct { - Count int -} - // HistoryLoadedMsg 任务历史记录加载完成。 type HistoryLoadedMsg struct { TaskID string diff --git a/internal/tui/model.go b/internal/tui/model.go index 48837f1..7772421 100644 --- a/internal/tui/model.go +++ b/internal/tui/model.go @@ -221,11 +221,6 @@ func (m *Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { case ProxyConfigSavedMsg: m.status = "代理配置已保存" return m, nil - - case BatchTasksSavedMsg: - m.status = fmt.Sprintf("已批量创建 %d 个任务", msg.Count) - m.view = viewTaskList - return m, m.client.LoadTasksCmd() } return m, nil diff --git a/internal/tui/model_test.go b/internal/tui/model_test.go index 2bc04e3..59d27df 100644 --- a/internal/tui/model_test.go +++ b/internal/tui/model_test.go @@ -95,8 +95,8 @@ func TestOpenWizard_EditTask_Populate(t *testing.T) { if m.wizard.EditingID != "task-123" { t.Errorf("EditingID = %q, want %q", m.wizard.EditingID, "task-123") } - if m.wizard.ModelsText != "gpt-4" { - t.Errorf("ModelsText = %q, want %q", m.wizard.ModelsText, "gpt-4") + if m.wizard.Model != "gpt-4" { + t.Errorf("Model = %q, want %q", m.wizard.Model, "gpt-4") } if m.wizard.Concurrency != 5 { t.Errorf("Concurrency = %d, want 5", m.wizard.Concurrency) @@ -107,7 +107,7 @@ func TestBuildTaskInput_Standard(t *testing.T) { m := NewModel(&stubServer{}) m.wizard = pages.NewWizardState() wz := m.wizard - wz.ModelsText = "gpt-4.1" + wz.Model = "gpt-4.1" wz.APIKey = "sk-test" wz.Concurrency = 8 wz.Count = 120 @@ -137,10 +137,9 @@ func TestBuildTaskInput_Turbo(t *testing.T) { m := NewModel(&stubServer{}) m.wizard = pages.NewWizardState() wz := m.wizard - wz.ModelsText = "claude-3-7-sonnet" + wz.Model = "claude-3-7-sonnet" wz.APIKey = "sk-ant" - wz.SelOpenAICompletions = false - wz.SelAnthropic = true + wz.Protocol = types.ProtocolAnthropicMessages wz.Turbo = true wz.InitConcurrency = 1 wz.MaxConcurrency = 12 diff --git a/internal/tui/pages/contextbar.go b/internal/tui/pages/contextbar.go index 72f566a..2cb02fa 100644 --- a/internal/tui/pages/contextbar.go +++ b/internal/tui/pages/contextbar.go @@ -88,7 +88,7 @@ func Hotkeys_TaskDetail_Running() []HotkeyItem { func Hotkeys_Wizard_Step1() []HotkeyItem { return []HotkeyItem{ HotkeyAction("Tab/↑↓", "切换字段"), - HotkeyAction("Space/←→", "切换选项"), + HotkeyAction("←→", "切换协议"), HotkeyAction("Enter", "下一步"), HotkeyAction("Esc", "返回列表"), } @@ -115,16 +115,6 @@ func Hotkeys_Wizard_Step3() []HotkeyItem { } } -// Hotkeys_Wizard_Step3_Batch 批量创建确认页热键。 -func Hotkeys_Wizard_Step3_Batch() []HotkeyItem { - return []HotkeyItem{ - HotkeyAction("↑↓", "滚动"), - HotkeyAction("PgUp/PgDn", "翻页"), - HotkeyAction("Enter", "批量创建"), - HotkeyAction("Esc", "返回修改"), - } -} - // Hotkeys_Dashboard_Running_NoSel 标准仪表盘运行中,无选中请求时。 func Hotkeys_Dashboard_Running_NoSel() []HotkeyItem { return []HotkeyItem{ diff --git a/internal/tui/pages/nav.go b/internal/tui/pages/nav.go index 40c025f..7f70876 100644 --- a/internal/tui/pages/nav.go +++ b/internal/tui/pages/nav.go @@ -55,9 +55,6 @@ type Client interface { GetRunStateForHistoryCmd(runID server.RunID, summary *types.TaskRunSummary) tea.Cmd GenerateReportCmd(runID server.RunID, format server.ReportFormat) tea.Cmd - // 批量创建任务 - CreateBatchTasksCmd(cfgs []server.TaskConfig) tea.Cmd - // 全局配置 SaveProxyConfigCmd(proxyURL string) tea.Cmd LoadProxyConfigCmd() tea.Cmd diff --git a/internal/tui/pages/wizard.go b/internal/tui/pages/wizard.go index c462027..f8e7bdb 100644 --- a/internal/tui/pages/wizard.go +++ b/internal/tui/pages/wizard.go @@ -37,13 +37,10 @@ type WizardState struct { // Step 1: 基本信息 Name string - // 协议多选(至少选一个) - SelOpenAICompletions bool - SelOpenAIResponses bool - SelAnthropic bool + Protocol string // types.Protocol* 常量 EndpointURL string APIKey string - ModelsText string // 逗号分隔,支持多个模型 + Model string // Step 2: 测试参数 Turbo bool @@ -122,8 +119,8 @@ func loadCurrentFieldInput(wz *WizardState) { // NewWizardState 创建新建任务向导状态(使用默认值)。 func NewWizardState() *WizardState { wz := &WizardState{ - Step: wizardStep1, - SelOpenAICompletions: true, + Step: wizardStep1, + Protocol: types.ProtocolOpenAICompletions, Concurrency: 10, Count: 100, Timeout: 30, @@ -152,19 +149,10 @@ func NewWizardStateEdit(t *types.TaskDefinition) *WizardState { wz.EditingID = t.ID wz.Name = t.Name - switch types.NormalizeProtocol(inp.Protocol) { - case types.ProtocolOpenAIResponses: - wz.SelOpenAICompletions = false - wz.SelOpenAIResponses = true - case types.ProtocolAnthropicMessages: - wz.SelOpenAICompletions = false - wz.SelAnthropic = true - default: - wz.SelOpenAICompletions = true - } + wz.Protocol = types.NormalizeProtocol(inp.Protocol) wz.EndpointURL = inp.EndpointURL wz.APIKey = inp.ApiKey - wz.ModelsText = inp.Model + wz.Model = inp.Model wz.Turbo = inp.Turbo wz.Stream = inp.Stream wz.PromptText = inp.PromptText @@ -210,53 +198,6 @@ func NewWizardStateEdit(t *types.TaskDefinition) *WizardState { return wz } -// ParseModels 解析 ModelsText,返回去重后的模型名列表。 -func (wz *WizardState) ParseModels() []string { - raw := strings.NewReplacer("\n", ",", ";", ",").Replace(wz.ModelsText) - seen := map[string]bool{} - var result []string - for _, m := range strings.Split(raw, ",") { - m = strings.TrimSpace(m) - if m != "" && !seen[m] { - seen[m] = true - result = append(result, m) - } - } - return result -} - -// SelectedProtocols 返回已勾选的协议列表(保持固定顺序)。 -func (wz *WizardState) SelectedProtocols() []string { - var protos []string - if wz.SelOpenAICompletions { - protos = append(protos, types.ProtocolOpenAICompletions) - } - if wz.SelOpenAIResponses { - protos = append(protos, types.ProtocolOpenAIResponses) - } - if wz.SelAnthropic { - protos = append(protos, types.ProtocolAnthropicMessages) - } - return protos -} - -// SingleProtocol 返回单协议场景下的协议名(优先第一个勾选;默认 OpenAI Chat)。 -func (wz *WizardState) SingleProtocol() string { - protos := wz.SelectedProtocols() - if len(protos) > 0 { - return protos[0] - } - return types.ProtocolOpenAICompletions -} - -// IsBatch 判断当前配置是否需要批量创建(多模型或多协议,且非编辑模式)。 -func (wz *WizardState) IsBatch() bool { - if wz.EditingID != "" { - return false - } - return len(wz.ParseModels()) > 1 || len(wz.SelectedProtocols()) > 1 -} - // BuildTaskConfig 将向导状态转换为 server.TaskConfig。 func (wz *WizardState) BuildTaskConfig() server.TaskConfig { turboRate := wz.MinSuccessRate / 100 // 转回小数 @@ -267,17 +208,13 @@ func (wz *WizardState) BuildTaskConfig() server.TaskConfig { if wz.Timeout > 0 { timeout = time.Duration(wz.Timeout) * time.Second } - model := "" - if models := wz.ParseModels(); len(models) > 0 { - model = models[0] - } return server.TaskConfig{ Name: wizardFallback(wz.Name, "未命名任务"), Input: types.Input{ - Protocol: wz.SingleProtocol(), + Protocol: wz.Protocol, EndpointURL: wz.EndpointURL, ApiKey: wz.APIKey, - Model: model, + Model: wz.Model, Concurrency: wz.Concurrency, Count: wz.Count, Timeout: timeout, @@ -298,64 +235,6 @@ func (wz *WizardState) BuildTaskConfig() server.TaskConfig { } } -// BuildBatchTaskConfigs 返回 (模型 × 协议) 笛卡尔积的任务配置列表。 -func (wz *WizardState) BuildBatchTaskConfigs() []server.TaskConfig { - protos := wz.SelectedProtocols() - models := wz.ParseModels() - if len(protos) == 0 || len(models) == 0 { - return nil - } - turboRate := wz.MinSuccessRate / 100 - if turboRate <= 0 { - turboRate = 0.9 - } - var timeout time.Duration - if wz.Timeout > 0 { - timeout = time.Duration(wz.Timeout) * time.Second - } - multiProto := len(protos) > 1 - var cfgs []server.TaskConfig - for _, model := range models { - for _, proto := range protos { - endpointURL := wz.EndpointURL - displayBase := endpointURL - if displayBase == "" { - displayBase = types.DefaultEndpointURL(proto) - } - taskName := fmt.Sprintf("%s@%s", model, strings.TrimRight(displayBase, "/")) - if multiProto { - taskName = fmt.Sprintf("%s (%s)@%s", model, proto, strings.TrimRight(displayBase, "/")) - } - cfgs = append(cfgs, server.TaskConfig{ - Name: taskName, - Input: types.Input{ - Protocol: proto, - EndpointURL: endpointURL, - ApiKey: wz.APIKey, - Model: model, - Concurrency: wz.Concurrency, - Count: wz.Count, - Timeout: timeout, - Stream: wz.Stream, - Turbo: wz.Turbo, - TurboConfig: types.TurboConfig{ - InitConcurrency: wz.InitConcurrency, - MaxConcurrency: wz.MaxConcurrency, - StepSize: wz.StepSize, - LevelRequests: wz.LevelRequests, - MinSuccessRate: turboRate, - }, - PromptMode: wz.PromptMode, - PromptText: wz.PromptText, - PromptFile: wz.PromptFile, - PromptLength: wz.PromptLength, - }, - }) - } - } - return cfgs -} - // fieldDef 向导字段定义 type fieldDef struct { kind fieldKind @@ -395,6 +274,11 @@ func intField(label string, get func(*WizardState) int, set func(*WizardState, i // step1Fields 返回步骤1的字段列表。 func step1Fields() []fieldDef { + protocols := []string{ + types.ProtocolOpenAICompletions, + types.ProtocolOpenAIResponses, + types.ProtocolAnthropicMessages, + } return []fieldDef{ { kind: fieldText, label: "任务名称", @@ -402,19 +286,25 @@ func step1Fields() []fieldDef { set: func(wz *WizardState, v string) { wz.Name = v }, }, { - kind: fieldBool, label: "OpenAI Chat", - get: func(wz *WizardState) string { return boolLabel(wz.SelOpenAICompletions) }, - toggle: func(wz *WizardState, _ bool) { wz.SelOpenAICompletions = !wz.SelOpenAICompletions }, - }, - { - kind: fieldBool, label: "OpenAI Resp", - get: func(wz *WizardState) string { return boolLabel(wz.SelOpenAIResponses) }, - toggle: func(wz *WizardState, _ bool) { wz.SelOpenAIResponses = !wz.SelOpenAIResponses }, - }, - { - kind: fieldBool, label: "Anthropic", - get: func(wz *WizardState) string { return boolLabel(wz.SelAnthropic) }, - toggle: func(wz *WizardState, _ bool) { wz.SelAnthropic = !wz.SelAnthropic }, + kind: fieldEnum, label: "协议类型", + get: func(wz *WizardState) string { return wz.Protocol }, + toggle: func(wz *WizardState, forward bool) { + idx := 0 + for i, p := range protocols { + if p == wz.Protocol { + idx = i + break + } + } + if forward { + idx = (idx + 1) % len(protocols) + } else { + idx = (idx - 1 + len(protocols)) % len(protocols) + } + wz.Protocol = protocols[idx] + // 清空 endpoint,使其跟随协议默认值 + wz.EndpointURL = "" + }, }, { kind: fieldText, label: "接口地址", @@ -422,7 +312,7 @@ func step1Fields() []fieldDef { if wz.EndpointURL != "" { return wz.EndpointURL } - return types.DefaultEndpointURL(wz.SingleProtocol()) + return types.DefaultEndpointURL(wz.Protocol) }, getRaw: func(wz *WizardState) string { return wz.EndpointURL }, set: func(wz *WizardState, v string) { wz.EndpointURL = v }, @@ -433,9 +323,9 @@ func step1Fields() []fieldDef { set: func(wz *WizardState, v string) { wz.APIKey = v }, }, { - kind: fieldText, label: "模型列表", - get: func(wz *WizardState) string { return wz.ModelsText }, - set: func(wz *WizardState, v string) { wz.ModelsText = v }, + kind: fieldText, label: "测试模型", + get: func(wz *WizardState) string { return wz.Model }, + set: func(wz *WizardState, v string) { wz.Model = v }, }, } } @@ -589,14 +479,6 @@ func HandleWizardKey(wz *WizardState, msg tea.KeyMsg, client Client) (*WizardSta case "end": wz.ScrollOff = 1 << 30 case "enter": - if wz.IsBatch() { - cfgs := wz.BuildBatchTaskConfigs() - if len(cfgs) == 0 { - return wz, nil, nav - } - nav = NavAction{To: NavTaskList} - return wz, client.CreateBatchTasksCmd(cfgs), nav - } cfg := wz.BuildTaskConfig() var cmd tea.Cmd if wz.EditingID != "" { @@ -607,17 +489,15 @@ func HandleWizardKey(wz *WizardState, msg tea.KeyMsg, client Client) (*WizardSta nav = NavAction{To: NavTaskList} return wz, cmd, nav case "r": - if !wz.IsBatch() { - cfg := wz.BuildTaskConfig() - var cmd tea.Cmd - if wz.EditingID != "" { - cmd = client.UpdateTaskCmd(wz.EditingID, cfg) - } else { - cmd = client.CreateTaskCmd(cfg, true) // 保存并运行 - } - nav = NavAction{To: NavTaskList} - return wz, cmd, nav + cfg := wz.BuildTaskConfig() + var cmd tea.Cmd + if wz.EditingID != "" { + cmd = client.UpdateTaskCmd(wz.EditingID, cfg) + } else { + cmd = client.CreateTaskCmd(cfg, true) // 保存并运行 } + nav = NavAction{To: NavTaskList} + return wz, cmd, nav case "q", "ctrl+c": nav = NavAction{To: NavQuit} } @@ -730,12 +610,8 @@ func RenderWizard(wz *WizardState, st Styles, width, height int) string { } stepTitle := stepTitles[int(wz.Step)] headerLeft := []string{stepTitle} - if protos := wz.SelectedProtocols(); len(protos) > 0 && wz.Step >= wizardStep2 { - if len(protos) == 1 { - headerLeft = append(headerLeft, strings.ToUpper(protos[0])) - } else { - headerLeft = append(headerLeft, fmt.Sprintf("%d 协议", len(protos))) - } + if wz.Protocol != "" && wz.Step >= wizardStep2 { + headerLeft = append(headerLeft, strings.ToUpper(wz.Protocol)) } headerRight := []string{} if wz.Step >= wizardStep2 { @@ -745,16 +621,12 @@ func RenderWizard(wz *WizardState, st Styles, width, height int) string { headerRight = append(headerRight, "标准模式") } } - if n := len(wz.ParseModels()); n == 1 { - headerRight = append(headerRight, "模型 "+truncate(wz.ParseModels()[0], 18)) - } else if n > 1 { - headerRight = append(headerRight, fmt.Sprintf("%d 个模型", n)) + if wz.Model != "" { + headerRight = append(headerRight, "模型 "+truncate(wz.Model, 18)) } action := "创建任务" if wz.EditingID != "" { action = "编辑任务" - } else if wz.IsBatch() { - action = "批量创建" } l := PageLayout{ @@ -763,7 +635,7 @@ func RenderWizard(wz *WizardState, st Styles, width, height int) string { HeaderMeta: fmt.Sprintf("步骤 %d/3", int(wz.Step)+1), HeaderInfoLeft: headerLeft, HeaderInfoRight: headerRight, - Hotkeys: NewPageHotkeys(wizardHotkeyItems(wz), "[q] 退出"), + Hotkeys: NewPageHotkeys(wizardHotkeyItems(wz.Step), "[q] 退出"), } frame := l.Frame(width, height) panel := NewPanelFrame(frame.OuterWidth) @@ -774,7 +646,7 @@ func RenderWizard(wz *WizardState, st Styles, width, height int) string { func buildWizardPageContent(wz *WizardState, st Styles, width, maxH int) string { var topLines []string if maxH >= 8 && width >= 46 { - topLines = append(topLines, renderWizardStepStrip(wz.Step, wz.IsBatch())) + topLines = append(topLines, renderWizardStepStrip(wz.Step)) } bottomCount := 1 @@ -926,9 +798,6 @@ func renderWizardField(st Styles, f fieldDef, wz *WizardState, active bool, maxW // renderStep3Summary 渲染步骤3的确认内容。 func renderStep3Summary(wz *WizardState, st Styles, innerW int) []string { - if wz.IsBatch() { - return renderBatchConfirmLines(wz, st, innerW) - } var lines []string addRow := func(label, value string, valueStyle lipgloss.Style) { appendWizardSummaryRow(&lines, st, label, value, innerW, valueStyle) @@ -936,19 +805,14 @@ func renderStep3Summary(wz *WizardState, st Styles, innerW int) []string { lines = append(lines, st.SectionHead.Render("配置概览")) addRow("任务名称", wizardFallback(wz.Name, "未命名任务"), st.Value) - proto := wz.SingleProtocol() - addRow("协议", proto, st.Value) + addRow("协议", wz.Protocol, st.Value) endpointDisplay := wz.EndpointURL if endpointDisplay == "" { - endpointDisplay = types.DefaultEndpointURL(proto) + endpointDisplay = types.DefaultEndpointURL(wz.Protocol) } addRow("接口地址", endpointDisplay, st.Value) addRow("API 密钥", wizardFallback(maskAPIKey(wz.APIKey), "未填写"), st.Value) - model := "" - if models := wz.ParseModels(); len(models) > 0 { - model = models[0] - } - addRow("测试模型", wizardFallback(model, "未填写"), st.Value) + addRow("测试模型", wizardFallback(wz.Model, "未填写"), st.Value) lines = append(lines, "", st.SectionHead.Render("执行参数")) if wz.Turbo { @@ -979,43 +843,11 @@ func renderStep3Summary(wz *WizardState, st Styles, innerW int) []string { return lines } -// renderBatchConfirmLines 渲染批量创建的确认内容。 -func renderBatchConfirmLines(wz *WizardState, st Styles, _ int) []string { - var lines []string - cfgs := wz.BuildBatchTaskConfigs() - if len(cfgs) == 0 { - lines = append(lines, st.Muted.Render("请至少选择一个协议并填写一个模型名称。")) - return lines - } - lines = append(lines, st.SectionHead.Render(fmt.Sprintf("批量任务预览(共 %d 个)", len(cfgs)))) - lines = append(lines, "") - for i, cfg := range cfgs { - lines = append(lines, fmt.Sprintf(" %d. %s", i+1, cfg.Name)) - lines = append(lines, fmt.Sprintf(" 协议: %s 模型: %s", cfg.Input.Protocol, cfg.Input.Model)) - if cfg.Input.Turbo { - lines = append(lines, fmt.Sprintf(" 并发: %d→%d 请求: %d/级", - cfg.Input.TurboConfig.InitConcurrency, cfg.Input.TurboConfig.MaxConcurrency, cfg.Input.TurboConfig.LevelRequests)) - } else { - lines = append(lines, fmt.Sprintf(" 并发: %d 请求: %d", cfg.Input.Concurrency, cfg.Input.Count)) - } - if i < len(cfgs)-1 { - lines = append(lines, "") - } - } - lines = append(lines, "") - lines = append(lines, st.Muted.Render("保存位置: ~/.ait/tasks/.json")) - return lines -} - -func renderWizardStepStrip(step wizardStep, isBatch bool) string { +func renderWizardStepStrip(step wizardStep) string { active := lipgloss.NewStyle().Background(colorPink).Foreground(colorWhite).Bold(true).Padding(0, 1) done := lipgloss.NewStyle().Background(colorCyan).Foreground(lipgloss.Color("233")).Bold(true).Padding(0, 1) idle := lipgloss.NewStyle().Background(lipgloss.Color("238")).Foreground(colorMuted).Padding(0, 1) - label3 := "3 确认保存" - if isBatch { - label3 = "3 确认创建" - } - labels := []string{"1 基本信息", "2 测试参数", label3} + labels := []string{"1 基本信息", "2 测试参数", "3 确认保存"} parts := make([]string, 0, len(labels)) for i, label := range labels { switch { @@ -1095,16 +927,13 @@ func appendWizardSummaryRow(lines *[]string, st Styles, label, value string, wid } } -func wizardHotkeyItems(wz *WizardState) []HotkeyItem { - switch wz.Step { +func wizardHotkeyItems(step wizardStep) []HotkeyItem { + switch step { case wizardStep1: return Hotkeys_Wizard_Step1() case wizardStep2: return Hotkeys_Wizard_Step2() default: - if wz.IsBatch() { - return Hotkeys_Wizard_Step3_Batch() - } return Hotkeys_Wizard_Step3() } } From 56067accb240a85ec85c266455314758e5502728 Mon Sep 17 00:00:00 2001 From: Alain Date: Sat, 23 May 2026 10:40:21 +0800 Subject: [PATCH 34/52] =?UTF-8?q?feat:=20=E6=B7=BB=E5=8A=A0=E5=8E=9F?= =?UTF-8?q?=E5=A7=8B=E8=AF=B7=E6=B1=82=E4=BD=93=E6=94=AF=E6=8C=81=EF=BC=8C?= =?UTF-8?q?=E6=9B=B4=E6=96=B0=E7=9B=B8=E5=85=B3=E9=80=BB=E8=BE=91=E4=BB=A5?= =?UTF-8?q?=E5=A4=84=E7=90=86=E4=B8=8D=E5=90=8C=E7=9A=84=E6=8F=90=E7=A4=BA?= =?UTF-8?q?=E6=A8=A1=E5=BC=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/client/anthropic.go | 14 ++++++++ internal/client/client.go | 2 ++ internal/client/openai.go | 14 ++++++++ internal/prompt/prompt.go | 6 ++-- internal/runner/runner.go | 41 +++++++++++++++++------- internal/task/input.go | 9 ++++++ internal/tui/model_test.go | 4 +-- internal/tui/pages/helpers.go | 9 ++++++ internal/tui/pages/wizard.go | 60 +++++++++++++++++++++++++++++++---- 9 files changed, 136 insertions(+), 23 deletions(-) diff --git a/internal/client/anthropic.go b/internal/client/anthropic.go index 19d743f..a94a2d9 100644 --- a/internal/client/anthropic.go +++ b/internal/client/anthropic.go @@ -200,6 +200,20 @@ func (c *AnthropicClient) Request(systemPrompt, userPrompt string, stream bool) }, err } + return c.doRequest(reqBodyBytes, stream) +} + +// RawRequest 使用原始 JSON 请求体发送请求,stream 从请求体中的 stream 字段自动检测。 +func (c *AnthropicClient) RawRequest(rawBody string) (*ResponseMetrics, error) { + var tmp struct { + Stream bool `json:"stream"` + } + _ = json.Unmarshal([]byte(rawBody), &tmp) + return c.doRequest([]byte(rawBody), tmp.Stream) +} + +// doRequest 执行 HTTP 请求并解析响应(支持流式和非流式) +func (c *AnthropicClient) doRequest(reqBodyBytes []byte, stream bool) (*ResponseMetrics, error) { req, err := http.NewRequest("POST", c.EndpointURL, bytes.NewBuffer(reqBodyBytes)) if err != nil { // 记录错误日志 diff --git a/internal/client/client.go b/internal/client/client.go index 9bada71..f361462 100644 --- a/internal/client/client.go +++ b/internal/client/client.go @@ -38,6 +38,8 @@ type ResponseMetrics struct { type ModelClient interface { // Request 发送请求。systemPrompt 为空时行为与原来相同(不添加 system 消息)。 Request(systemPrompt, userPrompt string, stream bool) (*ResponseMetrics, error) + // RawRequest 使用原始 JSON 请求体发送请求,stream 从请求体中的 stream 字段自动检测。 + RawRequest(rawBody string) (*ResponseMetrics, error) GetProtocol() string GetModel() string SetLogger(logger *logger.Logger) // 设置日志记录器 diff --git a/internal/client/openai.go b/internal/client/openai.go index 0e71216..cd5994b 100644 --- a/internal/client/openai.go +++ b/internal/client/openai.go @@ -411,6 +411,20 @@ func (c *OpenAIClient) Request(systemPrompt, userPrompt string, stream bool) (*R return nil, err } + return c.doRequest(jsonData, stream) +} + +// RawRequest 使用原始 JSON 请求体发送请求,stream 从请求体中的 stream 字段自动检测。 +func (c *OpenAIClient) RawRequest(rawBody string) (*ResponseMetrics, error) { + var tmp struct { + Stream bool `json:"stream"` + } + _ = json.Unmarshal([]byte(rawBody), &tmp) + return c.doRequest([]byte(rawBody), tmp.Stream) +} + +// doRequest 执行 HTTP 请求并解析响应(支持流式和非流式) +func (c *OpenAIClient) doRequest(jsonData []byte, stream bool) (*ResponseMetrics, error) { req, err := http.NewRequestWithContext(context.Background(), "POST", c.endpointURL, bytes.NewBuffer(jsonData)) if err != nil { // 记录错误日志 diff --git a/internal/prompt/prompt.go b/internal/prompt/prompt.go index 53fd92f..e814721 100644 --- a/internal/prompt/prompt.go +++ b/internal/prompt/prompt.go @@ -295,11 +295,11 @@ func splitGeneratedPromptLengths(total int) (commonLen, userLen int) { return 0, total } - commonLen = total * 7 / 10 + commonLen = total * 9 / 10 userLen = total - commonLen - if userLen < 12 { - userLen = 12 + if userLen < 24 { + userLen = 24 if total < userLen { userLen = total } diff --git a/internal/runner/runner.go b/internal/runner/runner.go index 01d209c..479fd71 100644 --- a/internal/runner/runner.go +++ b/internal/runner/runner.go @@ -87,10 +87,16 @@ func (r *Runner) Run() (*types.ReportData, error) { defer func() { <-ch }() // 获取当前请求使用的prompt - systemPrompt := r.input.PromptSource.GetSystemContent() - userPrompt := r.input.PromptSource.GetContentByIndex(idx) - - metrics, err := r.client.Request(systemPrompt, userPrompt, r.input.Stream) + var metrics *client.ResponseMetrics + var err error + if r.input.PromptMode == "raw" { + rawBody := r.input.PromptSource.GetContentByIndex(idx) + metrics, err = r.client.RawRequest(rawBody) + } else { + systemPrompt := r.input.PromptSource.GetSystemContent() + userPrompt := r.input.PromptSource.GetContentByIndex(idx) + metrics, err = r.client.Request(systemPrompt, userPrompt, r.input.Stream) + } if err != nil { // 即使有错误,也尝试保存 metrics(如果有的话) if metrics != nil { @@ -130,9 +136,16 @@ func (r *Runner) RunWithCallback(cb RequestDoneCallback) (*types.ReportData, err defer wg.Done() defer func() { <-ch }() - systemPrompt := r.input.PromptSource.GetSystemContent() - userPrompt := r.input.PromptSource.GetContentByIndex(idx) - metrics, err := r.client.Request(systemPrompt, userPrompt, r.input.Stream) + var metrics *client.ResponseMetrics + var err error + if r.input.PromptMode == "raw" { + rawBody := r.input.PromptSource.GetContentByIndex(idx) + metrics, err = r.client.RawRequest(rawBody) + } else { + systemPrompt := r.input.PromptSource.GetSystemContent() + userPrompt := r.input.PromptSource.GetContentByIndex(idx) + metrics, err = r.client.Request(systemPrompt, userPrompt, r.input.Stream) + } if metrics != nil { results[idx] = metrics } @@ -233,10 +246,16 @@ func (r *Runner) RunWithProgress(progressCallback func(types.StatsData)) (*types defer func() { <-ch }() // 获取当前请求使用的prompt - systemPrompt := r.input.PromptSource.GetSystemContent() - userPrompt := r.input.PromptSource.GetContentByIndex(idx) - - metrics, err := r.client.Request(systemPrompt, userPrompt, r.input.Stream) + var metrics *client.ResponseMetrics + var err error + if r.input.PromptMode == "raw" { + rawBody := r.input.PromptSource.GetContentByIndex(idx) + metrics, err = r.client.RawRequest(rawBody) + } else { + systemPrompt := r.input.PromptSource.GetSystemContent() + userPrompt := r.input.PromptSource.GetContentByIndex(idx) + metrics, err = r.client.Request(systemPrompt, userPrompt, r.input.Stream) + } if err != nil { ttftsMutex.Lock() errorMessages = append(errorMessages, err.Error()) diff --git a/internal/task/input.go b/internal/task/input.go index bd8f3f5..6ed510b 100644 --- a/internal/task/input.go +++ b/internal/task/input.go @@ -40,6 +40,15 @@ func HydrateInput(input types.Input) (types.Input, error) { return input, err } input.PromptSource = source + case "raw": + if input.PromptText == "" { + return input, fmt.Errorf("prompt_text is required for prompt_mode=raw (paste the raw JSON request body)") + } + source, err := prompt.LoadPrompts(input.PromptText) + if err != nil { + return input, err + } + input.PromptSource = source default: return input, fmt.Errorf("unsupported prompt_mode: %s", input.PromptMode) } diff --git a/internal/tui/model_test.go b/internal/tui/model_test.go index 59d27df..33c8a80 100644 --- a/internal/tui/model_test.go +++ b/internal/tui/model_test.go @@ -67,8 +67,8 @@ func TestOpenWizard_NewTask_Defaults(t *testing.T) { if m.wizard.PromptMode != pages.PromptModeGenerated { t.Errorf("default PromptMode = %q, want %q", m.wizard.PromptMode, pages.PromptModeGenerated) } - if m.wizard.PromptLength != 100 { - t.Errorf("default PromptLength = %d, want 100", m.wizard.PromptLength) + if m.wizard.PromptLength != 4096 { + t.Errorf("default PromptLength = %d, want 4096", m.wizard.PromptLength) } } diff --git a/internal/tui/pages/helpers.go b/internal/tui/pages/helpers.go index 6499bfc..12567cb 100644 --- a/internal/tui/pages/helpers.go +++ b/internal/tui/pages/helpers.go @@ -628,6 +628,15 @@ func promptSummary(promptMode, promptText, promptFile string, promptLength int) return "文件: " + promptFile case "generated": return fmt.Sprintf("生成 %d 字符", promptLength) + case "raw": + if promptText != "" { + r := []rune(promptText) + if len(r) > 20 { + return "RAW: " + string(r[:20]) + "…" + } + return "RAW: " + promptText + } + return "(未设置)" default: if promptText != "" { r := []rune(promptText) diff --git a/internal/tui/pages/wizard.go b/internal/tui/pages/wizard.go index f8e7bdb..3c80ed3 100644 --- a/internal/tui/pages/wizard.go +++ b/internal/tui/pages/wizard.go @@ -19,6 +19,7 @@ const ( PromptModeText = "text" PromptModeFile = "file" PromptModeGenerated = "generated" + PromptModeRaw = "raw" ) // wizardStep 步骤枚举 @@ -131,7 +132,7 @@ func NewWizardState() *WizardState { MinSuccessRate: 90, Stream: true, PromptMode: PromptModeGenerated, - PromptLength: 100, + PromptLength: 4096, input: newWizardTextInput(), } loadCurrentFieldInput(wz) @@ -379,7 +380,7 @@ func step2Fields(turbo bool) []fieldDef { }) // Prompt 字段(共用) - promptModes := []string{PromptModeText, PromptModeFile, PromptModeGenerated} + promptModes := []string{PromptModeText, PromptModeFile, PromptModeGenerated, PromptModeRaw} fields = append(fields, fieldDef{ kind: fieldEnum, label: "输入方式", @@ -389,6 +390,8 @@ func step2Fields(turbo bool) []fieldDef { return "文件" case PromptModeGenerated: return "按长度生成" + case PromptModeRaw: + return "RAW 请求体" default: return "直接输入" } @@ -408,7 +411,7 @@ func step2Fields(turbo bool) []fieldDef { } wz.PromptMode = promptModes[idx] if wz.PromptMode == PromptModeGenerated && wz.PromptLength <= 0 { - wz.PromptLength = 100 + wz.PromptLength = 4096 } }, }, @@ -668,13 +671,17 @@ func buildWizardPageContent(wz *WizardState, st Styles, width, maxH int) string topLines = append(topLines, dividerLine(st, width)) } - bodyLines, focusLine := buildWizardBody(wz, st, width) + bodyLines, focusLine, focusEndLine := buildWizardBody(wz, st, width) bodyH := maxInt(1, availableForContent-len(topLines)) offset := 0 if wz.Step == wizardStep3 { offset = clampInt(wz.ScrollOff, 0, maxInt(0, len(bodyLines)-bodyH)) } else if focusLine >= 0 { - offset = ensureVisibleOffset(focusLine, len(bodyLines), 0, bodyH) + // 先确保聚焦块末尾(含提示行)可见,再保证起始行不滚出视口顶部 + offset = ensureVisibleOffset(focusEndLine, len(bodyLines), 0, bodyH) + if focusLine < offset { + offset = focusLine + } } end := minInt(len(bodyLines), offset+bodyH) visibleBody := append([]string{}, bodyLines[offset:end]...) @@ -698,9 +705,10 @@ func buildWizardPageContent(wz *WizardState, st Styles, width, maxH int) string return strings.Join(lines, "\n") } -func buildWizardBody(wz *WizardState, st Styles, contentW int) ([]string, int) { +func buildWizardBody(wz *WizardState, st Styles, contentW int) ([]string, int, int) { var lines []string focusLine := -1 + focusEndLine := -1 // appendField 将字段渲染结果按行展开追加,因为 FieldActive/FieldIdle 带 Border // 会产生 3 行输出(顶部边框 + 内容 + 底部边框),必须逐行记录才能正确计算高度。 @@ -711,6 +719,9 @@ func buildWizardBody(wz *WizardState, st Styles, contentW int) ([]string, int) { for _, l := range strings.Split(rendered, "\n") { lines = append(lines, l) } + if focused { + focusEndLine = len(lines) - 1 + } } switch wz.Step { @@ -727,13 +738,42 @@ func buildWizardBody(wz *WizardState, st Styles, contentW int) ([]string, int) { lines = append(lines, "", st.Muted.Render("Prompt 配置")) } appendField(renderWizardField(st, f, wz, i == wz.FieldIndex, contentW), i == wz.FieldIndex) + if f.label == "测试模式" { + if wz.Turbo { + lines = append(lines, st.Muted.Render(" 自动从低并发起步,逐级加压,找到最大稳定吞吐点")) + } else { + lines = append(lines, st.Muted.Render(" 固定并发数和请求总数,测量在指定负载下的延迟与成功率")) + } + } + if f.label == "输入方式" { + switch wz.PromptMode { + case PromptModeText: + lines = append(lines, st.Muted.Render(" 直接粘贴或输入 Prompt 文本,所有请求共享同一段内容")) + case PromptModeFile: + lines = append(lines, st.Muted.Render(" 从文件读取 Prompt,支持通配符匹配多个文件(请求按文件轮换)")) + case PromptModeGenerated: + lines = append(lines, st.Muted.Render(" 按指定字符数自动生成测试文本,内容含大量公共前缀以模拟缓存命中")) + case PromptModeRaw: + lines = append(lines, st.Muted.Render(" 粘贴完整的 HTTP 请求 JSON Body,将跳过参数组装直接发送")) + } + } + if f.label == "内容" && (wz.PromptMode == PromptModeText || wz.PromptMode == PromptModeFile || wz.PromptMode == PromptModeGenerated) { + lines = append(lines, st.Muted.Render(" 提示:大多数服务需要 ≥ 1024 tokens 才能命中缓存")) + } + if f.label == "内容" && wz.PromptMode == PromptModeRaw { + lines = append(lines, st.Muted.Render(" 提示:粘贴 API 请求的完整 JSON Body,将直接作为 HTTP 请求体发送")) + } + // 提示行追加完毕后,更新聚焦块的末尾行(含提示) + if i == wz.FieldIndex { + focusEndLine = len(lines) - 1 + } } case wizardStep3: lines = append(lines, renderStep3Summary(wz, st, contentW)...) } - return lines, focusLine + return lines, focusLine, focusEndLine } // renderWizardField 渲染向导的一个字段行。 @@ -836,6 +876,8 @@ func renderStep3Summary(wz *WizardState, st Styles, innerW int) []string { addRow("字符数", strconv.Itoa(len([]rune(wz.PromptText))), st.Muted) } else if wz.PromptMode == PromptModeGenerated { addRow("目标长度", strconv.Itoa(wz.PromptLength), st.Muted) + } else if wz.PromptMode == PromptModeRaw { + addRow("Body 字节数", strconv.Itoa(len(wz.PromptText)), st.Muted) } lines = append(lines, "", st.Muted.Render("保存位置: ~/.ait/tasks/.json")) @@ -871,6 +913,8 @@ func wizardFieldLabel(f fieldDef, wz *WizardState) string { return "文件路径" case PromptModeGenerated: return "生成长度" + case PromptModeRaw: + return "JSON Body" default: return "Prompt" } @@ -882,6 +926,8 @@ func wizardPromptModeLabel(mode string) string { return "文件" case PromptModeGenerated: return "按长度生成" + case PromptModeRaw: + return "RAW 请求体" default: return "直接输入" } From c50a781f0f11475817069284c53c7470ab90f1f1 Mon Sep 17 00:00:00 2001 From: Alain Date: Sat, 23 May 2026 12:42:20 +0800 Subject: [PATCH 35/52] =?UTF-8?q?feat:=20=E6=B7=BB=E5=8A=A0=20TurboConfig?= =?UTF-8?q?=20=E6=94=AF=E6=8C=81=EF=BC=8C=E6=9B=B4=E6=96=B0=E8=BF=90?= =?UTF-8?q?=E8=A1=8C=E7=8A=B6=E6=80=81=E5=92=8C=E7=95=8C=E9=9D=A2=E4=BB=A5?= =?UTF-8?q?=E5=8F=8D=E6=98=A0=E5=B9=B6=E5=8F=91=E7=BA=A7=E5=88=AB?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/server/run.go | 16 ++++ internal/server/types.go | 1 + internal/tui/pages/turbodash.go | 38 ++++++--- internal/turbo/engine.go | 40 +++++++-- internal/turbo/engine_test.go | 138 ++++++++++++++++++++++++++++++++ internal/types/types.go | 1 + 6 files changed, 214 insertions(+), 20 deletions(-) diff --git a/internal/server/run.go b/internal/server/run.go index e818080..6a62d19 100644 --- a/internal/server/run.go +++ b/internal/server/run.go @@ -182,6 +182,7 @@ func buildRunStateFromStoredRun(run *store.StoredRun, requests []types.RequestMe state.StandardResult = run.Result.StandardResult state.TurboResult = run.Result.TurboResult if run.Result.TurboResult != nil { + state.TurboConfig = run.Result.TurboResult.Config state.Levels = run.Result.TurboResult.Levels state.CurrentLevel = run.Result.TurboResult.MaxStableConcurrency } @@ -253,6 +254,8 @@ func (s *serverImpl) StartRun(taskID string) (RunID, error) { if hydratedInput.Turbo { // turbo 模式:跨多个并发级别探测,请求总数不固定,动态追加 state.TotalReqs = 0 + // 规范化并存储 TurboConfig,供 TUI 在运行开始时即可显示任务参数 + state.TurboConfig = turbo.NormalizeConfig(hydratedInput.TurboConfig, hydratedInput.Count) } else { // standard 模式:请求数固定,动态追加(按完成顺序) state.TotalReqs = hydratedInput.Count @@ -351,6 +354,11 @@ func (s *serverImpl) runTurbo(ar *activeRun, runID RunID, taskDef types.TaskDefi var globalIdx int64 factory := func(levelInput types.Input) (turbo.LevelRunner, error) { + // 每级别开始时更新 CurrentLevel,TUI 实时反映当前探测的并发度 + ar.mu.Lock() + ar.state.CurrentLevel = levelInput.Concurrency + ar.mu.Unlock() + r, err := runner.NewRunner(taskDef.ID, levelInput) if err != nil { return nil, err @@ -360,6 +368,7 @@ func (s *serverImpl) runTurbo(ar *activeRun, runID RunID, taskDef types.TaskDefi cb: func(metrics *client.ResponseMetrics, _ int, cbErr error) { gIdx := int(atomic.AddInt64(&globalIdx, 1)) - 1 rm := mapRequestMetrics(metrics, gIdx, cbErr) + rm.Level = levelInput.Concurrency _ = runStore.AppendRequest(taskDef.ID, string(runID), *rm) ar.mu.Lock() @@ -391,6 +400,13 @@ func (s *serverImpl) runTurbo(ar *activeRun, runID RunID, taskDef types.TaskDefi } engine := turbo.New(factory) + engine.SetOnLevelDone(func(level types.TurboLevelResult) { + ar.mu.Lock() + ar.state.Levels = append(ar.state.Levels, level) + snap := ar.snapshotState() + ar.mu.Unlock() + s.bus.Publish(Event{RunID: runID, Kind: EventLevelDone, Payload: snap}) + }) ar.mu.Lock() ar.turboEngine = engine diff --git a/internal/server/types.go b/internal/server/types.go index 1feda74..635a549 100644 --- a/internal/server/types.go +++ b/internal/server/types.go @@ -63,6 +63,7 @@ type RunState struct { Requests []*types.RequestMetrics // Turbo 专用 + TurboConfig types.TurboConfig // 规范化后的 Turbo 配置(运行开始时填充) Levels []types.TurboLevelResult CurrentLevel int diff --git a/internal/tui/pages/turbodash.go b/internal/tui/pages/turbodash.go index 91b356f..d242357 100644 --- a/internal/tui/pages/turbodash.go +++ b/internal/tui/pages/turbodash.go @@ -189,11 +189,16 @@ func RenderTurboDash(d *TurboDashState, taskName string, st Styles, width, heigh headerRight := []string{} if rs != nil { headerLeft = []string{runStatusText(string(rs.Status)), fmt.Sprintf("完成 %d/%d", rs.DoneReqs, rs.TotalReqs)} - currentLevel := rs.CurrentLevel + 1 - if currentLevel < 1 { - currentLevel = 1 + var levelNum int + if d.IsRunning() { + levelNum = len(rs.Levels) + 1 + } else { + levelNum = len(rs.Levels) + } + if levelNum < 1 { + levelNum = 1 } - headerRight = []string{fmt.Sprintf("级别 %d", currentLevel)} + headerRight = []string{fmt.Sprintf("级别 %d", levelNum)} if len(rs.Levels) > 0 { headerRight = append(headerRight, fmt.Sprintf("已探测 %d 档", len(rs.Levels))) } @@ -243,12 +248,10 @@ func buildTurboDashParams(rs *server.RunState, st Styles, maxH, width int) strin if rs == nil { lines = append(lines, " "+st.Muted.Render("等待数据...")) } else { - if rs.TurboResult != nil { - tc := rs.TurboResult.Config - lines = append(lines, " "+labelValue(st, "爬坡 ", fmt.Sprintf("%d→%d 步进+%d", tc.InitConcurrency, tc.MaxConcurrency, tc.StepSize))) - lines = append(lines, " "+labelValue(st, "每级 ", fmt.Sprintf("%d 请求", tc.LevelRequests))) - lines = append(lines, " "+labelValue(st, "停止 ", fmt.Sprintf("成功率 < %.0f%%", tc.MinSuccessRate*100))) - } + tc := rs.TurboConfig + lines = append(lines, " "+labelValue(st, "爬坡 ", fmt.Sprintf("%d→%d 步进+%d", tc.InitConcurrency, tc.MaxConcurrency, tc.StepSize))) + lines = append(lines, " "+labelValue(st, "每级 ", fmt.Sprintf("%d 请求", tc.LevelRequests))) + lines = append(lines, " "+labelValue(st, "停止 ", fmt.Sprintf("成功率 < %.0f%%", tc.MinSuccessRate*100))) } return finishPanelLines(lines, maxH) @@ -287,9 +290,17 @@ func buildTurboProgressLine(rs *server.RunState, st Styles, width int) string { if total > 0 { ratio = float64(done) / float64(total) } - levelTotal := len(rs.Levels) + levelDone := len(rs.Levels) + var levelTotalStr string + cfg := rs.TurboConfig + if cfg.StepSize > 0 { + expected := (cfg.MaxConcurrency-cfg.InitConcurrency)/cfg.StepSize + 1 + levelTotalStr = fmt.Sprintf("%d/%d", levelDone, expected) + } else { + levelTotalStr = fmt.Sprintf("%d", levelDone) + } prefix := " 进度 " - suffix := fmt.Sprintf(" %d/%d 当前并发 %d 总进度: 已完成 %d/~? 级", done, total, rs.CurrentLevel, levelTotal) + suffix := fmt.Sprintf(" %d/%d 当前并发 %d 总进度: %s 级", done, total, rs.CurrentLevel, levelTotalStr) barW := width - lipgloss.Width(prefix) - lipgloss.Width(suffix) if barW < 5 { @@ -323,6 +334,7 @@ func buildTurboRequestList(d *TurboDashState, rs *server.RunState, st Styles, wi markW = 2 idW = 6 statW = 5 + levelW = 6 timeW = 10 ttftW = 10 cacheW = 8 @@ -332,6 +344,7 @@ func buildTurboRequestList(d *TurboDashState, rs *server.RunState, st Styles, wi tableCol(markW, ""), tableCol(idW, "#"), tableCol(statW, "状态"), + tableCol(levelW, "级别"), tableCol(timeW, "总耗时"), tableCol(ttftW, "TTFT"), tableCol(cacheW, "Cache"), @@ -376,6 +389,7 @@ func buildTurboRequestList(d *TurboDashState, rs *server.RunState, st Styles, wi tableCol(markW, marker), tableCol(idW, fmt.Sprintf("#%d", len(reqs)-pos)), tableCol(statW, statusStr), + tableCol(levelW, fmt.Sprintf("%d", r.Level)), tableCol(timeW, totalStr), tableCol(ttftW, fmtDuration(r.TTFT)), tableCol(cacheW, fmt.Sprintf("%.0f%%", r.CacheHitRate*100)), diff --git a/internal/turbo/engine.go b/internal/turbo/engine.go index e13bc15..4ebae86 100644 --- a/internal/turbo/engine.go +++ b/internal/turbo/engine.go @@ -30,6 +30,12 @@ type Engine struct { currentRunner LevelRunner stopCh chan struct{} stopOnce sync.Once + onLevelDone func(types.TurboLevelResult) +} + +// SetOnLevelDone 设置每个并发级别探测完成后的回调(含稳定与不稳定级别)。 +func (e *Engine) SetOnLevelDone(fn func(types.TurboLevelResult)) { + e.onLevelDone = fn } func New(factory RunnerFactory) *Engine { @@ -108,26 +114,39 @@ func (e *Engine) Run(input types.Input) (*types.TurboResult, error) { } level := buildLevelResult(report, concurrency) - result.Levels = append(result.Levels, level) + // 先判断停止条件,确保 Stable/StopReason 在回调前已填充。 + unstable := false select { case <-e.stopCh: result.StopReason = StopReasonManual + result.Levels = append(result.Levels, level) + if e.onLevelDone != nil { + e.onLevelDone(level) + } result.ProbeDuration = time.Since(startedAt) return result, nil default: } if level.SuccessRate < cfg.MinSuccessRate { - result.Levels[len(result.Levels)-1].Stable = false - result.Levels[len(result.Levels)-1].StopReason = StopReasonLowSuccessRate + level.Stable = false + level.StopReason = StopReasonLowSuccessRate result.StopReason = StopReasonLowSuccessRate - break - } - if level.AvgTotalTime > cfg.MaxLatency { - result.Levels[len(result.Levels)-1].Stable = false - result.Levels[len(result.Levels)-1].StopReason = StopReasonHighLatency + unstable = true + } else if level.AvgTotalTime > cfg.MaxLatency { + level.Stable = false + level.StopReason = StopReasonHighLatency result.StopReason = StopReasonHighLatency + unstable = true + } + + result.Levels = append(result.Levels, level) + if e.onLevelDone != nil { + e.onLevelDone(level) + } + + if unstable { break } @@ -165,6 +184,11 @@ func buildLevelResult(report *types.ReportData, concurrency int) types.TurboLeve } } +// NormalizeConfig 对 TurboConfig 应用默认值,供外部包在构建 RunState 时复用。 +func NormalizeConfig(cfg types.TurboConfig, fallbackLevelRequests int) types.TurboConfig { + return normalizeConfig(cfg, fallbackLevelRequests) +} + func normalizeConfig(cfg types.TurboConfig, fallbackLevelRequests int) types.TurboConfig { if cfg.InitConcurrency <= 0 { cfg.InitConcurrency = 1 diff --git a/internal/turbo/engine_test.go b/internal/turbo/engine_test.go index 52c9ae8..8c23b43 100644 --- a/internal/turbo/engine_test.go +++ b/internal/turbo/engine_test.go @@ -145,3 +145,141 @@ func TestNormalizeConfigUsesDefaults(t *testing.T) { t.Fatalf("unexpected threshold defaults: %+v", cfg) } } + +func TestNormalizeConfigExported(t *testing.T) { + cfg := NormalizeConfig(types.TurboConfig{}, 20) + if cfg.LevelRequests != 20 { + t.Fatalf("expected LevelRequests 20, got %d", cfg.LevelRequests) + } + if cfg.InitConcurrency != 1 || cfg.MaxConcurrency != 50 { + t.Fatalf("unexpected defaults from exported NormalizeConfig: %+v", cfg) + } +} + +func TestEngineOnLevelDone_CalledPerLevel(t *testing.T) { + var doneLevels []types.TurboLevelResult + engine := New(func(input types.Input) (LevelRunner, error) { + report := &types.ReportData{ + TotalRequests: 10, SuccessRate: 100, + AvgTPS: float64(input.Concurrency) * 5, MaxTPS: float64(input.Concurrency) * 6, + AvgTTFT: 50 * time.Millisecond, AvgTotalTime: 200 * time.Millisecond, + } + return &fakeRunner{report: report}, nil + }) + engine.SetOnLevelDone(func(level types.TurboLevelResult) { + doneLevels = append(doneLevels, level) + }) + + result, err := engine.Run(types.Input{ + Protocol: types.ProtocolOpenAIResponses, + EndpointURL: "https://api.example.com", + Model: "test-model", + TurboConfig: types.TurboConfig{InitConcurrency: 1, MaxConcurrency: 3, StepSize: 1, LevelRequests: 10, MinSuccessRate: 0.9, MaxLatency: 10 * time.Second}, + }) + if err != nil { + t.Fatalf("Run() returned unexpected error: %v", err) + } + if len(doneLevels) != 3 { + t.Fatalf("expected OnLevelDone called 3 times, got %d", len(doneLevels)) + } + if len(doneLevels) != len(result.Levels) { + t.Fatalf("OnLevelDone calls (%d) != result.Levels (%d)", len(doneLevels), len(result.Levels)) + } + for i, level := range doneLevels { + if level.Concurrency != result.Levels[i].Concurrency { + t.Fatalf("level[%d] concurrency mismatch: callback=%d result=%d", i, level.Concurrency, result.Levels[i].Concurrency) + } + } +} + +func TestEngineOnLevelDone_UnstableLevelIncluded(t *testing.T) { + var doneLevels []types.TurboLevelResult + reports := map[int]*types.ReportData{ + 1: {TotalRequests: 10, SuccessRate: 100, AvgTPS: 10, MaxTPS: 12, AvgTTFT: 50 * time.Millisecond, AvgTotalTime: 200 * time.Millisecond}, + 2: {TotalRequests: 10, SuccessRate: 70, AvgTPS: 8, MaxTPS: 10, AvgTTFT: 80 * time.Millisecond, AvgTotalTime: 300 * time.Millisecond}, + } + engine := New(func(input types.Input) (LevelRunner, error) { + return &fakeRunner{report: reports[input.Concurrency]}, nil + }) + engine.SetOnLevelDone(func(level types.TurboLevelResult) { + doneLevels = append(doneLevels, level) + }) + + result, err := engine.Run(types.Input{ + TurboConfig: types.TurboConfig{InitConcurrency: 1, MaxConcurrency: 3, StepSize: 1, LevelRequests: 10, MinSuccessRate: 0.9, MaxLatency: 10 * time.Second}, + }) + if err != nil { + t.Fatalf("Run() returned unexpected error: %v", err) + } + if len(doneLevels) != 2 { + t.Fatalf("expected 2 OnLevelDone calls (stable+unstable), got %d", len(doneLevels)) + } + // 最后一级应是不稳定的 + lastLevel := doneLevels[len(doneLevels)-1] + if lastLevel.Stable { + t.Fatal("expected last level to be unstable in OnLevelDone callback") + } + if lastLevel.StopReason != StopReasonLowSuccessRate { + t.Fatalf("expected stop reason %s, got %s", StopReasonLowSuccessRate, lastLevel.StopReason) + } + if result.StopReason != StopReasonLowSuccessRate { + t.Fatalf("expected result stop reason %s, got %s", StopReasonLowSuccessRate, result.StopReason) + } +} + +func TestEngineRunAllLevelsStable(t *testing.T) { + engine := New(func(input types.Input) (LevelRunner, error) { + report := &types.ReportData{ + TotalRequests: 5, SuccessRate: 100, + AvgTPS: 10, MaxTPS: 12, + AvgTTFT: 30 * time.Millisecond, AvgTotalTime: 100 * time.Millisecond, + } + return &fakeRunner{report: report}, nil + }) + + result, err := engine.Run(types.Input{ + TurboConfig: types.TurboConfig{InitConcurrency: 2, MaxConcurrency: 4, StepSize: 2, LevelRequests: 5, MinSuccessRate: 0.9, MaxLatency: 5 * time.Second}, + }) + if err != nil { + t.Fatalf("Run() error: %v", err) + } + if len(result.Levels) != 2 { + t.Fatalf("expected 2 levels, got %d", len(result.Levels)) + } + if result.MaxStableConcurrency != 4 { + t.Fatalf("expected MaxStableConcurrency 4, got %d", result.MaxStableConcurrency) + } + if result.StopReason != StopReasonMaxConcurrency { + t.Fatalf("expected stop reason %s, got %s", StopReasonMaxConcurrency, result.StopReason) + } + for i, level := range result.Levels { + if !level.Stable { + t.Fatalf("level[%d] expected stable but got unstable", i) + } + } +} + +func TestEngineFactoryReceivesCorrectConcurrency(t *testing.T) { + var concurrencies []int + engine := New(func(input types.Input) (LevelRunner, error) { + concurrencies = append(concurrencies, input.Concurrency) + report := &types.ReportData{TotalRequests: 5, SuccessRate: 100, AvgTPS: 10, MaxTPS: 12, AvgTotalTime: 100 * time.Millisecond} + return &fakeRunner{report: report}, nil + }) + + _, err := engine.Run(types.Input{ + TurboConfig: types.TurboConfig{InitConcurrency: 2, MaxConcurrency: 6, StepSize: 2, LevelRequests: 5, MinSuccessRate: 0.9, MaxLatency: 5 * time.Second}, + }) + if err != nil { + t.Fatalf("Run() error: %v", err) + } + expected := []int{2, 4, 6} + if len(concurrencies) != len(expected) { + t.Fatalf("expected factory called with %v, got %v", expected, concurrencies) + } + for i, c := range expected { + if concurrencies[i] != c { + t.Fatalf("concurrencies[%d]: expected %d, got %d", i, c, concurrencies[i]) + } + } +} diff --git a/internal/types/types.go b/internal/types/types.go index e128f27..f03c88a 100644 --- a/internal/types/types.go +++ b/internal/types/types.go @@ -277,6 +277,7 @@ type RequestMetrics struct { ErrorMessage string `json:"error_message,omitempty"` RequestBody string `json:"request_body,omitempty"` ResponseBody string `json:"response_body,omitempty"` + Level int `json:"level,omitempty"` } type TurboConfig struct { From 9a32fa8a29f10661cc5c403d0ded21b98e148916 Mon Sep 17 00:00:00 2001 From: Alain Date: Sat, 23 May 2026 15:28:15 +0800 Subject: [PATCH 36/52] =?UTF-8?q?feat:=20=E6=9B=B4=E6=96=B0=E7=8A=B6?= =?UTF-8?q?=E6=80=81=E6=8F=90=E7=A4=BA=E9=80=BB=E8=BE=91=EF=BC=8C=E4=BC=98?= =?UTF-8?q?=E5=8C=96=E4=BB=BB=E5=8A=A1=E9=80=89=E6=8B=A9=E5=92=8C=E8=BE=93?= =?UTF-8?q?=E5=85=A5=E5=AD=97=E6=AE=B5=E5=A4=84=E7=90=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/tui/model.go | 23 ++++++++++++++++++----- internal/tui/pages/contextbar.go | 3 +-- internal/tui/pages/dashboard.go | 4 ++-- internal/tui/pages/helpers.go | 17 ----------------- internal/tui/pages/turbodash.go | 8 +------- internal/tui/pages/wizard.go | 24 +++++++++++++++--------- 6 files changed, 37 insertions(+), 42 deletions(-) diff --git a/internal/tui/model.go b/internal/tui/model.go index 7772421..560abfe 100644 --- a/internal/tui/model.go +++ b/internal/tui/model.go @@ -112,8 +112,6 @@ func (m *Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { if m.taskList.Selected >= len(msg.Tasks) { m.taskList.Selected = max(len(msg.Tasks)-1, 0) } - m.status = "" - m.err = nil return m, nil // ── 任务保存完成(新建或更新) ── @@ -213,9 +211,7 @@ func (m *Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { // ── 代理配置 ── case ProxyConfigLoadedMsg: - if m.proxyConf != nil { - m.proxyConf = pages.NewProxyConfigState(msg.ProxyURL) - } + m.proxyConf = pages.NewProxyConfigState(msg.ProxyURL) return m, nil case ProxyConfigSavedMsg: @@ -233,6 +229,16 @@ func (m *Model) View() string { innerW := m.width innerH := m.height + // 状态/错误提示条占用一行 + var banner string + if m.err != nil { + banner = m.styles.ErrStyle.Width(innerW).Render(" ✗ " + m.err.Error()) + innerH-- + } else if m.status != "" { + banner = m.styles.Ok.Width(innerW).Render(" ✓ " + m.status) + innerH-- + } + var content string switch m.view { case viewTaskList: @@ -253,12 +259,19 @@ func (m *Model) View() string { content = "未知视图" } + if banner != "" { + return banner + "\n" + content + } return content } // ─── 键盘分发 ───────────────────────────────────────────────────────────────── func (m *Model) handleKey(msg tea.KeyMsg) (tea.Model, tea.Cmd) { + // 任意按键清除状态提示 + m.status = "" + m.err = nil + switch m.view { case viewTaskList: newState, cmd, nav := pages.HandleTaskListKey(m.taskList, msg, m.client) diff --git a/internal/tui/pages/contextbar.go b/internal/tui/pages/contextbar.go index 2cb02fa..51bd4e2 100644 --- a/internal/tui/pages/contextbar.go +++ b/internal/tui/pages/contextbar.go @@ -77,7 +77,7 @@ func Hotkeys_TaskDetail_HasHistory() []HotkeyItem { func Hotkeys_TaskDetail_Running() []HotkeyItem { return []HotkeyItem{ HotkeyAction("↑↓", "选择记录"), - HotkeyAction("Enter", "进入运行中仓表盘"), + HotkeyAction("Enter", "进入运行中仪表盘"), HotkeyAction("g", "导出历史 JSON"), HotkeyAction("e", "编辑"), HotkeyAction("y", "复制任务"), @@ -152,7 +152,6 @@ func Hotkeys_Dashboard_Done_Sel() []HotkeyItem { func Hotkeys_TurboDash_Running_NoSel() []HotkeyItem { return []HotkeyItem{ HotkeyAction("s", "停止"), - HotkeyAction("m", "标记极限并停止"), HotkeyAction("b/Esc", "返回列表"), } } diff --git a/internal/tui/pages/dashboard.go b/internal/tui/pages/dashboard.go index 26605ac..486a73d 100644 --- a/internal/tui/pages/dashboard.go +++ b/internal/tui/pages/dashboard.go @@ -81,8 +81,8 @@ func HandleDashboardKey(d *DashboardState, msg tea.KeyMsg, client Client) (*Dash break } if d.ReqSel < 0 { - // 无选中项:首次按键选中最新一条(显示列表顶部) - d.ReqSel = requestIndexFromDisplayPos(0, len(reqs)) + // 无选中项:首次向上按键选中最旧一条(显示列表底部) + d.ReqSel = requestIndexFromDisplayPos(len(reqs)-1, len(reqs)) } else { selPos := requestDisplayPos(d.ReqSel, len(reqs)) if selPos <= 0 { diff --git a/internal/tui/pages/helpers.go b/internal/tui/pages/helpers.go index 12567cb..2c64862 100644 --- a/internal/tui/pages/helpers.go +++ b/internal/tui/pages/helpers.go @@ -88,23 +88,6 @@ func dividerLine(st Styles, width int) string { return st.Divider.Render(strings.Repeat("─", width)) } -// ─── 时间格式化 ─────────────────────────────────────────────────────────────── - -// timeAgo 将时间转换为"N分钟前"/"刚刚"等人性化描述。 -func timeAgo(t time.Time) string { - d := time.Since(t) - switch { - case d < time.Minute: - return "刚刚" - case d < time.Hour: - return fmt.Sprintf("%d 分钟前", int(d.Minutes())) - case d < 24*time.Hour: - return fmt.Sprintf("%d 小时前", int(d.Hours())) - default: - return t.Format("2006-01-02 15:04") - } -} - // fmtDuration 格式化 Duration 为简短字符串(ms/s/min)。 func fmtDuration(d time.Duration) string { ms := d.Milliseconds() diff --git a/internal/tui/pages/turbodash.go b/internal/tui/pages/turbodash.go index d242357..c8ae759 100644 --- a/internal/tui/pages/turbodash.go +++ b/internal/tui/pages/turbodash.go @@ -80,7 +80,7 @@ func HandleTurboDashKey(d *TurboDashState, msg tea.KeyMsg, client Client) (*Turb break } if d.ReqSel < 0 { - d.ReqSel = requestIndexFromDisplayPos(0, len(reqs)) + d.ReqSel = requestIndexFromDisplayPos(len(reqs)-1, len(reqs)) } else { selPos := requestDisplayPos(d.ReqSel, len(reqs)) if selPos > 0 { @@ -117,12 +117,6 @@ func HandleTurboDashKey(d *TurboDashState, msg tea.KeyMsg, client Client) (*Turb return d, client.StopRunCmd(d.RunID), nav } - case "m": - // 手动标记极限并停止 - if d.IsRunning() { - return d, client.StopRunCmd(d.RunID), nav - } - case "b", "esc": if d.BackNav.To != NavNone { nav = d.BackNav diff --git a/internal/tui/pages/wizard.go b/internal/tui/pages/wizard.go index 3c80ed3..646d9f4 100644 --- a/internal/tui/pages/wizard.go +++ b/internal/tui/pages/wizard.go @@ -88,7 +88,7 @@ func loadInputForField(wz *WizardState, f fieldDef) { } else if f.get != nil { rawVal = f.get(wz) } - if f.label == "API 密钥" { + if f.password { wz.input.EchoMode = textinput.EchoPassword } else { wz.input.EchoMode = textinput.EchoNormal @@ -248,6 +248,10 @@ type fieldDef struct { set func(wz *WizardState, v string) // 枚举/布尔切换 toggle func(wz *WizardState, forward bool) + // password 为 true 时以密码模式显示输入 + password bool + // triggersFieldReset 为 true 时切换后重置字段列表索引 + triggersFieldReset bool } type fieldKind int @@ -320,8 +324,9 @@ func step1Fields() []fieldDef { }, { kind: fieldText, label: "API 密钥", - get: func(wz *WizardState) string { return wz.APIKey }, - set: func(wz *WizardState, v string) { wz.APIKey = v }, + get: func(wz *WizardState) string { return wz.APIKey }, + set: func(wz *WizardState, v string) { wz.APIKey = v }, + password: true, }, { kind: fieldText, label: "测试模型", @@ -342,7 +347,8 @@ func step2Fields(turbo bool) []fieldDef { } return "标准模式" }, - toggle: func(wz *WizardState, _ bool) { wz.Turbo = !wz.Turbo }, + toggle: func(wz *WizardState, _ bool) { wz.Turbo = !wz.Turbo }, + triggersFieldReset: true, }, } @@ -501,7 +507,7 @@ func HandleWizardKey(wz *WizardState, msg tea.KeyMsg, client Client) (*WizardSta } nav = NavAction{To: NavTaskList} return wz, cmd, nav - case "q", "ctrl+c": + case "ctrl+c": nav = NavAction{To: NavQuit} } return wz, nil, nav @@ -537,7 +543,7 @@ func HandleWizardKey(wz *WizardState, msg tea.KeyMsg, client Client) (*WizardSta f := fields[wz.FieldIndex] if f.toggle != nil { f.toggle(wz, false) - if f.label == "测试模式" { + if f.triggersFieldReset { wz.FieldIndex = 0 wz.ScrollOff = 0 loadCurrentFieldInput(wz) @@ -554,7 +560,7 @@ func HandleWizardKey(wz *WizardState, msg tea.KeyMsg, client Client) (*WizardSta f := fields[wz.FieldIndex] if f.toggle != nil { f.toggle(wz, true) - if f.label == "测试模式" { + if f.triggersFieldReset { wz.FieldIndex = 0 wz.ScrollOff = 0 loadCurrentFieldInput(wz) @@ -576,7 +582,7 @@ func HandleWizardKey(wz *WizardState, msg tea.KeyMsg, client Client) (*WizardSta } loadCurrentFieldInput(wz) - case "q", "ctrl+c": + case "ctrl+c": nav = NavAction{To: NavQuit} default: @@ -638,7 +644,7 @@ func RenderWizard(wz *WizardState, st Styles, width, height int) string { HeaderMeta: fmt.Sprintf("步骤 %d/3", int(wz.Step)+1), HeaderInfoLeft: headerLeft, HeaderInfoRight: headerRight, - Hotkeys: NewPageHotkeys(wizardHotkeyItems(wz.Step), "[q] 退出"), + Hotkeys: NewPageHotkeys(wizardHotkeyItems(wz.Step), "[Ctrl+C] 退出"), } frame := l.Frame(width, height) panel := NewPanelFrame(frame.OuterWidth) From b5aa6f1e12a02f1699316f412464745699d73555 Mon Sep 17 00:00:00 2001 From: Alain Date: Sat, 23 May 2026 19:14:35 +0800 Subject: [PATCH 37/52] =?UTF-8?q?feat:=20=E6=9B=B4=E6=96=B0=E4=BB=BB?= =?UTF-8?q?=E5=8A=A1=E5=92=8C=E4=BB=AA=E8=A1=A8=E6=9D=BF=E7=95=8C=E9=9D=A2?= =?UTF-8?q?=EF=BC=8C=E4=BC=98=E5=8C=96=E5=88=97=E5=AE=BD=E5=92=8C=E6=98=BE?= =?UTF-8?q?=E7=A4=BA=E5=86=85=E5=AE=B9=EF=BC=8C=E6=B7=BB=E5=8A=A0=E6=88=90?= =?UTF-8?q?=E5=8A=9F=E7=8E=87=E5=92=8C=E8=80=97=E6=97=B6=E4=BF=A1=E6=81=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/tui/pages/dashboard.go | 9 ++++++--- internal/tui/pages/taskdetail.go | 11 +++++++++-- internal/tui/pages/tasklist.go | 21 ++++++++++++++++----- internal/tui/pages/turbodash.go | 9 ++++++--- 4 files changed, 37 insertions(+), 13 deletions(-) diff --git a/internal/tui/pages/dashboard.go b/internal/tui/pages/dashboard.go index 486a73d..6ae9cdf 100644 --- a/internal/tui/pages/dashboard.go +++ b/internal/tui/pages/dashboard.go @@ -326,7 +326,8 @@ func buildRequestList(d *DashboardState, rs *server.RunState, st Styles, width, timeW = 10 // 总耗时 ttftW = 10 // TTFT cacheW = 8 // Cache - tokW = 10 // Token + ptokW = 9 // 提示 tok + ctokW = 9 // 完成 tok // TPS: 余量 ) hdr := lipgloss.JoinHorizontal(lipgloss.Top, @@ -336,7 +337,8 @@ func buildRequestList(d *DashboardState, rs *server.RunState, st Styles, width, tableCol(timeW, "总耗时"), tableCol(ttftW, "TTFT"), tableCol(cacheW, "Cache"), - tableCol(tokW, "Token"), + tableCol(ptokW, "提示tok"), + tableCol(ctokW, "完成tok"), "TPS", ) lines = append(lines, renderTableHeader(st, width, hdr)) @@ -381,7 +383,8 @@ func buildRequestList(d *DashboardState, rs *server.RunState, st Styles, width, tableCol(timeW, totalStr), tableCol(ttftW, fmtDuration(r.TTFT)), tableCol(cacheW, fmt.Sprintf("%.0f%%", r.CacheHitRate*100)), - tableCol(tokW, fmt.Sprintf("%dtok", r.CompletionTokens)), + tableCol(ptokW, fmt.Sprintf("%dtok", r.PromptTokens)), + tableCol(ctokW, fmt.Sprintf("%dtok", r.CompletionTokens)), fmt.Sprintf("%.1f/s", r.TPS), ) diff --git a/internal/tui/pages/taskdetail.go b/internal/tui/pages/taskdetail.go index acaa23a..fedc80a 100644 --- a/internal/tui/pages/taskdetail.go +++ b/internal/tui/pages/taskdetail.go @@ -267,13 +267,14 @@ func buildTaskDetailContent(s *TaskDetailState, st Styles, t types.TaskDefinitio const ( markW = 2 statW = 2 - timeW = 17 + timeW = 15 modeW = 7 rateW = 8 + durW = 9 // 耗时 ttftW = 10 ) hdr := padRight("", markW) + padRight("", statW) + padRight("时间", timeW) + padRight("模式", modeW) + - padRight("成功率", rateW) + padRight("TTFT", ttftW) + "TPS" + padRight("成功率", rateW) + padRight("耗时", durW) + padRight("TTFT", ttftW) + "TPS" rightLines = append(rightLines, renderTableHeader(st, rightW, hdr)) rightLines = append(rightLines, padRight(st.Divider.Render(strings.Repeat("─", rightW)), rightW)) @@ -332,6 +333,7 @@ func buildTaskDetailContent(s *TaskDetailState, st Styles, t types.TaskDefinitio padRight(rs.StartedAt.Format("2006-01-02 15:04"), timeW) + padRight(modeShort, modeW) + padRight(rateStr, rateW) + + padRight("─", durW) + styleWhenNotSelected(isSel, st.Ok, progStr) } else { histIdx := idx @@ -357,11 +359,16 @@ func buildTaskDetailContent(s *TaskDetailState, st Styles, t types.TaskDefinitio modeShort = "Turbo" } statusIcon := styleWhenNotSelected(isSel, statusStyle, statusText) + durText := "─" + if !run.FinishedAt.IsZero() { + durText = fmtDuration(run.FinishedAt.Sub(run.StartedAt)) + } row = padRight(marker, markW) + padRight(statusIcon, statW) + padRight(run.StartedAt.Format("2006-01-02 15:04"), timeW) + padRight(modeShort, modeW) + padRight(fmt.Sprintf("%.1f%%", run.SuccessRate), rateW) + + padRight(durText, durW) + padRight(fmtDuration(run.AvgTTFT), ttftW) + fmt.Sprintf("%.1f", run.AvgTPS) } diff --git a/internal/tui/pages/tasklist.go b/internal/tui/pages/tasklist.go index 282d0b7..0b56c06 100644 --- a/internal/tui/pages/tasklist.go +++ b/internal/tui/pages/tasklist.go @@ -200,12 +200,13 @@ func buildTaskListContent(s *TaskListState, st Styles, width, maxH int) string { // 列宽(gap=2 作为列间距内置到每个非末尾列的宽度中) const ( modeW = 9 // 7 + 2 gap - protoW = 20 // 10 + 2 gap + protoW = 18 // 8 + 2 gap lastRunW = 16 // 11 + 2 gap - ttftW = 16 // 10 + 2 gap + rateW = 10 // 6 + 2 gap -- 成功率 + ttftW = 10 // 6 + 2 gap tpsW = 16 // 末尾列,无需额外 gap ) - fixedW := 2 + modeW + protoW + lastRunW + ttftW + tpsW + fixedW := 2 + modeW + protoW + lastRunW + rateW + ttftW + tpsW nameW := maxInt(10, width-fixedW) // 表头:2 空格前缀与正文行对齐(cursor=2) @@ -216,6 +217,7 @@ func buildTaskListContent(s *TaskListState, st Styles, width, maxH int) string { tableCol(modeW, "模式"), tableCol(protoW, "协议"), tableCol(lastRunW, "上次运行"), + tableCol(rateW, "成功率"), tableCol(ttftW, "TTFT"), "TPS", )) @@ -285,6 +287,15 @@ func buildTaskListContent(s *TaskListState, st Styles, width, maxH int) string { } ttftCol := tableCol(ttftW, styleWhenNotSelected(isSel, st.Value, ttftText)) + // ── 成功率 ── + rateText := "─" + if hasActiveRun && rs != nil && rs.TotalReqs > 0 { + rateText = fmt.Sprintf("%.1f%%", rs.SuccessRate) + } else if !hasActiveRun && t.LatestRun != nil { + rateText = fmt.Sprintf("%.1f%%", t.LatestRun.SuccessRate) + } + rateCol := tableCol(rateW, styleWhenNotSelected(isSel, st.Value, rateText)) + // ── TPS ── tpsText := "─" if hasActiveRun && rs != nil && rs.AvgTPS > 0 { @@ -298,9 +309,9 @@ func buildTaskListContent(s *TaskListState, st Styles, width, maxH int) string { } tpsCol := styleWhenNotSelected(isSel, st.Value, tpsText) - // ── 单行:名称 | 模式 | 协议 | 上次运行 | TTFT | TPS ── + // ── 单行:名称 | 模式 | 协议 | 上次运行 | 成功率 | TTFT | TPS ── lines = append(lines, renderTableRow(st, width, isSel, lipgloss.JoinHorizontal(lipgloss.Top, - prefix, nameCol, modeCol, proto, lastRunCol, ttftCol, tpsCol))) + prefix, nameCol, modeCol, proto, lastRunCol, rateCol, ttftCol, tpsCol))) // ── 分隔线 ── if i < end-1 && len(lines) < maxH-1 { diff --git a/internal/tui/pages/turbodash.go b/internal/tui/pages/turbodash.go index c8ae759..5bbcb4f 100644 --- a/internal/tui/pages/turbodash.go +++ b/internal/tui/pages/turbodash.go @@ -332,7 +332,8 @@ func buildTurboRequestList(d *TurboDashState, rs *server.RunState, st Styles, wi timeW = 10 ttftW = 10 cacheW = 8 - tokW = 10 + ptokW = 9 // 提示 tok + ctokW = 9 // 完成 tok ) hdr := lipgloss.JoinHorizontal(lipgloss.Top, tableCol(markW, ""), @@ -342,7 +343,8 @@ func buildTurboRequestList(d *TurboDashState, rs *server.RunState, st Styles, wi tableCol(timeW, "总耗时"), tableCol(ttftW, "TTFT"), tableCol(cacheW, "Cache"), - tableCol(tokW, "Token"), + tableCol(ptokW, "提示tok"), + tableCol(ctokW, "完成tok"), "TPS", ) lines = append(lines, renderTableHeader(st, width, hdr)) @@ -387,7 +389,8 @@ func buildTurboRequestList(d *TurboDashState, rs *server.RunState, st Styles, wi tableCol(timeW, totalStr), tableCol(ttftW, fmtDuration(r.TTFT)), tableCol(cacheW, fmt.Sprintf("%.0f%%", r.CacheHitRate*100)), - tableCol(tokW, fmt.Sprintf("%dtok", r.CompletionTokens)), + tableCol(ptokW, fmt.Sprintf("%dtok", r.PromptTokens)), + tableCol(ctokW, fmt.Sprintf("%dtok", r.CompletionTokens)), fmt.Sprintf("%.1f/s", r.TPS), ) From d5e937c395df5963635353b8a22d678fa37a06de Mon Sep 17 00:00:00 2001 From: Alain Date: Sat, 23 May 2026 21:00:20 +0800 Subject: [PATCH 38/52] Refactor metrics mapping and enhance TUI table rendering - Update `mapRequestMetrics` to set CacheHitRate to 1 when CachedInputTokens > 0. - Modify unit tests for `mapRequestMetrics` to reflect new CacheHitRate logic. - Refactor TUI dashboard and task detail pages to utilize `lgtable` for rendering tables, improving layout and readability. - Consolidate table rendering logic in task list and turbo request list, enhancing maintainability. - Remove redundant functions related to table rendering and selection markers. - Ensure consistent styling and formatting across various TUI components. --- internal/runner/runner.go | 4 +- internal/server/run.go | 4 +- internal/server/server_test.go | 21 ++- internal/tui/pages/dashboard.go | 161 +++++++++++-------- internal/tui/pages/helpers.go | 44 ------ internal/tui/pages/taskdetail.go | 262 ++++++++++++++++++------------- internal/tui/pages/tasklist.go | 207 +++++++++++++----------- internal/tui/pages/turbodash.go | 161 +++++++++++-------- 8 files changed, 477 insertions(+), 387 deletions(-) diff --git a/internal/runner/runner.go b/internal/runner/runner.go index 479fd71..42dcb52 100644 --- a/internal/runner/runner.go +++ b/internal/runner/runner.go @@ -62,10 +62,10 @@ func (r *Runner) acquireSlot(ch chan int) bool { } func calculateCacheHitRate(metrics *client.ResponseMetrics) float64 { - if metrics == nil || metrics.PromptTokens <= 0 { + if metrics == nil || metrics.CachedInputTokens <= 0 { return 0 } - return float64(metrics.CachedInputTokens) / float64(metrics.PromptTokens) + return 1 } // Run 执行性能测试,返回结果数据 diff --git a/internal/server/run.go b/internal/server/run.go index 6a62d19..578d9b6 100644 --- a/internal/server/run.go +++ b/internal/server/run.go @@ -92,8 +92,8 @@ func mapRequestMetrics(m *client.ResponseMetrics, idx int, err error) *types.Req if m.TotalTime > 0 && m.CompletionTokens > 0 { rm.TPS = float64(m.CompletionTokens) / m.TotalTime.Seconds() } - if m.PromptTokens > 0 { - rm.CacheHitRate = float64(m.CachedInputTokens) / float64(m.PromptTokens) + if m.CachedInputTokens > 0 { + rm.CacheHitRate = 1 } return rm } diff --git a/internal/server/server_test.go b/internal/server/server_test.go index 4a9521d..476bc5f 100644 --- a/internal/server/server_test.go +++ b/internal/server/server_test.go @@ -229,9 +229,9 @@ func TestMapRequestMetrics_SuccessFields(t *testing.T) { if rm.TPS != 50.0 { t.Errorf("TPS: got %v, want 50", rm.TPS) } - // CacheHitRate = CachedInputTokens / PromptTokens = 50 / 200 = 0.25 - if rm.CacheHitRate != 0.25 { - t.Errorf("CacheHitRate: got %v, want 0.25", rm.CacheHitRate) + // CacheHitRate = 1 if CachedInputTokens > 0 else 0 + if rm.CacheHitRate != 1.0 { + t.Errorf("CacheHitRate: got %v, want 1.0", rm.CacheHitRate) } if rm.TargetIP != "1.2.3.4" { t.Errorf("TargetIP: got %q, want %q", rm.TargetIP, "1.2.3.4") @@ -281,11 +281,18 @@ func TestMapRequestMetrics_ZeroTotalTimeSkipsTPS(t *testing.T) { } } -func TestMapRequestMetrics_ZeroPromptTokensSkipsCacheHitRate(t *testing.T) { - m := &client.ResponseMetrics{CachedInputTokens: 10} // PromptTokens == 0 +func TestMapRequestMetrics_CacheHitRateBinary(t *testing.T) { + // 有缓存命中 → 1 + m := &client.ResponseMetrics{CachedInputTokens: 10, PromptTokens: 100} rm := mapRequestMetrics(m, 0, nil) - if rm.CacheHitRate != 0 { - t.Errorf("expected CacheHitRate=0 when PromptTokens=0, got %v", rm.CacheHitRate) + if rm.CacheHitRate != 1.0 { + t.Errorf("expected CacheHitRate=1 when CachedInputTokens>0, got %v", rm.CacheHitRate) + } + // 无缓存命中 → 0 + m2 := &client.ResponseMetrics{CachedInputTokens: 0, PromptTokens: 100} + rm2 := mapRequestMetrics(m2, 0, nil) + if rm2.CacheHitRate != 0 { + t.Errorf("expected CacheHitRate=0 when CachedInputTokens=0, got %v", rm2.CacheHitRate) } } diff --git a/internal/tui/pages/dashboard.go b/internal/tui/pages/dashboard.go index 6ae9cdf..c16d2b7 100644 --- a/internal/tui/pages/dashboard.go +++ b/internal/tui/pages/dashboard.go @@ -7,6 +7,7 @@ import ( tea "github.com/charmbracelet/bubbletea" "charm.land/lipgloss/v2" + lgtable "charm.land/lipgloss/v2/table" "github.com/yinxulai/ait/internal/server" "github.com/yinxulai/ait/internal/types" ) @@ -307,53 +308,35 @@ func buildProgressLine(rs *server.RunState, st Styles, width int) string { // buildRequestList 构建请求列表区域。 func buildRequestList(d *DashboardState, rs *server.RunState, st Styles, width, maxH int) string { - lines := panelTitleLines(st, "请求列表", width, true) + titleLines := panelTitleLines(st, "请求列表", width, true) if rs == nil || len(rs.Requests) == 0 { msg := "等待请求..." if rs != nil && rs.Status != server.RunStatusRunning { msg = "无请求详情数据" } - lines = append(lines, " "+st.Muted.Render(msg)) - return finishPanelLines(lines, maxH) + titleLines = append(titleLines, " "+st.Muted.Render(msg)) + return finishPanelLines(titleLines, maxH) } - // 列宽(header 与 content 行保持一致,前缀均为 2 字符) - const ( - markW = 2 // 选择标记列 - idW = 6 // "#1" 等 - statW = 5 // "✓" / "✗" 加空白 - timeW = 10 // 总耗时 - ttftW = 10 // TTFT - cacheW = 8 // Cache - ptokW = 9 // 提示 tok - ctokW = 9 // 完成 tok - // TPS: 余量 - ) - hdr := lipgloss.JoinHorizontal(lipgloss.Top, - tableCol(markW, ""), - tableCol(idW, "#"), - tableCol(statW, "状态"), - tableCol(timeW, "总耗时"), - tableCol(ttftW, "TTFT"), - tableCol(cacheW, "Cache"), - tableCol(ptokW, "提示tok"), - tableCol(ctokW, "完成tok"), - "TPS", - ) - lines = append(lines, renderTableHeader(st, width, hdr)) - lines = append(lines, dividerLine(st, width)) - d.ReqVis = listVisibleItems(maxH, 3) - d.AdjustReqOffset(d.ReqVis, len(rs.Requests)) - + // ── 预计算每行数据(按展示顺序,最新在前)── + type reqRow struct { + success bool + errMsg string + id string + status string + total string + ttft string + cache string + ptok string + ctok string + tps string + } reqs := rs.Requests - start := d.ReqOff - end := minInt(len(reqs), start+d.ReqVis) - for pos := start; pos < end; pos++ { + reqRows := make([]reqRow, len(reqs)) + for pos := 0; pos < len(reqs); pos++ { i := requestIndexFromDisplayPos(pos, len(reqs)) r := reqs[i] - isSel := i == d.ReqSel - statusText := "✓" if !r.Success { statusText = "✗" @@ -362,42 +345,84 @@ func buildRequestList(d *DashboardState, rs *server.RunState, st Styles, width, if !r.Success && r.ErrorMessage != "" { totalText = r.ErrorMessage } - - statusStr := statusText - if r.Success { - statusStr = styleWhenNotSelected(isSel, st.Ok, statusText) - } else { - statusStr = styleWhenNotSelected(isSel, st.ErrStyle, statusText) - } - totalStr := totalText - if !r.Success && r.ErrorMessage != "" { - totalStr = styleWhenNotSelected(isSel, st.ErrStyle, totalText) + reqRows[pos] = reqRow{ + success: r.Success, + errMsg: r.ErrorMessage, + id: fmt.Sprintf("#%d", len(reqs)-pos), + status: statusText, + total: totalText, + ttft: fmtDuration(r.TTFT), + cache: fmt.Sprintf("%dtok", r.CachedTokens), + ptok: fmt.Sprintf("%dtok", r.PromptTokens), + ctok: fmt.Sprintf("%dtok", r.CompletionTokens), + tps: fmt.Sprintf("%.1f/s", r.TPS), } + } - marker := selectionMarker(isSel) - - rowContent := lipgloss.JoinHorizontal(lipgloss.Top, - tableCol(markW, marker), - tableCol(idW, fmt.Sprintf("#%d", len(reqs)-pos)), - tableCol(statW, statusStr), - tableCol(timeW, totalStr), - tableCol(ttftW, fmtDuration(r.TTFT)), - tableCol(cacheW, fmt.Sprintf("%.0f%%", r.CacheHitRate*100)), - tableCol(ptokW, fmt.Sprintf("%dtok", r.PromptTokens)), - tableCol(ctokW, fmt.Sprintf("%dtok", r.CompletionTokens)), - fmt.Sprintf("%.1f/s", r.TPS), - ) - - rendered := renderTableRow(st, width, isSel, rowContent) - lines = append(lines, rendered) - - // 行间分隔线 - if pos < end-1 && len(lines) < maxH-1 { - lines = append(lines, dividerLine(st, width)) - } + // 将 d.ReqSel(绝对索引)转换为展示位置 + selDisplayPos := requestDisplayPos(d.ReqSel, len(reqs)) + + // colWidths: 0 = 弹性列(占用剩余宽度),>0 = 固定总宽 + colWidths := []int{6, 8, 0, 8, 7, 10, 10, 10} // #, 状态, 总耗时=flex, TTFT, Cache, 提示tok, 完成tok, TPS + tableH := maxH - len(titleLines) + tbl := lgtable.New(). + Headers("#", "状态", "总耗时", "TTFT", "Cache", "提示tok", "完成tok", "TPS"). + Width(width). + Height(tableH). + YOffset(d.ReqOff). + BorderTop(false).BorderBottom(false). + BorderLeft(false).BorderRight(false). + BorderHeader(true).BorderColumn(true).BorderRow(true). + BorderStyle(lipgloss.NewStyle().Foreground(colorDivider)). + StyleFunc(func(row, col int) lipgloss.Style { + aw := func(s lipgloss.Style) lipgloss.Style { + if col < len(colWidths) && colWidths[col] > 0 { + return s.Width(colWidths[col]).Padding(0, 1) + } + return s.Padding(0, 1) + } + if row == lgtable.HeaderRow { + return aw(st.TableHead) + } + if row < 0 || row >= len(reqRows) { + return aw(st.TableRow) + } + r := reqRows[row] + if row == selDisplayPos { + return aw(st.TableRowSel) + } + switch col { + case 1: // status + if r.success { + return aw(st.Ok) + } + return aw(st.ErrStyle) + case 2: // total + if !r.success && r.errMsg != "" { + return aw(st.ErrStyle) + } + return aw(st.Value) + case 3, 4, 5, 6, 7: // ttft, cache, ptok, ctok, tps + return aw(st.Value) + default: + return aw(st.TableRow) + } + }) + + for _, r := range reqRows { + tbl.Row(r.id, r.status, r.total, r.ttft, r.cache, r.ptok, r.ctok, r.tps) } - return finishPanelLines(lines, maxH) + tableStr := tbl.String() + d.ReqVis = tbl.VisibleRows() + if d.ReqVis < 1 { + d.ReqVis = 1 + } + d.AdjustReqOffset(d.ReqVis, len(reqs)) + + tableLines := strings.Split(tableStr, "\n") + result := append(titleLines, tableLines...) + return finishPanelLines(result, maxH) } func requestDisplayPos(reqIndex, total int) int { diff --git a/internal/tui/pages/helpers.go b/internal/tui/pages/helpers.go index 2c64862..e7d0868 100644 --- a/internal/tui/pages/helpers.go +++ b/internal/tui/pages/helpers.go @@ -41,11 +41,6 @@ func padRight(s string, width int) string { return s + strings.Repeat(" ", width-w) } -// tableCol 返回固定宽度的表格单元格(自动填充空格并截断,始终单行)。 -func tableCol(w int, text string) string { - return lipgloss.NewStyle().Width(w).Render(truncate(text, w)) -} - // wrapText 将文本按 maxW 列宽折行,返回行切片(CJK 字符按 2 列宽计算)。 func wrapText(s string, maxW int) []string { if maxW <= 0 { @@ -485,19 +480,6 @@ func normalizeInlineText(s string) string { return strings.Join(strings.Fields(replacer.Replace(s)), " ") } -// renderTableHeader 统一渲染列表表头。 -func renderTableHeader(st Styles, width int, row string) string { - return st.TableHead.Width(width).Render(row) -} - -// renderTableRow 统一渲染列表行(选中/未选中)。 -func renderTableRow(st Styles, width int, isSel bool, row string) string { - if isSel { - return st.TableRowSel.Width(width).Render(row) - } - return st.TableRow.Width(width).Render(row) -} - // minInt 返回两个整数中的较小值。 func minInt(a, b int) int { if a < b { @@ -525,16 +507,6 @@ func clampInt(v, low, high int) int { return v } -// listVisibleItems 计算在给定高度下可自然滚动的列表项数量。 -// staticLines 是列表项区域前的固定行数(如 section/header/divider)。 -func listVisibleItems(maxLines, staticLines int) int { - visible := (maxLines - staticLines + 1) / 2 - if visible < 1 { - return 1 - } - return visible -} - // ensureVisibleOffset 让 selected 始终位于 offset/visible 定义的可视窗口内。 func ensureVisibleOffset(selected, count, offset, visible int) int { if count <= 0 { @@ -555,22 +527,6 @@ func ensureVisibleOffset(selected, count, offset, visible int) int { return clampInt(offset, 0, maxOffset) } -// selectionMarker 返回统一的选中标记列内容。 -func selectionMarker(isSel bool) string { - if isSel { - return "▶" - } - return "" -} - -// styleWhenNotSelected 仅在未选中时应用局部样式,避免重置选中行背景。 -func styleWhenNotSelected(isSel bool, style lipgloss.Style, text string) string { - if isSel { - return text - } - return style.Render(text) -} - // wrapIndex 循环索引(保证 0 ≤ result < count)。 func wrapIndex(idx, count int) int { if count <= 0 { diff --git a/internal/tui/pages/taskdetail.go b/internal/tui/pages/taskdetail.go index fedc80a..77d9480 100644 --- a/internal/tui/pages/taskdetail.go +++ b/internal/tui/pages/taskdetail.go @@ -6,6 +6,7 @@ import ( "time" "charm.land/lipgloss/v2" + lgtable "charm.land/lipgloss/v2/table" tea "github.com/charmbracelet/bubbletea" "github.com/yinxulai/ait/internal/server" "github.com/yinxulai/ait/internal/types" @@ -255,7 +256,7 @@ func buildTaskDetailContent(s *TaskDetailState, st Styles, t types.TaskDefinitio // ─── 右栏:历史运行记录 ───────────────────────────────────── rightW := rightPanelFrame.InnerWidth - rightLines := panelTitleLines(st, "历史运行记录", rightW, false) + rightTitle := panelTitleLines(st, "历史运行记录", rightW, false) // 2 行 historyEntries := taskDetailHistoryEntries(s) hasActive := s.ActiveRun != nil @@ -264,121 +265,170 @@ func buildTaskDetailContent(s *TaskDetailState, st Styles, t types.TaskDefinitio effectiveLen++ } - const ( - markW = 2 - statW = 2 - timeW = 15 - modeW = 7 - rateW = 8 - durW = 9 // 耗时 - ttftW = 10 - ) - hdr := padRight("", markW) + padRight("", statW) + padRight("时间", timeW) + padRight("模式", modeW) + - padRight("成功率", rateW) + padRight("耗时", durW) + padRight("TTFT", ttftW) + "TPS" - rightLines = append(rightLines, renderTableHeader(st, rightW, hdr)) - rightLines = append(rightLines, padRight(st.Divider.Render(strings.Repeat("─", rightW)), rightW)) - if effectiveLen == 0 { - rightLines = append(rightLines, padRight(" "+st.Muted.Render("暂无运行记录"), rightW)) - } else { - // 始终为当前选中的历史条目显示详情面板 - var detailLines []string - { - histIdx := s.HistorySel - if hasActive { - if s.HistorySel == 0 { - histIdx = -1 // 运行中条目无详情 - } else { - histIdx-- - } - } - if histIdx >= 0 { - detailLines = buildTaskHistoryDetailLines(historyEntries, histIdx, st, rightW) + rightLines := append(rightTitle, padRight(" "+st.Muted.Render("暂无运行记录"), rightW)) + rightContent := finishPanelLines(rightLines, panelContentH) + return renderSplitPanels(st, leftPanelFrame, rightPanelFrame, leftContent, rightContent) + } + + // ── 计算 detailLines (选中条目详情)── + var detailLines []string + { + histIdx := s.HistorySel + if hasActive { + if s.HistorySel == 0 { + histIdx = -1 // 运行中条目无详情 + } else { + histIdx-- } } - tableMaxH := panelContentH - len(detailLines) - if tableMaxH < 5 { - allowedDetail := maxInt(0, panelContentH-5) - if len(detailLines) > allowedDetail { - detailLines = detailLines[:allowedDetail] - } - tableMaxH = panelContentH - len(detailLines) + if histIdx >= 0 { + detailLines = buildTaskHistoryDetailLines(historyEntries, histIdx, st, rightW) } - s.HistoryVis = listVisibleItems(tableMaxH, 4) - s.HistoryOff = ensureVisibleOffset(s.HistorySel, effectiveLen, s.HistoryOff, s.HistoryVis) - start := s.HistoryOff - end := minInt(effectiveLen, start+s.HistoryVis) - - // ── 历史列表 ── - for idx := start; idx < end; idx++ { - isSel := idx == s.HistorySel - marker := selectionMarker(isSel) - var row string - - if hasActive && idx == 0 { - // 正在运行中的条目 - rs := s.ActiveRun - modeShort := "标准" - if rs.Mode == "turbo" { - modeShort = "Turbo" - } - statusIcon := styleWhenNotSelected(isSel, st.Ok, "●") - rateStr := "─" - if rs.TotalReqs > 0 { - rateStr = fmt.Sprintf("%.0f%%", rs.SuccessRate) - } - progStr := fmt.Sprintf("%d/%d 正在运行...", rs.DoneReqs, rs.TotalReqs) - row = padRight(marker, markW) + - padRight(statusIcon, statW) + - padRight(rs.StartedAt.Format("2006-01-02 15:04"), timeW) + - padRight(modeShort, modeW) + - padRight(rateStr, rateW) + - padRight("─", durW) + - styleWhenNotSelected(isSel, st.Ok, progStr) - } else { - histIdx := idx - if hasActive { - histIdx-- - } - run := historyEntries[histIdx] - statusText := "✗" - statusStyle := st.ErrStyle - switch run.Status { - case string(server.RunStatusRunning): - statusText = "●" - statusStyle = st.Ok - case string(server.RunStatusCompleted): - statusText = "✓" - statusStyle = st.Ok - case string(server.RunStatusStopped): - statusText = "■" - statusStyle = st.Muted + } + tableMaxH := panelContentH - len(detailLines) + if tableMaxH < 5 { + allowedDetail := maxInt(0, panelContentH-5) + if len(detailLines) > allowedDetail { + detailLines = detailLines[:allowedDetail] + } + tableMaxH = panelContentH - len(detailLines) + } + + // ── 预计算每行数据── + type histRow struct { + statusText string + statusIsOk bool + statusIsMut bool + time string + mode string + rate string + dur string + ttft string + tps string + } + rowData := make([]histRow, effectiveLen) + if hasActive { + rs := s.ActiveRun + modeShort := "标准" + if rs.Mode == "turbo" { + modeShort = "Turbo" + } + rateStr := "─" + if rs.TotalReqs > 0 { + rateStr = fmt.Sprintf("%.0f%%", rs.SuccessRate) + } + rowData[0] = histRow{ + statusText: "●", + statusIsOk: true, + time: rs.StartedAt.Format("2006-01-02 15:04"), + mode: modeShort, + rate: rateStr, + dur: "─", + ttft: "─", + tps: fmt.Sprintf("%d/%d 正在运行...", rs.DoneReqs, rs.TotalReqs), + } + } + for histIdx := 0; histIdx < len(historyEntries); histIdx++ { + rowIdx := histIdx + if hasActive { + rowIdx++ + } + run := historyEntries[histIdx] + statusText := "✗" + statusIsOk := false + statusMut := false + switch run.Status { + case string(server.RunStatusRunning): + statusText = "●" + statusIsOk = true + case string(server.RunStatusCompleted): + statusText = "✓" + statusIsOk = true + case string(server.RunStatusStopped): + statusText = "■" + statusMut = true + } + modeShort := "标准" + if run.Mode == "turbo" { + modeShort = "Turbo" + } + durText := "─" + if !run.FinishedAt.IsZero() { + durText = fmtDuration(run.FinishedAt.Sub(run.StartedAt)) + } + rowData[rowIdx] = histRow{ + statusText: statusText, + statusIsOk: statusIsOk, + statusIsMut: statusMut, + time: run.StartedAt.Format("2006-01-02 15:04"), + mode: modeShort, + rate: fmt.Sprintf("%.1f%%", run.SuccessRate), + dur: durText, + ttft: fmtDuration(run.AvgTTFT), + tps: fmt.Sprintf("%.1f", run.AvgTPS), + } + } + + // colWidths: 0 = 弹性列,>0 = 固定总宽 + colWidths := []int{4, 0, 7, 8, 7, 7, 7} // 状态图标, 时间=flex, 模式, 成功率, 耗时, TTFT, TPS + sel := s.HistorySel + tableH := tableMaxH - len(rightTitle) + tbl := lgtable.New(). + Headers("", "时间", "模式", "成功率", "耗时", "TTFT", "TPS"). + Width(rightW). + Height(tableH). + YOffset(s.HistoryOff). + BorderTop(false).BorderBottom(false). + BorderLeft(false).BorderRight(false). + BorderHeader(true).BorderColumn(true).BorderRow(true). + BorderStyle(lipgloss.NewStyle().Foreground(colorDivider)). + StyleFunc(func(row, col int) lipgloss.Style { + aw := func(s lipgloss.Style) lipgloss.Style { + if col < len(colWidths) && colWidths[col] > 0 { + return s.Width(colWidths[col]).Padding(0, 1) } - modeShort := "标准" - if run.Mode == "turbo" { - modeShort = "Turbo" + return s.Padding(0, 1) + } + if row == lgtable.HeaderRow { + return aw(st.TableHead) + } + if row < 0 || row >= len(rowData) { + return aw(st.TableRow) + } + r := rowData[row] + if row == sel { + return aw(st.TableRowSel) + } + if col == 0 { // status icon + if r.statusIsOk { + return aw(st.Ok) } - statusIcon := styleWhenNotSelected(isSel, statusStyle, statusText) - durText := "─" - if !run.FinishedAt.IsZero() { - durText = fmtDuration(run.FinishedAt.Sub(run.StartedAt)) + if r.statusIsMut { + return aw(st.Muted) } - row = padRight(marker, markW) + - padRight(statusIcon, statW) + - padRight(run.StartedAt.Format("2006-01-02 15:04"), timeW) + - padRight(modeShort, modeW) + - padRight(fmt.Sprintf("%.1f%%", run.SuccessRate), rateW) + - padRight(durText, durW) + - padRight(fmtDuration(run.AvgTTFT), ttftW) + - fmt.Sprintf("%.1f", run.AvgTPS) + return aw(st.ErrStyle) } - rightLines = append(rightLines, padRight(renderTableRow(st, rightW, isSel, row), rightW)) - if idx < end-1 { - rightLines = append(rightLines, padRight(st.Divider.Render(strings.Repeat("─", rightW)), rightW)) + if col >= 3 { // rate, dur, ttft, tps + return aw(st.Value) } - } - rightLines = append(rightLines, detailLines...) + return aw(st.TableRow) + }) + + for _, r := range rowData { + tbl.Row(r.statusText, r.time, r.mode, r.rate, r.dur, r.ttft, r.tps) + } + + tableStr := tbl.String() + s.HistoryVis = tbl.VisibleRows() + if s.HistoryVis < 1 { + s.HistoryVis = 1 } + s.HistoryOff = ensureVisibleOffset(s.HistorySel, effectiveLen, s.HistoryOff, s.HistoryVis) + + tableLines := strings.Split(tableStr, "\n") + rightLines := append(rightTitle, tableLines...) + rightLines = append(rightLines, detailLines...) rightContent := finishPanelLines(rightLines, panelContentH) return renderSplitPanels(st, leftPanelFrame, rightPanelFrame, leftContent, rightContent) } diff --git a/internal/tui/pages/tasklist.go b/internal/tui/pages/tasklist.go index 0b56c06..027f5f6 100644 --- a/internal/tui/pages/tasklist.go +++ b/internal/tui/pages/tasklist.go @@ -7,6 +7,7 @@ import ( tea "github.com/charmbracelet/bubbletea" "charm.land/lipgloss/v2" + lgtable "charm.land/lipgloss/v2/table" "github.com/yinxulai/ait/internal/server" "github.com/yinxulai/ait/internal/types" ) @@ -194,109 +195,54 @@ func RenderTaskList(s *TaskListState, st Styles, width, height int) string { // buildTaskListContent 构建任务列表内容区(含表头 + 任务条目)。 func buildTaskListContent(s *TaskListState, st Styles, width, maxH int) string { - var lines []string - listTopLines := len(lines) - - // 列宽(gap=2 作为列间距内置到每个非末尾列的宽度中) - const ( - modeW = 9 // 7 + 2 gap - protoW = 18 // 8 + 2 gap - lastRunW = 16 // 11 + 2 gap - rateW = 10 // 6 + 2 gap -- 成功率 - ttftW = 10 // 6 + 2 gap - tpsW = 16 // 末尾列,无需额外 gap - ) - fixedW := 2 + modeW + protoW + lastRunW + rateW + ttftW + tpsW - nameW := maxInt(10, width-fixedW) - - // 表头:2 空格前缀与正文行对齐(cursor=2) - header := renderTableHeader(st, width, - lipgloss.JoinHorizontal(lipgloss.Top, - tableCol(2, ""), - tableCol(nameW, "任务名称"), - tableCol(modeW, "模式"), - tableCol(protoW, "协议"), - tableCol(lastRunW, "上次运行"), - tableCol(rateW, "成功率"), - tableCol(ttftW, "TTFT"), - "TPS", - )) - lines = append(lines, header) - lines = append(lines, dividerLine(st, width)) - listMaxH := maxInt(3, maxH-listTopLines) - s.Visible = listVisibleItems(listMaxH, 2) - s.Offset = ensureVisibleOffset(s.Selected, len(s.Tasks), s.Offset, s.Visible) - - if len(s.Tasks) == 0 { - lines = append(lines, "") - lines = append(lines, " "+st.Muted.Render("暂无任务 按 [a] 新建第一个任务")) - // 补齐剩余行 - for len(lines) < maxH { - lines = append(lines, "") - } - return strings.Join(lines, "\n") + // ── 预计算每行数据(供 StyleFunc 闭包引用)── + type taskRowData struct { + name string + mode string + isTurbo bool + proto string + lastRun string + isRunning bool + rate string + ttft string + tps string } - start := s.Offset - end := minInt(len(s.Tasks), start+s.Visible) - for i := start; i < end; i++ { - t := s.Tasks[i] - - isSel := i == s.Selected + sel := s.Selected + rowData := make([]taskRowData, len(s.Tasks)) + for i, t := range s.Tasks { rs := s.ActiveRuns[t.ID] _, hasActiveRun := s.ActiveRuns[t.ID] - // ── 指示符 ── - prefix := tableCol(2, selectionMarker(isSel)) - - // ── 模式(选中行禁用嵌套样式,避免重置整行背景)── modeText := "标准" - var modeCol string + isTurbo := false if t.Input.Turbo { modeText = "Turbo" - modeCol = tableCol(modeW, styleWhenNotSelected(isSel, lipgloss.NewStyle().Foreground(colorGold).Bold(true), modeText)) - } else { - modeCol = tableCol(modeW, styleWhenNotSelected(isSel, lipgloss.NewStyle().Foreground(colorPurple), modeText)) + isTurbo = true } - // ── 协议 ── - proto := tableCol(protoW, shortProtocol(t.Input.NormalizedProtocol())) - - // ── 任务名称 ── - nameCol := tableCol(nameW, t.Name) - - // ── 上次运行时间 ── + isRunning := hasActiveRun || (t.LatestRun != nil && t.LatestRun.Status == string(server.RunStatusRunning)) lastRunText := "─" - if hasActiveRun || (t.LatestRun != nil && t.LatestRun.Status == string(server.RunStatusRunning)) { + if isRunning { lastRunText = "运行中" } else if t.LatestRun != nil && !t.LatestRun.FinishedAt.IsZero() { lastRunText = fmtRelativeTime(t.LatestRun.FinishedAt) } - lastRunStyle := st.Muted - if hasActiveRun || (t.LatestRun != nil && t.LatestRun.Status == string(server.RunStatusRunning)) { - lastRunStyle = st.Ok + + rateText := "─" + if hasActiveRun && rs != nil && rs.TotalReqs > 0 { + rateText = fmt.Sprintf("%.1f%%", rs.SuccessRate) + } else if !hasActiveRun && t.LatestRun != nil { + rateText = fmt.Sprintf("%.1f%%", t.LatestRun.SuccessRate) } - lastRunCol := tableCol(lastRunW, styleWhenNotSelected(isSel, lastRunStyle, lastRunText)) - // ── TTFT ── ttftText := "─" if hasActiveRun && rs != nil && rs.AvgTTFT > 0 { ttftText = fmtDuration(rs.AvgTTFT) } else if !hasActiveRun && t.LatestRun != nil { ttftText = fmtDuration(t.LatestRun.AvgTTFT) } - ttftCol := tableCol(ttftW, styleWhenNotSelected(isSel, st.Value, ttftText)) - // ── 成功率 ── - rateText := "─" - if hasActiveRun && rs != nil && rs.TotalReqs > 0 { - rateText = fmt.Sprintf("%.1f%%", rs.SuccessRate) - } else if !hasActiveRun && t.LatestRun != nil { - rateText = fmt.Sprintf("%.1f%%", t.LatestRun.SuccessRate) - } - rateCol := tableCol(rateW, styleWhenNotSelected(isSel, st.Value, rateText)) - - // ── TPS ── tpsText := "─" if hasActiveRun && rs != nil && rs.AvgTPS > 0 { tpsText = fmt.Sprintf("%.1f", rs.AvgTPS) @@ -307,24 +253,103 @@ func buildTaskListContent(s *TaskListState, st Styles, width, maxH int) string { tpsText = fmt.Sprintf("%.1f", t.LatestRun.AvgTPS) } } - tpsCol := styleWhenNotSelected(isSel, st.Value, tpsText) - - // ── 单行:名称 | 模式 | 协议 | 上次运行 | 成功率 | TTFT | TPS ── - lines = append(lines, renderTableRow(st, width, isSel, lipgloss.JoinHorizontal(lipgloss.Top, - prefix, nameCol, modeCol, proto, lastRunCol, rateCol, ttftCol, tpsCol))) - // ── 分隔线 ── - if i < end-1 && len(lines) < maxH-1 { - lines = append(lines, dividerLine(st, width)) + rowData[i] = taskRowData{ + name: t.Name, + mode: modeText, + isTurbo: isTurbo, + proto: shortProtocol(t.Input.NormalizedProtocol()), + lastRun: lastRunText, + isRunning: isRunning, + rate: rateText, + ttft: ttftText, + tps: tpsText, } } - // 补齐剩余行 - for len(lines) < maxH { - lines = append(lines, "") + // ── 构建 lipgloss/table ── + // colWidths: 0 = 弹性列(占用剩余宽度),>0 = 固定总宽(包括两端各 1 字符 padding) + colWidths := []int{0, 8, 22, 12, 8, 10, 10} // 任务名称=flex, 模式, 协议, 上次运行, 成功率, TTFT, TPS + t := lgtable.New(). + Headers("任务名称", "模式", "协议", "上次运行", "成功率", "TTFT", "TPS"). + Width(width). + Height(maxH). + YOffset(s.Offset). + BorderTop(false).BorderBottom(false). + BorderLeft(false).BorderRight(false). + BorderHeader(true).BorderColumn(true).BorderRow(true). + BorderStyle(lipgloss.NewStyle().Foreground(colorDivider)). + StyleFunc(func(row, col int) lipgloss.Style { + aw := func(s lipgloss.Style) lipgloss.Style { + if col < len(colWidths) && colWidths[col] > 0 { + return s.Width(colWidths[col]).Padding(0, 1) + } + return s.Padding(0, 1) + } + if row == lgtable.HeaderRow { + return aw(st.TableHead) + } + if row < 0 || row >= len(rowData) { + return aw(st.TableRow) + } + r := rowData[row] + if row == sel { + return aw(st.TableRowSel) + } + switch col { + case 1: // mode + if r.isTurbo { + return aw(lipgloss.NewStyle().Foreground(colorGold).Bold(true)) + } + return aw(lipgloss.NewStyle().Foreground(colorPurple)) + case 3: // lastRun + if r.isRunning { + return aw(st.Ok) + } + return aw(st.Muted) + case 4, 5, 6: // rate, ttft, tps + return aw(st.Value) + default: + return aw(st.TableRow) + } + }) + + for _, r := range rowData { + t.Row(r.name, r.mode, r.proto, r.lastRun, r.rate, r.ttft, r.tps) } - return strings.Join(lines, "\n") + tableStr := t.String() + s.Visible = t.VisibleRows() + if s.Visible < 1 { + s.Visible = 1 + } + s.Offset = ensureVisibleOffset(s.Selected, len(s.Tasks), s.Offset, s.Visible) + + // 空任务状态:在表头下方显示提示 + if len(s.Tasks) == 0 { + tableLines := strings.Split(tableStr, "\n") + for len(tableLines) < maxH-1 { + tableLines = append(tableLines, "") + } + tableLines = append(tableLines, " "+st.Muted.Render("暂无任务 按 [a] 新建第一个任务")) + if len(tableLines) > maxH { + tableLines = tableLines[:maxH] + } + for len(tableLines) < maxH { + tableLines = append(tableLines, "") + } + return strings.Join(tableLines, "\n") + } + + // 补齐至 maxH + tableLines := strings.Split(tableStr, "\n") + for len(tableLines) < maxH { + tableLines = append(tableLines, "") + } + if len(tableLines) > maxH { + tableLines = tableLines[:maxH] + } + return strings.Join(tableLines, "\n") } // buildTaskListConfirmContent 渲染删除确认对话框内容。 diff --git a/internal/tui/pages/turbodash.go b/internal/tui/pages/turbodash.go index 5bbcb4f..2751352 100644 --- a/internal/tui/pages/turbodash.go +++ b/internal/tui/pages/turbodash.go @@ -6,6 +6,7 @@ import ( tea "github.com/charmbracelet/bubbletea" "charm.land/lipgloss/v2" + lgtable "charm.land/lipgloss/v2/table" "github.com/yinxulai/ait/internal/server" "github.com/yinxulai/ait/internal/types" ) @@ -313,53 +314,36 @@ func buildTurboProgressLine(rs *server.RunState, st Styles, width int) string { // buildTurboRequestList 构建 Turbo 模式请求列表区域。 func buildTurboRequestList(d *TurboDashState, rs *server.RunState, st Styles, width, maxH int) string { - lines := panelTitleLines(st, "请求列表", width, true) + titleLines := panelTitleLines(st, "请求列表", width, true) if rs == nil || len(rs.Requests) == 0 { msg := "等待请求..." if rs != nil && rs.Status != server.RunStatusRunning { msg = "无请求详情数据" } - lines = append(lines, " "+st.Muted.Render(msg)) - return finishPanelLines(lines, maxH) + titleLines = append(titleLines, " "+st.Muted.Render(msg)) + return finishPanelLines(titleLines, maxH) } - const ( - markW = 2 - idW = 6 - statW = 5 - levelW = 6 - timeW = 10 - ttftW = 10 - cacheW = 8 - ptokW = 9 // 提示 tok - ctokW = 9 // 完成 tok - ) - hdr := lipgloss.JoinHorizontal(lipgloss.Top, - tableCol(markW, ""), - tableCol(idW, "#"), - tableCol(statW, "状态"), - tableCol(levelW, "级别"), - tableCol(timeW, "总耗时"), - tableCol(ttftW, "TTFT"), - tableCol(cacheW, "Cache"), - tableCol(ptokW, "提示tok"), - tableCol(ctokW, "完成tok"), - "TPS", - ) - lines = append(lines, renderTableHeader(st, width, hdr)) - lines = append(lines, dividerLine(st, width)) - d.ReqVis = listVisibleItems(maxH, 3) - d.AdjustReqOffset(d.ReqVis, len(rs.Requests)) - + // ── 预计算每行数据(按展示顺序,最新在前)── + type reqRow struct { + success bool + errMsg string + id string + status string + level string + total string + ttft string + cache string + ptok string + ctok string + tps string + } reqs := rs.Requests - start := d.ReqOff - end := minInt(len(reqs), start+d.ReqVis) - for pos := start; pos < end; pos++ { + reqRows := make([]reqRow, len(reqs)) + for pos := 0; pos < len(reqs); pos++ { i := requestIndexFromDisplayPos(pos, len(reqs)) r := reqs[i] - isSel := i == d.ReqSel - statusText := "✓" if !r.Success { statusText = "✗" @@ -368,39 +352,82 @@ func buildTurboRequestList(d *TurboDashState, rs *server.RunState, st Styles, wi if !r.Success && r.ErrorMessage != "" { totalText = r.ErrorMessage } - - statusStr := statusText - if r.Success { - statusStr = styleWhenNotSelected(isSel, st.Ok, statusText) - } else { - statusStr = styleWhenNotSelected(isSel, st.ErrStyle, statusText) - } - totalStr := totalText - if !r.Success && r.ErrorMessage != "" { - totalStr = styleWhenNotSelected(isSel, st.ErrStyle, totalText) + reqRows[pos] = reqRow{ + success: r.Success, + errMsg: r.ErrorMessage, + id: fmt.Sprintf("#%d", len(reqs)-pos), + status: statusText, + level: fmt.Sprintf("%d", r.Level), + total: totalText, + ttft: fmtDuration(r.TTFT), + cache: fmt.Sprintf("%dtok", r.CachedTokens), + ptok: fmt.Sprintf("%dtok", r.PromptTokens), + ctok: fmt.Sprintf("%dtok", r.CompletionTokens), + tps: fmt.Sprintf("%.1f/s", r.TPS), } + } - marker := selectionMarker(isSel) - rowContent := lipgloss.JoinHorizontal(lipgloss.Top, - tableCol(markW, marker), - tableCol(idW, fmt.Sprintf("#%d", len(reqs)-pos)), - tableCol(statW, statusStr), - tableCol(levelW, fmt.Sprintf("%d", r.Level)), - tableCol(timeW, totalStr), - tableCol(ttftW, fmtDuration(r.TTFT)), - tableCol(cacheW, fmt.Sprintf("%.0f%%", r.CacheHitRate*100)), - tableCol(ptokW, fmt.Sprintf("%dtok", r.PromptTokens)), - tableCol(ctokW, fmt.Sprintf("%dtok", r.CompletionTokens)), - fmt.Sprintf("%.1f/s", r.TPS), - ) - - rendered := renderTableRow(st, width, isSel, rowContent) - lines = append(lines, rendered) - - if pos < end-1 && len(lines) < maxH-1 { - lines = append(lines, dividerLine(st, width)) - } + selDisplayPos := requestDisplayPos(d.ReqSel, len(reqs)) + + // colWidths: 0 = 弹性列(占用剩余宽度),>0 = 固定总宽 + colWidths := []int{6, 5, 6, 0, 8, 7, 10, 10, 10} // #, 状态, 级别, 总耗时=flex, TTFT, Cache, 提示tok, 完成tok, TPS + tableH := maxH - len(titleLines) + tbl := lgtable.New(). + Headers("#", "状态", "级别", "总耗时", "TTFT", "Cache", "提示tok", "完成tok", "TPS"). + Width(width). + Height(tableH). + YOffset(d.ReqOff). + BorderTop(false).BorderBottom(false). + BorderLeft(false).BorderRight(false). + BorderHeader(true).BorderColumn(true).BorderRow(true). + BorderStyle(lipgloss.NewStyle().Foreground(colorDivider)). + StyleFunc(func(row, col int) lipgloss.Style { + aw := func(s lipgloss.Style) lipgloss.Style { + if col < len(colWidths) && colWidths[col] > 0 { + return s.Width(colWidths[col]).Padding(0, 1) + } + return s.Padding(0, 1) + } + if row == lgtable.HeaderRow { + return aw(st.TableHead) + } + if row < 0 || row >= len(reqRows) { + return aw(st.TableRow) + } + r := reqRows[row] + if row == selDisplayPos { + return aw(st.TableRowSel) + } + switch col { + case 1: // status + if r.success { + return aw(st.Ok) + } + return aw(st.ErrStyle) + case 3: // total + if !r.success && r.errMsg != "" { + return aw(st.ErrStyle) + } + return aw(st.Value) + case 4, 5, 6, 7, 8: // ttft, cache, ptok, ctok, tps + return aw(st.Value) + default: + return aw(st.TableRow) + } + }) + + for _, r := range reqRows { + tbl.Row(r.id, r.status, r.level, r.total, r.ttft, r.cache, r.ptok, r.ctok, r.tps) } - return finishPanelLines(lines, maxH) + tableStr := tbl.String() + d.ReqVis = tbl.VisibleRows() + if d.ReqVis < 1 { + d.ReqVis = 1 + } + d.AdjustReqOffset(d.ReqVis, len(reqs)) + + tableLines := strings.Split(tableStr, "\n") + result := append(titleLines, tableLines...) + return finishPanelLines(result, maxH) } From 6ad91f5da881d8eb34a3ea23e181c5421741a310 Mon Sep 17 00:00:00 2001 From: Alain Date: Sat, 23 May 2026 21:11:21 +0800 Subject: [PATCH 39/52] =?UTF-8?q?feat:=20=E8=B0=83=E6=95=B4=E4=BB=AA?= =?UTF-8?q?=E8=A1=A8=E6=9D=BF=E5=92=8C=20TurboDash=20=E5=88=97=E5=AE=BD?= =?UTF-8?q?=EF=BC=8C=E6=9B=B4=E6=96=B0=E6=98=BE=E7=A4=BA=E5=86=85=E5=AE=B9?= =?UTF-8?q?=E4=BB=A5=E6=94=AF=E6=8C=81=E8=BE=93=E5=85=A5=E5=92=8C=E8=BE=93?= =?UTF-8?q?=E5=87=BA=E4=BF=A1=E6=81=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/tui/pages/dashboard.go | 7 +++---- internal/tui/pages/turbodash.go | 6 +++--- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/internal/tui/pages/dashboard.go b/internal/tui/pages/dashboard.go index c16d2b7..6f8eadb 100644 --- a/internal/tui/pages/dashboard.go +++ b/internal/tui/pages/dashboard.go @@ -209,7 +209,7 @@ func RenderDashboard(d *DashboardState, taskName string, st Styles, width, heigh bodyPanel := frame.InnerPanel() // ── 计算高度 ── - splitOuterH := 9 // 双栏面板外部总高度(含面板边框) + splitOuterH := 7 // 双栏面板外部总高度(含面板边框) progressOuterH := 3 // 进度条面板外部高度(1内容+2边框) reqOuterH := RemainingStackOuterHeight(frame.InnerHeight, splitOuterH, progressOuterH) reqListH := PanelContentHeight(reqOuterH) @@ -263,7 +263,6 @@ func buildDashMetricsPanel(rs *server.RunState, st Styles, maxH, width int) stri st.MetricVal.Render(fmtDuration(rs.AvgTTFT)))) lines = append(lines, " "+labelValue(st, "缓存命中", st.MetricVal.Render(fmt.Sprintf("%.1f%%", rs.CacheHitRate*100)))) - lines = append(lines, " "+st.Muted.Render(fmt.Sprintf(" 成功: %d 失败: %d", rs.SuccessReqs, rs.FailedReqs))) } return finishPanelLines(lines, maxH) @@ -363,10 +362,10 @@ func buildRequestList(d *DashboardState, rs *server.RunState, st Styles, width, selDisplayPos := requestDisplayPos(d.ReqSel, len(reqs)) // colWidths: 0 = 弹性列(占用剩余宽度),>0 = 固定总宽 - colWidths := []int{6, 8, 0, 8, 7, 10, 10, 10} // #, 状态, 总耗时=flex, TTFT, Cache, 提示tok, 完成tok, TPS + colWidths := []int{6, 8, 0, 8, 10, 12, 12, 10} // #, 状态, 总耗时=flex, TTFT, Cache, 输入, 输出, TPS tableH := maxH - len(titleLines) tbl := lgtable.New(). - Headers("#", "状态", "总耗时", "TTFT", "Cache", "提示tok", "完成tok", "TPS"). + Headers("#", "状态", "总耗时", "TTFT", "Cache", "输入", "输出", "TPS"). Width(width). Height(tableH). YOffset(d.ReqOff). diff --git a/internal/tui/pages/turbodash.go b/internal/tui/pages/turbodash.go index 2751352..a19065d 100644 --- a/internal/tui/pages/turbodash.go +++ b/internal/tui/pages/turbodash.go @@ -213,7 +213,7 @@ func RenderTurboDash(d *TurboDashState, taskName string, st Styles, width, heigh bodyPanel := frame.InnerPanel() // ── 计算高度 ── - splitOuterH := 9 + splitOuterH := 7 progressOuterH := 3 levelOuterH := RemainingStackOuterHeight(frame.InnerHeight, splitOuterH, progressOuterH) levelListH := PanelContentHeight(levelOuterH) @@ -370,10 +370,10 @@ func buildTurboRequestList(d *TurboDashState, rs *server.RunState, st Styles, wi selDisplayPos := requestDisplayPos(d.ReqSel, len(reqs)) // colWidths: 0 = 弹性列(占用剩余宽度),>0 = 固定总宽 - colWidths := []int{6, 5, 6, 0, 8, 7, 10, 10, 10} // #, 状态, 级别, 总耗时=flex, TTFT, Cache, 提示tok, 完成tok, TPS + colWidths := []int{6, 5, 6, 0, 8, 10, 12, 12, 10} // #, 状态, 级别, 总耗时=flex, TTFT, Cache, 输入, 输出, TPS tableH := maxH - len(titleLines) tbl := lgtable.New(). - Headers("#", "状态", "级别", "总耗时", "TTFT", "Cache", "提示tok", "完成tok", "TPS"). + Headers("#", "状态", "级别", "总耗时", "TTFT", "Cache", "输入", "输出", "TPS"). Width(width). Height(tableH). YOffset(d.ReqOff). From d99f53d7c107eb46fe701239bcf4ab8ea3bf722c Mon Sep 17 00:00:00 2001 From: Alain Date: Sat, 23 May 2026 21:12:14 +0800 Subject: [PATCH 40/52] =?UTF-8?q?fix:=20=E4=BF=AE=E6=AD=A3=E6=80=A7?= =?UTF-8?q?=E8=83=BD=E9=9D=A2=E6=9D=BF=E4=B8=AD=E2=80=9C=E4=BB=A4=E7=89=8C?= =?UTF-8?q?=E2=80=9D=E6=A0=87=E7=AD=BE=E4=B8=BA=E2=80=9CToken=E2=80=9D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/tui/pages/reqdetail.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/tui/pages/reqdetail.go b/internal/tui/pages/reqdetail.go index bfabbbc..626ae44 100644 --- a/internal/tui/pages/reqdetail.go +++ b/internal/tui/pages/reqdetail.go @@ -196,7 +196,7 @@ func buildReqPerfPanel(r *types.RequestMetrics, st Styles, maxH, width int) stri lines = append(lines, " "+labelValue(st, "总耗时 ", st.MetricVal.Render(totalTime))) lines = append(lines, " "+labelValue(st, "TTFT ", st.MetricVal.Render(ttft))) lines = append(lines, " "+labelValue(st, "输出TPS ", st.MetricVal.Render(tps))) - lines = append(lines, " "+labelValue(st, "令牌 ", tokenSummary)) + lines = append(lines, " "+labelValue(st, "Token ", tokenSummary)) if r.Success { lines = append(lines, " "+labelValue(st, "缓存 ", cacheSummary)) } else { From 4d95e5785e586b1774037535f973d375142bee2b Mon Sep 17 00:00:00 2001 From: Alain Date: Sat, 23 May 2026 22:30:08 +0800 Subject: [PATCH 41/52] =?UTF-8?q?feat:=20=E9=87=8D=E6=9E=84=E8=BF=9B?= =?UTF-8?q?=E5=BA=A6=E6=9D=A1=E6=B8=B2=E6=9F=93=E9=80=BB=E8=BE=91=EF=BC=8C?= =?UTF-8?q?=E7=AE=80=E5=8C=96=E4=BB=A3=E7=A0=81=E5=B9=B6=E6=8F=90=E5=8D=87?= =?UTF-8?q?=E5=8F=AF=E8=AF=BB=E6=80=A7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/tui/pages/dashboard.go | 16 +------------- internal/tui/pages/helpers.go | 38 ++++++++++++++++++--------------- internal/tui/pages/proxy.go | 16 ++++++-------- internal/tui/pages/turbodash.go | 16 +------------- internal/tui/pages/wizard.go | 9 +++----- 5 files changed, 32 insertions(+), 63 deletions(-) diff --git a/internal/tui/pages/dashboard.go b/internal/tui/pages/dashboard.go index 6f8eadb..e98a28b 100644 --- a/internal/tui/pages/dashboard.go +++ b/internal/tui/pages/dashboard.go @@ -279,7 +279,6 @@ func buildProgressLine(rs *server.RunState, st Styles, width int) string { if total > 0 { ratio = float64(done) / float64(total) } - prefix := " 进度 " elapsed := "─" if !rs.StartedAt.IsZero() { if rs.FinishedAt != nil { @@ -289,20 +288,7 @@ func buildProgressLine(rs *server.RunState, st Styles, width int) string { } } suffix := fmt.Sprintf(" %d / %d %s", done, total, elapsed) - - barW := width - lipgloss.Width(prefix) - lipgloss.Width(suffix) - if barW < 5 { - barW = 5 - // 压缩 suffix 确保进度行总宽度不超过 width,防止 lipgloss 折行 - maxSuffixW := maxInt(0, width-lipgloss.Width(prefix)-barW) - suffix = truncate(suffix, maxSuffixW) - } - - filled := int(ratio * float64(barW)) - barRendered := st.Ok.Render(strings.Repeat("█", filled)) + - st.Muted.Render(strings.Repeat("░", barW-filled)) - - return lipgloss.JoinHorizontal(lipgloss.Top, prefix, barRendered, suffix) + return renderProgressBar(st, " 进度 ", suffix, ratio, width) } // buildRequestList 构建请求列表区域。 diff --git a/internal/tui/pages/helpers.go b/internal/tui/pages/helpers.go index e7d0868..ead1262 100644 --- a/internal/tui/pages/helpers.go +++ b/internal/tui/pages/helpers.go @@ -34,11 +34,7 @@ func truncate(s string, maxW int) string { // padRight 右侧补空格至 width(按可见列宽)。 func padRight(s string, width int) string { - w := lipgloss.Width(s) - if w >= width { - return s - } - return s + strings.Repeat(" ", width-w) + return lipgloss.NewStyle().Width(width).Render(s) } // wrapText 将文本按 maxW 列宽折行,返回行切片(CJK 字符按 2 列宽计算)。 @@ -75,6 +71,20 @@ func wrapText(s string, maxW int) []string { return result } +// renderProgressBar 渲染进度条行:prefix 固定在左,suffix 固定在右,中间弹性进度条。 +func renderProgressBar(st Styles, prefix, suffix string, ratio float64, totalW int) string { + barW := totalW - lipgloss.Width(prefix) - lipgloss.Width(suffix) + if barW < 5 { + barW = 5 + maxSuffixW := maxInt(0, totalW-lipgloss.Width(prefix)-barW) + suffix = truncate(suffix, maxSuffixW) + } + filled := int(ratio * float64(barW)) + barRendered := st.Ok.Render(strings.Repeat("█", filled)) + + st.Muted.Render(strings.Repeat("░", barW-filled)) + return lipgloss.JoinHorizontal(lipgloss.Top, prefix, barRendered, suffix) +} + // dividerLine 生成全宽水平分隔线。 func dividerLine(st Styles, width int) string { if width <= 0 { @@ -172,10 +182,6 @@ func renderHeader(st Styles, width int, title, subtitle, meta string, infoLeft, styleT := lipgloss.NewStyle().Foreground(colorCyan).Bold(true) styleSep := lipgloss.NewStyle().Foreground(colorPink) - // artVisW = 10+2+5+2+10 = 29; artSepW = " "(1) + art(29) + " "(1) + "┃"(1) + " "(1) = 33 - artVisW := 10 + 2 + 5 + 2 + 10 - artSepW := artVisW + 4 - artRow := func(i int) string { return styleA.Render(artA[i]) + " " + styleI.Render(artI[i]) + " " + styleT.Render(artT[i]) } @@ -227,7 +233,8 @@ func renderHeader(st Styles, width int, title, subtitle, meta string, infoLeft, var left3 string if wideEnough { artPart := " " + artRow(2) + " " + vsep - availW := maxInt(8, w-artSepW-2-maxInt(10, w/3)) + artPartW := lipgloss.Width(artPart) + 1 // +1 为 art 与 pills 之间的分隔空格 + availW := maxInt(8, w-artPartW-2-maxInt(10, w/3)) if leftPills := renderInfoPills(infoLeft, availW); leftPills != "" { left3 = artPart + " " + leftPills } else { @@ -267,13 +274,10 @@ func renderHotkeys(st Styles, width int, hk PageHotkeys) string { } func renderChromeLine(base lipgloss.Style, width int, left, right string) string { - leftW := lipgloss.Width(left) rightW := lipgloss.Width(right) - pad := width - leftW - rightW - if pad < 0 { - pad = 0 - } - return base.Width(width).Render(left + strings.Repeat(" ", pad) + right) + spacerW := maxInt(0, width-lipgloss.Width(left)-rightW) + spacer := lipgloss.NewStyle().Width(spacerW).Render("") + return base.Width(width).Render(left + spacer + right) } func renderInfoPills(parts []string, maxW int) string { @@ -445,7 +449,7 @@ func panelTitleLines(st Styles, title string, width int, compact bool) []string var rendered string if width > 0 { // 截断标题防止超宽后被 lipgloss 折行 - rendered = st.PanelHead.Width(width).Render(" " + truncate(title, maxInt(1, width-1))) + rendered = st.PanelHead.Width(width).Padding(0, 0, 0, 1).Render(truncate(title, maxInt(1, width-1))) } else { rendered = st.PanelHead.Render(" " + title) } diff --git a/internal/tui/pages/proxy.go b/internal/tui/pages/proxy.go index ba2e4e4..47b8a20 100644 --- a/internal/tui/pages/proxy.go +++ b/internal/tui/pages/proxy.go @@ -217,11 +217,9 @@ func buildProxyConfigContent(s *ProxyConfigState, st Styles, contentW, maxH int) if s.FieldIndex == 0 { typeFieldStyle = st.FieldActive } - typeLabelBlock := strings.Join([]string{ - strings.Repeat(" ", 15), - lipgloss.NewStyle().Width(15).Render(st.Label.Render("代理类型")), - strings.Repeat(" ", 15), - }, "\n") + typeLabelBlock := lipgloss.NewStyle().Width(15).Height(3). + AlignVertical(lipgloss.Center). + Render(st.Label.Render("代理类型")) typeRendered := typeFieldStyle.Width(fieldW + 4).Render(st.Value.Render(typeLabel)) appendBlock(lipgloss.JoinHorizontal(lipgloss.Top, typeLabelBlock, typeRendered)) @@ -242,11 +240,9 @@ func buildProxyConfigContent(s *ProxyConfigState, st Styles, contentW, maxH int) urlRendered = urlFieldStyle.Width(fieldW + 4).Render(st.Value.Render(fitTail(v, fieldW))) } } - urlLabelBlock := strings.Join([]string{ - strings.Repeat(" ", 15), - lipgloss.NewStyle().Width(15).Render(st.Label.Render("代理地址")), - strings.Repeat(" ", 15), - }, "\n") + urlLabelBlock := lipgloss.NewStyle().Width(15).Height(3). + AlignVertical(lipgloss.Center). + Render(st.Label.Render("代理地址")) appendBlock(lipgloss.JoinHorizontal(lipgloss.Top, urlLabelBlock, urlRendered)) lines = append(lines, "") diff --git a/internal/tui/pages/turbodash.go b/internal/tui/pages/turbodash.go index a19065d..265edd0 100644 --- a/internal/tui/pages/turbodash.go +++ b/internal/tui/pages/turbodash.go @@ -294,22 +294,8 @@ func buildTurboProgressLine(rs *server.RunState, st Styles, width int) string { } else { levelTotalStr = fmt.Sprintf("%d", levelDone) } - prefix := " 进度 " suffix := fmt.Sprintf(" %d/%d 当前并发 %d 总进度: %s 级", done, total, rs.CurrentLevel, levelTotalStr) - - barW := width - lipgloss.Width(prefix) - lipgloss.Width(suffix) - if barW < 5 { - barW = 5 - // 压缩 suffix 确保进度行总宽度不超过 width,防止 lipgloss 折行 - maxSuffixW := maxInt(0, width-lipgloss.Width(prefix)-barW) - suffix = truncate(suffix, maxSuffixW) - } - - filled := int(ratio * float64(barW)) - barRendered := st.Ok.Render(strings.Repeat("█", filled)) + - st.Muted.Render(strings.Repeat("░", barW-filled)) - - return lipgloss.JoinHorizontal(lipgloss.Top, prefix, barRendered, suffix) + return renderProgressBar(st, " 进度 ", suffix, ratio, width) } // buildTurboRequestList 构建 Turbo 模式请求列表区域。 diff --git a/internal/tui/pages/wizard.go b/internal/tui/pages/wizard.go index 646d9f4..5271dd5 100644 --- a/internal/tui/pages/wizard.go +++ b/internal/tui/pages/wizard.go @@ -833,12 +833,9 @@ func renderWizardField(st Styles, f fieldDef, wz *WizardState, active bool, maxW } else { renderedValue = fieldStyle.Width(fieldW + 4).Render(valueStyle.Render(valueStr)) } - labelLines := []string{ - strings.Repeat(" ", 15), - lipgloss.NewStyle().Width(15).Render(st.Label.Render(wizardFieldLabel(f, wz))), - strings.Repeat(" ", 15), - } - labelBlock := strings.Join(labelLines, "\n") + labelBlock := lipgloss.NewStyle().Width(15).Height(3). + AlignVertical(lipgloss.Center). + Render(st.Label.Render(wizardFieldLabel(f, wz))) return lipgloss.JoinHorizontal(lipgloss.Top, labelBlock, renderedValue) } From 3dc70de3ff9cbf7ec251da2bd7b6e98d201c6057 Mon Sep 17 00:00:00 2001 From: Alain Date: Sat, 23 May 2026 22:48:03 +0800 Subject: [PATCH 42/52] =?UTF-8?q?feat:=20=E9=87=8D=E6=9E=84=E4=BB=AA?= =?UTF-8?q?=E8=A1=A8=E7=9B=98=E5=92=8C=20TurboDash=20=E8=BF=90=E8=A1=8C?= =?UTF-8?q?=E7=8A=B6=E6=80=81=E9=80=BB=E8=BE=91=EF=BC=8C=E7=AE=80=E5=8C=96?= =?UTF-8?q?=E4=BB=A3=E7=A0=81=E5=B9=B6=E5=A2=9E=E5=BC=BA=E5=8F=AF=E8=AF=BB?= =?UTF-8?q?=E6=80=A7=EF=BC=9B=E6=B7=BB=E5=8A=A0=E8=BF=90=E8=A1=8C=E6=8C=87?= =?UTF-8?q?=E6=A0=87=E6=98=BE=E7=A4=BA=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/tui/pages/dashboard.go | 20 ++++---------------- internal/tui/pages/helpers.go | 32 ++++++++++++++++++++++++++++++++ internal/tui/pages/taskdetail.go | 30 ++++++------------------------ internal/tui/pages/tasklist.go | 31 ++++++++++++++----------------- internal/tui/pages/turbodash.go | 16 ++++------------ 5 files changed, 60 insertions(+), 69 deletions(-) diff --git a/internal/tui/pages/dashboard.go b/internal/tui/pages/dashboard.go index e98a28b..1b3e6d3 100644 --- a/internal/tui/pages/dashboard.go +++ b/internal/tui/pages/dashboard.go @@ -36,10 +36,10 @@ func NewDashboardState(runID server.RunID, taskID string) *DashboardState { // IsRunning 判断运行是否仍在进行。 func (d *DashboardState) IsRunning() bool { - if d == nil || d.RunState == nil { + if d == nil { return false } - return d.RunState.Status == server.RunStatusRunning + return isRunStateRunning(d.RunState) } // AdjustReqOffset 根据屏幕显示顺序调整列表可见窗口。 @@ -255,14 +255,7 @@ func buildDashMetricsPanel(rs *server.RunState, st Styles, maxH, width int) stri if rs == nil { lines = append(lines, " "+st.Muted.Render("等待数据...")) } else { - lines = append(lines, " "+labelValue(st, "成功率 ", - st.MetricVal.Render(fmt.Sprintf("%.1f%%", rs.SuccessRate)))) - lines = append(lines, " "+labelValue(st, "avg TPS ", - st.MetricVal.Render(fmt.Sprintf("%.1f tok/s", rs.AvgTPS)))) - lines = append(lines, " "+labelValue(st, "avg TTFT", - st.MetricVal.Render(fmtDuration(rs.AvgTTFT)))) - lines = append(lines, " "+labelValue(st, "缓存命中", - st.MetricVal.Render(fmt.Sprintf("%.1f%%", rs.CacheHitRate*100)))) + lines = appendRunMetricLines(lines, st, rs) } return finishPanelLines(lines, maxH) @@ -360,12 +353,7 @@ func buildRequestList(d *DashboardState, rs *server.RunState, st Styles, width, BorderHeader(true).BorderColumn(true).BorderRow(true). BorderStyle(lipgloss.NewStyle().Foreground(colorDivider)). StyleFunc(func(row, col int) lipgloss.Style { - aw := func(s lipgloss.Style) lipgloss.Style { - if col < len(colWidths) && colWidths[col] > 0 { - return s.Width(colWidths[col]).Padding(0, 1) - } - return s.Padding(0, 1) - } + aw := func(s lipgloss.Style) lipgloss.Style { return applyColWidth(s, col, colWidths) } if row == lgtable.HeaderRow { return aw(st.TableHead) } diff --git a/internal/tui/pages/helpers.go b/internal/tui/pages/helpers.go index ead1262..568814b 100644 --- a/internal/tui/pages/helpers.go +++ b/internal/tui/pages/helpers.go @@ -6,6 +6,7 @@ import ( "time" "charm.land/lipgloss/v2" + "github.com/yinxulai/ait/internal/server" ) // ─── 文本工具 ───────────────────────────────────────────────────────────────── @@ -445,6 +446,37 @@ func runStatusText(status string) string { } } +// modeShortLabel 将运行模式字符串转换为短标签。 +func modeShortLabel(mode string) string { + if mode == "turbo" { + return "Turbo" + } + return "标准" +} + +// isRunStateRunning 判断 RunState 是否处于运行状态。 +func isRunStateRunning(rs *server.RunState) bool { + return rs != nil && rs.Status == server.RunStatusRunning +} + +// applyColWidth 按列宽定义应用固定宽度或仅 padding,用于 lgtable StyleFunc 中的 aw 闭包。 +// colWidths[col] > 0 时设固定总宽(含 padding),否则仅添加 padding。 +func applyColWidth(s lipgloss.Style, col int, colWidths []int) lipgloss.Style { + if col < len(colWidths) && colWidths[col] > 0 { + return s.Width(colWidths[col]).Padding(0, 1) + } + return s.Padding(0, 1) +} + +// appendRunMetricLines 向 lines 追加 4 行运行指标(成功率/TPS/TTFT/缓存命中)。 +func appendRunMetricLines(lines []string, st Styles, rs *server.RunState) []string { + lines = append(lines, " "+labelValue(st, "成功率 ", st.MetricVal.Render(fmt.Sprintf("%.1f%%", rs.SuccessRate)))) + lines = append(lines, " "+labelValue(st, "TPS ", st.MetricVal.Render(fmt.Sprintf("%.1f tok/s", rs.AvgTPS)))) + lines = append(lines, " "+labelValue(st, "TTFT ", st.MetricVal.Render(fmtDuration(rs.AvgTTFT)))) + lines = append(lines, " "+labelValue(st, "缓存命中", st.MetricVal.Render(fmt.Sprintf("%.1f%%", rs.CacheHitRate*100)))) + return lines +} + func panelTitleLines(st Styles, title string, width int, compact bool) []string { var rendered string if width > 0 { diff --git a/internal/tui/pages/taskdetail.go b/internal/tui/pages/taskdetail.go index 77d9480..74646c9 100644 --- a/internal/tui/pages/taskdetail.go +++ b/internal/tui/pages/taskdetail.go @@ -310,13 +310,10 @@ func buildTaskDetailContent(s *TaskDetailState, st Styles, t types.TaskDefinitio rowData := make([]histRow, effectiveLen) if hasActive { rs := s.ActiveRun - modeShort := "标准" - if rs.Mode == "turbo" { - modeShort = "Turbo" - } + modeShort := modeShortLabel(rs.Mode) rateStr := "─" if rs.TotalReqs > 0 { - rateStr = fmt.Sprintf("%.0f%%", rs.SuccessRate) + rateStr = fmt.Sprintf("%.1f%%", rs.SuccessRate) } rowData[0] = histRow{ statusText: "●", @@ -349,10 +346,7 @@ func buildTaskDetailContent(s *TaskDetailState, st Styles, t types.TaskDefinitio statusText = "■" statusMut = true } - modeShort := "标准" - if run.Mode == "turbo" { - modeShort = "Turbo" - } + modeShort := modeShortLabel(run.Mode) durText := "─" if !run.FinishedAt.IsZero() { durText = fmtDuration(run.FinishedAt.Sub(run.StartedAt)) @@ -384,12 +378,7 @@ func buildTaskDetailContent(s *TaskDetailState, st Styles, t types.TaskDefinitio BorderHeader(true).BorderColumn(true).BorderRow(true). BorderStyle(lipgloss.NewStyle().Foreground(colorDivider)). StyleFunc(func(row, col int) lipgloss.Style { - aw := func(s lipgloss.Style) lipgloss.Style { - if col < len(colWidths) && colWidths[col] > 0 { - return s.Width(colWidths[col]).Padding(0, 1) - } - return s.Padding(0, 1) - } + aw := func(s lipgloss.Style) lipgloss.Style { return applyColWidth(s, col, colWidths) } if row == lgtable.HeaderRow { return aw(st.TableHead) } @@ -490,27 +479,20 @@ func buildTaskHistoryDetailLines(history []types.TaskRunSummary, histIdx int, st contentW := maxInt(12, width-lipgloss.Width(indent)) useTwoCols := contentW >= 48 - statusText := sel.Status + statusText := runStatusText(sel.Status) statusStyle := st.Value switch sel.Status { case "running": - statusText = "运行中" statusStyle = st.Ok case "completed": - statusText = "完成" statusStyle = st.Ok case "failed": - statusText = "失败" statusStyle = st.ErrStyle case "stopped": - statusText = "已停止" statusStyle = st.Muted } - modeText := "标准" - if sel.Mode == "turbo" { - modeText = "Turbo" - } + modeText := modeShortLabel(sel.Mode) renderCell := func(label, value string, valueStyle lipgloss.Style, cellW int) string { prefix := st.Label.Render(padRight(label, labelW)) diff --git a/internal/tui/pages/tasklist.go b/internal/tui/pages/tasklist.go index 027f5f6..29ca764 100644 --- a/internal/tui/pages/tasklist.go +++ b/internal/tui/pages/tasklist.go @@ -204,6 +204,7 @@ func buildTaskListContent(s *TaskListState, st Styles, width, maxH int) string { lastRun string isRunning bool rate string + cache string ttft string tps string } @@ -254,6 +255,13 @@ func buildTaskListContent(s *TaskListState, st Styles, width, maxH int) string { } } + cacheText := "─" + if hasActiveRun && rs != nil && rs.CacheHitRate > 0 { + cacheText = fmt.Sprintf("%.1f%%", rs.CacheHitRate*100) + } else if !hasActiveRun && t.LatestRun != nil && t.LatestRun.CacheHitRate > 0 { + cacheText = fmt.Sprintf("%.1f%%", t.LatestRun.CacheHitRate*100) + } + rowData[i] = taskRowData{ name: t.Name, mode: modeText, @@ -262,6 +270,7 @@ func buildTaskListContent(s *TaskListState, st Styles, width, maxH int) string { lastRun: lastRunText, isRunning: isRunning, rate: rateText, + cache: cacheText, ttft: ttftText, tps: tpsText, } @@ -269,9 +278,9 @@ func buildTaskListContent(s *TaskListState, st Styles, width, maxH int) string { // ── 构建 lipgloss/table ── // colWidths: 0 = 弹性列(占用剩余宽度),>0 = 固定总宽(包括两端各 1 字符 padding) - colWidths := []int{0, 8, 22, 12, 8, 10, 10} // 任务名称=flex, 模式, 协议, 上次运行, 成功率, TTFT, TPS + colWidths := []int{0, 8, 22, 12, 8, 8, 10, 10} // 任务名称=flex, 模式, 协议, 上次运行, 成功率, 缓存, TTFT, TPS t := lgtable.New(). - Headers("任务名称", "模式", "协议", "上次运行", "成功率", "TTFT", "TPS"). + Headers("任务名称", "模式", "协议", "上次运行", "成功率", "缓存", "TTFT", "TPS"). Width(width). Height(maxH). YOffset(s.Offset). @@ -280,12 +289,7 @@ func buildTaskListContent(s *TaskListState, st Styles, width, maxH int) string { BorderHeader(true).BorderColumn(true).BorderRow(true). BorderStyle(lipgloss.NewStyle().Foreground(colorDivider)). StyleFunc(func(row, col int) lipgloss.Style { - aw := func(s lipgloss.Style) lipgloss.Style { - if col < len(colWidths) && colWidths[col] > 0 { - return s.Width(colWidths[col]).Padding(0, 1) - } - return s.Padding(0, 1) - } + aw := func(s lipgloss.Style) lipgloss.Style { return applyColWidth(s, col, colWidths) } if row == lgtable.HeaderRow { return aw(st.TableHead) } @@ -307,7 +311,7 @@ func buildTaskListContent(s *TaskListState, st Styles, width, maxH int) string { return aw(st.Ok) } return aw(st.Muted) - case 4, 5, 6: // rate, ttft, tps + case 4, 5, 6, 7: // rate, cache, ttft, tps return aw(st.Value) default: return aw(st.TableRow) @@ -315,7 +319,7 @@ func buildTaskListContent(s *TaskListState, st Styles, width, maxH int) string { }) for _, r := range rowData { - t.Row(r.name, r.mode, r.proto, r.lastRun, r.rate, r.ttft, r.tps) + t.Row(r.name, r.mode, r.proto, r.lastRun, r.rate, r.cache, r.ttft, r.tps) } tableStr := t.String() @@ -376,10 +380,3 @@ func buildTaskListConfirmContent(s *TaskListState, st Styles, width, maxH int) s } return strings.Join(lines, "\n") } - -func max(a, b int) int { - if a > b { - return a - } - return b -} diff --git a/internal/tui/pages/turbodash.go b/internal/tui/pages/turbodash.go index 265edd0..448a1b8 100644 --- a/internal/tui/pages/turbodash.go +++ b/internal/tui/pages/turbodash.go @@ -57,10 +57,10 @@ func NewTurboDashState(runID server.RunID, taskID string) *TurboDashState { // IsRunning 判断是否仍在运行。 func (d *TurboDashState) IsRunning() bool { - if d == nil || d.RunState == nil { + if d == nil { return false } - return d.RunState.Status == server.RunStatusRunning + return isRunStateRunning(d.RunState) } // HandleTurboDashKey 处理 Turbo 仪表盘按键。 @@ -265,10 +265,7 @@ func buildTurboDashMetrics(rs *server.RunState, st Styles, maxH, width int) stri if rs == nil { lines = append(lines, " "+st.Muted.Render("等待数据...")) } else { - lines = append(lines, " "+labelValue(st, "成功率 ", st.MetricVal.Render(fmt.Sprintf("%.1f%%", rs.SuccessRate)))) - lines = append(lines, " "+labelValue(st, "TPS ", st.MetricVal.Render(fmt.Sprintf("%.1f tok/s", rs.AvgTPS)))) - lines = append(lines, " "+labelValue(st, "TTFT ", st.MetricVal.Render(fmtDuration(rs.AvgTTFT)))) - lines = append(lines, " "+labelValue(st, "Cache ", st.MetricVal.Render(fmt.Sprintf("%.1f%%", rs.CacheHitRate*100)))) + lines = appendRunMetricLines(lines, st, rs) } return finishPanelLines(lines, maxH) @@ -368,12 +365,7 @@ func buildTurboRequestList(d *TurboDashState, rs *server.RunState, st Styles, wi BorderHeader(true).BorderColumn(true).BorderRow(true). BorderStyle(lipgloss.NewStyle().Foreground(colorDivider)). StyleFunc(func(row, col int) lipgloss.Style { - aw := func(s lipgloss.Style) lipgloss.Style { - if col < len(colWidths) && colWidths[col] > 0 { - return s.Width(colWidths[col]).Padding(0, 1) - } - return s.Padding(0, 1) - } + aw := func(s lipgloss.Style) lipgloss.Style { return applyColWidth(s, col, colWidths) } if row == lgtable.HeaderRow { return aw(st.TableHead) } From 4829c97128fea6e395cbb1aae0c33e611898e9cd Mon Sep 17 00:00:00 2001 From: Alain Date: Sat, 23 May 2026 23:03:58 +0800 Subject: [PATCH 43/52] =?UTF-8?q?feat:=20=E6=9B=B4=E6=96=B0=E8=BF=90?= =?UTF-8?q?=E8=A1=8C=E6=8C=87=E6=A0=87=E6=98=BE=E7=A4=BA=EF=BC=8C=E4=BF=AE?= =?UTF-8?q?=E6=94=B9=E5=88=97=E6=A0=87=E9=A2=98=E4=B8=BA=E2=80=9C=E5=9D=87?= =?UTF-8?q?=E5=80=BCTTFT=E2=80=9D=E5=92=8C=E2=80=9C=E5=9D=87=E5=80=BCTPS?= =?UTF-8?q?=E2=80=9D=EF=BC=8C=E5=A2=9E=E5=BC=BA=E5=8F=AF=E8=AF=BB=E6=80=A7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/tui/pages/helpers.go | 4 ++-- internal/tui/pages/tasklist.go | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/internal/tui/pages/helpers.go b/internal/tui/pages/helpers.go index 568814b..d41b466 100644 --- a/internal/tui/pages/helpers.go +++ b/internal/tui/pages/helpers.go @@ -471,8 +471,8 @@ func applyColWidth(s lipgloss.Style, col int, colWidths []int) lipgloss.Style { // appendRunMetricLines 向 lines 追加 4 行运行指标(成功率/TPS/TTFT/缓存命中)。 func appendRunMetricLines(lines []string, st Styles, rs *server.RunState) []string { lines = append(lines, " "+labelValue(st, "成功率 ", st.MetricVal.Render(fmt.Sprintf("%.1f%%", rs.SuccessRate)))) - lines = append(lines, " "+labelValue(st, "TPS ", st.MetricVal.Render(fmt.Sprintf("%.1f tok/s", rs.AvgTPS)))) - lines = append(lines, " "+labelValue(st, "TTFT ", st.MetricVal.Render(fmtDuration(rs.AvgTTFT)))) + lines = append(lines, " "+labelValue(st, "TPS均值 ", st.MetricVal.Render(fmt.Sprintf("%.1f tok/s", rs.AvgTPS)))) + lines = append(lines, " "+labelValue(st, "TTFT均值", st.MetricVal.Render(fmtDuration(rs.AvgTTFT)))) lines = append(lines, " "+labelValue(st, "缓存命中", st.MetricVal.Render(fmt.Sprintf("%.1f%%", rs.CacheHitRate*100)))) return lines } diff --git a/internal/tui/pages/tasklist.go b/internal/tui/pages/tasklist.go index 29ca764..c999501 100644 --- a/internal/tui/pages/tasklist.go +++ b/internal/tui/pages/tasklist.go @@ -278,9 +278,9 @@ func buildTaskListContent(s *TaskListState, st Styles, width, maxH int) string { // ── 构建 lipgloss/table ── // colWidths: 0 = 弹性列(占用剩余宽度),>0 = 固定总宽(包括两端各 1 字符 padding) - colWidths := []int{0, 8, 22, 12, 8, 8, 10, 10} // 任务名称=flex, 模式, 协议, 上次运行, 成功率, 缓存, TTFT, TPS + colWidths := []int{0, 8, 22, 12, 8, 10, 10, 10} // 任务名称=flex, 模式, 协议, 上次运行, 成功率, 缓存命中, TTFT均值, TPS均值 t := lgtable.New(). - Headers("任务名称", "模式", "协议", "上次运行", "成功率", "缓存", "TTFT", "TPS"). + Headers("任务名称", "模式", "协议", "上次运行", "成功率", "缓存命中", "均值TTFT", "均值TPS"). Width(width). Height(maxH). YOffset(s.Offset). From a8a902548c7e015a2c8073e0d0c1b86bbcc50625 Mon Sep 17 00:00:00 2001 From: Alain Date: Sat, 23 May 2026 23:19:19 +0800 Subject: [PATCH 44/52] =?UTF-8?q?feat:=20=E6=B7=BB=E5=8A=A0=E5=B8=AE?= =?UTF-8?q?=E5=8A=A9=E9=A1=B5=E5=8A=9F=E8=83=BD=EF=BC=8C=E6=94=AF=E6=8C=81?= =?UTF-8?q?=E5=B8=AE=E5=8A=A9=E5=86=85=E5=AE=B9=E7=9A=84=E6=B8=B2=E6=9F=93?= =?UTF-8?q?=E4=B8=8E=E5=AF=BC=E8=88=AA?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/tui/model.go | 36 +++++ internal/tui/pages/contextbar.go | 9 ++ internal/tui/pages/dashboard.go | 5 +- internal/tui/pages/help.go | 258 +++++++++++++++++++++++++++++++ internal/tui/pages/layout.go | 11 ++ internal/tui/pages/nav.go | 1 + internal/tui/pages/proxy.go | 5 +- internal/tui/pages/reqdetail.go | 5 +- internal/tui/pages/taskdetail.go | 5 +- internal/tui/pages/tasklist.go | 5 +- internal/tui/pages/turbodash.go | 5 +- internal/tui/pages/wizard.go | 5 +- 12 files changed, 343 insertions(+), 7 deletions(-) create mode 100644 internal/tui/pages/help.go diff --git a/internal/tui/model.go b/internal/tui/model.go index 560abfe..68f1d22 100644 --- a/internal/tui/model.go +++ b/internal/tui/model.go @@ -23,6 +23,7 @@ const ( viewTurboDash viewState = "turbo-dash" viewReqDetail viewState = "req-detail" viewProxy viewState = "proxy" + viewHelp viewState = "help" ) // ─── 根 Model ───────────────────────────────────────────────────────────────── @@ -46,6 +47,7 @@ type Model struct { turboDash *pages.TurboDashState reqDetail *pages.ReqDetailState proxyConf *pages.ProxyConfigState + help *pages.HelpState } // NewModel 创建 Model。srv 不能为 nil。 @@ -255,6 +257,8 @@ func (m *Model) View() string { content = pages.RenderReqDetail(m.reqDetail, m.reqDetailTaskName(), m.styles, innerW, innerH) case viewProxy: content = pages.RenderProxyConfig(m.proxyConf, m.styles, innerW, innerH) + case viewHelp: + content = pages.RenderHelp(m.help, m.styles, innerW, innerH) default: content = "未知视图" } @@ -313,6 +317,11 @@ func (m *Model) handleKey(msg tea.KeyMsg) (tea.Model, tea.Cmd) { m.proxyConf = newState navCmd := m.handleNav(nav) return m, tea.Batch(cmd, navCmd) + + case viewHelp: + newState, nav := pages.HandleHelpKey(m.help, msg) + m.help = newState + return m, m.handleNav(nav) } return m, nil @@ -400,12 +409,39 @@ func (m *Model) handleNav(nav pages.NavAction) tea.Cmd { m.view = viewProxy return m.client.LoadProxyConfigCmd() + case pages.NavHelp: + m.help = pages.NewHelpState(pages.NavAction{To: m.currentNavTarget()}) + m.view = viewHelp + return nil + case pages.NavQuit: return tea.Quit } return nil } +// currentNavTarget 返回当前视图对应的 NavTarget,用于帮助页的返回导航。 +func (m *Model) currentNavTarget() pages.NavTarget { + switch m.view { + case viewTaskList: + return pages.NavTaskList + case viewTaskDetail: + return pages.NavTaskDetail + case viewWizard: + return pages.NavWizard + case viewDashboard: + return pages.NavDashboard + case viewTurboDash: + return pages.NavTurboDash + case viewReqDetail: + return pages.NavReqDetail + case viewProxy: + return pages.NavProxy + default: + return pages.NavTaskList + } +} + // ─── Server 事件处理 ────────────────────────────────────────────────────────── func (m *Model) handleServerEvent(msg ServerEventMsg) (tea.Model, tea.Cmd) { diff --git a/internal/tui/pages/contextbar.go b/internal/tui/pages/contextbar.go index 51bd4e2..3c27c07 100644 --- a/internal/tui/pages/contextbar.go +++ b/internal/tui/pages/contextbar.go @@ -199,3 +199,12 @@ func Hotkeys_ProxyConfig() []HotkeyItem { HotkeyAction("Ctrl+U", "清空"), } } + +// Hotkeys_Help 帮助页。 +func Hotkeys_Help() []HotkeyItem { + return []HotkeyItem{ + HotkeyAction("↑↓", "滚动"), + HotkeyAction("PgUp/PgDn", "翻页"), + HotkeyAction("g/G", "顶部/底部"), + } +} diff --git a/internal/tui/pages/dashboard.go b/internal/tui/pages/dashboard.go index 1b3e6d3..27fd6b3 100644 --- a/internal/tui/pages/dashboard.go +++ b/internal/tui/pages/dashboard.go @@ -135,6 +135,9 @@ func HandleDashboardKey(d *DashboardState, msg tea.KeyMsg, client Client) (*Dash return d, client.GenerateReportCmd(d.RunID, server.ReportFormatJSON), nav } + case "?": + nav = NavAction{To: NavHelp} + case "q", "ctrl+c": nav = NavAction{To: NavQuit} } @@ -203,7 +206,7 @@ func RenderDashboard(d *DashboardState, taskName string, st Styles, width, heigh HeaderMeta: "标准模式", HeaderInfoLeft: headerLeft, HeaderInfoRight: headerRight, - Hotkeys: NewPageHotkeys(cbItems, "[b/Esc] 返回上一页", "[q] 退出"), + Hotkeys: NewPageHotkeysWithHelp(cbItems, "[b/Esc] 返回上一页", "[q] 退出"), } frame := l.Frame(width, height) bodyPanel := frame.InnerPanel() diff --git a/internal/tui/pages/help.go b/internal/tui/pages/help.go new file mode 100644 index 0000000..46809a2 --- /dev/null +++ b/internal/tui/pages/help.go @@ -0,0 +1,258 @@ +package pages + +import ( + "strings" + + tea "github.com/charmbracelet/bubbletea" + "charm.land/lipgloss/v2" +) + +// HelpState 帮助页状态。 +type HelpState struct { + ScrollY int + BackNav NavAction // 按 b/Esc 时的返回目标 +} + +// NewHelpState 创建帮助页状态。 +func NewHelpState(backNav NavAction) *HelpState { + return &HelpState{BackNav: backNav} +} + +// HandleHelpKey 处理帮助页按键。 +func HandleHelpKey(s *HelpState, msg tea.KeyMsg) (*HelpState, NavAction) { + nav := NavAction{} + if s == nil { + return s, NavAction{To: NavTaskList} + } + + lines := buildHelpLines(s, 9999, 9999) // 仅用于计算总行数 + totalLines := len(lines) + + switch msg.String() { + case "b", "esc", "q", "?": + if s.BackNav.To != NavNone { + nav = s.BackNav + } else { + nav = NavAction{To: NavTaskList} + } + + case "ctrl+c": + nav = NavAction{To: NavQuit} + + case "up", "k": + if s.ScrollY > 0 { + s.ScrollY-- + } + + case "down", "j": + if s.ScrollY < totalLines-1 { + s.ScrollY++ + } + + case "pgup": + s.ScrollY -= 10 + if s.ScrollY < 0 { + s.ScrollY = 0 + } + + case "pgdown": + s.ScrollY += 10 + if s.ScrollY >= totalLines { + s.ScrollY = maxInt(0, totalLines-1) + } + + case "home", "g": + s.ScrollY = 0 + + case "end", "G": + s.ScrollY = maxInt(0, totalLines-1) + } + + return s, nav +} + +// RenderHelp 渲染帮助页面。 +func RenderHelp(s *HelpState, st Styles, width, height int) string { + if TooSmall(width, height) { + return renderTooSmall(st, width, height) + } + if s == nil { + s = &HelpState{} + } + + l := PageLayout{ + HeaderTitle: "帮助", + HeaderSubtitle: "AIT — AI 接口压测工具概念说明与操作指南", + HeaderMeta: "帮助", + Hotkeys: NewPageHotkeys(Hotkeys_Help(), "[b/Esc] 返回", "[q] 退出"), + } + frame := l.Frame(width, height) + panel := NewPanelFrame(frame.OuterWidth) + content := buildHelpContent(s, st, panel.InnerWidth, PanelContentHeight(frame.InnerHeight)) + return l.Assemble(panel.Wrap(st, content), st, width) +} + +// ─── 内容构建 ───────────────────────────────────────────────────────────────── + +// helpSection 表示帮助页的一个章节。 +type helpSection struct { + title string + items []helpItem +} + +type helpItem struct { + term string // 概念名称或快捷键 + desc string // 说明 +} + +func helpContent() []helpSection { + return []helpSection{ + { + title: "核心概念", + items: []helpItem{ + {"任务 (Task)", "一组压测配置的集合,包含目标接口、模型、并发数、请求数等参数。任务可多次运行,每次运行独立记录结果。"}, + {"运行 (Run)", "任务的一次具体执行。每次运行产生独立的指标数据和请求记录,可导出为 JSON/CSV 报告。"}, + {"标准模式", "以固定并发数执行所有请求,适合衡量稳定负载下的接口性能。"}, + {"Turbo 模式", "自动从低并发逐步爬坡,找出接口在保持成功率要求下能承受的最大稳定并发数。"}, + }, + }, + { + title: "性能指标", + items: []helpItem{ + {"TPS", "Tokens Per Second,每秒输出 Token 数,衡量模型的文本生成速率。"}, + {"均值TPS", "本次运行中所有请求的 TPS 均值,反映整体吞吐水平。"}, + {"TTFT", "Time To First Token,从发送请求到收到第一个 Token 的耗时,衡量模型响应延迟。"}, + {"均值TTFT", "本次运行中所有请求的 TTFT 均值。"}, + {"成功率", "成功完成的请求数占总请求数的百分比。失败包括超时、HTTP 错误、模型返回错误等。"}, + {"缓存命中", "请求中使用了 KV 缓存(Prompt Cache)的比例。命中缓存可显著降低 TTFT 和推理成本。该指标为二值统计:单次请求若有任何 Token 命中缓存则计为命中。"}, + {"并发(Turbo)", "Turbo 模式下找到的最大稳定并发数,即在满足最低成功率要求的前提下能同时维持的请求数。"}, + }, + }, + { + title: "协议支持", + items: []helpItem{ + {"OpenAI", "兼容 OpenAI Chat Completions API(/v1/chat/completions),支持流式和非流式响应。"}, + {"Anthropic", "兼容 Anthropic Messages API(/v1/messages),支持流式和非流式响应。"}, + }, + }, + { + title: "快捷键 — 全局", + items: []helpItem{ + {"q / Ctrl+C", "退出程序。"}, + {"?", "打开此帮助页。"}, + {"b / Esc", "返回上一页。"}, + }, + }, + { + title: "快捷键 — 任务列表", + items: []helpItem{ + {"↑↓ / j k", "选择任务。"}, + {"Enter", "进入任务详情页。"}, + {"r", "立即运行选中任务。"}, + {"s", "停止正在运行的任务(仅任务运行中可用)。"}, + {"a", "新建任务(打开向导)。"}, + {"e", "编辑选中任务配置。"}, + {"d", "删除选中任务(需确认)。"}, + {"y", "复制选中任务(生成副本)。"}, + {"p", "打开代理配置页。"}, + }, + }, + { + title: "快捷键 — 任务详情", + items: []helpItem{ + {"↑↓ / j k", "在历史运行记录中选择条目。"}, + {"Enter", "查看选中运行的仪表盘;若任务正在运行,进入实时仪表盘。"}, + {"r", "再次运行该任务(无正在运行的实例时可用)。"}, + {"g", "将选中的历史运行导出为 JSON 报告。"}, + {"e", "编辑任务配置。"}, + {"y", "复制任务。"}, + {"d", "删除任务。"}, + }, + }, + { + title: "快捷键 — 运行仪表盘", + items: []helpItem{ + {"↑↓ / j k", "选择请求条目。"}, + {"Enter", "查看选中请求的详情(耗时、Token、响应体等)。"}, + {"s", "停止正在运行的任务。"}, + {"r", "生成 JSON 报告(运行结束后可用)。"}, + {"b / Esc", "返回任务详情页。"}, + }, + }, + { + title: "报告导出", + items: []helpItem{ + {"JSON 报告", "完整记录每次请求的所有指标、请求/响应体,适合程序化分析。"}, + {"CSV 报告", "表格形式的汇总数据,可直接在电子表格中打开。报告默认保存在当前工作目录。"}, + }, + }, + } +} + +func buildHelpLines(s *HelpState, contentW, _ int) []string { + sections := helpContent() + var lines []string + for _, sec := range sections { + lines = append(lines, " "+sec.title) + lines = append(lines, "") + for _, item := range sec.items { + lines = append(lines, " "+item.term) + // 简单 wrap desc + wrapped := wrapText(item.desc, maxInt(20, contentW-6)) + for _, l := range wrapped { + lines = append(lines, " "+l) + } + lines = append(lines, "") + } + } + return lines +} + +func buildHelpContent(s *HelpState, st Styles, contentW, maxH int) string { + sections := helpContent() + termW := 16 // 概念名/快捷键列宽 + + var rawLines []string + for _, sec := range sections { + // 章节标题 + rawLines = append(rawLines, st.SectionHead.Render(" "+sec.title)) + rawLines = append(rawLines, "") + for _, item := range sec.items { + // term 列 + termStr := st.Label.Render(padRight(item.term, termW)) + // desc 第一行与 term 同行,后续行缩进 + descW := maxInt(20, contentW-termW-4) + wrapped := wrapText(item.desc, descW) + if len(wrapped) == 0 { + wrapped = []string{""} + } + // 第一行:term + desc[0] + firstLine := " " + termStr + " " + wrapped[0] + rawLines = append(rawLines, firstLine) + // 后续行缩进对齐 + indent := strings.Repeat(" ", 2+lipgloss.Width(termStr)+2) + for _, seg := range wrapped[1:] { + rawLines = append(rawLines, indent+seg) + } + } + rawLines = append(rawLines, "") + } + + // 应用滚动 + if s.ScrollY >= len(rawLines) { + s.ScrollY = maxInt(0, len(rawLines)-1) + } + visible := rawLines + if s.ScrollY > 0 { + visible = rawLines[s.ScrollY:] + } + + // 填充至 maxH + if len(visible) > maxH { + visible = visible[:maxH] + } + for len(visible) < maxH { + visible = append(visible, "") + } + return strings.Join(visible, "\n") +} diff --git a/internal/tui/pages/layout.go b/internal/tui/pages/layout.go index e29504c..e252b2d 100644 --- a/internal/tui/pages/layout.go +++ b/internal/tui/pages/layout.go @@ -38,6 +38,7 @@ type PageHotkeys struct { } // NewPageHotkeys 用于构建统一的页面 Hotkeys。 +// 所有页面(帮助页除外)会自动在 Hotkeys 末尾追加 [?] 帮助提示。 func NewPageHotkeys(hotkeys []HotkeyItem, hints ...string) PageHotkeys { return PageHotkeys{ Hotkeys: hotkeys, @@ -45,6 +46,16 @@ func NewPageHotkeys(hotkeys []HotkeyItem, hints ...string) PageHotkeys { } } +// NewPageHotkeysWithHelp 在 NewPageHotkeys 基础上自动追加 [?] 帮助。 +// 非帮助页使用此函数以统一显示帮助入口。 +func NewPageHotkeysWithHelp(hotkeys []HotkeyItem, hints ...string) PageHotkeys { + withHelp := append(hotkeys, HotkeyAction("?", "帮助")) + return PageHotkeys{ + Hotkeys: withHelp, + Hints: HotkeyTexts(hints...), + } +} + // PageFrame 描述页面主内容区的统一尺寸。 // OuterWidth 是最外层内容面板总宽度,InnerWidth/InnerHeight 是面板内部可用区域。 type PageFrame struct { diff --git a/internal/tui/pages/nav.go b/internal/tui/pages/nav.go index 7f70876..83b079b 100644 --- a/internal/tui/pages/nav.go +++ b/internal/tui/pages/nav.go @@ -22,6 +22,7 @@ const ( NavRunDetail // 从历史记录进入某次运行的仪表盘(需 RunID) NavReqDetail // 进入请求详情(需 ReqIndex) NavProxy // 进入代理配置页 + NavHelp // 打开帮助页 NavQuit // 退出程序 ) diff --git a/internal/tui/pages/proxy.go b/internal/tui/pages/proxy.go index 47b8a20..f8b04f9 100644 --- a/internal/tui/pages/proxy.go +++ b/internal/tui/pages/proxy.go @@ -160,6 +160,9 @@ func HandleProxyConfigKey(s *ProxyConfigState, msg tea.KeyMsg, client Client) (* return s, cmd, nav } + case "?": + nav = NavAction{To: NavHelp} + case "q", "ctrl+c": nav = NavAction{To: NavQuit} @@ -187,7 +190,7 @@ func RenderProxyConfig(s *ProxyConfigState, st Styles, width, height int) string HeaderTitle: "代理配置", HeaderSubtitle: "设置全局 HTTP 代理,适用于所有任务的请求。留空则使用系统环境变量或直连。", HeaderMeta: "全局配置", - Hotkeys: NewPageHotkeys(Hotkeys_ProxyConfig(), "[Esc] 返回", "[q] 退出"), + Hotkeys: NewPageHotkeysWithHelp(Hotkeys_ProxyConfig(), "[Esc] 返回", "[q] 退出"), } frame := l.Frame(width, height) panel := NewPanelFrame(frame.OuterWidth) diff --git a/internal/tui/pages/reqdetail.go b/internal/tui/pages/reqdetail.go index 626ae44..bce5bcd 100644 --- a/internal/tui/pages/reqdetail.go +++ b/internal/tui/pages/reqdetail.go @@ -65,6 +65,9 @@ func HandleReqDetailKey(s *ReqDetailState, msg tea.KeyMsg) (*ReqDetailState, Nav nav = NavAction{To: NavDashboard} } + case "?": + nav = NavAction{To: NavHelp} + case "q", "ctrl+c": nav = NavAction{To: NavQuit} } @@ -126,7 +129,7 @@ func RenderReqDetail(s *ReqDetailState, taskName string, st Styles, width, heigh HeaderMeta: truncate(string(s.RunID), 18), HeaderInfoLeft: []string{fmt.Sprintf("请求 %d/%d", idx+1, len(s.Requests)), status}, HeaderInfoRight: []string{fmt.Sprintf("缓存 %.0f%%", r.CacheHitRate*100), "耗时 " + fmtDuration(r.TotalTime)}, - Hotkeys: NewPageHotkeys(Hotkeys_ReqDetail(), "[b/Esc] 返回上一页", "[q] 退出"), + Hotkeys: NewPageHotkeysWithHelp(Hotkeys_ReqDetail(), "[b/Esc] 返回上一页", "[q] 退出"), } frame := l.Frame(width, height) bodyPanel := frame.InnerPanel() diff --git a/internal/tui/pages/taskdetail.go b/internal/tui/pages/taskdetail.go index 74646c9..e7ea2c5 100644 --- a/internal/tui/pages/taskdetail.go +++ b/internal/tui/pages/taskdetail.go @@ -131,6 +131,9 @@ func HandleTaskDetailKey(s *TaskDetailState, msg tea.KeyMsg, client Client) (*Ta case "d": return s, client.DeleteTaskCmd(s.Task.ID), nav + case "?": + nav = NavAction{To: NavHelp} + case "q", "ctrl+c": nav = NavAction{To: NavQuit} } @@ -199,7 +202,7 @@ func RenderTaskDetail(s *TaskDetailState, st Styles, width, height int) string { HeaderMeta: "任务详情", HeaderInfoLeft: []string{modeStr, inp.NormalizedProtocol()}, HeaderInfoRight: headerRight, - Hotkeys: NewPageHotkeys(cbItems, "[b/Esc] 返回上一页", "[q] 退出"), + Hotkeys: NewPageHotkeysWithHelp(cbItems, "[b/Esc] 返回上一页", "[q] 退出"), } frame := l.Frame(width, height) diff --git a/internal/tui/pages/tasklist.go b/internal/tui/pages/tasklist.go index c999501..1245228 100644 --- a/internal/tui/pages/tasklist.go +++ b/internal/tui/pages/tasklist.go @@ -129,6 +129,9 @@ func HandleTaskListKey(s *TaskListState, msg tea.KeyMsg, client Client) (*TaskLi } } + case "?": + nav = NavAction{To: NavHelp} + case "q", "ctrl+c": nav = NavAction{To: NavQuit} } @@ -177,7 +180,7 @@ func RenderTaskList(s *TaskListState, st Styles, width, height int) string { HeaderMeta: fmt.Sprintf("%d 个任务", len(s.Tasks)), HeaderInfoLeft: []string{fmt.Sprintf("运行中 %d", runningCount)}, HeaderInfoRight: headerRight, - Hotkeys: NewPageHotkeys(cbItems, "[↑↓] 选择", "[a] 新建", "[q] 退出"), + Hotkeys: NewPageHotkeysWithHelp(cbItems, "[↑↓] 选择", "[a] 新建", "[q] 退出"), } frame := l.Frame(width, height) panel := NewPanelFrame(frame.OuterWidth) diff --git a/internal/tui/pages/turbodash.go b/internal/tui/pages/turbodash.go index 448a1b8..62cedf7 100644 --- a/internal/tui/pages/turbodash.go +++ b/internal/tui/pages/turbodash.go @@ -130,6 +130,9 @@ func HandleTurboDashKey(d *TurboDashState, msg tea.KeyMsg, client Client) (*Turb return d, client.GenerateReportCmd(d.RunID, server.ReportFormatJSON), nav } + case "?": + nav = NavAction{To: NavHelp} + case "q", "ctrl+c": nav = NavAction{To: NavQuit} } @@ -207,7 +210,7 @@ func RenderTurboDash(d *TurboDashState, taskName string, st Styles, width, heigh HeaderMeta: "Turbo 模式", HeaderInfoLeft: headerLeft, HeaderInfoRight: headerRight, - Hotkeys: NewPageHotkeys(cbItems, "[b/Esc] 返回上一页", "[q] 退出"), + Hotkeys: NewPageHotkeysWithHelp(cbItems, "[b/Esc] 返回上一页", "[q] 退出"), } frame := l.Frame(width, height) bodyPanel := frame.InnerPanel() diff --git a/internal/tui/pages/wizard.go b/internal/tui/pages/wizard.go index 5271dd5..45a0323 100644 --- a/internal/tui/pages/wizard.go +++ b/internal/tui/pages/wizard.go @@ -582,6 +582,9 @@ func HandleWizardKey(wz *WizardState, msg tea.KeyMsg, client Client) (*WizardSta } loadCurrentFieldInput(wz) + case "?": + nav = NavAction{To: NavHelp} + case "ctrl+c": nav = NavAction{To: NavQuit} @@ -644,7 +647,7 @@ func RenderWizard(wz *WizardState, st Styles, width, height int) string { HeaderMeta: fmt.Sprintf("步骤 %d/3", int(wz.Step)+1), HeaderInfoLeft: headerLeft, HeaderInfoRight: headerRight, - Hotkeys: NewPageHotkeys(wizardHotkeyItems(wz.Step), "[Ctrl+C] 退出"), + Hotkeys: NewPageHotkeysWithHelp(wizardHotkeyItems(wz.Step), "[Ctrl+C] 退出"), } frame := l.Frame(width, height) panel := NewPanelFrame(frame.OuterWidth) From c4fe89daa69c4a9f9dc3b08264f5770324e79dc7 Mon Sep 17 00:00:00 2001 From: Alain Date: Sat, 23 May 2026 23:48:44 +0800 Subject: [PATCH 45/52] =?UTF-8?q?feat:=20=E6=B7=BB=E5=8A=A0=20RPM=20?= =?UTF-8?q?=E5=92=8C=20TPM=20=E6=8C=87=E6=A0=87=E8=AE=A1=E7=AE=97=EF=BC=8C?= =?UTF-8?q?=E5=A2=9E=E5=BC=BA=E6=80=A7=E8=83=BD=E6=8A=A5=E5=91=8A=E5=92=8C?= =?UTF-8?q?=E4=BB=AA=E8=A1=A8=E7=9B=98=E6=98=BE=E7=A4=BA?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/runner/runner.go | 8 ++++++++ internal/server/run.go | 35 ++++++++++++++++++++++++++++++++ internal/server/types.go | 5 +++++ internal/store/run.go | 21 +++++++++++++++++++ internal/tui/pages/dashboard.go | 2 +- internal/tui/pages/helpers.go | 4 +++- internal/tui/pages/taskdetail.go | 19 +++++++++++++---- internal/tui/pages/tasklist.go | 26 ++++++++++++++++++++---- internal/tui/pages/turbodash.go | 2 +- internal/types/types.go | 8 ++++++++ 10 files changed, 119 insertions(+), 11 deletions(-) diff --git a/internal/runner/runner.go b/internal/runner/runner.go index 42dcb52..e7fdfea 100644 --- a/internal/runner/runner.go +++ b/internal/runner/runner.go @@ -682,6 +682,12 @@ func (r *Runner) calculateResult(results []*client.ResponseMetrics, totalTime ti stdDevTPS := math.Sqrt(varianceSumTPS / float64(validCount)) stdDevTotalThroughputTPS := math.Sqrt(varianceSumTotalThroughputTPS / float64(validCount)) + var rpm, tpm float64 + if totalTime.Minutes() > 0 { + rpm = float64(successCount) / totalTime.Minutes() + tpm = float64(sumOutputTokens) / totalTime.Minutes() + } + return &types.ReportData{ TotalRequests: requestCount, Concurrency: r.input.Concurrency, @@ -731,6 +737,8 @@ func (r *Runner) calculateResult(results []*client.ResponseMetrics, totalTime ti AvgTotalThroughputTPS: avgTotalThroughputTPS, MinTotalThroughputTPS: minTotalThroughputTPS, MaxTotalThroughputTPS: maxTotalThroughputTPS, + RPM: rpm, + TPM: tpm, StdDevTotalTime: stdDevTotalTime, StdDevTTFT: stdDevTTFT, StdDevTPOT: stdDevTPOT, diff --git a/internal/server/run.go b/internal/server/run.go index 578d9b6..989adf6 100644 --- a/internal/server/run.go +++ b/internal/server/run.go @@ -27,6 +27,7 @@ type activeRun struct { tpsSum float64 ttftSum time.Duration cacheSum float64 + tokenSum int64 // 累计成功请求的输出 Token 数,用于计算 TPM doneCount int // 与 state.DoneReqs 保持同步,方便不加锁时计算 } @@ -175,6 +176,24 @@ func buildRunStateFromStoredRun(run *store.StoredRun, requests []types.RequestMe } state.FailedReqs = state.DoneReqs - state.SuccessReqs state.TotalReqs = run.TotalReqs(requests) + + // 从存储数据重建 RPM/TPM + end := time.Now() + if run.Metadata.FinishedAt != nil { + end = *run.Metadata.FinishedAt + } + if !run.Metadata.StartedAt.IsZero() { + if elapsed := end.Sub(run.Metadata.StartedAt).Minutes(); elapsed > 0 { + var tokenSum int64 + for _, r := range requests { + if r.Success { + tokenSum += int64(r.CompletionTokens) + } + } + state.RPM = float64(state.DoneReqs) / elapsed + state.TPM = float64(tokenSum) / elapsed + } + } if run.Result == nil { return state } @@ -380,6 +399,7 @@ func (s *serverImpl) runTurbo(ar *activeRun, runID RunID, taskDef types.TaskDefi ar.tpsSum += rm.TPS ar.ttftSum += rm.TTFT ar.cacheSum += rm.CacheHitRate + ar.tokenSum += int64(rm.CompletionTokens) } else { ar.state.FailedReqs++ } @@ -391,6 +411,11 @@ func (s *serverImpl) runTurbo(ar *activeRun, runID RunID, taskDef types.TaskDefi if ar.state.DoneReqs > 0 { ar.state.SuccessRate = float64(ar.state.SuccessReqs) / float64(ar.state.DoneReqs) * 100 } + // 更新 RPM/TPM(基于运行时长) + if elapsed := time.Since(ar.state.StartedAt).Minutes(); elapsed > 0 { + ar.state.RPM = float64(ar.state.DoneReqs) / elapsed + ar.state.TPM = float64(ar.tokenSum) / elapsed + } snap := ar.snapshotState() ar.mu.Unlock() @@ -435,6 +460,11 @@ func (s *serverImpl) completeStandardRun(ar *activeRun, runID RunID, taskDef typ ar.state.SuccessRate = data.SuccessRate ar.state.CacheHitRate = data.AvgCacheHitRate } + // 使用完整运行时长计算最终稳定的 RPM/TPM + if elapsed := finishedAt.Sub(ar.state.StartedAt).Minutes(); elapsed > 0 { + ar.state.RPM = float64(ar.state.DoneReqs) / elapsed + ar.state.TPM = float64(ar.tokenSum) / elapsed + } snap := ar.snapshotState() ar.mu.Unlock() @@ -457,6 +487,11 @@ func (s *serverImpl) completeTurboRun(ar *activeRun, runID RunID, taskDef types. ar.state.Levels = result.Levels ar.state.CurrentLevel = result.MaxStableConcurrency } + // 使用完整运行时长计算最终稳定的 RPM/TPM + if elapsed := finishedAt.Sub(ar.state.StartedAt).Minutes(); elapsed > 0 { + ar.state.RPM = float64(ar.state.DoneReqs) / elapsed + ar.state.TPM = float64(ar.tokenSum) / elapsed + } snap := ar.snapshotState() ar.mu.Unlock() diff --git a/internal/server/types.go b/internal/server/types.go index 635a549..2ad470a 100644 --- a/internal/server/types.go +++ b/internal/server/types.go @@ -59,6 +59,11 @@ type RunState struct { SuccessRate float64 CacheHitRate float64 + // 吞吐量指标(基于整体运行时长,最终稳定值) + // RPM = 每分钟完成请求数;TPM = 每分钟输出 Token 数 + RPM float64 + TPM float64 + // 详细请求列表(按 index 排序) Requests []*types.RequestMetrics diff --git a/internal/store/run.go b/internal/store/run.go index 4321e0d..a1d617d 100644 --- a/internal/store/run.go +++ b/internal/store/run.go @@ -302,6 +302,25 @@ func (r StoredRun) Summary(requests []types.RequestMetrics) types.TaskRunSummary summary.AvgTTFT = derived.AvgTTFT summary.AvgTPS = derived.AvgTPS summary.CacheHitRate = derived.CacheHitRate + + // 从时间信息计算 RPM/TPM + if !r.Metadata.StartedAt.IsZero() { + end := time.Now() + if r.Metadata.FinishedAt != nil { + end = *r.Metadata.FinishedAt + } + if elapsed := end.Sub(r.Metadata.StartedAt).Minutes(); elapsed > 0 { + var totalTokens int64 + for _, req := range requests { + if req.Success { + totalTokens += int64(req.CompletionTokens) + } + } + summary.RPM = float64(len(requests)) / elapsed + summary.TPM = float64(totalTokens) / elapsed + } + } + if r.Result != nil { summary.ErrorSummary = r.Result.ErrorSummary summary.MaxStableConcurrency = r.Result.MaxStableConcurrency @@ -310,6 +329,8 @@ func (r StoredRun) Summary(requests []types.RequestMetrics) types.TaskRunSummary summary.AvgTTFT = r.Result.StandardResult.AvgTTFT summary.AvgTPS = r.Result.StandardResult.AvgTPS summary.CacheHitRate = r.Result.StandardResult.AvgCacheHitRate + summary.RPM = r.Result.StandardResult.RPM + summary.TPM = r.Result.StandardResult.TPM } if r.Result.TurboResult != nil { summary.MaxStableConcurrency = r.Result.TurboResult.MaxStableConcurrency diff --git a/internal/tui/pages/dashboard.go b/internal/tui/pages/dashboard.go index 27fd6b3..5f4a52d 100644 --- a/internal/tui/pages/dashboard.go +++ b/internal/tui/pages/dashboard.go @@ -212,7 +212,7 @@ func RenderDashboard(d *DashboardState, taskName string, st Styles, width, heigh bodyPanel := frame.InnerPanel() // ── 计算高度 ── - splitOuterH := 7 // 双栏面板外部总高度(含面板边框) + splitOuterH := 9 // 双栏面板外部总高度(含面板边框) progressOuterH := 3 // 进度条面板外部高度(1内容+2边框) reqOuterH := RemainingStackOuterHeight(frame.InnerHeight, splitOuterH, progressOuterH) reqListH := PanelContentHeight(reqOuterH) diff --git a/internal/tui/pages/helpers.go b/internal/tui/pages/helpers.go index d41b466..9680ba6 100644 --- a/internal/tui/pages/helpers.go +++ b/internal/tui/pages/helpers.go @@ -468,12 +468,14 @@ func applyColWidth(s lipgloss.Style, col int, colWidths []int) lipgloss.Style { return s.Padding(0, 1) } -// appendRunMetricLines 向 lines 追加 4 行运行指标(成功率/TPS/TTFT/缓存命中)。 +// appendRunMetricLines 向 lines 追加 6 行运行指标(成功率/TPS/TTFT/缓存命中/RPM/TPM)。 func appendRunMetricLines(lines []string, st Styles, rs *server.RunState) []string { lines = append(lines, " "+labelValue(st, "成功率 ", st.MetricVal.Render(fmt.Sprintf("%.1f%%", rs.SuccessRate)))) lines = append(lines, " "+labelValue(st, "TPS均值 ", st.MetricVal.Render(fmt.Sprintf("%.1f tok/s", rs.AvgTPS)))) lines = append(lines, " "+labelValue(st, "TTFT均值", st.MetricVal.Render(fmtDuration(rs.AvgTTFT)))) lines = append(lines, " "+labelValue(st, "缓存命中", st.MetricVal.Render(fmt.Sprintf("%.1f%%", rs.CacheHitRate*100)))) + lines = append(lines, " "+labelValue(st, "RPM ", st.MetricVal.Render(fmt.Sprintf("%.0f req/min", rs.RPM)))) + lines = append(lines, " "+labelValue(st, "TPM ", st.MetricVal.Render(fmt.Sprintf("%.0f tok/min", rs.TPM)))) return lines } diff --git a/internal/tui/pages/taskdetail.go b/internal/tui/pages/taskdetail.go index e7ea2c5..9340bb2 100644 --- a/internal/tui/pages/taskdetail.go +++ b/internal/tui/pages/taskdetail.go @@ -302,6 +302,7 @@ func buildTaskDetailContent(s *TaskDetailState, st Styles, t types.TaskDefinitio type histRow struct { statusText string statusIsOk bool + statusMut bool statusIsMut bool time string mode string @@ -309,6 +310,8 @@ func buildTaskDetailContent(s *TaskDetailState, st Styles, t types.TaskDefinitio dur string ttft string tps string + rpm string + tpm string } rowData := make([]histRow, effectiveLen) if hasActive { @@ -327,6 +330,8 @@ func buildTaskDetailContent(s *TaskDetailState, st Styles, t types.TaskDefinitio dur: "─", ttft: "─", tps: fmt.Sprintf("%d/%d 正在运行...", rs.DoneReqs, rs.TotalReqs), + rpm: "─", + tpm: "─", } } for histIdx := 0; histIdx < len(historyEntries); histIdx++ { @@ -364,15 +369,17 @@ func buildTaskDetailContent(s *TaskDetailState, st Styles, t types.TaskDefinitio dur: durText, ttft: fmtDuration(run.AvgTTFT), tps: fmt.Sprintf("%.1f", run.AvgTPS), + rpm: fmt.Sprintf("%.0f", run.RPM), + tpm: fmt.Sprintf("%.0f", run.TPM), } } // colWidths: 0 = 弹性列,>0 = 固定总宽 - colWidths := []int{4, 0, 7, 8, 7, 7, 7} // 状态图标, 时间=flex, 模式, 成功率, 耗时, TTFT, TPS + colWidths := []int{4, 0, 7, 8, 7, 7, 7, 6, 6} // 状态图标, 时间=flex, 模式, 成功率, 耗时, TTFT, TPS, RPM, TPM sel := s.HistorySel tableH := tableMaxH - len(rightTitle) tbl := lgtable.New(). - Headers("", "时间", "模式", "成功率", "耗时", "TTFT", "TPS"). + Headers("", "时间", "模式", "成功率", "耗时", "TTFT", "TPS", "RPM", "TPM"). Width(rightW). Height(tableH). YOffset(s.HistoryOff). @@ -401,14 +408,14 @@ func buildTaskDetailContent(s *TaskDetailState, st Styles, t types.TaskDefinitio } return aw(st.ErrStyle) } - if col >= 3 { // rate, dur, ttft, tps + if col >= 3 { // rate, dur, ttft, tps, rpm, tpm return aw(st.Value) } return aw(st.TableRow) }) for _, r := range rowData { - tbl.Row(r.statusText, r.time, r.mode, r.rate, r.dur, r.ttft, r.tps) + tbl.Row(r.statusText, r.time, r.mode, r.rate, r.dur, r.ttft, r.tps, r.rpm, r.tpm) } tableStr := tbl.String() @@ -550,6 +557,10 @@ func buildTaskHistoryDetailLines(history []types.TaskRunSummary, histIdx int, st "TTFT", fmtDuration(sel.AvgTTFT), st.Value, "TPS", fmt.Sprintf("%.1f", sel.AvgTPS), st.MetricVal, ) + lines = appendPairRow(lines, + "RPM", fmt.Sprintf("%.0f req/min", sel.RPM), st.MetricVal, + "TPM", fmt.Sprintf("%.0f tok/min", sel.TPM), st.MetricVal, + ) lines = appendSingleField(lines, "协议", shortProtocol(sel.Protocol), st.Value) lines = appendSingleField(lines, "模型", sel.Model, st.Value) if sel.CacheHitRate > 0 { diff --git a/internal/tui/pages/tasklist.go b/internal/tui/pages/tasklist.go index 1245228..c6e63b1 100644 --- a/internal/tui/pages/tasklist.go +++ b/internal/tui/pages/tasklist.go @@ -210,6 +210,8 @@ func buildTaskListContent(s *TaskListState, st Styles, width, maxH int) string { cache string ttft string tps string + rpm string + tpm string } sel := s.Selected @@ -265,6 +267,20 @@ func buildTaskListContent(s *TaskListState, st Styles, width, maxH int) string { cacheText = fmt.Sprintf("%.1f%%", t.LatestRun.CacheHitRate*100) } + rpmText := "─" + if hasActiveRun && rs != nil && rs.RPM > 0 { + rpmText = fmt.Sprintf("%.0f", rs.RPM) + } else if !hasActiveRun && t.LatestRun != nil && t.LatestRun.RPM > 0 { + rpmText = fmt.Sprintf("%.0f", t.LatestRun.RPM) + } + + tpmText := "─" + if hasActiveRun && rs != nil && rs.TPM > 0 { + tpmText = fmt.Sprintf("%.0f", rs.TPM) + } else if !hasActiveRun && t.LatestRun != nil && t.LatestRun.TPM > 0 { + tpmText = fmt.Sprintf("%.0f", t.LatestRun.TPM) + } + rowData[i] = taskRowData{ name: t.Name, mode: modeText, @@ -276,14 +292,16 @@ func buildTaskListContent(s *TaskListState, st Styles, width, maxH int) string { cache: cacheText, ttft: ttftText, tps: tpsText, + rpm: rpmText, + tpm: tpmText, } } // ── 构建 lipgloss/table ── // colWidths: 0 = 弹性列(占用剩余宽度),>0 = 固定总宽(包括两端各 1 字符 padding) - colWidths := []int{0, 8, 22, 12, 8, 10, 10, 10} // 任务名称=flex, 模式, 协议, 上次运行, 成功率, 缓存命中, TTFT均值, TPS均值 + colWidths := []int{0, 8, 22, 12, 8, 10, 10, 10, 8, 8} // 任务名称=flex, 模式, 协议, 上次运行, 成功率, 缓存命中, TTFT均值, TPS均值, RPM, TPM t := lgtable.New(). - Headers("任务名称", "模式", "协议", "上次运行", "成功率", "缓存命中", "均值TTFT", "均值TPS"). + Headers("任务名称", "模式", "协议", "上次运行", "成功率", "缓存命中", "均值TTFT", "均值TPS", "RPM", "TPM"). Width(width). Height(maxH). YOffset(s.Offset). @@ -314,7 +332,7 @@ func buildTaskListContent(s *TaskListState, st Styles, width, maxH int) string { return aw(st.Ok) } return aw(st.Muted) - case 4, 5, 6, 7: // rate, cache, ttft, tps + case 4, 5, 6, 7, 8, 9: // rate, cache, ttft, tps, rpm, tpm return aw(st.Value) default: return aw(st.TableRow) @@ -322,7 +340,7 @@ func buildTaskListContent(s *TaskListState, st Styles, width, maxH int) string { }) for _, r := range rowData { - t.Row(r.name, r.mode, r.proto, r.lastRun, r.rate, r.cache, r.ttft, r.tps) + t.Row(r.name, r.mode, r.proto, r.lastRun, r.rate, r.cache, r.ttft, r.tps, r.rpm, r.tpm) } tableStr := t.String() diff --git a/internal/tui/pages/turbodash.go b/internal/tui/pages/turbodash.go index 62cedf7..a482d8f 100644 --- a/internal/tui/pages/turbodash.go +++ b/internal/tui/pages/turbodash.go @@ -216,7 +216,7 @@ func RenderTurboDash(d *TurboDashState, taskName string, st Styles, width, heigh bodyPanel := frame.InnerPanel() // ── 计算高度 ── - splitOuterH := 7 + splitOuterH := 9 progressOuterH := 3 levelOuterH := RemainingStackOuterHeight(frame.InnerHeight, splitOuterH, progressOuterH) levelListH := PanelContentHeight(levelOuterH) diff --git a/internal/types/types.go b/internal/types/types.go index f03c88a..4815e6b 100644 --- a/internal/types/types.go +++ b/internal/types/types.go @@ -208,6 +208,10 @@ type ReportData struct { MinTPS float64 `json:"min_tps"` // 最小输出 TPS MaxTPS float64 `json:"max_tps"` // 最大输出 TPS + // 分钟吩吐量(基于整体运行时长,最终稳定值) + RPM float64 `json:"rpm"` // 每分钟完成请求数 + TPM float64 `json:"tpm"` // 每分钟输出 Token 数 + // 吞吐量指标 - 统计结果 AvgTotalThroughputTPS float64 `json:"avg_total_throughput_tps"` // 平均吞吐 TPS (输入+输出 tokens per second) MinTotalThroughputTPS float64 `json:"min_total_throughput_tps"` // 最小吞吐 TPS @@ -256,6 +260,8 @@ type TaskRunSummary struct { AvgTTFT time.Duration `json:"avg_ttft"` AvgTPS float64 `json:"avg_tps"` CacheHitRate float64 `json:"cache_hit_rate"` + RPM float64 `json:"rpm,omitempty"` + TPM float64 `json:"tpm,omitempty"` MaxStableConcurrency int `json:"max_stable_concurrency,omitempty"` ErrorSummary string `json:"error_summary,omitempty"` } @@ -298,6 +304,8 @@ type TurboLevelResult struct { PeakTPS float64 `json:"peak_tps"` AvgTTFT time.Duration `json:"avg_ttft"` CacheHitRate float64 `json:"cache_hit_rate"` + RPM float64 `json:"rpm,omitempty"` + TPM float64 `json:"tpm,omitempty"` AvgTotalTime time.Duration `json:"avg_total_time"` StdDevTPS float64 `json:"stddev_tps"` Stable bool `json:"stable"` From 71a5c60c2e89ed4762f94932c2644c2cee392b49 Mon Sep 17 00:00:00 2001 From: Alain Date: Sun, 24 May 2026 00:07:55 +0800 Subject: [PATCH 46/52] =?UTF-8?q?feat:=20=E6=9B=B4=E6=96=B0=E7=83=AD?= =?UTF-8?q?=E9=94=AE=E6=B8=B2=E6=9F=93=EF=BC=8C=E6=B7=BB=E5=8A=A0=E6=97=B6?= =?UTF-8?q?=E9=97=B4=E6=88=B3=E5=92=8C=E9=A1=B9=E7=9B=AE=E9=93=BE=E6=8E=A5?= =?UTF-8?q?=EF=BC=8C=E5=A2=9E=E5=BC=BA=E7=95=8C=E9=9D=A2=E4=BF=A1=E6=81=AF?= =?UTF-8?q?=E5=B1=95=E7=A4=BA?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/tui/pages/helpers.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/internal/tui/pages/helpers.go b/internal/tui/pages/helpers.go index 9680ba6..a5fe560 100644 --- a/internal/tui/pages/helpers.go +++ b/internal/tui/pages/helpers.go @@ -266,8 +266,9 @@ func renderHotkeys(st Styles, width int, hk PageHotkeys) string { hkLine := renderPrimaryHotkeyItems(hk.Hotkeys, maxInt(8, w-4)) line1 := renderChromeLine(st.HotkeysPrimary, w, " "+hkLine, "") - appStamp := lipgloss.NewStyle().Foreground(colorPink).Bold(true).Render("AIT") + - lipgloss.NewStyle().Foreground(colorMuted).Render(" 终端 · "+time.Now().Format("15:04")) + appStamp := lipgloss.NewStyle().Background(colorHotkeysSecondaryBg).Foreground(colorMuted).Render(time.Now().Format("2006-01-02 15:04:05")+" ") + + lipgloss.NewStyle().Background(colorHotkeysSecondaryBg).Foreground(colorPink).Bold(true).Render("github.com/yinxulai/ait") + + lipgloss.NewStyle().Background(colorHotkeysSecondaryBg).Foreground(colorMuted).Render(" Powered by Alain") left2 := renderSecondaryHotkeyItems(hk.Hints, maxInt(8, w-lipgloss.Width(appStamp)-4)) line2 := renderChromeLine(st.HotkeysSecondary, w, " "+left2, appStamp+" ") From 191818027174cec7a0fa7409371307f8e738a90f Mon Sep 17 00:00:00 2001 From: Alain Date: Sun, 24 May 2026 08:49:23 +0800 Subject: [PATCH 47/52] feat(i18n): add Chinese/English UI string translations with support for dynamic language switching - Introduced a new package `i18n` for managing translations. - Added support for Chinese (default) and English languages. - Implemented a translation key system for UI elements and messages. - Created a function `T` to retrieve translations based on the active language. - Included helper functions for display width calculations to accommodate CJK characters. --- internal/i18n/i18n.go | 1052 ++++++++++++++++++++++++++++++ internal/tui/pages/contextbar.go | 138 ++-- internal/tui/pages/dashboard.go | 48 +- internal/tui/pages/help.go | 103 +-- internal/tui/pages/helpers.go | 91 ++- internal/tui/pages/layout.go | 10 +- internal/tui/pages/proxy.go | 23 +- internal/tui/pages/reqdetail.go | 70 +- internal/tui/pages/taskdetail.go | 87 +-- internal/tui/pages/tasklist.go | 41 +- internal/tui/pages/turbodash.go | 49 +- internal/tui/pages/wizard.go | 181 ++--- 12 files changed, 1509 insertions(+), 384 deletions(-) create mode 100644 internal/i18n/i18n.go diff --git a/internal/i18n/i18n.go b/internal/i18n/i18n.go new file mode 100644 index 0000000..70044b8 --- /dev/null +++ b/internal/i18n/i18n.go @@ -0,0 +1,1052 @@ +// Package i18n provides simple Chinese/English UI string translations. +// Default language is ZH (Chinese). Call SetLang(EN) to switch to English. +// +// Alignment problem: TUI label groups (e.g. metrics panel) pad labels to equal +// display width. Because CJK chars are 2 display columns wide, use +// DisplayWidth() to measure and maxLabelWidth() in helpers.go to auto-compute. +package i18n + +import ( + "sync/atomic" + + "github.com/mattn/go-runewidth" +) + +// Lang represents a supported language. +type Lang int32 + +const ( + ZH Lang = iota // Chinese (default) + EN // English +) + +var active atomic.Int32 + +// SetLang switches the active UI language. +func SetLang(l Lang) { active.Store(int32(l)) } + +// Active returns the currently active language. +func Active() Lang { return Lang(active.Load()) } + +// Key is a typed string resource key. +type Key int + +const ( + // ─── Hotkeys ───────────────────────────────────────────────────────────── + KHelp Key = iota + KViewDetails + KRun + KNewTask + KEdit + KDelete + KCopy + KCopyTask + KProxyConfig + KStop + KSelectRecord + KViewRunDetails + KRunAgain + KExportJSONReport + KExportHistoryJSON + KGoToLiveDash + KSwitchField + KSwitchProtocol + KNextStep + KBackToList + KSwitchOption + KGoBack + KScroll + KPageTurn + KSave + KSaveAndRun + KBackToEdit + KGenerateReport + KViewRequest + KSelectRequest + KViewLevelReqs + KSelectItem + KBackToDash + KPrevNextReq + KSwitchType + KClear + KTopBottom + KConfirmDelete + KCancel + + // ─── Hint texts (used in NewPageHotkeys hints) ──────────────────────────── + KHintQuit // [q] 退出 + KHintCtrlCQuit // [Ctrl+C] 退出 + KHintBack // [b/Esc] 返回 + KHintEscBack // [Esc] 返回 + KHintGoBack // [b/Esc] 返回上一页 + KHintSelect // [↑↓] 选择 + KHintNew // [a] 新建 + + // ─── Metric labels ─────────────────────────────────────────────────────── + KSuccessRate + KAvgTPS + KAvgTTFT + KCacheHit + KRPM + KTPM + KStatus + KTotalTime + KTTFT + KOutputTPS + KToken + KCache + KError + KDNS + KTCPConnect + KTLSHandshake + KTargetIP + + // ─── Status values ─────────────────────────────────────────────────────── + KRunning + KCompleted + KRunFailed + KStopped + KWaitingStatus + + // ─── Dashboard / progress ───────────────────────────────────────────────── + KProgress + KSuccessCount + KFailureCount + KWaitingDots + KRamp + KPerLevel + KStopCondLabel + KTurboMode + KStandardMode + KTurboMonitor + KTurboModeMeta + KSuccessRateFmt // "成功率 %.1f%%" + KTurboCurLevelFmt // "当前级别实时指标 [并发 = %d]" + KTurboDashSuffix // " %d/%d 并发 %d 进度 %s 级" + + // ─── Layout ────────────────────────────────────────────────────────────── + KWindowTooSmall + KWaitingData + KNotRecorded + KScrollMore + KTerminalLabel + KInProgress + KNoHotkeys // "当前页暂无快捷操作" + + // ─── Page subtitles ─────────────────────────────────────────────────────── + KStdMonitorTitle // "标准运行监控" + KStdMonitorSubtitle // "实时查看运行进度、吞吐和单请求明细" + KTurboSubtitle // "观察并发爬坡过程、级别指标与稳定区间" + KTaskListSubtitle // "创建任务、运行压测、查看执行记录与导出报告" + KTaskDetailSubtitle // "查看任务配置、当前运行状态与历史记录" + KReqDetailSubtitle // "查看单次请求的耗时、网络阶段和完整报文" + KConcFmt // "并发%d" + + // ─── TaskDetail / ReqDetail fields ─────────────────────────────────────── + KProtocol + KEndpoint + KProxy + KModel + KMode + KConcurrency + KStepLabel + KRequests + KTimeout + KStream + KPromptLabel + KNoRunRecords + KRecordDetails + KStart + KEnd + KElapsed + KErrorSummary + KRequestBody + KResponseBody + + // ─── Table column headers ───────────────────────────────────────────────── + KColTime + KTime // alias for table display + KColMode + KColProtocol + KColSuccessRate + KColCacheHit + KColAvgTTFT + KColAvgTPS + KColLevel + KColInput + KColOutput + + // ─── TaskList ──────────────────────────────────────────────────────────── + KTaskName + KTaskID + KLastRun + KIrreversible + KNoTasks + KRunHistory + KTaskCenter + KNoRunHistory + KConfirmDeletePrompt + + // ─── Proxy ─────────────────────────────────────────────────────────────── + KExSOCKS5 + KExSSH + KExHTTP + KProxySubtitle + KProxySaveHint + KGlobalConfig + KProxyType + KProxyURL + + // ─── Help page ─────────────────────────────────────────────────────────── + KHelpTitle + KHelpSubtitle + KHelpMeta + KHelpSecConcepts + KHelpSecMetrics + KHelpSecProtocols + KHelpSecGlobal + KHelpSecTaskList + KHelpSecTaskDetail + KHelpSecDashboard + KHelpSecExport + KHelpTermTask + KHelpDescTask + KHelpTermRun + KHelpDescRun + KHelpTermStandard + KHelpDescStandard + KHelpTermTurboMode + KHelpDescTurboMode + KHelpTermTPS + KHelpDescTPS + KHelpTermAvgTPS + KHelpDescAvgTPS + KHelpTermTTFT + KHelpDescTTFT + KHelpTermAvgTTFT + KHelpDescAvgTTFT + KHelpTermSuccessRate + KHelpDescSuccessRate + KHelpTermCacheHit + KHelpDescCacheHit + KHelpTermConcurrencyTurbo + KHelpDescConcurrencyTurbo + KHelpTermOpenAI + KHelpDescOpenAI + KHelpTermAnthropic + KHelpDescAnthropic + KHelpTermQuit + KHelpDescQuit + KHelpTermQuestionMark + KHelpDescQuestionMark + KHelpTermBack + KHelpDescBack + KHelpTermSelectTask + KHelpDescSelectTask + KHelpTermEnterDetail + KHelpDescEnterDetail + KHelpTermRunTask + KHelpDescRunTask + KHelpTermStopTask + KHelpDescStopTask + KHelpTermNewTask + KHelpDescNewTask + KHelpTermEditTask + KHelpDescEditTask + KHelpTermDeleteTask + KHelpDescDeleteTask + KHelpTermCopyTask + KHelpDescCopyTask + KHelpTermProxy + KHelpDescProxy + KHelpTermSelectHistory + KHelpDescSelectHistory + KHelpTermEnterDash + KHelpDescEnterDash + KHelpTermRunAgain + KHelpDescRunAgain + KHelpTermExport + KHelpDescExport + KHelpTermEditConfig + KHelpDescEditConfig + KHelpTermCopyTask2 + KHelpDescCopyTask2 + KHelpTermDeleteTask2 + KHelpDescDeleteTask2 + KHelpTermSelectReq + KHelpDescSelectReq + KHelpTermViewReq + KHelpDescViewReq + KHelpTermStopDash + KHelpDescStopDash + KHelpTermGenerateReport + KHelpDescGenerateReport + KHelpTermBackDash + KHelpDescBackDash + KHelpTermJSONReport + KHelpDescJSONReport + KHelpTermCSVReport + KHelpDescCSVReport + + // ─── Wizard ────────────────────────────────────────────────────────────── + KWzTaskName + KWzProtocol + KWzEndpoint + KWzAPIKey + KWzTestModel + KWzTestMode + KWzTurboMode + KWzStandardMode + KWzConcurrency + KWzTotalRequests + KWzTimeoutSecs + KWzInitConc + KWzMaxConc + KWzStepSize + KWzLevelReqs + KWzMinSuccessRate + KWzStreamMode + KWzInputMode + KWzInputDirect + KWzInputFile + KWzInputGenerated + KWzInputRaw + KWzPromptConfig + KWzSelectModeHint + KWzTurboModeLabel + KWzStepFmt // "步骤 %d/3" + KWzStep1Label + KWzStep2Label + KWzStep3Label + KWzStep1Desc + KWzStep2Desc + KWzStep3Desc + KWzUntitled + KWzNotFilled + KWzExecParams + KWzConcurrencyRamp + KWzStopCondition + KWzTimeoutLabel + KWzContentSummary + KWzBodyBytes + KWzSaveLocation + KWzPromptSection + KWzHintDirect + KWzHintFile + KWzHintRaw + KWzHintCacheToken + KWzHintRawBody + KWzJSONBody + KWzPromptLabelShort + KWzRAWBody + KWzFileSummary + KWzGeneratedFmt // "生成 %d 字符" + KWzPromptContent // field label for prompt content input + KWzNoConfirmItems // "暂无确认项" + KWzConfirmRange // "确认项 %d-%d/%d" + KWzConfirmTotal // "共 %d 项待确认" + KWzNoFields // "暂无配置项" + KWzFieldProgress // "当前字段 %d/%d" + + // ─── Misc ──────────────────────────────────────────────────────────────── + KEnabled + KDisabled + KFileSummaryPfx // "文件: " + KNotSet // "(未设置)" + KJustNow // "刚刚" + KMinutesAgoFmt // "%d 分钟前" + KHoursAgoFmt // "%d 小时前" + KDaysAgoFmt // "%d 天前" +) + +var translations = [2]map[Key]string{ + ZH: { + // Hotkeys + KHelp: "帮助", + KViewDetails: "查看详情", + KRun: "运行", + KNewTask: "新建任务", + KEdit: "编辑", + KDelete: "删除", + KCopy: "复制", + KCopyTask: "复制任务", + KProxyConfig: "代理配置", + KStop: "停止", + KSelectRecord: "选择记录", + KViewRunDetails: "查看运行详情", + KRunAgain: "再次运行", + KExportJSONReport: "导出 JSON 报告", + KExportHistoryJSON: "导出历史 JSON", + KGoToLiveDash: "进入运行中仪表盘", + KSwitchField: "切换字段", + KSwitchProtocol: "切换协议", + KNextStep: "下一步", + KBackToList: "返回列表", + KSwitchOption: "切换选项", + KGoBack: "返回上一步", + KScroll: "滚动", + KPageTurn: "翻页", + KSave: "保存", + KSaveAndRun: "保存并运行", + KBackToEdit: "返回修改", + KGenerateReport: "生成报告", + KViewRequest: "查看请求详情", + KSelectRequest: "选择请求", + KViewLevelReqs: "查看该级别请求", + KSelectItem: "选择", + KBackToDash: "返回仪表盘", + KPrevNextReq: "上/下一条请求", + KSwitchType: "切换类型", + KClear: "清空", + KTopBottom: "顶部/底部", + KConfirmDelete: "确认删除", + KCancel: "取消", + + // Hints + KHintQuit: "[q] 退出", + KHintCtrlCQuit: "[Ctrl+C] 退出", + KHintBack: "[b/Esc] 返回", + KHintEscBack: "[Esc] 返回", + KHintGoBack: "[b/Esc] 返回上一页", + KHintSelect: "[↑↓] 选择", + KHintNew: "[a] 新建", + + // Metric labels + KSuccessRate: "成功率", + KAvgTPS: "TPS均值", + KAvgTTFT: "TTFT均值", + KCacheHit: "缓存命中", + KRPM: "RPM", + KTPM: "TPM", + KStatus: "状态", + KTotalTime: "总耗时", + KTTFT: "TTFT", + KOutputTPS: "输出TPS", + KToken: "Token", + KCache: "缓存", + KError: "错误", + KDNS: "DNS", + KTCPConnect: "TCP 连接", + KTLSHandshake: "TLS 握手", + KTargetIP: "目标 IP", + + // Status values + KRunning: "运行中", + KCompleted: "已完成", + KRunFailed: "运行失败", + KStopped: "已停止", + KWaitingStatus: "等待数据", + + // Dashboard / progress + KProgress: "进度", + KSuccessCount: "成功", + KFailureCount: "失败", + KWaitingDots: "等待中...", + KRamp: "爬坡", + KPerLevel: "每级", + KStopCondLabel: "停止", + KTurboMode: "Turbo 模式", + KStandardMode: "标准", + KTurboMonitor: "Turbo 探测监控", + KTurboModeMeta: "Turbo 模式", + KSuccessRateFmt: "成功率 %.1f%%", + KTurboCurLevelFmt: "当前级别实时指标 [并发 = %d]", + KTurboDashSuffix: " %d/%d 并发 %d 进度 %s 级", + + // Layout + KWindowTooSmall: "窗口过小 ↔ 请放大终端", + KWaitingData: "等待数据...", + KNotRecorded: "(未记录)", + KScrollMore: "↑↓ 滚动查看完整内容", + KTerminalLabel: "终端", + KInProgress: "进行中", + KNoHotkeys: "当前页暂无快捷操作", + KStdMonitorTitle: "标准运行监控", + KStdMonitorSubtitle: "实时查看运行进度、吸吐和单请求明细", + KTurboSubtitle: "观察并发爬坡过程、级别指标与稳定区间", + KTaskListSubtitle: "创建任务、运行压测、查看执行记录与导出报告", + KTaskDetailSubtitle: "查看任务配置、当前运行状态与历史记录", + KReqDetailSubtitle: "查看单次请求的耗时、网络阶段和完整报文", + KConcFmt: "并发%d", + + // Fields + KProtocol: "协议", + KEndpoint: "接口", + KProxy: "代理", + KModel: "模型", + KMode: "模式", + KConcurrency: "并发", + KStepLabel: "步进", + KRequests: "请求", + KTimeout: "超时", + KStream: "流式", + KPromptLabel: "Prompt", + KNoRunRecords: "暂无运行记录", + KRecordDetails: "记录详情", + KStart: "开始", + KEnd: "结束", + KElapsed: "耗时", + KErrorSummary: "错误摘要", + KRequestBody: "请求体 (Request Body)", + KResponseBody: "响应体 (Response Body)", + + // Table column headers + KColTime: "时间", + KTime: "时间", + KColMode: "模式", + KColProtocol: "协议", + KColSuccessRate: "成功率", + KColCacheHit: "缓存命中", + KColAvgTTFT: "均值TTFT", + KColAvgTPS: "均值TPS", + KColLevel: "级别", + KColInput: "输入", + KColOutput: "输出", + + // TaskList + KTaskName: "任务名称", + KTaskID: "任务 ID", + KLastRun: "上次运行", + KIrreversible: "此操作不可恢复,任务的历史运行记录将一并删除。", + KNoTasks: "暂无任务 按 [a] 新建第一个任务", + KRunHistory: "历史运行记录", + KTaskCenter: "任务中心", + KNoRunHistory: "暂无运行历史", + KConfirmDeletePrompt: "确认删除任务?", + + // Proxy + KExSOCKS5: "示例: socks5://127.0.0.1:1080", + KExSSH: "示例: ssh://user@host:22", + KExHTTP: "示例: http://127.0.0.1:7890", + KProxySubtitle: "设置全局 HTTP 代理,适用于所有任务的请求。留空则使用系统环境变量或直连。", + KProxySaveHint: "配置保存至 ~/.ait/config.json,重启无需重新输入。", + KGlobalConfig: "全局配置", + KProxyType: "代理类型", + KProxyURL: "代理地址", + + // Help page + KHelpTitle: "帮助", + KHelpSubtitle: "AIT — AI 接口压测工具概念说明与操作指南", + KHelpMeta: "帮助", + KHelpSecConcepts: "核心概念", + KHelpSecMetrics: "性能指标", + KHelpSecProtocols: "协议支持", + KHelpSecGlobal: "快捷键 — 全局", + KHelpSecTaskList: "快捷键 — 任务列表", + KHelpSecTaskDetail: "快捷键 — 任务详情", + KHelpSecDashboard: "快捷键 — 运行仪表盘", + KHelpSecExport: "报告导出", + + KHelpTermTask: "任务 (Task)", + KHelpDescTask: "一组压测配置的集合,包含目标接口、模型、并发数、请求数等参数。任务可多次运行,每次运行独立记录结果。", + KHelpTermRun: "运行 (Run)", + KHelpDescRun: "任务的一次具体执行。每次运行产生独立的指标数据和请求记录,可导出为 JSON/CSV 报告。", + KHelpTermStandard: "标准模式", + KHelpDescStandard: "以固定并发数执行所有请求,适合衡量稳定负载下的接口性能。", + KHelpTermTurboMode: "Turbo 模式", + KHelpDescTurboMode: "自动从低并发逐步爬坡,找出接口在保持成功率要求下能承受的最大稳定并发数。", + + KHelpTermTPS: "TPS", + KHelpDescTPS: "Tokens Per Second,每秒输出 Token 数,衡量模型的文本生成速率。", + KHelpTermAvgTPS: "均值TPS", + KHelpDescAvgTPS: "本次运行中所有请求的 TPS 均值,反映整体吞吐水平。", + KHelpTermTTFT: "TTFT", + KHelpDescTTFT: "Time To First Token,从发送请求到收到第一个 Token 的耗时,衡量模型响应延迟。", + KHelpTermAvgTTFT: "均值TTFT", + KHelpDescAvgTTFT: "本次运行中所有请求的 TTFT 均值。", + + KHelpTermSuccessRate: "成功率", + KHelpDescSuccessRate: "成功完成的请求数占总请求数的百分比。失败包括超时、HTTP 错误、模型返回错误等。", + KHelpTermCacheHit: "缓存命中", + KHelpDescCacheHit: "请求中使用了 KV 缓存(Prompt Cache)的比例。命中缓存可显著降低 TTFT 和推理成本。该指标为二值统计:单次请求若有任何 Token 命中缓存则计为命中。", + KHelpTermConcurrencyTurbo: "并发(Turbo)", + KHelpDescConcurrencyTurbo: "Turbo 模式下找到的最大稳定并发数,即在满足最低成功率要求的前提下能同时维持的请求数。", + + KHelpTermOpenAI: "OpenAI", + KHelpDescOpenAI: "兼容 OpenAI Chat Completions API(/v1/chat/completions),支持流式和非流式响应。", + KHelpTermAnthropic: "Anthropic", + KHelpDescAnthropic: "兼容 Anthropic Messages API(/v1/messages),支持流式和非流式响应。", + + KHelpTermQuit: "q / Ctrl+C", + KHelpDescQuit: "退出程序。", + KHelpTermQuestionMark: "?", + KHelpDescQuestionMark: "打开此帮助页。", + KHelpTermBack: "b / Esc", + KHelpDescBack: "返回上一页。", + + KHelpTermSelectTask: "↑↓ / j k", + KHelpDescSelectTask: "选择任务。", + KHelpTermEnterDetail: "Enter", + KHelpDescEnterDetail: "进入任务详情页。", + KHelpTermRunTask: "r", + KHelpDescRunTask: "立即运行选中任务。", + KHelpTermStopTask: "s", + KHelpDescStopTask: "停止正在运行的任务(仅任务运行中可用)。", + KHelpTermNewTask: "a", + KHelpDescNewTask: "新建任务(打开向导)。", + KHelpTermEditTask: "e", + KHelpDescEditTask: "编辑选中任务配置。", + KHelpTermDeleteTask: "d", + KHelpDescDeleteTask: "删除选中任务(需确认)。", + KHelpTermCopyTask: "y", + KHelpDescCopyTask: "复制选中任务(生成副本)。", + KHelpTermProxy: "p", + KHelpDescProxy: "打开代理配置页。", + + KHelpTermSelectHistory: "↑↓ / j k", + KHelpDescSelectHistory: "在历史运行记录中选择条目。", + KHelpTermEnterDash: "Enter", + KHelpDescEnterDash: "查看选中运行的仪表盘;若任务正在运行,进入实时仪表盘。", + KHelpTermRunAgain: "r", + KHelpDescRunAgain: "再次运行该任务(无正在运行的实例时可用)。", + KHelpTermExport: "g", + KHelpDescExport: "将选中的历史运行导出为 JSON 报告。", + KHelpTermEditConfig: "e", + KHelpDescEditConfig: "编辑任务配置。", + KHelpTermCopyTask2: "y", + KHelpDescCopyTask2: "复制任务。", + KHelpTermDeleteTask2: "d", + KHelpDescDeleteTask2: "删除任务。", + + KHelpTermSelectReq: "↑↓ / j k", + KHelpDescSelectReq: "选择请求条目。", + KHelpTermViewReq: "Enter", + KHelpDescViewReq: "查看选中请求的详情(耗时、Token、响应体等)。", + KHelpTermStopDash: "s", + KHelpDescStopDash: "停止正在运行的任务。", + KHelpTermGenerateReport: "r", + KHelpDescGenerateReport: "生成 JSON 报告(运行结束后可用)。", + KHelpTermBackDash: "b / Esc", + KHelpDescBackDash: "返回任务详情页。", + + KHelpTermJSONReport: "JSON 报告", + KHelpDescJSONReport: "完整记录每次请求的所有指标、请求/响应体,适合程序化分析。", + KHelpTermCSVReport: "CSV 报告", + KHelpDescCSVReport: "表格形式的汇总数据,可直接在电子表格中打开。报告默认保存在当前工作目录。", + + // Wizard fields + KWzTaskName: "任务名称", + KWzProtocol: "协议类型", + KWzEndpoint: "接口地址", + KWzAPIKey: "API 密钥", + KWzTestModel: "测试模型", + KWzTestMode: "测试模式", + KWzTurboMode: "Turbo 模式", + KWzStandardMode: "标准模式", + KWzConcurrency: "并发数", + KWzTotalRequests: "请求总数", + KWzTimeoutSecs: "超时(秒)", + KWzInitConc: "初始并发", + KWzMaxConc: "最大并发", + KWzStepSize: "步进值", + KWzLevelReqs: "每级请求数", + KWzMinSuccessRate: "最低成功率", + KWzStreamMode: "流式模式", + KWzInputMode: "输入方式", + KWzInputDirect: "直接输入", + KWzInputFile: "文件", + KWzInputGenerated: "按长度生成", + KWzInputRaw: "RAW 请求体", + KWzPromptConfig: "Prompt 配置", + KWzSelectModeHint: "选择压测模式,并补全并发与 Prompt 参数。", + KWzTurboModeLabel: "Turbo 模式", + KWzStepFmt: "步骤 %d/3", + KWzStep1Label: "1 基本信息", + KWzStep2Label: "2 测试参数", + KWzStep3Label: "3 确认保存", + KWzStep1Desc: "配置任务名称、模型协议和连接信息。", + KWzStep2Desc: "选择压测模式,并补全并发与 Prompt 参数。", + KWzStep3Desc: "保存前快速检查关键配置。", + KWzUntitled: "未命名任务", + KWzNotFilled: "未填写", + KWzExecParams: "执行参数", + KWzConcurrencyRamp: "并发爬坡", + KWzStopCondition: "停止条件", + KWzTimeoutLabel: "超时", + KWzContentSummary: "内容摘要", + KWzBodyBytes: "Body 字节数", + KWzSaveLocation: "保存位置", + KWzPromptSection: "Prompt", + KWzHintDirect: "直接粘贴或输入 Prompt 文本,所有请求共享同一段内容", + KWzHintFile: "从文件读取 Prompt,支持通配符匹配多个文件(请求按文件轮换)", + KWzHintRaw: "粘贴完整的 HTTP 请求 JSON Body,将跳过参数组装直接发送", + KWzHintCacheToken: "提示:大多数服务需要 ≥ 1024 tokens 才能命中缓存", + KWzHintRawBody: "提示:粘贴 API 请求的完整 JSON Body,将直接作为 HTTP 请求体发送", + KWzJSONBody: "JSON Body", + KWzPromptLabelShort: "Prompt", + KWzRAWBody: "RAW 请求体", + KWzFileSummary: "文件", + KWzGeneratedFmt: "生成 %d 字符", + KWzPromptContent: "内容", + KWzNoConfirmItems: "暂无确认项", + KWzConfirmRange: "确认项 %d-%d/%d", + KWzConfirmTotal: "共 %d 项待确认", + KWzNoFields: "暂无配置项", + KWzFieldProgress: "当前字段 %d/%d", + + // Misc + KEnabled: "开启", + KDisabled: "关闭", + KFileSummaryPfx: "文件: ", + KNotSet: "(未设置)", + KJustNow: "刚刚", + KMinutesAgoFmt: "%d 分钟前", + KHoursAgoFmt: "%d 小时前", + KDaysAgoFmt: "%d 天前", + }, + EN: { + // Hotkeys + KHelp: "Help", + KViewDetails: "View Details", + KRun: "Run", + KNewTask: "New Task", + KEdit: "Edit", + KDelete: "Delete", + KCopy: "Copy", + KCopyTask: "Copy Task", + KProxyConfig: "Proxy Config", + KStop: "Stop", + KSelectRecord: "Select Record", + KViewRunDetails: "View Run Details", + KRunAgain: "Run Again", + KExportJSONReport: "Export JSON Report", + KExportHistoryJSON: "Export History JSON", + KGoToLiveDash: "Live Dashboard", + KSwitchField: "Switch Field", + KSwitchProtocol: "Switch Protocol", + KNextStep: "Next", + KBackToList: "Back to List", + KSwitchOption: "Switch Option", + KGoBack: "Go Back", + KScroll: "Scroll", + KPageTurn: "Page", + KSave: "Save", + KSaveAndRun: "Save & Run", + KBackToEdit: "Back to Edit", + KGenerateReport: "Generate Report", + KViewRequest: "View Request", + KSelectRequest: "Select Request", + KViewLevelReqs: "View Level Requests", + KSelectItem: "Select", + KBackToDash: "Back to Dashboard", + KPrevNextReq: "Prev/Next Request", + KSwitchType: "Switch Type", + KClear: "Clear", + KTopBottom: "Top/Bottom", + KConfirmDelete: "Confirm Delete", + KCancel: "Cancel", + + // Hints + KHintQuit: "[q] Quit", + KHintCtrlCQuit: "[Ctrl+C] Quit", + KHintBack: "[b/Esc] Back", + KHintEscBack: "[Esc] Back", + KHintGoBack: "[b/Esc] Go Back", + KHintSelect: "[↑↓] Select", + KHintNew: "[a] New", + + // Metric labels + KSuccessRate: "Success Rate", + KAvgTPS: "Avg TPS", + KAvgTTFT: "Avg TTFT", + KCacheHit: "Cache Hit", + KRPM: "RPM", + KTPM: "TPM", + KStatus: "Status", + KTotalTime: "Total Time", + KTTFT: "TTFT", + KOutputTPS: "Output TPS", + KToken: "Token", + KCache: "Cache", + KError: "Error", + KDNS: "DNS", + KTCPConnect: "TCP Connect", + KTLSHandshake: "TLS Handshake", + KTargetIP: "Target IP", + + // Status values + KRunning: "Running", + KCompleted: "Done", + KRunFailed: "Failed", + KStopped: "Stopped", + KWaitingStatus: "Waiting", + + // Dashboard / progress + KProgress: "Progress", + KSuccessCount: "Success", + KFailureCount: "Failed", + KWaitingDots: "Waiting...", + KRamp: "Ramp", + KPerLevel: "Per Level", + KStopCondLabel: "Stop", + KTurboMode: "Turbo Mode", + KStandardMode: "Standard", + KTurboMonitor: "Turbo Probe Monitor", + KTurboModeMeta: "Turbo Mode", + KSuccessRateFmt: "Success %.1f%%", + KTurboCurLevelFmt: "Current Level Metrics [Concurrency = %d]", + KTurboDashSuffix: " %d/%d Level %d Progress %s", + + // Layout + KWindowTooSmall: "Terminal too small ↔ please resize", + KWaitingData: "Waiting for data...", + KNotRecorded: "(not recorded)", + KScrollMore: "↑↓ scroll to view full content", + KTerminalLabel: "Terminal", + KInProgress: "In Progress", + KNoHotkeys: "No shortcuts on this page", + KStdMonitorTitle: "Standard Run Monitor", + KStdMonitorSubtitle: "Live view of run progress, throughput and per-request details", + KTurboSubtitle: "Observe concurrency ramp, level metrics and stable range", + KTaskListSubtitle: "Create tasks, run benchmarks, view run history and export reports", + KTaskDetailSubtitle: "View task configuration, current run state and history", + KReqDetailSubtitle: "View latency, network phases and full payload of a single request", + KConcFmt: "Conc %d", + + // Fields + KProtocol: "Protocol", + KEndpoint: "Endpoint", + KProxy: "Proxy", + KModel: "Model", + KMode: "Mode", + KConcurrency: "Concurrency", + KStepLabel: "Step", + KRequests: "Requests", + KTimeout: "Timeout", + KStream: "Stream", + KPromptLabel: "Prompt", + KNoRunRecords: "No run records", + KRecordDetails: "Run Details", + KStart: "Start", + KEnd: "End", + KElapsed: "Elapsed", + KErrorSummary: "Error Summary", + KRequestBody: "Request Body", + KResponseBody: "Response Body", + + // Table column headers + KColTime: "Time", + KTime: "Time", + KColMode: "Mode", + KColProtocol: "Protocol", + KColSuccessRate: "Success%", + KColCacheHit: "Cache Hit", + KColAvgTTFT: "Avg TTFT", + KColAvgTPS: "Avg TPS", + KColLevel: "Level", + KColInput: "Input", + KColOutput: "Output", + + // TaskList + KTaskName: "Task Name", + KTaskID: "Task ID", + KLastRun: "Last Run", + KIrreversible: "This action is irreversible. All run history will also be deleted.", + KNoTasks: "No tasks · Press [a] to create the first task", + KRunHistory: "Run History", + KTaskCenter: "Tasks", + KNoRunHistory: "No run history", + KConfirmDeletePrompt: "Delete this task?", + + // Proxy + KExSOCKS5: "Example: socks5://127.0.0.1:1080", + KExSSH: "Example: ssh://user@host:22", + KExHTTP: "Example: http://127.0.0.1:7890", + KProxySubtitle: "Set a global HTTP proxy for all task requests. Leave blank to use system env or direct connection.", + KProxySaveHint: "Config saved to ~/.ait/config.json. No need to re-enter after restart.", + KGlobalConfig: "Global Config", + KProxyType: "Proxy Type", + KProxyURL: "Proxy URL", + + // Help page + KHelpTitle: "Help", + KHelpSubtitle: "AIT — AI Load Testing Tool: Concepts & Usage Guide", + KHelpMeta: "Help", + KHelpSecConcepts: "Core Concepts", + KHelpSecMetrics: "Performance Metrics", + KHelpSecProtocols: "Protocol Support", + KHelpSecGlobal: "Hotkeys — Global", + KHelpSecTaskList: "Hotkeys — Task List", + KHelpSecTaskDetail: "Hotkeys — Task Detail", + KHelpSecDashboard: "Hotkeys — Dashboard", + KHelpSecExport: "Report Export", + + KHelpTermTask: "Task", + KHelpDescTask: "A set of load test configurations including target endpoint, model, concurrency, and request count. A task can be run multiple times, each run recorded independently.", + KHelpTermRun: "Run", + KHelpDescRun: "A single execution of a task. Each run produces independent metric data and request records, exportable as JSON/CSV reports.", + KHelpTermStandard: "Standard Mode", + KHelpDescStandard: "Executes all requests at a fixed concurrency level, ideal for measuring interface performance under steady load.", + KHelpTermTurboMode: "Turbo Mode", + KHelpDescTurboMode: "Automatically ramps up concurrency to find the maximum stable concurrency the interface can sustain while meeting the success rate requirement.", + + KHelpTermTPS: "TPS", + KHelpDescTPS: "Tokens Per Second — output token generation rate of the model.", + KHelpTermAvgTPS: "Avg TPS", + KHelpDescAvgTPS: "Mean TPS across all requests in this run, reflecting overall throughput.", + KHelpTermTTFT: "TTFT", + KHelpDescTTFT: "Time To First Token — latency from sending the request to receiving the first token.", + KHelpTermAvgTTFT: "Avg TTFT", + KHelpDescAvgTTFT: "Mean TTFT across all requests in this run.", + + KHelpTermSuccessRate: "Success Rate", + KHelpDescSuccessRate: "Percentage of requests that completed successfully. Failures include timeouts, HTTP errors, and model errors.", + KHelpTermCacheHit: "Cache Hit", + KHelpDescCacheHit: "Ratio of requests that used KV cache (Prompt Cache). Cache hits significantly reduce TTFT and inference cost. Binary metric: a request counts as a hit if any tokens were served from cache.", + KHelpTermConcurrencyTurbo: "Concurrency (Turbo)", + KHelpDescConcurrencyTurbo: "Maximum stable concurrency found by Turbo mode — the number of simultaneous requests sustainable while meeting the minimum success rate.", + + KHelpTermOpenAI: "OpenAI", + KHelpDescOpenAI: "Compatible with OpenAI Chat Completions API (/v1/chat/completions), supporting both streaming and non-streaming responses.", + KHelpTermAnthropic: "Anthropic", + KHelpDescAnthropic: "Compatible with Anthropic Messages API (/v1/messages), supporting both streaming and non-streaming responses.", + + KHelpTermQuit: "q / Ctrl+C", + KHelpDescQuit: "Quit the program.", + KHelpTermQuestionMark: "?", + KHelpDescQuestionMark: "Open this help page.", + KHelpTermBack: "b / Esc", + KHelpDescBack: "Go back to the previous page.", + + KHelpTermSelectTask: "↑↓ / j k", + KHelpDescSelectTask: "Select a task.", + KHelpTermEnterDetail: "Enter", + KHelpDescEnterDetail: "Enter task detail page.", + KHelpTermRunTask: "r", + KHelpDescRunTask: "Run the selected task immediately.", + KHelpTermStopTask: "s", + KHelpDescStopTask: "Stop the running task (only available while running).", + KHelpTermNewTask: "a", + KHelpDescNewTask: "Create a new task (opens wizard).", + KHelpTermEditTask: "e", + KHelpDescEditTask: "Edit the selected task configuration.", + KHelpTermDeleteTask: "d", + KHelpDescDeleteTask: "Delete the selected task (requires confirmation).", + KHelpTermCopyTask: "y", + KHelpDescCopyTask: "Copy the selected task (creates a duplicate).", + KHelpTermProxy: "p", + KHelpDescProxy: "Open proxy configuration page.", + + KHelpTermSelectHistory: "↑↓ / j k", + KHelpDescSelectHistory: "Select an entry in the run history.", + KHelpTermEnterDash: "Enter", + KHelpDescEnterDash: "View the dashboard for the selected run; enters live dashboard if the task is running.", + KHelpTermRunAgain: "r", + KHelpDescRunAgain: "Run the task again (available when no instance is running).", + KHelpTermExport: "g", + KHelpDescExport: "Export the selected run as a JSON report.", + KHelpTermEditConfig: "e", + KHelpDescEditConfig: "Edit the task configuration.", + KHelpTermCopyTask2: "y", + KHelpDescCopyTask2: "Copy the task.", + KHelpTermDeleteTask2: "d", + KHelpDescDeleteTask2: "Delete the task.", + + KHelpTermSelectReq: "↑↓ / j k", + KHelpDescSelectReq: "Select a request entry.", + KHelpTermViewReq: "Enter", + KHelpDescViewReq: "View request details (latency, tokens, response body, etc.).", + KHelpTermStopDash: "s", + KHelpDescStopDash: "Stop the running task.", + KHelpTermGenerateReport: "r", + KHelpDescGenerateReport: "Generate a JSON report (available after run completes).", + KHelpTermBackDash: "b / Esc", + KHelpDescBackDash: "Return to task detail page.", + + KHelpTermJSONReport: "JSON Report", + KHelpDescJSONReport: "Complete record of all metrics, request/response bodies for each request. Suitable for programmatic analysis.", + KHelpTermCSVReport: "CSV Report", + KHelpDescCSVReport: "Summary data in tabular form, openable directly in spreadsheets. Reports are saved in the current working directory by default.", + + // Wizard fields + KWzTaskName: "Task Name", + KWzProtocol: "Protocol", + KWzEndpoint: "Endpoint URL", + KWzAPIKey: "API Key", + KWzTestModel: "Model", + KWzTestMode: "Test Mode", + KWzTurboMode: "Turbo Mode", + KWzStandardMode: "Standard Mode", + KWzConcurrency: "Concurrency", + KWzTotalRequests: "Total Requests", + KWzTimeoutSecs: "Timeout (s)", + KWzInitConc: "Init Concurrency", + KWzMaxConc: "Max Concurrency", + KWzStepSize: "Step Size", + KWzLevelReqs: "Requests/Level", + KWzMinSuccessRate: "Min Success Rate", + KWzStreamMode: "Stream Mode", + KWzInputMode: "Input Mode", + KWzInputDirect: "Direct Input", + KWzInputFile: "File", + KWzInputGenerated: "Generated", + KWzInputRaw: "RAW Body", + KWzPromptConfig: "Prompt Config", + KWzSelectModeHint: "Select load test mode, then fill in concurrency and Prompt parameters.", + KWzTurboModeLabel: "Turbo Mode", + KWzStepFmt: "Step %d/3", + KWzStep1Label: "1 Basic Info", + KWzStep2Label: "2 Parameters", + KWzStep3Label: "3 Confirm", + KWzStep1Desc: "Configure task name, protocol, and connection info.", + KWzStep2Desc: "Choose test mode and fill in concurrency and prompt parameters.", + KWzStep3Desc: "Quick review before saving.", + KWzUntitled: "Untitled Task", + KWzNotFilled: "(empty)", + KWzExecParams: "Execution Parameters", + KWzConcurrencyRamp: "Concurrency Ramp", + KWzStopCondition: "Stop Condition", + KWzTimeoutLabel: "Timeout", + KWzContentSummary: "Content Summary", + KWzBodyBytes: "Body Bytes", + KWzSaveLocation: "Save Location", + KWzPromptSection: "Prompt", + KWzHintDirect: "Paste or type Prompt text directly. All requests share the same content.", + KWzHintFile: "Read Prompt from file(s). Supports glob patterns; requests rotate through matching files.", + KWzHintRaw: "Paste a complete HTTP request JSON body. Parameter assembly is skipped and the body is sent as-is.", + KWzHintCacheToken: "Tip: most services require ≥ 1024 tokens to trigger cache hits.", + KWzHintRawBody: "Tip: paste the full JSON body of an API request. It will be sent directly as the HTTP request body.", + KWzJSONBody: "JSON Body", + KWzPromptLabelShort: "Prompt", + KWzRAWBody: "RAW Body", + KWzFileSummary: "File", + KWzGeneratedFmt: "%d chars", + KWzPromptContent: "Content", + KWzNoConfirmItems: "No confirm items", + KWzConfirmRange: "Items %d-%d/%d", + KWzConfirmTotal: "%d items to confirm", + KWzNoFields: "No fields", + KWzFieldProgress: "Field %d/%d", + + // Misc + KEnabled: "On", + KDisabled: "Off", + KFileSummaryPfx: "File: ", + KNotSet: "(empty)", + KJustNow: "just now", + KMinutesAgoFmt: "%d min ago", + KHoursAgoFmt: "%d hr ago", + KDaysAgoFmt: "%d days ago", + }, +} + +// T returns the translation for key k in the active language. +// Falls back to ZH if the key is missing in the active language. +func T(k Key) string { + l := Active() + if m := translations[l]; m != nil { + if s, ok := m[k]; ok { + return s + } + } + if s, ok := translations[ZH][k]; ok { + return s + } + return "" +} + +// DisplayWidth returns the display column width of s, +// counting CJK characters as 2 columns and ASCII as 1. +func DisplayWidth(s string) int { + return runewidth.StringWidth(s) +} diff --git a/internal/tui/pages/contextbar.go b/internal/tui/pages/contextbar.go index 3c27c07..e5ab2a7 100644 --- a/internal/tui/pages/contextbar.go +++ b/internal/tui/pages/contextbar.go @@ -1,5 +1,7 @@ package pages +import "github.com/yinxulai/ait/internal/i18n" + // HotkeyItem 是底部 Hotkeys 区中的一个展示项。 type HotkeyItem struct { Key string // 如 "Enter"、"r"、"↑↓" @@ -30,181 +32,181 @@ func HotkeyTexts(texts ...string) []HotkeyItem { // Hotkeys_TaskList_Normal 普通任务选中时的 Hotkeys。 func Hotkeys_TaskList_Normal() []HotkeyItem { return []HotkeyItem{ - HotkeyAction("Enter", "查看详情"), - HotkeyAction("r", "运行"), - HotkeyAction("a", "新建任务"), - HotkeyAction("e", "编辑"), - HotkeyAction("d", "删除"), - HotkeyAction("y", "复制"), - HotkeyAction("p", "代理配置"), + HotkeyAction("Enter", i18n.T(i18n.KViewDetails)), + HotkeyAction("r", i18n.T(i18n.KRun)), + HotkeyAction("a", i18n.T(i18n.KNewTask)), + HotkeyAction("e", i18n.T(i18n.KEdit)), + HotkeyAction("d", i18n.T(i18n.KDelete)), + HotkeyAction("y", i18n.T(i18n.KCopy)), + HotkeyAction("p", i18n.T(i18n.KProxyConfig)), } } // Hotkeys_TaskList_Running 运行中任务选中时的 Hotkeys。 func Hotkeys_TaskList_Running() []HotkeyItem { return []HotkeyItem{ - HotkeyAction("Enter", "查看详情"), - HotkeyAction("s", "停止"), - HotkeyAction("y", "复制"), - HotkeyAction("p", "代理配置"), + HotkeyAction("Enter", i18n.T(i18n.KViewDetails)), + HotkeyAction("s", i18n.T(i18n.KStop)), + HotkeyAction("y", i18n.T(i18n.KCopy)), + HotkeyAction("p", i18n.T(i18n.KProxyConfig)), } } // Hotkeys_TaskDetail_NoHistory 任务详情页,无运行记录时。 func Hotkeys_TaskDetail_NoHistory() []HotkeyItem { return []HotkeyItem{ - HotkeyAction("r", "运行"), - HotkeyAction("e", "编辑"), - HotkeyAction("y", "复制"), - HotkeyAction("d", "删除"), + HotkeyAction("r", i18n.T(i18n.KRun)), + HotkeyAction("e", i18n.T(i18n.KEdit)), + HotkeyAction("y", i18n.T(i18n.KCopyTask)), + HotkeyAction("d", i18n.T(i18n.KDelete)), } } // Hotkeys_TaskDetail_HasHistory 任务详情页,有运行记录且未运行时。 func Hotkeys_TaskDetail_HasHistory() []HotkeyItem { return []HotkeyItem{ - HotkeyAction("↑↓", "选择记录"), - HotkeyAction("Enter", "查看运行详情"), - HotkeyAction("r", "再次运行"), - HotkeyAction("g", "导出 JSON 报告"), - HotkeyAction("e", "编辑"), - HotkeyAction("y", "复制任务"), - HotkeyAction("d", "删除"), + HotkeyAction("↑↓", i18n.T(i18n.KSelectRecord)), + HotkeyAction("Enter", i18n.T(i18n.KViewRunDetails)), + HotkeyAction("r", i18n.T(i18n.KRunAgain)), + HotkeyAction("g", i18n.T(i18n.KExportJSONReport)), + HotkeyAction("e", i18n.T(i18n.KEdit)), + HotkeyAction("y", i18n.T(i18n.KCopyTask)), + HotkeyAction("d", i18n.T(i18n.KDelete)), } } // Hotkeys_TaskDetail_Running 任务详情页,任务正在运行时。 func Hotkeys_TaskDetail_Running() []HotkeyItem { return []HotkeyItem{ - HotkeyAction("↑↓", "选择记录"), - HotkeyAction("Enter", "进入运行中仪表盘"), - HotkeyAction("g", "导出历史 JSON"), - HotkeyAction("e", "编辑"), - HotkeyAction("y", "复制任务"), + HotkeyAction("↑↓", i18n.T(i18n.KSelectRecord)), + HotkeyAction("Enter", i18n.T(i18n.KGoToLiveDash)), + HotkeyAction("g", i18n.T(i18n.KExportHistoryJSON)), + HotkeyAction("e", i18n.T(i18n.KEdit)), + HotkeyAction("y", i18n.T(i18n.KCopyTask)), } } // Hotkeys_Wizard_Step1 创建任务页,第 1 步。 func Hotkeys_Wizard_Step1() []HotkeyItem { return []HotkeyItem{ - HotkeyAction("Tab/↑↓", "切换字段"), - HotkeyAction("←→", "切换协议"), - HotkeyAction("Enter", "下一步"), - HotkeyAction("Esc", "返回列表"), + HotkeyAction("Tab/↑↓", i18n.T(i18n.KSwitchField)), + HotkeyAction("←→", i18n.T(i18n.KSwitchProtocol)), + HotkeyAction("Enter", i18n.T(i18n.KNextStep)), + HotkeyAction("Esc", i18n.T(i18n.KBackToList)), } } // Hotkeys_Wizard_Step2 创建任务页,第 2 步。 func Hotkeys_Wizard_Step2() []HotkeyItem { return []HotkeyItem{ - HotkeyAction("Tab/↑↓", "切换字段"), - HotkeyAction("←→", "切换选项"), - HotkeyAction("Enter", "下一步"), - HotkeyAction("Esc", "返回上一步"), + HotkeyAction("Tab/↑↓", i18n.T(i18n.KSwitchField)), + HotkeyAction("←→", i18n.T(i18n.KSwitchOption)), + HotkeyAction("Enter", i18n.T(i18n.KNextStep)), + HotkeyAction("Esc", i18n.T(i18n.KGoBack)), } } // Hotkeys_Wizard_Step3 创建任务页,第 3 步。 func Hotkeys_Wizard_Step3() []HotkeyItem { return []HotkeyItem{ - HotkeyAction("↑↓", "滚动"), - HotkeyAction("PgUp/PgDn", "翻页"), - HotkeyAction("Enter", "保存"), - HotkeyAction("r", "保存并运行"), - HotkeyAction("Esc", "返回修改"), + HotkeyAction("↑↓", i18n.T(i18n.KScroll)), + HotkeyAction("PgUp/PgDn", i18n.T(i18n.KPageTurn)), + HotkeyAction("Enter", i18n.T(i18n.KSave)), + HotkeyAction("r", i18n.T(i18n.KSaveAndRun)), + HotkeyAction("Esc", i18n.T(i18n.KBackToEdit)), } } // Hotkeys_Dashboard_Running_NoSel 标准仪表盘运行中,无选中请求时。 func Hotkeys_Dashboard_Running_NoSel() []HotkeyItem { return []HotkeyItem{ - HotkeyAction("s", "停止"), - HotkeyAction("b/Esc", "返回列表"), + HotkeyAction("s", i18n.T(i18n.KStop)), + HotkeyAction("b/Esc", i18n.T(i18n.KBackToList)), } } // Hotkeys_Dashboard_Done_NoSel 标准仪表盘完成后,无选中请求时。 func Hotkeys_Dashboard_Done_NoSel() []HotkeyItem { return []HotkeyItem{ - HotkeyAction("r", "生成报告"), - HotkeyAction("b/Esc", "返回列表"), + HotkeyAction("r", i18n.T(i18n.KGenerateReport)), + HotkeyAction("b/Esc", i18n.T(i18n.KBackToList)), } } // Hotkeys_Dashboard_Running_Sel 标准仪表盘运行中,已选中请求时。 func Hotkeys_Dashboard_Running_Sel() []HotkeyItem { return []HotkeyItem{ - HotkeyAction("Enter", "查看请求详情"), - HotkeyAction("↑↓", "选择请求"), - HotkeyAction("s", "停止"), + HotkeyAction("Enter", i18n.T(i18n.KViewRequest)), + HotkeyAction("↑↓", i18n.T(i18n.KSelectRequest)), + HotkeyAction("s", i18n.T(i18n.KStop)), } } // Hotkeys_Dashboard_Done_Sel 标准仪表盘完成后,已选中请求时。 func Hotkeys_Dashboard_Done_Sel() []HotkeyItem { return []HotkeyItem{ - HotkeyAction("Enter", "查看请求详情"), - HotkeyAction("↑↓", "选择请求"), + HotkeyAction("Enter", i18n.T(i18n.KViewRequest)), + HotkeyAction("↑↓", i18n.T(i18n.KSelectRequest)), } } // Hotkeys_TurboDash_Running_NoSel Turbo 仪表盘运行中,无选中级别时。 func Hotkeys_TurboDash_Running_NoSel() []HotkeyItem { return []HotkeyItem{ - HotkeyAction("s", "停止"), - HotkeyAction("b/Esc", "返回列表"), + HotkeyAction("s", i18n.T(i18n.KStop)), + HotkeyAction("b/Esc", i18n.T(i18n.KBackToList)), } } // Hotkeys_TurboDash_Done_NoSel Turbo 仪表盘完成后,无选中级别时。 func Hotkeys_TurboDash_Done_NoSel() []HotkeyItem { return []HotkeyItem{ - HotkeyAction("r", "生成报告"), - HotkeyAction("b/Esc", "返回列表"), + HotkeyAction("r", i18n.T(i18n.KGenerateReport)), + HotkeyAction("b/Esc", i18n.T(i18n.KBackToList)), } } // Hotkeys_TurboDash_Running_Sel Turbo 仪表盘运行中,已选中级别时。 func Hotkeys_TurboDash_Running_Sel() []HotkeyItem { return []HotkeyItem{ - HotkeyAction("Enter", "查看该级别请求"), - HotkeyAction("↑↓", "选择"), - HotkeyAction("s", "停止"), + HotkeyAction("Enter", i18n.T(i18n.KViewLevelReqs)), + HotkeyAction("↑↓", i18n.T(i18n.KSelectItem)), + HotkeyAction("s", i18n.T(i18n.KStop)), } } // Hotkeys_TurboDash_Done_Sel Turbo 仪表盘完成后,已选中级别时。 func Hotkeys_TurboDash_Done_Sel() []HotkeyItem { return []HotkeyItem{ - HotkeyAction("Enter", "查看该级别请求"), - HotkeyAction("↑↓", "选择"), + HotkeyAction("Enter", i18n.T(i18n.KViewLevelReqs)), + HotkeyAction("↑↓", i18n.T(i18n.KSelectItem)), } } // Hotkeys_ReqDetail 请求详情页。 func Hotkeys_ReqDetail() []HotkeyItem { return []HotkeyItem{ - HotkeyAction("b/Esc", "返回仪表盘"), - HotkeyAction("↑↓", "滚动"), - HotkeyAction("←→", "上/下一条请求"), + HotkeyAction("b/Esc", i18n.T(i18n.KBackToDash)), + HotkeyAction("↑↓", i18n.T(i18n.KScroll)), + HotkeyAction("←→", i18n.T(i18n.KPrevNextReq)), } } // Hotkeys_ProxyConfig 代理配置页。 func Hotkeys_ProxyConfig() []HotkeyItem { return []HotkeyItem{ - HotkeyAction("Tab/↑↓", "切换字段"), - HotkeyAction("←→/Space", "切换类型"), - HotkeyAction("Enter", "保存"), - HotkeyAction("Ctrl+U", "清空"), + HotkeyAction("Tab/↑↓", i18n.T(i18n.KSwitchField)), + HotkeyAction("←→/Space", i18n.T(i18n.KSwitchType)), + HotkeyAction("Enter", i18n.T(i18n.KSave)), + HotkeyAction("Ctrl+U", i18n.T(i18n.KClear)), } } // Hotkeys_Help 帮助页。 func Hotkeys_Help() []HotkeyItem { return []HotkeyItem{ - HotkeyAction("↑↓", "滚动"), - HotkeyAction("PgUp/PgDn", "翻页"), - HotkeyAction("g/G", "顶部/底部"), + HotkeyAction("↑↓", i18n.T(i18n.KScroll)), + HotkeyAction("PgUp/PgDn", i18n.T(i18n.KPageTurn)), + HotkeyAction("g/G", i18n.T(i18n.KTopBottom)), } } diff --git a/internal/tui/pages/dashboard.go b/internal/tui/pages/dashboard.go index 5f4a52d..60d8237 100644 --- a/internal/tui/pages/dashboard.go +++ b/internal/tui/pages/dashboard.go @@ -8,6 +8,7 @@ import ( tea "github.com/charmbracelet/bubbletea" "charm.land/lipgloss/v2" lgtable "charm.land/lipgloss/v2/table" + "github.com/yinxulai/ait/internal/i18n" "github.com/yinxulai/ait/internal/server" "github.com/yinxulai/ait/internal/types" ) @@ -188,25 +189,25 @@ func RenderDashboard(d *DashboardState, taskName string, st Styles, width, heigh default: cbItems = Hotkeys_Dashboard_Done_NoSel() } - headerLeft := []string{"等待数据"} + headerLeft := []string{i18n.T(i18n.KWaitingStatus)} headerRight := []string{} if rs != nil { - headerLeft = []string{runStatusText(string(rs.Status)), fmt.Sprintf("完成 %d/%d", rs.DoneReqs, rs.TotalReqs)} - headerRight = []string{fmt.Sprintf("成功率 %.1f%%", rs.SuccessRate)} + headerLeft = []string{runStatusText(string(rs.Status)), fmt.Sprintf("%d/%d", rs.DoneReqs, rs.TotalReqs)} + headerRight = []string{fmt.Sprintf(i18n.T(i18n.KSuccessRateFmt), rs.SuccessRate)} if !rs.StartedAt.IsZero() { - headerRight = append(headerRight, "开始 "+fmtRelativeTime(rs.StartedAt)) + headerRight = append(headerRight, i18n.T(i18n.KStart)+" "+fmtRelativeTime(rs.StartedAt)) } } if d.TaskID != "" { - headerRight = append(headerRight, "任务 "+truncate(d.TaskID, 14)) + headerRight = append(headerRight, truncate(d.TaskID, 14)) } l := PageLayout{ - HeaderTitle: "标准运行监控", - HeaderSubtitle: "实时查看运行进度、吞吐和单请求明细", - HeaderMeta: "标准模式", + HeaderTitle: i18n.T(i18n.KStdMonitorTitle), + HeaderSubtitle: i18n.T(i18n.KStdMonitorSubtitle), + HeaderMeta: i18n.T(i18n.KStandardMode), HeaderInfoLeft: headerLeft, HeaderInfoRight: headerRight, - Hotkeys: NewPageHotkeysWithHelp(cbItems, "[b/Esc] 返回上一页", "[q] 退出"), + Hotkeys: NewPageHotkeysWithHelp(cbItems, i18n.T(i18n.KHintGoBack), i18n.T(i18n.KHintQuit)), } frame := l.Frame(width, height) bodyPanel := frame.InnerPanel() @@ -237,15 +238,16 @@ func RenderDashboard(d *DashboardState, taskName string, st Styles, width, heigh // buildDashParamsPanel 构建左侧任务参数面板。 func buildDashParamsPanel(d *DashboardState, rs *server.RunState, st Styles, maxH, width int) string { - lines := panelTitleLines(st, "运行进度", width, false) + lines := panelTitleLines(st, i18n.T(i18n.KProgress), width, false) if rs == nil { - lines = append(lines, " "+st.Muted.Render("等待数据...")) + lines = append(lines, " "+st.Muted.Render(i18n.T(i18n.KWaitingData))) } else { - // 参数从 RunState 读取(实际可从 task 传入,此处用 RunState 已知信息展示) - lines = append(lines, " "+labelValue(st, "进度", fmt.Sprintf("%d/%d", rs.DoneReqs, rs.TotalReqs))) - lines = append(lines, " "+labelValue(st, "成功", fmt.Sprintf("%d", rs.SuccessReqs))) - lines = append(lines, " "+labelValue(st, "失败", fmt.Sprintf("%d", rs.FailedReqs))) + lbls := []string{i18n.T(i18n.KProgress), i18n.T(i18n.KSuccessCount), i18n.T(i18n.KFailureCount)} + lw := maxLabelWidth(lbls) + lines = append(lines, " "+labelValue(st, lbls[0], fmt.Sprintf("%d/%d", rs.DoneReqs, rs.TotalReqs), lw)) + lines = append(lines, " "+labelValue(st, lbls[1], fmt.Sprintf("%d", rs.SuccessReqs), lw)) + lines = append(lines, " "+labelValue(st, lbls[2], fmt.Sprintf("%d", rs.FailedReqs), lw)) } return finishPanelLines(lines, maxH) @@ -253,10 +255,10 @@ func buildDashParamsPanel(d *DashboardState, rs *server.RunState, st Styles, max // buildDashMetricsPanel 构建右侧实时指标面板。 func buildDashMetricsPanel(rs *server.RunState, st Styles, maxH, width int) string { - lines := panelTitleLines(st, "实时指标", width, false) + lines := panelTitleLines(st, i18n.T(i18n.KInProgress), width, false) if rs == nil { - lines = append(lines, " "+st.Muted.Render("等待数据...")) + lines = append(lines, " "+st.Muted.Render(i18n.T(i18n.KWaitingData))) } else { lines = appendRunMetricLines(lines, st, rs) } @@ -267,7 +269,7 @@ func buildDashMetricsPanel(rs *server.RunState, st Styles, maxH, width int) stri // buildProgressLine 构建进度条行。 func buildProgressLine(rs *server.RunState, st Styles, width int) string { if rs == nil { - return " 进度 " + st.Muted.Render("等待中...") + return " " + padToDisplayWidth(i18n.T(i18n.KProgress), 4) + " " + st.Muted.Render(i18n.T(i18n.KWaitingDots)) } total := rs.TotalReqs done := rs.DoneReqs @@ -284,17 +286,17 @@ func buildProgressLine(rs *server.RunState, st Styles, width int) string { } } suffix := fmt.Sprintf(" %d / %d %s", done, total, elapsed) - return renderProgressBar(st, " 进度 ", suffix, ratio, width) + return renderProgressBar(st, " "+padToDisplayWidth(i18n.T(i18n.KProgress), 4)+" ", suffix, ratio, width) } // buildRequestList 构建请求列表区域。 func buildRequestList(d *DashboardState, rs *server.RunState, st Styles, width, maxH int) string { - titleLines := panelTitleLines(st, "请求列表", width, true) + titleLines := panelTitleLines(st, i18n.T(i18n.KRequests), width, true) if rs == nil || len(rs.Requests) == 0 { - msg := "等待请求..." + msg := i18n.T(i18n.KWaitingData) if rs != nil && rs.Status != server.RunStatusRunning { - msg = "无请求详情数据" + msg = i18n.T(i18n.KNoRunRecords) } titleLines = append(titleLines, " "+st.Muted.Render(msg)) return finishPanelLines(titleLines, maxH) @@ -347,7 +349,7 @@ func buildRequestList(d *DashboardState, rs *server.RunState, st Styles, width, colWidths := []int{6, 8, 0, 8, 10, 12, 12, 10} // #, 状态, 总耗时=flex, TTFT, Cache, 输入, 输出, TPS tableH := maxH - len(titleLines) tbl := lgtable.New(). - Headers("#", "状态", "总耗时", "TTFT", "Cache", "输入", "输出", "TPS"). + Headers("#", i18n.T(i18n.KStatus), i18n.T(i18n.KTotalTime), "TTFT", "Cache", i18n.T(i18n.KColInput), i18n.T(i18n.KColOutput), "TPS"). Width(width). Height(tableH). YOffset(d.ReqOff). diff --git a/internal/tui/pages/help.go b/internal/tui/pages/help.go index 46809a2..526e967 100644 --- a/internal/tui/pages/help.go +++ b/internal/tui/pages/help.go @@ -5,6 +5,7 @@ import ( tea "github.com/charmbracelet/bubbletea" "charm.land/lipgloss/v2" + "github.com/yinxulai/ait/internal/i18n" ) // HelpState 帮助页状态。 @@ -81,10 +82,10 @@ func RenderHelp(s *HelpState, st Styles, width, height int) string { } l := PageLayout{ - HeaderTitle: "帮助", - HeaderSubtitle: "AIT — AI 接口压测工具概念说明与操作指南", - HeaderMeta: "帮助", - Hotkeys: NewPageHotkeys(Hotkeys_Help(), "[b/Esc] 返回", "[q] 退出"), + HeaderTitle: i18n.T(i18n.KHelpTitle), + HeaderSubtitle: i18n.T(i18n.KHelpSubtitle), + HeaderMeta: i18n.T(i18n.KHelpMeta), + Hotkeys: NewPageHotkeys(Hotkeys_Help(), i18n.T(i18n.KHintEscBack), i18n.T(i18n.KHintQuit)), } frame := l.Frame(width, height) panel := NewPanelFrame(frame.OuterWidth) @@ -108,82 +109,82 @@ type helpItem struct { func helpContent() []helpSection { return []helpSection{ { - title: "核心概念", + title: i18n.T(i18n.KHelpSecConcepts), items: []helpItem{ - {"任务 (Task)", "一组压测配置的集合,包含目标接口、模型、并发数、请求数等参数。任务可多次运行,每次运行独立记录结果。"}, - {"运行 (Run)", "任务的一次具体执行。每次运行产生独立的指标数据和请求记录,可导出为 JSON/CSV 报告。"}, - {"标准模式", "以固定并发数执行所有请求,适合衡量稳定负载下的接口性能。"}, - {"Turbo 模式", "自动从低并发逐步爬坡,找出接口在保持成功率要求下能承受的最大稳定并发数。"}, + {i18n.T(i18n.KHelpTermTask), i18n.T(i18n.KHelpDescTask)}, + {i18n.T(i18n.KHelpTermRun), i18n.T(i18n.KHelpDescRun)}, + {i18n.T(i18n.KHelpTermStandard), i18n.T(i18n.KHelpDescStandard)}, + {i18n.T(i18n.KHelpTermTurboMode), i18n.T(i18n.KHelpDescTurboMode)}, }, }, { - title: "性能指标", + title: i18n.T(i18n.KHelpSecMetrics), items: []helpItem{ - {"TPS", "Tokens Per Second,每秒输出 Token 数,衡量模型的文本生成速率。"}, - {"均值TPS", "本次运行中所有请求的 TPS 均值,反映整体吞吐水平。"}, - {"TTFT", "Time To First Token,从发送请求到收到第一个 Token 的耗时,衡量模型响应延迟。"}, - {"均值TTFT", "本次运行中所有请求的 TTFT 均值。"}, - {"成功率", "成功完成的请求数占总请求数的百分比。失败包括超时、HTTP 错误、模型返回错误等。"}, - {"缓存命中", "请求中使用了 KV 缓存(Prompt Cache)的比例。命中缓存可显著降低 TTFT 和推理成本。该指标为二值统计:单次请求若有任何 Token 命中缓存则计为命中。"}, - {"并发(Turbo)", "Turbo 模式下找到的最大稳定并发数,即在满足最低成功率要求的前提下能同时维持的请求数。"}, + {i18n.T(i18n.KHelpTermTPS), i18n.T(i18n.KHelpDescTPS)}, + {i18n.T(i18n.KHelpTermAvgTPS), i18n.T(i18n.KHelpDescAvgTPS)}, + {i18n.T(i18n.KHelpTermTTFT), i18n.T(i18n.KHelpDescTTFT)}, + {i18n.T(i18n.KHelpTermAvgTTFT), i18n.T(i18n.KHelpDescAvgTTFT)}, + {i18n.T(i18n.KHelpTermSuccessRate), i18n.T(i18n.KHelpDescSuccessRate)}, + {i18n.T(i18n.KHelpTermCacheHit), i18n.T(i18n.KHelpDescCacheHit)}, + {i18n.T(i18n.KHelpTermConcurrencyTurbo), i18n.T(i18n.KHelpDescConcurrencyTurbo)}, }, }, { - title: "协议支持", + title: i18n.T(i18n.KHelpSecProtocols), items: []helpItem{ - {"OpenAI", "兼容 OpenAI Chat Completions API(/v1/chat/completions),支持流式和非流式响应。"}, - {"Anthropic", "兼容 Anthropic Messages API(/v1/messages),支持流式和非流式响应。"}, + {i18n.T(i18n.KHelpTermOpenAI), i18n.T(i18n.KHelpDescOpenAI)}, + {i18n.T(i18n.KHelpTermAnthropic), i18n.T(i18n.KHelpDescAnthropic)}, }, }, { - title: "快捷键 — 全局", + title: i18n.T(i18n.KHelpSecGlobal), items: []helpItem{ - {"q / Ctrl+C", "退出程序。"}, - {"?", "打开此帮助页。"}, - {"b / Esc", "返回上一页。"}, + {i18n.T(i18n.KHelpTermQuit), i18n.T(i18n.KHelpDescQuit)}, + {i18n.T(i18n.KHelpTermQuestionMark), i18n.T(i18n.KHelpDescQuestionMark)}, + {i18n.T(i18n.KHelpTermBack), i18n.T(i18n.KHelpDescBack)}, }, }, { - title: "快捷键 — 任务列表", + title: i18n.T(i18n.KHelpSecTaskList), items: []helpItem{ - {"↑↓ / j k", "选择任务。"}, - {"Enter", "进入任务详情页。"}, - {"r", "立即运行选中任务。"}, - {"s", "停止正在运行的任务(仅任务运行中可用)。"}, - {"a", "新建任务(打开向导)。"}, - {"e", "编辑选中任务配置。"}, - {"d", "删除选中任务(需确认)。"}, - {"y", "复制选中任务(生成副本)。"}, - {"p", "打开代理配置页。"}, + {i18n.T(i18n.KHelpTermSelectTask), i18n.T(i18n.KHelpDescSelectTask)}, + {i18n.T(i18n.KHelpTermEnterDetail), i18n.T(i18n.KHelpDescEnterDetail)}, + {i18n.T(i18n.KHelpTermRunTask), i18n.T(i18n.KHelpDescRunTask)}, + {i18n.T(i18n.KHelpTermStopTask), i18n.T(i18n.KHelpDescStopTask)}, + {i18n.T(i18n.KHelpTermNewTask), i18n.T(i18n.KHelpDescNewTask)}, + {i18n.T(i18n.KHelpTermEditTask), i18n.T(i18n.KHelpDescEditTask)}, + {i18n.T(i18n.KHelpTermDeleteTask), i18n.T(i18n.KHelpDescDeleteTask)}, + {i18n.T(i18n.KHelpTermCopyTask), i18n.T(i18n.KHelpDescCopyTask)}, + {i18n.T(i18n.KHelpTermProxy), i18n.T(i18n.KHelpDescProxy)}, }, }, { - title: "快捷键 — 任务详情", + title: i18n.T(i18n.KHelpSecTaskDetail), items: []helpItem{ - {"↑↓ / j k", "在历史运行记录中选择条目。"}, - {"Enter", "查看选中运行的仪表盘;若任务正在运行,进入实时仪表盘。"}, - {"r", "再次运行该任务(无正在运行的实例时可用)。"}, - {"g", "将选中的历史运行导出为 JSON 报告。"}, - {"e", "编辑任务配置。"}, - {"y", "复制任务。"}, - {"d", "删除任务。"}, + {i18n.T(i18n.KHelpTermSelectHistory), i18n.T(i18n.KHelpDescSelectHistory)}, + {i18n.T(i18n.KHelpTermEnterDash), i18n.T(i18n.KHelpDescEnterDash)}, + {i18n.T(i18n.KHelpTermRunAgain), i18n.T(i18n.KHelpDescRunAgain)}, + {i18n.T(i18n.KHelpTermExport), i18n.T(i18n.KHelpDescExport)}, + {i18n.T(i18n.KHelpTermEditConfig), i18n.T(i18n.KHelpDescEditConfig)}, + {i18n.T(i18n.KHelpTermCopyTask2), i18n.T(i18n.KHelpDescCopyTask2)}, + {i18n.T(i18n.KHelpTermDeleteTask2), i18n.T(i18n.KHelpDescDeleteTask2)}, }, }, { - title: "快捷键 — 运行仪表盘", + title: i18n.T(i18n.KHelpSecDashboard), items: []helpItem{ - {"↑↓ / j k", "选择请求条目。"}, - {"Enter", "查看选中请求的详情(耗时、Token、响应体等)。"}, - {"s", "停止正在运行的任务。"}, - {"r", "生成 JSON 报告(运行结束后可用)。"}, - {"b / Esc", "返回任务详情页。"}, + {i18n.T(i18n.KHelpTermSelectReq), i18n.T(i18n.KHelpDescSelectReq)}, + {i18n.T(i18n.KHelpTermViewReq), i18n.T(i18n.KHelpDescViewReq)}, + {i18n.T(i18n.KHelpTermStopDash), i18n.T(i18n.KHelpDescStopDash)}, + {i18n.T(i18n.KHelpTermGenerateReport), i18n.T(i18n.KHelpDescGenerateReport)}, + {i18n.T(i18n.KHelpTermBackDash), i18n.T(i18n.KHelpDescBackDash)}, }, }, { - title: "报告导出", + title: i18n.T(i18n.KHelpSecExport), items: []helpItem{ - {"JSON 报告", "完整记录每次请求的所有指标、请求/响应体,适合程序化分析。"}, - {"CSV 报告", "表格形式的汇总数据,可直接在电子表格中打开。报告默认保存在当前工作目录。"}, + {i18n.T(i18n.KHelpTermJSONReport), i18n.T(i18n.KHelpDescJSONReport)}, + {i18n.T(i18n.KHelpTermCSVReport), i18n.T(i18n.KHelpDescCSVReport)}, }, }, } diff --git a/internal/tui/pages/helpers.go b/internal/tui/pages/helpers.go index a5fe560..f04a055 100644 --- a/internal/tui/pages/helpers.go +++ b/internal/tui/pages/helpers.go @@ -6,6 +6,8 @@ import ( "time" "charm.land/lipgloss/v2" + "github.com/mattn/go-runewidth" + "github.com/yinxulai/ait/internal/i18n" "github.com/yinxulai/ait/internal/server" ) @@ -38,6 +40,27 @@ func padRight(s string, width int) string { return lipgloss.NewStyle().Width(width).Render(s) } +// padToDisplayWidth 用空格将 s 右侧填充至 w 显示列宽。 +// 使用 runewidth 精确测量 CJK(2列)与 ASCII(1列)字符。 +func padToDisplayWidth(s string, w int) string { + cur := runewidth.StringWidth(s) + if cur >= w { + return s + } + return s + strings.Repeat(" ", w-cur) +} + +// maxLabelWidth 返回一组标签中最大的显示列宽。 +func maxLabelWidth(labels []string) int { + max := 0 + for _, l := range labels { + if w := runewidth.StringWidth(l); w > max { + max = w + } + } + return max +} + // wrapText 将文本按 maxW 列宽折行,返回行切片(CJK 字符按 2 列宽计算)。 func wrapText(s string, maxW int) []string { if maxW <= 0 { @@ -117,17 +140,17 @@ func fmtRelativeTime(t time.Time) string { } d := time.Since(t) if d < time.Minute { - return "刚刚" + return i18n.T(i18n.KJustNow) } if d < time.Hour { - return fmt.Sprintf("%d 分钟前", int(d.Minutes())) + return fmt.Sprintf(i18n.T(i18n.KMinutesAgoFmt), int(d.Minutes())) } if d < 24*time.Hour { - return fmt.Sprintf("%d 小时前", int(d.Hours())) + return fmt.Sprintf(i18n.T(i18n.KHoursAgoFmt), int(d.Hours())) } days := int(d.Hours() / 24) if days < 30 { - return fmt.Sprintf("%d 天前", days) + return fmt.Sprintf(i18n.T(i18n.KDaysAgoFmt), days) } return t.Format("2006-01-02") } @@ -305,7 +328,7 @@ func renderPrimaryHotkeyItems(items []HotkeyItem, maxW int) string { Background(lipgloss.Color("239")). Foreground(colorMuted). Padding(0, 1). - Render("当前页暂无快捷操作") + Render(i18n.T(i18n.KNoHotkeys)) } var rendered []string @@ -433,15 +456,15 @@ func nonEmptyParts(parts []string) []string { func runStatusText(status string) string { switch strings.ToLower(strings.TrimSpace(status)) { case "running": - return "运行中" + return i18n.T(i18n.KRunning) case "completed": - return "已完成" + return i18n.T(i18n.KCompleted) case "failed": - return "运行失败" + return i18n.T(i18n.KRunFailed) case "stopped": - return "已停止" + return i18n.T(i18n.KStopped) case "": - return "等待数据" + return i18n.T(i18n.KWaitingStatus) default: return status } @@ -452,7 +475,7 @@ func modeShortLabel(mode string) string { if mode == "turbo" { return "Turbo" } - return "标准" + return i18n.T(i18n.KStandardMode) } // isRunStateRunning 判断 RunState 是否处于运行状态。 @@ -470,13 +493,23 @@ func applyColWidth(s lipgloss.Style, col int, colWidths []int) lipgloss.Style { } // appendRunMetricLines 向 lines 追加 6 行运行指标(成功率/TPS/TTFT/缓存命中/RPM/TPM)。 +// 标签宽度由当前语言的最大标签宽度自动计算,无需手动添加空格对齐。 func appendRunMetricLines(lines []string, st Styles, rs *server.RunState) []string { - lines = append(lines, " "+labelValue(st, "成功率 ", st.MetricVal.Render(fmt.Sprintf("%.1f%%", rs.SuccessRate)))) - lines = append(lines, " "+labelValue(st, "TPS均值 ", st.MetricVal.Render(fmt.Sprintf("%.1f tok/s", rs.AvgTPS)))) - lines = append(lines, " "+labelValue(st, "TTFT均值", st.MetricVal.Render(fmtDuration(rs.AvgTTFT)))) - lines = append(lines, " "+labelValue(st, "缓存命中", st.MetricVal.Render(fmt.Sprintf("%.1f%%", rs.CacheHitRate*100)))) - lines = append(lines, " "+labelValue(st, "RPM ", st.MetricVal.Render(fmt.Sprintf("%.0f req/min", rs.RPM)))) - lines = append(lines, " "+labelValue(st, "TPM ", st.MetricVal.Render(fmt.Sprintf("%.0f tok/min", rs.TPM)))) + lbls := []string{ + i18n.T(i18n.KSuccessRate), + i18n.T(i18n.KAvgTPS), + i18n.T(i18n.KAvgTTFT), + i18n.T(i18n.KCacheHit), + i18n.T(i18n.KRPM), + i18n.T(i18n.KTPM), + } + lw := maxLabelWidth(lbls) + lines = append(lines, " "+labelValue(st, lbls[0], st.MetricVal.Render(fmt.Sprintf("%.1f%%", rs.SuccessRate)), lw)) + lines = append(lines, " "+labelValue(st, lbls[1], st.MetricVal.Render(fmt.Sprintf("%.1f tok/s", rs.AvgTPS)), lw)) + lines = append(lines, " "+labelValue(st, lbls[2], st.MetricVal.Render(fmtDuration(rs.AvgTTFT)), lw)) + lines = append(lines, " "+labelValue(st, lbls[3], st.MetricVal.Render(fmt.Sprintf("%.1f%%", rs.CacheHitRate*100)), lw)) + lines = append(lines, " "+labelValue(st, lbls[4], st.MetricVal.Render(fmt.Sprintf("%.0f req/min", rs.RPM)), lw)) + lines = append(lines, " "+labelValue(st, lbls[5], st.MetricVal.Render(fmt.Sprintf("%.0f tok/min", rs.TPM)), lw)) return lines } @@ -603,9 +636,9 @@ func shortProtocol(p string) string { func promptSummary(promptMode, promptText, promptFile string, promptLength int) string { switch promptMode { case "file": - return "文件: " + promptFile + return i18n.T(i18n.KFileSummaryPfx) + promptFile case "generated": - return fmt.Sprintf("生成 %d 字符", promptLength) + return fmt.Sprintf(i18n.T(i18n.KWzGeneratedFmt), promptLength) case "raw": if promptText != "" { r := []rune(promptText) @@ -614,7 +647,7 @@ func promptSummary(promptMode, promptText, promptFile string, promptLength int) } return "RAW: " + promptText } - return "(未设置)" + return i18n.T(i18n.KNotSet) default: if promptText != "" { r := []rune(promptText) @@ -623,21 +656,27 @@ func promptSummary(promptMode, promptText, promptFile string, promptLength int) } return promptText } - return "(未设置)" + return i18n.T(i18n.KNotSet) } } -// boolLabel 将 bool 值转换为"开启"/"关闭"。 +// boolLabel 将 bool 值转换为当前语言的开启/关闭标签。 func boolLabel(b bool) string { if b { - return "开启" + return i18n.T(i18n.KEnabled) } - return "关闭" + return i18n.T(i18n.KDisabled) } // labelValue 渲染一个 label:value 对。 -func labelValue(st Styles, label, value string) string { - return st.Label.Render(label) + " " + st.Value.Render(value) +// 可选参数 labelW 指定标签显示列宽,由 padToDisplayWidth 填充至指定宽度。 +// 同组标签调用时传入 maxLabelWidth(labels) 即可实现自动对齐。 +func labelValue(st Styles, label, value string, labelW ...int) string { + l := label + if len(labelW) > 0 && labelW[0] > 0 { + l = padToDisplayWidth(label, labelW[0]) + } + return st.Label.Render(l) + " " + st.Value.Render(value) } // wrapPanel 用带边框的 Panel 包裹内容,outerW 为包含边框的总宽度。 diff --git a/internal/tui/pages/layout.go b/internal/tui/pages/layout.go index e252b2d..c201df7 100644 --- a/internal/tui/pages/layout.go +++ b/internal/tui/pages/layout.go @@ -1,6 +1,10 @@ package pages -import "strings" +import ( + "strings" + + "github.com/yinxulai/ait/internal/i18n" +) // ── 尺寸常量 ────────────────────────────────────────────────────────────────── @@ -49,7 +53,7 @@ func NewPageHotkeys(hotkeys []HotkeyItem, hints ...string) PageHotkeys { // NewPageHotkeysWithHelp 在 NewPageHotkeys 基础上自动追加 [?] 帮助。 // 非帮助页使用此函数以统一显示帮助入口。 func NewPageHotkeysWithHelp(hotkeys []HotkeyItem, hints ...string) PageHotkeys { - withHelp := append(hotkeys, HotkeyAction("?", "帮助")) + withHelp := append(hotkeys, HotkeyAction("?", i18n.T(i18n.KHelp))) return PageHotkeys{ Hotkeys: withHelp, Hints: HotkeyTexts(hints...), @@ -215,5 +219,5 @@ func renderTooSmall(st Styles, width, _ int) string { if width < 4 { return "..." } - return st.Muted.Render(truncate("窗口过小 ↔ 请放大终端", width)) + return st.Muted.Render(truncate(i18n.T(i18n.KWindowTooSmall), width)) } diff --git a/internal/tui/pages/proxy.go b/internal/tui/pages/proxy.go index f8b04f9..d576a14 100644 --- a/internal/tui/pages/proxy.go +++ b/internal/tui/pages/proxy.go @@ -7,6 +7,7 @@ import ( "github.com/charmbracelet/bubbles/textinput" tea "github.com/charmbracelet/bubbletea" "charm.land/lipgloss/v2" + "github.com/yinxulai/ait/internal/i18n" ) // 代理类型常量 @@ -60,11 +61,11 @@ func proxyTypeLabel(t string) string { func proxyTypeHint(t string) string { switch t { case ProxyTypeSOCKS5: - return "示例: socks5://127.0.0.1:1080" + return i18n.T(i18n.KExSOCKS5) case ProxyTypeSSH: - return "示例: ssh://user@host:22" + return i18n.T(i18n.KExSSH) default: - return "示例: http://127.0.0.1:7890" + return i18n.T(i18n.KExHTTP) } } @@ -187,10 +188,10 @@ func RenderProxyConfig(s *ProxyConfigState, st Styles, width, height int) string } l := PageLayout{ - HeaderTitle: "代理配置", - HeaderSubtitle: "设置全局 HTTP 代理,适用于所有任务的请求。留空则使用系统环境变量或直连。", - HeaderMeta: "全局配置", - Hotkeys: NewPageHotkeysWithHelp(Hotkeys_ProxyConfig(), "[Esc] 返回", "[q] 退出"), + HeaderTitle: i18n.T(i18n.KProxyConfig), + HeaderSubtitle: i18n.T(i18n.KProxySubtitle), + HeaderMeta: i18n.T(i18n.KGlobalConfig), + Hotkeys: NewPageHotkeysWithHelp(Hotkeys_ProxyConfig(), i18n.T(i18n.KHintGoBack), i18n.T(i18n.KHintQuit)), } frame := l.Frame(width, height) panel := NewPanelFrame(frame.OuterWidth) @@ -222,7 +223,7 @@ func buildProxyConfigContent(s *ProxyConfigState, st Styles, contentW, maxH int) } typeLabelBlock := lipgloss.NewStyle().Width(15).Height(3). AlignVertical(lipgloss.Center). - Render(st.Label.Render("代理类型")) + Render(st.Label.Render(i18n.T(i18n.KProxyType))) typeRendered := typeFieldStyle.Width(fieldW + 4).Render(st.Value.Render(typeLabel)) appendBlock(lipgloss.JoinHorizontal(lipgloss.Top, typeLabelBlock, typeRendered)) @@ -238,20 +239,20 @@ func buildProxyConfigContent(s *ProxyConfigState, st Styles, contentW, maxH int) } else { v := s.input.Value() if v == "" { - urlRendered = urlFieldStyle.Width(fieldW + 4).Render(st.Muted.Render("未填写")) + urlRendered = urlFieldStyle.Width(fieldW + 4).Render(st.Muted.Render(i18n.T(i18n.KNotSet))) } else { urlRendered = urlFieldStyle.Width(fieldW + 4).Render(st.Value.Render(fitTail(v, fieldW))) } } urlLabelBlock := lipgloss.NewStyle().Width(15).Height(3). AlignVertical(lipgloss.Center). - Render(st.Label.Render("代理地址")) + Render(st.Label.Render(i18n.T(i18n.KProxyURL))) appendBlock(lipgloss.JoinHorizontal(lipgloss.Top, urlLabelBlock, urlRendered)) lines = append(lines, "") lines = append(lines, st.Muted.Render(truncate(proxyTypeHint(s.ProxyType), contentW))) lines = append(lines, "") - lines = append(lines, st.Muted.Render(truncate("配置保存至 ~/.ait/config.json,重启无需重新输入。", contentW))) + lines = append(lines, st.Muted.Render(truncate(i18n.T(i18n.KProxySaveHint), contentW))) // 填充至 maxH for len(lines) < maxH { diff --git a/internal/tui/pages/reqdetail.go b/internal/tui/pages/reqdetail.go index bce5bcd..0be943b 100644 --- a/internal/tui/pages/reqdetail.go +++ b/internal/tui/pages/reqdetail.go @@ -4,6 +4,7 @@ import ( "fmt" tea "github.com/charmbracelet/bubbletea" + "github.com/yinxulai/ait/internal/i18n" "github.com/yinxulai/ait/internal/server" "github.com/yinxulai/ait/internal/types" ) @@ -118,18 +119,18 @@ func RenderReqDetail(s *ReqDetailState, taskName string, st Styles, width, heigh idx = len(s.Requests) - 1 } r := s.Requests[idx] - status := "失败" + status := i18n.T(i18n.KRunFailed) if r.Success { - status = "成功" + status = i18n.T(i18n.KCompleted) } l := PageLayout{ - HeaderTitle: "请求详情", - HeaderSubtitle: "查看单次请求的耗时、网络阶段和完整报文", + HeaderTitle: i18n.T(i18n.KViewRequest), + HeaderSubtitle: i18n.T(i18n.KReqDetailSubtitle), HeaderMeta: truncate(string(s.RunID), 18), - HeaderInfoLeft: []string{fmt.Sprintf("请求 %d/%d", idx+1, len(s.Requests)), status}, - HeaderInfoRight: []string{fmt.Sprintf("缓存 %.0f%%", r.CacheHitRate*100), "耗时 " + fmtDuration(r.TotalTime)}, - Hotkeys: NewPageHotkeysWithHelp(Hotkeys_ReqDetail(), "[b/Esc] 返回上一页", "[q] 退出"), + HeaderInfoLeft: []string{fmt.Sprintf("%s %d/%d", i18n.T(i18n.KRequests), idx+1, len(s.Requests)), status}, + HeaderInfoRight: []string{fmt.Sprintf("%.0f%%", r.CacheHitRate*100), fmtDuration(r.TotalTime)}, + Hotkeys: NewPageHotkeysWithHelp(Hotkeys_ReqDetail(), i18n.T(i18n.KHintGoBack), i18n.T(i18n.KHintQuit)), } frame := l.Frame(width, height) bodyPanel := frame.InnerPanel() @@ -161,16 +162,16 @@ func RenderReqDetail(s *ReqDetailState, taskName string, st Styles, width, heigh // buildReqPerfPanel 构建请求左侧性能指标面板。 func buildReqPerfPanel(r *types.RequestMetrics, st Styles, maxH, width int) string { - lines := panelTitleLines(st, "性能指标", width, true) + lines := panelTitleLines(st, i18n.T(i18n.KStatus), width, true) if r == nil { - lines = append(lines, " "+st.Muted.Render("等待数据...")) + lines = append(lines, " "+st.Muted.Render(i18n.T(i18n.KWaitingData))) return finishPanelLines(lines, maxH) } - statusStr := st.Ok.Render("✓ 成功") + statusStr := st.Ok.Render("✓ " + i18n.T(i18n.KCompleted)) if !r.Success { - statusStr = st.ErrStyle.Render("✗ 失败") + statusStr = st.ErrStyle.Render("✗ " + i18n.T(i18n.KRunFailed)) } totalTime := "─" if r.TotalTime > 0 { @@ -190,20 +191,25 @@ func buildReqPerfPanel(r *types.RequestMetrics, st Styles, maxH, width int) stri if !r.Success { errorSummary = normalizeInlineText(r.ErrorMessage) if errorSummary == "" { - errorSummary = "请求失败" + errorSummary = i18n.T(i18n.KRunFailed) } errorSummary = truncate(errorSummary, maxInt(8, width-8)) } - lines = append(lines, " "+labelValue(st, "状态 ", statusStr)) - lines = append(lines, " "+labelValue(st, "总耗时 ", st.MetricVal.Render(totalTime))) - lines = append(lines, " "+labelValue(st, "TTFT ", st.MetricVal.Render(ttft))) - lines = append(lines, " "+labelValue(st, "输出TPS ", st.MetricVal.Render(tps))) - lines = append(lines, " "+labelValue(st, "Token ", tokenSummary)) + lbls := []string{ + i18n.T(i18n.KStatus), i18n.T(i18n.KTotalTime), "TTFT", + i18n.T(i18n.KOutputTPS), i18n.T(i18n.KToken), i18n.T(i18n.KCache), + } + lw := maxLabelWidth(lbls) + lines = append(lines, " "+labelValue(st, lbls[0], statusStr, lw)) + lines = append(lines, " "+labelValue(st, lbls[1], st.MetricVal.Render(totalTime), lw)) + lines = append(lines, " "+labelValue(st, lbls[2], st.MetricVal.Render(ttft), lw)) + lines = append(lines, " "+labelValue(st, lbls[3], st.MetricVal.Render(tps), lw)) + lines = append(lines, " "+labelValue(st, lbls[4], tokenSummary, lw)) if r.Success { - lines = append(lines, " "+labelValue(st, "缓存 ", cacheSummary)) + lines = append(lines, " "+labelValue(st, lbls[5], cacheSummary, lw)) } else { - lines = append(lines, " "+st.ErrStyle.Render("错误: "+errorSummary)) + lines = append(lines, " "+st.ErrStyle.Render(i18n.T(i18n.KError)+": "+errorSummary)) } return finishPanelLines(lines, maxH) @@ -211,18 +217,22 @@ func buildReqPerfPanel(r *types.RequestMetrics, st Styles, maxH, width int) stri // buildReqNetworkPanel 构建请求右侧网络指标面板。 func buildReqNetworkPanel(r *types.RequestMetrics, st Styles, maxH, width int) string { - lines := panelTitleLines(st, "网络指标", width, true) + lines := panelTitleLines(st, i18n.T(i18n.KTCPConnect), width, true) if r == nil { - lines = append(lines, " "+st.Muted.Render("等待数据...")) + lines = append(lines, " "+st.Muted.Render(i18n.T(i18n.KWaitingData))) return finishPanelLines(lines, maxH) } - lines = append(lines, " "+labelValue(st, "DNS ", fmtDuration(r.DNSTime))) - lines = append(lines, " "+labelValue(st, "TCP 连接 ", fmtDuration(r.ConnectTime))) - lines = append(lines, " "+labelValue(st, "TLS 握手 ", fmtDuration(r.TLSTime))) + lbls := []string{ + i18n.T(i18n.KDNS), i18n.T(i18n.KTCPConnect), i18n.T(i18n.KTLSHandshake), i18n.T(i18n.KTargetIP), + } + lw := maxLabelWidth(lbls) + lines = append(lines, " "+labelValue(st, lbls[0], fmtDuration(r.DNSTime), lw)) + lines = append(lines, " "+labelValue(st, lbls[1], fmtDuration(r.ConnectTime), lw)) + lines = append(lines, " "+labelValue(st, lbls[2], fmtDuration(r.TLSTime), lw)) if r.TargetIP != "" { - lines = append(lines, " "+labelValue(st, "目标 IP ", truncate(r.TargetIP, maxInt(4, width-12)))) + lines = append(lines, " "+labelValue(st, lbls[3], truncate(r.TargetIP, maxInt(4, width-12)), lw)) } return finishPanelLines(lines, maxH) @@ -230,11 +240,11 @@ func buildReqNetworkPanel(r *types.RequestMetrics, st Styles, maxH, width int) s // buildInputSection 构建输入 (请求体) 区域。 func buildInputSection(r *types.RequestMetrics, st Styles, width, maxH int) string { - lines := panelTitleLines(st, "请求体 (Request Body)", width, true) + lines := panelTitleLines(st, i18n.T(i18n.KRequestBody), width, true) lines = append(lines, " "+dividerLine(st, width-2)) if r.RequestBody == "" { - lines = append(lines, " "+st.Muted.Render("(未记录)")) + lines = append(lines, " "+st.Muted.Render(i18n.T(i18n.KNotRecorded))) } else { for _, l := range wrapText(r.RequestBody, width-3) { if len(lines) >= maxH-1 { @@ -249,11 +259,11 @@ func buildInputSection(r *types.RequestMetrics, st Styles, width, maxH int) stri // buildOutputSection 构建输出 (响应体) 区域。 func buildOutputSection(r *types.RequestMetrics, scrollY int, st Styles, width, maxH int) string { - lines := panelTitleLines(st, "响应体 (Response Body)", width, true) + lines := panelTitleLines(st, i18n.T(i18n.KResponseBody), width, true) lines = append(lines, " "+dividerLine(st, width-2)) if r.ResponseBody == "" { - lines = append(lines, " "+st.Muted.Render("(未记录)")) + lines = append(lines, " "+st.Muted.Render(i18n.T(i18n.KNotRecorded))) } else { allLines := wrapText(r.ResponseBody, width-3) if scrollY >= len(allLines) { @@ -269,7 +279,7 @@ func buildOutputSection(r *types.RequestMetrics, scrollY int, st Styles, width, lines = append(lines, " "+l) } if len(allLines) > maxH-3 { - lines = append(lines, " "+st.Muted.Render("(↑↓ 滚动查看完整内容)")) + lines = append(lines, " "+st.Muted.Render("("+i18n.T(i18n.KScrollMore)+")")) } } diff --git a/internal/tui/pages/taskdetail.go b/internal/tui/pages/taskdetail.go index 9340bb2..df63e15 100644 --- a/internal/tui/pages/taskdetail.go +++ b/internal/tui/pages/taskdetail.go @@ -8,6 +8,7 @@ import ( "charm.land/lipgloss/v2" lgtable "charm.land/lipgloss/v2/table" tea "github.com/charmbracelet/bubbletea" + "github.com/yinxulai/ait/internal/i18n" "github.com/yinxulai/ait/internal/server" "github.com/yinxulai/ait/internal/types" ) @@ -184,25 +185,25 @@ func RenderTaskDetail(s *TaskDetailState, st Styles, width, height int) string { default: cbItems = Hotkeys_TaskDetail_NoHistory() } - modeStr := "标准模式" + modeStr := i18n.T(i18n.KStandardMode) if inp.Turbo { - modeStr = "Turbo 模式" + modeStr = i18n.T(i18n.KTurboMode) } - headerRight := []string{"暂无运行记录"} + headerRight := []string{i18n.T(i18n.KNoRunRecords)} historyCount := len(taskDetailHistoryEntries(s)) if historyCount > 0 { - headerRight = []string{fmt.Sprintf("历史 %d 条", historyCount)} + headerRight = []string{fmt.Sprintf("%d", historyCount)} } if hasActive { - headerRight = append([]string{"运行中"}, headerRight...) + headerRight = append([]string{i18n.T(i18n.KRunning)}, headerRight...) } l := PageLayout{ HeaderTitle: truncate(t.Name, 28), - HeaderSubtitle: "查看任务配置、当前运行状态与历史记录", - HeaderMeta: "任务详情", + HeaderSubtitle: i18n.T(i18n.KTaskDetailSubtitle), + HeaderMeta: i18n.T(i18n.KRecordDetails), HeaderInfoLeft: []string{modeStr, inp.NormalizedProtocol()}, HeaderInfoRight: headerRight, - Hotkeys: NewPageHotkeysWithHelp(cbItems, "[b/Esc] 返回上一页", "[q] 退出"), + Hotkeys: NewPageHotkeysWithHelp(cbItems, i18n.T(i18n.KHintGoBack), i18n.T(i18n.KHintQuit)), } frame := l.Frame(width, height) @@ -219,47 +220,47 @@ func buildTaskDetailContent(s *TaskDetailState, st Styles, t types.TaskDefinitio // ─── 左栏:配置摘要 ───────────────────────────────────────── leftW := leftPanelFrame.InnerWidth - leftLines := panelTitleLines(st, "配置摘要", leftW, false) + leftLines := panelTitleLines(st, i18n.T(i18n.KProtocol), leftW, false) proto := inp.NormalizedProtocol() - leftLines = append(leftLines, padRight(" "+st.Label.Render("协议")+" "+st.Value.Render(proto), leftW)) + leftLines = append(leftLines, padRight(" "+st.Label.Render(i18n.T(i18n.KProtocol))+" "+st.Value.Render(proto), leftW)) endpoint := truncate(inp.ResolvedEndpointURL(), leftW-8) - leftLines = append(leftLines, padRight(" "+st.Label.Render("接口")+" "+st.Value.Render(endpoint), leftW)) + leftLines = append(leftLines, padRight(" "+st.Label.Render(i18n.T(i18n.KEndpoint))+" "+st.Value.Render(endpoint), leftW)) if inp.ProxyURL != "" { proxy := truncate(inp.ProxyURL, leftW-8) - leftLines = append(leftLines, padRight(" "+st.Label.Render("代理")+" "+st.Value.Render(proxy), leftW)) + leftLines = append(leftLines, padRight(" "+st.Label.Render(i18n.T(i18n.KProxy))+" "+st.Value.Render(proxy), leftW)) } leftLines = append(leftLines, padRight("", leftW)) model := truncate(inp.Model, leftW-10) - leftLines = append(leftLines, padRight(" "+st.Label.Render("模型")+" "+st.Value.Render(model), leftW)) - modeStr := "标准模式" + leftLines = append(leftLines, padRight(" "+st.Label.Render(i18n.T(i18n.KModel))+" "+st.Value.Render(model), leftW)) + modeStr := i18n.T(i18n.KStandardMode) if inp.Turbo { - modeStr = "Turbo 模式" + modeStr = i18n.T(i18n.KTurboMode) } - leftLines = append(leftLines, padRight(" "+st.Label.Render("模式")+" "+st.Value.Render(modeStr), leftW)) + leftLines = append(leftLines, padRight(" "+st.Label.Render(i18n.T(i18n.KMode))+" "+st.Value.Render(modeStr), leftW)) if inp.Turbo { tc := inp.TurboConfig - leftLines = append(leftLines, padRight(" "+st.Label.Render("并发")+" "+st.Value.Render( + leftLines = append(leftLines, padRight(" "+st.Label.Render(i18n.T(i18n.KConcurrency))+" "+st.Value.Render( fmt.Sprintf("%d → %d", tc.InitConcurrency, tc.MaxConcurrency)), leftW)) - leftLines = append(leftLines, padRight(" "+st.Label.Render("步进")+" "+st.Value.Render( - fmt.Sprintf("+%d 每级%d请求", tc.StepSize, tc.LevelRequests)), leftW)) + leftLines = append(leftLines, padRight(" "+st.Label.Render(i18n.T(i18n.KStepLabel))+" "+st.Value.Render( + fmt.Sprintf("+%d %d req", tc.StepSize, tc.LevelRequests)), leftW)) } else { - leftLines = append(leftLines, padRight(" "+st.Label.Render("并发")+" "+st.Value.Render( + leftLines = append(leftLines, padRight(" "+st.Label.Render(i18n.T(i18n.KConcurrency))+" "+st.Value.Render( fmt.Sprintf("%d", inp.Concurrency)), leftW)) - leftLines = append(leftLines, padRight(" "+st.Label.Render("请求")+" "+st.Value.Render( + leftLines = append(leftLines, padRight(" "+st.Label.Render(i18n.T(i18n.KRequests))+" "+st.Value.Render( fmt.Sprintf("%d", inp.Count)), leftW)) } leftLines = append(leftLines, padRight("", leftW)) - leftLines = append(leftLines, padRight(" "+st.Label.Render("超时")+" "+st.Value.Render(fmtDuration(inp.Timeout)), leftW)) - leftLines = append(leftLines, padRight(" "+st.Label.Render("流式")+" "+st.Value.Render(boolLabel(inp.Stream)), leftW)) + leftLines = append(leftLines, padRight(" "+st.Label.Render(i18n.T(i18n.KTimeout))+" "+st.Value.Render(fmtDuration(inp.Timeout)), leftW)) + leftLines = append(leftLines, padRight(" "+st.Label.Render(i18n.T(i18n.KStream))+" "+st.Value.Render(boolLabel(inp.Stream)), leftW)) prompt := promptSummary(inp.PromptMode, inp.PromptText, inp.PromptFile, inp.PromptLength) - leftLines = append(leftLines, padRight(" "+st.Label.Render("Prompt")+" "+st.Value.Render(truncate(prompt, leftW-12)), leftW)) + leftLines = append(leftLines, padRight(" "+st.Label.Render(i18n.T(i18n.KPromptLabel))+" "+st.Value.Render(truncate(prompt, leftW-12)), leftW)) leftContent := finishPanelLines(leftLines, panelContentH) // ─── 右栏:历史运行记录 ───────────────────────────────────── rightW := rightPanelFrame.InnerWidth - rightTitle := panelTitleLines(st, "历史运行记录", rightW, false) // 2 行 + rightTitle := panelTitleLines(st, i18n.T(i18n.KRunHistory), rightW, false) // 2 行 historyEntries := taskDetailHistoryEntries(s) hasActive := s.ActiveRun != nil @@ -269,7 +270,7 @@ func buildTaskDetailContent(s *TaskDetailState, st Styles, t types.TaskDefinitio } if effectiveLen == 0 { - rightLines := append(rightTitle, padRight(" "+st.Muted.Render("暂无运行记录"), rightW)) + rightLines := append(rightTitle, padRight(" "+st.Muted.Render(i18n.T(i18n.KNoRunRecords)), rightW)) rightContent := finishPanelLines(rightLines, panelContentH) return renderSplitPanels(st, leftPanelFrame, rightPanelFrame, leftContent, rightContent) } @@ -329,7 +330,7 @@ func buildTaskDetailContent(s *TaskDetailState, st Styles, t types.TaskDefinitio rate: rateStr, dur: "─", ttft: "─", - tps: fmt.Sprintf("%d/%d 正在运行...", rs.DoneReqs, rs.TotalReqs), + tps: fmt.Sprintf("%d/%d %s", rs.DoneReqs, rs.TotalReqs, i18n.T(i18n.KRunning)), rpm: "─", tpm: "─", } @@ -379,7 +380,7 @@ func buildTaskDetailContent(s *TaskDetailState, st Styles, t types.TaskDefinitio sel := s.HistorySel tableH := tableMaxH - len(rightTitle) tbl := lgtable.New(). - Headers("", "时间", "模式", "成功率", "耗时", "TTFT", "TPS", "RPM", "TPM"). + Headers("", i18n.T(i18n.KTime), i18n.T(i18n.KMode), i18n.T(i18n.KSuccessRate), i18n.T(i18n.KElapsed), "TTFT", "TPS", "RPM", "TPM"). Width(rightW). Height(tableH). YOffset(s.HistoryOff). @@ -481,9 +482,13 @@ func buildTaskHistoryDetailLines(history []types.TaskRunSummary, histIdx int, st finishedText := sel.FinishedAt.Format("2006-01-02 15:04") if sel.FinishedAt.IsZero() { elapsedText = fmtDuration(time.Since(sel.StartedAt)) - finishedText = "进行中" + finishedText = i18n.T(i18n.KRunning) } - labelW := 8 + labelW := maxLabelWidth([]string{ + i18n.T(i18n.KStatus), i18n.T(i18n.KMode), i18n.T(i18n.KStart), i18n.T(i18n.KEnd), + i18n.T(i18n.KElapsed), i18n.T(i18n.KSuccessRate), "TTFT", "TPS", "RPM", "TPM", + i18n.T(i18n.KProtocol), i18n.T(i18n.KModel), i18n.T(i18n.KCache), i18n.T(i18n.KErrorSummary), + }) indent := " " gap := 4 contentW := maxInt(12, width-lipgloss.Width(indent)) @@ -538,20 +543,20 @@ func buildTaskHistoryDetailLines(history []types.TaskRunSummary, histIdx int, st lines := []string{ padRight(st.Divider.Render(strings.Repeat("─", width)), width), - padRight(" "+st.SectionHead.Render("记录详情"), width), + padRight(" "+st.SectionHead.Render(i18n.T(i18n.KRecordDetails)), width), } lines = appendPairRow(lines, - "状态", statusText, statusStyle, - "模式", modeText, st.Value, + i18n.T(i18n.KStatus), statusText, statusStyle, + i18n.T(i18n.KMode), modeText, st.Value, ) lines = appendPairRow(lines, - "开始", sel.StartedAt.Format("2006-01-02 15:04"), st.Value, - "结束", finishedText, st.Value, + i18n.T(i18n.KStart), sel.StartedAt.Format("2006-01-02 15:04"), st.Value, + i18n.T(i18n.KEnd), finishedText, st.Value, ) lines = appendPairRow(lines, - "耗时", elapsedText, st.Value, - "成功率", fmt.Sprintf("%.1f%%", sel.SuccessRate), st.Value, + i18n.T(i18n.KElapsed), elapsedText, st.Value, + i18n.T(i18n.KSuccessRate), fmt.Sprintf("%.1f%%", sel.SuccessRate), st.Value, ) lines = appendPairRow(lines, "TTFT", fmtDuration(sel.AvgTTFT), st.Value, @@ -561,13 +566,13 @@ func buildTaskHistoryDetailLines(history []types.TaskRunSummary, histIdx int, st "RPM", fmt.Sprintf("%.0f req/min", sel.RPM), st.MetricVal, "TPM", fmt.Sprintf("%.0f tok/min", sel.TPM), st.MetricVal, ) - lines = appendSingleField(lines, "协议", shortProtocol(sel.Protocol), st.Value) - lines = appendSingleField(lines, "模型", sel.Model, st.Value) + lines = appendSingleField(lines, i18n.T(i18n.KProtocol), shortProtocol(sel.Protocol), st.Value) + lines = appendSingleField(lines, i18n.T(i18n.KModel), sel.Model, st.Value) if sel.CacheHitRate > 0 { - lines = appendSingleField(lines, "缓存", fmt.Sprintf("%.1f%%", sel.CacheHitRate*100), st.Value) + lines = appendSingleField(lines, i18n.T(i18n.KCache), fmt.Sprintf("%.1f%%", sel.CacheHitRate*100), st.Value) } if sel.ErrorSummary != "" { - lines = append(lines, indent+st.Label.Render("错误摘要")) + lines = append(lines, indent+st.Label.Render(i18n.T(i18n.KErrorSummary))) for _, seg := range wrapText(sel.ErrorSummary, maxInt(10, contentW-2)) { lines = append(lines, indent+" "+st.ErrStyle.Render(seg)) } diff --git a/internal/tui/pages/tasklist.go b/internal/tui/pages/tasklist.go index c6e63b1..8e0eb6b 100644 --- a/internal/tui/pages/tasklist.go +++ b/internal/tui/pages/tasklist.go @@ -8,6 +8,7 @@ import ( tea "github.com/charmbracelet/bubbletea" "charm.land/lipgloss/v2" lgtable "charm.land/lipgloss/v2/table" + "github.com/yinxulai/ait/internal/i18n" "github.com/yinxulai/ait/internal/server" "github.com/yinxulai/ait/internal/types" ) @@ -151,7 +152,7 @@ func RenderTaskList(s *TaskListState, st Styles, width, height int) string { var cbItems []HotkeyItem if s.ConfirmDelete { - cbItems = []HotkeyItem{HotkeyAction("y/Enter", "确认删除"), HotkeyAction("n/Esc", "取消")} + cbItems = []HotkeyItem{HotkeyAction("y/Enter", i18n.T(i18n.KConfirmDelete)), HotkeyAction("n/Esc", i18n.T(i18n.KCancel))} } else if t, ok := s.CurrentTask(); ok { if s.IsTaskRunning(t.ID) { cbItems = Hotkeys_TaskList_Running() @@ -159,7 +160,7 @@ func RenderTaskList(s *TaskListState, st Styles, width, height int) string { cbItems = Hotkeys_TaskList_Normal() } } else { - cbItems = []HotkeyItem{HotkeyAction("a", "新建任务")} + cbItems = []HotkeyItem{HotkeyAction("a", i18n.T(i18n.KNewTask))} } runningCount := 0 for _, rs := range s.ActiveRuns { @@ -167,20 +168,20 @@ func RenderTaskList(s *TaskListState, st Styles, width, height int) string { runningCount++ } } - headerRight := []string{"暂无运行历史"} + headerRight := []string{i18n.T(i18n.KNoRunHistory)} if latest := s.latestRunAt(); latest != nil { - headerRight = []string{"最近运行 " + fmtRelativeTime(*latest)} + headerRight = []string{fmtRelativeTime(*latest)} } if t, ok := s.CurrentTask(); ok { - headerRight = append([]string{"当前 " + truncate(t.Name, 22)}, headerRight...) + headerRight = append([]string{truncate(t.Name, 22)}, headerRight...) } l := PageLayout{ - HeaderTitle: "任务中心", - HeaderSubtitle: "创建任务、运行压测、查看执行记录与导出报告", - HeaderMeta: fmt.Sprintf("%d 个任务", len(s.Tasks)), - HeaderInfoLeft: []string{fmt.Sprintf("运行中 %d", runningCount)}, + HeaderTitle: i18n.T(i18n.KTaskCenter), + HeaderSubtitle: i18n.T(i18n.KTaskListSubtitle), + HeaderMeta: fmt.Sprintf("%d", len(s.Tasks)), + HeaderInfoLeft: []string{fmt.Sprintf("%s %d", i18n.T(i18n.KRunning), runningCount)}, HeaderInfoRight: headerRight, - Hotkeys: NewPageHotkeysWithHelp(cbItems, "[↑↓] 选择", "[a] 新建", "[q] 退出"), + Hotkeys: NewPageHotkeysWithHelp(cbItems, "[↑↓]", "[a]", i18n.T(i18n.KHintQuit)), } frame := l.Frame(width, height) panel := NewPanelFrame(frame.OuterWidth) @@ -220,7 +221,7 @@ func buildTaskListContent(s *TaskListState, st Styles, width, maxH int) string { rs := s.ActiveRuns[t.ID] _, hasActiveRun := s.ActiveRuns[t.ID] - modeText := "标准" + modeText := i18n.T(i18n.KStandardMode) isTurbo := false if t.Input.Turbo { modeText = "Turbo" @@ -230,7 +231,7 @@ func buildTaskListContent(s *TaskListState, st Styles, width, maxH int) string { isRunning := hasActiveRun || (t.LatestRun != nil && t.LatestRun.Status == string(server.RunStatusRunning)) lastRunText := "─" if isRunning { - lastRunText = "运行中" + lastRunText = i18n.T(i18n.KRunning) } else if t.LatestRun != nil && !t.LatestRun.FinishedAt.IsZero() { lastRunText = fmtRelativeTime(t.LatestRun.FinishedAt) } @@ -254,7 +255,7 @@ func buildTaskListContent(s *TaskListState, st Styles, width, maxH int) string { tpsText = fmt.Sprintf("%.1f", rs.AvgTPS) } else if !hasActiveRun && t.LatestRun != nil { if t.Input.Turbo && t.LatestRun.MaxStableConcurrency > 0 { - tpsText = fmt.Sprintf("并发%d", t.LatestRun.MaxStableConcurrency) + tpsText = fmt.Sprintf(i18n.T(i18n.KConcFmt), t.LatestRun.MaxStableConcurrency) } else if !t.Input.Turbo { tpsText = fmt.Sprintf("%.1f", t.LatestRun.AvgTPS) } @@ -301,7 +302,7 @@ func buildTaskListContent(s *TaskListState, st Styles, width, maxH int) string { // colWidths: 0 = 弹性列(占用剩余宽度),>0 = 固定总宽(包括两端各 1 字符 padding) colWidths := []int{0, 8, 22, 12, 8, 10, 10, 10, 8, 8} // 任务名称=flex, 模式, 协议, 上次运行, 成功率, 缓存命中, TTFT均值, TPS均值, RPM, TPM t := lgtable.New(). - Headers("任务名称", "模式", "协议", "上次运行", "成功率", "缓存命中", "均值TTFT", "均值TPS", "RPM", "TPM"). + Headers(i18n.T(i18n.KTaskName), i18n.T(i18n.KMode), i18n.T(i18n.KProtocol), i18n.T(i18n.KLastRun), i18n.T(i18n.KSuccessRate), i18n.T(i18n.KColCacheHit), i18n.T(i18n.KColAvgTTFT), i18n.T(i18n.KColAvgTPS), "RPM", "TPM"). Width(width). Height(maxH). YOffset(s.Offset). @@ -356,7 +357,7 @@ func buildTaskListContent(s *TaskListState, st Styles, width, maxH int) string { for len(tableLines) < maxH-1 { tableLines = append(tableLines, "") } - tableLines = append(tableLines, " "+st.Muted.Render("暂无任务 按 [a] 新建第一个任务")) + tableLines = append(tableLines, " "+st.Muted.Render(i18n.T(i18n.KNoTasks))) if len(tableLines) > maxH { tableLines = tableLines[:maxH] } @@ -385,14 +386,14 @@ func buildTaskListConfirmContent(s *TaskListState, st Styles, width, maxH int) s return strings.Repeat("\n", maxH-1) } lines = append(lines, "") - lines = append(lines, st.ErrStyle.Render(" 确认删除任务?")) + lines = append(lines, st.ErrStyle.Render(" "+i18n.T(i18n.KConfirmDeletePrompt))) lines = append(lines, "") - lines = append(lines, " "+st.Label.Render("任务名称")+" "+st.Value.Render(truncate(task.Name, maxInt(8, width-14)))) - lines = append(lines, " "+st.Label.Render("任务 ID ")+" "+st.Muted.Render(task.ID)) + lines = append(lines, " "+st.Label.Render(i18n.T(i18n.KTaskName))+" "+st.Value.Render(truncate(task.Name, maxInt(8, width-14)))) + lines = append(lines, " "+st.Label.Render(i18n.T(i18n.KTaskID))+" "+st.Muted.Render(task.ID)) lines = append(lines, "") - lines = append(lines, " "+st.Muted.Render("此操作不可恢复,任务的历史运行记录将一并删除。")) + lines = append(lines, " "+st.Muted.Render(i18n.T(i18n.KIrreversible))) lines = append(lines, "") - lines = append(lines, " "+st.Value.Render("[y / Enter]")+" 确认删除 "+st.Value.Render("[n / Esc]")+" 取消") + lines = append(lines, " "+st.Value.Render("[y / Enter]")+" "+i18n.T(i18n.KConfirmDelete)+" "+st.Value.Render("[n / Esc]")+" "+i18n.T(i18n.KCancel)) for len(lines) < maxH { lines = append(lines, "") } diff --git a/internal/tui/pages/turbodash.go b/internal/tui/pages/turbodash.go index a482d8f..4970e70 100644 --- a/internal/tui/pages/turbodash.go +++ b/internal/tui/pages/turbodash.go @@ -7,6 +7,7 @@ import ( tea "github.com/charmbracelet/bubbletea" "charm.land/lipgloss/v2" lgtable "charm.land/lipgloss/v2/table" + "github.com/yinxulai/ait/internal/i18n" "github.com/yinxulai/ait/internal/server" "github.com/yinxulai/ait/internal/types" ) @@ -183,10 +184,10 @@ func RenderTurboDash(d *TurboDashState, taskName string, st Styles, width, heigh default: cbItems = Hotkeys_TurboDash_Done_NoSel() } - headerLeft := []string{"等待数据"} + headerLeft := []string{i18n.T(i18n.KWaitingStatus)} headerRight := []string{} if rs != nil { - headerLeft = []string{runStatusText(string(rs.Status)), fmt.Sprintf("完成 %d/%d", rs.DoneReqs, rs.TotalReqs)} + headerLeft = []string{runStatusText(string(rs.Status)), fmt.Sprintf("%d/%d", rs.DoneReqs, rs.TotalReqs)} var levelNum int if d.IsRunning() { levelNum = len(rs.Levels) + 1 @@ -196,21 +197,21 @@ func RenderTurboDash(d *TurboDashState, taskName string, st Styles, width, heigh if levelNum < 1 { levelNum = 1 } - headerRight = []string{fmt.Sprintf("级别 %d", levelNum)} + headerRight = []string{fmt.Sprintf("%s %d", i18n.T(i18n.KColLevel), levelNum)} if len(rs.Levels) > 0 { - headerRight = append(headerRight, fmt.Sprintf("已探测 %d 档", len(rs.Levels))) + headerRight = append(headerRight, fmt.Sprintf("%d", len(rs.Levels))) } } if d.TaskID != "" { - headerRight = append(headerRight, "任务 "+truncate(d.TaskID, 14)) + headerRight = append(headerRight, truncate(d.TaskID, 14)) } l := PageLayout{ - HeaderTitle: "Turbo 探测监控", - HeaderSubtitle: "观察并发爬坡过程、级别指标与稳定区间", - HeaderMeta: "Turbo 模式", + HeaderTitle: i18n.T(i18n.KTurboMonitor), + HeaderSubtitle: i18n.T(i18n.KTurboSubtitle), + HeaderMeta: i18n.T(i18n.KTurboModeMeta), HeaderInfoLeft: headerLeft, HeaderInfoRight: headerRight, - Hotkeys: NewPageHotkeysWithHelp(cbItems, "[b/Esc] 返回上一页", "[q] 退出"), + Hotkeys: NewPageHotkeysWithHelp(cbItems, i18n.T(i18n.KHintGoBack), i18n.T(i18n.KHintQuit)), } frame := l.Frame(width, height) bodyPanel := frame.InnerPanel() @@ -241,15 +242,17 @@ func RenderTurboDash(d *TurboDashState, taskName string, st Styles, width, heigh // buildTurboDashParams 构建 Turbo 仪表盘左侧任务参数面板。 func buildTurboDashParams(rs *server.RunState, st Styles, maxH, width int) string { - lines := panelTitleLines(st, "任务参数", width, false) + lines := panelTitleLines(st, i18n.T(i18n.KConcurrency), width, false) if rs == nil { - lines = append(lines, " "+st.Muted.Render("等待数据...")) + lines = append(lines, " "+st.Muted.Render(i18n.T(i18n.KWaitingData))) } else { tc := rs.TurboConfig - lines = append(lines, " "+labelValue(st, "爬坡 ", fmt.Sprintf("%d→%d 步进+%d", tc.InitConcurrency, tc.MaxConcurrency, tc.StepSize))) - lines = append(lines, " "+labelValue(st, "每级 ", fmt.Sprintf("%d 请求", tc.LevelRequests))) - lines = append(lines, " "+labelValue(st, "停止 ", fmt.Sprintf("成功率 < %.0f%%", tc.MinSuccessRate*100))) + lbls := []string{i18n.T(i18n.KRamp), i18n.T(i18n.KPerLevel), i18n.T(i18n.KStopCondLabel)} + lw := maxLabelWidth(lbls) + lines = append(lines, " "+labelValue(st, lbls[0], fmt.Sprintf("%d→%d +%d", tc.InitConcurrency, tc.MaxConcurrency, tc.StepSize), lw)) + lines = append(lines, " "+labelValue(st, lbls[1], fmt.Sprintf("%d req", tc.LevelRequests), lw)) + lines = append(lines, " "+labelValue(st, lbls[2], fmt.Sprintf("%.0f%%", tc.MinSuccessRate*100), lw)) } return finishPanelLines(lines, maxH) @@ -263,10 +266,10 @@ func buildTurboDashMetrics(rs *server.RunState, st Styles, maxH, width int) stri if rs != nil { curLevel = rs.CurrentLevel } - lines = panelTitleLines(st, fmt.Sprintf("当前级别实时指标 [并发 = %d]", curLevel), width, false) + lines = panelTitleLines(st, fmt.Sprintf(i18n.T(i18n.KTurboCurLevelFmt), curLevel), width, false) if rs == nil { - lines = append(lines, " "+st.Muted.Render("等待数据...")) + lines = append(lines, " "+st.Muted.Render(i18n.T(i18n.KWaitingData))) } else { lines = appendRunMetricLines(lines, st, rs) } @@ -277,7 +280,7 @@ func buildTurboDashMetrics(rs *server.RunState, st Styles, maxH, width int) stri // buildTurboProgressLine 构建 Turbo 模式进度条行。 func buildTurboProgressLine(rs *server.RunState, st Styles, width int) string { if rs == nil { - return " 进度 " + st.Muted.Render("等待中...") + return " " + padToDisplayWidth(i18n.T(i18n.KProgress), 4) + " " + st.Muted.Render(i18n.T(i18n.KWaitingDots)) } total := rs.TotalReqs done := rs.DoneReqs @@ -294,18 +297,18 @@ func buildTurboProgressLine(rs *server.RunState, st Styles, width int) string { } else { levelTotalStr = fmt.Sprintf("%d", levelDone) } - suffix := fmt.Sprintf(" %d/%d 当前并发 %d 总进度: %s 级", done, total, rs.CurrentLevel, levelTotalStr) - return renderProgressBar(st, " 进度 ", suffix, ratio, width) + suffix := fmt.Sprintf(i18n.T(i18n.KTurboDashSuffix), done, total, rs.CurrentLevel, levelTotalStr) + return renderProgressBar(st, " "+padToDisplayWidth(i18n.T(i18n.KProgress), 4)+" ", suffix, ratio, width) } // buildTurboRequestList 构建 Turbo 模式请求列表区域。 func buildTurboRequestList(d *TurboDashState, rs *server.RunState, st Styles, width, maxH int) string { - titleLines := panelTitleLines(st, "请求列表", width, true) + titleLines := panelTitleLines(st, i18n.T(i18n.KRequests), width, true) if rs == nil || len(rs.Requests) == 0 { - msg := "等待请求..." + msg := i18n.T(i18n.KWaitingData) if rs != nil && rs.Status != server.RunStatusRunning { - msg = "无请求详情数据" + msg = i18n.T(i18n.KNoRunRecords) } titleLines = append(titleLines, " "+st.Muted.Render(msg)) return finishPanelLines(titleLines, maxH) @@ -359,7 +362,7 @@ func buildTurboRequestList(d *TurboDashState, rs *server.RunState, st Styles, wi colWidths := []int{6, 5, 6, 0, 8, 10, 12, 12, 10} // #, 状态, 级别, 总耗时=flex, TTFT, Cache, 输入, 输出, TPS tableH := maxH - len(titleLines) tbl := lgtable.New(). - Headers("#", "状态", "级别", "总耗时", "TTFT", "Cache", "输入", "输出", "TPS"). + Headers("#", i18n.T(i18n.KStatus), i18n.T(i18n.KColLevel), i18n.T(i18n.KTotalTime), "TTFT", "Cache", i18n.T(i18n.KColInput), i18n.T(i18n.KColOutput), "TPS"). Width(width). Height(tableH). YOffset(d.ReqOff). diff --git a/internal/tui/pages/wizard.go b/internal/tui/pages/wizard.go index 45a0323..863006d 100644 --- a/internal/tui/pages/wizard.go +++ b/internal/tui/pages/wizard.go @@ -10,6 +10,7 @@ import ( "github.com/charmbracelet/bubbles/textinput" tea "github.com/charmbracelet/bubbletea" "charm.land/lipgloss/v2" + "github.com/yinxulai/ait/internal/i18n" "github.com/yinxulai/ait/internal/server" "github.com/yinxulai/ait/internal/types" ) @@ -210,7 +211,7 @@ func (wz *WizardState) BuildTaskConfig() server.TaskConfig { timeout = time.Duration(wz.Timeout) * time.Second } return server.TaskConfig{ - Name: wizardFallback(wz.Name, "未命名任务"), + Name: wizardFallback(wz.Name, i18n.T(i18n.KWzUntitled)), Input: types.Input{ Protocol: wz.Protocol, EndpointURL: wz.EndpointURL, @@ -286,12 +287,12 @@ func step1Fields() []fieldDef { } return []fieldDef{ { - kind: fieldText, label: "任务名称", + kind: fieldText, label: i18n.T(i18n.KWzTaskName), get: func(wz *WizardState) string { return wz.Name }, set: func(wz *WizardState, v string) { wz.Name = v }, }, { - kind: fieldEnum, label: "协议类型", + kind: fieldEnum, label: i18n.T(i18n.KWzProtocol), get: func(wz *WizardState) string { return wz.Protocol }, toggle: func(wz *WizardState, forward bool) { idx := 0 @@ -312,7 +313,7 @@ func step1Fields() []fieldDef { }, }, { - kind: fieldText, label: "接口地址", + kind: fieldText, label: i18n.T(i18n.KWzEndpoint), get: func(wz *WizardState) string { if wz.EndpointURL != "" { return wz.EndpointURL @@ -323,13 +324,13 @@ func step1Fields() []fieldDef { set: func(wz *WizardState, v string) { wz.EndpointURL = v }, }, { - kind: fieldText, label: "API 密钥", + kind: fieldText, label: i18n.T(i18n.KWzAPIKey), get: func(wz *WizardState) string { return wz.APIKey }, set: func(wz *WizardState, v string) { wz.APIKey = v }, password: true, }, { - kind: fieldText, label: "测试模型", + kind: fieldText, label: i18n.T(i18n.KWzTestModel), get: func(wz *WizardState) string { return wz.Model }, set: func(wz *WizardState, v string) { wz.Model = v }, }, @@ -340,12 +341,12 @@ func step1Fields() []fieldDef { func step2Fields(turbo bool) []fieldDef { fields := []fieldDef{ { - kind: fieldBool, label: "测试模式", + kind: fieldBool, label: i18n.T(i18n.KWzTestMode), get: func(wz *WizardState) string { if wz.Turbo { - return "Turbo 模式" + return i18n.T(i18n.KWzTurboMode) } - return "标准模式" + return i18n.T(i18n.KWzStandardMode) }, toggle: func(wz *WizardState, _ bool) { wz.Turbo = !wz.Turbo }, triggersFieldReset: true, @@ -354,19 +355,19 @@ func step2Fields(turbo bool) []fieldDef { if !turbo { fields = append(fields, - intField("并发数", func(wz *WizardState) int { return wz.Concurrency }, func(wz *WizardState, n int) { wz.Concurrency = n }), - intField("请求总数", func(wz *WizardState) int { return wz.Count }, func(wz *WizardState, n int) { wz.Count = n }), - intField("超时(秒)", func(wz *WizardState) int { return wz.Timeout }, func(wz *WizardState, n int) { wz.Timeout = n }), + intField(i18n.T(i18n.KWzConcurrency), func(wz *WizardState) int { return wz.Concurrency }, func(wz *WizardState, n int) { wz.Concurrency = n }), + intField(i18n.T(i18n.KWzTotalRequests), func(wz *WizardState) int { return wz.Count }, func(wz *WizardState, n int) { wz.Count = n }), + intField(i18n.T(i18n.KWzTimeoutSecs), func(wz *WizardState) int { return wz.Timeout }, func(wz *WizardState, n int) { wz.Timeout = n }), ) } else { fields = append(fields, - intField("初始并发", func(wz *WizardState) int { return wz.InitConcurrency }, func(wz *WizardState, n int) { wz.InitConcurrency = n }), - intField("最大并发", func(wz *WizardState) int { return wz.MaxConcurrency }, func(wz *WizardState, n int) { wz.MaxConcurrency = n }), - intField("步进值", func(wz *WizardState) int { return wz.StepSize }, func(wz *WizardState, n int) { wz.StepSize = n }), - intField("每级请求数", func(wz *WizardState) int { return wz.LevelRequests }, func(wz *WizardState, n int) { wz.LevelRequests = n }), + intField(i18n.T(i18n.KWzInitConc), func(wz *WizardState) int { return wz.InitConcurrency }, func(wz *WizardState, n int) { wz.InitConcurrency = n }), + intField(i18n.T(i18n.KWzMaxConc), func(wz *WizardState) int { return wz.MaxConcurrency }, func(wz *WizardState, n int) { wz.MaxConcurrency = n }), + intField(i18n.T(i18n.KWzStepSize), func(wz *WizardState) int { return wz.StepSize }, func(wz *WizardState, n int) { wz.StepSize = n }), + intField(i18n.T(i18n.KWzLevelReqs), func(wz *WizardState) int { return wz.LevelRequests }, func(wz *WizardState, n int) { wz.LevelRequests = n }), fieldDef{ kind: fieldNumber, - label: "最低成功率", + label: i18n.T(i18n.KWzMinSuccessRate), get: func(wz *WizardState) string { return fmt.Sprintf("%.0f", wz.MinSuccessRate) }, set: func(wz *WizardState, v string) { if f, err := strconv.ParseFloat(v, 64); err == nil && f > 0 && f <= 100 { @@ -380,7 +381,7 @@ func step2Fields(turbo bool) []fieldDef { // 流式模式:与测试模式无关,两种模式均可配置 fields = append(fields, fieldDef{ kind: fieldBool, - label: "流式模式", + label: i18n.T(i18n.KWzStreamMode), get: func(wz *WizardState) string { return boolLabel(wz.Stream) }, toggle: func(wz *WizardState, _ bool) { wz.Stream = !wz.Stream }, }) @@ -389,17 +390,17 @@ func step2Fields(turbo bool) []fieldDef { promptModes := []string{PromptModeText, PromptModeFile, PromptModeGenerated, PromptModeRaw} fields = append(fields, fieldDef{ - kind: fieldEnum, label: "输入方式", + kind: fieldEnum, label: i18n.T(i18n.KWzInputMode), get: func(wz *WizardState) string { switch wz.PromptMode { case PromptModeFile: - return "文件" + return i18n.T(i18n.KWzInputFile) case PromptModeGenerated: - return "按长度生成" + return i18n.T(i18n.KWzInputGenerated) case PromptModeRaw: - return "RAW 请求体" + return i18n.T(i18n.KWzInputRaw) default: - return "直接输入" + return i18n.T(i18n.KWzInputDirect) } }, toggle: func(wz *WizardState, forward bool) { @@ -426,7 +427,7 @@ func step2Fields(turbo bool) []fieldDef { // 根据 prompt 模式添加对应字段(在渲染时动态决定) fields = append(fields, fieldDef{ - kind: fieldText, label: "内容", + kind: fieldText, label: i18n.T(i18n.KWzPromptContent), get: func(wz *WizardState) string { switch wz.PromptMode { case PromptModeFile: @@ -614,11 +615,11 @@ func RenderWizard(wz *WizardState, st Styles, width, height int) string { if wz == nil { return renderTooSmall(st, width, height) } - stepTitles := []string{"基本信息", "测试参数", "确认保存"} + stepTitles := []string{i18n.T(i18n.KWzStep1Label), i18n.T(i18n.KWzStep2Label), i18n.T(i18n.KWzStep3Label)} stepDescs := []string{ - "配置任务名称、模型协议和连接信息。", - "选择压测模式,并补全并发与 Prompt 参数。", - "保存前快速检查关键配置。", + i18n.T(i18n.KWzStep1Desc), + i18n.T(i18n.KWzStep2Desc), + i18n.T(i18n.KWzStep3Desc), } stepTitle := stepTitles[int(wz.Step)] headerLeft := []string{stepTitle} @@ -628,26 +629,26 @@ func RenderWizard(wz *WizardState, st Styles, width, height int) string { headerRight := []string{} if wz.Step >= wizardStep2 { if wz.Turbo { - headerRight = append(headerRight, "Turbo 模式") + headerRight = append(headerRight, i18n.T(i18n.KTurboMode)) } else { - headerRight = append(headerRight, "标准模式") + headerRight = append(headerRight, i18n.T(i18n.KStandardMode)) } } if wz.Model != "" { - headerRight = append(headerRight, "模型 "+truncate(wz.Model, 18)) + headerRight = append(headerRight, truncate(wz.Model, 18)) } - action := "创建任务" + action := i18n.T(i18n.KNewTask) if wz.EditingID != "" { - action = "编辑任务" + action = i18n.T(i18n.KEdit) } l := PageLayout{ HeaderTitle: action, HeaderSubtitle: stepDescs[int(wz.Step)], - HeaderMeta: fmt.Sprintf("步骤 %d/3", int(wz.Step)+1), + HeaderMeta: fmt.Sprintf(i18n.T(i18n.KWzStepFmt), int(wz.Step)+1), HeaderInfoLeft: headerLeft, HeaderInfoRight: headerRight, - Hotkeys: NewPageHotkeysWithHelp(wizardHotkeyItems(wz.Step), "[Ctrl+C] 退出"), + Hotkeys: NewPageHotkeysWithHelp(wizardHotkeyItems(wz.Step), i18n.T(i18n.KHintQuit)), } frame := l.Frame(width, height) panel := NewPanelFrame(frame.OuterWidth) @@ -743,34 +744,34 @@ func buildWizardBody(wz *WizardState, st Styles, contentW int) ([]string, int, i case wizardStep2: fields := step2Fields(wz.Turbo) for i, f := range fields { - if f.label == "输入方式" { - lines = append(lines, "", st.Muted.Render("Prompt 配置")) + if f.label == i18n.T(i18n.KWzInputMode) { + lines = append(lines, "", st.Muted.Render(i18n.T(i18n.KWzPromptConfig))) } appendField(renderWizardField(st, f, wz, i == wz.FieldIndex, contentW), i == wz.FieldIndex) - if f.label == "测试模式" { + if f.label == i18n.T(i18n.KWzTestMode) { if wz.Turbo { - lines = append(lines, st.Muted.Render(" 自动从低并发起步,逐级加压,找到最大稳定吞吐点")) + lines = append(lines, st.Muted.Render(" "+i18n.T(i18n.KWzTurboModeLabel))) } else { - lines = append(lines, st.Muted.Render(" 固定并发数和请求总数,测量在指定负载下的延迟与成功率")) + lines = append(lines, st.Muted.Render(" "+i18n.T(i18n.KWzSelectModeHint))) } } - if f.label == "输入方式" { + if f.label == i18n.T(i18n.KWzInputMode) { switch wz.PromptMode { case PromptModeText: - lines = append(lines, st.Muted.Render(" 直接粘贴或输入 Prompt 文本,所有请求共享同一段内容")) + lines = append(lines, st.Muted.Render(" "+i18n.T(i18n.KWzHintDirect))) case PromptModeFile: - lines = append(lines, st.Muted.Render(" 从文件读取 Prompt,支持通配符匹配多个文件(请求按文件轮换)")) + lines = append(lines, st.Muted.Render(" "+i18n.T(i18n.KWzHintFile))) case PromptModeGenerated: - lines = append(lines, st.Muted.Render(" 按指定字符数自动生成测试文本,内容含大量公共前缀以模拟缓存命中")) + lines = append(lines, st.Muted.Render(" "+i18n.T(i18n.KWzHintCacheToken))) case PromptModeRaw: - lines = append(lines, st.Muted.Render(" 粘贴完整的 HTTP 请求 JSON Body,将跳过参数组装直接发送")) + lines = append(lines, st.Muted.Render(" "+i18n.T(i18n.KWzHintRaw))) } } - if f.label == "内容" && (wz.PromptMode == PromptModeText || wz.PromptMode == PromptModeFile || wz.PromptMode == PromptModeGenerated) { - lines = append(lines, st.Muted.Render(" 提示:大多数服务需要 ≥ 1024 tokens 才能命中缓存")) + if f.label == i18n.T(i18n.KWzPromptContent) && (wz.PromptMode == PromptModeText || wz.PromptMode == PromptModeFile || wz.PromptMode == PromptModeGenerated) { + lines = append(lines, st.Muted.Render(" "+i18n.T(i18n.KWzHintCacheToken))) } - if f.label == "内容" && wz.PromptMode == PromptModeRaw { - lines = append(lines, st.Muted.Render(" 提示:粘贴 API 请求的完整 JSON Body,将直接作为 HTTP 请求体发送")) + if f.label == i18n.T(i18n.KWzPromptContent) && wz.PromptMode == PromptModeRaw { + lines = append(lines, st.Muted.Render(" "+i18n.T(i18n.KWzHintRawBody))) } // 提示行追加完毕后,更新聚焦块的末尾行(含提示) if i == wz.FieldIndex { @@ -794,7 +795,7 @@ func renderWizardField(st Styles, f fieldDef, wz *WizardState, active bool, maxW } // API key 遮蔽显示 - if f.label == "API 密钥" && valueStr != "" { + if f.password && valueStr != "" { valueStr = maskAPIKey(valueStr) } @@ -805,7 +806,7 @@ func renderWizardField(st Styles, f fieldDef, wz *WizardState, active bool, maxW fieldW := maxInt(10, maxW-19) valueStyle := st.Value if valueStr == "" && !active { - valueStr = "未填写" + valueStr = i18n.T(i18n.KWzNotFilled) valueStyle = st.Muted } @@ -849,44 +850,44 @@ func renderStep3Summary(wz *WizardState, st Styles, innerW int) []string { appendWizardSummaryRow(&lines, st, label, value, innerW, valueStyle) } - lines = append(lines, st.SectionHead.Render("配置概览")) - addRow("任务名称", wizardFallback(wz.Name, "未命名任务"), st.Value) - addRow("协议", wz.Protocol, st.Value) + lines = append(lines, st.SectionHead.Render(i18n.T(i18n.KProtocol))) + addRow(i18n.T(i18n.KWzTaskName), wizardFallback(wz.Name, i18n.T(i18n.KWzUntitled)), st.Value) + addRow(i18n.T(i18n.KWzProtocol), wz.Protocol, st.Value) endpointDisplay := wz.EndpointURL if endpointDisplay == "" { endpointDisplay = types.DefaultEndpointURL(wz.Protocol) } - addRow("接口地址", endpointDisplay, st.Value) - addRow("API 密钥", wizardFallback(maskAPIKey(wz.APIKey), "未填写"), st.Value) - addRow("测试模型", wizardFallback(wz.Model, "未填写"), st.Value) + addRow(i18n.T(i18n.KWzEndpoint), endpointDisplay, st.Value) + addRow(i18n.T(i18n.KWzAPIKey), wizardFallback(maskAPIKey(wz.APIKey), i18n.T(i18n.KWzNotFilled)), st.Value) + addRow(i18n.T(i18n.KWzTestModel), wizardFallback(wz.Model, i18n.T(i18n.KWzNotFilled)), st.Value) - lines = append(lines, "", st.SectionHead.Render("执行参数")) + lines = append(lines, "", st.SectionHead.Render(i18n.T(i18n.KWzExecParams))) if wz.Turbo { - addRow("测试模式", "Turbo 模式", st.Value) - addRow("并发爬坡", fmt.Sprintf("%d → %d · 步进 +%d · 每级 %d 请求", + addRow(i18n.T(i18n.KWzTestMode), i18n.T(i18n.KWzTurboMode), st.Value) + addRow(i18n.T(i18n.KWzConcurrencyRamp), fmt.Sprintf("%d → %d · +%d · %d req", wz.InitConcurrency, wz.MaxConcurrency, wz.StepSize, wz.LevelRequests), st.Value) - addRow("停止条件", fmt.Sprintf("成功率 < %.0f%%", wz.MinSuccessRate), st.Value) + addRow(i18n.T(i18n.KWzStopCondition), fmt.Sprintf("< %.0f%%", wz.MinSuccessRate), st.Value) } else { - addRow("测试模式", "标准模式", st.Value) - addRow("并发数", strconv.Itoa(wz.Concurrency), st.Value) - addRow("请求总数", strconv.Itoa(wz.Count), st.Value) - addRow("超时", fmt.Sprintf("%ds", wz.Timeout), st.Value) + addRow(i18n.T(i18n.KWzTestMode), i18n.T(i18n.KWzStandardMode), st.Value) + addRow(i18n.T(i18n.KWzConcurrency), strconv.Itoa(wz.Concurrency), st.Value) + addRow(i18n.T(i18n.KWzTotalRequests), strconv.Itoa(wz.Count), st.Value) + addRow(i18n.T(i18n.KWzTimeoutLabel), fmt.Sprintf("%ds", wz.Timeout), st.Value) } - addRow("流式模式", boolLabel(wz.Stream), st.Value) + addRow(i18n.T(i18n.KWzStreamMode), boolLabel(wz.Stream), st.Value) - lines = append(lines, "", st.SectionHead.Render("Prompt")) - addRow("输入方式", wizardPromptModeLabel(wz.PromptMode), st.Value) + lines = append(lines, "", st.SectionHead.Render(i18n.T(i18n.KWzPromptSection))) + addRow(i18n.T(i18n.KWzInputMode), wizardPromptModeLabel(wz.PromptMode), st.Value) promptDesc := promptSummary(wz.PromptMode, wz.PromptText, wz.PromptFile, wz.PromptLength) - addRow("内容摘要", wizardFallback(promptDesc, "未填写"), st.Value) + addRow(i18n.T(i18n.KWzContentSummary), wizardFallback(promptDesc, i18n.T(i18n.KWzNotFilled)), st.Value) if wz.PromptMode == PromptModeText { - addRow("字符数", strconv.Itoa(len([]rune(wz.PromptText))), st.Muted) + addRow(i18n.T(i18n.KWzContentSummary), strconv.Itoa(len([]rune(wz.PromptText))), st.Muted) } else if wz.PromptMode == PromptModeGenerated { - addRow("目标长度", strconv.Itoa(wz.PromptLength), st.Muted) + addRow(i18n.T(i18n.KWzLevelReqs), strconv.Itoa(wz.PromptLength), st.Muted) } else if wz.PromptMode == PromptModeRaw { - addRow("Body 字节数", strconv.Itoa(len(wz.PromptText)), st.Muted) + addRow(i18n.T(i18n.KWzBodyBytes), strconv.Itoa(len(wz.PromptText)), st.Muted) } - lines = append(lines, "", st.Muted.Render("保存位置: ~/.ait/tasks/.json")) + lines = append(lines, "", st.Muted.Render(i18n.T(i18n.KWzSaveLocation))) return lines } @@ -895,7 +896,11 @@ func renderWizardStepStrip(step wizardStep) string { active := lipgloss.NewStyle().Background(colorPink).Foreground(colorWhite).Bold(true).Padding(0, 1) done := lipgloss.NewStyle().Background(colorCyan).Foreground(lipgloss.Color("233")).Bold(true).Padding(0, 1) idle := lipgloss.NewStyle().Background(lipgloss.Color("238")).Foreground(colorMuted).Padding(0, 1) - labels := []string{"1 基本信息", "2 测试参数", "3 确认保存"} + labels := []string{ + "1 " + i18n.T(i18n.KWzStep1Label), + "2 " + i18n.T(i18n.KWzStep2Label), + "3 " + i18n.T(i18n.KWzStep3Label), + } parts := make([]string, 0, len(labels)) for i, label := range labels { switch { @@ -911,31 +916,31 @@ func renderWizardStepStrip(step wizardStep) string { } func wizardFieldLabel(f fieldDef, wz *WizardState) string { - if f.label != "内容" { + if f.label != i18n.T(i18n.KWzPromptContent) { return f.label } switch wz.PromptMode { case PromptModeFile: - return "文件路径" + return i18n.T(i18n.KWzFileSummary) case PromptModeGenerated: - return "生成长度" + return i18n.T(i18n.KWzRAWBody) case PromptModeRaw: - return "JSON Body" + return i18n.T(i18n.KWzJSONBody) default: - return "Prompt" + return i18n.T(i18n.KWzPromptLabelShort) } } func wizardPromptModeLabel(mode string) string { switch mode { case PromptModeFile: - return "文件" + return i18n.T(i18n.KWzInputFile) case PromptModeGenerated: - return "按长度生成" + return i18n.T(i18n.KWzInputGenerated) case PromptModeRaw: - return "RAW 请求体" + return i18n.T(i18n.KWzInputRaw) default: - return "直接输入" + return i18n.T(i18n.KWzInputDirect) } } @@ -993,12 +998,12 @@ func wizardHotkeyItems(step wizardStep) []HotkeyItem { func wizardStatusText(wz *WizardState, offset, end, scrollTotal, visible int) string { if wz.Step == wizardStep3 { if scrollTotal <= 0 { - return "暂无确认项" + return i18n.T(i18n.KWzNoConfirmItems) } if scrollTotal > visible { - return fmt.Sprintf("确认项 %d-%d/%d", offset+1, end, scrollTotal) + return fmt.Sprintf(i18n.T(i18n.KWzConfirmRange), offset+1, end, scrollTotal) } - return fmt.Sprintf("共 %d 项待确认", scrollTotal) + return fmt.Sprintf(i18n.T(i18n.KWzConfirmTotal), scrollTotal) } var fieldTotal int switch wz.Step { @@ -1008,7 +1013,7 @@ func wizardStatusText(wz *WizardState, offset, end, scrollTotal, visible int) st fieldTotal = len(step2Fields(wz.Turbo)) } if fieldTotal <= 0 { - return "暂无配置项" + return i18n.T(i18n.KWzNoFields) } - return fmt.Sprintf("当前字段 %d/%d", wz.FieldIndex+1, fieldTotal) + return fmt.Sprintf(i18n.T(i18n.KWzFieldProgress), wz.FieldIndex+1, fieldTotal) } From 865a4d81da051d539a42e30fd7a8295f047d77bf Mon Sep 17 00:00:00 2001 From: Alain Date: Sun, 24 May 2026 23:11:06 +0800 Subject: [PATCH 48/52] =?UTF-8?q?feat(i18n):=20=E6=B7=BB=E5=8A=A0=E8=AF=AD?= =?UTF-8?q?=E8=A8=80=E5=88=87=E6=8D=A2=E5=8A=9F=E8=83=BD=EF=BC=8C=E6=94=AF?= =?UTF-8?q?=E6=8C=81=E9=80=9A=E8=BF=87=E5=91=BD=E4=BB=A4=E8=A1=8C=E5=8F=82?= =?UTF-8?q?=E6=95=B0=E5=92=8C=E5=BF=AB=E6=8D=B7=E9=94=AE=E5=88=87=E6=8D=A2?= =?UTF-8?q?=E7=95=8C=E9=9D=A2=E8=AF=AD=E8=A8=80?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- cmd/ait/ait.go | 12 ++++++++++++ internal/config/config.go | 1 + internal/i18n/i18n.go | 9 +++++++++ internal/tui/model.go | 29 +++++++++++++++++++++++++++++ internal/tui/pages/help.go | 1 + internal/tui/pages/layout.go | 4 ++-- 6 files changed, 54 insertions(+), 2 deletions(-) diff --git a/cmd/ait/ait.go b/cmd/ait/ait.go index 587fb00..2b4fdb8 100644 --- a/cmd/ait/ait.go +++ b/cmd/ait/ait.go @@ -10,6 +10,8 @@ import ( "github.com/yinxulai/ait/internal/server" "github.com/yinxulai/ait/internal/tui" "github.com/yinxulai/ait/internal/types" + "github.com/yinxulai/ait/internal/config" + "github.com/yinxulai/ait/internal/i18n" ) // 版本信息,通过 ldflags 在构建时注入。 @@ -36,6 +38,7 @@ func main() { count := flag.Int("count", 100, "请求总数") timeout := flag.Int("timeout", 300, "请求超时时间(秒)") turboFlag := flag.Bool("turbo", false, "是否启用 Turbo 并发探测模式") + langFlag := flag.String("lang", "", "界面语言:zh 或 en") flag.Parse() // ── 版本输出 ────────────────────────────────────────────────────────────── @@ -53,6 +56,15 @@ func main() { os.Exit(1) } + // ── 初始化界面语言(flag > 配置文件 > 默认 ZH)──────────────────────────── + if *langFlag == "en" { + i18n.SetLang(i18n.EN) + } else if *langFlag == "zh" { + i18n.SetLang(i18n.ZH) + } else if cfg, err := config.Load(); err == nil && cfg.Lang == "en" { + i18n.SetLang(i18n.EN) + } + // ── 若提供了足够参数则预建任务并自动启动 ──────────────────────────────────── // 合并 --model 和 --models,去重,保持顺序 var modelList []string diff --git a/internal/config/config.go b/internal/config/config.go index c9940fe..8d579f7 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -22,6 +22,7 @@ type Config struct { LastSelectedTaskID string `json:"last_selected_task_id,omitempty"` DefaultProtocol string `json:"default_protocol,omitempty"` ProxyURL string `json:"proxy_url,omitempty"` + Lang string `json:"lang,omitempty"` // "zh" or "en", empty = zh } func Load() (*Config, error) { diff --git a/internal/i18n/i18n.go b/internal/i18n/i18n.go index 70044b8..aee7cb5 100644 --- a/internal/i18n/i18n.go +++ b/internal/i18n/i18n.go @@ -241,6 +241,8 @@ const ( KHelpDescQuestionMark KHelpTermBack KHelpDescBack + KHelpTermLangToggle + KHelpDescLangToggle KHelpTermSelectTask KHelpDescSelectTask KHelpTermEnterDetail @@ -357,6 +359,7 @@ const ( KMinutesAgoFmt // "%d 分钟前" KHoursAgoFmt // "%d 小时前" KDaysAgoFmt // "%d 天前" + KToggleLang // "切换语言" / "Toggle Lang" ) var translations = [2]map[Key]string{ @@ -573,6 +576,8 @@ var translations = [2]map[Key]string{ KHelpDescQuestionMark: "打开此帮助页。", KHelpTermBack: "b / Esc", KHelpDescBack: "返回上一页。", + KHelpTermLangToggle: "F2", + KHelpDescLangToggle: "切换界面语言(中文 / 英文)。", KHelpTermSelectTask: "↑↓ / j k", KHelpDescSelectTask: "选择任务。", @@ -693,6 +698,7 @@ var translations = [2]map[Key]string{ KMinutesAgoFmt: "%d 分钟前", KHoursAgoFmt: "%d 小时前", KDaysAgoFmt: "%d 天前", + KToggleLang: "切换语言", }, EN: { // Hotkeys @@ -907,6 +913,8 @@ var translations = [2]map[Key]string{ KHelpDescQuestionMark: "Open this help page.", KHelpTermBack: "b / Esc", KHelpDescBack: "Go back to the previous page.", + KHelpTermLangToggle: "F2", + KHelpDescLangToggle: "Switch UI language (ZH / EN).", KHelpTermSelectTask: "↑↓ / j k", KHelpDescSelectTask: "Select a task.", @@ -1027,6 +1035,7 @@ var translations = [2]map[Key]string{ KMinutesAgoFmt: "%d min ago", KHoursAgoFmt: "%d hr ago", KDaysAgoFmt: "%d days ago", + KToggleLang: "Toggle Lang", }, } diff --git a/internal/tui/model.go b/internal/tui/model.go index 68f1d22..1f60654 100644 --- a/internal/tui/model.go +++ b/internal/tui/model.go @@ -6,6 +6,8 @@ import ( "fmt" tea "github.com/charmbracelet/bubbletea" + "github.com/yinxulai/ait/internal/config" + "github.com/yinxulai/ait/internal/i18n" "github.com/yinxulai/ait/internal/server" "github.com/yinxulai/ait/internal/tui/pages" "github.com/yinxulai/ait/internal/types" @@ -276,6 +278,16 @@ func (m *Model) handleKey(msg tea.KeyMsg) (tea.Model, tea.Cmd) { m.status = "" m.err = nil + // ── 全局快捷键(所有页面层共享)── + if msg.String() == "F2" { + if i18n.Active() == i18n.ZH { + i18n.SetLang(i18n.EN) + } else { + i18n.SetLang(i18n.ZH) + } + return m, saveLangConfigCmd(i18n.Active()) + } + switch m.view { case viewTaskList: newState, cmd, nav := pages.HandleTaskListKey(m.taskList, msg, m.client) @@ -444,6 +456,23 @@ func (m *Model) currentNavTarget() pages.NavTarget { // ─── Server 事件处理 ────────────────────────────────────────────────────────── +// saveLangConfigCmd 将语言设置异步保存到配置文件(尽力而为)。 +func saveLangConfigCmd(lang i18n.Lang) tea.Cmd { + return func() tea.Msg { + cfg, err := config.Load() + if err != nil { + cfg = &config.Config{} + } + if lang == i18n.EN { + cfg.Lang = "en" + } else { + cfg.Lang = "zh" + } + _ = cfg.Save() + return nil + } +} + func (m *Model) handleServerEvent(msg ServerEventMsg) (tea.Model, tea.Cmd) { e := msg.Event diff --git a/internal/tui/pages/help.go b/internal/tui/pages/help.go index 526e967..eb8e7d5 100644 --- a/internal/tui/pages/help.go +++ b/internal/tui/pages/help.go @@ -141,6 +141,7 @@ func helpContent() []helpSection { items: []helpItem{ {i18n.T(i18n.KHelpTermQuit), i18n.T(i18n.KHelpDescQuit)}, {i18n.T(i18n.KHelpTermQuestionMark), i18n.T(i18n.KHelpDescQuestionMark)}, + {i18n.T(i18n.KHelpTermLangToggle), i18n.T(i18n.KHelpDescLangToggle)}, {i18n.T(i18n.KHelpTermBack), i18n.T(i18n.KHelpDescBack)}, }, }, diff --git a/internal/tui/pages/layout.go b/internal/tui/pages/layout.go index c201df7..3dc642f 100644 --- a/internal/tui/pages/layout.go +++ b/internal/tui/pages/layout.go @@ -50,10 +50,10 @@ func NewPageHotkeys(hotkeys []HotkeyItem, hints ...string) PageHotkeys { } } -// NewPageHotkeysWithHelp 在 NewPageHotkeys 基础上自动追加 [?] 帮助。 +// NewPageHotkeysWithHelp 在 NewPageHotkeys 基础上自动追加 [?] 帮助和 [F2] 切换语言。 // 非帮助页使用此函数以统一显示帮助入口。 func NewPageHotkeysWithHelp(hotkeys []HotkeyItem, hints ...string) PageHotkeys { - withHelp := append(hotkeys, HotkeyAction("?", i18n.T(i18n.KHelp))) + withHelp := append(hotkeys, HotkeyAction("F2", i18n.T(i18n.KToggleLang)), HotkeyAction("?", i18n.T(i18n.KHelp))) return PageHotkeys{ Hotkeys: withHelp, Hints: HotkeyTexts(hints...), From bd8b4d33612f895848a3acac1c9888f92974164d Mon Sep 17 00:00:00 2001 From: Alain Date: Sun, 24 May 2026 23:19:17 +0800 Subject: [PATCH 49/52] =?UTF-8?q?feat(i18n):=20=E6=9B=B4=E6=96=B0=E5=BF=AB?= =?UTF-8?q?=E6=8D=B7=E9=94=AE=E6=8F=90=E7=A4=BA=EF=BC=8C=E4=BC=98=E5=8C=96?= =?UTF-8?q?=E7=94=A8=E6=88=B7=E7=95=8C=E9=9D=A2=E4=BA=A4=E4=BA=92=E4=BD=93?= =?UTF-8?q?=E9=AA=8C?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/i18n/i18n.go | 8 ++++---- internal/tui/model.go | 2 +- internal/tui/pages/layout.go | 2 +- internal/tui/pages/tasklist.go | 2 +- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/internal/i18n/i18n.go b/internal/i18n/i18n.go index aee7cb5..854df13 100644 --- a/internal/i18n/i18n.go +++ b/internal/i18n/i18n.go @@ -411,8 +411,8 @@ var translations = [2]map[Key]string{ KHintBack: "[b/Esc] 返回", KHintEscBack: "[Esc] 返回", KHintGoBack: "[b/Esc] 返回上一页", - KHintSelect: "[↑↓] 选择", - KHintNew: "[a] 新建", + KHintSelect: "[↑↓] 上下切换", + KHintNew: "[a] 创建任务", // Metric labels KSuccessRate: "成功率", @@ -748,8 +748,8 @@ var translations = [2]map[Key]string{ KHintBack: "[b/Esc] Back", KHintEscBack: "[Esc] Back", KHintGoBack: "[b/Esc] Go Back", - KHintSelect: "[↑↓] Select", - KHintNew: "[a] New", + KHintSelect: "[↑↓] Navigate", + KHintNew: "[a] New Task", // Metric labels KSuccessRate: "Success Rate", diff --git a/internal/tui/model.go b/internal/tui/model.go index 1f60654..edeb12b 100644 --- a/internal/tui/model.go +++ b/internal/tui/model.go @@ -279,7 +279,7 @@ func (m *Model) handleKey(msg tea.KeyMsg) (tea.Model, tea.Cmd) { m.err = nil // ── 全局快捷键(所有页面层共享)── - if msg.String() == "F2" { + if msg.String() == "f2" { if i18n.Active() == i18n.ZH { i18n.SetLang(i18n.EN) } else { diff --git a/internal/tui/pages/layout.go b/internal/tui/pages/layout.go index 3dc642f..446862b 100644 --- a/internal/tui/pages/layout.go +++ b/internal/tui/pages/layout.go @@ -53,7 +53,7 @@ func NewPageHotkeys(hotkeys []HotkeyItem, hints ...string) PageHotkeys { // NewPageHotkeysWithHelp 在 NewPageHotkeys 基础上自动追加 [?] 帮助和 [F2] 切换语言。 // 非帮助页使用此函数以统一显示帮助入口。 func NewPageHotkeysWithHelp(hotkeys []HotkeyItem, hints ...string) PageHotkeys { - withHelp := append(hotkeys, HotkeyAction("F2", i18n.T(i18n.KToggleLang)), HotkeyAction("?", i18n.T(i18n.KHelp))) + withHelp := append(hotkeys, HotkeyAction("f2", i18n.T(i18n.KToggleLang)), HotkeyAction("?", i18n.T(i18n.KHelp))) return PageHotkeys{ Hotkeys: withHelp, Hints: HotkeyTexts(hints...), diff --git a/internal/tui/pages/tasklist.go b/internal/tui/pages/tasklist.go index 8e0eb6b..49927b0 100644 --- a/internal/tui/pages/tasklist.go +++ b/internal/tui/pages/tasklist.go @@ -181,7 +181,7 @@ func RenderTaskList(s *TaskListState, st Styles, width, height int) string { HeaderMeta: fmt.Sprintf("%d", len(s.Tasks)), HeaderInfoLeft: []string{fmt.Sprintf("%s %d", i18n.T(i18n.KRunning), runningCount)}, HeaderInfoRight: headerRight, - Hotkeys: NewPageHotkeysWithHelp(cbItems, "[↑↓]", "[a]", i18n.T(i18n.KHintQuit)), + Hotkeys: NewPageHotkeysWithHelp(cbItems, i18n.T(i18n.KHintSelect), i18n.T(i18n.KHintNew), i18n.T(i18n.KHintQuit)), } frame := l.Frame(width, height) panel := NewPanelFrame(frame.OuterWidth) From 94011946e30f5b8ad32e7b1d5581628f60a69b17 Mon Sep 17 00:00:00 2001 From: Alain Date: Sun, 24 May 2026 23:25:50 +0800 Subject: [PATCH 50/52] =?UTF-8?q?feat(i18n):=20=E5=8A=A8=E6=80=81=E8=B0=83?= =?UTF-8?q?=E6=95=B4=E4=BB=BB=E5=8A=A1=E8=AF=A6=E6=83=85=E5=92=8C=E4=BB=BB?= =?UTF-8?q?=E5=8A=A1=E5=88=97=E8=A1=A8=E7=9A=84=E5=88=97=E5=AE=BD=EF=BC=8C?= =?UTF-8?q?=E7=A1=AE=E4=BF=9D=E5=A4=9A=E8=AF=AD=E8=A8=80=E6=94=AF=E6=8C=81?= =?UTF-8?q?=E4=B8=8B=E7=9A=84=E7=95=8C=E9=9D=A2=E9=80=82=E9=85=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/tui/pages/taskdetail.go | 19 +++++++++++++++++-- internal/tui/pages/tasklist.go | 24 ++++++++++++++++++++++-- 2 files changed, 39 insertions(+), 4 deletions(-) diff --git a/internal/tui/pages/taskdetail.go b/internal/tui/pages/taskdetail.go index df63e15..9e3f097 100644 --- a/internal/tui/pages/taskdetail.go +++ b/internal/tui/pages/taskdetail.go @@ -376,11 +376,26 @@ func buildTaskDetailContent(s *TaskDetailState, st Styles, t types.TaskDefinitio } // colWidths: 0 = 弹性列,>0 = 固定总宽 - colWidths := []int{4, 0, 7, 8, 7, 7, 7, 6, 6} // 状态图标, 时间=flex, 模式, 成功率, 耗时, TTFT, TPS, RPM, TPM + // 动态列宽:取数据最小需求与表头显示宽+2的较大值,确保切换语言后不溢出 + hw := func(s string) int { return lipgloss.Width(s) + 2 } + h3 := i18n.T(i18n.KMode) + h4 := i18n.T(i18n.KSuccessRate) + h5 := i18n.T(i18n.KElapsed) + colWidths := []int{ + 4, // 状态图标 + 0, // 时间=flex + maxInt(7, hw(h3)), // 模式 + maxInt(8, hw(h4)), // 成功率 + maxInt(7, hw(h5)), // 耗时 + maxInt(7, hw("TTFT")), // TTFT + maxInt(7, hw("TPS")), // TPS + maxInt(6, hw("RPM")), // RPM + maxInt(6, hw("TPM")), // TPM + } sel := s.HistorySel tableH := tableMaxH - len(rightTitle) tbl := lgtable.New(). - Headers("", i18n.T(i18n.KTime), i18n.T(i18n.KMode), i18n.T(i18n.KSuccessRate), i18n.T(i18n.KElapsed), "TTFT", "TPS", "RPM", "TPM"). + Headers("", i18n.T(i18n.KTime), h3, h4, h5, "TTFT", "TPS", "RPM", "TPM"). Width(rightW). Height(tableH). YOffset(s.HistoryOff). diff --git a/internal/tui/pages/tasklist.go b/internal/tui/pages/tasklist.go index 49927b0..7d3a181 100644 --- a/internal/tui/pages/tasklist.go +++ b/internal/tui/pages/tasklist.go @@ -300,9 +300,29 @@ func buildTaskListContent(s *TaskListState, st Styles, width, maxH int) string { // ── 构建 lipgloss/table ── // colWidths: 0 = 弹性列(占用剩余宽度),>0 = 固定总宽(包括两端各 1 字符 padding) - colWidths := []int{0, 8, 22, 12, 8, 10, 10, 10, 8, 8} // 任务名称=flex, 模式, 协议, 上次运行, 成功率, 缓存命中, TTFT均值, TPS均值, RPM, TPM + // 动态列宽:取数据最小需求与表头显示宽+2的较大值,确保切换语言后不溢出 + hw := func(s string) int { return lipgloss.Width(s) + 2 } + h1 := i18n.T(i18n.KMode) + h2 := i18n.T(i18n.KProtocol) + h3 := i18n.T(i18n.KLastRun) + h4 := i18n.T(i18n.KSuccessRate) + h5 := i18n.T(i18n.KColCacheHit) + h6 := i18n.T(i18n.KColAvgTTFT) + h7 := i18n.T(i18n.KColAvgTPS) + colWidths := []int{ + 0, // 任务名称=flex + maxInt(8, hw(h1)), // 模式 + maxInt(22, hw(h2)), // 协议(数据可能较长) + maxInt(12, hw(h3)), // 上次运行 + maxInt(8, hw(h4)), // 成功率 + maxInt(10, hw(h5)), // 缓存命中 + maxInt(10, hw(h6)), // TTFT均值 + maxInt(10, hw(h7)), // TPS均值 + maxInt(8, hw("RPM")), // RPM + maxInt(8, hw("TPM")), // TPM + } t := lgtable.New(). - Headers(i18n.T(i18n.KTaskName), i18n.T(i18n.KMode), i18n.T(i18n.KProtocol), i18n.T(i18n.KLastRun), i18n.T(i18n.KSuccessRate), i18n.T(i18n.KColCacheHit), i18n.T(i18n.KColAvgTTFT), i18n.T(i18n.KColAvgTPS), "RPM", "TPM"). + Headers(i18n.T(i18n.KTaskName), h1, h2, h3, h4, h5, h6, h7, "RPM", "TPM"). Width(width). Height(maxH). YOffset(s.Offset). From 04cae00064a26565830c8d61eb3b799a6b4d53b8 Mon Sep 17 00:00:00 2001 From: yinxulai Date: Mon, 25 May 2026 13:05:40 +0800 Subject: [PATCH 51/52] =?UTF-8?q?feat:=20=E6=9B=B4=E6=96=B0=E6=96=87?= =?UTF-8?q?=E6=A1=A3=E6=A0=BC=E5=BC=8F=EF=BC=8C=E4=BC=98=E5=8C=96=E5=91=BD?= =?UTF-8?q?=E4=BB=A4=E8=A1=8C=E5=8F=82=E6=95=B0=E5=92=8C=E7=A4=BA=E4=BE=8B?= =?UTF-8?q?=E8=BE=93=E5=87=BA=EF=BC=8C=E5=A2=9E=E5=BC=BA=E5=8F=AF=E8=AF=BB?= =?UTF-8?q?=E6=80=A7=20fix:=20=E4=BF=AE=E5=A4=8D=E9=94=99=E8=AF=AF?= =?UTF-8?q?=E4=BF=A1=E6=81=AF=E6=A0=BC=E5=BC=8F=E5=8C=96=EF=BC=8C=E7=A1=AE?= =?UTF-8?q?=E4=BF=9D=E9=94=99=E8=AF=AF=E4=BF=A1=E6=81=AF=E8=BE=93=E5=87=BA?= =?UTF-8?q?=E4=B8=80=E8=87=B4=E6=80=A7=20fix:=20=E4=BF=AE=E6=AD=A3?= =?UTF-8?q?=E7=BC=93=E5=AD=98=E5=91=BD=E4=B8=AD=E7=8E=87=E8=AE=A1=E7=AE=97?= =?UTF-8?q?=E9=80=BB=E8=BE=91=EF=BC=8C=E9=81=BF=E5=85=8D=E9=99=A4=E4=BB=A5?= =?UTF-8?q?=E9=9B=B6=E9=94=99=E8=AF=AF=20feat:=20=E6=B7=BB=E5=8A=A0=20Mock?= =?UTF-8?q?Client=20=E7=9A=84=20RawRequest=20=E6=96=B9=E6=B3=95=EF=BC=8C?= =?UTF-8?q?=E6=94=AF=E6=8C=81=E5=8E=9F=E5=A7=8B=20JSON=20=E8=AF=B7?= =?UTF-8?q?=E6=B1=82=E4=BD=93=20test:=20=E6=9B=B4=E6=96=B0=E5=B8=83?= =?UTF-8?q?=E5=B1=80=E6=B5=8B=E8=AF=95=EF=BC=8C=E5=A2=9E=E5=BC=BA=E5=AF=B9?= =?UTF-8?q?=E8=BE=93=E5=87=BA=E5=86=85=E5=AE=B9=E7=9A=84=E9=AA=8C=E8=AF=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 55 ++++++++++++++++--------------- internal/client/anthropic.go | 2 +- internal/client/openai.go | 4 +-- internal/runner/runner.go | 5 ++- internal/runner/runner_test.go | 9 +++++ internal/tui/pages/layout_test.go | 2 +- 6 files changed, 45 insertions(+), 32 deletions(-) diff --git a/README.md b/README.md index 0496aac..d39a97c 100644 --- a/README.md +++ b/README.md @@ -306,26 +306,27 @@ ait --models=gpt-4,claude-3-sonnet --prompt-file="test_prompts/*.txt" --count=20 ## 📋 命令行参数 -| 参数 | 描述 | 默认值 | 必填 | -|:---------------|:-------------------------------------------------------------|:--------------------------|:----:| -| `--version` | 显示版本信息(包括 Git Commit 和构建时间) | - | ❌ | -| `--protocol` | 协议类型 (`openai`/`anthropic`) | 根据环境变量自动推断 | ❌ | -| `--baseUrl` | 服务地址
支持环境变量:`OPENAI_BASE_URL` 或 `ANTHROPIC_BASE_URL` | - | ✅ | -| `--apiKey` | API 密钥
支持环境变量:`OPENAI_API_KEY` 或 `ANTHROPIC_API_KEY` | - | ✅ | -| `--model` | 单个模型名称
如:`gpt-4`(不支持多个模型) | - | ❌ | -| `--models` | 模型名称,支持多个模型用逗号分割
如:`gpt-4,claude-3-sonnet` | - | ✅ | -| `--concurrency`| 并发数 | `3` | ❌ | -| `--count` | 请求总数 | `10` | ❌ | -| `--timeout` | 请求超时时间(秒) | `300` | ❌ | -| `--prompt` | 测试提示语(直接输入字符串)
如:`"分析人工智能的发展前景"` | `"你好,介绍一下你自己。"` | ❌ | -| `--prompt-file`| 从文件读取 prompt
**支持多种模式**:
• 单文件:`"prompts/test.txt"`
• 通配符:`"prompts/*.txt"`
• 相对/绝对路径均可 | - | ❌ | -| `--prompt-length`| 生成指定字符长度的测试 prompt
**快速测试功能**:无需准备文件即可生成测试内容
• 优先级高于其他 prompt 参数
• 生成有意义的中文文本片段 | `0`(不启用) | ❌ | -| `--stream` | 是否开启流模式 | `true` | ❌ | -| `--thinking` | 是否开启思考模式(仅 OpenAI 协议支持) | `false` | ❌ | -| `--log` | 是否开启详细日志记录 | `false` | ❌ | -| `--report` | 是否生成报告文件(同时生成 JSON 和 CSV) | `false` | ❌ | +| 参数 | 描述 | 默认值 | 必填 | +| :--- | :--- | :--- | :---: | +| `--version` | 显示版本信息(包括 Git Commit 和构建时间) | - | ❌ | +| `--protocol` | 协议类型 (`openai`/`anthropic`) | 根据环境变量自动推断 | ❌ | +| `--baseUrl` | 服务地址
支持环境变量:`OPENAI_BASE_URL` 或 `ANTHROPIC_BASE_URL` | - | ✅ | +| `--apiKey` | API 密钥
支持环境变量:`OPENAI_API_KEY` 或 `ANTHROPIC_API_KEY` | - | ✅ | +| `--model` | 单个模型名称
如:`gpt-4`(不支持多个模型) | - | ❌ | +| `--models` | 模型名称,支持多个模型用逗号分割
如:`gpt-4,claude-3-sonnet` | - | ✅ | +| `--concurrency` | 并发数 | `3` | ❌ | +| `--count` | 请求总数 | `10` | ❌ | +| `--timeout` | 请求超时时间(秒) | `300` | ❌ | +| `--prompt` | 测试提示语(直接输入字符串)
如:`"分析人工智能的发展前景"` | `"你好,介绍一下你自己。"` | ❌ | +| `--prompt-file` | 从文件读取 prompt
**支持多种模式**:
• 单文件:`"prompts/test.txt"`
• 通配符:`"prompts/*.txt"`
• 相对/绝对路径均可 | - | ❌ | +| `--prompt-length` | 生成指定字符长度的测试 prompt
**快速测试功能**:无需准备文件即可生成测试内容
• 优先级高于其他 prompt 参数
• 生成有意义的中文文本片段 | `0`(不启用) | ❌ | +| `--stream` | 是否开启流模式 | `true` | ❌ | +| `--thinking` | 是否开启思考模式(仅 OpenAI 协议支持) | `false` | ❌ | +| `--log` | 是否开启详细日志记录 | `false` | ❌ | +| `--report` | 是否生成报告文件(同时生成 JSON 和 CSV) | `false` | ❌ | **注意**: + - `--model` 和 `--models` 不能同时使用。使用 `--model` 测试单个模型,使用 `--models` 测试多个模型 - prompt 参数优先级:`--prompt-length` > `--prompt-file` > `--prompt` > 管道输入 > 默认值 @@ -343,7 +344,7 @@ ait --models=gpt-4,claude-3-sonnet --prompt-file="test_prompts/*.txt" --count=20 ### 单模型详细报告示例 -``` +```text ┌──────────────────┬──────────┬──────────┬──────────┬──────────┬────────┬────────────────────────────┐ │ 指标 │ 最小值 │ 平均值 │ 标准差 │ 最大值 │ 单位 │ 采样方式说明 │ ├──────────────────┼──────────┼──────────┼──────────┼──────────┼────────┼────────────────────────────┤ @@ -356,7 +357,7 @@ ait --models=gpt-4,claude-3-sonnet --prompt-file="test_prompts/*.txt" --count=20 ### 多模型对比报告示例 -``` +```text ┌────────────────┬──────────┬────────┬──────────┬────────────┬─────────────┬────────────────┬────────────────┬──────────────────┐ │ 🤖 模型 │ 📊 请求数│ ⚡ 并发│ ✅ 成功率│ 🕐 平均总耗时│ ⚡ 平均 TTFT │ 🚀 平均输出 TPS│ 🌐 平均吞吐 TPS│ 🎲 平均输出Token数│ ├────────────────┼──────────┼────────┼──────────┼────────────┼─────────────┼────────────────┼────────────────┼──────────────────┤ @@ -581,13 +582,13 @@ ait \ ### tpg 参数说明 -| 参数 | 描述 | 默认值 | -|:-----------|:------------------------------------------|:-----------| -| `-count` | 生成的 prompt 数量 | `10` | -| `-length` | 每个 prompt 的近似长度(字符数) | `50` | -| `-output` | 输出目录 | `prompts` | -| `-template`| 模板字符串,支持占位符 | 无 | -| `-help` | 显示帮助信息 | - | +| 参数 | 描述 | 默认值 | +| :--- | :--- | :--- | +| `-count` | 生成的 prompt 数量 | `10` | +| `-length` | 每个 prompt 的近似长度(字符数) | `50` | +| `-output` | 输出目录 | `prompts` | +| `-template` | 模板字符串,支持占位符 | 无 | +| `-help` | 显示帮助信息 | - | ### 模板占位符 diff --git a/internal/client/anthropic.go b/internal/client/anthropic.go index a94a2d9..4bb50cd 100644 --- a/internal/client/anthropic.go +++ b/internal/client/anthropic.go @@ -355,7 +355,7 @@ func (c *AnthropicClient) doRequest(reqBodyBytes []byte, stream bool) (*Response RequestBody: string(reqBodyBytes), ResponseBody: responseBody, ErrorMessage: errorMessage, - }, fmt.Errorf(errorMessage) + }, fmt.Errorf("%s", errorMessage) } if stream { diff --git a/internal/client/openai.go b/internal/client/openai.go index cd5994b..b24141a 100644 --- a/internal/client/openai.go +++ b/internal/client/openai.go @@ -567,7 +567,7 @@ func (c *OpenAIClient) doRequest(jsonData []byte, stream bool) (*ResponseMetrics RequestBody: string(jsonData), ResponseBody: responseBody, ErrorMessage: errorMessage, - }, fmt.Errorf(errorMessage) + }, fmt.Errorf("%s", errorMessage) } if c.Provider == types.ProtocolOpenAIResponses { @@ -726,7 +726,7 @@ func (c *OpenAIClient) doRequest(jsonData []byte, stream bool) (*ResponseMetrics RequestBody: string(jsonData), ResponseBody: string(responseData), ErrorMessage: errorMessage, - }, fmt.Errorf(errorMessage) + }, fmt.Errorf("%s", errorMessage) } responseData, err := io.ReadAll(resp.Body) diff --git a/internal/runner/runner.go b/internal/runner/runner.go index e7fdfea..82e0ad8 100644 --- a/internal/runner/runner.go +++ b/internal/runner/runner.go @@ -65,7 +65,10 @@ func calculateCacheHitRate(metrics *client.ResponseMetrics) float64 { if metrics == nil || metrics.CachedInputTokens <= 0 { return 0 } - return 1 + if metrics.PromptTokens <= 0 { + return 0 + } + return float64(metrics.CachedInputTokens) / float64(metrics.PromptTokens) } // Run 执行性能测试,返回结果数据 diff --git a/internal/runner/runner_test.go b/internal/runner/runner_test.go index 1dd42d7..b472ec9 100644 --- a/internal/runner/runner_test.go +++ b/internal/runner/runner_test.go @@ -92,6 +92,11 @@ func (m *MockClient) ResetCallCount() { atomic.StoreInt64(&m.callCount, 0) } +// RawRequest 使用原始 JSON 请求体发送请求(mock 实现) +func (m *MockClient) RawRequest(rawBody string) (*client.ResponseMetrics, error) { + return m.Request("", rawBody, false) +} + // SetLogger 设置日志记录器 func (m *MockClient) SetLogger(logger *logger.Logger) { // MockClient 不需要实际的日志记录器,所以这里是空实现 @@ -1546,3 +1551,7 @@ func (m *MockClientWithErrorMetrics) GetModel() string { func (m *MockClientWithErrorMetrics) SetLogger(logger *logger.Logger) { // Mock实现,不需要实际功能 } + +func (m *MockClientWithErrorMetrics) RawRequest(rawBody string) (*client.ResponseMetrics, error) { + return m.Request("", rawBody, false) +} diff --git a/internal/tui/pages/layout_test.go b/internal/tui/pages/layout_test.go index 34f2253..298f326 100644 --- a/internal/tui/pages/layout_test.go +++ b/internal/tui/pages/layout_test.go @@ -23,7 +23,7 @@ func TestPageLayoutAssembleRendersSharedChrome(t *testing.T) { if len(lines) < 6 { t.Fatalf("expected shared chrome to add header/hotkeys lines, got %d lines", len(lines)) } - if !strings.Contains(rendered, "AIT") || !strings.Contains(rendered, "任务中心") { + if (!strings.Contains(rendered, "AIT") && !strings.Contains(rendered, "████")) || !strings.Contains(rendered, "任务中心") { t.Fatalf("expected header brand/title in output: %q", rendered) } if !strings.Contains(rendered, "创建任务、运行压测") { From 06ae14febc7ab2f0c00f5c0ede7fe3ab4a687511 Mon Sep 17 00:00:00 2001 From: yinxulai Date: Mon, 25 May 2026 13:16:51 +0800 Subject: [PATCH 52/52] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8D=20Go=20?= =?UTF-8?q?=E7=8E=AF=E5=A2=83=E8=AE=BE=E7=BD=AE=EF=BC=8C=E4=BD=BF=E7=94=A8?= =?UTF-8?q?=20go.mod=20=E6=96=87=E4=BB=B6=E6=8C=87=E5=AE=9A=20Go=20?= =?UTF-8?q?=E7=89=88=E6=9C=AC?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .github/workflows/test.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 417323b..0030d47 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -12,7 +12,7 @@ jobs: - name: 设置 Go 环境 uses: actions/setup-go@v4 with: - go-version: '1.22' + go-version-file: 'go.mod' - name: 运行测试并生成覆盖率报告 run: go test -v -race -coverprofile=coverage.txt ./...