From ce6804c97e2b9420236a7ff36370028db9ab5cad Mon Sep 17 00:00:00 2001 From: liut Date: Thu, 26 Mar 2026 15:54:07 +0800 Subject: [PATCH] feat: add platform adapter layer for WeCom and Feishu integration MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Introduce a comprehensive platform adapter layer enabling AI chat integration with enterprise messaging platforms. This implementation supports both WeChat Work (WeCom) and Feishu through WebSocket and HTTP webhook modes. Core architecture includes: - Unified channel interface (Channel) with platform-agnostic message handling - Registry pattern for dynamic channel registration and lifecycle management - Message deduplication using Redis with 60-second TTL - Session-to-conversation mapping for persistent chat history across reconnections Platform implementations: - WeCom: WebSocket long-connection and HTTP webhook modes with AES decryption - Feishu: WebSocket and HTTP webhook with signature verification Supporting infrastructure: - Extracted ToolExecutor to eliminate duplicate tool call loop logic between handlers - Added allowlist filtering for authorized user access control - Conversation binding via Redis mapping (platform:csid:{sessionKey} → OID) - Unified message format supporting text, images, audio, and file attachments Configuration via preset.yaml with multi-instance support per platform type. --- data/preset.example.yaml | 43 ++ ...-25-001-feat-channel-adapter-layer-plan.md | 259 +++++++++ ...26-03-26-001-fix-wecom-session-key-plan.md | 181 ++++++ ...02-refactor-handle-duplicate-logic-plan.md | 237 ++++++++ .../executeToolCallLoop-deduplication.md | 147 +++++ go.mod | 3 + go.sum | 27 + main.go | 4 + pkg/models/aigc/preset.go | 10 + pkg/models/channel/channel.go | 50 ++ pkg/models/channel/message.go | 39 ++ pkg/services/channels/allowlist.go | 21 + pkg/services/channels/dedup.go | 80 +++ pkg/services/channels/feishu/feishu.go | 25 + pkg/services/channels/feishu/webhook.go | 533 ++++++++++++++++++ pkg/services/channels/feishu/websocket.go | 479 ++++++++++++++++ pkg/services/channels/registry.go | 82 +++ pkg/services/channels/wecom/websocket.go | 521 +++++++++++++++++ pkg/services/channels/wecom/wecom.go | 25 + pkg/services/channels/wecom/wecom_http.go | 514 +++++++++++++++++ pkg/services/channels/wecom/wecom_msg.go | 28 + pkg/services/stores/conversation.go | 18 + pkg/web/api/api.go | 15 + pkg/web/api/handle_convo.go | 59 +- pkg/web/api/handle_platform.go | 194 +++++++ pkg/web/api/tool_executor.go | 83 +++ 26 files changed, 3619 insertions(+), 58 deletions(-) create mode 100644 docs/plans/2026-03-25-001-feat-channel-adapter-layer-plan.md create mode 100644 docs/plans/2026-03-26-001-fix-wecom-session-key-plan.md create mode 100644 docs/plans/2026-03-26-002-refactor-handle-duplicate-logic-plan.md create mode 100644 docs/solutions/logic-errors/executeToolCallLoop-deduplication.md create mode 100644 pkg/models/channel/channel.go create mode 100644 pkg/models/channel/message.go create mode 100644 pkg/services/channels/allowlist.go create mode 100644 pkg/services/channels/dedup.go create mode 100644 pkg/services/channels/feishu/feishu.go create mode 100644 pkg/services/channels/feishu/webhook.go create mode 100644 pkg/services/channels/feishu/websocket.go create mode 100644 pkg/services/channels/registry.go create mode 100644 pkg/services/channels/wecom/websocket.go create mode 100644 pkg/services/channels/wecom/wecom.go create mode 100644 pkg/services/channels/wecom/wecom_http.go create mode 100644 pkg/services/channels/wecom/wecom_msg.go create mode 100644 pkg/web/api/handle_platform.go create mode 100644 pkg/web/api/tool_executor.go diff --git a/data/preset.example.yaml b/data/preset.example.yaml index c2dcc80..84dd5b9 100644 --- a/data/preset.example.yaml +++ b/data/preset.example.yaml @@ -23,3 +23,46 @@ tools: kb_search: "在知识库中搜索相关内容。当遇到未知或不确定的问题时,优先查阅知识库。" kb_create: "创建新的知识库文档,所有参数必填。注意:除非用户明确要求补充内容,否则不要调用。" fetch: "从互联网获取 URL 内容并可选地提取为 markdown 格式" + +# 平台适配器配置(支持多实例) +channels: + # WeCom WebSocket 长连接模式 + wecom: + enable: true + mode: websocket + config: + bot_id: "YOUR_BOT_ID" + bot_secret: "YOUR_BOT_SECRET" + allow_from: "" # 可选,限制来源 UserID + + # WeCom Webhook 回调模式 + # wecom: + # enable: true + # mode: webhook + # config: + # corp_id: "YOUR_CORP_ID" + # corp_secret: "YOUR_CORP_SECRET" + # agent_id: "YOUR_AGENT_ID" + # callback_token: "YOUR_CALLBACK_TOKEN" + # callback_aes_key: "YOUR_CALLBACK_AES_KEY" + # callback_path: "/wecom/callback" + + # 飞书 WebSocket 长连接模式 + # feishu: + # enable: true + # mode: websocket + # config: + # app_id: "YOUR_APP_ID" + # app_secret: "YOUR_APP_SECRET" + # allow_from: "" # 可选,限制来源 UserID + + # 飞书 Webhook 回调模式 + # feishu: + # enable: true + # mode: webhook + # config: + # app_id: "YOUR_APP_ID" + # app_secret: "YOUR_APP_SECRET" + # encrypt_key: "YOUR_ENCRYPT_KEY" # 可选,加密密钥 + # callback_path: "/feishu/callback" + # allow_from: "" # 可选,限制来源 UserID diff --git a/docs/plans/2026-03-25-001-feat-channel-adapter-layer-plan.md b/docs/plans/2026-03-25-001-feat-channel-adapter-layer-plan.md new file mode 100644 index 0000000..58c0479 --- /dev/null +++ b/docs/plans/2026-03-25-001-feat-channel-adapter-layer-plan.md @@ -0,0 +1,259 @@ +--- +title: Add Platform Adapter Layer for Chat Platform Integration +type: feat +status: completed +date: 2026-03-25 +--- + +# Add Platform Adapter Layer for Chat Platform Integration + +## Overview + +在 morrigan 中引入平台适配器层,使 AI 对话能力可以对接微信企业版(WeCom)、飞书等聊天平台。以 WeCom 为突破点,验证架构设计后扩展至其他平台。 + +## Problem Statement + +当前 morrigan 只支持 HTTP API 方式的对话接入(通过前端)。需要支持将 AI 对话能力以 Bot 形式接入到企业常用的聊天平台(WeCom、飞书等),让用户可以在这些平台中直接与 AI 对话。 + +## Proposed Solution + +### 架构设计 + +``` +┌─────────────────────────────────────────────────────────────┐ +│ Chat Platforms │ +│ WeCom │ 飞书 │ 钉钉 │ Telegram ... │ +└──────┬──────┴─────┬─────┴─────┬────┴────────┬──────────────┘ + │ │ │ │ + ▼ ▼ ▼ ▼ +┌─────────────────────────────────────────────────────────────┐ +│ Channel Adapter Layer (pkg/channels/) │ +│ ┌─────────┐ ┌─────────┐ ┌─────────┐ ┌─────────┐ │ +│ │ wecom │ │ feishu │ │dingtalk │ │ ... │ (可扩展) │ +│ └────┬────┘ └────┬────┘ └────┬────┘ └────┬────┘ │ +│ │ │ │ │ │ +│ └────────────┴────────────┴────────────┘ │ +│ │ │ +│ ┌──────────┴──────────┐ │ +│ │ Platform Bridge │ (统一消息格式) │ +│ │ (pkg/channels/) │ │ +│ └──────────┬──────────┘ │ +└─────────────────────────┼───────────────────────────────────┘ + │ + ▼ +┌─────────────────────────────────────────────────────────────┐ +│ Morrigan Core (Existing) │ +│ ┌─────────────┐ ┌─────────────┐ ┌─────────────────────┐ │ +│ │ handle_convo│ │ LLM │ │ tools/registry │ │ +│ │ (chat) │◄─┤ Client │◄─┤ (MCP tools) │ │ +│ └─────────────┘ └─────────────┘ └─────────────────────┘ │ +└─────────────────────────────────────────────────────────────┘ +``` + + +## Technical Considerations + +### 核心接口设计 + +参考 `cc-connect/core/interfaces.go`,定义以下接口: + +```go +// pkg/models/channel/channel.go + +// Channel 平台适配器必须实现的接口 +type Channel interface { + Name() string // 平台名称: "wecom", "feishu" + Start(handler MessageHandler) error // 启动平台连接 + Reply(ctx context.Context, replyCtx any, content string) error // 回复消息 + Send(ctx context.Context, replyCtx any, content string) error // 发送消息 + Stop() error // 停止平台连接 +} + +// MessageHandler 消息处理函数类型 +type MessageHandler func(p Channel, msg *Message) + +// ReplyContextReconstructor 可选接口:支持从 sessionKey 重建回复上下文 +type ReplyContextReconstructor interface { + ReconstructReplyCtx(sessionKey string) (any, error) +} + +// ImageSender 可选接口:支持发送图片 +type ImageSender interface { + SendImage(ctx context.Context, replyCtx any, img ImageAttachment) error +} + +// 统一消息结构 +type Message struct { + Channel string // 平台名称 + SessionKey string // 唯一标识: "{platform}:{chatID}:{userID}" + MessageID string // 平台原始消息ID (用于去重) + UserID string + UserName string + ChatName string + Content string // 消息内容 + Images []ImageAttachment + Files []FileAttachment + Audio *AudioAttachment + ReplyCtx any // 平台特定回复上下文 + FromVoice bool // 语音转文字 +} +``` + +### 平台注册机制 + +```go +// pkg/channels/registry.go + +type Registry struct { + channels map[string]Factory +} + +type PlatformFactory func(opts map[string]any) (Channel, error) + +// 全局注册表 +var registry *PlatformRegistry + +func RegisterPlatform(name string, factory PlatformFactory) +func NewPlatform(name string, opts map[string]any) (Platform, error) +``` + +每个平台在 `init()` 中注册: + +```go +// pkg/channels/wecom/wecom.go +func init() { + channels.RegisterPlatform("wecom", New) +} +``` + +### WeCom 适配器设计 + +**HTTP Webhook 模式:** +- 接收 GET 请求验证回调 URL +- 接收 POST 请求处理加密消息(XML + AES-256-CBC) +- 消息类型:文本、图片、语音 +- 回复:POST XML 消息 +- Access token 缓存(提前 60 秒刷新) + +**WebSocket 长连接模式:** +- 独立进程运行,通过 bridge 与主进程通信 +- 支持重连(指数退避:1s -> 30s 最大) +- 心跳保活(30s ping/pong) +- 流式响应支持 + +### 与 Morrigan Core 的集成 + +消息流程: + +1. WeCom 适配器接收消息 +2. 解析并构建统一 `Message` 结构 +3. 检查去重、过滤老消息、白名单 +4. 构建 sessionKey: `wecom:{chatID}:{userID}` +5. 调用 `MessageHandler` → 转发给现有 `handle_convo.go` 的对话处理逻辑 +6. AI 回复通过适配器的 `Reply()` 方法发送回平台 + +**配置扩展:** + +在现有 `settings` 中添加平台配置: + +```go +// pkg/settings/settings.go +type PlatformConfig struct { + Enable bool + Type string // "wecom", "feishu" + Config map[string]any +} +``` + +### 关键实现细节 + +1. **去重机制**:使用 Redis 缓存 MsgId,60 秒 TTL +2. **Token 缓存**:Access Token 缓存,提前刷新 +3. **消息分片**:大消息按 UTF-8 拆分(WeCom 限制 2000 字符) +4. **异步处理**:`go handler(p, msg)` 非阻塞处理 +5. **优雅关闭**:Context 取消和连接 draining + +## System-Wide Impact + +### Interaction Graph + +``` +WeCom Message Received + ↓ +[wecom.go: message handler] + ↓ +Parse XML + AES decrypt + ↓ +[bridge.go: dispatch] + ↓ +Check dedup (Redis) + ↓ +Check allowlist + ↓ +[handle_convo.go: postChat] + ↓ +[llm.Client: Chat] + ↓ +AI Response + ↓ +[wecom.go: Reply] + ↓ +POST XML to WeCom API +``` + +### 错误处理 + +- 平台 API 调用失败:重试 + 告警 +- LLM 调用失败:返回友好错误消息 +- 消息处理超时:平台一般有超时限制,考虑异步回复 + +## Acceptance Criteria + +- [ ] `pkg/channels/` 目录创建完成,核心接口定义完成 +- [ ] WeCom HTTP Webhook 适配器可接收消息并回复 +- [ ] 消息正确路由到现有 `handle_convo.go` 对话处理 +- [ ] 集成测试通过(WeCom 模拟消息) +- [ ] 配置可通过 `settings` 管理 +- [ ] 文档说明如何添加新平台 + +## Dependencies & Risks + +**依赖:** +- 现有 Redis 存储(去重、Token 缓存) +- 现有 LLM 客户端 +- 现有对话处理逻辑 + +**风险:** +- WeCom API 变更需同步更新 +- 消息加密/签名验证需严格实现 +- 平台限流需处理 + +## Implementation Phases + +### Phase 1: Core Platform Layer (基础设施) + +- 创建 `pkg/channels/` 目录结构 +- 实现 `channel.go` 核心接口 +- 实现 `registry.go` 平台注册表 +- 实现 `message.go` 统一消息结构 +- 实现 `dedup.go` 去重机制 + +### Phase 2: WeCom HTTP Adapter (WeCom HTTP 适配器) + +- 实现 `pkg/channels/wecom/wecom.go` +- 实现消息解析(AES 解密) +- 实现回复发送 +- 实现 Access Token 管理 + +### Phase 3: Integration (集成) + +- 创建 `pkg/web/api/handle_platform.go` +- 集成到路由系统 +- 对接现有 `handle_convo.go` +- 配置管理 + +### Phase 4: Testing & Polish (测试完善) + +- 单元测试 +- 集成测试 +- 飞书适配器(扩展) diff --git a/docs/plans/2026-03-26-001-fix-wecom-session-key-plan.md b/docs/plans/2026-03-26-001-fix-wecom-session-key-plan.md new file mode 100644 index 0000000..cf419dc --- /dev/null +++ b/docs/plans/2026-03-26-001-fix-wecom-session-key-plan.md @@ -0,0 +1,181 @@ +--- +title: "fix: 平台会话与 Conversation 映射及历史保持" +type: fix +status: active +date: 2026-03-26 +origin: docs/plans/2026-03-25-001-feat-platform-adapter-layer-plan.md +--- + +# fix: 平台会话与 Conversation 映射及历史保持 + +## Overview + +将平台会话(sessionKey = `{platform}:{chatID}:{userID}`)与 Conversation (OID) 建立持久映射,解决: +1. WeCom WebSocket 当前使用 `reqID` 导致每条消息都是独立会话 +2. 群聊/单聊应绑定为一个固定 Conversation,历史累积 +3. 重连后可恢复之前的会话上下文 + +## Problem Statement + +### 根因分析 + +**WeCom 当前问题** (`pkg/platform/wecom/websocket.go:329`): +```go +sessionKey := reqID // reqID 是每次消息都不同的 UUID +``` + +**期望语义**: 一个聊天模式(单聊或群聊)= 一个固定 Conversation,历史累积 + +**问题**: +- `reqID` 每次消息都变化,导致每次创建新 Conversation +- `sessionKey` 格式 `wecom:{chatID}:{userID}` 与 `ReconstructReplyCtx` 期望不一致 + +### 平台会话语义 + +| 平台 | sessionKey 格式 | 语义 | +|------|-----------------|------| +| Feishu | `feishu:{chatID}:{userID}` | ✅ 正确 | +| WeCom (当前) | `wecom:{reqID}` | ❌ 每消息新会话 | +| WeCom (修复后) | `wecom:{chatID}:{userID}` | ✅ 同 Feishu | + +## Proposed Solution + +### 核心设计 + +**Redis 映射表**: +``` +Key: platform:csid:{platform}:{chatID}:{userID} +Value: conversation OID 字符串 (如 "cs-12345") +TTL: 30 天(可调整) +``` + +**查找顺序**: +1. 先查 Redis:`GET platform:csid:wecom:{chatID}:{userID}` +2. 找到 → 复用该 Conversation OID +3. 未找到 → 创建新 Conversation,写入 Redis 映射 + +### 实现位置 + +**`stores/conversation.go`**:新增 `GetOrCreateConversationBySessionKey` 函数 + +### 修改点 + +#### 1. stores/conversation.go - 新增映射函数 + +```go +const sessionKeyCSIDPrefix = "platform:csid:" + +func sessionKeyToCSIDKey(sessionKey string) string { + return sessionKeyCSIDPrefix + sessionKey +} + +// GetOrCreateConversationBySessionKey 从 Redis 查找或创建 Conversation +func GetOrCreateConversationBySessionKey(ctx context.Context, sessionKey string) Conversation { + // 1. 尝试从 Redis 获取已映射的 OID + key := sessionKeyToCSIDKey(sessionKey) + oidStr, err := SgtRC().Get(ctx, key).Result() + if err == nil && oidStr != "" { + // 2a. 找到映射,直接使用该 OID 创建 Conversation + return NewConversation(ctx, oidStr) + } + + // 2b. 未找到,创建新 Conversation + cs := NewConversation(ctx, nil) // 会生成新 OID + csid := cs.GetID() + + // 3. 写入 Redis 映射,TTL 30 天 + SgtRC().Set(ctx, key, csid, 30*24*time.Hour) + + return cs +} +``` + +**关键点**:找到 OID 后直接 `NewConversation(ctx, oidStr)`,不管数据库是否有记录。`NewConversation` 内部会处理查找或创建。 + +#### 2. handle_platform.go - 使用新的映射函数 + +```go +// 修改前 +csid := extractPlatformCSID(msg.SessionKey, msg.SessionKey) +cs := stores.NewConversation(ctx, csid) + +// 修改后 +cs := stores.GetOrCreateConversationBySessionKey(ctx, msg.SessionKey) +``` + +#### 3. wecom/websocket.go - 修复 sessionKey 格式 + +```go +// 修改前 +sessionKey := reqID + +// 修改后 +sessionKey := fmt.Sprintf("wecom:%s:%s", chatID, body.From.UserID) +``` + +### 重连后上下文恢复 + +`ReconstructReplyCtx` 依赖 `sessionKey` 格式 `wecom:{chatID}:{userID}`,而该格式在重连后不变,所以: +1. 平台适配器仍用相同 sessionKey 构造 Message +2. `GetOrCreateConversationBySessionKey` 查找 Redis 映射,找到则复用 +3. 找到对应 Conversation,加载历史消息 + +### 单聊新会话指令(可选扩展) + +未来可支持 `/new` 或 `/clear` 指令来主动创建新 Conversation: +- 检测到指令时,删除 Redis 中的映射 Key +- 下次消息将创建新的 Conversation + +## Acceptance Criteria + +- [ ] WeCom WebSocket sessionKey 格式改为 `wecom:{chatID}:{userID}` +- [ ] 同一用户/群的消息共享会话历史(Redis 映射生效) +- [ ] 重连后能恢复之前的会话上下文 +- [ ] Feishu 平台不受影响(使用相同逻辑) +- [ ] `make vet lint` 通过 +- [ ] 不破坏现有 Reply/Send 功能 + +## Technical Considerations + +### Redis 映射 TTL + +- 设为 30 天,与会话历史 TTL(24 小时)分离 +- 30 天无活动后,映射自动清除,下次消息创建新会话 + +### 去重机制 + +- 当前通过 `MsgID` 去重,不依赖 sessionKey +- 修复后同一用户发送的相同内容会被正确识别为重复 + +### 并发处理 + +- 同一 sessionKey 可能同时收到多条消息 +- 需要考虑 Redis SetNX 或类似机制避免重复创建 + +### 平台兼容性 + +| 平台 | sessionKey | 映射支持 | +|------|------------|----------| +| Feishu WS | `feishu:{chatID}:{userID}` | ✅ 直接复用 | +| Feishu HTTP | `feishu:{chatID}:{userID}` | ✅ 直接复用 | +| WeCom WS | `wecom:{chatID}:{userID}` | ✅ 修复后复用 | +| WeCom HTTP | `wecom:{userID}` | ✅ 直接复用 | + +## Files to Change + +| File | Change | +|------|--------| +| `pkg/services/stores/conversation.go` | 新增 `GetOrCreateConversationBySessionKey` | +| `pkg/web/api/handle_platform.go` | 使用新的映射函数 | +| `pkg/services/channels/wecom/websocket.go:329` | `sessionKey := fmt.Sprintf("wecom:%s:%s", chatID, body.From.UserID)` | + +## Related + +- **Origin**: `docs/plans/2026-03-25-001-feat-platform-adapter-layer-plan.md` +- **Issue**: WebSocket 重连后会话历史断裂 +- **Log Sample**: + ``` + subReqID=ev-557gu7xyc9aj # 正常 + subReqID=ev-557gwnxqblla # 重连后变化 + subReqID=ev-557gz3ws309d # 再次重连 + ``` diff --git a/docs/plans/2026-03-26-002-refactor-handle-duplicate-logic-plan.md b/docs/plans/2026-03-26-002-refactor-handle-duplicate-logic-plan.md new file mode 100644 index 0000000..17cc063 --- /dev/null +++ b/docs/plans/2026-03-26-002-refactor-handle-duplicate-logic-plan.md @@ -0,0 +1,237 @@ +--- +title: "refactor: 抽取 handle_convo 与 handle_platform 重复逻辑" +type: refactor +status: completed +date: 2026-03-26 +--- + +# refactor: 抽取 handle_convo 与 handle_platform 重复逻辑 + +## Overview + +`handle_convo.go` 和 `handle_platform.go` 中存在重复的工具调用循环逻辑,需要抽取为共享代码,消除代码冗余并统一日志规范。 + +## Problem Statement + +### 重复的 `executeToolCallLoop` + +两个文件都有独立的 `executeToolCallLoop` 实现: + +| 文件 | 行号 | 日志方式 | 日志详细程度 | +|------|------|----------|-------------| +| `handle_platform.go` | 183-232 | `slog.Warn` | 最小化 | +| `handle_convo.go` | 751-805 | `logger().Infow` | 详细 | + +**`handle_platform.go` 版本特点**: +- 使用 `slog.Warn` 记录错误 +- 仅在解析参数失败和调用工具失败时记录日志 +- 无成功调用的日志 + +**`handle_convo.go` 版本特点**: +- 使用 `logger().Infow`(项目规范的日志方式) +- 详细记录 `toolCallID`、`toolCallType`、`toolCallName` +- 成功调用时记录 `invokeTool ok` 及返回内容 + +### 其他共享元素(无需修改) + +| 元素 | 位置 | 说明 | +|------|------|------| +| `formatToolResult` | handle_convo.go:685-718 | 同一包内,可直接调用 | +| `chatExecutor` 类型 | handle_convo.go:743-744 | 已共享 | +| `convertToolCallsForJSON` | handle_convo.go:720-741 | 仅 convo 使用 | + +## Proposed Solution + +### 方案:提取共享的 `ToolExecutor` + +创建 `pkg/web/api/tool_executor.go`,包含: + +1. **`ToolExecutor` 结构体**:封装通用的工具调用循环逻辑 +2. **`ExecuteToolCallLoop` 方法**:处理工具调用循环直到无 tool calls +3. **`chatExecutor` 类型**:作为函数参数传入,支持流式/非流式执行器 + +### 架构设计 + +``` +pkg/web/api/ +├── handle_convo.go # api.executeToolCallLoop → 使用 ToolExecutor +├── handle_platform.go # channelHandler.executeToolCallLoop → 使用 ToolExecutor +├── tool_executor.go # NEW: 共享的 ToolExecutor +``` + +### 核心实现 + +```go +// tool_executor.go + +// ToolExecutor 封装工具调用循环逻辑 +type ToolExecutor struct { + toolreg *tools.Registry +} + +// NewToolExecutor 创建 ToolExecutor +func NewToolExecutor(toolreg *tools.Registry) *ToolExecutor { + return &ToolExecutor{toolreg: toolreg} +} + +// ExecuteToolCallLoop 执行工具调用循环,直到无 tool calls +func (e *ToolExecutor) ExecuteToolCallLoop( + ctx context.Context, + messages []llm.Message, + tools []llm.ToolDefinition, + exec chatExecutor, +) (string, []llm.ToolCall, *llm.Usage, error) { + for { + answer, toolCalls, usage, err := exec(ctx, messages, tools) + if err != nil { + return "", nil, nil, err + } + + if len(toolCalls) == 0 { + return answer, nil, usage, nil + } + + // 添加 assistant 消息(带 tool calls) + messages = append(messages, llm.Message{ + Role: llm.RoleAssistant, + ToolCalls: toolCalls, + }) + + // 执行工具调用 + for _, tc := range toolCalls { + if tc.Type != "function" { + continue + } + + var parameters map[string]any + args := string(tc.Function.Arguments) + if args != "" && args != "{}" { + if err := json.Unmarshal([]byte(args), ¶meters); err != nil { + logger().Infow("chat", "toolCallID", tc.ID, "args", args, "err", err) + continue + } + } + if parameters == nil { + parameters = make(map[string]any) + } + + content, err := e.toolreg.Invoke(ctx, tc.Function.Name, parameters) + if err != nil { + logger().Infow("invokeTool fail", "toolCallName", tc.Function.Name, "err", err) + continue + } + + logger().Infow("invokeTool ok", "toolCallName", tc.Function.Name, + "content", toolsvc.ResultLogs(content)) + messages = append(messages, llm.Message{ + Role: llm.RoleTool, + Content: formatToolResult(content), + ToolCallID: tc.ID, + }) + } + } +} +``` + +### 修改点 + +#### 1. 创建 `tool_executor.go` + +新文件,包含: +- `ToolExecutor` 结构体 +- `NewToolExecutor` 构造函数 +- `ExecuteToolCallLoop` 方法 +- 从 `handle_convo.go` 移动 `chatExecutor` 类型定义 + +#### 2. 修改 `handle_convo.go` + +```go +// 添加字段 +type api struct { + // ... existing fields ... + toolExec *ToolExecutor +} + +// 在 newapi() 中初始化 +a.toolExec = NewToolExecutor(toolreg) + +// 修改 executeToolCallLoop 为委托调用 +func (a *api) executeToolCallLoop(ctx context.Context, messages []llm.Message, tools []llm.ToolDefinition, exec chatExecutor) (string, []llm.ToolCall, *llm.Usage, error) { + return a.toolExec.ExecuteToolCallLoop(ctx, messages, tools, exec) +} +``` + +#### 3. 修改 `handle_platform.go` + +```go +// channelHandler 添加 toolExec 字段 +type channelHandler struct { + toolExec *ToolExecutor + // ... existing fields ... +} + +// 修改 InitChannels 接收 toolExec +func InitChannels(r chi.Router, preset *aigc.Preset, sto stores.Storage, llmClient llm.Client, toolreg *tools.Registry) error { + phandler = &channelHandler{ + toolExec: NewToolExecutor(toolreg), + // ... existing fields ... + } + // ... +} + +// 修改 executeToolCallLoop 为委托调用 +func (phandler *channelHandler) executeToolCallLoop(ctx context.Context, messages []llm.Message, tools []llm.ToolDefinition, exec chatExecutor) (string, []llm.ToolCall, *llm.Usage, error) { + return phandler.toolExec.ExecuteToolCallLoop(ctx, messages, tools, exec) +} +``` + +## Technical Considerations + +### 日志标准化 + +统一使用 `logger().Infow` 而非 `slog.Warn`,符合 AGENTS.md 中的日志规范: +```go +// ✅ 规范写法 +logger().Infow("invokeTool fail", "toolCallName", tc.Function.Name, "err", err) + +// ❌ 避免 +slog.Warn("platform: invoke tool failed", "tool", tc.Function.Name, "err", err) +``` + +### 错误处理策略 + +两个原实现都使用 `continue` 处理单个工具调用失败,不阻塞循环。保持此行为。 + +### 依赖注入 + +`ToolExecutor` 通过 `*tools.Registry` 初始化,遵循现有依赖注入模式。 + +## Files to Change + +| 文件 | 变更 | +|------|------| +| `pkg/web/api/tool_executor.go` | 新增 - 共享的 ToolExecutor | +| `pkg/web/api/handle_convo.go` | 使用 ToolExecutor,移除重复的 executeToolCallLoop | +| `pkg/web/api/handle_platform.go` | 使用 ToolExecutor,移除重复的 executeToolCallLoop | +| `pkg/web/api/api.go` | 初始化 api.toolExec 字段 | + +## Acceptance Criteria + +- [ ] `executeToolCallLoop` 仅在 `tool_executor.go` 中有一个实现 +- [ ] `handle_convo.go` 和 `handle_platform.go` 都使用共享的 `ToolExecutor` +- [ ] 日志统一使用 `logger().Infow`(符合项目规范) +- [ ] 工具调用行为保持一致(continue on single failure) +- [ ] `make vet lint` 通过 +- [ ] 单元测试(如有)仍然通过 + +## Verification + +```bash +# 验证 vet 和 lint +make vet lint + +# 验证构建 +go build ./... + +# 手动测试:发送聊天消息,验证工具调用正常 +``` diff --git a/docs/solutions/logic-errors/executeToolCallLoop-deduplication.md b/docs/solutions/logic-errors/executeToolCallLoop-deduplication.md new file mode 100644 index 0000000..a74bf7c --- /dev/null +++ b/docs/solutions/logic-errors/executeToolCallLoop-deduplication.md @@ -0,0 +1,147 @@ +--- +title: "Extract duplicate executeToolCallLoop into shared ToolExecutor" +category: logic-errors +date: 2026-03-26 +tags: [refactor, code-duplication, golang] +related: + - docs/plans/2026-03-26-002-refactor-handle-duplicate-logic-plan.md +--- + +# Extract duplicate executeToolCallLoop into shared ToolExecutor + +## Problem Description + +Two separate files (`handle_convo.go` and `handle_platform.go`) implemented nearly identical `executeToolCallLoop` logic with inconsistent logging packages (`slog` vs `logger().Infow`). This duplication risked divergence over time and inconsistent behavior between API and platform handler code paths. + +## Root Cause + +The `executeToolCallLoop` method was copy-pasted into both handlers with different logging implementations: +- `handle_convo.go`: used `logger().Infow` (project-standard custom logger) +- `handle_platform.go`: used `log/slog` package directly + +When a tool call failed in `handle_platform.go`, the error was logged via `slog.Warn` but success was not logged at all. Meanwhile `handle_convo.go` logged both failures and successes with full context (`toolCallID`, `toolCallType`, `toolCallName`). + +## Solution + +**Created `pkg/web/api/tool_executor.go`** - unified executor: + +```go +// chatExecutor defines the chat execution function type (streaming or non-streaming) +type chatExecutor func(ctx context.Context, messages []llm.Message, tools []llm.ToolDefinition) (string, []llm.ToolCall, *llm.Usage, error) + +// ToolExecutor encapsulates tool call loop logic +type ToolExecutor struct { + toolreg *tools.Registry +} + +func NewToolExecutor(toolreg *tools.Registry) *ToolExecutor { + return &ToolExecutor{toolreg: toolreg} +} + +func (e *ToolExecutor) ExecuteToolCallLoop( + ctx context.Context, + messages []llm.Message, + tools []llm.ToolDefinition, + exec chatExecutor, +) (string, []llm.ToolCall, *llm.Usage, error) { + for { + answer, toolCalls, usage, err := exec(ctx, messages, tools) + if err != nil { + return "", nil, nil, err + } + if len(toolCalls) == 0 { + return answer, nil, usage, nil + } + messages = append(messages, llm.Message{ + Role: llm.RoleAssistant, + ToolCalls: toolCalls, + }) + for _, tc := range toolCalls { + logger().Infow("chat", "toolCallID", tc.ID, "toolCallType", tc.Type, "toolCallName", tc.Function.Name) + // ... tool execution with unified logging + } + } +} +``` + +**Modified `api.go`** - added `toolExec` field and initialization: + +```go +type api struct { + // ... + toolExec *ToolExecutor +} + +// In newapi(): +return &api{ + // ... + toolExec: NewToolExecutor(toolreg), +} +``` + +**Simplified `handle_convo.go`** - delegation pattern: + +```go +func (a *api) executeToolCallLoop(ctx context.Context, messages []llm.Message, tools []llm.ToolDefinition, exec chatExecutor) (string, []llm.ToolCall, *llm.Usage, error) { + return a.toolExec.ExecuteToolCallLoop(ctx, messages, tools, exec) +} +``` + +**Simplified `handle_platform.go`** - same delegation pattern: + +```go +type channelHandler struct { + // ... + toolExec *ToolExecutor +} + +// In InitChannels(): +phandler = &channelHandler{ + // ... + toolExec: NewToolExecutor(toolreg), +} + +// Simplified wrapper: +func (phandler *channelHandler) executeToolCallLoop(ctx context.Context, messages []llm.Message, tools []llm.ToolDefinition, exec chatExecutor) (string, []llm.ToolCall, *llm.Usage, error) { + return phandler.toolExec.ExecuteToolCallLoop(ctx, messages, tools, exec) +} +``` + +## Verification + +- `make vet lint` passed (0 issues) +- `go build ./...` succeeded + +## Files Changed + +| File | Change | +|------|--------| +| `pkg/web/api/tool_executor.go` | NEW - shared ToolExecutor | +| `pkg/web/api/handle_convo.go` | Removed duplicate, added delegation | +| `pkg/web/api/handle_platform.go` | Removed duplicate, added delegation | +| `pkg/web/api/api.go` | Added toolExec field | + +## Prevention Strategies + +1. **Extract common patterns early** - when two implementations diverge slightly, immediately extract to a shared location +2. **Unified logging interface** - use `logger().Infow` across all packages rather than mixing `slog` and custom loggers +3. **Code review for duplication** - require reviewers to explicitly verify the PR does not duplicate existing logic elsewhere +4. **Consider facade/delegation pattern** when handlers need different contexts but similar core logic + +## Recommended Tests + +Add unit tests for `ToolExecutor`: + +```go +func TestExecuteToolCallLoop_NoToolCalls(t *testing.T) { + // Verify: when LLM returns no tool calls, returns immediately +} + +func TestExecuteToolCallLoop_SingleToolCall(t *testing.T) { + // Verify: single tool call executes and adds result to messages +} + +func TestExecuteToolCallLoop_ContinueOnFailure(t *testing.T) { + // Verify: if one tool fails, remaining tools still execute +} +``` diff --git a/go.mod b/go.mod index d8c258c..33891fa 100644 --- a/go.mod +++ b/go.mod @@ -10,8 +10,10 @@ require ( github.com/cupogo/andvari v0.0.0-20260314102041-168adc9ab3a6 github.com/go-chi/chi/v5 v5.2.5 github.com/go-chi/render v1.0.3 + github.com/gorilla/websocket v1.5.3 github.com/jpillora/eventsource v1.2.0 github.com/kelseyhightower/envconfig v1.4.0 + github.com/larksuite/oapi-sdk-go/v3 v3.5.3 github.com/liut/simpauth v0.1.18 github.com/liut/staffio-client v0.2.10 github.com/marcsv/go-binder v0.0.0-20160121205837-a8bae0b66e09 @@ -42,6 +44,7 @@ require ( github.com/go-logr/logr v1.4.3 // indirect github.com/go-logr/stdr v1.2.2 // indirect github.com/go-shiori/dom v0.0.0-20230515143342-73569d674e1c // indirect + github.com/gogo/protobuf v1.3.2 // indirect github.com/gogs/chardet v0.0.0-20211120154057-b7413eaefb8f // indirect github.com/google/uuid v1.6.0 // indirect github.com/invopop/jsonschema v0.13.0 // indirect diff --git a/go.sum b/go.sum index e00d1e9..9e3c65c 100644 --- a/go.sum +++ b/go.sum @@ -47,6 +47,8 @@ github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= github.com/go-shiori/dom v0.0.0-20230515143342-73569d674e1c h1:wpkoddUomPfHiOziHZixGO5ZBS73cKqVzZipfrLmO1w= github.com/go-shiori/dom v0.0.0-20230515143342-73569d674e1c/go.mod h1:oVDCh3qjJMLVUSILBRwrm+Bc6RNXGZYtoh9xdvf1ffM= +github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= +github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= github.com/gogs/chardet v0.0.0-20211120154057-b7413eaefb8f h1:3BSP1Tbs2djlpprl7wCLuiqMaUh5SJkkzI2gDs+FgLs= github.com/gogs/chardet v0.0.0-20211120154057-b7413eaefb8f/go.mod h1:Pcatq5tYkCW2Q6yrR2VRHlbHpZ/R4/7qyL1TCF7vl14= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= @@ -54,6 +56,9 @@ github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= +github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= +github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/invopop/jsonschema v0.13.0 h1:KvpoAJWEjR3uD9Kbm2HWJmqsEaHt8lBUpd0qHcIi21E= github.com/invopop/jsonschema v0.13.0/go.mod h1:ffZ5Km5SWWRAIN6wbDXItl95euhFz2uON45H2qjYt+0= github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= @@ -62,6 +67,8 @@ github.com/jpillora/eventsource v1.2.0 h1:UNvcC7v/4aq7xgRZiD3uOSdmfcq08r2k5WCwzA github.com/jpillora/eventsource v1.2.0/go.mod h1:K3tRq8cBJgDqIQ8L5wKk9Fe5aeLgKfrRg1XF3zAO2lA= github.com/kelseyhightower/envconfig v1.4.0 h1:Im6hONhd3pLkfDFsbRgu68RDNkGF1r3dvMUtDTo2cv8= github.com/kelseyhightower/envconfig v1.4.0/go.mod h1:cccZRl6mQpaq41TPp5QxidR+Sa3axMbJDNb//FQX6Gg= +github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8= +github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= github.com/klauspost/cpuid/v2 v2.0.9 h1:lgaqFMSdTdQYdZ04uHyN2d/eKdOMyi2YLSvlQIBFYa4= github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= @@ -71,6 +78,8 @@ github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/larksuite/oapi-sdk-go/v3 v3.5.3 h1:xvf8Dv29kBXC5/DNDCLhHkAFW8l/0LlQJimO5Zn+JUk= +github.com/larksuite/oapi-sdk-go/v3 v3.5.3/go.mod h1:ZEplY+kwuIrj/nqw5uSCINNATcH3KdxSN7y+UxYY5fI= github.com/liut/simpauth v0.1.18 h1:uPAeefL8mBXpOv1Cn5r3cHqzv/wuIyFFHGtpFJai8l4= github.com/liut/simpauth v0.1.18/go.mod h1:7DBCXACVqthUCo3T8rLX7v3DpMwQWctEhWpInmaphc4= github.com/liut/staffio-client v0.2.10 h1:CR7I7jJmozpNVuI4R82oePl6GM+pcZNbvVBsVp+wZPk= @@ -152,6 +161,8 @@ github.com/yalue/merged_fs v1.3.0 h1:qCeh9tMPNy/i8cwDsQTJ5bLr6IRxbs6meakNE5O+wyY github.com/yalue/merged_fs v1.3.0/go.mod h1:WqqchfVYQyclV2tnR7wtRhBddzBvLVR83Cjw9BKQw0M= github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= +github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= +github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= github.com/yuin/goldmark v1.7.1 h1:3bajkSilaCbjdKVsKdZjZCLBNPL9pYzrCakKaf4U49U= github.com/yuin/goldmark v1.7.1/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E= @@ -176,6 +187,8 @@ go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN8 go.uber.org/zap v1.27.1 h1:08RqriUEv8+ArZRYSTXy1LeBScaMpVSTBhCeaZYfMYc= go.uber.org/zap v1.27.1/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliYc= golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU= @@ -184,12 +197,17 @@ golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk= golang.org/x/crypto v0.49.0 h1:+Ng2ULVvLHnJ/ZFEq4KdcDd/cfjrrjjNSXNzxg0Y4U4= golang.org/x/crypto v0.49.0/go.mod h1:ErX4dUh2UM+CFYiXZRTcMpEcN8b/1gxEuv3nODoYtCA= +golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/mod v0.15.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= @@ -205,6 +223,8 @@ golang.org/x/net v0.52.0/go.mod h1:R1MAz7uMZxVMualyPXb+VaqGSa3LIaUqk0eEt3w36Sw= golang.org/x/oauth2 v0.36.0 h1:peZ/1z27fi9hUOFCAZaHyrpWG5lwe0RJEEEeH0ThlIs= golang.org/x/oauth2 v0.36.0/go.mod h1:YDBUJMTkDnJS+A4BP4eZBjCqtokkg1hODuPjwiGPO7Q= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y= @@ -212,6 +232,8 @@ golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -250,11 +272,16 @@ golang.org/x/text v0.35.0 h1:JOVx6vVDFokkpaq1AEptVzLTpDe9KGpj5tR4/X+ybL8= golang.org/x/text v0.35.0/go.mod h1:khi/HExzZJ2pGnjenulevKNX1W67CUy0AsXcNubPGCA= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= +golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= golang.org/x/tools v0.13.0/go.mod h1:HvlwmtVNQAhOuCjW7xxvovg8wbNq7LwfXh/k7wXUl58= golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= diff --git a/main.go b/main.go index 1762015..e73a0a0 100644 --- a/main.go +++ b/main.go @@ -4,6 +4,7 @@ import ( "bufio" "context" "fmt" + "log/slog" "net/http" "os" "os/signal" @@ -252,6 +253,9 @@ func main() { } func webRun(cc *cli.Context) error { + if settings.InDevelop() { + slog.SetLogLoggerLevel(slog.LevelDebug) + } if err := stores.InitDB(cc.Context); err != nil { return err } diff --git a/pkg/models/aigc/preset.go b/pkg/models/aigc/preset.go index 83574e3..1affe48 100644 --- a/pkg/models/aigc/preset.go +++ b/pkg/models/aigc/preset.go @@ -11,4 +11,14 @@ type Preset struct { // toolName -> description Tools map[string]string `json:"tools,omitempty" yaml:"tools,omitempty"` + + // Channels holds channel adapter configurations + Channels map[string]ChannelConfig `json:"channels,omitempty" yaml:"channels,omitempty"` +} + +// ChannelConfig holds configuration for a single channel adapter. +type ChannelConfig struct { + Enable bool `json:"enable,omitempty" yaml:"enable,omitempty"` + Mode string `json:"mode,omitempty" yaml:"mode,omitempty"` // "websocket", "webhook" + Config map[string]any `json:"config,omitempty" yaml:"config,omitempty"` } diff --git a/pkg/models/channel/channel.go b/pkg/models/channel/channel.go new file mode 100644 index 0000000..070fade --- /dev/null +++ b/pkg/models/channel/channel.go @@ -0,0 +1,50 @@ +package channel + +import ( + "context" + "errors" +) + +// ErrNotSupported indicates a channel doesn't support a particular operation. +var ErrNotSupported = errors.New("operation not supported by this channel") + +// Channel abstracts a messaging channel (WeCom, Feishu, DingTalk, etc.). +type Channel interface { + Name() string + Start(handler MessageHandler) error + Reply(ctx context.Context, replyCtx any, content string) error + Send(ctx context.Context, replyCtx any, content string) error + Stop() error +} + +// ReplyContextReconstructor is an optional interface for channels that can +// recreate a reply context from a session key. This is needed for cron jobs +// to send messages to users without an incoming message. +type ReplyContextReconstructor interface { + ReconstructReplyCtx(sessionKey string) (any, error) +} + +// TypingIndicator is an optional interface for channels that can show a +// "processing" indicator (typing bubble, emoji reaction, etc.) while the +// agent is working. +type TypingIndicator interface { + StartTyping(ctx context.Context, replyCtx any) (stop func()) +} + +// ImageSender is an optional interface for channels that support sending images. +type ImageSender interface { + SendImage(ctx context.Context, replyCtx any, img ImageAttachment) error +} + +// FileSender is an optional interface for channels that support sending files. +type FileSender interface { + SendFile(ctx context.Context, replyCtx any, file FileAttachment) error +} + +// MessageUpdater is an optional interface for channels that support updating messages. +type MessageUpdater interface { + UpdateMessage(ctx context.Context, replyCtx any, content string) error +} + +// MessageHandler is called by channels when a new message arrives. +type MessageHandler func(p Channel, msg *Message) diff --git a/pkg/models/channel/message.go b/pkg/models/channel/message.go new file mode 100644 index 0000000..c006c8b --- /dev/null +++ b/pkg/models/channel/message.go @@ -0,0 +1,39 @@ +package channel + +// ImageAttachment represents an image sent by the user. +type ImageAttachment struct { + MimeType string // e.g. "image/png", "image/jpeg" + Data []byte // raw image bytes + FileName string // original filename (optional) +} + +// FileAttachment represents a file (PDF, doc, spreadsheet, etc.) sent by the user. +type FileAttachment struct { + MimeType string // e.g. "application/pdf", "text/plain" + Data []byte // raw file bytes + FileName string +} + +// AudioAttachment represents a voice/audio message sent by the user. +type AudioAttachment struct { + MimeType string // e.g. "audio/amr", "audio/ogg", "audio/mp4" + Data []byte // raw audio bytes + Format string // short format hint: "amr", "ogg", "m4a", "mp3", "wav", etc. + Duration int // duration in seconds (if known) +} + +// Message represents a unified incoming message from any channel. +type Message struct { + SessionKey string // unique key for user context, e.g. "wecom:{userID}" + Channel string + MessageID string // channel message ID for tracing/dedup + UserID string + UserName string + ChatName string // human-readable chat/group name (optional) + Content string + Images []ImageAttachment // attached images (if any) + Files []FileAttachment // attached files (if any) + Audio *AudioAttachment // voice message (if any) + ReplyCtx any // channel-specific context needed for replying + FromVoice bool // true if message originated from voice transcription +} diff --git a/pkg/services/channels/allowlist.go b/pkg/services/channels/allowlist.go new file mode 100644 index 0000000..14eb999 --- /dev/null +++ b/pkg/services/channels/allowlist.go @@ -0,0 +1,21 @@ +package channels + +import ( + "strings" +) + +// AllowList checks whether a user ID is permitted based on a comma-separated +// allow_from string. Returns true if allowFrom is empty or "*" (allow all), +// or if the userID is in the list. Comparison is case-insensitive. +func AllowList(allowFrom, userID string) bool { + allowFrom = strings.TrimSpace(allowFrom) + if allowFrom == "" || allowFrom == "*" { + return true + } + for _, id := range strings.Split(allowFrom, ",") { + if strings.EqualFold(strings.TrimSpace(id), userID) { + return true + } + } + return false +} diff --git a/pkg/services/channels/dedup.go b/pkg/services/channels/dedup.go new file mode 100644 index 0000000..7d2bab9 --- /dev/null +++ b/pkg/services/channels/dedup.go @@ -0,0 +1,80 @@ +package channels + +import ( + "context" + "fmt" + "log/slog" + "sync" + "time" + + "github.com/redis/go-redis/v9" +) + +const dedupTTL = 60 * time.Second + +// Dedup provides message deduplication using Redis. +type Dedup struct { + rdb redis.UniversalClient +} + +// NewDedup creates a new Dedup instance with the given Redis client. +func NewDedup(rdb redis.UniversalClient) *Dedup { + return &Dedup{rdb: rdb} +} + +// IsDuplicate checks if the message ID has been seen within the TTL window. +// Returns true if this is a duplicate, false if it's a new message. +// Marks the message as seen if it's new. +func (d *Dedup) IsDuplicate(ctx context.Context, ch string, msgID string) (bool, error) { + if msgID == "" { + return false, nil + } + + key := fmt.Sprintf("channel:dedup:%s:%s", ch, msgID) + + // Try to set the key only if it doesn't exist (NX) + added, err := d.rdb.SetNX(ctx, key, "1", dedupTTL).Result() + if err != nil { + return false, fmt.Errorf("dedup check: %w", err) + } + + if !added { + slog.Debug("channel: duplicate message skipped", + "channel", ch, "msg_id", msgID) + } + + return !added, nil +} + +// msgDedup is an in-memory alternative for testing without Redis. +type msgDedup struct { + mu sync.Mutex + seen map[string]time.Time +} + +// NewMsgDedup creates an in-memory dedup tracker. +func NewMsgDedup() *msgDedup { + return &msgDedup{seen: make(map[string]time.Time)} +} + +// IsDuplicate checks in-memory deduplication. +func (d *msgDedup) IsDuplicate(msgID string) bool { + if msgID == "" { + return false + } + d.mu.Lock() + defer d.mu.Unlock() + + now := time.Now() + for k, t := range d.seen { + if now.Sub(t) > dedupTTL { + delete(d.seen, k) + } + } + + if _, exists := d.seen[msgID]; exists { + return true + } + d.seen[msgID] = now + return false +} diff --git a/pkg/services/channels/feishu/feishu.go b/pkg/services/channels/feishu/feishu.go new file mode 100644 index 0000000..2ff3927 --- /dev/null +++ b/pkg/services/channels/feishu/feishu.go @@ -0,0 +1,25 @@ +package feishu + +import ( + "fmt" + + "github.com/liut/morign/pkg/models/channel" + "github.com/liut/morign/pkg/services/channels" +) + +func init() { + channels.RegisterChannel("feishu", New) +} + +// New creates a Feishu channel adapter. +func New(opts map[string]any) (channel.Channel, error) { + mode, _ := opts["mode"].(string) + switch mode { + case "websocket": + return newWebSocket(opts) + case "webhook": + return newWebhook(opts) + default: + return nil, fmt.Errorf("feishu: unsupported mode %q (supported: websocket, webhook)", mode) + } +} diff --git a/pkg/services/channels/feishu/webhook.go b/pkg/services/channels/feishu/webhook.go new file mode 100644 index 0000000..e0d49e5 --- /dev/null +++ b/pkg/services/channels/feishu/webhook.go @@ -0,0 +1,533 @@ +package feishu + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "encoding/json" + "fmt" + "io" + "log/slog" + "net/http" + "sort" + "strings" + "sync" + "time" + + "github.com/go-chi/chi/v5" + + lark "github.com/larksuite/oapi-sdk-go/v3" + larkcore "github.com/larksuite/oapi-sdk-go/v3/core" + larkcontact "github.com/larksuite/oapi-sdk-go/v3/service/contact/v3" + larkim "github.com/larksuite/oapi-sdk-go/v3/service/im/v1" + + "github.com/liut/morign/pkg/models/channel" + "github.com/liut/morign/pkg/services/channels" + "github.com/liut/morign/pkg/services/stores" +) + +// HTTPChannel implements channel.Channel for Feishu HTTP webhook callback. +type HTTPChannel struct { + appID string + appSecret string + allowFrom string + encryptKey string + callbackPath string + handler channel.MessageHandler + server *http.Server + dedup *channels.Dedup + botOpenID string + userNameCache sync.Map + apiClient *http.Client +} + +type webhookReplyContext struct { + messageID string + chatID string + userID string +} + +// feishuWebhookEvent is the structure for incoming webhook events +type feishuWebhookEvent struct { + Schema string `json:"schema"` + Header struct { + EventID string `json:"event_id"` + EventType string `json:"event_type"` + Token string `json:"token"` + AppID string `json:"app_id"` + TenantKey string `json:"tenant_key"` + CreateTime string `json:"create_time"` + } `json:"header"` + Event struct { + // Sender information + Sender struct { + SenderID struct { + OpenID string `json:"open_id"` + UserID string `json:"user_id"` + UnionID string `json:"union_id"` + } `json:"sender_id"` + SenderType string `json:"sender_type"` + TenantKey string `json:"tenant_key"` + } `json:"sender"` + // Message information + Message struct { + MessageID string `json:"message_id"` + RootID string `json:"root_id"` + ParentID string `json:"parent_id"` + CreateTime string `json:"create_time"` + ChatID string `json:"chat_id"` + ChatType string `json:"chat_type"` + MessageType string `json:"message_type"` + Content string `json:"content"` + } `json:"message"` + } `json:"event"` +} + +func newWebhook(opts map[string]any) (channel.Channel, error) { + appID, _ := opts["app_id"].(string) + appSecret, _ := opts["app_secret"].(string) + if appID == "" || appSecret == "" { + return nil, fmt.Errorf("feishu-webhook: app_id and app_secret are required") + } + allowFrom, _ := opts["allow_from"].(string) + encryptKey, _ := opts["encrypt_key"].(string) + callbackPath, _ := opts["callback_path"].(string) + if callbackPath == "" { + callbackPath = "/feishu/callback" + } + + return &HTTPChannel{ + appID: appID, + appSecret: appSecret, + allowFrom: allowFrom, + encryptKey: encryptKey, + callbackPath: callbackPath, + dedup: channels.NewDedup(stores.SgtRC()), + apiClient: &http.Client{ + Timeout: 30 * time.Second, + }, + }, nil +} + +func (p *HTTPChannel) Name() string { return "feishu" } + +func (p *HTTPChannel) Start(handler channel.MessageHandler) error { + p.handler = handler + + if err := p.fetchBotOpenID(); err != nil { + slog.Warn("feishu-webhook: failed to get bot open_id", "error", err) + } + + return nil +} + +func (p *HTTPChannel) fetchBotOpenID() error { + client := lark.NewClient(p.appID, p.appSecret) + resp, err := client.Get(context.Background(), + "/open-apis/bot/v3/info", nil, larkcore.AccessTokenTypeTenant) + if err != nil { + return fmt.Errorf("fetch bot info: %w", err) + } + var result struct { + Code int `json:"code"` + Bot struct { + OpenID string `json:"open_id"` + } `json:"bot"` + } + if err := json.Unmarshal(resp.RawBody, &result); err != nil { + return fmt.Errorf("parse response: %w", err) + } + if result.Code != 0 { + return fmt.Errorf("api code=%d", result.Code) + } + p.botOpenID = result.Bot.OpenID + slog.Info("feishu-webhook: bot identified", "open_id", p.botOpenID) + return nil +} + +func (p *HTTPChannel) RegisterHTTPRoutes(r chi.Router, callbackPath string, handler channel.MessageHandler) { + p.callbackPath = callbackPath + r.Method(http.MethodGet, callbackPath, http.HandlerFunc(p.handleVerify)) + r.Method(http.MethodPost, callbackPath, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + p.handleMessage(w, r, handler) + })) + slog.Info("feishu-webhook: routes registered", "path", callbackPath) +} + +func (p *HTTPChannel) handleVerify(w http.ResponseWriter, r *http.Request) { + q := r.URL.Query() + signature := q.Get("signature") + timestamp := q.Get("timestamp") + nonce := q.Get("nonce") + echostr := q.Get("echostr") + + if signature == "" { + slog.Warn("feishu-webhook: missing signature in verification") + w.WriteHeader(http.StatusForbidden) + return + } + + if !p.verifySignature(signature, timestamp, nonce, echostr) { + slog.Warn("feishu-webhook: verify signature failed") + w.WriteHeader(http.StatusForbidden) + return + } + + // Decode echostr if encrypted + var plain string + if p.encryptKey != "" { + var err error + plain, err = p.decrypt(echostr) + if err != nil { + slog.Error("feishu-webhook: decrypt echostr failed", "error", err) + w.WriteHeader(http.StatusInternalServerError) + return + } + } else { + plain = echostr + } + + slog.Info("feishu-webhook: URL verification succeeded") + w.Header().Set("Content-Type", "text/plain") + w.WriteHeader(http.StatusOK) + fmt.Fprint(w, plain) +} + +func (p *HTTPChannel) handleMessage(w http.ResponseWriter, r *http.Request, handler channel.MessageHandler) { + q := r.URL.Query() + signature := q.Get("signature") + timestamp := q.Get("timestamp") + nonce := q.Get("nonce") + + body, err := io.ReadAll(io.LimitReader(r.Body, 1<<20)) + if err != nil { + slog.Error("feishu-webhook: read body failed", "error", err) + w.WriteHeader(http.StatusBadRequest) + return + } + + // Parse challenge if present (for URL verification) + var checkReq struct { + Challenge string `json:"challenge"` + } + if err := json.Unmarshal(body, &checkReq); err == nil && checkReq.Challenge != "" { + // This is a URL verification request + if signature == "" || !p.verifySignature(signature, timestamp, nonce, checkReq.Challenge) { + w.WriteHeader(http.StatusForbidden) + return + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(map[string]string{"challenge": checkReq.Challenge}) + return + } + + if !p.verifySignature(signature, timestamp, nonce, string(body)) { + slog.Warn("feishu-webhook: message signature verification failed") + w.WriteHeader(http.StatusForbidden) + return + } + + // Return 200 immediately (Feishu requires response within 3 seconds) + w.WriteHeader(http.StatusOK) + + // Parse the webhook event + var event feishuWebhookEvent + if err := json.Unmarshal(body, &event); err != nil { + slog.Error("feishu-webhook: parse event failed", "error", err) + return + } + + // Handle the message event + if event.Header.EventType == "im.message.receive_v1" { + p.onWebhookMessage(&event) + } +} + +func (p *HTTPChannel) onWebhookMessage(event *feishuWebhookEvent) { + msg := &event.Event.Message + sender := &event.Event.Sender + + msgType := msg.MessageType + chatID := msg.ChatID + chatType := msg.ChatType + threadID := msg.RootID + userID := sender.SenderID.OpenID + messageID := msg.MessageID + + // Filter: skip messages without message ID + if messageID == "" { + slog.Debug("feishu-webhook: message without ID ignored") + return + } + + // Debug logging for received messages + slog.Info("feishu-webhook: message received", + "msg_id", messageID, + "msg_type", msgType, + "chat_id", chatID, + "chat_type", chatType, + "thread_id", threadID, + "user_id", userID, + "content_len", len(msg.Content), + ) + + // Deduplicate + if p.dedup != nil { + isDup, _ := p.dedup.IsDuplicate(context.Background(), "feishu", messageID) + if isDup { + slog.Info("feishu-webhook: skipping duplicate message", "msg_id", messageID) + return + } + } + + // Check allow_from filter + if !channels.AllowList(p.allowFrom, userID) { + slog.Info("feishu-webhook: message from unauthorized user", "user", userID) + return + } + + sessionKey := fmt.Sprintf("feishu:%s:%s", chatID, userID) + rctx := webhookReplyContext{ + messageID: messageID, + chatID: chatID, + userID: userID, + } + + switch msgType { + case "text": + var textBody struct { + Text string `json:"text"` + } + if err := json.Unmarshal([]byte(msg.Content), &textBody); err != nil { + slog.Error("feishu-webhook: failed to parse text content", "error", err) + return + } + text := strings.TrimSpace(textBody.Text) + if text == "" { + slog.Debug("feishu-webhook: dropping empty text") + return + } + go p.handler(p, &channel.Message{ + SessionKey: sessionKey, Channel: "feishu", + MessageID: messageID, + UserID: userID, UserName: p.resolveUserName(userID), + Content: text, ReplyCtx: rctx, + }) + + case "image": + var imgBody struct { + ImageKey string `json:"image_key"` + } + if err := json.Unmarshal([]byte(msg.Content), &imgBody); err != nil { + slog.Error("feishu-webhook: failed to parse image content", "error", err) + return + } + imgData, mimeType, err := p.downloadImage(messageID, imgBody.ImageKey) + if err != nil { + slog.Error("feishu-webhook: download image failed", "error", err) + return + } + go p.handler(p, &channel.Message{ + SessionKey: sessionKey, Channel: "feishu", + MessageID: messageID, + UserID: userID, UserName: p.resolveUserName(userID), + Images: []channel.ImageAttachment{{MimeType: mimeType, Data: imgData}}, + ReplyCtx: rctx, + }) + + case "audio": + var audioBody struct { + FileKey string `json:"file_key"` + Duration int `json:"duration"` + } + if err := json.Unmarshal([]byte(msg.Content), &audioBody); err != nil { + slog.Error("feishu-webhook: failed to parse audio content", "error", err) + return + } + slog.Debug("feishu-webhook: audio received", "user", userID, "file_key", audioBody.FileKey) + audioData, err := p.downloadResource(messageID, audioBody.FileKey, "file") + if err != nil { + slog.Error("feishu-webhook: download audio failed", "error", err) + return + } + go p.handler(p, &channel.Message{ + SessionKey: sessionKey, Channel: "feishu", + MessageID: messageID, + UserID: userID, UserName: p.resolveUserName(userID), + Audio: &channel.AudioAttachment{MimeType: "audio/opus", Data: audioData, Format: "ogg"}, + ReplyCtx: rctx, + }) + + default: + slog.Debug("feishu-webhook: ignoring unsupported message type", "type", msgType) + } +} + +func (p *HTTPChannel) Reply(ctx context.Context, rctx any, content string) error { + rc, ok := rctx.(webhookReplyContext) + if !ok { + return fmt.Errorf("feishu-webhook: invalid reply context type %T", rctx) + } + if content == "" { + return nil + } + + client := lark.NewClient(p.appID, p.appSecret) + msgType, msgBody := buildReplyContent(content) + + resp, err := client.Im.Message.Reply(ctx, larkim.NewReplyMessageReqBuilder(). + MessageId(rc.messageID). + Body(larkim.NewReplyMessageReqBodyBuilder(). + MsgType(msgType). + Content(msgBody). + ReplyInThread(true). + Build()). + Build()) + if err != nil { + return fmt.Errorf("feishu-webhook: reply api call: %w", err) + } + if !resp.Success() { + return fmt.Errorf("feishu-webhook: reply failed code=%d msg=%s", resp.Code, resp.Msg) + } + return nil +} + +func (p *HTTPChannel) Send(ctx context.Context, rctx any, content string) error { + rc, ok := rctx.(webhookReplyContext) + if !ok { + return fmt.Errorf("feishu-webhook: invalid reply context type %T", rctx) + } + if content == "" { + return nil + } + if rc.chatID == "" { + return fmt.Errorf("feishu-webhook: chatID is empty, cannot send proactive message") + } + + client := lark.NewClient(p.appID, p.appSecret) + msgType, _ := buildReplyContent(content) + + chunks := splitByBytes(content, 4000) + for _, chunk := range chunks { + _, body := buildReplyContent(chunk) + resp, err := client.Im.Message.Create(ctx, larkim.NewCreateMessageReqBuilder(). + ReceiveIdType(larkim.ReceiveIdTypeChatId). + Body(larkim.NewCreateMessageReqBodyBuilder(). + ReceiveId(rc.chatID). + MsgType(msgType). + Content(body). + Build()). + Build()) + if err != nil { + return fmt.Errorf("feishu-webhook: send api call: %w", err) + } + if !resp.Success() { + return fmt.Errorf("feishu-webhook: send failed code=%d msg=%s", resp.Code, resp.Msg) + } + } + return nil +} + +func (p *HTTPChannel) Stop() error { + if p.server != nil { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if err := p.server.Shutdown(ctx); err != nil { + slog.Error("feishu-webhook: server shutdown error", "error", err) + } + } + return nil +} + +func (p *HTTPChannel) ReconstructReplyCtx(sessionKey string) (any, error) { + parts := strings.SplitN(sessionKey, ":", 3) + if len(parts) < 3 || parts[0] != "feishu" { + return nil, fmt.Errorf("feishu-webhook: invalid session key %q", sessionKey) + } + return webhookReplyContext{chatID: parts[1], userID: parts[2]}, nil +} + +func (p *HTTPChannel) resolveUserName(openID string) string { + if cached, ok := p.userNameCache.Load(openID); ok { + return cached.(string) + } + client := lark.NewClient(p.appID, p.appSecret) + resp, err := client.Contact.User.Get(context.Background(), + larkcontact.NewGetUserReqBuilder(). + UserId(openID). + UserIdType("open_id"). + Build()) + if err != nil { + slog.Debug("feishu-webhook: resolve user name failed", "open_id", openID, "error", err) + return openID + } + if !resp.Success() || resp.Data == nil || resp.Data.User == nil || resp.Data.User.Name == nil { + return openID + } + name := *resp.Data.User.Name + p.userNameCache.Store(openID, name) + return name +} + +func (p *HTTPChannel) downloadImage(messageID, imageKey string) ([]byte, string, error) { + client := lark.NewClient(p.appID, p.appSecret) + resp, err := client.Im.MessageResource.Get(context.Background(), + larkim.NewGetMessageResourceReqBuilder(). + MessageId(messageID). + FileKey(imageKey). + Type("image"). + Build()) + if err != nil { + return nil, "", fmt.Errorf("feishu-webhook: image API: %w", err) + } + if !resp.Success() { + return nil, "", fmt.Errorf("feishu-webhook: image API code=%d msg=%s", resp.Code, resp.Msg) + } + if resp.File == nil { + return nil, "", fmt.Errorf("feishu-webhook: image API returned nil file body") + } + data, err := io.ReadAll(resp.File) + if err != nil { + return nil, "", fmt.Errorf("feishu-webhook: read image: %w", err) + } + mimeType := detectMimeType(data) + return data, mimeType, nil +} + +func (p *HTTPChannel) downloadResource(messageID, fileKey, resType string) ([]byte, error) { + client := lark.NewClient(p.appID, p.appSecret) + resp, err := client.Im.MessageResource.Get(context.Background(), + larkim.NewGetMessageResourceReqBuilder(). + MessageId(messageID). + FileKey(fileKey). + Type(resType). + Build()) + if err != nil { + return nil, fmt.Errorf("feishu-webhook: resource API: %w", err) + } + if !resp.Success() { + return nil, fmt.Errorf("feishu-webhook: resource API code=%d msg=%s", resp.Code, resp.Msg) + } + if resp.File == nil { + return nil, fmt.Errorf("feishu-webhook: resource API returned nil file body") + } + return io.ReadAll(resp.File) +} + +func (p *HTTPChannel) verifySignature(expected, timestamp, nonce, encrypt string) bool { + parts := []string{timestamp, nonce, encrypt} + sort.Strings(parts) + h := sha256.New() + h.Write([]byte(strings.Join(parts, ""))) + got := hex.EncodeToString(h.Sum(nil)) + return got == expected +} + +func (p *HTTPChannel) decrypt(echostr string) (string, error) { + // TODO: Implement AES decryption for encrypted messages + // For now, return the echostr as-is + return echostr, nil +} + +var _ channel.Channel = (*HTTPChannel)(nil) diff --git a/pkg/services/channels/feishu/websocket.go b/pkg/services/channels/feishu/websocket.go new file mode 100644 index 0000000..865d697 --- /dev/null +++ b/pkg/services/channels/feishu/websocket.go @@ -0,0 +1,479 @@ +package feishu + +import ( + "context" + "encoding/json" + "fmt" + "log/slog" + "strings" + "sync" + + lark "github.com/larksuite/oapi-sdk-go/v3" + larkcore "github.com/larksuite/oapi-sdk-go/v3/core" + "github.com/larksuite/oapi-sdk-go/v3/event/dispatcher" + larkcontact "github.com/larksuite/oapi-sdk-go/v3/service/contact/v3" + larkim "github.com/larksuite/oapi-sdk-go/v3/service/im/v1" + larkws "github.com/larksuite/oapi-sdk-go/v3/ws" + + "github.com/liut/morign/pkg/models/channel" + "github.com/liut/morign/pkg/services/channels" + "github.com/liut/morign/pkg/services/stores" +) + +// WSChannel implements channel.Channel using Feishu WebSocket long-connection mode. +type WSChannel struct { + appID string + appSecret string + allowFrom string + wsClient *larkws.Client + handler channel.MessageHandler + eventHandler *dispatcher.EventDispatcher + ctx context.Context + cancel context.CancelFunc + dedup *channels.Dedup + botOpenID string + userNameCache sync.Map // open_id -> display name +} + +type wsReplyContext struct { + messageID string + chatID string + userID string +} + +func newWebSocket(opts map[string]any) (channel.Channel, error) { + appID, _ := opts["app_id"].(string) + appSecret, _ := opts["app_secret"].(string) + if appID == "" || appSecret == "" { + return nil, fmt.Errorf("feishu-ws: app_id and app_secret are required") + } + allowFrom, _ := opts["allow_from"].(string) + + return &WSChannel{ + appID: appID, + appSecret: appSecret, + allowFrom: allowFrom, + dedup: channels.NewDedup(stores.SgtRC()), + }, nil +} + +func (p *WSChannel) Name() string { return "feishu" } + +func (p *WSChannel) Start(handler channel.MessageHandler) error { + p.handler = handler + p.ctx, p.cancel = context.WithCancel(context.Background()) + + p.eventHandler = dispatcher.NewEventDispatcher("", ""). + OnP2MessageReceiveV1(func(ctx context.Context, event *larkim.P2MessageReceiveV1) error { + return p.onMessage(event) + }) + + wsOpts := []larkws.ClientOption{ + larkws.WithEventHandler(p.eventHandler), + larkws.WithLogLevel(larkcore.LogLevelDebug), + } + p.wsClient = larkws.NewClient(p.appID, p.appSecret, wsOpts...) + + go func() { + if err := p.wsClient.Start(p.ctx); err != nil { + slog.Error("feishu-ws: websocket error", "error", err) + } + }() + + if err := p.fetchBotOpenID(); err != nil { + slog.Warn("feishu-ws: failed to get bot open_id", "error", err) + } + + return nil +} + +func (p *WSChannel) fetchBotOpenID() error { + client := lark.NewClient(p.appID, p.appSecret) + resp, err := client.Get(context.Background(), + "/open-apis/bot/v3/info", nil, larkcore.AccessTokenTypeTenant) + if err != nil { + return fmt.Errorf("fetch bot info: %w", err) + } + var result struct { + Code int `json:"code"` + Bot struct { + OpenID string `json:"open_id"` + } `json:"bot"` + } + if err := json.Unmarshal(resp.RawBody, &result); err != nil { + return fmt.Errorf("parse response: %w", err) + } + if result.Code != 0 { + return fmt.Errorf("api code=%d", result.Code) + } + p.botOpenID = result.Bot.OpenID + slog.Info("feishu-ws: bot identified", "open_id", p.botOpenID) + return nil +} + +func (p *WSChannel) Stop() error { + if p.cancel != nil { + p.cancel() + } + return nil +} + +func (p *WSChannel) Reply(ctx context.Context, rctx any, content string) error { + rc, ok := rctx.(wsReplyContext) + if !ok { + return fmt.Errorf("feishu-ws: invalid reply context type %T", rctx) + } + if content == "" { + return nil + } + + client := lark.NewClient(p.appID, p.appSecret) + msgType, msgBody := buildReplyContent(content) + + resp, err := client.Im.Message.Reply(ctx, larkim.NewReplyMessageReqBuilder(). + MessageId(rc.messageID). + Body(larkim.NewReplyMessageReqBodyBuilder(). + MsgType(msgType). + Content(msgBody). + ReplyInThread(true). + Build()). + Build()) + if err != nil { + return fmt.Errorf("feishu-ws: reply api call: %w", err) + } + if !resp.Success() { + return fmt.Errorf("feishu-ws: reply failed code=%d msg=%s", resp.Code, resp.Msg) + } + return nil +} + +func (p *WSChannel) Send(ctx context.Context, rctx any, content string) error { + rc, ok := rctx.(wsReplyContext) + if !ok { + return fmt.Errorf("feishu-ws: invalid reply context type %T", rctx) + } + if content == "" { + return nil + } + if rc.chatID == "" { + return fmt.Errorf("feishu-ws: chatID is empty, cannot send proactive message") + } + + client := lark.NewClient(p.appID, p.appSecret) + msgType, _ := buildReplyContent(content) + + chunks := splitByBytes(content, 4000) + for _, chunk := range chunks { + _, body := buildReplyContent(chunk) + resp, err := client.Im.Message.Create(ctx, larkim.NewCreateMessageReqBuilder(). + ReceiveIdType(larkim.ReceiveIdTypeChatId). + Body(larkim.NewCreateMessageReqBodyBuilder(). + ReceiveId(rc.chatID). + MsgType(msgType). + Content(body). + Build()). + Build()) + if err != nil { + return fmt.Errorf("feishu-ws: send api call: %w", err) + } + if !resp.Success() { + return fmt.Errorf("feishu-ws: send failed code=%d msg=%s", resp.Code, resp.Msg) + } + } + return nil +} + +func (p *WSChannel) ReconstructReplyCtx(sessionKey string) (any, error) { + parts := strings.SplitN(sessionKey, ":", 3) + if len(parts) < 3 || parts[0] != "feishu" { + return nil, fmt.Errorf("feishu-ws: invalid session key %q", sessionKey) + } + return wsReplyContext{chatID: parts[1], userID: parts[2]}, nil +} + +func (p *WSChannel) onMessage(event *larkim.P2MessageReceiveV1) error { + slog.Debug("feishu-ws: onMessage called") + + msg := event.Event.Message + sender := event.Event.Sender + + msgType := "" + if msg.MessageType != nil { + msgType = *msg.MessageType + } + + chatID := "" + if msg.ChatId != nil { + chatID = *msg.ChatId + } + + userID := "" + if sender.SenderId != nil && sender.SenderId.OpenId != nil { + userID = *sender.SenderId.OpenId + } + + messageID := "" + if msg.MessageId != nil { + messageID = *msg.MessageId + } + + // Filter: skip messages without message ID + if messageID == "" { + slog.Debug("feishu-ws: message without ID ignored") + return nil + } + + // Debug logging for received messages + slog.Info("feishu-ws: message received", + "msg_id", messageID, + "msg_type", msgType, + "chat_id", chatID, + "chat_type", ptrStr(msg.ChatType), + "thread_id", ptrStr(msg.ThreadId), + "user_id", userID, + "content_len", len(ptrStr(msg.Content)), + ) + + // Deduplicate + if p.dedup != nil { + isDup, _ := p.dedup.IsDuplicate(context.Background(), "feishu", messageID) + if isDup { + slog.Info("feishu-ws: skipping duplicate message", "msg_id", messageID) + return nil + } + } + + // Check allow_from filter + if !channels.AllowList(p.allowFrom, userID) { + slog.Info("feishu-ws: message from unauthorized user", "user", userID) + return nil + } + + sessionKey := fmt.Sprintf("feishu:%s:%s", chatID, userID) + rctx := wsReplyContext{ + messageID: messageID, + chatID: chatID, + userID: userID, + } + + switch msgType { + case "text": + var textBody struct { + Text string `json:"text"` + } + if err := json.Unmarshal([]byte(ptrStr(msg.Content)), &textBody); err != nil { + slog.Error("feishu-ws: failed to parse text content", "error", err) + return nil + } + text := stripMentions(textBody.Text, msg.Mentions, p.botOpenID) + if text == "" { + slog.Debug("feishu-ws: dropping empty text after mention stripping") + return nil + } + go p.handler(p, &channel.Message{ + SessionKey: sessionKey, Channel: "feishu", + MessageID: messageID, + UserID: userID, UserName: p.resolveUserName(userID), + Content: text, ReplyCtx: rctx, + }) + + case "image": + var imgBody struct { + ImageKey string `json:"image_key"` + } + if err := json.Unmarshal([]byte(ptrStr(msg.Content)), &imgBody); err != nil { + slog.Error("feishu-ws: failed to parse image content", "error", err) + return nil + } + imgData, mimeType, err := p.downloadImage(messageID, imgBody.ImageKey) + if err != nil { + slog.Error("feishu-ws: download image failed", "error", err) + return nil + } + go p.handler(p, &channel.Message{ + SessionKey: sessionKey, Channel: "feishu", + MessageID: messageID, + UserID: userID, UserName: p.resolveUserName(userID), + Images: []channel.ImageAttachment{{MimeType: mimeType, Data: imgData}}, + ReplyCtx: rctx, + }) + + case "audio": + var audioBody struct { + FileKey string `json:"file_key"` + Duration int `json:"duration"` + } + if err := json.Unmarshal([]byte(ptrStr(msg.Content)), &audioBody); err != nil { + slog.Error("feishu-ws: failed to parse audio content", "error", err) + return nil + } + slog.Debug("feishu-ws: audio received", "user", userID, "file_key", audioBody.FileKey) + audioData, err := p.downloadResource(messageID, audioBody.FileKey, "file") + if err != nil { + slog.Error("feishu-ws: download audio failed", "error", err) + return nil + } + go p.handler(p, &channel.Message{ + SessionKey: sessionKey, Channel: "feishu", + MessageID: messageID, + UserID: userID, UserName: p.resolveUserName(userID), + Audio: &channel.AudioAttachment{MimeType: "audio/opus", Data: audioData, Format: "ogg"}, + ReplyCtx: rctx, + }) + + default: + slog.Debug("feishu-ws: ignoring unsupported message type", "type", msgType) + } + + return nil +} + +func (p *WSChannel) resolveUserName(openID string) string { + if cached, ok := p.userNameCache.Load(openID); ok { + return cached.(string) + } + client := lark.NewClient(p.appID, p.appSecret) + resp, err := client.Contact.User.Get(context.Background(), + larkcontact.NewGetUserReqBuilder(). + UserId(openID). + UserIdType("open_id"). + Build()) + if err != nil { + slog.Debug("feishu-ws: resolve user name failed", "open_id", openID, "error", err) + return openID + } + if !resp.Success() || resp.Data == nil || resp.Data.User == nil || resp.Data.User.Name == nil { + return openID + } + name := *resp.Data.User.Name + p.userNameCache.Store(openID, name) + return name +} + +func (p *WSChannel) downloadImage(messageID, imageKey string) ([]byte, string, error) { + client := lark.NewClient(p.appID, p.appSecret) + resp, err := client.Im.MessageResource.Get(context.Background(), + larkim.NewGetMessageResourceReqBuilder(). + MessageId(messageID). + FileKey(imageKey). + Type("image"). + Build()) + if err != nil { + return nil, "", fmt.Errorf("feishu-ws: image API: %w", err) + } + if !resp.Success() { + return nil, "", fmt.Errorf("feishu-ws: image API code=%d msg=%s", resp.Code, resp.Msg) + } + if resp.File == nil { + return nil, "", fmt.Errorf("feishu-ws: image API returned nil file body") + } + data, err := ioReadAll(resp.File) + if err != nil { + return nil, "", fmt.Errorf("feishu-ws: read image: %w", err) + } + mimeType := detectMimeType(data) + return data, mimeType, nil +} + +func (p *WSChannel) downloadResource(messageID, fileKey, resType string) ([]byte, error) { + client := lark.NewClient(p.appID, p.appSecret) + resp, err := client.Im.MessageResource.Get(context.Background(), + larkim.NewGetMessageResourceReqBuilder(). + MessageId(messageID). + FileKey(fileKey). + Type(resType). + Build()) + if err != nil { + return nil, fmt.Errorf("feishu-ws: resource API: %w", err) + } + if !resp.Success() { + return nil, fmt.Errorf("feishu-ws: resource API code=%d msg=%s", resp.Code, resp.Msg) + } + if resp.File == nil { + return nil, fmt.Errorf("feishu-ws: resource API returned nil file body") + } + return ioReadAll(resp.File) +} + +func ioReadAll(r interface{ Read([]byte) (int, error) }) ([]byte, error) { + var buf []byte + tmp := make([]byte, 4096) + for { + n, err := r.Read(tmp) + buf = append(buf, tmp[:n]...) + if err != nil { + break + } + } + return buf, nil +} + +func detectMimeType(data []byte) string { + if len(data) >= 8 { + if data[0] == 0x89 && data[1] == 'P' && data[2] == 'N' && data[3] == 'G' { + return "image/png" + } + if data[0] == 0xFF && data[1] == 0xD8 { + return "image/jpeg" + } + if string(data[:4]) == "GIF8" { + return "image/gif" + } + } + return "image/png" +} + +func buildReplyContent(content string) (msgType string, body string) { + b, _ := json.Marshal(map[string]string{"text": content}) + return "text", string(b) +} + +func stripMentions(text string, mentions []*larkim.MentionEvent, botOpenID string) string { + if len(mentions) == 0 { + return text + } + for _, m := range mentions { + if m.Key == nil { + continue + } + if botOpenID != "" && m.Id != nil && m.Id.OpenId != nil && *m.Id.OpenId == botOpenID { + text = strings.ReplaceAll(text, *m.Key, "") + } else if m.Name != nil && *m.Name != "" { + text = strings.ReplaceAll(text, *m.Key, "@"+*m.Name) + } else { + text = strings.ReplaceAll(text, *m.Key, "") + } + } + return strings.TrimSpace(text) +} + +func splitByBytes(s string, maxBytes int) []string { + if len(s) <= maxBytes { + return []string{s} + } + var parts []string + for len(s) > 0 { + end := maxBytes + if end > len(s) { + end = len(s) + } + for end > 0 && end < len(s) && s[end]>>6 == 0b10 { + end-- + } + if end == 0 { + end = maxBytes + } + parts = append(parts, s[:end]) + s = s[end:] + } + return parts +} + +func ptrStr(s *string) string { + if s == nil { + return "" + } + return *s +} + +var _ channel.Channel = (*WSChannel)(nil) diff --git a/pkg/services/channels/registry.go b/pkg/services/channels/registry.go new file mode 100644 index 0000000..2c6787b --- /dev/null +++ b/pkg/services/channels/registry.go @@ -0,0 +1,82 @@ +package channels + +import ( + "fmt" + "log/slog" + "sync" + + "github.com/go-chi/chi/v5" + "github.com/liut/morign/pkg/models/channel" +) + +// HTTPRouter is an optional interface for channels that support HTTP webhook callbacks. +// Implementations should register their GET (verification) and POST (message) handlers on the router. +type HTTPRouter interface { + RegisterHTTPRoutes(r chi.Router, callbackPath string, handler channel.MessageHandler) +} + +// Factory creates a channel instance from configuration options. +type Factory func(opts map[string]any) (channel.Channel, error) + +// Registry manages channel adapters. +type Registry struct { + channels map[string]Factory + started map[string]channel.Channel + mu sync.Mutex +} + +var registry = &Registry{ + channels: make(map[string]Factory), + started: make(map[string]channel.Channel), +} + +// RegisterChannel registers a channel factory under the given name. +// Each channel package should call this in its init() function. +func RegisterChannel(name string, factory Factory) { + registry.mu.Lock() + defer registry.mu.Unlock() + if _, exists := registry.channels[name]; exists { + slog.Warn("channel: overwriting existing channel registration", + "channel", name) + } + registry.channels[name] = factory + slog.Info("channel: registered", "name", name) +} + +// NewChannel creates a new channel instance by name with the given options. +func NewChannel(name string, opts map[string]any) (channel.Channel, error) { + factory, exists := registry.channels[name] + if !exists { + return nil, fmt.Errorf("channel %q not registered, available: %v", + name, availableChannels()) + } + return factory(opts) +} + +// TrackChannel adds a started channel to the registry with a unique key. +func TrackChannel(key string, p channel.Channel) { + registry.mu.Lock() + defer registry.mu.Unlock() + registry.started[key] = p +} + +// StopAll stops all tracked channels. +func StopAll() { + registry.mu.Lock() + defer registry.mu.Unlock() + for name, p := range registry.started { + if err := p.Stop(); err != nil { + slog.Warn("channel: stop failed", "name", name, "error", err) + } + } + registry.started = make(map[string]channel.Channel) +} + +// availableChannels returns a list of registered channel names. +func availableChannels() []string { + var names []string + for name := range registry.channels { + names = append(names, name) + } + return names +} diff --git a/pkg/services/channels/wecom/websocket.go b/pkg/services/channels/wecom/websocket.go new file mode 100644 index 0000000..a5965ee --- /dev/null +++ b/pkg/services/channels/wecom/websocket.go @@ -0,0 +1,521 @@ +package wecom + +import ( + "context" + "encoding/json" + "fmt" + "log/slog" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/cupogo/andvari/models/oid" + "github.com/gorilla/websocket" + "github.com/liut/morign/pkg/models/channel" + "github.com/liut/morign/pkg/services/channels" + "github.com/liut/morign/pkg/services/stores" +) + +const ( + wsEndpoint = "wss://openws.work.weixin.qq.com" + wsPingInterval = 30 * time.Second + wsMaxBackoff = 30 * time.Second + wsMaxMissed = 2 +) + +const wsAckTimeout = 5 * time.Second + +// WSChannel implements channel.Channel using WeChat Work WebSocket long-connection mode. +type WSChannel struct { + botID string + secret string + allowFrom string + conn *websocket.Conn + handler channel.MessageHandler + ctx context.Context + cancel context.CancelFunc + mu sync.Mutex + dedup *channels.Dedup + // reqSeq atomic.Int64 + missedPong atomic.Int32 + pendingAcks sync.Map +} + +// wsReplyContext holds context needed to reply to a specific message. +type wsReplyContext struct { + reqID string + chatID string + chatType string + userID string +} + +type wsFrame struct { + Cmd string `json:"cmd,omitempty"` + Headers wsFrameHeaders `json:"headers"` + Body json.RawMessage `json:"body,omitempty"` + ErrCode *int `json:"errcode,omitempty"` + ErrMsg string `json:"errmsg,omitempty"` +} + +type wsFrameHeaders struct { + ReqID string `json:"req_id"` +} + +type wsMsgCallbackBody struct { + MsgID string `json:"msgid"` + AibotID string `json:"aibotid"` + ChatID string `json:"chatid"` + ChatType string `json:"chattype"` + From struct { + UserID string `json:"userid"` + } `json:"from"` + MsgType string `json:"msgtype"` + Text struct { + Content string `json:"content"` + } `json:"text"` + Voice struct { + Text string `json:"text"` + } `json:"voice"` + CreateTime int64 `json:"create_time"` +} + +func newWebSocket(opts map[string]any) (channel.Channel, error) { + botID, _ := opts["bot_id"].(string) + secret, _ := opts["bot_secret"].(string) + if botID == "" || secret == "" { + return nil, fmt.Errorf("wecom-ws: bot_id and bot_secret are required for websocket mode") + } + allowFrom, _ := opts["allow_from"].(string) + + return &WSChannel{ + botID: botID, + secret: secret, + allowFrom: allowFrom, + dedup: channels.NewDedup(stores.SgtRC()), + }, nil +} + +func (p *WSChannel) generateReqID(prefix string) string { + id := oid.NewID(oid.OtEvent) + return id.String() + // seq := p.reqSeq.Add(1) + // return fmt.Sprintf("%s_%d", prefix, seq) +} + +func (p *WSChannel) Name() string { return "wecom" } + +func (p *WSChannel) Start(handler channel.MessageHandler) error { + p.handler = handler + p.ctx, p.cancel = context.WithCancel(context.Background()) + go p.connectLoop() + return nil +} + +func (p *WSChannel) connectLoop() { + backoff := time.Second + for { + select { + case <-p.ctx.Done(): + return + default: + } + + start := time.Now() + err := p.runConnection() + if p.ctx.Err() != nil { + return + } + + if time.Since(start) > 2*wsPingInterval { + backoff = time.Second + } + + slog.Warn("wecom-ws: connection lost, reconnecting", "error", err, "backoff", backoff) + select { + case <-time.After(backoff): + case <-p.ctx.Done(): + return + } + + backoff *= 2 + if backoff > wsMaxBackoff { + backoff = wsMaxBackoff + } + } +} + +func (p *WSChannel) runConnection() error { + slog.Info("wecom-ws: connecting", "endpoint", wsEndpoint) + + conn, _, err := websocket.DefaultDialer.DialContext(p.ctx, wsEndpoint, nil) + if err != nil { + slog.Info("wecom-ws: dial failed", "error", err) + return fmt.Errorf("dial: %w", err) + } + + p.mu.Lock() + p.conn = conn + p.mu.Unlock() + + defer func() { + slog.Debug("wecom-ws: connection closed") + p.mu.Lock() + p.conn = nil + p.mu.Unlock() + conn.Close() + + var staleKeys []any + p.pendingAcks.Range(func(key, value any) bool { + if ch, ok := value.(chan error); ok { + select { + case ch <- fmt.Errorf("wecom-ws: connection closed"): + default: + } + } + staleKeys = append(staleKeys, key) + return true + }) + for _, k := range staleKeys { + p.pendingAcks.Delete(k) + } + }() + + subReqID := p.generateReqID("aibot_subscribe") + subFrame := map[string]any{ + "cmd": "aibot_subscribe", + "headers": map[string]string{"req_id": subReqID}, + "body": map[string]string{ + "bot_id": p.botID, + "secret": p.secret, + }, + } + if err := p.writeJSON(subFrame); err != nil { + slog.Info("wecom-ws: subscribe write failed", "error", err) + return fmt.Errorf("subscribe: %w", err) + } + + var subResp wsFrame + if err := conn.ReadJSON(&subResp); err != nil { + slog.Info("wecom-ws: subscribe response failed", "error", err) + return fmt.Errorf("subscribe response: %w", err) + } + if subResp.ErrCode == nil || *subResp.ErrCode != 0 { + errCode := 0 + if subResp.ErrCode != nil { + errCode = *subResp.ErrCode + } + slog.Info("wecom-ws: subscribe rejected", "errcode", errCode, "errmsg", subResp.ErrMsg) + return fmt.Errorf("subscribe failed: errcode=%d errmsg=%s", errCode, subResp.ErrMsg) + } + slog.Info("wecom-ws: subscribed successfully", "bot_id", p.botID, "subReqID", subReqID) + p.missedPong.Store(0) + + heartCtx, heartCancel := context.WithCancel(p.ctx) + defer heartCancel() + slog.Debug("wecom-ws: heartbeat starting") + go p.heartbeat(heartCtx, conn) + + for { + select { + case <-p.ctx.Done(): + return p.ctx.Err() + default: + } + + _, raw, err := conn.ReadMessage() + if err != nil { + slog.Info("wecom-ws: read message failed", "error", err) + return fmt.Errorf("read: %w", err) + } + + var frame wsFrame + if err := json.Unmarshal(raw, &frame); err != nil { + slog.Warn("wecom-ws: invalid json", "error", err) + continue + } + + p.handleFrame(frame) + } +} + +func (p *WSChannel) handleFrame(frame wsFrame) { + switch frame.Cmd { + case "aibot_msg_callback": + p.handleMsgCallback(frame) + case "aibot_event_callback": + slog.Debug("wecom-ws: event callback received (ignored)", "req_id", frame.Headers.ReqID) + case "ping", "": + // pong response: cmd can be "ping" or "" depending on server version + p.missedPong.Store(0) + slog.Debug("wecom-ws: heartbeat ack received", "cmd", frame.Cmd, "req_id", frame.Headers.ReqID) + case "aibot_subscribe": + slog.Debug("wecom-ws: late subscribe ack", "req_id", frame.Headers.ReqID) + default: + var ackErr error + if frame.ErrCode != nil && *frame.ErrCode != 0 { + ackErr = fmt.Errorf("wecom-ws: ack error: errcode=%d errmsg=%s", *frame.ErrCode, frame.ErrMsg) + slog.Warn("wecom-ws: reply/send ack error", "req_id", frame.Headers.ReqID, "errcode", *frame.ErrCode, "errmsg", frame.ErrMsg) + } + if ch, ok := p.pendingAcks.LoadAndDelete(frame.Headers.ReqID); ok { + ch.(chan error) <- ackErr + } + } +} + +func (p *WSChannel) heartbeat(ctx context.Context, conn *websocket.Conn) { + ticker := time.NewTicker(wsPingInterval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + missed := int(p.missedPong.Load()) + if missed >= wsMaxMissed { + slog.Warn("wecom-ws: no heartbeat ack, connection considered dead", "missed", missed) + conn.Close() + return + } + + p.missedPong.Add(1) + pingFrame := map[string]any{ + "cmd": "ping", + "headers": map[string]string{"req_id": p.generateReqID("ping")}, + } + if err := p.writeJSON(pingFrame); err != nil { + slog.Warn("wecom-ws: ping failed", "error", err) + return + } + slog.Debug("wecom-ws: ping sent", "missed_pong", p.missedPong.Load()) + } + } +} + +func (p *WSChannel) handleMsgCallback(frame wsFrame) { + var body wsMsgCallbackBody + if err := json.Unmarshal(frame.Body, &body); err != nil { + slog.Warn("wecom-ws: parse msg_callback body failed", "error", err) + return + } + + reqID := frame.Headers.ReqID + + if p.dedup != nil { + ctx := context.Background() + isDup, _ := p.dedup.IsDuplicate(ctx, "wecom", body.MsgID) + if isDup { + slog.Info("wecom-ws: skipping duplicate message", "msg_id", body.MsgID) + return + } + } + + // Check allow_from filter + if !channels.AllowList(p.allowFrom, body.From.UserID) { + slog.Info("wecom-ws: message from unauthorized user", "user", body.From.UserID) + return + } + + chatID := body.ChatID + if chatID == "" { + chatID = body.From.UserID + } + + sessionKey := fmt.Sprintf("wecom:%s:%s", chatID, body.From.UserID) + rctx := wsReplyContext{ + reqID: reqID, + chatID: chatID, + chatType: body.ChatType, + userID: body.From.UserID, + } + + chatName := "" + if body.ChatType == "group" { + chatName = body.ChatID + } + + switch body.MsgType { + case "text": + text := stripWeComAtMentions(body.Text.Content, p.botID) + slog.Debug("wecom-ws: text received", "user", body.From.UserID, "len", len(text)) + go p.handler(p, &channel.Message{ + SessionKey: sessionKey, Channel: "wecom", + MessageID: body.MsgID, + UserID: body.From.UserID, UserName: body.From.UserID, + ChatName: chatName, + Content: text, ReplyCtx: rctx, + }) + + case "voice": + text := stripWeComAtMentions(body.Voice.Text, p.botID) + if text == "" { + slog.Debug("wecom-ws: voice message with empty transcription, ignoring") + return + } + slog.Debug("wecom-ws: voice received (transcribed)", "user", body.From.UserID, "len", len(text)) + go p.handler(p, &channel.Message{ + SessionKey: sessionKey, Channel: "wecom", + MessageID: body.MsgID, + UserID: body.From.UserID, UserName: body.From.UserID, + ChatName: chatName, + Content: text, ReplyCtx: rctx, FromVoice: true, + }) + + default: + slog.Debug("wecom-ws: ignoring unsupported message type", "type", body.MsgType) + } +} + +func (p *WSChannel) Reply(ctx context.Context, rctx any, content string) error { + rc, ok := rctx.(wsReplyContext) + if !ok { + return fmt.Errorf("wecom-ws: invalid reply context type %T", rctx) + } + if content == "" { + return nil + } + + streamID := p.generateReqID("stream") + frame := map[string]any{ + "cmd": "aibot_respond_msg", + "headers": map[string]string{"req_id": rc.reqID}, + "body": map[string]any{ + "msgtype": "stream", + "stream": map[string]any{ + "id": streamID, + "finish": true, + "content": content, + }, + }, + } + if err := p.writeJSON(frame); err != nil { + slog.Error("wecom-ws: reply failed", "user", rc.userID, "error", err) + return err + } + slog.Debug("wecom-ws: reply sent", "user", rc.userID, "len", len(content)) + return nil +} + +func (p *WSChannel) Send(ctx context.Context, rctx any, content string) error { + rc, ok := rctx.(wsReplyContext) + if !ok { + return fmt.Errorf("wecom-ws: invalid reply context type %T", rctx) + } + if content == "" { + return nil + } + if rc.chatID == "" { + slog.Info("wecom-ws: chatID is empty, cannot send proactive message") + return fmt.Errorf("wecom-ws: chatID is empty, cannot send proactive message") + } + + chunks := splitByBytes(content, 2000) + for i, chunk := range chunks { + reqID := p.generateReqID("aibot_send_msg") + frame := map[string]any{ + "cmd": "aibot_send_msg", + "headers": map[string]string{"req_id": reqID}, + "body": map[string]any{ + "chatid": rc.chatID, + "msgtype": "markdown", + "markdown": map[string]string{ + "content": chunk, + }, + }, + } + if err := p.writeAndWaitAck(ctx, frame, reqID); err != nil { + slog.Error("wecom-ws: send failed", "user", rc.userID, "chunk", i, "error", err) + return err + } + } + slog.Debug("wecom-ws: message sent", "user", rc.userID, "chunks", len(chunks)) + return nil +} + +func (p *WSChannel) ReconstructReplyCtx(sessionKey string) (any, error) { + parts := strings.SplitN(sessionKey, ":", 3) + if len(parts) < 3 || parts[0] != "wecom" { + return nil, fmt.Errorf("wecom-ws: invalid session key %q", sessionKey) + } + return wsReplyContext{chatID: parts[1], userID: parts[2]}, nil +} + +func (p *WSChannel) Stop() error { + if p.cancel != nil { + p.cancel() + } + p.mu.Lock() + conn := p.conn + p.mu.Unlock() + if conn != nil { + return conn.Close() + } + return nil +} + +func (p *WSChannel) writeJSON(v any) error { + p.mu.Lock() + defer p.mu.Unlock() + if p.conn == nil { + slog.Debug("wecom-ws: writeJSON called while not connected") + return fmt.Errorf("wecom-ws: not connected") + } + return p.conn.WriteJSON(v) +} + +func (p *WSChannel) writeAndWaitAck(ctx context.Context, frame map[string]any, reqID string) error { + ch := make(chan error, 1) + p.pendingAcks.Store(reqID, ch) + + if err := p.writeJSON(frame); err != nil { + p.pendingAcks.Delete(reqID) + return err + } + + select { + case err := <-ch: + return err + case <-ctx.Done(): + p.pendingAcks.Delete(reqID) + return ctx.Err() + case <-time.After(wsAckTimeout): + p.pendingAcks.Delete(reqID) + slog.Debug("wecom-ws: ack timeout, proceeding", "req_id", reqID) + return nil + } +} + +// stripWeComAtMentions removes @bot mentions from message content. +func stripWeComAtMentions(content, botID string) string { + if content == "" { + return "" + } + botMention := "@" + botID + content = strings.ReplaceAll(content, botMention, "") + content = strings.ReplaceAll(content, "@所有人", "") + return strings.TrimSpace(content) +} + +// splitByBytes splits text by UTF-8 byte length (WeCom limit is ~2000 bytes). +func splitByBytes(s string, maxBytes int) []string { + if len(s) <= maxBytes { + return []string{s} + } + var parts []string + for len(s) > 0 { + end := maxBytes + if end > len(s) { + end = len(s) + } + for end > 0 && end < len(s) && s[end]>>6 == 0b10 { + end-- + } + if end == 0 { + end = maxBytes + } + parts = append(parts, s[:end]) + s = s[end:] + } + return parts +} diff --git a/pkg/services/channels/wecom/wecom.go b/pkg/services/channels/wecom/wecom.go new file mode 100644 index 0000000..dd59866 --- /dev/null +++ b/pkg/services/channels/wecom/wecom.go @@ -0,0 +1,25 @@ +package wecom + +import ( + "fmt" + + "github.com/liut/morign/pkg/models/channel" + "github.com/liut/morign/pkg/services/channels" +) + +func init() { + channels.RegisterChannel("wecom", New) +} + +// New creates a WeCom channel adapter. +func New(opts map[string]any) (channel.Channel, error) { + mode, _ := opts["mode"].(string) + switch mode { + case "websocket": + return newWebSocket(opts) + case "webhook": + return newHTTP(opts) + default: + return nil, fmt.Errorf("wecom: unsupported mode %q (supported: websocket, webhook)", mode) + } +} diff --git a/pkg/services/channels/wecom/wecom_http.go b/pkg/services/channels/wecom/wecom_http.go new file mode 100644 index 0000000..0c8a0b8 --- /dev/null +++ b/pkg/services/channels/wecom/wecom_http.go @@ -0,0 +1,514 @@ +package wecom + +import ( + "context" + "crypto/aes" + "crypto/cipher" + "crypto/sha1" + "encoding/base64" + "encoding/binary" + "encoding/json" + "encoding/xml" + "fmt" + "io" + "log/slog" + "net/http" + "net/url" + "sort" + "strconv" + "strings" + "sync" + "time" + + "github.com/go-chi/chi/v5" + + "github.com/liut/morign/pkg/models/channel" + "github.com/liut/morign/pkg/services/channels" +) + +// HTTPChannel implements channel.Channel for WeCom HTTP webhook. +type HTTPChannel struct { + corpID string + corpSecret string + agentID string + allowFrom string + token string + aesKey []byte + callbackPath string + enableMarkdown bool + handler channel.MessageHandler + apiClient *http.Client + tokenCache tokenCache + userNameCache sync.Map +} + +type tokenCache struct { + mu sync.Mutex + token string + expiresAt time.Time +} + +type replyContext struct { + userID string +} + +func newHTTP(opts map[string]any) (channel.Channel, error) { + corpID, _ := opts["corp_id"].(string) + corpSecret, _ := opts["corp_secret"].(string) + agentID, _ := opts["agent_id"].(string) + callbackToken, _ := opts["callback_token"].(string) + callbackAESKey, _ := opts["callback_aes_key"].(string) + + if corpID == "" || corpSecret == "" || agentID == "" { + return nil, fmt.Errorf("wecom-http: corp_id, corp_secret, and agent_id are required") + } + if callbackToken == "" || callbackAESKey == "" { + return nil, fmt.Errorf("wecom-http: callback_token and callback_aes_key are required") + } + + aesKey, err := decodeAESKey(callbackAESKey) + if err != nil { + return nil, fmt.Errorf("wecom-http: invalid callback_aes_key: %w", err) + } + + transport := &http.Transport{ + MaxIdleConns: 2, + MaxIdleConnsPerHost: 1, + IdleConnTimeout: 10 * time.Second, + } + if proxyURL, _ := opts["proxy"].(string); proxyURL != "" { + u, err := url.Parse(proxyURL) + if err != nil { + return nil, fmt.Errorf("wecom-http: invalid proxy URL %q: %w", proxyURL, err) + } + proxyUser, _ := opts["proxy_username"].(string) + proxyPass, _ := opts["proxy_password"].(string) + if proxyUser != "" { + u.User = url.UserPassword(proxyUser, proxyPass) + } + transport.Proxy = http.ProxyURL(u) + transport.DisableKeepAlives = true + } + apiClient := &http.Client{Timeout: 30 * time.Second, Transport: transport} + + enableMarkdown, _ := opts["enable_markdown"].(bool) + allowFrom, _ := opts["allow_from"].(string) + + return &HTTPChannel{ + corpID: corpID, + corpSecret: corpSecret, + agentID: agentID, + allowFrom: allowFrom, + token: callbackToken, + aesKey: aesKey, + enableMarkdown: enableMarkdown, + apiClient: apiClient, + }, nil +} + +func (p *HTTPChannel) Name() string { return "wecom" } + +func (p *HTTPChannel) Start(handler channel.MessageHandler) error { + p.handler = handler + return nil +} + +func (p *HTTPChannel) RegisterHTTPRoutes(r chi.Router, callbackPath string, handler channel.MessageHandler) { + p.callbackPath = callbackPath + r.Method(http.MethodGet, callbackPath, http.HandlerFunc(p.handleVerify)) + r.Method(http.MethodPost, callbackPath, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + p.handleMessage(w, r, handler) + })) + slog.Info("wecom-http: routes registered", "path", callbackPath) +} + +func (p *HTTPChannel) handleVerify(w http.ResponseWriter, r *http.Request) { + q := r.URL.Query() + msgSignature := q.Get("msg_signature") + timestamp := q.Get("timestamp") + nonce := q.Get("nonce") + echostr := q.Get("echostr") + + if !p.verifySignature(msgSignature, timestamp, nonce, echostr) { + slog.Warn("wecom-http: verify signature failed") + w.WriteHeader(http.StatusForbidden) + return + } + + plain, err := p.decrypt(echostr) + if err != nil { + slog.Error("wecom-http: decrypt echostr failed", "error", err) + w.WriteHeader(http.StatusInternalServerError) + return + } + + slog.Info("wecom-http: URL verification succeeded") + w.Header().Set("Content-Type", "text/plain") + w.WriteHeader(http.StatusOK) + fmt.Fprint(w, plain) +} + +func (p *HTTPChannel) handleMessage(w http.ResponseWriter, r *http.Request, handler channel.MessageHandler) { + q := r.URL.Query() + msgSignature := q.Get("msg_signature") + timestamp := q.Get("timestamp") + nonce := q.Get("nonce") + + body, err := io.ReadAll(io.LimitReader(r.Body, 1<<20)) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + return + } + + var encMsg xmlEncryptedMsg + if err := xml.Unmarshal(body, &encMsg); err != nil { + slog.Error("wecom-http: parse xml failed", "error", err) + w.WriteHeader(http.StatusBadRequest) + return + } + + if !p.verifySignature(msgSignature, timestamp, nonce, encMsg.Encrypt) { + slog.Warn("wecom-http: message signature verification failed") + w.WriteHeader(http.StatusForbidden) + return + } + + plainXML, err := p.decrypt(encMsg.Encrypt) + if err != nil { + slog.Error("wecom-http: decrypt message failed", "error", err) + w.WriteHeader(http.StatusInternalServerError) + return + } + + // Return 200 immediately (WeChat Work requires response within 5 seconds) + w.WriteHeader(http.StatusOK) + + var msg xmlMessage + if err := xml.Unmarshal([]byte(plainXML), &msg); err != nil { + slog.Error("wecom-http: parse decrypted xml failed", "error", err) + return + } + + // Check allow_from filter + if !channels.AllowList(p.allowFrom, msg.FromUserName) { + slog.Info("wecom-http: message from unauthorized user", "user", msg.FromUserName) + return + } + + sessionKey := fmt.Sprintf("wecom:%s", msg.FromUserName) + rctx := replyContext{userID: msg.FromUserName} + + switch msg.MsgType { + case "text": + text := stripAtMentions(msg.Content, p.agentID) + slog.Debug("wecom-http: message received", "user", msg.FromUserName, "text_len", len(text)) + go handler(p, &channel.Message{ + SessionKey: sessionKey, Channel: "wecom", + MessageID: strconv.FormatInt(msg.MsgId, 10), + UserID: msg.FromUserName, UserName: p.resolveUserName(msg.FromUserName), + Content: text, ReplyCtx: rctx, + }) + + case "image": + slog.Debug("wecom-http: image received", "user", msg.FromUserName) + go func() { + imgData, err := p.downloadMedia(msg.MediaId) + if err != nil { + slog.Error("wecom-http: download image failed", "error", err) + return + } + handler(p, &channel.Message{ + SessionKey: sessionKey, Channel: "wecom", + MessageID: strconv.FormatInt(msg.MsgId, 10), + UserID: msg.FromUserName, UserName: p.resolveUserName(msg.FromUserName), + Images: []channel.ImageAttachment{{MimeType: "image/jpeg", Data: imgData}}, + ReplyCtx: rctx, + }) + }() + + case "voice": + slog.Debug("wecom-http: voice received", "user", msg.FromUserName, "format", msg.Format) + go func() { + audioData, err := p.downloadMedia(msg.MediaId) + if err != nil { + slog.Error("wecom-http: download voice failed", "error", err) + return + } + format := strings.ToLower(msg.Format) + if format == "" { + format = "amr" + } + handler(p, &channel.Message{ + SessionKey: sessionKey, Channel: "wecom", + MessageID: strconv.FormatInt(msg.MsgId, 10), + UserID: msg.FromUserName, UserName: p.resolveUserName(msg.FromUserName), + Audio: &channel.AudioAttachment{MimeType: "audio/" + format, Data: audioData, Format: format}, + ReplyCtx: rctx, + }) + }() + + default: + slog.Debug("wecom-http: ignoring unsupported message type", "type", msg.MsgType) + } +} + +func (p *HTTPChannel) Reply(ctx context.Context, rctx any, content string) error { + rc, ok := rctx.(replyContext) + if !ok { + return fmt.Errorf("wecom-http: invalid reply context type %T", rctx) + } + if content == "" { + return nil + } + + accessToken, err := p.getAccessToken() + if err != nil { + return fmt.Errorf("wecom-http: get access_token: %w", err) + } + + chunks := splitByBytes(content, 2000) + for i, chunk := range chunks { + var sendErr error + if p.enableMarkdown { + sendErr = p.sendMarkdown(accessToken, rc.userID, chunk) + } else { + sendErr = p.sendText(accessToken, rc.userID, chunk) + } + if sendErr != nil { + slog.Error("wecom-http: send failed", "user", rc.userID, "chunk", i, "error", sendErr) + return sendErr + } + } + return nil +} + +func (p *HTTPChannel) Send(ctx context.Context, rctx any, content string) error { + return p.Reply(ctx, rctx, content) +} + +func (p *HTTPChannel) Stop() error { + return nil +} + +func (p *HTTPChannel) ReconstructReplyCtx(sessionKey string) (any, error) { + parts := strings.SplitN(sessionKey, ":", 2) + if len(parts) < 2 || parts[0] != "wecom" { + return nil, fmt.Errorf("wecom-http: invalid session key %q", sessionKey) + } + return replyContext{userID: parts[1]}, nil +} + +func (p *HTTPChannel) sendMarkdown(accessToken, toUser, content string) error { + payload := map[string]any{ + "touser": toUser, + "msgtype": "markdown", + "agentid": p.agentID, + "markdown": map[string]string{"content": content}, + } + + body, _ := json.Marshal(payload) + apiURL := "https://qyapi.weixin.qq.com/cgi-bin/message/send?access_token=" + accessToken + + resp, err := p.apiClient.Post(apiURL, "application/json", strings.NewReader(string(body))) + if err != nil { + return fmt.Errorf("wecom-http: send markdown: %w", err) + } + defer resp.Body.Close() + + var result struct { + ErrCode int `json:"errcode"` + ErrMsg string `json:"errmsg"` + } + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return fmt.Errorf("wecom-http: decode send response: %w", err) + } + if result.ErrCode != 0 { + return fmt.Errorf("wecom-http: send markdown failed: %d %s", result.ErrCode, result.ErrMsg) + } + return nil +} + +func (p *HTTPChannel) sendText(accessToken, toUser, text string) error { + payload := map[string]any{ + "touser": toUser, + "msgtype": "text", + "agentid": p.agentID, + "text": map[string]string{"content": text}, + "safe": 0, + } + + body, _ := json.Marshal(payload) + apiURL := "https://qyapi.weixin.qq.com/cgi-bin/message/send?access_token=" + accessToken + + resp, err := p.apiClient.Post(apiURL, "application/json", strings.NewReader(string(body))) + if err != nil { + return fmt.Errorf("wecom-http: send message: %w", err) + } + defer resp.Body.Close() + + var result struct { + ErrCode int `json:"errcode"` + ErrMsg string `json:"errmsg"` + } + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return fmt.Errorf("wecom-http: decode send response: %w", err) + } + if result.ErrCode != 0 { + return fmt.Errorf("wecom-http: send failed: %d %s", result.ErrCode, result.ErrMsg) + } + return nil +} + +func (p *HTTPChannel) getAccessToken() (string, error) { + p.tokenCache.mu.Lock() + defer p.tokenCache.mu.Unlock() + + if p.tokenCache.token != "" && time.Now().Before(p.tokenCache.expiresAt) { + return p.tokenCache.token, nil + } + + apiURL := fmt.Sprintf( + "https://qyapi.weixin.qq.com/cgi-bin/gettoken?corpid=%s&corpsecret=%s", + p.corpID, p.corpSecret, + ) + + resp, err := p.apiClient.Get(apiURL) + if err != nil { + return "", fmt.Errorf("wecom-http: request access_token: %w", err) + } + defer resp.Body.Close() + + var result struct { + ErrCode int `json:"errcode"` + ErrMsg string `json:"errmsg"` + AccessToken string `json:"access_token"` + ExpiresIn int `json:"expires_in"` + } + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return "", fmt.Errorf("wecom-http: decode token response: %w", err) + } + if result.ErrCode != 0 { + return "", fmt.Errorf("wecom-http: get token failed: %d %s", result.ErrCode, result.ErrMsg) + } + + p.tokenCache.token = result.AccessToken + p.tokenCache.expiresAt = time.Now().Add(time.Duration(result.ExpiresIn-60) * time.Second) + return result.AccessToken, nil +} + +func (p *HTTPChannel) downloadMedia(mediaID string) ([]byte, error) { + accessToken, err := p.getAccessToken() + if err != nil { + return nil, fmt.Errorf("get token: %w", err) + } + u := fmt.Sprintf("https://qyapi.weixin.qq.com/cgi-bin/media/get?access_token=%s&media_id=%s", accessToken, mediaID) + resp, err := p.apiClient.Get(u) + if err != nil { + return nil, fmt.Errorf("download: %w", err) + } + defer resp.Body.Close() + return io.ReadAll(resp.Body) +} + +func (p *HTTPChannel) resolveUserName(userID string) string { + if cached, ok := p.userNameCache.Load(userID); ok { + return cached.(string) + } + accessToken, err := p.getAccessToken() + if err != nil { + return userID + } + apiURL := fmt.Sprintf("https://qyapi.weixin.qq.com/cgi-bin/user/get?access_token=%s&userid=%s", accessToken, url.QueryEscape(userID)) + resp, err := p.apiClient.Get(apiURL) + if err != nil { + return userID + } + defer resp.Body.Close() + var result struct { + ErrCode int `json:"errcode"` + Name string `json:"name"` + } + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil || result.ErrCode != 0 { + return userID + } + if result.Name != "" { + p.userNameCache.Store(userID, result.Name) + return result.Name + } + return userID +} + +func (p *HTTPChannel) verifySignature(expected, timestamp, nonce, encrypt string) bool { + parts := []string{p.token, timestamp, nonce, encrypt} + sort.Strings(parts) + h := sha1.New() + h.Write([]byte(strings.Join(parts, ""))) + got := fmt.Sprintf("%x", h.Sum(nil)) + return got == expected +} + +func decodeAESKey(encodingAESKey string) ([]byte, error) { + if len(encodingAESKey) != 43 { + return nil, fmt.Errorf("EncodingAESKey must be 43 characters, got %d", len(encodingAESKey)) + } + return base64.StdEncoding.DecodeString(encodingAESKey + "=") +} + +func (p *HTTPChannel) decrypt(cipherBase64 string) (string, error) { + cipherData, err := base64.StdEncoding.DecodeString(cipherBase64) + if err != nil { + return "", fmt.Errorf("base64 decode: %w", err) + } + + block, err := aes.NewCipher(p.aesKey) + if err != nil { + return "", fmt.Errorf("aes new cipher: %w", err) + } + + if len(cipherData) < aes.BlockSize || len(cipherData)%aes.BlockSize != 0 { + return "", fmt.Errorf("invalid ciphertext length %d", len(cipherData)) + } + + iv := p.aesKey[:16] + mode := cipher.NewCBCDecrypter(block, iv) + plain := make([]byte, len(cipherData)) + mode.CryptBlocks(plain, cipherData) + + plain = pkcs7Unpad(plain) + + if len(plain) < 20 { + return "", fmt.Errorf("decrypted data too short") + } + + msgLen := int(binary.BigEndian.Uint32(plain[16:20])) + if 20+msgLen > len(plain) { + return "", fmt.Errorf("invalid message length %d in decrypted data (total %d)", msgLen, len(plain)) + } + + msg := string(plain[20 : 20+msgLen]) + corpID := string(plain[20+msgLen:]) + + if corpID != p.corpID { + return "", fmt.Errorf("corp_id mismatch: expected %s, got %s", p.corpID, corpID) + } + + return msg, nil +} + +func pkcs7Unpad(data []byte) []byte { + if len(data) == 0 { + return data + } + pad := int(data[len(data)-1]) + if pad < 1 || pad > 32 || pad > len(data) { + return data + } + return data[:len(data)-pad] +} + +func stripAtMentions(content, agentID string) string { + agentIDStr := "@" + agentID + return strings.ReplaceAll(content, agentIDStr, "") +} + +var _ channel.Channel = (*HTTPChannel)(nil) diff --git a/pkg/services/channels/wecom/wecom_msg.go b/pkg/services/channels/wecom/wecom_msg.go new file mode 100644 index 0000000..f7e5b14 --- /dev/null +++ b/pkg/services/channels/wecom/wecom_msg.go @@ -0,0 +1,28 @@ +package wecom + +import ( + "encoding/xml" +) + +// Incoming XML envelope from WeChat Work callback. +type xmlEncryptedMsg struct { + XMLName xml.Name `xml:"xml"` + ToUserName string `xml:"ToUserName"` + AgentID string `xml:"AgentID"` + Encrypt string `xml:"Encrypt"` +} + +// Decrypted message body. +type xmlMessage struct { + XMLName xml.Name `xml:"xml"` + ToUserName string `xml:"ToUserName"` + FromUserName string `xml:"FromUserName"` + CreateTime int64 `xml:"CreateTime"` + MsgType string `xml:"MsgType"` + Content string `xml:"Content"` + PicUrl string `xml:"PicUrl"` + MediaId string `xml:"MediaId"` + Format string `xml:"Format"` // voice format: amr, speex, etc. + MsgId int64 `xml:"MsgId"` + AgentID int64 `xml:"AgentID"` +} diff --git a/pkg/services/stores/conversation.go b/pkg/services/stores/conversation.go index 6ba46a0..a464005 100644 --- a/pkg/services/stores/conversation.go +++ b/pkg/services/stores/conversation.go @@ -34,6 +34,24 @@ func NewConversation(ctx context.Context, id any) Conversation { return newConversation(ctx, id, SgtRC()) } +const sessionKeyCSIDPrefix = "platform:csid:" + +// GetOrCreateConversationBySessionKey 根据渠道 sessionKey 查找或创建 Conversation +// sessionKey 格式: "{channel}:{chatID}:{userID}" 或 "{channel}:{userID}" +// 查找 Redis 映射获取已绑定的 Conversation OID,若无则创建新 Conversation 并写入映射 +func GetOrCreateConversationBySessionKey(ctx context.Context, sessionKey string) Conversation { + key := sessionKeyCSIDPrefix + sessionKey + oidStr, _ := SgtRC().Get(ctx, key).Result() + + cs := NewConversation(ctx, oidStr) + if oidStr == "" { + SgtRC().Set(ctx, key, cs.GetID(), 30*24*time.Hour) + } else { + SgtRC().Expire(ctx, key, 30*24*time.Hour) + } + return cs +} + // newConversation is internal constructor, supports injecting Redis client (for testing) func newConversation(ctx context.Context, id any, rc RedisClient) Conversation { sto := Sgt() diff --git a/pkg/web/api/api.go b/pkg/web/api/api.go index 0625900..3ff448a 100644 --- a/pkg/web/api/api.go +++ b/pkg/web/api/api.go @@ -23,6 +23,9 @@ import ( "github.com/liut/morign/pkg/settings" "github.com/liut/morign/pkg/web/resp" "github.com/liut/morign/pkg/web/routes" + + _ "github.com/liut/morign/pkg/services/channels/feishu" + _ "github.com/liut/morign/pkg/services/channels/wecom" ) var handles = []handleIn{} @@ -50,6 +53,9 @@ type api struct { llm llm.Client preset aigc.Preset toolreg *tools.Registry + toolExec *ToolExecutor + + router chi.Router // 用于平台 HTTP 回调注册 } func init() { @@ -117,11 +123,14 @@ func newapi(sto stores.Storage) *api { llm: stores.GetLLMClient(), preset: preset, toolreg: toolreg, + toolExec: NewToolExecutor(toolreg), } } // Strap 注册路由到 chi.Router func (a *api) Strap(router chi.Router) { + a.router = router + // staffio 认证路由 router.Get(authLoginPath, staffio.LoginHandler) router.Get(authLogoutPath, handleLogout) @@ -178,6 +187,12 @@ func (a *api) Strap(router chi.Router) { limited.Post("/chat", a.postChat) limited.Post("/chat-{suffix}", a.postChat) }) + + // 初始化平台适配器(HTTP webhook 回调等) + if err := InitChannels(a.router, &a.preset, a.sto, a.llm, a.toolreg); err != nil { + logger().Warnw("init channels failed", "err", err) + } + } func (a *api) authPerm(permID string) func(next http.Handler) http.Handler { diff --git a/pkg/web/api/handle_convo.go b/pkg/web/api/handle_convo.go index 1eac99a..c13ada0 100644 --- a/pkg/web/api/handle_convo.go +++ b/pkg/web/api/handle_convo.go @@ -740,68 +740,11 @@ func convertToolCallsForJSON(tcs []llm.ToolCall) []map[string]any { return result } -// chatExecutor 定义聊天执行函数类型,支持流式/非流式 -type chatExecutor func(ctx context.Context, messages []llm.Message, tools []llm.ToolDefinition) (string, []llm.ToolCall, *llm.Usage, error) - // executeToolCallLoop 执行工具调用循环,直到没有 tool calls // - messages: 初始消息列表,会被修改 // - tools: 工具定义 // - exec: 执行聊天的函数(流式或非流式) // 返回最终的 answer、最后的 toolCalls(如果有)、usage func (a *api) executeToolCallLoop(ctx context.Context, messages []llm.Message, tools []llm.ToolDefinition, exec chatExecutor) (string, []llm.ToolCall, *llm.Usage, error) { - for { - answer, toolCalls, usage, err := exec(ctx, messages, tools) - if err != nil { - return "", nil, nil, err - } - - if len(toolCalls) == 0 { - return answer, nil, usage, nil - } - - // 添加 assistant 消息(带 tool calls) - messages = append(messages, llm.Message{ - Role: llm.RoleAssistant, - ToolCalls: toolCalls, - }) - - // 执行工具调用 - for _, tc := range toolCalls { - logger().Infow("chat", "toolCallID", tc.ID, "toolCallType", tc.Type, "toolCallName", tc.Function.Name) - - if tc.Type != "function" { - continue - } - - var parameters map[string]any - args := string(tc.Function.Arguments) - if args != "" && args != "{}" { - if err := json.Unmarshal(tc.Function.Arguments, ¶meters); err != nil { - logger().Infow("chat", "toolCallID", tc.ID, "args", args, "err", err) - continue - } - } - // 空参数时使用空 map - if parameters == nil { - parameters = make(map[string]any) - } - - content, err := a.toolreg.Invoke(ctx, tc.Function.Name, parameters) - if err != nil { - logger().Infow("invokeTool fail", "toolCallName", tc.Function.Name, "err", err) - continue - } - - logger().Infow("invokeTool ok", "toolCallName", tc.Function.Name, - "content", toolsvc.ResultLogs(content)) - messages = append(messages, llm.Message{ - Role: llm.RoleTool, - Content: formatToolResult(content), - ToolCallID: tc.ID, - }) - } - - // 清除工具定义,避免死循环 - tools = nil - } + return a.toolExec.ExecuteToolCallLoop(ctx, messages, tools, exec) } diff --git a/pkg/web/api/handle_platform.go b/pkg/web/api/handle_platform.go new file mode 100644 index 0000000..4160767 --- /dev/null +++ b/pkg/web/api/handle_platform.go @@ -0,0 +1,194 @@ +package api + +import ( + "context" + "log/slog" + "time" + + "github.com/go-chi/chi/v5" + + "github.com/liut/morign/pkg/models/aigc" + "github.com/liut/morign/pkg/models/channel" + "github.com/liut/morign/pkg/services/channels" + "github.com/liut/morign/pkg/services/llm" + "github.com/liut/morign/pkg/services/stores" + "github.com/liut/morign/pkg/services/tools" +) + +// channelHandler holds dependencies for handling channel messages. +type channelHandler struct { + sto stores.Storage + llm llm.Client + toolreg *tools.Registry + toolExec *ToolExecutor +} + +// InitChannels initializes channel adapters from preset configuration. +func InitChannels(r chi.Router, preset *aigc.Preset, sto stores.Storage, llmClient llm.Client, toolreg *tools.Registry) error { + chandler := &channelHandler{ + sto: sto, + llm: llmClient, + toolreg: toolreg, + toolExec: NewToolExecutor(toolreg), + } + + if preset == nil || len(preset.Channels) == 0 { + slog.Info("channel: no platforms configured") + return nil + } + + for name, cfg := range preset.Channels { + if !cfg.Enable { + slog.Debug("channel: skipping disabled channel", "name", name) + continue + } + + // Inject mode into config for channel factory + channelConfig := cfg.Config + if channelConfig == nil { + channelConfig = make(map[string]any) + } + channelConfig["mode"] = cfg.Mode + + p, err := channels.NewChannel(name, channelConfig) + if err != nil { + slog.Warn("channel: create failed", "name", name, "error", err) + continue + } + + if err := p.Start(chandler.MessageHandler); err != nil { + slog.Warn("channel: start failed", "name", name, "error", err) + continue + } + + // Register HTTP routes if channel supports webhook callback + if httpRouter, ok := p.(channels.HTTPRouter); ok { + callbackPath, _ := channelConfig["callback_path"].(string) + if callbackPath == "" { + callbackPath = "/" + name + "/callback" + } + httpRouter.RegisterHTTPRoutes(r, callbackPath, chandler.MessageHandler) + slog.Info("channel: HTTP routes registered", "name", name, "path", callbackPath) + } + + // Use name + mode as unique key to support multiple instances of same channel type + key := name + if cfg.Mode != "" { + key = name + "-" + cfg.Mode + } + channels.TrackChannel(key, p) + slog.Info("channel: started", "name", name, "mode", cfg.Mode, "key", key) + } + + slog.Info("channel: manager initialized") + return nil +} + +// StopChannels stops all channel adapters. +func StopChannels() { + channels.StopAll() +} + +// MessageHandler processes incoming messages from channel adapters. +func (chandler *channelHandler) MessageHandler(p channel.Channel, msg *channel.Message) { + if chandler == nil { + slog.Error("channel: handler not initialized") + return + } + + ctx := context.Background() + + // Build the chat request + cs := stores.GetOrCreateConversationBySessionKey(ctx, msg.SessionKey) + + slog.Info("channel: message received", + "channel", p.Name(), + "session", msg.SessionKey, + "conversation", cs.GetID(), + "user", msg.UserID, + "content_len", len(msg.Content), + ) + + // Prepare system message and tools + sysMsg, tools := prepareSystemMessage(ctx, stores.Sgt(), chandler.toolreg, msg.Content, cs) + + // Build user message with any attachments + content := msg.Content + if len(msg.Images) > 0 { + content += "\n[User sent an image]" + } + if msg.Audio != nil { + content += "\n[User sent a voice message]" + } + + // Load conversation history + messages := []llm.Message{sysMsg} + history, _ := cs.ListHistory(ctx) + for _, hi := range history { + if hi.ChatItem != nil { + if hi.ChatItem.User != "" { + messages = append(messages, llm.Message{Role: llm.RoleUser, Content: hi.ChatItem.User}) + } + if hi.ChatItem.Assistant != "" { + messages = append(messages, llm.Message{Role: llm.RoleAssistant, Content: hi.ChatItem.Assistant}) + } + } + } + messages = append(messages, llm.Message{ + Role: llm.RoleUser, + Content: content, + }) + + // Execute the chat with tool call loop + exec := func(ctx context.Context, messages []llm.Message, tools []llm.ToolDefinition) (string, []llm.ToolCall, *llm.Usage, error) { + result, err := chandler.llm.Chat(ctx, messages, tools) + if err != nil { + return "", nil, nil, err + } + return result.Content, result.ToolCalls, result.Usage, nil + } + answer, _, _, err := chandler.executeToolCallLoop(ctx, messages, tools, exec) + if err != nil { + slog.Error("channel: chat execution failed", + "channel", p.Name(), "error", err) + channelReplyError(p, msg, "AI processing failed") + return + } + + // Save to history (only final answer, not tool call content) + if len(answer) > 0 { + hi := &aigc.HistoryItem{ + Time: time.Now().Unix(), + UID: msg.UserID, + ChatItem: &aigc.HistoryChatItem{ + User: msg.Content, + Assistant: answer, + }, + } + if err := cs.AddHistory(ctx, hi); err == nil { + if err := cs.Save(ctx); err != nil { + slog.Warn("channel: save history failed", "err", err) + } + } + } + + // Send reply to channel + if err := p.Reply(ctx, msg.ReplyCtx, answer); err != nil { + slog.Error("channel: reply failed", + "channel", p.Name(), "error", err) + } +} + +// executeToolCallLoop executes tool calls in a loop until no more tool calls +func (chandler *channelHandler) executeToolCallLoop(ctx context.Context, messages []llm.Message, tools []llm.ToolDefinition, exec chatExecutor) (string, []llm.ToolCall, *llm.Usage, error) { + return chandler.toolExec.ExecuteToolCallLoop(ctx, messages, tools, exec) +} + +// channelReplyError sends an error message back to the channel. +func channelReplyError(p channel.Channel, msg *channel.Message, errorText string) { + ctx := context.Background() + if err := p.Reply(ctx, msg.ReplyCtx, errorText); err != nil { + slog.Error("channel: send error reply failed", + "channel", p.Name(), "error", err) + } +} diff --git a/pkg/web/api/tool_executor.go b/pkg/web/api/tool_executor.go new file mode 100644 index 0000000..f19a9a7 --- /dev/null +++ b/pkg/web/api/tool_executor.go @@ -0,0 +1,83 @@ +package api + +import ( + "context" + "encoding/json" + + "github.com/liut/morign/pkg/services/llm" + "github.com/liut/morign/pkg/services/tools" + toolsvc "github.com/liut/morign/pkg/services/tools" +) + +// chatExecutor 定义聊天执行函数类型,支持流式/非流式 +type chatExecutor func(ctx context.Context, messages []llm.Message, tools []llm.ToolDefinition) (string, []llm.ToolCall, *llm.Usage, error) + +// ToolExecutor 封装工具调用循环逻辑 +type ToolExecutor struct { + toolreg *tools.Registry +} + +// NewToolExecutor 创建 ToolExecutor +func NewToolExecutor(toolreg *tools.Registry) *ToolExecutor { + return &ToolExecutor{toolreg: toolreg} +} + +// ExecuteToolCallLoop 执行工具调用循环,直到无 tool calls +func (e *ToolExecutor) ExecuteToolCallLoop( + ctx context.Context, + messages []llm.Message, + tools []llm.ToolDefinition, + exec chatExecutor, +) (string, []llm.ToolCall, *llm.Usage, error) { + for { + answer, toolCalls, usage, err := exec(ctx, messages, tools) + if err != nil { + return "", nil, nil, err + } + + if len(toolCalls) == 0 { + return answer, nil, usage, nil + } + + // 添加 assistant 消息(带 tool calls) + messages = append(messages, llm.Message{ + Role: llm.RoleAssistant, + ToolCalls: toolCalls, + }) + + // 执行工具调用 + for _, tc := range toolCalls { + logger().Infow("chat", "toolCallID", tc.ID, "toolCallType", tc.Type, "toolCallName", tc.Function.Name) + + if tc.Type != "function" { + continue + } + + var parameters map[string]any + args := string(tc.Function.Arguments) + if args != "" && args != "{}" { + if err := json.Unmarshal(tc.Function.Arguments, ¶meters); err != nil { + logger().Infow("chat", "toolCallID", tc.ID, "args", args, "err", err) + continue + } + } + if parameters == nil { + parameters = make(map[string]any) + } + + content, err := e.toolreg.Invoke(ctx, tc.Function.Name, parameters) + if err != nil { + logger().Infow("invokeTool fail", "toolCallName", tc.Function.Name, "err", err) + continue + } + + logger().Infow("invokeTool ok", "toolCallName", tc.Function.Name, + "content", toolsvc.ResultLogs(content)) + messages = append(messages, llm.Message{ + Role: llm.RoleTool, + Content: formatToolResult(content), + ToolCallID: tc.ID, + }) + } + } +}