From c33e01ee621aece07d2d1f614a261c02628fb4cf Mon Sep 17 00:00:00 2001 From: MiguelPF Date: Wed, 18 Mar 2026 10:11:01 +0100 Subject: [PATCH 01/40] fix(cron): scope cron job store to workspace instead of global directory Replace `get_cron_dir()` with `config.workspace_path / "cron"` so each workspace keeps its own `jobs.json`. This lets users run multiple nanobot instances with independent cron schedules without cross-talk. Co-Authored-By: Claude Opus 4.6 --- nanobot/cli/commands.py | 10 ++++------ tests/test_commands.py | 6 +----- 2 files changed, 5 insertions(+), 11 deletions(-) diff --git a/nanobot/cli/commands.py b/nanobot/cli/commands.py index 0d4bb3de893..cde14365919 100644 --- a/nanobot/cli/commands.py +++ b/nanobot/cli/commands.py @@ -465,7 +465,6 @@ def gateway( from nanobot.agent.loop import AgentLoop from nanobot.bus.queue import MessageBus from nanobot.channels.manager import ChannelManager - from nanobot.config.paths import get_cron_dir from nanobot.cron.service import CronService from nanobot.cron.types import CronJob from nanobot.heartbeat.service import HeartbeatService @@ -485,8 +484,8 @@ def gateway( provider = _make_provider(config) session_manager = SessionManager(config.workspace_path) - # Create cron service first (callback set after agent creation) - cron_store_path = get_cron_dir() / "jobs.json" + # Create cron service with workspace-scoped store + cron_store_path = config.workspace_path / "cron" / "jobs.json" cron = CronService(cron_store_path) # Create agent with cron service @@ -663,7 +662,6 @@ def agent( from nanobot.agent.loop import AgentLoop from nanobot.bus.queue import MessageBus - from nanobot.config.paths import get_cron_dir from nanobot.cron.service import CronService config = _load_runtime_config(config, workspace) @@ -673,8 +671,8 @@ def agent( bus = MessageBus() provider = _make_provider(config) - # Create cron service for tool usage (no callback needed for CLI unless running) - cron_store_path = get_cron_dir() / "jobs.json" + # Create cron service with workspace-scoped store + cron_store_path = config.workspace_path / "cron" / "jobs.json" cron = CronService(cron_store_path) if logs: diff --git a/tests/test_commands.py b/tests/test_commands.py index a820e7755c0..fcb2f6a6b03 100644 --- a/tests/test_commands.py +++ b/tests/test_commands.py @@ -275,10 +275,8 @@ def mock_agent_runtime(tmp_path): """Mock agent command dependencies for focused CLI tests.""" config = Config() config.agents.defaults.workspace = str(tmp_path / "default-workspace") - cron_dir = tmp_path / "data" / "cron" with patch("nanobot.config.loader.load_config", return_value=config) as mock_load_config, \ - patch("nanobot.config.paths.get_cron_dir", return_value=cron_dir), \ patch("nanobot.cli.commands.sync_workspace_templates") as mock_sync_templates, \ patch("nanobot.cli.commands._make_provider", return_value=object()), \ patch("nanobot.cli.commands._print_agent_response") as mock_print_response, \ @@ -351,7 +349,6 @@ def test_agent_config_sets_active_path(monkeypatch, tmp_path: Path) -> None: lambda path: seen.__setitem__("config_path", path), ) monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config) - monkeypatch.setattr("nanobot.config.paths.get_cron_dir", lambda: config_file.parent / "cron") monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None) monkeypatch.setattr("nanobot.cli.commands._make_provider", lambda _config: object()) monkeypatch.setattr("nanobot.bus.queue.MessageBus", lambda: object()) @@ -508,7 +505,6 @@ def test_gateway_uses_config_directory_for_cron_store(monkeypatch, tmp_path: Pat monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None) monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config) - monkeypatch.setattr("nanobot.config.paths.get_cron_dir", lambda: config_file.parent / "cron") monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None) monkeypatch.setattr("nanobot.cli.commands._make_provider", lambda _config: object()) monkeypatch.setattr("nanobot.bus.queue.MessageBus", lambda: object()) @@ -524,7 +520,7 @@ def __init__(self, store_path: Path) -> None: result = runner.invoke(app, ["gateway", "--config", str(config_file)]) assert isinstance(result.exception, _StopGateway) - assert seen["cron_store"] == config_file.parent / "cron" / "jobs.json" + assert seen["cron_store"] == config.workspace_path / "cron" / "jobs.json" def test_gateway_uses_configured_port_when_cli_flag_is_missing(monkeypatch, tmp_path: Path) -> None: From 4e56481f0ba59ce53bfed03e01c941722fdcae20 Mon Sep 17 00:00:00 2001 From: MiguelPF Date: Wed, 18 Mar 2026 10:16:06 +0100 Subject: [PATCH 02/40] add one-time migration for legacy global cron store When upgrading, if jobs.json exists at the old global path and not yet at the workspace path, move it automatically. Prevents silent loss of existing cron jobs. Co-Authored-By: Claude Opus 4.6 --- nanobot/cli/commands.py | 18 ++++++++++++++++++ tests/test_commands.py | 41 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 59 insertions(+) diff --git a/nanobot/cli/commands.py b/nanobot/cli/commands.py index cde14365919..17fe7b86aca 100644 --- a/nanobot/cli/commands.py +++ b/nanobot/cli/commands.py @@ -449,6 +449,18 @@ def _print_deprecated_memory_window_notice(config: Config) -> None: ) +def _migrate_cron_store(config: "Config") -> None: + """One-time migration: move legacy global cron store into the workspace.""" + from nanobot.config.paths import get_cron_dir + + legacy_path = get_cron_dir() / "jobs.json" + new_path = config.workspace_path / "cron" / "jobs.json" + if legacy_path.is_file() and not new_path.exists(): + new_path.parent.mkdir(parents=True, exist_ok=True) + import shutil + shutil.move(str(legacy_path), str(new_path)) + + # ============================================================================ # Gateway / Server # ============================================================================ @@ -484,6 +496,9 @@ def gateway( provider = _make_provider(config) session_manager = SessionManager(config.workspace_path) + # Migrate legacy global cron store into workspace (one-time) + _migrate_cron_store(config) + # Create cron service with workspace-scoped store cron_store_path = config.workspace_path / "cron" / "jobs.json" cron = CronService(cron_store_path) @@ -671,6 +686,9 @@ def agent( bus = MessageBus() provider = _make_provider(config) + # Migrate legacy global cron store into workspace (one-time) + _migrate_cron_store(config) + # Create cron service with workspace-scoped store cron_store_path = config.workspace_path / "cron" / "jobs.json" cron = CronService(cron_store_path) diff --git a/tests/test_commands.py b/tests/test_commands.py index fcb2f6a6b03..9875644954b 100644 --- a/tests/test_commands.py +++ b/tests/test_commands.py @@ -523,6 +523,47 @@ def __init__(self, store_path: Path) -> None: assert seen["cron_store"] == config.workspace_path / "cron" / "jobs.json" +def test_migrate_cron_store_moves_legacy_file(tmp_path: Path) -> None: + """Legacy global jobs.json is moved into the workspace on first run.""" + from nanobot.cli.commands import _migrate_cron_store + + legacy_dir = tmp_path / "global" / "cron" + legacy_dir.mkdir(parents=True) + legacy_file = legacy_dir / "jobs.json" + legacy_file.write_text('{"jobs": []}') + + config = Config() + config.agents.defaults.workspace = str(tmp_path / "workspace") + workspace_cron = config.workspace_path / "cron" / "jobs.json" + + with patch("nanobot.config.paths.get_cron_dir", return_value=legacy_dir): + _migrate_cron_store(config) + + assert workspace_cron.exists() + assert workspace_cron.read_text() == '{"jobs": []}' + assert not legacy_file.exists() + + +def test_migrate_cron_store_skips_when_workspace_file_exists(tmp_path: Path) -> None: + """Migration does not overwrite an existing workspace cron store.""" + from nanobot.cli.commands import _migrate_cron_store + + legacy_dir = tmp_path / "global" / "cron" + legacy_dir.mkdir(parents=True) + (legacy_dir / "jobs.json").write_text('{"old": true}') + + config = Config() + config.agents.defaults.workspace = str(tmp_path / "workspace") + workspace_cron = config.workspace_path / "cron" / "jobs.json" + workspace_cron.parent.mkdir(parents=True) + workspace_cron.write_text('{"new": true}') + + with patch("nanobot.config.paths.get_cron_dir", return_value=legacy_dir): + _migrate_cron_store(config) + + assert workspace_cron.read_text() == '{"new": true}' + + def test_gateway_uses_configured_port_when_cli_flag_is_missing(monkeypatch, tmp_path: Path) -> None: config_file = tmp_path / "instance" / "config.json" config_file.parent.mkdir(parents=True) From 9a2b1a3f1a348a97d1537db19278a487ed881e64 Mon Sep 17 00:00:00 2001 From: flobo3 Date: Sat, 21 Mar 2026 16:23:05 +0300 Subject: [PATCH 03/40] feat(telegram): add react_emoji config for incoming messages --- nanobot/channels/telegram.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/nanobot/channels/telegram.py b/nanobot/channels/telegram.py index 850e09c0f67..04cc89cc2c8 100644 --- a/nanobot/channels/telegram.py +++ b/nanobot/channels/telegram.py @@ -11,7 +11,7 @@ from loguru import logger from pydantic import Field -from telegram import BotCommand, ReplyParameters, Update +from telegram import BotCommand, ReactionTypeEmoji, ReplyParameters, Update from telegram.error import TimedOut from telegram.ext import Application, CommandHandler, ContextTypes, MessageHandler, filters from telegram.request import HTTPXRequest @@ -173,6 +173,7 @@ class TelegramConfig(Base): allow_from: list[str] = Field(default_factory=list) proxy: str | None = None reply_to_message: bool = False + react_emoji: str = "👀" group_policy: Literal["open", "mention"] = "mention" connection_pool_size: int = 32 pool_timeout: float = 5.0 @@ -812,6 +813,7 @@ async def _on_message(self, update: Update, context: ContextTypes.DEFAULT_TYPE) "session_key": session_key, } self._start_typing(str_chat_id) + await self._add_reaction(str_chat_id, message.message_id, self.config.react_emoji) buf = self._media_group_buffers[key] if content and content != "[empty message]": buf["contents"].append(content) @@ -822,6 +824,7 @@ async def _on_message(self, update: Update, context: ContextTypes.DEFAULT_TYPE) # Start typing indicator before processing self._start_typing(str_chat_id) + await self._add_reaction(str_chat_id, message.message_id, self.config.react_emoji) # Forward to the message bus await self._handle_message( @@ -861,6 +864,19 @@ def _stop_typing(self, chat_id: str) -> None: if task and not task.done(): task.cancel() + async def _add_reaction(self, chat_id: str, message_id: int, emoji: str) -> None: + """Add emoji reaction to a message (best-effort, non-blocking).""" + if not self._app or not emoji: + return + try: + await self._app.bot.set_message_reaction( + chat_id=int(chat_id), + message_id=message_id, + reaction=[ReactionTypeEmoji(emoji=emoji)], + ) + except Exception as e: + logger.debug("Telegram reaction failed: {}", e) + async def _typing_loop(self, chat_id: str) -> None: """Repeatedly send 'typing' action until cancelled.""" try: From 80ee2729ac0eff02a8b08ef3768b0e29e4165a6f Mon Sep 17 00:00:00 2001 From: Flo Date: Fri, 20 Mar 2026 09:31:09 +0300 Subject: [PATCH 04/40] feat(telegram): add silent_tool_hints config to disable notifications for tool hints (#2252) --- README.md | 3 ++- nanobot/channels/telegram.py | 8 +++++++- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 062abbbfc8c..73cdddcb658 100644 --- a/README.md +++ b/README.md @@ -263,7 +263,8 @@ Connect nanobot to your favorite chat platform. Want to build your own? See the "telegram": { "enabled": true, "token": "YOUR_BOT_TOKEN", - "allowFrom": ["YOUR_USER_ID"] + "allowFrom": ["YOUR_USER_ID"], + "silentToolHints": false } } } diff --git a/nanobot/channels/telegram.py b/nanobot/channels/telegram.py index 04cc89cc2c8..b9d52a64f27 100644 --- a/nanobot/channels/telegram.py +++ b/nanobot/channels/telegram.py @@ -178,6 +178,7 @@ class TelegramConfig(Base): connection_pool_size: int = 32 pool_timeout: float = 5.0 streaming: bool = True + silent_tool_hints: bool = False class TelegramChannel(BaseChannel): @@ -430,8 +431,10 @@ async def send(self, msg: OutboundMessage) -> None: # Send text content if msg.content and msg.content != "[empty message]": + disable_notification = self.config.silent_tool_hints and msg.metadata.get("_tool_hint", False) + for chunk in split_message(msg.content, TELEGRAM_MAX_MESSAGE_LEN): - await self._send_text(chat_id, chunk, reply_params, thread_kwargs) + await self._send_text(chat_id, chunk, reply_params, thread_kwargs, disable_notification=disable_notification) async def _call_with_retry(self, fn, *args, **kwargs): """Call an async Telegram API function with retry on pool/network timeout.""" @@ -454,6 +457,7 @@ async def _send_text( text: str, reply_params=None, thread_kwargs: dict | None = None, + disable_notification: bool = False, ) -> None: """Send a plain text message with HTML fallback.""" try: @@ -462,6 +466,7 @@ async def _send_text( self._app.bot.send_message, chat_id=chat_id, text=html, parse_mode="HTML", reply_parameters=reply_params, + disable_notification=disable_notification, **(thread_kwargs or {}), ) except Exception as e: @@ -472,6 +477,7 @@ async def _send_text( chat_id=chat_id, text=text, reply_parameters=reply_params, + disable_notification=disable_notification, **(thread_kwargs or {}), ) except Exception as e2: From d7373db41958893ac0c1031f85c0bf1a72223b45 Mon Sep 17 00:00:00 2001 From: Chen Junda Date: Fri, 20 Mar 2026 11:27:40 +0800 Subject: [PATCH 05/40] feat(qq): bot can send and receive images and files (#1667) Implement file upload and sending for QQ C2C messages Reference: https://github.com/tencent-connect/botpy/blob/master/examples/demo_c2c_reply_file.py --------- Co-authored-by: Claude Sonnet 4.6 Co-authored-by: chengyongru --- nanobot/channels/qq.py | 581 +++++++++++++++++++++++++++++++++++---- tests/test_qq_channel.py | 1 + 2 files changed, 521 insertions(+), 61 deletions(-) diff --git a/nanobot/channels/qq.py b/nanobot/channels/qq.py index e556c9867dd..5dae01b2a0b 100644 --- a/nanobot/channels/qq.py +++ b/nanobot/channels/qq.py @@ -1,33 +1,107 @@ -"""QQ channel implementation using botpy SDK.""" +"""QQ channel implementation using botpy SDK. + +Inbound: +- Parse QQ botpy messages (C2C / Group) +- Download attachments to media dir using chunked streaming write (memory-safe) +- Publish to Nanobot bus via BaseChannel._handle_message() +- Content includes a clear, actionable "Received files:" list with local paths + +Outbound: +- Send attachments (msg.media) first via QQ rich media API (base64 upload + msg_type=7) +- Then send text (plain or markdown) +- msg.media supports local paths, file:// paths, and http(s) URLs + +Notes: +- QQ restricts many audio/video formats. We conservatively classify as image vs file. +- Attachment structures differ across botpy versions; we try multiple field candidates. +""" + +from __future__ import annotations import asyncio +import base64 +import mimetypes +import os +import re +import time from collections import deque +from pathlib import Path from typing import TYPE_CHECKING, Any, Literal +from urllib.parse import unquote, urlparse +import aiohttp from loguru import logger +from pydantic import Field from nanobot.bus.events import OutboundMessage from nanobot.bus.queue import MessageBus from nanobot.channels.base import BaseChannel from nanobot.config.schema import Base -from pydantic import Field +from nanobot.security.network import validate_url_target + +try: + from nanobot.config.paths import get_media_dir +except Exception: # pragma: no cover + get_media_dir = None # type: ignore try: import botpy - from botpy.message import C2CMessage, GroupMessage + from botpy.http import Route QQ_AVAILABLE = True -except ImportError: +except ImportError: # pragma: no cover QQ_AVAILABLE = False botpy = None - C2CMessage = None - GroupMessage = None + Route = None if TYPE_CHECKING: - from botpy.message import C2CMessage, GroupMessage + from botpy.message import BaseMessage, C2CMessage, GroupMessage + from botpy.types.message import Media + + +# QQ rich media file_type: 1=image, 4=file +# (2=voice, 3=video are restricted; we only use image vs file) +QQ_FILE_TYPE_IMAGE = 1 +QQ_FILE_TYPE_FILE = 4 + +_IMAGE_EXTS = { + ".png", + ".jpg", + ".jpeg", + ".gif", + ".bmp", + ".webp", + ".tif", + ".tiff", + ".ico", +} +# Replace unsafe characters with "_", keep Chinese and common safe punctuation. +_SAFE_NAME_RE = re.compile(r"[^\w.\-()\[\]()【】\u4e00-\u9fff]+", re.UNICODE) -def _make_bot_class(channel: "QQChannel") -> "type[botpy.Client]": + +def _sanitize_filename(name: str) -> str: + """Sanitize filename to avoid traversal and problematic chars.""" + name = (name or "").strip() + name = Path(name).name + name = _SAFE_NAME_RE.sub("_", name).strip("._ ") + return name + + +def _is_image_name(name: str) -> bool: + return Path(name).suffix.lower() in _IMAGE_EXTS + + +def _guess_send_file_type(filename: str) -> int: + """Conservative send type: images -> 1, else -> 4.""" + ext = Path(filename).suffix.lower() + mime, _ = mimetypes.guess_type(filename) + if ext in _IMAGE_EXTS or (mime and mime.startswith("image/")): + return QQ_FILE_TYPE_IMAGE + return QQ_FILE_TYPE_FILE + + +def _make_bot_class(channel: QQChannel) -> type[botpy.Client]: """Create a botpy Client subclass bound to the given channel.""" intents = botpy.Intents(public_messages=True, direct_message=True) @@ -39,10 +113,10 @@ def __init__(self): async def on_ready(self): logger.info("QQ bot ready: {}", self.robot.name) - async def on_c2c_message_create(self, message: "C2CMessage"): + async def on_c2c_message_create(self, message: C2CMessage): await channel._on_message(message, is_group=False) - async def on_group_at_message_create(self, message: "GroupMessage"): + async def on_group_at_message_create(self, message: GroupMessage): await channel._on_message(message, is_group=True) async def on_direct_message_create(self, message): @@ -60,6 +134,13 @@ class QQConfig(Base): allow_from: list[str] = Field(default_factory=list) msg_format: Literal["plain", "markdown"] = "plain" + # Optional: directory to save inbound attachments. If empty, use nanobot get_media_dir("qq"). + media_dir: str = "" + + # Download tuning + download_chunk_size: int = 1024 * 256 # 256KB + download_max_bytes: int = 1024 * 1024 * 200 # 200MB safety limit + class QQChannel(BaseChannel): """QQ channel using botpy SDK with WebSocket connection.""" @@ -76,13 +157,38 @@ def __init__(self, config: Any, bus: MessageBus): config = QQConfig.model_validate(config) super().__init__(config, bus) self.config: QQConfig = config - self._client: "botpy.Client | None" = None - self._processed_ids: deque = deque(maxlen=1000) - self._msg_seq: int = 1 # 消息序列号,避免被 QQ API 去重 + + self._client: botpy.Client | None = None + self._http: aiohttp.ClientSession | None = None + + self._processed_ids: deque[str] = deque(maxlen=1000) + self._msg_seq: int = 1 # used to avoid QQ API dedup self._chat_type_cache: dict[str, str] = {} + self._media_root: Path = self._init_media_root() + + # --------------------------- + # Lifecycle + # --------------------------- + + def _init_media_root(self) -> Path: + """Choose a directory for saving inbound attachments.""" + if self.config.media_dir: + root = Path(self.config.media_dir).expanduser() + elif get_media_dir: + try: + root = Path(get_media_dir("qq")) + except Exception: + root = Path.home() / ".nanobot" / "media" / "qq" + else: + root = Path.home() / ".nanobot" / "media" / "qq" + + root.mkdir(parents=True, exist_ok=True) + logger.info("QQ media directory: {}", str(root)) + return root + async def start(self) -> None: - """Start the QQ bot.""" + """Start the QQ bot with auto-reconnect loop.""" if not QQ_AVAILABLE: logger.error("QQ SDK not installed. Run: pip install qq-botpy") return @@ -92,8 +198,9 @@ async def start(self) -> None: return self._running = True - BotClass = _make_bot_class(self) - self._client = BotClass() + self._http = aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=120)) + + self._client = _make_bot_class(self)() logger.info("QQ bot started (C2C & Group supported)") await self._run_bot() @@ -109,75 +216,427 @@ async def _run_bot(self) -> None: await asyncio.sleep(5) async def stop(self) -> None: - """Stop the QQ bot.""" + """Stop bot and cleanup resources.""" self._running = False if self._client: try: await self._client.close() except Exception: pass + self._client = None + + if self._http: + try: + await self._http.close() + except Exception: + pass + self._http = None + logger.info("QQ bot stopped") + # --------------------------- + # Outbound (send) + # --------------------------- + async def send(self, msg: OutboundMessage) -> None: - """Send a message through QQ.""" + """Send attachments first, then text.""" if not self._client: logger.warning("QQ client not initialized") return + msg_id = msg.metadata.get("message_id") + chat_type = self._chat_type_cache.get(msg.chat_id, "c2c") + is_group = chat_type == "group" + + # 1) Send media + for media_ref in msg.media or []: + ok = await self._send_media( + chat_id=msg.chat_id, + media_ref=media_ref, + msg_id=msg_id, + is_group=is_group, + ) + if not ok: + filename = ( + os.path.basename(urlparse(media_ref).path) + or os.path.basename(media_ref) + or "file" + ) + await self._send_text_only( + chat_id=msg.chat_id, + is_group=is_group, + msg_id=msg_id, + content=f"[Attachment send failed: {filename}]", + ) + + # 2) Send text + if msg.content and msg.content.strip(): + await self._send_text_only( + chat_id=msg.chat_id, + is_group=is_group, + msg_id=msg_id, + content=msg.content.strip(), + ) + + async def _send_text_only( + self, + chat_id: str, + is_group: bool, + msg_id: str | None, + content: str, + ) -> None: + """Send a plain/markdown text message.""" + if not self._client: + return + + self._msg_seq += 1 + use_markdown = self.config.msg_format == "markdown" + payload: dict[str, Any] = { + "msg_type": 2 if use_markdown else 0, + "msg_id": msg_id, + "msg_seq": self._msg_seq, + } + if use_markdown: + payload["markdown"] = {"content": content} + else: + payload["content"] = content + + if is_group: + await self._client.api.post_group_message(group_openid=chat_id, **payload) + else: + await self._client.api.post_c2c_message(openid=chat_id, **payload) + + async def _send_media( + self, + chat_id: str, + media_ref: str, + msg_id: str | None, + is_group: bool, + ) -> bool: + """Read bytes -> base64 upload -> msg_type=7 send.""" + if not self._client: + return False + + data, filename = await self._read_media_bytes(media_ref) + if not data or not filename: + return False + try: - msg_id = msg.metadata.get("message_id") - self._msg_seq += 1 - use_markdown = self.config.msg_format == "markdown" - payload: dict[str, Any] = { - "msg_type": 2 if use_markdown else 0, - "msg_id": msg_id, - "msg_seq": self._msg_seq, - } - if use_markdown: - payload["markdown"] = {"content": msg.content} - else: - payload["content"] = msg.content + file_type = _guess_send_file_type(filename) + file_data_b64 = base64.b64encode(data).decode() - chat_type = self._chat_type_cache.get(msg.chat_id, "c2c") - if chat_type == "group": + media_obj = await self._post_base64file( + chat_id=chat_id, + is_group=is_group, + file_type=file_type, + file_data=file_data_b64, + file_name=filename, + srv_send_msg=False, + ) + if not media_obj: + logger.error("QQ media upload failed: empty response") + return False + + self._msg_seq += 1 + if is_group: await self._client.api.post_group_message( - group_openid=msg.chat_id, - **payload, + group_openid=chat_id, + msg_type=7, + msg_id=msg_id, + msg_seq=self._msg_seq, + media=media_obj, ) else: await self._client.api.post_c2c_message( - openid=msg.chat_id, - **payload, + openid=chat_id, + msg_type=7, + msg_id=msg_id, + msg_seq=self._msg_seq, + media=media_obj, ) + + logger.info("QQ media sent: {}", filename) + return True except Exception as e: - logger.error("Error sending QQ message: {}", e) + logger.error("QQ send media failed filename={} err={}", filename, e) + return False - async def _on_message(self, data: "C2CMessage | GroupMessage", is_group: bool = False) -> None: - """Handle incoming message from QQ.""" - try: - # Dedup by message ID - if data.id in self._processed_ids: - return - self._processed_ids.append(data.id) + async def _read_media_bytes(self, media_ref: str) -> tuple[bytes | None, str | None]: + """Read bytes from http(s) or local file path; return (data, filename).""" + media_ref = (media_ref or "").strip() + if not media_ref: + return None, None - content = (data.content or "").strip() - if not content: - return + ok, err = validate_url_target(media_ref) - if is_group: - chat_id = data.group_openid - user_id = data.author.member_openid - self._chat_type_cache[chat_id] = "group" + if not ok: + logger.warning("QQ outbound media URL validation failed url={} err={}", media_ref, err) + return None, None + + if not self._http: + self._http = aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=120)) + try: + async with self._http.get(media_ref, allow_redirects=True) as resp: + if resp.status >= 400: + logger.warning( + "QQ outbound media download failed status={} url={}", + resp.status, + media_ref, + ) + return None, None + data = await resp.read() + if not data: + return None, None + filename = os.path.basename(urlparse(media_ref).path) or "file.bin" + return data, filename + except Exception as e: + logger.warning("QQ outbound media download error url={} err={}", media_ref, e) + return None, None + + # Local file + try: + if media_ref.startswith("file://"): + parsed = urlparse(media_ref) + local_path = Path(unquote(parsed.path)) else: - chat_id = str(getattr(data.author, 'id', None) or getattr(data.author, 'user_openid', 'unknown')) - user_id = chat_id - self._chat_type_cache[chat_id] = "c2c" + local_path = Path(os.path.expanduser(media_ref)) - await self._handle_message( - sender_id=user_id, - chat_id=chat_id, - content=content, - metadata={"message_id": data.id}, + if not local_path.is_file(): + logger.warning("QQ outbound media file not found: {}", str(local_path)) + return None, None + + data = await asyncio.to_thread(local_path.read_bytes) + return data, local_path.name + except Exception as e: + logger.warning("QQ outbound media read error ref={} err={}", media_ref, e) + return None, None + + # https://github.com/tencent-connect/botpy/issues/198 + # https://bot.q.qq.com/wiki/develop/api-v2/server-inter/message/send-receive/rich-media.html + async def _post_base64file( + self, + chat_id: str, + is_group: bool, + file_type: int, + file_data: str, + file_name: str | None = None, + srv_send_msg: bool = False, + ) -> Media: + """Upload base64-encoded file and return Media object.""" + if not self._client: + raise RuntimeError("QQ client not initialized") + + if is_group: + endpoint = "/v2/groups/{group_openid}/files" + id_key = "group_openid" + else: + endpoint = "/v2/users/{openid}/files" + id_key = "openid" + + payload = { + id_key: chat_id, + "file_type": file_type, + "file_data": file_data, + "file_name": file_name, + "srv_send_msg": srv_send_msg, + } + route = Route("POST", endpoint, **{id_key: chat_id}) + return await self._client.api._http.request(route, json=payload) + + # --------------------------- + # Inbound (receive) + # --------------------------- + + async def _on_message(self, data: C2CMessage | GroupMessage, is_group: bool = False) -> None: + """Parse inbound message, download attachments, and publish to the bus.""" + if data.id in self._processed_ids: + return + self._processed_ids.append(data.id) + + if is_group: + chat_id = data.group_openid + user_id = data.author.member_openid + self._chat_type_cache[chat_id] = "group" + else: + chat_id = str( + getattr(data.author, "id", None) + or getattr(data.author, "user_openid", "unknown") ) - except Exception: - logger.exception("Error handling QQ message") + user_id = chat_id + self._chat_type_cache[chat_id] = "c2c" + + content = (data.content or "").strip() + + # the data used by tests don't contain attachments property + # so we use getattr with a default of [] to avoid AttributeError in tests + attachments = getattr(data, "attachments", None) or [] + media_paths, recv_lines, att_meta = await self._handle_attachments(attachments) + + # Compose content that always contains actionable saved paths + if recv_lines: + tag = ( + "[Image]" + if any(_is_image_name(Path(p).name) for p in media_paths) + else "[File]" + ) + file_block = "Received files:\n" + "\n".join(recv_lines) + content = ( + f"{content}\n\n{file_block}".strip() if content else f"{tag}\n{file_block}" + ) + + if not content and not media_paths: + return + + await self._handle_message( + sender_id=user_id, + chat_id=chat_id, + content=content, + media=media_paths if media_paths else None, + metadata={ + "message_id": data.id, + "attachments": att_meta, + }, + ) + + async def _handle_attachments( + self, + attachments: list[BaseMessage._Attachments], + ) -> tuple[list[str], list[str], list[dict[str, Any]]]: + """Extract, download (chunked), and format attachments for agent consumption.""" + media_paths: list[str] = [] + recv_lines: list[str] = [] + att_meta: list[dict[str, Any]] = [] + + if not attachments: + return media_paths, recv_lines, att_meta + + for att in attachments: + url, filename, ctype = att.url, att.filename, att.content_type + + logger.info("Downloading file from QQ: {}", filename or url) + local_path = await self._download_to_media_dir_chunked(url, filename_hint=filename) + + att_meta.append( + { + "url": url, + "filename": filename, + "content_type": ctype, + "saved_path": local_path, + } + ) + + if local_path: + media_paths.append(local_path) + shown_name = filename or os.path.basename(local_path) + recv_lines.append(f"- {shown_name}\n saved: {local_path}") + else: + shown_name = filename or url + recv_lines.append(f"- {shown_name}\n saved: [download failed]") + + return media_paths, recv_lines, att_meta + + async def _download_to_media_dir_chunked( + self, + url: str, + filename_hint: str = "", + ) -> str | None: + """Download an inbound attachment using streaming chunk write. + + Uses chunked streaming to avoid loading large files into memory. + Enforces a max download size and writes to a .part temp file + that is atomically renamed on success. + """ + if not self._http: + self._http = aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=120)) + + safe = _sanitize_filename(filename_hint) + ts = int(time.time() * 1000) + tmp_path: Path | None = None + + try: + async with self._http.get( + url, + timeout=aiohttp.ClientTimeout(total=120), + allow_redirects=True, + ) as resp: + if resp.status != 200: + logger.warning("QQ download failed: status={} url={}", resp.status, url) + return None + + ctype = (resp.headers.get("Content-Type") or "").lower() + + # Infer extension: url -> filename_hint -> content-type -> fallback + ext = Path(urlparse(url).path).suffix + if not ext: + ext = Path(filename_hint).suffix + if not ext: + if "png" in ctype: + ext = ".png" + elif "jpeg" in ctype or "jpg" in ctype: + ext = ".jpg" + elif "gif" in ctype: + ext = ".gif" + elif "webp" in ctype: + ext = ".webp" + elif "pdf" in ctype: + ext = ".pdf" + else: + ext = ".bin" + + if safe: + if not Path(safe).suffix: + safe = safe + ext + filename = safe + else: + filename = f"qq_file_{ts}{ext}" + + target = self._media_root / filename + if target.exists(): + target = self._media_root / f"{target.stem}_{ts}{target.suffix}" + + tmp_path = target.with_suffix(target.suffix + ".part") + + # Stream write + downloaded = 0 + chunk_size = max(1024, int(self.config.download_chunk_size or 262144)) + max_bytes = max( + 1024 * 1024, int(self.config.download_max_bytes or (200 * 1024 * 1024)) + ) + + def _open_tmp(): + tmp_path.parent.mkdir(parents=True, exist_ok=True) + return open(tmp_path, "wb") # noqa: SIM115 + + f = await asyncio.to_thread(_open_tmp) + try: + async for chunk in resp.content.iter_chunked(chunk_size): + if not chunk: + continue + downloaded += len(chunk) + if downloaded > max_bytes: + logger.warning( + "QQ download exceeded max_bytes={} url={} -> abort", + max_bytes, + url, + ) + return None + await asyncio.to_thread(f.write, chunk) + finally: + await asyncio.to_thread(f.close) + + # Atomic rename + await asyncio.to_thread(os.replace, tmp_path, target) + tmp_path = None # mark as moved + logger.info("QQ file saved: {}", str(target)) + return str(target) + + except Exception as e: + logger.error("QQ download error: {}", e) + return None + finally: + # Cleanup partial file + if tmp_path is not None: + try: + tmp_path.unlink(missing_ok=True) + except Exception: + pass diff --git a/tests/test_qq_channel.py b/tests/test_qq_channel.py index bd5e8911c54..ab09ff347e4 100644 --- a/tests/test_qq_channel.py +++ b/tests/test_qq_channel.py @@ -34,6 +34,7 @@ async def test_on_group_message_routes_to_group_chat_id() -> None: content="hello", group_openid="group123", author=SimpleNamespace(member_openid="user1"), + attachments=[], ) await channel._on_message(data, is_group=True) From 2db2cc18f1a40fb79b76cc137b71e5d277ce2205 Mon Sep 17 00:00:00 2001 From: Chen Junda Date: Fri, 20 Mar 2026 16:42:46 +0800 Subject: [PATCH 06/40] fix(qq): fix local file outbound and add svg as image type (#2294) - Fix _read_media_bytes treating local paths as URLs: local file handling code was dead code placed after an early return inside the HTTP try/except block. Restructure to check for local paths (plain path or file:// URI) before URL validation, so files like /home/.../.nanobot/workspace/generated_image.svg can be read and sent correctly. - Add .svg to _IMAGE_EXTS so SVG files are uploaded as file_type=1 (image) instead of file_type=4 (file). - Add tests for local path, file:// URI, and missing file cases. Fixes: https://github.com/HKUDS/nanobot/pull/1667#issuecomment-4096400955 Co-authored-by: Claude Sonnet 4.6 --- nanobot/channels/qq.py | 53 ++++++++++++++++++---------------------- tests/test_qq_channel.py | 40 ++++++++++++++++++++++++++++-- 2 files changed, 62 insertions(+), 31 deletions(-) diff --git a/nanobot/channels/qq.py b/nanobot/channels/qq.py index 5dae01b2a0b..7442e10069d 100644 --- a/nanobot/channels/qq.py +++ b/nanobot/channels/qq.py @@ -74,6 +74,7 @@ ".tif", ".tiff", ".ico", + ".svg", } # Replace unsafe characters with "_", keep Chinese and common safe punctuation. @@ -367,8 +368,27 @@ async def _read_media_bytes(self, media_ref: str) -> tuple[bytes | None, str | N if not media_ref: return None, None - ok, err = validate_url_target(media_ref) + # Local file: plain path or file:// URI + if not media_ref.startswith("http://") and not media_ref.startswith("https://"): + try: + if media_ref.startswith("file://"): + parsed = urlparse(media_ref) + local_path = Path(unquote(parsed.path)) + else: + local_path = Path(os.path.expanduser(media_ref)) + + if not local_path.is_file(): + logger.warning("QQ outbound media file not found: {}", str(local_path)) + return None, None + data = await asyncio.to_thread(local_path.read_bytes) + return data, local_path.name + except Exception as e: + logger.warning("QQ outbound media read error ref={} err={}", media_ref, e) + return None, None + + # Remote URL + ok, err = validate_url_target(media_ref) if not ok: logger.warning("QQ outbound media URL validation failed url={} err={}", media_ref, err) return None, None @@ -393,24 +413,6 @@ async def _read_media_bytes(self, media_ref: str) -> tuple[bytes | None, str | N logger.warning("QQ outbound media download error url={} err={}", media_ref, e) return None, None - # Local file - try: - if media_ref.startswith("file://"): - parsed = urlparse(media_ref) - local_path = Path(unquote(parsed.path)) - else: - local_path = Path(os.path.expanduser(media_ref)) - - if not local_path.is_file(): - logger.warning("QQ outbound media file not found: {}", str(local_path)) - return None, None - - data = await asyncio.to_thread(local_path.read_bytes) - return data, local_path.name - except Exception as e: - logger.warning("QQ outbound media read error ref={} err={}", media_ref, e) - return None, None - # https://github.com/tencent-connect/botpy/issues/198 # https://bot.q.qq.com/wiki/develop/api-v2/server-inter/message/send-receive/rich-media.html async def _post_base64file( @@ -459,8 +461,7 @@ async def _on_message(self, data: C2CMessage | GroupMessage, is_group: bool = Fa self._chat_type_cache[chat_id] = "group" else: chat_id = str( - getattr(data.author, "id", None) - or getattr(data.author, "user_openid", "unknown") + getattr(data.author, "id", None) or getattr(data.author, "user_openid", "unknown") ) user_id = chat_id self._chat_type_cache[chat_id] = "c2c" @@ -474,15 +475,9 @@ async def _on_message(self, data: C2CMessage | GroupMessage, is_group: bool = Fa # Compose content that always contains actionable saved paths if recv_lines: - tag = ( - "[Image]" - if any(_is_image_name(Path(p).name) for p in media_paths) - else "[File]" - ) + tag = "[Image]" if any(_is_image_name(Path(p).name) for p in media_paths) else "[File]" file_block = "Received files:\n" + "\n".join(recv_lines) - content = ( - f"{content}\n\n{file_block}".strip() if content else f"{tag}\n{file_block}" - ) + content = f"{content}\n\n{file_block}".strip() if content else f"{tag}\n{file_block}" if not content and not media_paths: return diff --git a/tests/test_qq_channel.py b/tests/test_qq_channel.py index ab09ff347e4..ab9afcbc7f5 100644 --- a/tests/test_qq_channel.py +++ b/tests/test_qq_channel.py @@ -1,11 +1,12 @@ +import tempfile +from pathlib import Path from types import SimpleNamespace import pytest from nanobot.bus.events import OutboundMessage from nanobot.bus.queue import MessageBus -from nanobot.channels.qq import QQChannel -from nanobot.channels.qq import QQConfig +from nanobot.channels.qq import QQChannel, QQConfig class _FakeApi: @@ -124,3 +125,38 @@ async def test_send_group_message_uses_markdown_when_configured() -> None: "msg_id": "msg1", "msg_seq": 2, } + + +@pytest.mark.asyncio +async def test_read_media_bytes_local_path() -> None: + channel = QQChannel(QQConfig(app_id="app", secret="secret"), MessageBus()) + + with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f: + f.write(b"\x89PNG\r\n") + tmp_path = f.name + + data, filename = await channel._read_media_bytes(tmp_path) + assert data == b"\x89PNG\r\n" + assert filename == Path(tmp_path).name + + +@pytest.mark.asyncio +async def test_read_media_bytes_file_uri() -> None: + channel = QQChannel(QQConfig(app_id="app", secret="secret"), MessageBus()) + + with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as f: + f.write(b"JFIF") + tmp_path = f.name + + data, filename = await channel._read_media_bytes(f"file://{tmp_path}") + assert data == b"JFIF" + assert filename == Path(tmp_path).name + + +@pytest.mark.asyncio +async def test_read_media_bytes_missing_file() -> None: + channel = QQChannel(QQConfig(app_id="app", secret="secret"), MessageBus()) + + data, filename = await channel._read_media_bytes("/nonexistent/path/image.png") + assert data is None + assert filename is None From e4137736f6aa32011f88ce46e90a7b039e5b8053 Mon Sep 17 00:00:00 2001 From: chengyongru Date: Mon, 23 Mar 2026 15:18:54 +0800 Subject: [PATCH 07/40] fix(qq): handle file:// URI on Windows in _read_media_bytes urlparse on Windows puts the path in netloc, not path. Use (parsed.path or parsed.netloc) to get the correct raw path. --- nanobot/channels/qq.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/nanobot/channels/qq.py b/nanobot/channels/qq.py index 7442e10069d..b9d2d64d867 100644 --- a/nanobot/channels/qq.py +++ b/nanobot/channels/qq.py @@ -373,7 +373,9 @@ async def _read_media_bytes(self, media_ref: str) -> tuple[bytes | None, str | N try: if media_ref.startswith("file://"): parsed = urlparse(media_ref) - local_path = Path(unquote(parsed.path)) + # Windows: path in netloc; Unix: path in path + raw = parsed.path or parsed.netloc + local_path = Path(unquote(raw)) else: local_path = Path(os.path.expanduser(media_ref)) From b14d5a0a1d7a3891928c3053378f9842b5b48079 Mon Sep 17 00:00:00 2001 From: flobo3 Date: Wed, 18 Mar 2026 18:13:13 +0300 Subject: [PATCH 08/40] feat(whatsapp): add group_policy to control bot response behavior in groups --- nanobot/channels/whatsapp.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/nanobot/channels/whatsapp.py b/nanobot/channels/whatsapp.py index b689e3060d5..6f4271e2421 100644 --- a/nanobot/channels/whatsapp.py +++ b/nanobot/channels/whatsapp.py @@ -4,7 +4,7 @@ import json import mimetypes from collections import OrderedDict -from typing import Any +from typing import Any, Literal from loguru import logger @@ -23,6 +23,7 @@ class WhatsAppConfig(Base): bridge_url: str = "ws://localhost:3001" bridge_token: str = "" allow_from: list[str] = Field(default_factory=list) + group_policy: Literal["open", "mention"] = "open" # "open" responds to all, "mention" only when @mentioned class WhatsAppChannel(BaseChannel): @@ -138,6 +139,13 @@ async def _handle_bridge_message(self, raw: str) -> None: self._processed_message_ids.popitem(last=False) # Extract just the phone number or lid as chat_id + is_group = data.get("isGroup", False) + was_mentioned = data.get("wasMentioned", False) + + if is_group and getattr(self.config, "group_policy", "open") == "mention": + if not was_mentioned: + return + user_id = pn if pn else sender sender_id = user_id.split("@")[0] if "@" in user_id else user_id logger.info("Sender {}", sender) From 4145f3eaccc6bdc992c8fe46f086d12bcb807b4f Mon Sep 17 00:00:00 2001 From: kohath Date: Fri, 20 Mar 2026 22:26:27 +0800 Subject: [PATCH 09/40] feat(feishu): add thread reply support for topic group messages --- nanobot/channels/feishu.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/nanobot/channels/feishu.py b/nanobot/channels/feishu.py index 5e3d126f696..06daf409d15 100644 --- a/nanobot/channels/feishu.py +++ b/nanobot/channels/feishu.py @@ -960,6 +960,9 @@ async def send(self, msg: OutboundMessage) -> None: and not msg.metadata.get("_progress", False) ): reply_message_id = msg.metadata.get("message_id") or None + # For topic group messages, always reply to keep context in thread + elif msg.metadata.get("thread_id"): + reply_message_id = msg.metadata.get("root_id") or msg.metadata.get("message_id") or None first_send = True # tracks whether the reply has already been used @@ -1121,6 +1124,7 @@ async def _on_message(self, data: Any) -> None: # Extract reply context (parent/root message IDs) parent_id = getattr(message, "parent_id", None) or None root_id = getattr(message, "root_id", None) or None + thread_id = getattr(message, "thread_id", None) or None # Prepend quoted message text when the user replied to another message if parent_id and self._client: @@ -1149,6 +1153,7 @@ async def _on_message(self, data: Any) -> None: "msg_type": msg_type, "parent_id": parent_id, "root_id": root_id, + "thread_id": thread_id, } ) From d454386f3266dbd9f843874192e4de280d77f7b9 Mon Sep 17 00:00:00 2001 From: Xubin Ren Date: Tue, 24 Mar 2026 02:51:50 +0000 Subject: [PATCH 10/40] docs(weixin): clarify source-only installation in README --- README.md | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index e7932829246..797a5bcf274 100644 --- a/README.md +++ b/README.md @@ -724,10 +724,14 @@ nanobot gateway Uses **HTTP long-poll** with QR-code login via the ilinkai personal WeChat API. No local WeChat desktop client is required. -**1. Install the optional dependency** +> Weixin support is available from source checkout, but is not included in the current PyPI release yet. + +**1. Install from source** ```bash -pip install nanobot-ai[weixin] +git clone https://github.com/HKUDS/nanobot.git +cd nanobot +pip install -e ".[weixin]" ``` **2. Configure** From 14763a6ad1721736ae0658b485a218107618972b Mon Sep 17 00:00:00 2001 From: Xubin Ren Date: Tue, 24 Mar 2026 03:03:59 +0000 Subject: [PATCH 11/40] fix(provider): accept canonical and alias provider names consistently --- nanobot/config/schema.py | 9 ++++++--- nanobot/providers/registry.py | 5 ++++- tests/test_commands.py | 30 +++++++++++++++++++++++++++++- 3 files changed, 39 insertions(+), 5 deletions(-) diff --git a/nanobot/config/schema.py b/nanobot/config/schema.py index 7d8f5c863d6..b31f3061a83 100644 --- a/nanobot/config/schema.py +++ b/nanobot/config/schema.py @@ -165,12 +165,15 @@ def _match_provider( self, model: str | None = None ) -> tuple["ProviderConfig | None", str | None]: """Match provider config and its registry name. Returns (config, spec_name).""" - from nanobot.providers.registry import PROVIDERS + from nanobot.providers.registry import PROVIDERS, find_by_name forced = self.agents.defaults.provider if forced != "auto": - p = getattr(self.providers, forced, None) - return (p, forced) if p else (None, None) + spec = find_by_name(forced) + if spec: + p = getattr(self.providers, spec.name, None) + return (p, spec.name) if p else (None, None) + return None, None model_lower = (model or self.agents.defaults.model).lower() model_normalized = model_lower.replace("-", "_") diff --git a/nanobot/providers/registry.py b/nanobot/providers/registry.py index 9cc430b883c..10e0fec9db8 100644 --- a/nanobot/providers/registry.py +++ b/nanobot/providers/registry.py @@ -15,6 +15,8 @@ from dataclasses import dataclass, field from typing import Any +from pydantic.alias_generators import to_snake + @dataclass(frozen=True) class ProviderSpec: @@ -545,7 +547,8 @@ def find_gateway( def find_by_name(name: str) -> ProviderSpec | None: """Find a provider spec by config field name, e.g. "dashscope".""" + normalized = to_snake(name.replace("-", "_")) for spec in PROVIDERS: - if spec.name == name: + if spec.name == normalized: return spec return None diff --git a/tests/test_commands.py b/tests/test_commands.py index 68cc429c00d..4e79fc71732 100644 --- a/tests/test_commands.py +++ b/tests/test_commands.py @@ -11,7 +11,7 @@ from nanobot.config.schema import Config from nanobot.providers.litellm_provider import LiteLLMProvider from nanobot.providers.openai_codex_provider import _strip_model_prefix -from nanobot.providers.registry import find_by_model +from nanobot.providers.registry import find_by_model, find_by_name runner = CliRunner() @@ -240,6 +240,34 @@ def test_config_explicit_ollama_provider_uses_default_localhost_api_base(): assert config.get_api_base() == "http://localhost:11434" +def test_config_accepts_camel_case_explicit_provider_name_for_coding_plan(): + config = Config.model_validate( + { + "agents": { + "defaults": { + "provider": "volcengineCodingPlan", + "model": "doubao-1-5-pro", + } + }, + "providers": { + "volcengineCodingPlan": { + "apiKey": "test-key", + } + }, + } + ) + + assert config.get_provider_name() == "volcengine_coding_plan" + assert config.get_api_base() == "https://ark.cn-beijing.volces.com/api/coding/v3" + + +def test_find_by_name_accepts_camel_case_and_hyphen_aliases(): + assert find_by_name("volcengineCodingPlan") is not None + assert find_by_name("volcengineCodingPlan").name == "volcengine_coding_plan" + assert find_by_name("github-copilot") is not None + assert find_by_name("github-copilot").name == "github_copilot" + + def test_config_auto_detects_ollama_from_local_api_base(): config = Config.model_validate( { From 69f1dcdba7c843a21ba845f6d6d1cc21c183293b Mon Sep 17 00:00:00 2001 From: 19emtuck Date: Sun, 22 Mar 2026 19:08:45 +0100 Subject: [PATCH 12/40] proposal to adopt mypy some e.g. interfaces problems --- nanobot/agent/tools/filesystem.py | 24 ++++++++++++++++++++---- pyproject.toml | 1 + 2 files changed, 21 insertions(+), 4 deletions(-) diff --git a/nanobot/agent/tools/filesystem.py b/nanobot/agent/tools/filesystem.py index 4f83642ba12..8ccffb2c099 100644 --- a/nanobot/agent/tools/filesystem.py +++ b/nanobot/agent/tools/filesystem.py @@ -93,8 +93,10 @@ def parameters(self) -> dict[str, Any]: "required": ["path"], } - async def execute(self, path: str, offset: int = 1, limit: int | None = None, **kwargs: Any) -> Any: + async def execute(self, path: str | None = None, offset: int = 1, limit: int | None = None, **kwargs: Any) -> Any: try: + if not path: + return f"Error: File not found: {path}" fp = self._resolve(path) if not fp.exists(): return f"Error: File not found: {path}" @@ -174,8 +176,12 @@ def parameters(self) -> dict[str, Any]: "required": ["path", "content"], } - async def execute(self, path: str, content: str, **kwargs: Any) -> str: + async def execute(self, path: str | None = None, content: str | None = None, **kwargs: Any) -> str: try: + if not path: + raise ValueError(f"Unknown path") + if content is None: + raise ValueError("Unknown content") fp = self._resolve(path) fp.parent.mkdir(parents=True, exist_ok=True) fp.write_text(content, encoding="utf-8") @@ -248,10 +254,18 @@ def parameters(self) -> dict[str, Any]: } async def execute( - self, path: str, old_text: str, new_text: str, + self, path: str | None = None, old_text: str | None = None, + new_text: str | None = None, replace_all: bool = False, **kwargs: Any, ) -> str: try: + if not path: + raise ValueError(f"Unknown path") + if old_text is None: + raise ValueError(f"Unknown old_text") + if new_text is None: + raise ValueError(f"Unknown next_text") + fp = self._resolve(path) if not fp.exists(): return f"Error: File not found: {path}" @@ -350,10 +364,12 @@ def parameters(self) -> dict[str, Any]: } async def execute( - self, path: str, recursive: bool = False, + self, path: str | None = None, recursive: bool = False, max_entries: int | None = None, **kwargs: Any, ) -> str: try: + if path is None: + raise ValueError(f"Unknown path") dp = self._resolve(path) if not dp.exists(): return f"Error: Directory not found: {path}" diff --git a/pyproject.toml b/pyproject.toml index b7657206858..a941ab17d8e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -74,6 +74,7 @@ dev = [ "matrix-nio[e2e]>=0.25.2", "mistune>=3.0.0,<4.0.0", "nh3>=0.2.17,<1.0.0", + "mypy>=1.19.1", ] [project.scripts] From d4a7194c88fc47b57ed254f5ad587ac309719b8b Mon Sep 17 00:00:00 2001 From: 19emtuck Date: Mon, 23 Mar 2026 12:26:06 +0100 Subject: [PATCH 13/40] remove some none used f string --- nanobot/agent/tools/filesystem.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/nanobot/agent/tools/filesystem.py b/nanobot/agent/tools/filesystem.py index 8ccffb2c099..a967073ef14 100644 --- a/nanobot/agent/tools/filesystem.py +++ b/nanobot/agent/tools/filesystem.py @@ -179,7 +179,7 @@ def parameters(self) -> dict[str, Any]: async def execute(self, path: str | None = None, content: str | None = None, **kwargs: Any) -> str: try: if not path: - raise ValueError(f"Unknown path") + raise ValueError("Unknown path") if content is None: raise ValueError("Unknown content") fp = self._resolve(path) @@ -260,11 +260,11 @@ async def execute( ) -> str: try: if not path: - raise ValueError(f"Unknown path") + raise ValueError("Unknown path") if old_text is None: - raise ValueError(f"Unknown old_text") + raise ValueError("Unknown old_text") if new_text is None: - raise ValueError(f"Unknown next_text") + raise ValueError("Unknown next_text") fp = self._resolve(path) if not fp.exists(): @@ -369,7 +369,7 @@ async def execute( ) -> str: try: if path is None: - raise ValueError(f"Unknown path") + raise ValueError("Unknown path") dp = self._resolve(path) if not dp.exists(): return f"Error: Directory not found: {path}" From d25985be0b7631e54acb1c6dfb9f500b3eb094d3 Mon Sep 17 00:00:00 2001 From: Xubin Ren Date: Tue, 24 Mar 2026 03:45:16 +0000 Subject: [PATCH 14/40] fix(filesystem): clarify optional tool argument handling Keep the mypy-friendly optional execute signatures while returning clearer errors for missing arguments and locking that behavior with regression tests. Made-with: Cursor --- nanobot/agent/tools/filesystem.py | 4 ++-- tests/test_filesystem_tools.py | 17 +++++++++++++++++ 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/nanobot/agent/tools/filesystem.py b/nanobot/agent/tools/filesystem.py index a967073ef14..da7778da3a7 100644 --- a/nanobot/agent/tools/filesystem.py +++ b/nanobot/agent/tools/filesystem.py @@ -96,7 +96,7 @@ def parameters(self) -> dict[str, Any]: async def execute(self, path: str | None = None, offset: int = 1, limit: int | None = None, **kwargs: Any) -> Any: try: if not path: - return f"Error: File not found: {path}" + return "Error reading file: Unknown path" fp = self._resolve(path) if not fp.exists(): return f"Error: File not found: {path}" @@ -264,7 +264,7 @@ async def execute( if old_text is None: raise ValueError("Unknown old_text") if new_text is None: - raise ValueError("Unknown next_text") + raise ValueError("Unknown new_text") fp = self._resolve(path) if not fp.exists(): diff --git a/tests/test_filesystem_tools.py b/tests/test_filesystem_tools.py index 76d0a512481..ca6629edbb5 100644 --- a/tests/test_filesystem_tools.py +++ b/tests/test_filesystem_tools.py @@ -77,6 +77,11 @@ async def test_file_not_found(self, tool, tmp_path): assert "Error" in result assert "not found" in result + @pytest.mark.asyncio + async def test_missing_path_returns_clear_error(self, tool): + result = await tool.execute() + assert result == "Error reading file: Unknown path" + @pytest.mark.asyncio async def test_char_budget_trims(self, tool, tmp_path): """When the selected slice exceeds _MAX_CHARS the output is trimmed.""" @@ -200,6 +205,13 @@ async def test_not_found(self, tool, tmp_path): assert "Error" in result assert "not found" in result + @pytest.mark.asyncio + async def test_missing_new_text_returns_clear_error(self, tool, tmp_path): + f = tmp_path / "a.py" + f.write_text("hello", encoding="utf-8") + result = await tool.execute(path=str(f), old_text="hello") + assert result == "Error editing file: Unknown new_text" + # --------------------------------------------------------------------------- # ListDirTool @@ -265,6 +277,11 @@ async def test_not_found(self, tool, tmp_path): assert "Error" in result assert "not found" in result + @pytest.mark.asyncio + async def test_missing_path_returns_clear_error(self, tool): + result = await tool.execute() + assert result == "Error listing directory: Unknown path" + # --------------------------------------------------------------------------- # Workspace restriction + extra_allowed_dirs From 72acba5d274b7148d147f3ad7e60d88932b5aeb4 Mon Sep 17 00:00:00 2001 From: chengyongru Date: Tue, 24 Mar 2026 13:37:06 +0800 Subject: [PATCH 15/40] refactor(tests): optimize unit test structure --- .github/workflows/ci.yml | 11 ++++++----- nanobot/agent/tools/shell.py | 10 ++++++---- pyproject.toml | 5 +---- tests/{ => agent}/test_consolidate_offset.py | 0 tests/{ => agent}/test_context_prompt_cache.py | 0 tests/{ => agent}/test_evaluator.py | 0 tests/{ => agent}/test_gemini_thought_signature.py | 0 tests/{ => agent}/test_heartbeat_service.py | 0 tests/{ => agent}/test_loop_consolidation_tokens.py | 0 tests/{ => agent}/test_loop_save_turn.py | 0 tests/{ => agent}/test_memory_consolidation_types.py | 0 tests/{ => agent}/test_onboard_logic.py | 0 tests/{ => agent}/test_session_manager_history.py | 0 tests/{ => agent}/test_skill_creator_scripts.py | 0 tests/{ => agent}/test_task_cancel.py | 0 tests/{ => channels}/test_base_channel.py | 0 tests/{ => channels}/test_channel_plugins.py | 0 tests/{ => channels}/test_dingtalk_channel.py | 10 ++++++++++ tests/{ => channels}/test_email_channel.py | 0 .../{ => channels}/test_feishu_markdown_rendering.py | 11 +++++++++++ tests/{ => channels}/test_feishu_post_content.py | 11 +++++++++++ tests/{ => channels}/test_feishu_reply.py | 10 ++++++++++ tests/{ => channels}/test_feishu_table_split.py | 11 +++++++++++ .../test_feishu_tool_hint_code_block.py | 10 ++++++++++ tests/{ => channels}/test_matrix_channel.py | 6 ++++++ tests/{ => channels}/test_qq_channel.py | 10 ++++++++++ tests/{ => channels}/test_slack_channel.py | 6 ++++++ tests/{ => channels}/test_telegram_channel.py | 6 ++++++ tests/{ => channels}/test_weixin_channel.py | 0 tests/{ => channels}/test_whatsapp_channel.py | 0 tests/{ => cli}/test_cli_input.py | 0 tests/{ => cli}/test_commands.py | 0 tests/{ => cli}/test_restart_command.py | 0 tests/{ => config}/test_config_migration.py | 0 tests/{ => config}/test_config_paths.py | 0 tests/{ => cron}/test_cron_service.py | 0 tests/{ => cron}/test_cron_tool_list.py | 0 tests/{ => providers}/test_azure_openai_provider.py | 0 tests/{ => providers}/test_custom_provider.py | 0 tests/{ => providers}/test_litellm_kwargs.py | 0 tests/{ => providers}/test_mistral_provider.py | 0 tests/{ => providers}/test_provider_retry.py | 0 tests/{ => providers}/test_providers_init.py | 0 tests/{ => security}/test_security_network.py | 0 tests/{ => tools}/test_exec_security.py | 0 tests/{ => tools}/test_filesystem_tools.py | 0 tests/{ => tools}/test_mcp_tool.py | 0 tests/{ => tools}/test_message_tool.py | 0 tests/{ => tools}/test_message_tool_suppress.py | 0 tests/{ => tools}/test_tool_validation.py | 0 tests/{ => tools}/test_web_fetch_security.py | 0 tests/{ => tools}/test_web_search_tool.py | 0 52 files changed, 104 insertions(+), 13 deletions(-) rename tests/{ => agent}/test_consolidate_offset.py (100%) rename tests/{ => agent}/test_context_prompt_cache.py (100%) rename tests/{ => agent}/test_evaluator.py (100%) rename tests/{ => agent}/test_gemini_thought_signature.py (100%) rename tests/{ => agent}/test_heartbeat_service.py (100%) rename tests/{ => agent}/test_loop_consolidation_tokens.py (100%) rename tests/{ => agent}/test_loop_save_turn.py (100%) rename tests/{ => agent}/test_memory_consolidation_types.py (100%) rename tests/{ => agent}/test_onboard_logic.py (100%) rename tests/{ => agent}/test_session_manager_history.py (100%) rename tests/{ => agent}/test_skill_creator_scripts.py (100%) rename tests/{ => agent}/test_task_cancel.py (100%) rename tests/{ => channels}/test_base_channel.py (100%) rename tests/{ => channels}/test_channel_plugins.py (100%) rename tests/{ => channels}/test_dingtalk_channel.py (95%) rename tests/{ => channels}/test_email_channel.py (100%) rename tests/{ => channels}/test_feishu_markdown_rendering.py (81%) rename tests/{ => channels}/test_feishu_post_content.py (82%) rename tests/{ => channels}/test_feishu_reply.py (97%) rename tests/{ => channels}/test_feishu_table_split.py (89%) rename tests/{ => channels}/test_feishu_tool_hint_code_block.py (93%) rename tests/{ => channels}/test_matrix_channel.py (99%) rename tests/{ => channels}/test_qq_channel.py (93%) rename tests/{ => channels}/test_slack_channel.py (95%) rename tests/{ => channels}/test_telegram_channel.py (99%) rename tests/{ => channels}/test_weixin_channel.py (100%) rename tests/{ => channels}/test_whatsapp_channel.py (100%) rename tests/{ => cli}/test_cli_input.py (100%) rename tests/{ => cli}/test_commands.py (100%) rename tests/{ => cli}/test_restart_command.py (100%) rename tests/{ => config}/test_config_migration.py (100%) rename tests/{ => config}/test_config_paths.py (100%) rename tests/{ => cron}/test_cron_service.py (100%) rename tests/{ => cron}/test_cron_tool_list.py (100%) rename tests/{ => providers}/test_azure_openai_provider.py (100%) rename tests/{ => providers}/test_custom_provider.py (100%) rename tests/{ => providers}/test_litellm_kwargs.py (100%) rename tests/{ => providers}/test_mistral_provider.py (100%) rename tests/{ => providers}/test_provider_retry.py (100%) rename tests/{ => providers}/test_providers_init.py (100%) rename tests/{ => security}/test_security_network.py (100%) rename tests/{ => tools}/test_exec_security.py (100%) rename tests/{ => tools}/test_filesystem_tools.py (100%) rename tests/{ => tools}/test_mcp_tool.py (100%) rename tests/{ => tools}/test_message_tool.py (100%) rename tests/{ => tools}/test_message_tool_suppress.py (100%) rename tests/{ => tools}/test_tool_validation.py (100%) rename tests/{ => tools}/test_web_fetch_security.py (100%) rename tests/{ => tools}/test_web_search_tool.py (100%) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 67a4d9b0d19..e00362d0214 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -21,13 +21,14 @@ jobs: with: python-version: ${{ matrix.python-version }} + - name: Install uv + uses: astral-sh/setup-uv@v4 + - name: Install system dependencies run: sudo apt-get update && sudo apt-get install -y libolm-dev build-essential - - name: Install dependencies - run: | - python -m pip install --upgrade pip - pip install .[dev] + - name: Install all dependencies + run: uv sync --all-extras - name: Run tests - run: python -m pytest tests/ -v + run: uv run pytest tests/ diff --git a/nanobot/agent/tools/shell.py b/nanobot/agent/tools/shell.py index 5b4641297ca..ed552b33e66 100644 --- a/nanobot/agent/tools/shell.py +++ b/nanobot/agent/tools/shell.py @@ -3,6 +3,7 @@ import asyncio import os import re +import sys from pathlib import Path from typing import Any @@ -113,10 +114,11 @@ async def execute( except asyncio.TimeoutError: pass finally: - try: - os.waitpid(process.pid, os.WNOHANG) - except (ProcessLookupError, ChildProcessError) as e: - logger.debug("Process already reaped or not found: {}", e) + if sys.platform != "win32": + try: + os.waitpid(process.pid, os.WNOHANG) + except (ProcessLookupError, ChildProcessError) as e: + logger.debug("Process already reaped or not found: {}", e) return f"Error: Command timed out after {effective_timeout} seconds" output_parts = [] diff --git a/pyproject.toml b/pyproject.toml index a941ab17d8e..be367a473e7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -70,11 +70,8 @@ langsmith = [ dev = [ "pytest>=9.0.0,<10.0.0", "pytest-asyncio>=1.3.0,<2.0.0", + "pytest-cov>=6.0.0,<7.0.0", "ruff>=0.1.0", - "matrix-nio[e2e]>=0.25.2", - "mistune>=3.0.0,<4.0.0", - "nh3>=0.2.17,<1.0.0", - "mypy>=1.19.1", ] [project.scripts] diff --git a/tests/test_consolidate_offset.py b/tests/agent/test_consolidate_offset.py similarity index 100% rename from tests/test_consolidate_offset.py rename to tests/agent/test_consolidate_offset.py diff --git a/tests/test_context_prompt_cache.py b/tests/agent/test_context_prompt_cache.py similarity index 100% rename from tests/test_context_prompt_cache.py rename to tests/agent/test_context_prompt_cache.py diff --git a/tests/test_evaluator.py b/tests/agent/test_evaluator.py similarity index 100% rename from tests/test_evaluator.py rename to tests/agent/test_evaluator.py diff --git a/tests/test_gemini_thought_signature.py b/tests/agent/test_gemini_thought_signature.py similarity index 100% rename from tests/test_gemini_thought_signature.py rename to tests/agent/test_gemini_thought_signature.py diff --git a/tests/test_heartbeat_service.py b/tests/agent/test_heartbeat_service.py similarity index 100% rename from tests/test_heartbeat_service.py rename to tests/agent/test_heartbeat_service.py diff --git a/tests/test_loop_consolidation_tokens.py b/tests/agent/test_loop_consolidation_tokens.py similarity index 100% rename from tests/test_loop_consolidation_tokens.py rename to tests/agent/test_loop_consolidation_tokens.py diff --git a/tests/test_loop_save_turn.py b/tests/agent/test_loop_save_turn.py similarity index 100% rename from tests/test_loop_save_turn.py rename to tests/agent/test_loop_save_turn.py diff --git a/tests/test_memory_consolidation_types.py b/tests/agent/test_memory_consolidation_types.py similarity index 100% rename from tests/test_memory_consolidation_types.py rename to tests/agent/test_memory_consolidation_types.py diff --git a/tests/test_onboard_logic.py b/tests/agent/test_onboard_logic.py similarity index 100% rename from tests/test_onboard_logic.py rename to tests/agent/test_onboard_logic.py diff --git a/tests/test_session_manager_history.py b/tests/agent/test_session_manager_history.py similarity index 100% rename from tests/test_session_manager_history.py rename to tests/agent/test_session_manager_history.py diff --git a/tests/test_skill_creator_scripts.py b/tests/agent/test_skill_creator_scripts.py similarity index 100% rename from tests/test_skill_creator_scripts.py rename to tests/agent/test_skill_creator_scripts.py diff --git a/tests/test_task_cancel.py b/tests/agent/test_task_cancel.py similarity index 100% rename from tests/test_task_cancel.py rename to tests/agent/test_task_cancel.py diff --git a/tests/test_base_channel.py b/tests/channels/test_base_channel.py similarity index 100% rename from tests/test_base_channel.py rename to tests/channels/test_base_channel.py diff --git a/tests/test_channel_plugins.py b/tests/channels/test_channel_plugins.py similarity index 100% rename from tests/test_channel_plugins.py rename to tests/channels/test_channel_plugins.py diff --git a/tests/test_dingtalk_channel.py b/tests/channels/test_dingtalk_channel.py similarity index 95% rename from tests/test_dingtalk_channel.py rename to tests/channels/test_dingtalk_channel.py index a0b866faded..6894c86837c 100644 --- a/tests/test_dingtalk_channel.py +++ b/tests/channels/test_dingtalk_channel.py @@ -3,6 +3,16 @@ import pytest +# Check optional dingtalk dependencies before running tests +try: + from nanobot.channels import dingtalk + DINGTALK_AVAILABLE = getattr(dingtalk, "DINGTALK_AVAILABLE", False) +except ImportError: + DINGTALK_AVAILABLE = False + +if not DINGTALK_AVAILABLE: + pytest.skip("DingTalk dependencies not installed (dingtalk-stream)", allow_module_level=True) + from nanobot.bus.queue import MessageBus import nanobot.channels.dingtalk as dingtalk_module from nanobot.channels.dingtalk import DingTalkChannel, NanobotDingTalkHandler diff --git a/tests/test_email_channel.py b/tests/channels/test_email_channel.py similarity index 100% rename from tests/test_email_channel.py rename to tests/channels/test_email_channel.py diff --git a/tests/test_feishu_markdown_rendering.py b/tests/channels/test_feishu_markdown_rendering.py similarity index 81% rename from tests/test_feishu_markdown_rendering.py rename to tests/channels/test_feishu_markdown_rendering.py index 6812a21aa66..efcd207335b 100644 --- a/tests/test_feishu_markdown_rendering.py +++ b/tests/channels/test_feishu_markdown_rendering.py @@ -1,3 +1,14 @@ +# Check optional Feishu dependencies before running tests +try: + from nanobot.channels import feishu + FEISHU_AVAILABLE = getattr(feishu, "FEISHU_AVAILABLE", False) +except ImportError: + FEISHU_AVAILABLE = False + +if not FEISHU_AVAILABLE: + import pytest + pytest.skip("Feishu dependencies not installed (lark-oapi)", allow_module_level=True) + from nanobot.channels.feishu import FeishuChannel diff --git a/tests/test_feishu_post_content.py b/tests/channels/test_feishu_post_content.py similarity index 82% rename from tests/test_feishu_post_content.py rename to tests/channels/test_feishu_post_content.py index 7b1cb9d31cf..a4c5bae19f5 100644 --- a/tests/test_feishu_post_content.py +++ b/tests/channels/test_feishu_post_content.py @@ -1,3 +1,14 @@ +# Check optional Feishu dependencies before running tests +try: + from nanobot.channels import feishu + FEISHU_AVAILABLE = getattr(feishu, "FEISHU_AVAILABLE", False) +except ImportError: + FEISHU_AVAILABLE = False + +if not FEISHU_AVAILABLE: + import pytest + pytest.skip("Feishu dependencies not installed (lark-oapi)", allow_module_level=True) + from nanobot.channels.feishu import FeishuChannel, _extract_post_content diff --git a/tests/test_feishu_reply.py b/tests/channels/test_feishu_reply.py similarity index 97% rename from tests/test_feishu_reply.py rename to tests/channels/test_feishu_reply.py index b2072b31aaf..0753653a775 100644 --- a/tests/test_feishu_reply.py +++ b/tests/channels/test_feishu_reply.py @@ -7,6 +7,16 @@ import pytest +# Check optional Feishu dependencies before running tests +try: + from nanobot.channels import feishu + FEISHU_AVAILABLE = getattr(feishu, "FEISHU_AVAILABLE", False) +except ImportError: + FEISHU_AVAILABLE = False + +if not FEISHU_AVAILABLE: + pytest.skip("Feishu dependencies not installed (lark-oapi)", allow_module_level=True) + from nanobot.bus.events import OutboundMessage from nanobot.bus.queue import MessageBus from nanobot.channels.feishu import FeishuChannel, FeishuConfig diff --git a/tests/test_feishu_table_split.py b/tests/channels/test_feishu_table_split.py similarity index 89% rename from tests/test_feishu_table_split.py rename to tests/channels/test_feishu_table_split.py index af8fa164a8e..030b8910dd4 100644 --- a/tests/test_feishu_table_split.py +++ b/tests/channels/test_feishu_table_split.py @@ -6,6 +6,17 @@ table, allowing nanobot to send multiple cards instead of failing. """ +# Check optional Feishu dependencies before running tests +try: + from nanobot.channels import feishu + FEISHU_AVAILABLE = getattr(feishu, "FEISHU_AVAILABLE", False) +except ImportError: + FEISHU_AVAILABLE = False + +if not FEISHU_AVAILABLE: + import pytest + pytest.skip("Feishu dependencies not installed (lark-oapi)", allow_module_level=True) + from nanobot.channels.feishu import FeishuChannel diff --git a/tests/test_feishu_tool_hint_code_block.py b/tests/channels/test_feishu_tool_hint_code_block.py similarity index 93% rename from tests/test_feishu_tool_hint_code_block.py rename to tests/channels/test_feishu_tool_hint_code_block.py index 2a1b81227bf..a65f1d9882e 100644 --- a/tests/test_feishu_tool_hint_code_block.py +++ b/tests/channels/test_feishu_tool_hint_code_block.py @@ -6,6 +6,16 @@ import pytest from pytest import mark +# Check optional Feishu dependencies before running tests +try: + from nanobot.channels import feishu + FEISHU_AVAILABLE = getattr(feishu, "FEISHU_AVAILABLE", False) +except ImportError: + FEISHU_AVAILABLE = False + +if not FEISHU_AVAILABLE: + pytest.skip("Feishu dependencies not installed (lark-oapi)", allow_module_level=True) + from nanobot.bus.events import OutboundMessage from nanobot.channels.feishu import FeishuChannel diff --git a/tests/test_matrix_channel.py b/tests/channels/test_matrix_channel.py similarity index 99% rename from tests/test_matrix_channel.py rename to tests/channels/test_matrix_channel.py index 1f3b69ccf81..dd5e97d9056 100644 --- a/tests/test_matrix_channel.py +++ b/tests/channels/test_matrix_channel.py @@ -4,6 +4,12 @@ import pytest +# Check optional matrix dependencies before importing +try: + import nh3 # noqa: F401 +except ImportError: + pytest.skip("Matrix dependencies not installed (nh3)", allow_module_level=True) + import nanobot.channels.matrix as matrix_module from nanobot.bus.events import OutboundMessage from nanobot.bus.queue import MessageBus diff --git a/tests/test_qq_channel.py b/tests/channels/test_qq_channel.py similarity index 93% rename from tests/test_qq_channel.py rename to tests/channels/test_qq_channel.py index ab9afcbc7f5..729442a13e3 100644 --- a/tests/test_qq_channel.py +++ b/tests/channels/test_qq_channel.py @@ -4,6 +4,16 @@ import pytest +# Check optional QQ dependencies before running tests +try: + from nanobot.channels import qq + QQ_AVAILABLE = getattr(qq, "QQ_AVAILABLE", False) +except ImportError: + QQ_AVAILABLE = False + +if not QQ_AVAILABLE: + pytest.skip("QQ dependencies not installed (qq-botpy)", allow_module_level=True) + from nanobot.bus.events import OutboundMessage from nanobot.bus.queue import MessageBus from nanobot.channels.qq import QQChannel, QQConfig diff --git a/tests/test_slack_channel.py b/tests/channels/test_slack_channel.py similarity index 95% rename from tests/test_slack_channel.py rename to tests/channels/test_slack_channel.py index d243235aaa7..f7eec95c036 100644 --- a/tests/test_slack_channel.py +++ b/tests/channels/test_slack_channel.py @@ -2,6 +2,12 @@ import pytest +# Check optional Slack dependencies before running tests +try: + import slack_sdk # noqa: F401 +except ImportError: + pytest.skip("Slack dependencies not installed (slack-sdk)", allow_module_level=True) + from nanobot.bus.events import OutboundMessage from nanobot.bus.queue import MessageBus from nanobot.channels.slack import SlackChannel diff --git a/tests/test_telegram_channel.py b/tests/channels/test_telegram_channel.py similarity index 99% rename from tests/test_telegram_channel.py rename to tests/channels/test_telegram_channel.py index 8b6ba97896e..353d5d05d71 100644 --- a/tests/test_telegram_channel.py +++ b/tests/channels/test_telegram_channel.py @@ -5,6 +5,12 @@ import pytest +# Check optional Telegram dependencies before running tests +try: + import telegram # noqa: F401 +except ImportError: + pytest.skip("Telegram dependencies not installed (python-telegram-bot)", allow_module_level=True) + from nanobot.bus.events import OutboundMessage from nanobot.bus.queue import MessageBus from nanobot.channels.telegram import TELEGRAM_REPLY_CONTEXT_MAX_LEN, TelegramChannel diff --git a/tests/test_weixin_channel.py b/tests/channels/test_weixin_channel.py similarity index 100% rename from tests/test_weixin_channel.py rename to tests/channels/test_weixin_channel.py diff --git a/tests/test_whatsapp_channel.py b/tests/channels/test_whatsapp_channel.py similarity index 100% rename from tests/test_whatsapp_channel.py rename to tests/channels/test_whatsapp_channel.py diff --git a/tests/test_cli_input.py b/tests/cli/test_cli_input.py similarity index 100% rename from tests/test_cli_input.py rename to tests/cli/test_cli_input.py diff --git a/tests/test_commands.py b/tests/cli/test_commands.py similarity index 100% rename from tests/test_commands.py rename to tests/cli/test_commands.py diff --git a/tests/test_restart_command.py b/tests/cli/test_restart_command.py similarity index 100% rename from tests/test_restart_command.py rename to tests/cli/test_restart_command.py diff --git a/tests/test_config_migration.py b/tests/config/test_config_migration.py similarity index 100% rename from tests/test_config_migration.py rename to tests/config/test_config_migration.py diff --git a/tests/test_config_paths.py b/tests/config/test_config_paths.py similarity index 100% rename from tests/test_config_paths.py rename to tests/config/test_config_paths.py diff --git a/tests/test_cron_service.py b/tests/cron/test_cron_service.py similarity index 100% rename from tests/test_cron_service.py rename to tests/cron/test_cron_service.py diff --git a/tests/test_cron_tool_list.py b/tests/cron/test_cron_tool_list.py similarity index 100% rename from tests/test_cron_tool_list.py rename to tests/cron/test_cron_tool_list.py diff --git a/tests/test_azure_openai_provider.py b/tests/providers/test_azure_openai_provider.py similarity index 100% rename from tests/test_azure_openai_provider.py rename to tests/providers/test_azure_openai_provider.py diff --git a/tests/test_custom_provider.py b/tests/providers/test_custom_provider.py similarity index 100% rename from tests/test_custom_provider.py rename to tests/providers/test_custom_provider.py diff --git a/tests/test_litellm_kwargs.py b/tests/providers/test_litellm_kwargs.py similarity index 100% rename from tests/test_litellm_kwargs.py rename to tests/providers/test_litellm_kwargs.py diff --git a/tests/test_mistral_provider.py b/tests/providers/test_mistral_provider.py similarity index 100% rename from tests/test_mistral_provider.py rename to tests/providers/test_mistral_provider.py diff --git a/tests/test_provider_retry.py b/tests/providers/test_provider_retry.py similarity index 100% rename from tests/test_provider_retry.py rename to tests/providers/test_provider_retry.py diff --git a/tests/test_providers_init.py b/tests/providers/test_providers_init.py similarity index 100% rename from tests/test_providers_init.py rename to tests/providers/test_providers_init.py diff --git a/tests/test_security_network.py b/tests/security/test_security_network.py similarity index 100% rename from tests/test_security_network.py rename to tests/security/test_security_network.py diff --git a/tests/test_exec_security.py b/tests/tools/test_exec_security.py similarity index 100% rename from tests/test_exec_security.py rename to tests/tools/test_exec_security.py diff --git a/tests/test_filesystem_tools.py b/tests/tools/test_filesystem_tools.py similarity index 100% rename from tests/test_filesystem_tools.py rename to tests/tools/test_filesystem_tools.py diff --git a/tests/test_mcp_tool.py b/tests/tools/test_mcp_tool.py similarity index 100% rename from tests/test_mcp_tool.py rename to tests/tools/test_mcp_tool.py diff --git a/tests/test_message_tool.py b/tests/tools/test_message_tool.py similarity index 100% rename from tests/test_message_tool.py rename to tests/tools/test_message_tool.py diff --git a/tests/test_message_tool_suppress.py b/tests/tools/test_message_tool_suppress.py similarity index 100% rename from tests/test_message_tool_suppress.py rename to tests/tools/test_message_tool_suppress.py diff --git a/tests/test_tool_validation.py b/tests/tools/test_tool_validation.py similarity index 100% rename from tests/test_tool_validation.py rename to tests/tools/test_tool_validation.py diff --git a/tests/test_web_fetch_security.py b/tests/tools/test_web_fetch_security.py similarity index 100% rename from tests/test_web_fetch_security.py rename to tests/tools/test_web_fetch_security.py diff --git a/tests/test_web_search_tool.py b/tests/tools/test_web_search_tool.py similarity index 100% rename from tests/test_web_search_tool.py rename to tests/tools/test_web_search_tool.py From 38ce054b31ee2bd939a3367854c166b074814b6b Mon Sep 17 00:00:00 2001 From: Xubin Ren Date: Tue, 24 Mar 2026 15:55:43 +0000 Subject: [PATCH 16/40] fix(security): pin litellm and add supply chain advisory note --- README.md | 3 +++ pyproject.toml | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 797a5bcf274..c9d19a1ca15 100644 --- a/README.md +++ b/README.md @@ -20,6 +20,9 @@ ## 📢 News +> [!IMPORTANT] +> **Security note:** Due to `litellm` supply chain poisoning, **please check your Python environment ASAP** and refer to this [advisory](https://github.com/HKUDS/nanobot/discussions/2445) for details. We are also urgently replacing `litellm` and preparing mitigations. + - **2026-03-16** 🚀 Released **v0.1.4.post5** — a refinement-focused release with stronger reliability and channel support, and a more dependable day-to-day experience. Please see [release notes](https://github.com/HKUDS/nanobot/releases/tag/v0.1.4.post5) for details. - **2026-03-15** 🧩 DingTalk rich media, smarter built-in skills, and cleaner model compatibility. - **2026-03-14** 💬 Channel plugins, Feishu replies, and steadier MCP, QQ, and media handling. diff --git a/pyproject.toml b/pyproject.toml index be367a473e7..246ca3074cf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,7 +19,7 @@ classifiers = [ dependencies = [ "typer>=0.20.0,<1.0.0", - "litellm>=1.82.1,<2.0.0", + "litellm>=1.82.1,<=1.82.6", "pydantic>=2.12.0,<3.0.0", "pydantic-settings>=2.12.0,<3.0.0", "websockets>=16.0,<17.0", From 3dfdab704e14b99de3ac93b24642eb9f09daab44 Mon Sep 17 00:00:00 2001 From: Xubin Ren Date: Tue, 24 Mar 2026 17:53:35 +0000 Subject: [PATCH 17/40] refactor: replace litellm with native openai + anthropic SDKs MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Remove litellm dependency entirely (supply chain risk mitigation) - Add AnthropicProvider (native SDK) and OpenAICompatProvider (unified) - Merge CustomProvider into OpenAICompatProvider, delete custom_provider.py - Add ProviderSpec.backend field for declarative provider routing - Remove _resolve_model, find_gateway, find_by_model (dead heuristics) - Pass resolved spec directly into provider — zero internal lookups - Stub out litellm-dependent model database (cli/models.py) - Add anthropic>=0.45.0 to dependencies, remove litellm - 593 tests passed, net -1034 lines --- README.md | 16 +- nanobot/cli/commands.py | 77 +-- nanobot/cli/models.py | 212 +-------- nanobot/config/schema.py | 3 +- nanobot/providers/__init__.py | 15 +- nanobot/providers/anthropic_provider.py | 441 ++++++++++++++++++ nanobot/providers/custom_provider.py | 152 ------ nanobot/providers/litellm_provider.py | 413 ---------------- nanobot/providers/openai_compat_provider.py | 349 ++++++++++++++ nanobot/providers/registry.py | 339 +++----------- pyproject.toml | 2 +- tests/agent/test_gemini_thought_signature.py | 34 -- .../agent/test_memory_consolidation_types.py | 2 +- tests/cli/test_commands.py | 33 +- tests/providers/test_custom_provider.py | 8 +- tests/providers/test_litellm_kwargs.py | 157 +++---- tests/providers/test_mistral_provider.py | 2 - tests/providers/test_providers_init.py | 17 +- 18 files changed, 1014 insertions(+), 1258 deletions(-) create mode 100644 nanobot/providers/anthropic_provider.py delete mode 100644 nanobot/providers/custom_provider.py delete mode 100644 nanobot/providers/litellm_provider.py create mode 100644 nanobot/providers/openai_compat_provider.py diff --git a/README.md b/README.md index c9d19a1ca15..9f5e0d248ae 100644 --- a/README.md +++ b/README.md @@ -842,7 +842,7 @@ Config file: `~/.nanobot/config.json` | Provider | Purpose | Get API Key | |----------|---------|-------------| -| `custom` | Any OpenAI-compatible endpoint (direct, no LiteLLM) | — | +| `custom` | Any OpenAI-compatible endpoint | — | | `openrouter` | LLM (recommended, access to all models) | [openrouter.ai](https://openrouter.ai) | | `volcengine` | LLM (VolcEngine, pay-per-use) | [Coding Plan](https://www.volcengine.com/activity/codingplan?utm_campaign=nanobot&utm_content=nanobot&utm_medium=devrel&utm_source=OWO&utm_term=nanobot) · [volcengine.com](https://www.volcengine.com) | | `byteplus` | LLM (VolcEngine international, pay-per-use) | [Coding Plan](https://www.byteplus.com/en/activity/codingplan?utm_campaign=nanobot&utm_content=nanobot&utm_medium=devrel&utm_source=OWO&utm_term=nanobot) · [byteplus.com](https://www.byteplus.com) | @@ -943,7 +943,7 @@ nanobot agent -c ~/.nanobot-telegram/config.json -w /tmp/nanobot-telegram-test -
Custom Provider (Any OpenAI-compatible API) -Connects directly to any OpenAI-compatible endpoint — LM Studio, llama.cpp, Together AI, Fireworks, Azure OpenAI, or any self-hosted server. Bypasses LiteLLM; model name is passed as-is. +Connects directly to any OpenAI-compatible endpoint — LM Studio, llama.cpp, Together AI, Fireworks, Azure OpenAI, or any self-hosted server. Model name is passed as-is. ```json { @@ -1120,10 +1120,9 @@ Adding a new provider only takes **2 steps** — no if-elif chains to touch. ProviderSpec( name="myprovider", # config field name keywords=("myprovider", "mymodel"), # model-name keywords for auto-matching - env_key="MYPROVIDER_API_KEY", # env var for LiteLLM + env_key="MYPROVIDER_API_KEY", # env var name display_name="My Provider", # shown in `nanobot status` - litellm_prefix="myprovider", # auto-prefix: model → myprovider/model - skip_prefixes=("myprovider/",), # don't double-prefix + default_api_base="https://api.myprovider.com/v1", # OpenAI-compatible endpoint ) ``` @@ -1135,20 +1134,19 @@ class ProvidersConfig(BaseModel): myprovider: ProviderConfig = ProviderConfig() ``` -That's it! Environment variables, model prefixing, config matching, and `nanobot status` display will all work automatically. +That's it! Environment variables, model routing, config matching, and `nanobot status` display will all work automatically. **Common `ProviderSpec` options:** | Field | Description | Example | |-------|-------------|---------| -| `litellm_prefix` | Auto-prefix model names for LiteLLM | `"dashscope"` → `dashscope/qwen-max` | -| `skip_prefixes` | Don't prefix if model already starts with these | `("dashscope/", "openrouter/")` | +| `default_api_base` | OpenAI-compatible base URL | `"https://api.deepseek.com"` | | `env_extras` | Additional env vars to set | `(("ZHIPUAI_API_KEY", "{api_key}"),)` | | `model_overrides` | Per-model parameter overrides | `(("kimi-k2.5", {"temperature": 1.0}),)` | | `is_gateway` | Can route any model (like OpenRouter) | `True` | | `detect_by_key_prefix` | Detect gateway by API key prefix | `"sk-or-"` | | `detect_by_base_keyword` | Detect gateway by API base URL | `"openrouter"` | -| `strip_model_prefix` | Strip existing prefix before re-prefixing | `True` (for AiHubMix) | +| `strip_model_prefix` | Strip provider prefix before sending to gateway | `True` (for AiHubMix) |
diff --git a/nanobot/cli/commands.py b/nanobot/cli/commands.py index 27733239c00..91c81d3dec7 100644 --- a/nanobot/cli/commands.py +++ b/nanobot/cli/commands.py @@ -376,61 +376,61 @@ def _onboard_plugins(config_path: Path) -> None: def _make_provider(config: Config): - """Create the appropriate LLM provider from config.""" - from nanobot.providers.azure_openai_provider import AzureOpenAIProvider + """Create the appropriate LLM provider from config. + + Routing is driven by ``ProviderSpec.backend`` in the registry. + """ from nanobot.providers.base import GenerationSettings - from nanobot.providers.openai_codex_provider import OpenAICodexProvider + from nanobot.providers.registry import find_by_name model = config.agents.defaults.model provider_name = config.get_provider_name(model) p = config.get_provider(model) + spec = find_by_name(provider_name) if provider_name else None + backend = spec.backend if spec else "openai_compat" - # OpenAI Codex (OAuth) - if provider_name == "openai_codex" or model.startswith("openai-codex/"): - provider = OpenAICodexProvider(default_model=model) - # Custom: direct OpenAI-compatible endpoint, bypasses LiteLLM - elif provider_name == "custom": - from nanobot.providers.custom_provider import CustomProvider - provider = CustomProvider( - api_key=p.api_key if p else "no-key", - api_base=config.get_api_base(model) or "http://localhost:8000/v1", - default_model=model, - extra_headers=p.extra_headers if p else None, - ) - # Azure OpenAI: direct Azure OpenAI endpoint with deployment name - elif provider_name == "azure_openai": + # --- validation --- + if backend == "azure_openai": if not p or not p.api_key or not p.api_base: console.print("[red]Error: Azure OpenAI requires api_key and api_base.[/red]") console.print("Set them in ~/.nanobot/config.json under providers.azure_openai section") console.print("Use the model field to specify the deployment name.") raise typer.Exit(1) + elif backend == "openai_compat" and not model.startswith("bedrock/"): + needs_key = not (p and p.api_key) + exempt = spec and (spec.is_oauth or spec.is_local or spec.is_direct) + if needs_key and not exempt: + console.print("[red]Error: No API key configured.[/red]") + console.print("Set one in ~/.nanobot/config.json under providers section") + raise typer.Exit(1) + + # --- instantiation by backend --- + if backend == "openai_codex": + from nanobot.providers.openai_codex_provider import OpenAICodexProvider + provider = OpenAICodexProvider(default_model=model) + elif backend == "azure_openai": + from nanobot.providers.azure_openai_provider import AzureOpenAIProvider provider = AzureOpenAIProvider( api_key=p.api_key, api_base=p.api_base, default_model=model, ) - # OpenVINO Model Server: direct OpenAI-compatible endpoint at /v3 - elif provider_name == "ovms": - from nanobot.providers.custom_provider import CustomProvider - provider = CustomProvider( - api_key=p.api_key if p else "no-key", - api_base=config.get_api_base(model) or "http://localhost:8000/v3", + elif backend == "anthropic": + from nanobot.providers.anthropic_provider import AnthropicProvider + provider = AnthropicProvider( + api_key=p.api_key if p else None, + api_base=config.get_api_base(model), default_model=model, + extra_headers=p.extra_headers if p else None, ) else: - from nanobot.providers.litellm_provider import LiteLLMProvider - from nanobot.providers.registry import find_by_name - spec = find_by_name(provider_name) - if not model.startswith("bedrock/") and not (p and p.api_key) and not (spec and (spec.is_oauth or spec.is_local)): - console.print("[red]Error: No API key configured.[/red]") - console.print("Set one in ~/.nanobot/config.json under providers section") - raise typer.Exit(1) - provider = LiteLLMProvider( + from nanobot.providers.openai_compat_provider import OpenAICompatProvider + provider = OpenAICompatProvider( api_key=p.api_key if p else None, api_base=config.get_api_base(model), default_model=model, extra_headers=p.extra_headers if p else None, - provider_name=provider_name, + spec=spec, ) defaults = config.agents.defaults @@ -1203,11 +1203,20 @@ def _login_openai_codex() -> None: def _login_github_copilot() -> None: import asyncio + from openai import AsyncOpenAI + console.print("[cyan]Starting GitHub Copilot device flow...[/cyan]\n") async def _trigger(): - from litellm import acompletion - await acompletion(model="github_copilot/gpt-4o", messages=[{"role": "user", "content": "hi"}], max_tokens=1) + client = AsyncOpenAI( + api_key="dummy", + base_url="https://api.githubcopilot.com", + ) + await client.chat.completions.create( + model="gpt-4o", + messages=[{"role": "user", "content": "hi"}], + max_tokens=1, + ) try: asyncio.run(_trigger()) diff --git a/nanobot/cli/models.py b/nanobot/cli/models.py index 520370c4b8f..0ba24018fc4 100644 --- a/nanobot/cli/models.py +++ b/nanobot/cli/models.py @@ -1,229 +1,29 @@ """Model information helpers for the onboard wizard. -Provides model context window lookup and autocomplete suggestions using litellm. +Model database / autocomplete is temporarily disabled while litellm is +being replaced. All public function signatures are preserved so callers +continue to work without changes. """ from __future__ import annotations -from functools import lru_cache from typing import Any -def _litellm(): - """Lazy accessor for litellm (heavy import deferred until actually needed).""" - import litellm as _ll - - return _ll - - -@lru_cache(maxsize=1) -def _get_model_cost_map() -> dict[str, Any]: - """Get litellm's model cost map (cached).""" - return getattr(_litellm(), "model_cost", {}) - - -@lru_cache(maxsize=1) def get_all_models() -> list[str]: - """Get all known model names from litellm. - """ - models = set() - - # From model_cost (has pricing info) - cost_map = _get_model_cost_map() - for k in cost_map.keys(): - if k != "sample_spec": - models.add(k) - - # From models_by_provider (more complete provider coverage) - for provider_models in getattr(_litellm(), "models_by_provider", {}).values(): - if isinstance(provider_models, (set, list)): - models.update(provider_models) - - return sorted(models) - - -def _normalize_model_name(model: str) -> str: - """Normalize model name for comparison.""" - return model.lower().replace("-", "_").replace(".", "") + return [] def find_model_info(model_name: str) -> dict[str, Any] | None: - """Find model info with fuzzy matching. - - Args: - model_name: Model name in any common format - - Returns: - Model info dict or None if not found - """ - cost_map = _get_model_cost_map() - if not cost_map: - return None - - # Direct match - if model_name in cost_map: - return cost_map[model_name] - - # Extract base name (without provider prefix) - base_name = model_name.split("/")[-1] if "/" in model_name else model_name - base_normalized = _normalize_model_name(base_name) - - candidates = [] - - for key, info in cost_map.items(): - if key == "sample_spec": - continue - - key_base = key.split("/")[-1] if "/" in key else key - key_base_normalized = _normalize_model_name(key_base) - - # Score the match - score = 0 - - # Exact base name match (highest priority) - if base_normalized == key_base_normalized: - score = 100 - # Base name contains model - elif base_normalized in key_base_normalized: - score = 80 - # Model contains base name - elif key_base_normalized in base_normalized: - score = 70 - # Partial match - elif base_normalized[:10] in key_base_normalized: - score = 50 - - if score > 0: - # Prefer models with max_input_tokens - if info.get("max_input_tokens"): - score += 10 - candidates.append((score, key, info)) - - if not candidates: - return None - - # Return the best match - candidates.sort(key=lambda x: (-x[0], x[1])) - return candidates[0][2] + return None def get_model_context_limit(model: str, provider: str = "auto") -> int | None: - """Get the maximum input context tokens for a model. - - Args: - model: Model name (e.g., "claude-3.5-sonnet", "gpt-4o") - provider: Provider name for informational purposes (not yet used for filtering) - - Returns: - Maximum input tokens, or None if unknown - - Note: - The provider parameter is currently informational only. Future versions may - use it to prefer provider-specific model variants in the lookup. - """ - # First try fuzzy search in model_cost (has more accurate max_input_tokens) - info = find_model_info(model) - if info: - # Prefer max_input_tokens (this is what we want for context window) - max_input = info.get("max_input_tokens") - if max_input and isinstance(max_input, int): - return max_input - - # Fall back to litellm's get_max_tokens (returns max_output_tokens typically) - try: - result = _litellm().get_max_tokens(model) - if result and result > 0: - return result - except (KeyError, ValueError, AttributeError): - # Model not found in litellm's database or invalid response - pass - - # Last resort: use max_tokens from model_cost - if info: - max_tokens = info.get("max_tokens") - if max_tokens and isinstance(max_tokens, int): - return max_tokens - return None -@lru_cache(maxsize=1) -def _get_provider_keywords() -> dict[str, list[str]]: - """Build provider keywords mapping from nanobot's provider registry. - - Returns: - Dict mapping provider name to list of keywords for model filtering. - """ - try: - from nanobot.providers.registry import PROVIDERS - - mapping = {} - for spec in PROVIDERS: - if spec.keywords: - mapping[spec.name] = list(spec.keywords) - return mapping - except ImportError: - return {} - - def get_model_suggestions(partial: str, provider: str = "auto", limit: int = 20) -> list[str]: - """Get autocomplete suggestions for model names. - - Args: - partial: Partial model name typed by user - provider: Provider name for filtering (e.g., "openrouter", "minimax") - limit: Maximum number of suggestions to return - - Returns: - List of matching model names - """ - all_models = get_all_models() - if not all_models: - return [] - - partial_lower = partial.lower() - partial_normalized = _normalize_model_name(partial) - - # Get provider keywords from registry - provider_keywords = _get_provider_keywords() - - # Filter by provider if specified - allowed_keywords = None - if provider and provider != "auto": - allowed_keywords = provider_keywords.get(provider.lower()) - - matches = [] - - for model in all_models: - model_lower = model.lower() - - # Apply provider filter - if allowed_keywords: - if not any(kw in model_lower for kw in allowed_keywords): - continue - - # Match against partial input - if not partial: - matches.append(model) - continue - - if partial_lower in model_lower: - # Score by position of match (earlier = better) - pos = model_lower.find(partial_lower) - score = 100 - pos - matches.append((score, model)) - elif partial_normalized in _normalize_model_name(model): - score = 50 - matches.append((score, model)) - - # Sort by score if we have scored matches - if matches and isinstance(matches[0], tuple): - matches.sort(key=lambda x: (-x[0], x[1])) - matches = [m[1] for m in matches] - else: - matches.sort() - - return matches[:limit] + return [] def format_token_count(tokens: int) -> str: diff --git a/nanobot/config/schema.py b/nanobot/config/schema.py index b31f3061a83..9ae662ec84e 100644 --- a/nanobot/config/schema.py +++ b/nanobot/config/schema.py @@ -249,8 +249,7 @@ def get_api_base(self, model: str | None = None) -> str | None: if p and p.api_base: return p.api_base # Only gateways get a default api_base here. Standard providers - # (like Moonshot) set their base URL via env vars in _setup_env - # to avoid polluting the global litellm.api_base. + # resolve their base URL from the registry in the provider constructor. if name: spec = find_by_name(name) if spec and (spec.is_gateway or spec.is_local) and spec.default_api_base: diff --git a/nanobot/providers/__init__.py b/nanobot/providers/__init__.py index 9d4994eb13c..0e259e6f014 100644 --- a/nanobot/providers/__init__.py +++ b/nanobot/providers/__init__.py @@ -7,17 +7,26 @@ from nanobot.providers.base import LLMProvider, LLMResponse -__all__ = ["LLMProvider", "LLMResponse", "LiteLLMProvider", "OpenAICodexProvider", "AzureOpenAIProvider"] +__all__ = [ + "LLMProvider", + "LLMResponse", + "AnthropicProvider", + "OpenAICompatProvider", + "OpenAICodexProvider", + "AzureOpenAIProvider", +] _LAZY_IMPORTS = { - "LiteLLMProvider": ".litellm_provider", + "AnthropicProvider": ".anthropic_provider", + "OpenAICompatProvider": ".openai_compat_provider", "OpenAICodexProvider": ".openai_codex_provider", "AzureOpenAIProvider": ".azure_openai_provider", } if TYPE_CHECKING: + from nanobot.providers.anthropic_provider import AnthropicProvider from nanobot.providers.azure_openai_provider import AzureOpenAIProvider - from nanobot.providers.litellm_provider import LiteLLMProvider + from nanobot.providers.openai_compat_provider import OpenAICompatProvider from nanobot.providers.openai_codex_provider import OpenAICodexProvider diff --git a/nanobot/providers/anthropic_provider.py b/nanobot/providers/anthropic_provider.py new file mode 100644 index 00000000000..3c789e73068 --- /dev/null +++ b/nanobot/providers/anthropic_provider.py @@ -0,0 +1,441 @@ +"""Anthropic provider — direct SDK integration for Claude models.""" + +from __future__ import annotations + +import re +import secrets +import string +from collections.abc import Awaitable, Callable +from typing import Any + +import json_repair +from loguru import logger + +from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest + +_ALNUM = string.ascii_letters + string.digits + + +def _gen_tool_id() -> str: + return "toolu_" + "".join(secrets.choice(_ALNUM) for _ in range(22)) + + +class AnthropicProvider(LLMProvider): + """LLM provider using the native Anthropic SDK for Claude models. + + Handles message format conversion (OpenAI → Anthropic Messages API), + prompt caching, extended thinking, tool calls, and streaming. + """ + + def __init__( + self, + api_key: str | None = None, + api_base: str | None = None, + default_model: str = "claude-sonnet-4-20250514", + extra_headers: dict[str, str] | None = None, + ): + super().__init__(api_key, api_base) + self.default_model = default_model + self.extra_headers = extra_headers or {} + + from anthropic import AsyncAnthropic + + client_kw: dict[str, Any] = {} + if api_key: + client_kw["api_key"] = api_key + if api_base: + client_kw["base_url"] = api_base + if extra_headers: + client_kw["default_headers"] = extra_headers + self._client = AsyncAnthropic(**client_kw) + + @staticmethod + def _strip_prefix(model: str) -> str: + if model.startswith("anthropic/"): + return model[len("anthropic/"):] + return model + + # ------------------------------------------------------------------ + # Message conversion: OpenAI chat format → Anthropic Messages API + # ------------------------------------------------------------------ + + def _convert_messages( + self, messages: list[dict[str, Any]], + ) -> tuple[str | list[dict[str, Any]], list[dict[str, Any]]]: + """Return ``(system, anthropic_messages)``.""" + system: str | list[dict[str, Any]] = "" + raw: list[dict[str, Any]] = [] + + for msg in messages: + role = msg.get("role", "") + content = msg.get("content") + + if role == "system": + system = content if isinstance(content, (str, list)) else str(content or "") + continue + + if role == "tool": + block = self._tool_result_block(msg) + if raw and raw[-1]["role"] == "user": + prev_c = raw[-1]["content"] + if isinstance(prev_c, list): + prev_c.append(block) + else: + raw[-1]["content"] = [ + {"type": "text", "text": prev_c or ""}, block, + ] + else: + raw.append({"role": "user", "content": [block]}) + continue + + if role == "assistant": + raw.append({"role": "assistant", "content": self._assistant_blocks(msg)}) + continue + + if role == "user": + raw.append({ + "role": "user", + "content": self._convert_user_content(content), + }) + continue + + return system, self._merge_consecutive(raw) + + @staticmethod + def _tool_result_block(msg: dict[str, Any]) -> dict[str, Any]: + content = msg.get("content") + block: dict[str, Any] = { + "type": "tool_result", + "tool_use_id": msg.get("tool_call_id", ""), + } + if isinstance(content, (str, list)): + block["content"] = content + else: + block["content"] = str(content) if content else "" + return block + + @staticmethod + def _assistant_blocks(msg: dict[str, Any]) -> list[dict[str, Any]]: + blocks: list[dict[str, Any]] = [] + content = msg.get("content") + + for tb in msg.get("thinking_blocks") or []: + if isinstance(tb, dict) and tb.get("type") == "thinking": + blocks.append({ + "type": "thinking", + "thinking": tb.get("thinking", ""), + "signature": tb.get("signature", ""), + }) + + if isinstance(content, str) and content: + blocks.append({"type": "text", "text": content}) + elif isinstance(content, list): + for item in content: + blocks.append(item if isinstance(item, dict) else {"type": "text", "text": str(item)}) + + for tc in msg.get("tool_calls") or []: + if not isinstance(tc, dict): + continue + func = tc.get("function", {}) + args = func.get("arguments", "{}") + if isinstance(args, str): + args = json_repair.loads(args) + blocks.append({ + "type": "tool_use", + "id": tc.get("id") or _gen_tool_id(), + "name": func.get("name", ""), + "input": args, + }) + + return blocks or [{"type": "text", "text": ""}] + + def _convert_user_content(self, content: Any) -> Any: + """Convert user message content, translating image_url blocks.""" + if isinstance(content, str) or content is None: + return content or "(empty)" + if not isinstance(content, list): + return str(content) + + result: list[dict[str, Any]] = [] + for item in content: + if not isinstance(item, dict): + result.append({"type": "text", "text": str(item)}) + continue + if item.get("type") == "image_url": + converted = self._convert_image_block(item) + if converted: + result.append(converted) + continue + result.append(item) + return result or "(empty)" + + @staticmethod + def _convert_image_block(block: dict[str, Any]) -> dict[str, Any] | None: + """Convert OpenAI image_url block to Anthropic image block.""" + url = (block.get("image_url") or {}).get("url", "") + if not url: + return None + m = re.match(r"data:(image/\w+);base64,(.+)", url, re.DOTALL) + if m: + return { + "type": "image", + "source": {"type": "base64", "media_type": m.group(1), "data": m.group(2)}, + } + return { + "type": "image", + "source": {"type": "url", "url": url}, + } + + @staticmethod + def _merge_consecutive(msgs: list[dict[str, Any]]) -> list[dict[str, Any]]: + """Anthropic requires alternating user/assistant roles.""" + merged: list[dict[str, Any]] = [] + for msg in msgs: + if merged and merged[-1]["role"] == msg["role"]: + prev_c = merged[-1]["content"] + cur_c = msg["content"] + if isinstance(prev_c, str): + prev_c = [{"type": "text", "text": prev_c}] + if isinstance(cur_c, str): + cur_c = [{"type": "text", "text": cur_c}] + if isinstance(cur_c, list): + prev_c.extend(cur_c) + merged[-1]["content"] = prev_c + else: + merged.append(msg) + return merged + + # ------------------------------------------------------------------ + # Tool definition conversion + # ------------------------------------------------------------------ + + @staticmethod + def _convert_tools(tools: list[dict[str, Any]] | None) -> list[dict[str, Any]] | None: + if not tools: + return None + result = [] + for tool in tools: + func = tool.get("function", tool) + entry: dict[str, Any] = { + "name": func.get("name", ""), + "input_schema": func.get("parameters", {"type": "object", "properties": {}}), + } + desc = func.get("description") + if desc: + entry["description"] = desc + if "cache_control" in tool: + entry["cache_control"] = tool["cache_control"] + result.append(entry) + return result + + @staticmethod + def _convert_tool_choice( + tool_choice: str | dict[str, Any] | None, + thinking_enabled: bool = False, + ) -> dict[str, Any] | None: + if thinking_enabled: + return {"type": "auto"} + if tool_choice is None or tool_choice == "auto": + return {"type": "auto"} + if tool_choice == "required": + return {"type": "any"} + if tool_choice == "none": + return None + if isinstance(tool_choice, dict): + name = tool_choice.get("function", {}).get("name") + if name: + return {"type": "tool", "name": name} + return {"type": "auto"} + + # ------------------------------------------------------------------ + # Prompt caching + # ------------------------------------------------------------------ + + @staticmethod + def _apply_cache_control( + system: str | list[dict[str, Any]], + messages: list[dict[str, Any]], + tools: list[dict[str, Any]] | None, + ) -> tuple[str | list[dict[str, Any]], list[dict[str, Any]], list[dict[str, Any]] | None]: + marker = {"type": "ephemeral"} + + if isinstance(system, str) and system: + system = [{"type": "text", "text": system, "cache_control": marker}] + elif isinstance(system, list) and system: + system = list(system) + system[-1] = {**system[-1], "cache_control": marker} + + new_msgs = list(messages) + if len(new_msgs) >= 3: + m = new_msgs[-2] + c = m.get("content") + if isinstance(c, str): + new_msgs[-2] = {**m, "content": [{"type": "text", "text": c, "cache_control": marker}]} + elif isinstance(c, list) and c: + nc = list(c) + nc[-1] = {**nc[-1], "cache_control": marker} + new_msgs[-2] = {**m, "content": nc} + + new_tools = tools + if tools: + new_tools = list(tools) + new_tools[-1] = {**new_tools[-1], "cache_control": marker} + + return system, new_msgs, new_tools + + # ------------------------------------------------------------------ + # Build API kwargs + # ------------------------------------------------------------------ + + def _build_kwargs( + self, + messages: list[dict[str, Any]], + tools: list[dict[str, Any]] | None, + model: str | None, + max_tokens: int, + temperature: float, + reasoning_effort: str | None, + tool_choice: str | dict[str, Any] | None, + supports_caching: bool = True, + ) -> dict[str, Any]: + model_name = self._strip_prefix(model or self.default_model) + system, anthropic_msgs = self._convert_messages(self._sanitize_empty_content(messages)) + anthropic_tools = self._convert_tools(tools) + + if supports_caching: + system, anthropic_msgs, anthropic_tools = self._apply_cache_control( + system, anthropic_msgs, anthropic_tools, + ) + + max_tokens = max(1, max_tokens) + thinking_enabled = bool(reasoning_effort) + + kwargs: dict[str, Any] = { + "model": model_name, + "messages": anthropic_msgs, + "max_tokens": max_tokens, + } + + if system: + kwargs["system"] = system + + if thinking_enabled: + budget_map = {"low": 1024, "medium": 4096, "high": max(8192, max_tokens)} + budget = budget_map.get(reasoning_effort.lower(), 4096) # type: ignore[union-attr] + kwargs["thinking"] = {"type": "enabled", "budget_tokens": budget} + kwargs["max_tokens"] = max(max_tokens, budget + 4096) + kwargs["temperature"] = 1.0 + else: + kwargs["temperature"] = temperature + + if anthropic_tools: + kwargs["tools"] = anthropic_tools + tc = self._convert_tool_choice(tool_choice, thinking_enabled) + if tc: + kwargs["tool_choice"] = tc + + if self.extra_headers: + kwargs["extra_headers"] = self.extra_headers + + return kwargs + + # ------------------------------------------------------------------ + # Response parsing + # ------------------------------------------------------------------ + + @staticmethod + def _parse_response(response: Any) -> LLMResponse: + content_parts: list[str] = [] + tool_calls: list[ToolCallRequest] = [] + thinking_blocks: list[dict[str, Any]] = [] + + for block in response.content: + if block.type == "text": + content_parts.append(block.text) + elif block.type == "tool_use": + tool_calls.append(ToolCallRequest( + id=block.id, + name=block.name, + arguments=block.input if isinstance(block.input, dict) else {}, + )) + elif block.type == "thinking": + thinking_blocks.append({ + "type": "thinking", + "thinking": block.thinking, + "signature": getattr(block, "signature", ""), + }) + + stop_map = {"tool_use": "tool_calls", "end_turn": "stop", "max_tokens": "length"} + finish_reason = stop_map.get(response.stop_reason or "", response.stop_reason or "stop") + + usage: dict[str, int] = {} + if response.usage: + usage = { + "prompt_tokens": response.usage.input_tokens, + "completion_tokens": response.usage.output_tokens, + "total_tokens": response.usage.input_tokens + response.usage.output_tokens, + } + for attr in ("cache_creation_input_tokens", "cache_read_input_tokens"): + val = getattr(response.usage, attr, 0) + if val: + usage[attr] = val + + return LLMResponse( + content="".join(content_parts) or None, + tool_calls=tool_calls, + finish_reason=finish_reason, + usage=usage, + thinking_blocks=thinking_blocks or None, + ) + + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ + + async def chat( + self, + messages: list[dict[str, Any]], + tools: list[dict[str, Any]] | None = None, + model: str | None = None, + max_tokens: int = 4096, + temperature: float = 0.7, + reasoning_effort: str | None = None, + tool_choice: str | dict[str, Any] | None = None, + ) -> LLMResponse: + kwargs = self._build_kwargs( + messages, tools, model, max_tokens, temperature, + reasoning_effort, tool_choice, + ) + try: + response = await self._client.messages.create(**kwargs) + return self._parse_response(response) + except Exception as e: + return LLMResponse(content=f"Error calling LLM: {e}", finish_reason="error") + + async def chat_stream( + self, + messages: list[dict[str, Any]], + tools: list[dict[str, Any]] | None = None, + model: str | None = None, + max_tokens: int = 4096, + temperature: float = 0.7, + reasoning_effort: str | None = None, + tool_choice: str | dict[str, Any] | None = None, + on_content_delta: Callable[[str], Awaitable[None]] | None = None, + ) -> LLMResponse: + kwargs = self._build_kwargs( + messages, tools, model, max_tokens, temperature, + reasoning_effort, tool_choice, + ) + try: + async with self._client.messages.stream(**kwargs) as stream: + if on_content_delta: + async for text in stream.text_stream: + await on_content_delta(text) + response = await stream.get_final_message() + return self._parse_response(response) + except Exception as e: + return LLMResponse(content=f"Error calling LLM: {e}", finish_reason="error") + + def get_default_model(self) -> str: + return self.default_model diff --git a/nanobot/providers/custom_provider.py b/nanobot/providers/custom_provider.py deleted file mode 100644 index a47dae7cd1e..00000000000 --- a/nanobot/providers/custom_provider.py +++ /dev/null @@ -1,152 +0,0 @@ -"""Direct OpenAI-compatible provider — bypasses LiteLLM.""" - -from __future__ import annotations - -import uuid -from collections.abc import Awaitable, Callable -from typing import Any - -import json_repair -from openai import AsyncOpenAI - -from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest - - -class CustomProvider(LLMProvider): - - def __init__( - self, - api_key: str = "no-key", - api_base: str = "http://localhost:8000/v1", - default_model: str = "default", - extra_headers: dict[str, str] | None = None, - ): - super().__init__(api_key, api_base) - self.default_model = default_model - self._client = AsyncOpenAI( - api_key=api_key, - base_url=api_base, - default_headers={ - "x-session-affinity": uuid.uuid4().hex, - **(extra_headers or {}), - }, - ) - - def _build_kwargs( - self, messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None, - model: str | None, max_tokens: int, temperature: float, - reasoning_effort: str | None, tool_choice: str | dict[str, Any] | None, - ) -> dict[str, Any]: - kwargs: dict[str, Any] = { - "model": model or self.default_model, - "messages": self._sanitize_empty_content(messages), - "max_tokens": max(1, max_tokens), - "temperature": temperature, - } - if reasoning_effort: - kwargs["reasoning_effort"] = reasoning_effort - if tools: - kwargs.update(tools=tools, tool_choice=tool_choice or "auto") - return kwargs - - def _handle_error(self, e: Exception) -> LLMResponse: - body = getattr(e, "doc", None) or getattr(getattr(e, "response", None), "text", None) - msg = f"Error: {body.strip()[:500]}" if body and body.strip() else f"Error: {e}" - return LLMResponse(content=msg, finish_reason="error") - - async def chat(self, messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None = None, - model: str | None = None, max_tokens: int = 4096, temperature: float = 0.7, - reasoning_effort: str | None = None, - tool_choice: str | dict[str, Any] | None = None) -> LLMResponse: - kwargs = self._build_kwargs(messages, tools, model, max_tokens, temperature, reasoning_effort, tool_choice) - try: - return self._parse(await self._client.chat.completions.create(**kwargs)) - except Exception as e: - return self._handle_error(e) - - async def chat_stream( - self, messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None = None, - model: str | None = None, max_tokens: int = 4096, temperature: float = 0.7, - reasoning_effort: str | None = None, - tool_choice: str | dict[str, Any] | None = None, - on_content_delta: Callable[[str], Awaitable[None]] | None = None, - ) -> LLMResponse: - kwargs = self._build_kwargs(messages, tools, model, max_tokens, temperature, reasoning_effort, tool_choice) - kwargs["stream"] = True - try: - stream = await self._client.chat.completions.create(**kwargs) - chunks: list[Any] = [] - async for chunk in stream: - chunks.append(chunk) - if on_content_delta and chunk.choices: - text = getattr(chunk.choices[0].delta, "content", None) - if text: - await on_content_delta(text) - return self._parse_chunks(chunks) - except Exception as e: - return self._handle_error(e) - - def _parse(self, response: Any) -> LLMResponse: - if not response.choices: - return LLMResponse( - content="Error: API returned empty choices.", - finish_reason="error", - ) - choice = response.choices[0] - msg = choice.message - tool_calls = [ - ToolCallRequest( - id=tc.id, name=tc.function.name, - arguments=json_repair.loads(tc.function.arguments) if isinstance(tc.function.arguments, str) else tc.function.arguments, - ) - for tc in (msg.tool_calls or []) - ] - u = response.usage - return LLMResponse( - content=msg.content, tool_calls=tool_calls, - finish_reason=choice.finish_reason or "stop", - usage={"prompt_tokens": u.prompt_tokens, "completion_tokens": u.completion_tokens, "total_tokens": u.total_tokens} if u else {}, - reasoning_content=getattr(msg, "reasoning_content", None) or None, - ) - - def _parse_chunks(self, chunks: list[Any]) -> LLMResponse: - """Reassemble streamed chunks into a single LLMResponse.""" - content_parts: list[str] = [] - tc_bufs: dict[int, dict[str, str]] = {} - finish_reason = "stop" - usage: dict[str, int] = {} - - for chunk in chunks: - if not chunk.choices: - if hasattr(chunk, "usage") and chunk.usage: - u = chunk.usage - usage = {"prompt_tokens": u.prompt_tokens or 0, "completion_tokens": u.completion_tokens or 0, - "total_tokens": u.total_tokens or 0} - continue - choice = chunk.choices[0] - if choice.finish_reason: - finish_reason = choice.finish_reason - delta = choice.delta - if delta and delta.content: - content_parts.append(delta.content) - for tc in (delta.tool_calls or []) if delta else []: - buf = tc_bufs.setdefault(tc.index, {"id": "", "name": "", "arguments": ""}) - if tc.id: - buf["id"] = tc.id - if tc.function and tc.function.name: - buf["name"] = tc.function.name - if tc.function and tc.function.arguments: - buf["arguments"] += tc.function.arguments - - return LLMResponse( - content="".join(content_parts) or None, - tool_calls=[ - ToolCallRequest(id=b["id"], name=b["name"], arguments=json_repair.loads(b["arguments"]) if b["arguments"] else {}) - for b in tc_bufs.values() - ], - finish_reason=finish_reason, - usage=usage, - ) - - def get_default_model(self) -> str: - return self.default_model diff --git a/nanobot/providers/litellm_provider.py b/nanobot/providers/litellm_provider.py deleted file mode 100644 index 9aa0ba68061..00000000000 --- a/nanobot/providers/litellm_provider.py +++ /dev/null @@ -1,413 +0,0 @@ -"""LiteLLM provider implementation for multi-provider support.""" - -import hashlib -import os -import secrets -import string -from collections.abc import Awaitable, Callable -from typing import Any - -import json_repair -import litellm -from litellm import acompletion -from loguru import logger - -from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest -from nanobot.providers.registry import find_by_model, find_gateway - -# Standard chat-completion message keys. -_ALLOWED_MSG_KEYS = frozenset({"role", "content", "tool_calls", "tool_call_id", "name", "reasoning_content"}) -_ANTHROPIC_EXTRA_KEYS = frozenset({"thinking_blocks"}) -_ALNUM = string.ascii_letters + string.digits - -def _short_tool_id() -> str: - """Generate a 9-char alphanumeric ID compatible with all providers (incl. Mistral).""" - return "".join(secrets.choice(_ALNUM) for _ in range(9)) - - -class LiteLLMProvider(LLMProvider): - """ - LLM provider using LiteLLM for multi-provider support. - - Supports OpenRouter, Anthropic, OpenAI, Gemini, MiniMax, and many other providers through - a unified interface. Provider-specific logic is driven by the registry - (see providers/registry.py) — no if-elif chains needed here. - """ - - def __init__( - self, - api_key: str | None = None, - api_base: str | None = None, - default_model: str = "anthropic/claude-opus-4-5", - extra_headers: dict[str, str] | None = None, - provider_name: str | None = None, - ): - super().__init__(api_key, api_base) - self.default_model = default_model - self.extra_headers = extra_headers or {} - - # Detect gateway / local deployment. - # provider_name (from config key) is the primary signal; - # api_key / api_base are fallback for auto-detection. - self._gateway = find_gateway(provider_name, api_key, api_base) - - # Configure environment variables - if api_key: - self._setup_env(api_key, api_base, default_model) - - if api_base: - litellm.api_base = api_base - - # Disable LiteLLM logging noise - litellm.suppress_debug_info = True - # Drop unsupported parameters for providers (e.g., gpt-5 rejects some params) - litellm.drop_params = True - - self._langsmith_enabled = bool(os.getenv("LANGSMITH_API_KEY")) - - def _setup_env(self, api_key: str, api_base: str | None, model: str) -> None: - """Set environment variables based on detected provider.""" - spec = self._gateway or find_by_model(model) - if not spec: - return - if not spec.env_key: - # OAuth/provider-only specs (for example: openai_codex) - return - - # Gateway/local overrides existing env; standard provider doesn't - if self._gateway: - os.environ[spec.env_key] = api_key - else: - os.environ.setdefault(spec.env_key, api_key) - - # Resolve env_extras placeholders: - # {api_key} → user's API key - # {api_base} → user's api_base, falling back to spec.default_api_base - effective_base = api_base or spec.default_api_base - for env_name, env_val in spec.env_extras: - resolved = env_val.replace("{api_key}", api_key) - resolved = resolved.replace("{api_base}", effective_base) - os.environ.setdefault(env_name, resolved) - - def _resolve_model(self, model: str) -> str: - """Resolve model name by applying provider/gateway prefixes.""" - if self._gateway: - prefix = self._gateway.litellm_prefix - if self._gateway.strip_model_prefix: - model = model.split("/")[-1] - if prefix: - model = f"{prefix}/{model}" - return model - - # Standard mode: auto-prefix for known providers - spec = find_by_model(model) - if spec and spec.litellm_prefix: - model = self._canonicalize_explicit_prefix(model, spec.name, spec.litellm_prefix) - if not any(model.startswith(s) for s in spec.skip_prefixes): - model = f"{spec.litellm_prefix}/{model}" - - return model - - @staticmethod - def _canonicalize_explicit_prefix(model: str, spec_name: str, canonical_prefix: str) -> str: - """Normalize explicit provider prefixes like `github-copilot/...`.""" - if "/" not in model: - return model - prefix, remainder = model.split("/", 1) - if prefix.lower().replace("-", "_") != spec_name: - return model - return f"{canonical_prefix}/{remainder}" - - def _supports_cache_control(self, model: str) -> bool: - """Return True when the provider supports cache_control on content blocks.""" - if self._gateway is not None: - return self._gateway.supports_prompt_caching - spec = find_by_model(model) - return spec is not None and spec.supports_prompt_caching - - def _apply_cache_control( - self, - messages: list[dict[str, Any]], - tools: list[dict[str, Any]] | None, - ) -> tuple[list[dict[str, Any]], list[dict[str, Any]] | None]: - """Return copies of messages and tools with cache_control injected. - - Two breakpoints are placed: - 1. System message — caches the static system prompt - 2. Second-to-last message — caches the conversation history prefix - This maximises cache hits across multi-turn conversations. - """ - cache_marker = {"type": "ephemeral"} - new_messages = list(messages) - - def _mark(msg: dict[str, Any]) -> dict[str, Any]: - content = msg.get("content") - if isinstance(content, str): - return {**msg, "content": [ - {"type": "text", "text": content, "cache_control": cache_marker} - ]} - elif isinstance(content, list) and content: - new_content = list(content) - new_content[-1] = {**new_content[-1], "cache_control": cache_marker} - return {**msg, "content": new_content} - return msg - - # Breakpoint 1: system message - if new_messages and new_messages[0].get("role") == "system": - new_messages[0] = _mark(new_messages[0]) - - # Breakpoint 2: second-to-last message (caches conversation history prefix) - if len(new_messages) >= 3: - new_messages[-2] = _mark(new_messages[-2]) - - new_tools = tools - if tools: - new_tools = list(tools) - new_tools[-1] = {**new_tools[-1], "cache_control": cache_marker} - - return new_messages, new_tools - - def _apply_model_overrides(self, model: str, kwargs: dict[str, Any]) -> None: - """Apply model-specific parameter overrides from the registry.""" - model_lower = model.lower() - spec = find_by_model(model) - if spec: - for pattern, overrides in spec.model_overrides: - if pattern in model_lower: - kwargs.update(overrides) - return - - @staticmethod - def _extra_msg_keys(original_model: str, resolved_model: str) -> frozenset[str]: - """Return provider-specific extra keys to preserve in request messages.""" - spec = find_by_model(original_model) or find_by_model(resolved_model) - if (spec and spec.name == "anthropic") or "claude" in original_model.lower() or resolved_model.startswith("anthropic/"): - return _ANTHROPIC_EXTRA_KEYS - return frozenset() - - @staticmethod - def _normalize_tool_call_id(tool_call_id: Any) -> Any: - """Normalize tool_call_id to a provider-safe 9-char alphanumeric form.""" - if not isinstance(tool_call_id, str): - return tool_call_id - if len(tool_call_id) == 9 and tool_call_id.isalnum(): - return tool_call_id - return hashlib.sha1(tool_call_id.encode()).hexdigest()[:9] - - @staticmethod - def _sanitize_messages(messages: list[dict[str, Any]], extra_keys: frozenset[str] = frozenset()) -> list[dict[str, Any]]: - """Strip non-standard keys and ensure assistant messages have a content key.""" - allowed = _ALLOWED_MSG_KEYS | extra_keys - sanitized = LLMProvider._sanitize_request_messages(messages, allowed) - id_map: dict[str, str] = {} - - def map_id(value: Any) -> Any: - if not isinstance(value, str): - return value - return id_map.setdefault(value, LiteLLMProvider._normalize_tool_call_id(value)) - - for clean in sanitized: - # Keep assistant tool_calls[].id and tool tool_call_id in sync after - # shortening, otherwise strict providers reject the broken linkage. - if isinstance(clean.get("tool_calls"), list): - normalized_tool_calls = [] - for tc in clean["tool_calls"]: - if not isinstance(tc, dict): - normalized_tool_calls.append(tc) - continue - tc_clean = dict(tc) - tc_clean["id"] = map_id(tc_clean.get("id")) - normalized_tool_calls.append(tc_clean) - clean["tool_calls"] = normalized_tool_calls - - if "tool_call_id" in clean and clean["tool_call_id"]: - clean["tool_call_id"] = map_id(clean["tool_call_id"]) - return sanitized - - def _build_chat_kwargs( - self, - messages: list[dict[str, Any]], - tools: list[dict[str, Any]] | None, - model: str | None, - max_tokens: int, - temperature: float, - reasoning_effort: str | None, - tool_choice: str | dict[str, Any] | None, - ) -> tuple[dict[str, Any], str]: - """Build the kwargs dict for ``acompletion``. - - Returns ``(kwargs, original_model)`` so callers can reuse the - original model string for downstream logic. - """ - original_model = model or self.default_model - resolved = self._resolve_model(original_model) - extra_msg_keys = self._extra_msg_keys(original_model, resolved) - - if self._supports_cache_control(original_model): - messages, tools = self._apply_cache_control(messages, tools) - - max_tokens = max(1, max_tokens) - - kwargs: dict[str, Any] = { - "model": resolved, - "messages": self._sanitize_messages( - self._sanitize_empty_content(messages), extra_keys=extra_msg_keys, - ), - "max_tokens": max_tokens, - "temperature": temperature, - } - - if self._gateway: - kwargs.update(self._gateway.litellm_kwargs) - - self._apply_model_overrides(resolved, kwargs) - - if self._langsmith_enabled: - kwargs.setdefault("callbacks", []).append("langsmith") - - if self.api_key: - kwargs["api_key"] = self.api_key - if self.api_base: - kwargs["api_base"] = self.api_base - if self.extra_headers: - kwargs["extra_headers"] = self.extra_headers - - if reasoning_effort: - kwargs["reasoning_effort"] = reasoning_effort - kwargs["drop_params"] = True - - if tools: - kwargs["tools"] = tools - kwargs["tool_choice"] = tool_choice or "auto" - - return kwargs, original_model - - async def chat( - self, - messages: list[dict[str, Any]], - tools: list[dict[str, Any]] | None = None, - model: str | None = None, - max_tokens: int = 4096, - temperature: float = 0.7, - reasoning_effort: str | None = None, - tool_choice: str | dict[str, Any] | None = None, - ) -> LLMResponse: - """Send a chat completion request via LiteLLM.""" - kwargs, _ = self._build_chat_kwargs( - messages, tools, model, max_tokens, temperature, - reasoning_effort, tool_choice, - ) - try: - response = await acompletion(**kwargs) - return self._parse_response(response) - except Exception as e: - return LLMResponse( - content=f"Error calling LLM: {str(e)}", - finish_reason="error", - ) - - async def chat_stream( - self, - messages: list[dict[str, Any]], - tools: list[dict[str, Any]] | None = None, - model: str | None = None, - max_tokens: int = 4096, - temperature: float = 0.7, - reasoning_effort: str | None = None, - tool_choice: str | dict[str, Any] | None = None, - on_content_delta: Callable[[str], Awaitable[None]] | None = None, - ) -> LLMResponse: - """Stream a chat completion via LiteLLM, forwarding text deltas.""" - kwargs, _ = self._build_chat_kwargs( - messages, tools, model, max_tokens, temperature, - reasoning_effort, tool_choice, - ) - kwargs["stream"] = True - - try: - stream = await acompletion(**kwargs) - chunks: list[Any] = [] - async for chunk in stream: - chunks.append(chunk) - if on_content_delta: - delta = chunk.choices[0].delta if chunk.choices else None - text = getattr(delta, "content", None) if delta else None - if text: - await on_content_delta(text) - - full_response = litellm.stream_chunk_builder( - chunks, messages=kwargs["messages"], - ) - return self._parse_response(full_response) - except Exception as e: - return LLMResponse( - content=f"Error calling LLM: {str(e)}", - finish_reason="error", - ) - - def _parse_response(self, response: Any) -> LLMResponse: - """Parse LiteLLM response into our standard format.""" - choice = response.choices[0] - message = choice.message - content = message.content - finish_reason = choice.finish_reason - - # Some providers (e.g. GitHub Copilot) split content and tool_calls - # across multiple choices. Merge them so tool_calls are not lost. - raw_tool_calls = [] - for ch in response.choices: - msg = ch.message - if hasattr(msg, "tool_calls") and msg.tool_calls: - raw_tool_calls.extend(msg.tool_calls) - if ch.finish_reason in ("tool_calls", "stop"): - finish_reason = ch.finish_reason - if not content and msg.content: - content = msg.content - - if len(response.choices) > 1: - logger.debug("LiteLLM response has {} choices, merged {} tool_calls", - len(response.choices), len(raw_tool_calls)) - - tool_calls = [] - for tc in raw_tool_calls: - # Parse arguments from JSON string if needed - args = tc.function.arguments - if isinstance(args, str): - args = json_repair.loads(args) - - provider_specific_fields = getattr(tc, "provider_specific_fields", None) or None - function_provider_specific_fields = ( - getattr(tc.function, "provider_specific_fields", None) or None - ) - - tool_calls.append(ToolCallRequest( - id=_short_tool_id(), - name=tc.function.name, - arguments=args, - provider_specific_fields=provider_specific_fields, - function_provider_specific_fields=function_provider_specific_fields, - )) - - usage = {} - if hasattr(response, "usage") and response.usage: - usage = { - "prompt_tokens": response.usage.prompt_tokens, - "completion_tokens": response.usage.completion_tokens, - "total_tokens": response.usage.total_tokens, - } - - reasoning_content = getattr(message, "reasoning_content", None) or None - thinking_blocks = getattr(message, "thinking_blocks", None) or None - - return LLMResponse( - content=content, - tool_calls=tool_calls, - finish_reason=finish_reason or "stop", - usage=usage, - reasoning_content=reasoning_content, - thinking_blocks=thinking_blocks, - ) - - def get_default_model(self) -> str: - """Get the default model.""" - return self.default_model diff --git a/nanobot/providers/openai_compat_provider.py b/nanobot/providers/openai_compat_provider.py new file mode 100644 index 00000000000..a210bf72d2c --- /dev/null +++ b/nanobot/providers/openai_compat_provider.py @@ -0,0 +1,349 @@ +"""OpenAI-compatible provider for all non-Anthropic LLM APIs.""" + +from __future__ import annotations + +import hashlib +import os +import secrets +import string +import uuid +from collections.abc import Awaitable, Callable +from typing import TYPE_CHECKING, Any + +import json_repair +from openai import AsyncOpenAI + +from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest + +if TYPE_CHECKING: + from nanobot.providers.registry import ProviderSpec + +_ALLOWED_MSG_KEYS = frozenset({ + "role", "content", "tool_calls", "tool_call_id", "name", "reasoning_content", +}) +_ALNUM = string.ascii_letters + string.digits + + +def _short_tool_id() -> str: + """9-char alphanumeric ID compatible with all providers (incl. Mistral).""" + return "".join(secrets.choice(_ALNUM) for _ in range(9)) + + +class OpenAICompatProvider(LLMProvider): + """Unified provider for all OpenAI-compatible APIs. + + Receives a resolved ``ProviderSpec`` from the caller — no internal + registry lookups needed. + """ + + def __init__( + self, + api_key: str | None = None, + api_base: str | None = None, + default_model: str = "gpt-4o", + extra_headers: dict[str, str] | None = None, + spec: ProviderSpec | None = None, + ): + super().__init__(api_key, api_base) + self.default_model = default_model + self.extra_headers = extra_headers or {} + self._spec = spec + + if api_key and spec and spec.env_key: + self._setup_env(api_key, api_base) + + effective_base = api_base or (spec.default_api_base if spec else None) or None + + self._client = AsyncOpenAI( + api_key=api_key or "no-key", + base_url=effective_base, + default_headers={ + "x-session-affinity": uuid.uuid4().hex, + **(extra_headers or {}), + }, + ) + + def _setup_env(self, api_key: str, api_base: str | None) -> None: + """Set environment variables based on provider spec.""" + spec = self._spec + if not spec or not spec.env_key: + return + if spec.is_gateway: + os.environ[spec.env_key] = api_key + else: + os.environ.setdefault(spec.env_key, api_key) + effective_base = api_base or spec.default_api_base + for env_name, env_val in spec.env_extras: + resolved = env_val.replace("{api_key}", api_key).replace("{api_base}", effective_base) + os.environ.setdefault(env_name, resolved) + + @staticmethod + def _apply_cache_control( + messages: list[dict[str, Any]], + tools: list[dict[str, Any]] | None, + ) -> tuple[list[dict[str, Any]], list[dict[str, Any]] | None]: + """Inject cache_control markers for prompt caching.""" + cache_marker = {"type": "ephemeral"} + new_messages = list(messages) + + def _mark(msg: dict[str, Any]) -> dict[str, Any]: + content = msg.get("content") + if isinstance(content, str): + return {**msg, "content": [ + {"type": "text", "text": content, "cache_control": cache_marker}, + ]} + if isinstance(content, list) and content: + nc = list(content) + nc[-1] = {**nc[-1], "cache_control": cache_marker} + return {**msg, "content": nc} + return msg + + if new_messages and new_messages[0].get("role") == "system": + new_messages[0] = _mark(new_messages[0]) + if len(new_messages) >= 3: + new_messages[-2] = _mark(new_messages[-2]) + + new_tools = tools + if tools: + new_tools = list(tools) + new_tools[-1] = {**new_tools[-1], "cache_control": cache_marker} + return new_messages, new_tools + + @staticmethod + def _normalize_tool_call_id(tool_call_id: Any) -> Any: + """Normalize to a provider-safe 9-char alphanumeric form.""" + if not isinstance(tool_call_id, str): + return tool_call_id + if len(tool_call_id) == 9 and tool_call_id.isalnum(): + return tool_call_id + return hashlib.sha1(tool_call_id.encode()).hexdigest()[:9] + + def _sanitize_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]: + """Strip non-standard keys, normalize tool_call IDs.""" + sanitized = LLMProvider._sanitize_request_messages(messages, _ALLOWED_MSG_KEYS) + id_map: dict[str, str] = {} + + def map_id(value: Any) -> Any: + if not isinstance(value, str): + return value + return id_map.setdefault(value, self._normalize_tool_call_id(value)) + + for clean in sanitized: + if isinstance(clean.get("tool_calls"), list): + normalized = [] + for tc in clean["tool_calls"]: + if not isinstance(tc, dict): + normalized.append(tc) + continue + tc_clean = dict(tc) + tc_clean["id"] = map_id(tc_clean.get("id")) + normalized.append(tc_clean) + clean["tool_calls"] = normalized + if "tool_call_id" in clean and clean["tool_call_id"]: + clean["tool_call_id"] = map_id(clean["tool_call_id"]) + return sanitized + + # ------------------------------------------------------------------ + # Build kwargs + # ------------------------------------------------------------------ + + def _build_kwargs( + self, + messages: list[dict[str, Any]], + tools: list[dict[str, Any]] | None, + model: str | None, + max_tokens: int, + temperature: float, + reasoning_effort: str | None, + tool_choice: str | dict[str, Any] | None, + ) -> dict[str, Any]: + model_name = model or self.default_model + spec = self._spec + + if spec and spec.supports_prompt_caching: + messages, tools = self._apply_cache_control(messages, tools) + + if spec and spec.strip_model_prefix: + model_name = model_name.split("/")[-1] + + kwargs: dict[str, Any] = { + "model": model_name, + "messages": self._sanitize_messages(self._sanitize_empty_content(messages)), + "max_tokens": max(1, max_tokens), + "temperature": temperature, + } + + if spec: + model_lower = model_name.lower() + for pattern, overrides in spec.model_overrides: + if pattern in model_lower: + kwargs.update(overrides) + break + + if reasoning_effort: + kwargs["reasoning_effort"] = reasoning_effort + + if tools: + kwargs["tools"] = tools + kwargs["tool_choice"] = tool_choice or "auto" + + return kwargs + + # ------------------------------------------------------------------ + # Response parsing + # ------------------------------------------------------------------ + + def _parse(self, response: Any) -> LLMResponse: + if not response.choices: + return LLMResponse(content="Error: API returned empty choices.", finish_reason="error") + + choice = response.choices[0] + msg = choice.message + content = msg.content + finish_reason = choice.finish_reason + + raw_tool_calls: list[Any] = [] + for ch in response.choices: + m = ch.message + if hasattr(m, "tool_calls") and m.tool_calls: + raw_tool_calls.extend(m.tool_calls) + if ch.finish_reason in ("tool_calls", "stop"): + finish_reason = ch.finish_reason + if not content and m.content: + content = m.content + + tool_calls = [] + for tc in raw_tool_calls: + args = tc.function.arguments + if isinstance(args, str): + args = json_repair.loads(args) + tool_calls.append(ToolCallRequest( + id=_short_tool_id(), + name=tc.function.name, + arguments=args, + )) + + usage: dict[str, int] = {} + if hasattr(response, "usage") and response.usage: + u = response.usage + usage = { + "prompt_tokens": u.prompt_tokens or 0, + "completion_tokens": u.completion_tokens or 0, + "total_tokens": u.total_tokens or 0, + } + + return LLMResponse( + content=content, + tool_calls=tool_calls, + finish_reason=finish_reason or "stop", + usage=usage, + reasoning_content=getattr(msg, "reasoning_content", None) or None, + ) + + @staticmethod + def _parse_chunks(chunks: list[Any]) -> LLMResponse: + content_parts: list[str] = [] + tc_bufs: dict[int, dict[str, str]] = {} + finish_reason = "stop" + usage: dict[str, int] = {} + + for chunk in chunks: + if not chunk.choices: + if hasattr(chunk, "usage") and chunk.usage: + u = chunk.usage + usage = { + "prompt_tokens": u.prompt_tokens or 0, + "completion_tokens": u.completion_tokens or 0, + "total_tokens": u.total_tokens or 0, + } + continue + choice = chunk.choices[0] + if choice.finish_reason: + finish_reason = choice.finish_reason + delta = choice.delta + if delta and delta.content: + content_parts.append(delta.content) + for tc in (delta.tool_calls or []) if delta else []: + buf = tc_bufs.setdefault(tc.index, {"id": "", "name": "", "arguments": ""}) + if tc.id: + buf["id"] = tc.id + if tc.function and tc.function.name: + buf["name"] = tc.function.name + if tc.function and tc.function.arguments: + buf["arguments"] += tc.function.arguments + + return LLMResponse( + content="".join(content_parts) or None, + tool_calls=[ + ToolCallRequest( + id=b["id"] or _short_tool_id(), + name=b["name"], + arguments=json_repair.loads(b["arguments"]) if b["arguments"] else {}, + ) + for b in tc_bufs.values() + ], + finish_reason=finish_reason, + usage=usage, + ) + + @staticmethod + def _handle_error(e: Exception) -> LLMResponse: + body = getattr(e, "doc", None) or getattr(getattr(e, "response", None), "text", None) + msg = f"Error: {body.strip()[:500]}" if body and body.strip() else f"Error calling LLM: {e}" + return LLMResponse(content=msg, finish_reason="error") + + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ + + async def chat( + self, + messages: list[dict[str, Any]], + tools: list[dict[str, Any]] | None = None, + model: str | None = None, + max_tokens: int = 4096, + temperature: float = 0.7, + reasoning_effort: str | None = None, + tool_choice: str | dict[str, Any] | None = None, + ) -> LLMResponse: + kwargs = self._build_kwargs( + messages, tools, model, max_tokens, temperature, + reasoning_effort, tool_choice, + ) + try: + return self._parse(await self._client.chat.completions.create(**kwargs)) + except Exception as e: + return self._handle_error(e) + + async def chat_stream( + self, + messages: list[dict[str, Any]], + tools: list[dict[str, Any]] | None = None, + model: str | None = None, + max_tokens: int = 4096, + temperature: float = 0.7, + reasoning_effort: str | None = None, + tool_choice: str | dict[str, Any] | None = None, + on_content_delta: Callable[[str], Awaitable[None]] | None = None, + ) -> LLMResponse: + kwargs = self._build_kwargs( + messages, tools, model, max_tokens, temperature, + reasoning_effort, tool_choice, + ) + kwargs["stream"] = True + kwargs["stream_options"] = {"include_usage": True} + try: + stream = await self._client.chat.completions.create(**kwargs) + chunks: list[Any] = [] + async for chunk in stream: + chunks.append(chunk) + if on_content_delta and chunk.choices: + text = getattr(chunk.choices[0].delta, "content", None) + if text: + await on_content_delta(text) + return self._parse_chunks(chunks) + except Exception as e: + return self._handle_error(e) + + def get_default_model(self) -> str: + return self.default_model diff --git a/nanobot/providers/registry.py b/nanobot/providers/registry.py index 10e0fec9db8..206b0b5042f 100644 --- a/nanobot/providers/registry.py +++ b/nanobot/providers/registry.py @@ -4,7 +4,7 @@ Adding a new provider: 1. Add a ProviderSpec to PROVIDERS below. 2. Add a field to ProvidersConfig in config/schema.py. - Done. Env vars, prefixing, config matching, status display all derive from here. + Done. Env vars, config matching, status display all derive from here. Order matters — it controls match priority and fallback. Gateways first. Every entry writes out all fields so you can copy-paste as a template. @@ -12,7 +12,7 @@ from __future__ import annotations -from dataclasses import dataclass, field +from dataclasses import dataclass from typing import Any from pydantic.alias_generators import to_snake @@ -30,12 +30,12 @@ class ProviderSpec: # identity name: str # config field name, e.g. "dashscope" keywords: tuple[str, ...] # model-name keywords for matching (lowercase) - env_key: str # LiteLLM env var, e.g. "DASHSCOPE_API_KEY" + env_key: str # env var for API key, e.g. "DASHSCOPE_API_KEY" display_name: str = "" # shown in `nanobot status` - # model prefixing - litellm_prefix: str = "" # "dashscope" → model becomes "dashscope/{model}" - skip_prefixes: tuple[str, ...] = () # don't prefix if model already starts with these + # which provider implementation to use + # "openai_compat" | "anthropic" | "azure_openai" | "openai_codex" + backend: str = "openai_compat" # extra env vars, e.g. (("ZHIPUAI_API_KEY", "{api_key}"),) env_extras: tuple[tuple[str, str], ...] = () @@ -45,19 +45,18 @@ class ProviderSpec: is_local: bool = False # local deployment (vLLM, Ollama) detect_by_key_prefix: str = "" # match api_key prefix, e.g. "sk-or-" detect_by_base_keyword: str = "" # match substring in api_base URL - default_api_base: str = "" # fallback base URL + default_api_base: str = "" # OpenAI-compatible base URL for this provider # gateway behavior - strip_model_prefix: bool = False # strip "provider/" before re-prefixing - litellm_kwargs: dict[str, Any] = field(default_factory=dict) # extra kwargs passed to LiteLLM + strip_model_prefix: bool = False # strip "provider/" before sending to gateway # per-model param overrides, e.g. (("kimi-k2.5", {"temperature": 1.0}),) model_overrides: tuple[tuple[str, dict[str, Any]], ...] = () # OAuth-based providers (e.g., OpenAI Codex) don't use API keys - is_oauth: bool = False # if True, uses OAuth flow instead of API key + is_oauth: bool = False - # Direct providers bypass LiteLLM entirely (e.g., CustomProvider) + # Direct providers skip API-key validation (user supplies everything) is_direct: bool = False # Provider supports cache_control on content blocks (e.g. Anthropic prompt caching) @@ -73,13 +72,13 @@ def label(self) -> str: # --------------------------------------------------------------------------- PROVIDERS: tuple[ProviderSpec, ...] = ( - # === Custom (direct OpenAI-compatible endpoint, bypasses LiteLLM) ====== + # === Custom (direct OpenAI-compatible endpoint) ======================== ProviderSpec( name="custom", keywords=(), env_key="", display_name="Custom", - litellm_prefix="", + backend="openai_compat", is_direct=True, ), @@ -89,7 +88,7 @@ def label(self) -> str: keywords=("azure", "azure-openai"), env_key="", display_name="Azure OpenAI", - litellm_prefix="", + backend="azure_openai", is_direct=True, ), # === Gateways (detected by api_key / api_base, not model name) ========= @@ -100,36 +99,26 @@ def label(self) -> str: keywords=("openrouter",), env_key="OPENROUTER_API_KEY", display_name="OpenRouter", - litellm_prefix="openrouter", # anthropic/claude-3 → openrouter/anthropic/claude-3 - skip_prefixes=(), - env_extras=(), + backend="openai_compat", is_gateway=True, - is_local=False, detect_by_key_prefix="sk-or-", detect_by_base_keyword="openrouter", default_api_base="https://openrouter.ai/api/v1", - strip_model_prefix=False, - model_overrides=(), supports_prompt_caching=True, ), # AiHubMix: global gateway, OpenAI-compatible interface. - # strip_model_prefix=True: it doesn't understand "anthropic/claude-3", - # so we strip to bare "claude-3" then re-prefix as "openai/claude-3". + # strip_model_prefix=True: doesn't understand "anthropic/claude-3", + # strips to bare "claude-3". ProviderSpec( name="aihubmix", keywords=("aihubmix",), - env_key="OPENAI_API_KEY", # OpenAI-compatible + env_key="OPENAI_API_KEY", display_name="AiHubMix", - litellm_prefix="openai", # → openai/{model} - skip_prefixes=(), - env_extras=(), + backend="openai_compat", is_gateway=True, - is_local=False, - detect_by_key_prefix="", detect_by_base_keyword="aihubmix", default_api_base="https://aihubmix.com/v1", - strip_model_prefix=True, # anthropic/claude-3 → claude-3 → openai/claude-3 - model_overrides=(), + strip_model_prefix=True, ), # SiliconFlow (硅基流动): OpenAI-compatible gateway, model names keep org prefix ProviderSpec( @@ -137,16 +126,10 @@ def label(self) -> str: keywords=("siliconflow",), env_key="OPENAI_API_KEY", display_name="SiliconFlow", - litellm_prefix="openai", - skip_prefixes=(), - env_extras=(), + backend="openai_compat", is_gateway=True, - is_local=False, - detect_by_key_prefix="", detect_by_base_keyword="siliconflow", default_api_base="https://api.siliconflow.cn/v1", - strip_model_prefix=False, - model_overrides=(), ), # VolcEngine (火山引擎): OpenAI-compatible gateway, pay-per-use models @@ -155,16 +138,10 @@ def label(self) -> str: keywords=("volcengine", "volces", "ark"), env_key="OPENAI_API_KEY", display_name="VolcEngine", - litellm_prefix="volcengine", - skip_prefixes=(), - env_extras=(), + backend="openai_compat", is_gateway=True, - is_local=False, - detect_by_key_prefix="", detect_by_base_keyword="volces", default_api_base="https://ark.cn-beijing.volces.com/api/v3", - strip_model_prefix=False, - model_overrides=(), ), # VolcEngine Coding Plan (火山引擎 Coding Plan): same key as volcengine @@ -173,16 +150,10 @@ def label(self) -> str: keywords=("volcengine-plan",), env_key="OPENAI_API_KEY", display_name="VolcEngine Coding Plan", - litellm_prefix="volcengine", - skip_prefixes=(), - env_extras=(), + backend="openai_compat", is_gateway=True, - is_local=False, - detect_by_key_prefix="", - detect_by_base_keyword="", default_api_base="https://ark.cn-beijing.volces.com/api/coding/v3", strip_model_prefix=True, - model_overrides=(), ), # BytePlus: VolcEngine international, pay-per-use models @@ -191,16 +162,11 @@ def label(self) -> str: keywords=("byteplus",), env_key="OPENAI_API_KEY", display_name="BytePlus", - litellm_prefix="volcengine", - skip_prefixes=(), - env_extras=(), + backend="openai_compat", is_gateway=True, - is_local=False, - detect_by_key_prefix="", detect_by_base_keyword="bytepluses", default_api_base="https://ark.ap-southeast.bytepluses.com/api/v3", strip_model_prefix=True, - model_overrides=(), ), # BytePlus Coding Plan: same key as byteplus @@ -209,250 +175,137 @@ def label(self) -> str: keywords=("byteplus-plan",), env_key="OPENAI_API_KEY", display_name="BytePlus Coding Plan", - litellm_prefix="volcengine", - skip_prefixes=(), - env_extras=(), + backend="openai_compat", is_gateway=True, - is_local=False, - detect_by_key_prefix="", - detect_by_base_keyword="", default_api_base="https://ark.ap-southeast.bytepluses.com/api/coding/v3", strip_model_prefix=True, - model_overrides=(), ), # === Standard providers (matched by model-name keywords) =============== - # Anthropic: LiteLLM recognizes "claude-*" natively, no prefix needed. + # Anthropic: native Anthropic SDK ProviderSpec( name="anthropic", keywords=("anthropic", "claude"), env_key="ANTHROPIC_API_KEY", display_name="Anthropic", - litellm_prefix="", - skip_prefixes=(), - env_extras=(), - is_gateway=False, - is_local=False, - detect_by_key_prefix="", - detect_by_base_keyword="", - default_api_base="", - strip_model_prefix=False, - model_overrides=(), + backend="anthropic", supports_prompt_caching=True, ), - # OpenAI: LiteLLM recognizes "gpt-*" natively, no prefix needed. + # OpenAI: SDK default base URL (no override needed) ProviderSpec( name="openai", keywords=("openai", "gpt"), env_key="OPENAI_API_KEY", display_name="OpenAI", - litellm_prefix="", - skip_prefixes=(), - env_extras=(), - is_gateway=False, - is_local=False, - detect_by_key_prefix="", - detect_by_base_keyword="", - default_api_base="", - strip_model_prefix=False, - model_overrides=(), + backend="openai_compat", ), - # OpenAI Codex: uses OAuth, not API key. + # OpenAI Codex: OAuth-based, dedicated provider ProviderSpec( name="openai_codex", keywords=("openai-codex",), - env_key="", # OAuth-based, no API key + env_key="", display_name="OpenAI Codex", - litellm_prefix="", # Not routed through LiteLLM - skip_prefixes=(), - env_extras=(), - is_gateway=False, - is_local=False, - detect_by_key_prefix="", + backend="openai_codex", detect_by_base_keyword="codex", default_api_base="https://chatgpt.com/backend-api", - strip_model_prefix=False, - model_overrides=(), - is_oauth=True, # OAuth-based authentication + is_oauth=True, ), - # Github Copilot: uses OAuth, not API key. + # GitHub Copilot: OAuth-based ProviderSpec( name="github_copilot", keywords=("github_copilot", "copilot"), - env_key="", # OAuth-based, no API key + env_key="", display_name="Github Copilot", - litellm_prefix="github_copilot", # github_copilot/model → github_copilot/model - skip_prefixes=("github_copilot/",), - env_extras=(), - is_gateway=False, - is_local=False, - detect_by_key_prefix="", - detect_by_base_keyword="", - default_api_base="", - strip_model_prefix=False, - model_overrides=(), - is_oauth=True, # OAuth-based authentication + backend="openai_compat", + default_api_base="https://api.githubcopilot.com", + is_oauth=True, ), - # DeepSeek: needs "deepseek/" prefix for LiteLLM routing. + # DeepSeek: OpenAI-compatible at api.deepseek.com ProviderSpec( name="deepseek", keywords=("deepseek",), env_key="DEEPSEEK_API_KEY", display_name="DeepSeek", - litellm_prefix="deepseek", # deepseek-chat → deepseek/deepseek-chat - skip_prefixes=("deepseek/",), # avoid double-prefix - env_extras=(), - is_gateway=False, - is_local=False, - detect_by_key_prefix="", - detect_by_base_keyword="", - default_api_base="", - strip_model_prefix=False, - model_overrides=(), + backend="openai_compat", + default_api_base="https://api.deepseek.com", ), - # Gemini: needs "gemini/" prefix for LiteLLM. + # Gemini: Google's OpenAI-compatible endpoint ProviderSpec( name="gemini", keywords=("gemini",), env_key="GEMINI_API_KEY", display_name="Gemini", - litellm_prefix="gemini", # gemini-pro → gemini/gemini-pro - skip_prefixes=("gemini/",), # avoid double-prefix - env_extras=(), - is_gateway=False, - is_local=False, - detect_by_key_prefix="", - detect_by_base_keyword="", - default_api_base="", - strip_model_prefix=False, - model_overrides=(), + backend="openai_compat", + default_api_base="https://generativelanguage.googleapis.com/v1beta/openai/", ), - # Zhipu: LiteLLM uses "zai/" prefix. - # Also mirrors key to ZHIPUAI_API_KEY (some LiteLLM paths check that). - # skip_prefixes: don't add "zai/" when already routed via gateway. + # Zhipu (智谱): OpenAI-compatible at open.bigmodel.cn ProviderSpec( name="zhipu", keywords=("zhipu", "glm", "zai"), env_key="ZAI_API_KEY", display_name="Zhipu AI", - litellm_prefix="zai", # glm-4 → zai/glm-4 - skip_prefixes=("zhipu/", "zai/", "openrouter/", "hosted_vllm/"), + backend="openai_compat", env_extras=(("ZHIPUAI_API_KEY", "{api_key}"),), - is_gateway=False, - is_local=False, - detect_by_key_prefix="", - detect_by_base_keyword="", - default_api_base="", - strip_model_prefix=False, - model_overrides=(), + default_api_base="https://open.bigmodel.cn/api/paas/v4", ), - # DashScope: Qwen models, needs "dashscope/" prefix. + # DashScope (通义): Qwen models, OpenAI-compatible endpoint ProviderSpec( name="dashscope", keywords=("qwen", "dashscope"), env_key="DASHSCOPE_API_KEY", display_name="DashScope", - litellm_prefix="dashscope", # qwen-max → dashscope/qwen-max - skip_prefixes=("dashscope/", "openrouter/"), - env_extras=(), - is_gateway=False, - is_local=False, - detect_by_key_prefix="", - detect_by_base_keyword="", - default_api_base="", - strip_model_prefix=False, - model_overrides=(), + backend="openai_compat", + default_api_base="https://dashscope.aliyuncs.com/compatible-mode/v1", ), - # Moonshot: Kimi models, needs "moonshot/" prefix. - # LiteLLM requires MOONSHOT_API_BASE env var to find the endpoint. - # Kimi K2.5 API enforces temperature >= 1.0. + # Moonshot (月之暗面): Kimi models. K2.5 enforces temperature >= 1.0. ProviderSpec( name="moonshot", keywords=("moonshot", "kimi"), env_key="MOONSHOT_API_KEY", display_name="Moonshot", - litellm_prefix="moonshot", # kimi-k2.5 → moonshot/kimi-k2.5 - skip_prefixes=("moonshot/", "openrouter/"), - env_extras=(("MOONSHOT_API_BASE", "{api_base}"),), - is_gateway=False, - is_local=False, - detect_by_key_prefix="", - detect_by_base_keyword="", - default_api_base="https://api.moonshot.ai/v1", # intl; use api.moonshot.cn for China - strip_model_prefix=False, + backend="openai_compat", + default_api_base="https://api.moonshot.ai/v1", model_overrides=(("kimi-k2.5", {"temperature": 1.0}),), ), - # MiniMax: needs "minimax/" prefix for LiteLLM routing. - # Uses OpenAI-compatible API at api.minimax.io/v1. + # MiniMax: OpenAI-compatible API ProviderSpec( name="minimax", keywords=("minimax",), env_key="MINIMAX_API_KEY", display_name="MiniMax", - litellm_prefix="minimax", # MiniMax-M2.1 → minimax/MiniMax-M2.1 - skip_prefixes=("minimax/", "openrouter/"), - env_extras=(), - is_gateway=False, - is_local=False, - detect_by_key_prefix="", - detect_by_base_keyword="", + backend="openai_compat", default_api_base="https://api.minimax.io/v1", - strip_model_prefix=False, - model_overrides=(), ), - # Mistral AI: OpenAI-compatible API at api.mistral.ai/v1. + # Mistral AI: OpenAI-compatible API ProviderSpec( name="mistral", keywords=("mistral",), env_key="MISTRAL_API_KEY", display_name="Mistral", - litellm_prefix="mistral", # mistral-large-latest → mistral/mistral-large-latest - skip_prefixes=("mistral/",), # avoid double-prefix - env_extras=(), - is_gateway=False, - is_local=False, - detect_by_key_prefix="", - detect_by_base_keyword="", + backend="openai_compat", default_api_base="https://api.mistral.ai/v1", - strip_model_prefix=False, - model_overrides=(), ), # === Local deployment (matched by config key, NOT by api_base) ========= - # vLLM / any OpenAI-compatible local server. - # Detected when config key is "vllm" (provider_name="vllm"). + # vLLM / any OpenAI-compatible local server ProviderSpec( name="vllm", keywords=("vllm",), env_key="HOSTED_VLLM_API_KEY", display_name="vLLM/Local", - litellm_prefix="hosted_vllm", # Llama-3-8B → hosted_vllm/Llama-3-8B - skip_prefixes=(), - env_extras=(), - is_gateway=False, + backend="openai_compat", is_local=True, - detect_by_key_prefix="", - detect_by_base_keyword="", - default_api_base="", # user must provide in config - strip_model_prefix=False, - model_overrides=(), ), - # === Ollama (local, OpenAI-compatible) =================================== + # Ollama (local, OpenAI-compatible) ProviderSpec( name="ollama", keywords=("ollama", "nemotron"), env_key="OLLAMA_API_KEY", display_name="Ollama", - litellm_prefix="ollama_chat", # model → ollama_chat/model - skip_prefixes=("ollama/", "ollama_chat/"), - env_extras=(), - is_gateway=False, + backend="openai_compat", is_local=True, - detect_by_key_prefix="", detect_by_base_keyword="11434", - default_api_base="http://localhost:11434", - strip_model_prefix=False, - model_overrides=(), + default_api_base="http://localhost:11434/v1", ), # === OpenVINO Model Server (direct, local, OpenAI-compatible at /v3) === ProviderSpec( @@ -460,29 +313,20 @@ def label(self) -> str: keywords=("openvino", "ovms"), env_key="", display_name="OpenVINO Model Server", - litellm_prefix="", + backend="openai_compat", is_direct=True, is_local=True, default_api_base="http://localhost:8000/v3", ), # === Auxiliary (not a primary LLM provider) ============================ - # Groq: mainly used for Whisper voice transcription, also usable for LLM. - # Needs "groq/" prefix for LiteLLM routing. Placed last — it rarely wins fallback. + # Groq: mainly used for Whisper voice transcription, also usable for LLM ProviderSpec( name="groq", keywords=("groq",), env_key="GROQ_API_KEY", display_name="Groq", - litellm_prefix="groq", # llama3-8b-8192 → groq/llama3-8b-8192 - skip_prefixes=("groq/",), # avoid double-prefix - env_extras=(), - is_gateway=False, - is_local=False, - detect_by_key_prefix="", - detect_by_base_keyword="", - default_api_base="", - strip_model_prefix=False, - model_overrides=(), + backend="openai_compat", + default_api_base="https://api.groq.com/openai/v1", ), ) @@ -492,59 +336,6 @@ def label(self) -> str: # --------------------------------------------------------------------------- -def find_by_model(model: str) -> ProviderSpec | None: - """Match a standard provider by model-name keyword (case-insensitive). - Skips gateways/local — those are matched by api_key/api_base instead.""" - model_lower = model.lower() - model_normalized = model_lower.replace("-", "_") - model_prefix = model_lower.split("/", 1)[0] if "/" in model_lower else "" - normalized_prefix = model_prefix.replace("-", "_") - std_specs = [s for s in PROVIDERS if not s.is_gateway and not s.is_local] - - # Prefer explicit provider prefix — prevents `github-copilot/...codex` matching openai_codex. - for spec in std_specs: - if model_prefix and normalized_prefix == spec.name: - return spec - - for spec in std_specs: - if any( - kw in model_lower or kw.replace("-", "_") in model_normalized for kw in spec.keywords - ): - return spec - return None - - -def find_gateway( - provider_name: str | None = None, - api_key: str | None = None, - api_base: str | None = None, -) -> ProviderSpec | None: - """Detect gateway/local provider. - - Priority: - 1. provider_name — if it maps to a gateway/local spec, use it directly. - 2. api_key prefix — e.g. "sk-or-" → OpenRouter. - 3. api_base keyword — e.g. "aihubmix" in URL → AiHubMix. - - A standard provider with a custom api_base (e.g. DeepSeek behind a proxy) - will NOT be mistaken for vLLM — the old fallback is gone. - """ - # 1. Direct match by config key - if provider_name: - spec = find_by_name(provider_name) - if spec and (spec.is_gateway or spec.is_local): - return spec - - # 2. Auto-detect by api_key prefix / api_base keyword - for spec in PROVIDERS: - if spec.detect_by_key_prefix and api_key and api_key.startswith(spec.detect_by_key_prefix): - return spec - if spec.detect_by_base_keyword and api_base and spec.detect_by_base_keyword in api_base: - return spec - - return None - - def find_by_name(name: str) -> ProviderSpec | None: """Find a provider spec by config field name, e.g. "dashscope".""" normalized = to_snake(name.replace("-", "_")) diff --git a/pyproject.toml b/pyproject.toml index 246ca3074cf..aca72777da6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,7 +19,7 @@ classifiers = [ dependencies = [ "typer>=0.20.0,<1.0.0", - "litellm>=1.82.1,<=1.82.6", + "anthropic>=0.45.0,<1.0.0", "pydantic>=2.12.0,<3.0.0", "pydantic-settings>=2.12.0,<3.0.0", "websockets>=16.0,<17.0", diff --git a/tests/agent/test_gemini_thought_signature.py b/tests/agent/test_gemini_thought_signature.py index bc4132c37f6..35739602abb 100644 --- a/tests/agent/test_gemini_thought_signature.py +++ b/tests/agent/test_gemini_thought_signature.py @@ -1,40 +1,6 @@ from types import SimpleNamespace from nanobot.providers.base import ToolCallRequest -from nanobot.providers.litellm_provider import LiteLLMProvider - - -def test_litellm_parse_response_preserves_tool_call_provider_fields() -> None: - provider = LiteLLMProvider(default_model="gemini/gemini-3-flash") - - response = SimpleNamespace( - choices=[ - SimpleNamespace( - finish_reason="tool_calls", - message=SimpleNamespace( - content=None, - tool_calls=[ - SimpleNamespace( - id="call_123", - function=SimpleNamespace( - name="read_file", - arguments='{"path":"todo.md"}', - provider_specific_fields={"inner": "value"}, - ), - provider_specific_fields={"thought_signature": "signed-token"}, - ) - ], - ), - ) - ], - usage=None, - ) - - parsed = provider._parse_response(response) - - assert len(parsed.tool_calls) == 1 - assert parsed.tool_calls[0].provider_specific_fields == {"thought_signature": "signed-token"} - assert parsed.tool_calls[0].function_provider_specific_fields == {"inner": "value"} def test_tool_call_request_serializes_provider_fields() -> None: diff --git a/tests/agent/test_memory_consolidation_types.py b/tests/agent/test_memory_consolidation_types.py index d63cc90475c..203e39a90f4 100644 --- a/tests/agent/test_memory_consolidation_types.py +++ b/tests/agent/test_memory_consolidation_types.py @@ -380,7 +380,7 @@ async def test_tool_choice_fallback_on_unsupported_error(self, tmp_path: Path) - """Forced tool_choice rejected by provider -> retry with auto and succeed.""" store = MemoryStore(tmp_path) error_resp = LLMResponse( - content="Error calling LLM: litellm.BadRequestError: " + content="Error calling LLM: BadRequestError: " "The tool_choice parameter does not support being set to required or object", finish_reason="error", tool_calls=[], diff --git a/tests/cli/test_commands.py b/tests/cli/test_commands.py index 4e79fc71732..a8fcc4aa0cf 100644 --- a/tests/cli/test_commands.py +++ b/tests/cli/test_commands.py @@ -9,9 +9,8 @@ from nanobot.bus.events import OutboundMessage from nanobot.cli.commands import _make_provider, app from nanobot.config.schema import Config -from nanobot.providers.litellm_provider import LiteLLMProvider from nanobot.providers.openai_codex_provider import _strip_model_prefix -from nanobot.providers.registry import find_by_model, find_by_name +from nanobot.providers.registry import find_by_name runner = CliRunner() @@ -228,7 +227,7 @@ def test_config_matches_explicit_ollama_prefix_without_api_key(): config.agents.defaults.model = "ollama/llama3.2" assert config.get_provider_name() == "ollama" - assert config.get_api_base() == "http://localhost:11434" + assert config.get_api_base() == "http://localhost:11434/v1" def test_config_explicit_ollama_provider_uses_default_localhost_api_base(): @@ -237,7 +236,7 @@ def test_config_explicit_ollama_provider_uses_default_localhost_api_base(): config.agents.defaults.model = "llama3.2" assert config.get_provider_name() == "ollama" - assert config.get_api_base() == "http://localhost:11434" + assert config.get_api_base() == "http://localhost:11434/v1" def test_config_accepts_camel_case_explicit_provider_name_for_coding_plan(): @@ -272,12 +271,12 @@ def test_config_auto_detects_ollama_from_local_api_base(): config = Config.model_validate( { "agents": {"defaults": {"provider": "auto", "model": "llama3.2"}}, - "providers": {"ollama": {"apiBase": "http://localhost:11434"}}, + "providers": {"ollama": {"apiBase": "http://localhost:11434/v1"}}, } ) assert config.get_provider_name() == "ollama" - assert config.get_api_base() == "http://localhost:11434" + assert config.get_api_base() == "http://localhost:11434/v1" def test_config_prefers_ollama_over_vllm_when_both_local_providers_configured(): @@ -286,13 +285,13 @@ def test_config_prefers_ollama_over_vllm_when_both_local_providers_configured(): "agents": {"defaults": {"provider": "auto", "model": "llama3.2"}}, "providers": { "vllm": {"apiBase": "http://localhost:8000"}, - "ollama": {"apiBase": "http://localhost:11434"}, + "ollama": {"apiBase": "http://localhost:11434/v1"}, }, } ) assert config.get_provider_name() == "ollama" - assert config.get_api_base() == "http://localhost:11434" + assert config.get_api_base() == "http://localhost:11434/v1" def test_config_falls_back_to_vllm_when_ollama_not_configured(): @@ -309,19 +308,13 @@ def test_config_falls_back_to_vllm_when_ollama_not_configured(): assert config.get_api_base() == "http://localhost:8000" -def test_find_by_model_prefers_explicit_prefix_over_generic_codex_keyword(): - spec = find_by_model("github-copilot/gpt-5.3-codex") +def test_openai_compat_provider_passes_model_through(): + from nanobot.providers.openai_compat_provider import OpenAICompatProvider - assert spec is not None - assert spec.name == "github_copilot" + with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"): + provider = OpenAICompatProvider(default_model="github-copilot/gpt-5.3-codex") - -def test_litellm_provider_canonicalizes_github_copilot_hyphen_prefix(): - provider = LiteLLMProvider(default_model="github-copilot/gpt-5.3-codex") - - resolved = provider._resolve_model("github-copilot/gpt-5.3-codex") - - assert resolved == "github_copilot/gpt-5.3-codex" + assert provider.get_default_model() == "github-copilot/gpt-5.3-codex" def test_openai_codex_strip_prefix_supports_hyphen_and_underscore(): @@ -346,7 +339,7 @@ def test_make_provider_passes_extra_headers_to_custom_provider(): } ) - with patch("nanobot.providers.custom_provider.AsyncOpenAI") as mock_async_openai: + with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI") as mock_async_openai: _make_provider(config) kwargs = mock_async_openai.call_args.kwargs diff --git a/tests/providers/test_custom_provider.py b/tests/providers/test_custom_provider.py index 463affedc39..bb46b887a2f 100644 --- a/tests/providers/test_custom_provider.py +++ b/tests/providers/test_custom_provider.py @@ -1,10 +1,14 @@ +"""Tests for OpenAICompatProvider handling custom/direct endpoints.""" + from types import SimpleNamespace +from unittest.mock import patch -from nanobot.providers.custom_provider import CustomProvider +from nanobot.providers.openai_compat_provider import OpenAICompatProvider def test_custom_provider_parse_handles_empty_choices() -> None: - provider = CustomProvider() + with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"): + provider = OpenAICompatProvider() response = SimpleNamespace(choices=[]) result = provider._parse(response) diff --git a/tests/providers/test_litellm_kwargs.py b/tests/providers/test_litellm_kwargs.py index 437f8a55562..c55857b3b78 100644 --- a/tests/providers/test_litellm_kwargs.py +++ b/tests/providers/test_litellm_kwargs.py @@ -1,161 +1,122 @@ -"""Regression tests for PR #2026 — litellm_kwargs injection from ProviderSpec. +"""Tests for OpenAICompatProvider spec-driven behavior. Validates that: -- OpenRouter uses litellm_prefix (NOT custom_llm_provider) to avoid LiteLLM double-prefixing. -- The litellm_kwargs mechanism works correctly for providers that declare it. -- Non-gateway providers are unaffected. +- OpenRouter (no strip) keeps model names intact. +- AiHubMix (strip_model_prefix=True) strips provider prefixes. +- Standard providers pass model names through as-is. """ from __future__ import annotations from types import SimpleNamespace -from typing import Any from unittest.mock import AsyncMock, patch import pytest -from nanobot.providers.litellm_provider import LiteLLMProvider +from nanobot.providers.openai_compat_provider import OpenAICompatProvider from nanobot.providers.registry import find_by_name -def _fake_response(content: str = "ok") -> SimpleNamespace: - """Build a minimal acompletion-shaped response object.""" +def _fake_chat_response(content: str = "ok") -> SimpleNamespace: + """Build a minimal OpenAI chat completion response.""" message = SimpleNamespace( content=content, tool_calls=None, reasoning_content=None, - thinking_blocks=None, ) choice = SimpleNamespace(message=message, finish_reason="stop") usage = SimpleNamespace(prompt_tokens=10, completion_tokens=5, total_tokens=15) return SimpleNamespace(choices=[choice], usage=usage) -def test_openrouter_spec_uses_prefix_not_custom_llm_provider() -> None: - """OpenRouter must rely on litellm_prefix, not custom_llm_provider kwarg. - - LiteLLM internally adds a provider/ prefix when custom_llm_provider is set, - which double-prefixes models (openrouter/anthropic/model) and breaks the API. - """ +def test_openrouter_spec_is_gateway() -> None: spec = find_by_name("openrouter") assert spec is not None - assert spec.litellm_prefix == "openrouter" - assert "custom_llm_provider" not in spec.litellm_kwargs, ( - "custom_llm_provider causes LiteLLM to double-prefix the model name" - ) + assert spec.is_gateway is True + assert spec.default_api_base == "https://openrouter.ai/api/v1" @pytest.mark.asyncio -async def test_openrouter_prefixes_model_correctly() -> None: - """OpenRouter should prefix model as openrouter/vendor/model for LiteLLM routing.""" - mock_acompletion = AsyncMock(return_value=_fake_response()) +async def test_openrouter_keeps_model_name_intact() -> None: + """OpenRouter gateway keeps the full model name (gateway does its own routing).""" + mock_create = AsyncMock(return_value=_fake_chat_response()) + spec = find_by_name("openrouter") + + with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI") as MockClient: + client_instance = MockClient.return_value + client_instance.chat.completions.create = mock_create - with patch("nanobot.providers.litellm_provider.acompletion", mock_acompletion): - provider = LiteLLMProvider( + provider = OpenAICompatProvider( api_key="sk-or-test-key", api_base="https://openrouter.ai/api/v1", default_model="anthropic/claude-sonnet-4-5", - provider_name="openrouter", + spec=spec, ) await provider.chat( messages=[{"role": "user", "content": "hello"}], model="anthropic/claude-sonnet-4-5", ) - call_kwargs = mock_acompletion.call_args.kwargs - assert call_kwargs["model"] == "openrouter/anthropic/claude-sonnet-4-5", ( - "LiteLLM needs openrouter/ prefix to detect the provider and strip it before API call" - ) - assert "custom_llm_provider" not in call_kwargs + call_kwargs = mock_create.call_args.kwargs + assert call_kwargs["model"] == "anthropic/claude-sonnet-4-5" @pytest.mark.asyncio -async def test_non_gateway_provider_no_extra_kwargs() -> None: - """Standard (non-gateway) providers must NOT inject any litellm_kwargs.""" - mock_acompletion = AsyncMock(return_value=_fake_response()) +async def test_aihubmix_strips_model_prefix() -> None: + """AiHubMix strips the provider prefix (strip_model_prefix=True).""" + mock_create = AsyncMock(return_value=_fake_chat_response()) + spec = find_by_name("aihubmix") - with patch("nanobot.providers.litellm_provider.acompletion", mock_acompletion): - provider = LiteLLMProvider( - api_key="sk-ant-test-key", - default_model="claude-sonnet-4-5", - ) - await provider.chat( - messages=[{"role": "user", "content": "hello"}], - model="claude-sonnet-4-5", - ) - - call_kwargs = mock_acompletion.call_args.kwargs - assert "custom_llm_provider" not in call_kwargs, ( - "Standard Anthropic provider should NOT inject custom_llm_provider" - ) - - -@pytest.mark.asyncio -async def test_gateway_without_litellm_kwargs_injects_nothing_extra() -> None: - """Gateways without litellm_kwargs (e.g. AiHubMix) must not add extra keys.""" - mock_acompletion = AsyncMock(return_value=_fake_response()) + with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI") as MockClient: + client_instance = MockClient.return_value + client_instance.chat.completions.create = mock_create - with patch("nanobot.providers.litellm_provider.acompletion", mock_acompletion): - provider = LiteLLMProvider( + provider = OpenAICompatProvider( api_key="sk-aihub-test-key", api_base="https://aihubmix.com/v1", default_model="claude-sonnet-4-5", - provider_name="aihubmix", + spec=spec, ) await provider.chat( messages=[{"role": "user", "content": "hello"}], - model="claude-sonnet-4-5", + model="anthropic/claude-sonnet-4-5", ) - call_kwargs = mock_acompletion.call_args.kwargs - assert "custom_llm_provider" not in call_kwargs + call_kwargs = mock_create.call_args.kwargs + assert call_kwargs["model"] == "claude-sonnet-4-5" @pytest.mark.asyncio -async def test_openrouter_autodetect_by_key_prefix() -> None: - """OpenRouter should be auto-detected by sk-or- key prefix even without explicit provider_name.""" - mock_acompletion = AsyncMock(return_value=_fake_response()) - - with patch("nanobot.providers.litellm_provider.acompletion", mock_acompletion): - provider = LiteLLMProvider( - api_key="sk-or-auto-detect-key", - default_model="anthropic/claude-sonnet-4-5", +async def test_standard_provider_passes_model_through() -> None: + """Standard provider (e.g. deepseek) passes model name through as-is.""" + mock_create = AsyncMock(return_value=_fake_chat_response()) + spec = find_by_name("deepseek") + + with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI") as MockClient: + client_instance = MockClient.return_value + client_instance.chat.completions.create = mock_create + + provider = OpenAICompatProvider( + api_key="sk-deepseek-test-key", + default_model="deepseek-chat", + spec=spec, ) await provider.chat( messages=[{"role": "user", "content": "hello"}], - model="anthropic/claude-sonnet-4-5", + model="deepseek-chat", ) - call_kwargs = mock_acompletion.call_args.kwargs - assert call_kwargs["model"] == "openrouter/anthropic/claude-sonnet-4-5", ( - "Auto-detected OpenRouter should prefix model for LiteLLM routing" - ) - - -@pytest.mark.asyncio -async def test_openrouter_native_model_id_gets_double_prefixed() -> None: - """Models like openrouter/free must be double-prefixed so LiteLLM strips one layer. + call_kwargs = mock_create.call_args.kwargs + assert call_kwargs["model"] == "deepseek-chat" - openrouter/free is an actual OpenRouter model ID. LiteLLM strips the first - openrouter/ for routing, so we must send openrouter/openrouter/free to ensure - the API receives openrouter/free. - """ - mock_acompletion = AsyncMock(return_value=_fake_response()) - with patch("nanobot.providers.litellm_provider.acompletion", mock_acompletion): - provider = LiteLLMProvider( - api_key="sk-or-test-key", - api_base="https://openrouter.ai/api/v1", - default_model="openrouter/free", - provider_name="openrouter", +def test_openai_model_passthrough() -> None: + """OpenAI models pass through unchanged.""" + spec = find_by_name("openai") + with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"): + provider = OpenAICompatProvider( + api_key="sk-test-key", + default_model="gpt-4o", + spec=spec, ) - await provider.chat( - messages=[{"role": "user", "content": "hello"}], - model="openrouter/free", - ) - - call_kwargs = mock_acompletion.call_args.kwargs - assert call_kwargs["model"] == "openrouter/openrouter/free", ( - "openrouter/free must become openrouter/openrouter/free — " - "LiteLLM strips one layer so the API receives openrouter/free" - ) + assert provider.get_default_model() == "gpt-4o" diff --git a/tests/providers/test_mistral_provider.py b/tests/providers/test_mistral_provider.py index 40112217823..30023afe74a 100644 --- a/tests/providers/test_mistral_provider.py +++ b/tests/providers/test_mistral_provider.py @@ -17,6 +17,4 @@ def test_mistral_provider_in_registry(): mistral = specs["mistral"] assert mistral.env_key == "MISTRAL_API_KEY" - assert mistral.litellm_prefix == "mistral" assert mistral.default_api_base == "https://api.mistral.ai/v1" - assert "mistral/" in mistral.skip_prefixes diff --git a/tests/providers/test_providers_init.py b/tests/providers/test_providers_init.py index 02ab7c1efe8..32cbab47883 100644 --- a/tests/providers/test_providers_init.py +++ b/tests/providers/test_providers_init.py @@ -8,19 +8,22 @@ def test_importing_providers_package_is_lazy(monkeypatch) -> None: monkeypatch.delitem(sys.modules, "nanobot.providers", raising=False) - monkeypatch.delitem(sys.modules, "nanobot.providers.litellm_provider", raising=False) + monkeypatch.delitem(sys.modules, "nanobot.providers.anthropic_provider", raising=False) + monkeypatch.delitem(sys.modules, "nanobot.providers.openai_compat_provider", raising=False) monkeypatch.delitem(sys.modules, "nanobot.providers.openai_codex_provider", raising=False) monkeypatch.delitem(sys.modules, "nanobot.providers.azure_openai_provider", raising=False) providers = importlib.import_module("nanobot.providers") - assert "nanobot.providers.litellm_provider" not in sys.modules + assert "nanobot.providers.anthropic_provider" not in sys.modules + assert "nanobot.providers.openai_compat_provider" not in sys.modules assert "nanobot.providers.openai_codex_provider" not in sys.modules assert "nanobot.providers.azure_openai_provider" not in sys.modules assert providers.__all__ == [ "LLMProvider", "LLMResponse", - "LiteLLMProvider", + "AnthropicProvider", + "OpenAICompatProvider", "OpenAICodexProvider", "AzureOpenAIProvider", ] @@ -28,10 +31,10 @@ def test_importing_providers_package_is_lazy(monkeypatch) -> None: def test_explicit_provider_import_still_works(monkeypatch) -> None: monkeypatch.delitem(sys.modules, "nanobot.providers", raising=False) - monkeypatch.delitem(sys.modules, "nanobot.providers.litellm_provider", raising=False) + monkeypatch.delitem(sys.modules, "nanobot.providers.anthropic_provider", raising=False) namespace: dict[str, object] = {} - exec("from nanobot.providers import LiteLLMProvider", namespace) + exec("from nanobot.providers import AnthropicProvider", namespace) - assert namespace["LiteLLMProvider"].__name__ == "LiteLLMProvider" - assert "nanobot.providers.litellm_provider" in sys.modules + assert namespace["AnthropicProvider"].__name__ == "AnthropicProvider" + assert "nanobot.providers.anthropic_provider" in sys.modules From c3031c9cb84bdad140711b3a0e4d37ba02595d87 Mon Sep 17 00:00:00 2001 From: Xubin Ren Date: Tue, 24 Mar 2026 18:11:03 +0000 Subject: [PATCH 18/40] docs: update news section about litellm --- README.md | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 9f5e0d248ae..1f337eb4195 100644 --- a/README.md +++ b/README.md @@ -21,8 +21,13 @@ ## 📢 News > [!IMPORTANT] -> **Security note:** Due to `litellm` supply chain poisoning, **please check your Python environment ASAP** and refer to this [advisory](https://github.com/HKUDS/nanobot/discussions/2445) for details. We are also urgently replacing `litellm` and preparing mitigations. +> **Security note:** Due to `litellm` supply chain poisoning, **please check your Python environment ASAP** and refer to this [advisory](https://github.com/HKUDS/nanobot/discussions/2445) for details. We have fully removed the `litellm` dependency in [this commit](https://github.com/HKUDS/nanobot/commit/3dfdab7). +- **2026-03-21** 🔒 Replace `litellm` with native `openai` + `anthropic` SDKs. Please see [commit](https://github.com/HKUDS/nanobot/commit/3dfdab7). +- **2026-03-20** 🧙 Interactive setup wizard — pick your provider, model autocomplete, and you're good to go. +- **2026-03-19** 💬 Telegram gets more resilient under load; Feishu now renders code blocks properly. +- **2026-03-18** 📷 Telegram can now send media via URL. Cron schedules show human-readable details. Fresh logo. +- **2026-03-17** ✨ Feishu formatting glow-up, Slack reacts when done, custom endpoints support extra headers, and image handling is more reliable. - **2026-03-16** 🚀 Released **v0.1.4.post5** — a refinement-focused release with stronger reliability and channel support, and a more dependable day-to-day experience. Please see [release notes](https://github.com/HKUDS/nanobot/releases/tag/v0.1.4.post5) for details. - **2026-03-15** 🧩 DingTalk rich media, smarter built-in skills, and cleaner model compatibility. - **2026-03-14** 💬 Channel plugins, Feishu replies, and steadier MCP, QQ, and media handling. From 7b31af22049444e246f842c1cf95b46b54990a72 Mon Sep 17 00:00:00 2001 From: Xubin Ren Date: Tue, 24 Mar 2026 18:11:50 +0000 Subject: [PATCH 19/40] docs: update news section --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 1f337eb4195..5ec339701b3 100644 --- a/README.md +++ b/README.md @@ -26,7 +26,7 @@ - **2026-03-21** 🔒 Replace `litellm` with native `openai` + `anthropic` SDKs. Please see [commit](https://github.com/HKUDS/nanobot/commit/3dfdab7). - **2026-03-20** 🧙 Interactive setup wizard — pick your provider, model autocomplete, and you're good to go. - **2026-03-19** 💬 Telegram gets more resilient under load; Feishu now renders code blocks properly. -- **2026-03-18** 📷 Telegram can now send media via URL. Cron schedules show human-readable details. Fresh logo. +- **2026-03-18** 📷 Telegram can now send media via URL. Cron schedules show human-readable details. - **2026-03-17** ✨ Feishu formatting glow-up, Slack reacts when done, custom endpoints support extra headers, and image handling is more reliable. - **2026-03-16** 🚀 Released **v0.1.4.post5** — a refinement-focused release with stronger reliability and channel support, and a more dependable day-to-day experience. Please see [release notes](https://github.com/HKUDS/nanobot/releases/tag/v0.1.4.post5) for details. - **2026-03-15** 🧩 DingTalk rich media, smarter built-in skills, and cleaner model compatibility. From 3a9d6ea536063935f26e468c53424cdced8f7e1f Mon Sep 17 00:00:00 2001 From: xcosmosbox <2162381070@qq.com> Date: Tue, 24 Mar 2026 14:38:18 +0800 Subject: [PATCH 20/40] feat(WeXin): add route_tag property to adapt to WeChat official ilinkai 1.0.3 requirements --- nanobot/channels/weixin.py | 3 +++ tests/channels/test_weixin_channel.py | 14 ++++++++++++++ 2 files changed, 17 insertions(+) diff --git a/nanobot/channels/weixin.py b/nanobot/channels/weixin.py index 48a97f582db..a8a4a636d7a 100644 --- a/nanobot/channels/weixin.py +++ b/nanobot/channels/weixin.py @@ -83,6 +83,7 @@ class WeixinConfig(Base): allow_from: list[str] = Field(default_factory=list) base_url: str = "https://ilinkai.weixin.qq.com" cdn_base_url: str = "https://novac2c.cdn.weixin.qq.com/c2c" + route_tag: str | int | None = None token: str = "" # Manually set token, or obtained via QR login state_dir: str = "" # Default: ~/.nanobot/weixin/ poll_timeout: int = DEFAULT_LONG_POLL_TIMEOUT_S # seconds for long-poll @@ -187,6 +188,8 @@ def _make_headers(self, *, auth: bool = True) -> dict[str, str]: } if auth and self._token: headers["Authorization"] = f"Bearer {self._token}" + if self.config.route_tag is not None and str(self.config.route_tag).strip(): + headers["SKRouteTag"] = str(self.config.route_tag).strip() return headers async def _api_get( diff --git a/tests/channels/test_weixin_channel.py b/tests/channels/test_weixin_channel.py index a16c6b7509d..6107d117b20 100644 --- a/tests/channels/test_weixin_channel.py +++ b/tests/channels/test_weixin_channel.py @@ -22,6 +22,20 @@ def _make_channel() -> tuple[WeixinChannel, MessageBus]: return channel, bus +def test_make_headers_includes_route_tag_when_configured() -> None: + bus = MessageBus() + channel = WeixinChannel( + WeixinConfig(enabled=True, allow_from=["*"], route_tag=123), + bus, + ) + channel._token = "token" + + headers = channel._make_headers() + + assert headers["Authorization"] == "Bearer token" + assert headers["SKRouteTag"] == "123" + + @pytest.mark.asyncio async def test_process_message_deduplicates_inbound_ids() -> None: channel, bus = _make_channel() From 9c872c34584b32bc72c6af0e4922263fa3d3315f Mon Sep 17 00:00:00 2001 From: xcosmosbox <2162381070@qq.com> Date: Tue, 24 Mar 2026 14:44:16 +0800 Subject: [PATCH 21/40] fix(WeiXin): resolve polling issues in WeiXin plugin - Prevent repeated retries on expired sessions in the polling thread - Stop sending messages to invalid agent sessions to eliminate noise logs and unnecessary requests --- nanobot/channels/weixin.py | 40 +++++++++++++++++++++++++-- tests/channels/test_weixin_channel.py | 29 +++++++++++++++++++ 2 files changed, 67 insertions(+), 2 deletions(-) diff --git a/nanobot/channels/weixin.py b/nanobot/channels/weixin.py index a8a4a636d7a..e572d68a29a 100644 --- a/nanobot/channels/weixin.py +++ b/nanobot/channels/weixin.py @@ -57,6 +57,7 @@ # Session-expired error code ERRCODE_SESSION_EXPIRED = -14 +SESSION_PAUSE_DURATION_S = 60 * 60 # Retry constants (matching the reference plugin's monitor.ts) MAX_CONSECUTIVE_FAILURES = 3 @@ -120,6 +121,7 @@ def __init__(self, config: Any, bus: MessageBus): self._token: str = "" self._poll_task: asyncio.Task | None = None self._next_poll_timeout_s: int = DEFAULT_LONG_POLL_TIMEOUT_S + self._session_pause_until: float = 0.0 # ------------------------------------------------------------------ # State persistence @@ -395,7 +397,34 @@ async def stop(self) -> None: # Polling (matches monitor.ts monitorWeixinProvider) # ------------------------------------------------------------------ + def _pause_session(self, duration_s: int = SESSION_PAUSE_DURATION_S) -> None: + self._session_pause_until = time.time() + duration_s + + def _session_pause_remaining_s(self) -> int: + remaining = int(self._session_pause_until - time.time()) + if remaining <= 0: + self._session_pause_until = 0.0 + return 0 + return remaining + + def _assert_session_active(self) -> None: + remaining = self._session_pause_remaining_s() + if remaining > 0: + remaining_min = max((remaining + 59) // 60, 1) + raise RuntimeError( + f"WeChat session paused, {remaining_min} min remaining (errcode {ERRCODE_SESSION_EXPIRED})" + ) + async def _poll_once(self) -> None: + remaining = self._session_pause_remaining_s() + if remaining > 0: + logger.warning( + "WeChat session paused, waiting {} min before next poll.", + max((remaining + 59) // 60, 1), + ) + await asyncio.sleep(remaining) + return + body: dict[str, Any] = { "get_updates_buf": self._get_updates_buf, "base_info": BASE_INFO, @@ -414,11 +443,13 @@ async def _poll_once(self) -> None: if is_error: if errcode == ERRCODE_SESSION_EXPIRED or ret == ERRCODE_SESSION_EXPIRED: + self._pause_session() + remaining = self._session_pause_remaining_s() logger.warning( - "WeChat session expired (errcode {}). Pausing 60 min.", + "WeChat session expired (errcode {}). Pausing {} min.", errcode, + max((remaining + 59) // 60, 1), ) - await asyncio.sleep(3600) return raise RuntimeError( f"getUpdates failed: ret={ret} errcode={errcode} errmsg={data.get('errmsg', '')}" @@ -654,6 +685,11 @@ async def send(self, msg: OutboundMessage) -> None: if not self._client or not self._token: logger.warning("WeChat client not initialized or not authenticated") return + try: + self._assert_session_active() + except RuntimeError as e: + logger.warning("WeChat send blocked: {}", e) + return content = msg.content.strip() ctx_token = self._context_tokens.get(msg.chat_id, "") diff --git a/tests/channels/test_weixin_channel.py b/tests/channels/test_weixin_channel.py index 6107d117b20..0a01b72c79a 100644 --- a/tests/channels/test_weixin_channel.py +++ b/tests/channels/test_weixin_channel.py @@ -1,4 +1,5 @@ import asyncio +from types import SimpleNamespace from unittest.mock import AsyncMock import pytest @@ -123,6 +124,34 @@ async def test_send_without_context_token_does_not_send_text() -> None: channel._send_text.assert_not_awaited() +@pytest.mark.asyncio +async def test_send_does_not_send_when_session_is_paused() -> None: + channel, _bus = _make_channel() + channel._client = object() + channel._token = "token" + channel._context_tokens["wx-user"] = "ctx-2" + channel._pause_session(60) + channel._send_text = AsyncMock() + + await channel.send( + type("Msg", (), {"chat_id": "wx-user", "content": "pong", "media": [], "metadata": {}})() + ) + + channel._send_text.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_poll_once_pauses_session_on_expired_errcode() -> None: + channel, _bus = _make_channel() + channel._client = SimpleNamespace(timeout=None) + channel._token = "token" + channel._api_post = AsyncMock(return_value={"ret": 0, "errcode": -14, "errmsg": "expired"}) + + await channel._poll_once() + + assert channel._session_pause_remaining_s() > 0 + + @pytest.mark.asyncio async def test_process_message_skips_bot_messages() -> None: channel, bus = _make_channel() From 1f5492ea9e33d431852b967b058d2c48d40ef8fb Mon Sep 17 00:00:00 2001 From: xcosmosbox <2162381070@qq.com> Date: Tue, 24 Mar 2026 14:52:13 +0800 Subject: [PATCH 22/40] fix(WeiXin): persist _context_tokens with account.json to restore conversations after restart --- nanobot/channels/weixin.py | 11 ++++++ tests/channels/test_weixin_channel.py | 56 ++++++++++++++++++++++++++- 2 files changed, 66 insertions(+), 1 deletion(-) diff --git a/nanobot/channels/weixin.py b/nanobot/channels/weixin.py index e572d68a29a..115cca7ffc5 100644 --- a/nanobot/channels/weixin.py +++ b/nanobot/channels/weixin.py @@ -147,6 +147,15 @@ def _load_state(self) -> bool: data = json.loads(state_file.read_text()) self._token = data.get("token", "") self._get_updates_buf = data.get("get_updates_buf", "") + context_tokens = data.get("context_tokens", {}) + if isinstance(context_tokens, dict): + self._context_tokens = { + str(user_id): str(token) + for user_id, token in context_tokens.items() + if str(user_id).strip() and str(token).strip() + } + else: + self._context_tokens = {} base_url = data.get("base_url", "") if base_url: self.config.base_url = base_url @@ -161,6 +170,7 @@ def _save_state(self) -> None: data = { "token": self._token, "get_updates_buf": self._get_updates_buf, + "context_tokens": self._context_tokens, "base_url": self.config.base_url, } state_file.write_text(json.dumps(data, ensure_ascii=False)) @@ -502,6 +512,7 @@ async def _process_message(self, msg: dict) -> None: ctx_token = msg.get("context_token", "") if ctx_token: self._context_tokens[from_user_id] = ctx_token + self._save_state() # Parse item_list (WeixinMessage.item_list — types.ts:161) item_list: list[dict] = msg.get("item_list") or [] diff --git a/tests/channels/test_weixin_channel.py b/tests/channels/test_weixin_channel.py index 0a01b72c79a..36e56315b01 100644 --- a/tests/channels/test_weixin_channel.py +++ b/tests/channels/test_weixin_channel.py @@ -1,4 +1,6 @@ import asyncio +import json +import tempfile from types import SimpleNamespace from unittest.mock import AsyncMock @@ -17,7 +19,11 @@ def _make_channel() -> tuple[WeixinChannel, MessageBus]: bus = MessageBus() channel = WeixinChannel( - WeixinConfig(enabled=True, allow_from=["*"]), + WeixinConfig( + enabled=True, + allow_from=["*"], + state_dir=tempfile.mkdtemp(prefix="nanobot-weixin-test-"), + ), bus, ) return channel, bus @@ -37,6 +43,30 @@ def test_make_headers_includes_route_tag_when_configured() -> None: assert headers["SKRouteTag"] == "123" +def test_save_and_load_state_persists_context_tokens(tmp_path) -> None: + bus = MessageBus() + channel = WeixinChannel( + WeixinConfig(enabled=True, allow_from=["*"], state_dir=str(tmp_path)), + bus, + ) + channel._token = "token" + channel._get_updates_buf = "cursor" + channel._context_tokens = {"wx-user": "ctx-1"} + + channel._save_state() + + saved = json.loads((tmp_path / "account.json").read_text()) + assert saved["context_tokens"] == {"wx-user": "ctx-1"} + + restored = WeixinChannel( + WeixinConfig(enabled=True, allow_from=["*"], state_dir=str(tmp_path)), + bus, + ) + + assert restored._load_state() is True + assert restored._context_tokens == {"wx-user": "ctx-1"} + + @pytest.mark.asyncio async def test_process_message_deduplicates_inbound_ids() -> None: channel, bus = _make_channel() @@ -86,6 +116,30 @@ async def test_process_message_caches_context_token_and_send_uses_it() -> None: channel._send_text.assert_awaited_once_with("wx-user", "pong", "ctx-2") +@pytest.mark.asyncio +async def test_process_message_persists_context_token_to_state_file(tmp_path) -> None: + bus = MessageBus() + channel = WeixinChannel( + WeixinConfig(enabled=True, allow_from=["*"], state_dir=str(tmp_path)), + bus, + ) + + await channel._process_message( + { + "message_type": 1, + "message_id": "m2b", + "from_user_id": "wx-user", + "context_token": "ctx-2b", + "item_list": [ + {"type": ITEM_TEXT, "text_item": {"text": "ping"}}, + ], + } + ) + + saved = json.loads((tmp_path / "account.json").read_text()) + assert saved["context_tokens"] == {"wx-user": "ctx-2b"} + + @pytest.mark.asyncio async def test_process_message_extracts_media_and_preserves_paths() -> None: channel, bus = _make_channel() From 48902ae95a67fc465ec394448cda9951cb32a84a Mon Sep 17 00:00:00 2001 From: xcosmosbox <2162381070@qq.com> Date: Tue, 24 Mar 2026 14:55:36 +0800 Subject: [PATCH 23/40] fix(WeiXin): auto-refresh expired QR code during login to improve success rate --- nanobot/channels/weixin.py | 49 ++++++++++++++++--------- tests/channels/test_weixin_channel.py | 51 +++++++++++++++++++++++++++ 2 files changed, 84 insertions(+), 16 deletions(-) diff --git a/nanobot/channels/weixin.py b/nanobot/channels/weixin.py index 115cca7ffc5..5ea887f0298 100644 --- a/nanobot/channels/weixin.py +++ b/nanobot/channels/weixin.py @@ -63,6 +63,7 @@ MAX_CONSECUTIVE_FAILURES = 3 BACKOFF_DELAY_S = 30 RETRY_DELAY_S = 2 +MAX_QR_REFRESH_COUNT = 3 # Default long-poll timeout; overridden by server via longpolling_timeout_ms. DEFAULT_LONG_POLL_TIMEOUT_S = 35 @@ -241,24 +242,25 @@ async def _api_post( # QR Code Login (matches login-qr.ts) # ------------------------------------------------------------------ + async def _fetch_qr_code(self) -> tuple[str, str]: + """Fetch a fresh QR code. Returns (qrcode_id, scan_url).""" + data = await self._api_get( + "ilink/bot/get_bot_qrcode", + params={"bot_type": "3"}, + auth=False, + ) + qrcode_img_content = data.get("qrcode_img_content", "") + qrcode_id = data.get("qrcode", "") + if not qrcode_id: + raise RuntimeError(f"Failed to get QR code from WeChat API: {data}") + return qrcode_id, (qrcode_img_content or qrcode_id) + async def _qr_login(self) -> bool: """Perform QR code login flow. Returns True on success.""" try: logger.info("Starting WeChat QR code login...") - - data = await self._api_get( - "ilink/bot/get_bot_qrcode", - params={"bot_type": "3"}, - auth=False, - ) - qrcode_img_content = data.get("qrcode_img_content", "") - qrcode_id = data.get("qrcode", "") - - if not qrcode_id: - logger.error("Failed to get QR code from WeChat API: {}", data) - return False - - scan_url = qrcode_img_content or qrcode_id + refresh_count = 0 + qrcode_id, scan_url = await self._fetch_qr_code() self._print_qr_code(scan_url) logger.info("Waiting for QR code scan...") @@ -298,8 +300,23 @@ async def _qr_login(self) -> bool: elif status == "scaned": logger.info("QR code scanned, waiting for confirmation...") elif status == "expired": - logger.warning("QR code expired") - return False + refresh_count += 1 + if refresh_count > MAX_QR_REFRESH_COUNT: + logger.warning( + "QR code expired too many times ({}/{}), giving up.", + refresh_count - 1, + MAX_QR_REFRESH_COUNT, + ) + return False + logger.warning( + "QR code expired, refreshing... ({}/{})", + refresh_count, + MAX_QR_REFRESH_COUNT, + ) + qrcode_id, scan_url = await self._fetch_qr_code() + self._print_qr_code(scan_url) + logger.info("New QR code generated, waiting for scan...") + continue # status == "wait" — keep polling await asyncio.sleep(1) diff --git a/tests/channels/test_weixin_channel.py b/tests/channels/test_weixin_channel.py index 36e56315b01..818e45d9827 100644 --- a/tests/channels/test_weixin_channel.py +++ b/tests/channels/test_weixin_channel.py @@ -206,6 +206,57 @@ async def test_poll_once_pauses_session_on_expired_errcode() -> None: assert channel._session_pause_remaining_s() > 0 +@pytest.mark.asyncio +async def test_qr_login_refreshes_expired_qr_and_then_succeeds() -> None: + channel, _bus = _make_channel() + channel._running = True + channel._save_state = lambda: None + channel._print_qr_code = lambda url: None + channel._api_get = AsyncMock( + side_effect=[ + {"qrcode": "qr-1", "qrcode_img_content": "url-1"}, + {"status": "expired"}, + {"qrcode": "qr-2", "qrcode_img_content": "url-2"}, + { + "status": "confirmed", + "bot_token": "token-2", + "ilink_bot_id": "bot-2", + "baseurl": "https://example.test", + "ilink_user_id": "wx-user", + }, + ] + ) + + ok = await channel._qr_login() + + assert ok is True + assert channel._token == "token-2" + assert channel.config.base_url == "https://example.test" + + +@pytest.mark.asyncio +async def test_qr_login_returns_false_after_too_many_expired_qr_codes() -> None: + channel, _bus = _make_channel() + channel._running = True + channel._print_qr_code = lambda url: None + channel._api_get = AsyncMock( + side_effect=[ + {"qrcode": "qr-1", "qrcode_img_content": "url-1"}, + {"status": "expired"}, + {"qrcode": "qr-2", "qrcode_img_content": "url-2"}, + {"status": "expired"}, + {"qrcode": "qr-3", "qrcode_img_content": "url-3"}, + {"status": "expired"}, + {"qrcode": "qr-4", "qrcode_img_content": "url-4"}, + {"status": "expired"}, + ] + ) + + ok = await channel._qr_login() + + assert ok is False + + @pytest.mark.asyncio async def test_process_message_skips_bot_messages() -> None: channel, bus = _make_channel() From 0dad6124a2f973e9efd0f32c73a0a388a76b35df Mon Sep 17 00:00:00 2001 From: xcosmosbox <2162381070@qq.com> Date: Tue, 24 Mar 2026 14:57:51 +0800 Subject: [PATCH 24/40] chore(WeiXin): version migration and compatibility update --- nanobot/channels/weixin.py | 3 ++- tests/channels/test_weixin_channel.py | 5 +++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/nanobot/channels/weixin.py b/nanobot/channels/weixin.py index 5ea887f0298..2e25b35692f 100644 --- a/nanobot/channels/weixin.py +++ b/nanobot/channels/weixin.py @@ -53,7 +53,8 @@ MESSAGE_STATE_FINISH = 2 WEIXIN_MAX_MESSAGE_LEN = 4000 -BASE_INFO: dict[str, str] = {"channel_version": "1.0.2"} +WEIXIN_CHANNEL_VERSION = "1.0.3" +BASE_INFO: dict[str, str] = {"channel_version": WEIXIN_CHANNEL_VERSION} # Session-expired error code ERRCODE_SESSION_EXPIRED = -14 diff --git a/tests/channels/test_weixin_channel.py b/tests/channels/test_weixin_channel.py index 818e45d9827..54d9bd93f91 100644 --- a/tests/channels/test_weixin_channel.py +++ b/tests/channels/test_weixin_channel.py @@ -11,6 +11,7 @@ ITEM_IMAGE, ITEM_TEXT, MESSAGE_TYPE_BOT, + WEIXIN_CHANNEL_VERSION, WeixinChannel, WeixinConfig, ) @@ -43,6 +44,10 @@ def test_make_headers_includes_route_tag_when_configured() -> None: assert headers["SKRouteTag"] == "123" +def test_channel_version_matches_reference_plugin_version() -> None: + assert WEIXIN_CHANNEL_VERSION == "1.0.3" + + def test_save_and_load_state_persists_context_tokens(tmp_path) -> None: bus = MessageBus() channel = WeixinChannel( From 0ccfcf6588420eaf485bd14892b2bf3ee1db4e78 Mon Sep 17 00:00:00 2001 From: xcosmosbox <2162381070@qq.com> Date: Tue, 24 Mar 2026 15:51:15 +0800 Subject: [PATCH 25/40] fix(WeiXin): version migration --- README.md | 1 + nanobot/channels/weixin.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 5ec339701b3..448351fdd69 100644 --- a/README.md +++ b/README.md @@ -757,6 +757,7 @@ pip install -e ".[weixin]" > - `allowFrom`: Add the sender ID you see in nanobot logs for your WeChat account. Use `["*"]` to allow all users. > - `token`: Optional. If omitted, log in interactively and nanobot will save the token for you. +> - `routeTag`: Optional. When your upstream Weixin deployment requires request routing, nanobot will send it as the `SKRouteTag` header. > - `stateDir`: Optional. Defaults to nanobot's runtime directory for Weixin state. > - `pollTimeout`: Optional long-poll timeout in seconds. diff --git a/nanobot/channels/weixin.py b/nanobot/channels/weixin.py index 2e25b35692f..3fbe329aa52 100644 --- a/nanobot/channels/weixin.py +++ b/nanobot/channels/weixin.py @@ -4,7 +4,7 @@ No WebSocket, no local WeChat client needed — just HTTP requests with a bot token obtained via QR code login. -Protocol reverse-engineered from ``@tencent-weixin/openclaw-weixin`` v1.0.2. +Protocol reverse-engineered from ``@tencent-weixin/openclaw-weixin`` v1.0.3. """ from __future__ import annotations @@ -799,7 +799,7 @@ async def _send_media_file( ) -> None: """Upload a local file to WeChat CDN and send it as a media message. - Follows the exact protocol from ``@tencent-weixin/openclaw-weixin`` v1.0.2: + Follows the exact protocol from ``@tencent-weixin/openclaw-weixin`` v1.0.3: 1. Generate a random 16-byte AES key (client-side). 2. Call ``getuploadurl`` with file metadata + hex-encoded AES key. 3. AES-128-ECB encrypt the file and POST to CDN (``{cdnBaseUrl}/upload``). From b7df3a0aea71abb266ccaf96813129dfd9598cf7 Mon Sep 17 00:00:00 2001 From: Seeratul <126798754+Seeratul@users.noreply.github.com> Date: Tue, 24 Mar 2026 21:41:58 +0100 Subject: [PATCH 26/40] Update README with group policy clarification Clarify group policy behavior for bot responses in group channels. --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 448351fdd69..d32a53ad052 100644 --- a/README.md +++ b/README.md @@ -381,6 +381,7 @@ If you prefer to configure manually, add the following to `~/.nanobot/config.jso > - `"mention"` (default) — Only respond when @mentioned > - `"open"` — Respond to all messages > DMs always respond when the sender is in `allowFrom`. +> - If you set group policy to open create new threads as private threads and then @ the bot into it. Otherwise bot the thread itself and the channel will spawn a bot session **5. Invite the bot** - OAuth2 → URL Generator From 321214e2e0c03415b5d4c872890508b834329a7f Mon Sep 17 00:00:00 2001 From: Seeratul <126798754+Seeratul@users.noreply.github.com> Date: Tue, 24 Mar 2026 21:43:22 +0100 Subject: [PATCH 27/40] Update group policy explanation in README Clarified instructions for group policy behavior in README. --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index d32a53ad052..270f61b627f 100644 --- a/README.md +++ b/README.md @@ -381,7 +381,7 @@ If you prefer to configure manually, add the following to `~/.nanobot/config.jso > - `"mention"` (default) — Only respond when @mentioned > - `"open"` — Respond to all messages > DMs always respond when the sender is in `allowFrom`. -> - If you set group policy to open create new threads as private threads and then @ the bot into it. Otherwise bot the thread itself and the channel will spawn a bot session +> - If you set group policy to open create new threads as private threads and then @ the bot into it. Otherwise the thread itself and the channel in which you spawned it will spawn a bot session. **5. Invite the bot** - OAuth2 → URL Generator From 263069583d921a30858de6e58e03f49b0fd12703 Mon Sep 17 00:00:00 2001 From: Xubin Ren Date: Wed, 25 Mar 2026 01:22:21 +0000 Subject: [PATCH 28/40] fix(provider): accept plain text OpenAI-compatible responses Handle string and dict-shaped responses from OpenAI-compatible backends so non-standard providers no longer crash on missing choices fields. Add regression tests to keep SDK, dict, and plain-text parsing paths aligned. --- nanobot/providers/openai_compat_provider.py | 178 +++++++++++++++++--- tests/providers/test_custom_provider.py | 38 +++++ 2 files changed, 197 insertions(+), 19 deletions(-) diff --git a/nanobot/providers/openai_compat_provider.py b/nanobot/providers/openai_compat_provider.py index a210bf72d2c..a69a716b132 100644 --- a/nanobot/providers/openai_compat_provider.py +++ b/nanobot/providers/openai_compat_provider.py @@ -193,7 +193,126 @@ def _build_kwargs( # Response parsing # ------------------------------------------------------------------ + @staticmethod + def _maybe_mapping(value: Any) -> dict[str, Any] | None: + if isinstance(value, dict): + return value + model_dump = getattr(value, "model_dump", None) + if callable(model_dump): + dumped = model_dump() + if isinstance(dumped, dict): + return dumped + return None + + @classmethod + def _extract_text_content(cls, value: Any) -> str | None: + if value is None: + return None + if isinstance(value, str): + return value + if isinstance(value, list): + parts: list[str] = [] + for item in value: + item_map = cls._maybe_mapping(item) + if item_map: + text = item_map.get("text") + if isinstance(text, str): + parts.append(text) + continue + text = getattr(item, "text", None) + if isinstance(text, str): + parts.append(text) + continue + if isinstance(item, str): + parts.append(item) + return "".join(parts) or None + return str(value) + + @classmethod + def _extract_usage(cls, response: Any) -> dict[str, int]: + usage_obj = None + response_map = cls._maybe_mapping(response) + if response_map is not None: + usage_obj = response_map.get("usage") + elif hasattr(response, "usage") and response.usage: + usage_obj = response.usage + + usage_map = cls._maybe_mapping(usage_obj) + if usage_map is not None: + return { + "prompt_tokens": int(usage_map.get("prompt_tokens") or 0), + "completion_tokens": int(usage_map.get("completion_tokens") or 0), + "total_tokens": int(usage_map.get("total_tokens") or 0), + } + + if usage_obj: + return { + "prompt_tokens": getattr(usage_obj, "prompt_tokens", 0) or 0, + "completion_tokens": getattr(usage_obj, "completion_tokens", 0) or 0, + "total_tokens": getattr(usage_obj, "total_tokens", 0) or 0, + } + return {} + def _parse(self, response: Any) -> LLMResponse: + if isinstance(response, str): + return LLMResponse(content=response, finish_reason="stop") + + response_map = self._maybe_mapping(response) + if response_map is not None: + choices = response_map.get("choices") or [] + if not choices: + content = self._extract_text_content( + response_map.get("content") or response_map.get("output_text") + ) + if content is not None: + return LLMResponse( + content=content, + finish_reason=str(response_map.get("finish_reason") or "stop"), + usage=self._extract_usage(response_map), + ) + return LLMResponse(content="Error: API returned empty choices.", finish_reason="error") + + choice0 = self._maybe_mapping(choices[0]) or {} + msg0 = self._maybe_mapping(choice0.get("message")) or {} + content = self._extract_text_content(msg0.get("content")) + finish_reason = str(choice0.get("finish_reason") or "stop") + + raw_tool_calls: list[Any] = [] + reasoning_content = msg0.get("reasoning_content") + for ch in choices: + ch_map = self._maybe_mapping(ch) or {} + m = self._maybe_mapping(ch_map.get("message")) or {} + tool_calls = m.get("tool_calls") + if isinstance(tool_calls, list) and tool_calls: + raw_tool_calls.extend(tool_calls) + if ch_map.get("finish_reason") in ("tool_calls", "stop"): + finish_reason = str(ch_map["finish_reason"]) + if not content: + content = self._extract_text_content(m.get("content")) + if not reasoning_content: + reasoning_content = m.get("reasoning_content") + + parsed_tool_calls = [] + for tc in raw_tool_calls: + tc_map = self._maybe_mapping(tc) or {} + fn = self._maybe_mapping(tc_map.get("function")) or {} + args = fn.get("arguments", {}) + if isinstance(args, str): + args = json_repair.loads(args) + parsed_tool_calls.append(ToolCallRequest( + id=_short_tool_id(), + name=str(fn.get("name") or ""), + arguments=args if isinstance(args, dict) else {}, + )) + + return LLMResponse( + content=content, + tool_calls=parsed_tool_calls, + finish_reason=finish_reason, + usage=self._extract_usage(response_map), + reasoning_content=reasoning_content if isinstance(reasoning_content, str) else None, + ) + if not response.choices: return LLMResponse(content="Error: API returned empty choices.", finish_reason="error") @@ -223,39 +342,60 @@ def _parse(self, response: Any) -> LLMResponse: arguments=args, )) - usage: dict[str, int] = {} - if hasattr(response, "usage") and response.usage: - u = response.usage - usage = { - "prompt_tokens": u.prompt_tokens or 0, - "completion_tokens": u.completion_tokens or 0, - "total_tokens": u.total_tokens or 0, - } - return LLMResponse( content=content, tool_calls=tool_calls, finish_reason=finish_reason or "stop", - usage=usage, + usage=self._extract_usage(response), reasoning_content=getattr(msg, "reasoning_content", None) or None, ) - @staticmethod - def _parse_chunks(chunks: list[Any]) -> LLMResponse: + @classmethod + def _parse_chunks(cls, chunks: list[Any]) -> LLMResponse: content_parts: list[str] = [] tc_bufs: dict[int, dict[str, str]] = {} finish_reason = "stop" usage: dict[str, int] = {} for chunk in chunks: + if isinstance(chunk, str): + content_parts.append(chunk) + continue + + chunk_map = cls._maybe_mapping(chunk) + if chunk_map is not None: + choices = chunk_map.get("choices") or [] + if not choices: + usage = cls._extract_usage(chunk_map) or usage + text = cls._extract_text_content( + chunk_map.get("content") or chunk_map.get("output_text") + ) + if text: + content_parts.append(text) + continue + choice = cls._maybe_mapping(choices[0]) or {} + if choice.get("finish_reason"): + finish_reason = str(choice["finish_reason"]) + delta = cls._maybe_mapping(choice.get("delta")) or {} + text = cls._extract_text_content(delta.get("content")) + if text: + content_parts.append(text) + for idx, tc in enumerate(delta.get("tool_calls") or []): + tc_map = cls._maybe_mapping(tc) or {} + tc_index = tc_map.get("index", idx) + buf = tc_bufs.setdefault(tc_index, {"id": "", "name": "", "arguments": ""}) + if tc_map.get("id"): + buf["id"] = str(tc_map["id"]) + fn = cls._maybe_mapping(tc_map.get("function")) or {} + if fn.get("name"): + buf["name"] = str(fn["name"]) + if fn.get("arguments"): + buf["arguments"] += str(fn["arguments"]) + usage = cls._extract_usage(chunk_map) or usage + continue + if not chunk.choices: - if hasattr(chunk, "usage") and chunk.usage: - u = chunk.usage - usage = { - "prompt_tokens": u.prompt_tokens or 0, - "completion_tokens": u.completion_tokens or 0, - "total_tokens": u.total_tokens or 0, - } + usage = cls._extract_usage(chunk) or usage continue choice = chunk.choices[0] if choice.finish_reason: diff --git a/tests/providers/test_custom_provider.py b/tests/providers/test_custom_provider.py index bb46b887a2f..d2a9f42473b 100644 --- a/tests/providers/test_custom_provider.py +++ b/tests/providers/test_custom_provider.py @@ -15,3 +15,41 @@ def test_custom_provider_parse_handles_empty_choices() -> None: assert result.finish_reason == "error" assert "empty choices" in result.content + + +def test_custom_provider_parse_accepts_plain_string_response() -> None: + with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"): + provider = OpenAICompatProvider() + + result = provider._parse("hello from backend") + + assert result.finish_reason == "stop" + assert result.content == "hello from backend" + + +def test_custom_provider_parse_accepts_dict_response() -> None: + with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"): + provider = OpenAICompatProvider() + + result = provider._parse({ + "choices": [{ + "message": {"content": "hello from dict"}, + "finish_reason": "stop", + }], + "usage": { + "prompt_tokens": 1, + "completion_tokens": 2, + "total_tokens": 3, + }, + }) + + assert result.finish_reason == "stop" + assert result.content == "hello from dict" + assert result.usage["total_tokens"] == 3 + + +def test_custom_provider_parse_chunks_accepts_plain_text_chunks() -> None: + result = OpenAICompatProvider._parse_chunks(["hello ", "world"]) + + assert result.finish_reason == "stop" + assert result.content == "hello world" From 7b720ce9f779d0eb86255455292f1dd09081530f Mon Sep 17 00:00:00 2001 From: Yohei Nishikubo Date: Wed, 25 Mar 2026 09:31:42 +0900 Subject: [PATCH 29/40] feat(OpenAICompatProvider): enhance tool call handling with provider-specific fields --- nanobot/providers/openai_compat_provider.py | 71 ++++++++++++++++++--- tests/providers/test_litellm_kwargs.py | 54 ++++++++++++++++ 2 files changed, 116 insertions(+), 9 deletions(-) diff --git a/nanobot/providers/openai_compat_provider.py b/nanobot/providers/openai_compat_provider.py index a69a716b132..866e05ef820 100644 --- a/nanobot/providers/openai_compat_provider.py +++ b/nanobot/providers/openai_compat_provider.py @@ -24,6 +24,32 @@ _ALNUM = string.ascii_letters + string.digits +def _get_attr_or_item(obj: Any, key: str, default: Any = None) -> Any: + """Read an attribute or dict key from provider SDK objects.""" + if obj is None: + return default + if isinstance(obj, dict): + return obj.get(key, default) + return getattr(obj, key, default) + + +def _coerce_dict(value: Any) -> dict[str, Any] | None: + """Return a shallow dict if the value looks mapping-like.""" + if isinstance(value, dict): + return dict(value) + return None + + +def _extract_tool_call_fields(tc: Any) -> tuple[dict[str, Any] | None, dict[str, Any] | None]: + """Extract provider-specific metadata from a tool call object.""" + provider_specific_fields = _coerce_dict(_get_attr_or_item(tc, "provider_specific_fields")) + function = _get_attr_or_item(tc, "function") + function_provider_specific_fields = _coerce_dict( + _get_attr_or_item(function, "provider_specific_fields") + ) + return provider_specific_fields, function_provider_specific_fields + + def _short_tool_id() -> str: """9-char alphanumeric ID compatible with all providers (incl. Mistral).""" return "".join(secrets.choice(_ALNUM) for _ in range(9)) @@ -333,13 +359,17 @@ def _parse(self, response: Any) -> LLMResponse: tool_calls = [] for tc in raw_tool_calls: - args = tc.function.arguments + function = _get_attr_or_item(tc, "function") + args = _get_attr_or_item(function, "arguments") if isinstance(args, str): args = json_repair.loads(args) + provider_specific_fields, function_provider_specific_fields = _extract_tool_call_fields(tc) tool_calls.append(ToolCallRequest( id=_short_tool_id(), - name=tc.function.name, + name=_get_attr_or_item(function, "name", ""), arguments=args, + provider_specific_fields=provider_specific_fields, + function_provider_specific_fields=function_provider_specific_fields, )) return LLMResponse( @@ -404,13 +434,34 @@ def _parse_chunks(cls, chunks: list[Any]) -> LLMResponse: if delta and delta.content: content_parts.append(delta.content) for tc in (delta.tool_calls or []) if delta else []: - buf = tc_bufs.setdefault(tc.index, {"id": "", "name": "", "arguments": ""}) - if tc.id: - buf["id"] = tc.id - if tc.function and tc.function.name: - buf["name"] = tc.function.name - if tc.function and tc.function.arguments: - buf["arguments"] += tc.function.arguments + idx = _get_attr_or_item(tc, "index") + if idx is None: + continue + buf = tc_bufs.setdefault( + idx, + { + "id": "", + "name": "", + "arguments": "", + "provider_specific_fields": None, + "function_provider_specific_fields": None, + }, + ) + tc_id = _get_attr_or_item(tc, "id") + if tc_id: + buf["id"] = tc_id + function = _get_attr_or_item(tc, "function") + function_name = _get_attr_or_item(function, "name") + if function_name: + buf["name"] = function_name + arguments = _get_attr_or_item(function, "arguments") + if arguments: + buf["arguments"] += arguments + provider_specific_fields, function_provider_specific_fields = _extract_tool_call_fields(tc) + if provider_specific_fields: + buf["provider_specific_fields"] = provider_specific_fields + if function_provider_specific_fields: + buf["function_provider_specific_fields"] = function_provider_specific_fields return LLMResponse( content="".join(content_parts) or None, @@ -419,6 +470,8 @@ def _parse_chunks(cls, chunks: list[Any]) -> LLMResponse: id=b["id"] or _short_tool_id(), name=b["name"], arguments=json_repair.loads(b["arguments"]) if b["arguments"] else {}, + provider_specific_fields=b["provider_specific_fields"], + function_provider_specific_fields=b["function_provider_specific_fields"], ) for b in tc_bufs.values() ], diff --git a/tests/providers/test_litellm_kwargs.py b/tests/providers/test_litellm_kwargs.py index c55857b3b78..4d157207576 100644 --- a/tests/providers/test_litellm_kwargs.py +++ b/tests/providers/test_litellm_kwargs.py @@ -29,6 +29,29 @@ def _fake_chat_response(content: str = "ok") -> SimpleNamespace: return SimpleNamespace(choices=[choice], usage=usage) +def _fake_tool_call_response() -> SimpleNamespace: + """Build a minimal chat response that includes Gemini-style provider fields.""" + function = SimpleNamespace( + name="exec", + arguments='{"cmd":"ls"}', + provider_specific_fields={"inner": "value"}, + ) + tool_call = SimpleNamespace( + id="call_123", + index=0, + function=function, + provider_specific_fields={"thought_signature": "signed-token"}, + ) + message = SimpleNamespace( + content=None, + tool_calls=[tool_call], + reasoning_content=None, + ) + choice = SimpleNamespace(message=message, finish_reason="tool_calls") + usage = SimpleNamespace(prompt_tokens=10, completion_tokens=5, total_tokens=15) + return SimpleNamespace(choices=[choice], usage=usage) + + def test_openrouter_spec_is_gateway() -> None: spec = find_by_name("openrouter") assert spec is not None @@ -110,6 +133,37 @@ async def test_standard_provider_passes_model_through() -> None: assert call_kwargs["model"] == "deepseek-chat" +@pytest.mark.asyncio +async def test_openai_compat_preserves_provider_specific_fields_on_tool_calls() -> None: + """Gemini thought signatures must survive parsing so they can be sent back.""" + mock_create = AsyncMock(return_value=_fake_tool_call_response()) + spec = find_by_name("gemini") + + with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI") as MockClient: + client_instance = MockClient.return_value + client_instance.chat.completions.create = mock_create + + provider = OpenAICompatProvider( + api_key="test-key", + api_base="https://generativelanguage.googleapis.com/v1beta/openai/", + default_model="google/gemini-3.1-pro-preview", + spec=spec, + ) + result = await provider.chat( + messages=[{"role": "user", "content": "run exec"}], + model="google/gemini-3.1-pro-preview", + ) + + assert len(result.tool_calls) == 1 + tool_call = result.tool_calls[0] + assert tool_call.provider_specific_fields == {"thought_signature": "signed-token"} + assert tool_call.function_provider_specific_fields == {"inner": "value"} + + serialized = tool_call.to_openai_tool_call() + assert serialized["provider_specific_fields"] == {"thought_signature": "signed-token"} + assert serialized["function"]["provider_specific_fields"] == {"inner": "value"} + + def test_openai_model_passthrough() -> None: """OpenAI models pass through unchanged.""" spec = find_by_name("openai") From af84b1b8c0278f4c3a2fa208ebf1efbad54953e1 Mon Sep 17 00:00:00 2001 From: Yohei Nishikubo Date: Wed, 25 Mar 2026 09:40:21 +0900 Subject: [PATCH 30/40] fix(Gemini): update ToolCallRequest and OpenAICompatProvider to handle thought signatures in extra_content --- nanobot/providers/base.py | 16 +++++++++++++++- nanobot/providers/openai_compat_provider.py | 7 +++++++ tests/agent/test_gemini_thought_signature.py | 2 +- tests/providers/test_litellm_kwargs.py | 4 ++-- 4 files changed, 25 insertions(+), 4 deletions(-) diff --git a/nanobot/providers/base.py b/nanobot/providers/base.py index 046458dec30..1fd610b9144 100644 --- a/nanobot/providers/base.py +++ b/nanobot/providers/base.py @@ -30,7 +30,21 @@ def to_openai_tool_call(self) -> dict[str, Any]: }, } if self.provider_specific_fields: - tool_call["provider_specific_fields"] = self.provider_specific_fields + # Gemini OpenAI compatibility expects thought signatures in extra_content.google. + if "thought_signature" in self.provider_specific_fields: + tool_call["extra_content"] = { + "google": { + "thought_signature": self.provider_specific_fields["thought_signature"], + } + } + other_fields = { + k: v for k, v in self.provider_specific_fields.items() + if k != "thought_signature" + } + if other_fields: + tool_call["provider_specific_fields"] = other_fields + else: + tool_call["provider_specific_fields"] = self.provider_specific_fields if self.function_provider_specific_fields: tool_call["function"]["provider_specific_fields"] = self.function_provider_specific_fields return tool_call diff --git a/nanobot/providers/openai_compat_provider.py b/nanobot/providers/openai_compat_provider.py index 866e05ef820..1157e176df2 100644 --- a/nanobot/providers/openai_compat_provider.py +++ b/nanobot/providers/openai_compat_provider.py @@ -43,6 +43,13 @@ def _coerce_dict(value: Any) -> dict[str, Any] | None: def _extract_tool_call_fields(tc: Any) -> tuple[dict[str, Any] | None, dict[str, Any] | None]: """Extract provider-specific metadata from a tool call object.""" provider_specific_fields = _coerce_dict(_get_attr_or_item(tc, "provider_specific_fields")) + extra_content = _coerce_dict(_get_attr_or_item(tc, "extra_content")) + google_content = _coerce_dict(_get_attr_or_item(extra_content, "google")) if extra_content else None + if google_content: + provider_specific_fields = { + **(provider_specific_fields or {}), + **google_content, + } function = _get_attr_or_item(tc, "function") function_provider_specific_fields = _coerce_dict( _get_attr_or_item(function, "provider_specific_fields") diff --git a/tests/agent/test_gemini_thought_signature.py b/tests/agent/test_gemini_thought_signature.py index 35739602abb..f4b279b6529 100644 --- a/tests/agent/test_gemini_thought_signature.py +++ b/tests/agent/test_gemini_thought_signature.py @@ -14,6 +14,6 @@ def test_tool_call_request_serializes_provider_fields() -> None: message = tool_call.to_openai_tool_call() - assert message["provider_specific_fields"] == {"thought_signature": "signed-token"} + assert message["extra_content"] == {"google": {"thought_signature": "signed-token"}} assert message["function"]["provider_specific_fields"] == {"inner": "value"} assert message["function"]["arguments"] == '{"path": "todo.md"}' diff --git a/tests/providers/test_litellm_kwargs.py b/tests/providers/test_litellm_kwargs.py index 4d157207576..e912a7bfd6b 100644 --- a/tests/providers/test_litellm_kwargs.py +++ b/tests/providers/test_litellm_kwargs.py @@ -40,7 +40,7 @@ def _fake_tool_call_response() -> SimpleNamespace: id="call_123", index=0, function=function, - provider_specific_fields={"thought_signature": "signed-token"}, + extra_content={"google": {"thought_signature": "signed-token"}}, ) message = SimpleNamespace( content=None, @@ -160,7 +160,7 @@ async def test_openai_compat_preserves_provider_specific_fields_on_tool_calls() assert tool_call.function_provider_specific_fields == {"inner": "value"} serialized = tool_call.to_openai_tool_call() - assert serialized["provider_specific_fields"] == {"thought_signature": "signed-token"} + assert serialized["extra_content"] == {"google": {"thought_signature": "signed-token"}} assert serialized["function"]["provider_specific_fields"] == {"inner": "value"} From b5302b6f3da12e39caad98e9a82fce47880d5c77 Mon Sep 17 00:00:00 2001 From: Xubin Ren Date: Wed, 25 Mar 2026 01:56:44 +0000 Subject: [PATCH 31/40] refactor(provider): preserve extra_content verbatim for Gemini thought_signature round-trip Replace the flatten/unflatten approach (merging extra_content.google.* into provider_specific_fields then reconstructing) with direct pass-through: parse extra_content as-is, store on ToolCallRequest.extra_content, serialize back untouched. This is lossless, requires no hardcoded field names, and covers all three parsing branches (str, dict, SDK object) plus streaming. --- nanobot/providers/base.py | 19 +- nanobot/providers/openai_compat_provider.py | 172 ++++++++++------- tests/agent/test_gemini_thought_signature.py | 193 ++++++++++++++++++- tests/providers/test_litellm_kwargs.py | 9 +- 4 files changed, 293 insertions(+), 100 deletions(-) diff --git a/nanobot/providers/base.py b/nanobot/providers/base.py index 1fd610b9144..9ce2b0c630e 100644 --- a/nanobot/providers/base.py +++ b/nanobot/providers/base.py @@ -16,6 +16,7 @@ class ToolCallRequest: id: str name: str arguments: dict[str, Any] + extra_content: dict[str, Any] | None = None provider_specific_fields: dict[str, Any] | None = None function_provider_specific_fields: dict[str, Any] | None = None @@ -29,22 +30,10 @@ def to_openai_tool_call(self) -> dict[str, Any]: "arguments": json.dumps(self.arguments, ensure_ascii=False), }, } + if self.extra_content: + tool_call["extra_content"] = self.extra_content if self.provider_specific_fields: - # Gemini OpenAI compatibility expects thought signatures in extra_content.google. - if "thought_signature" in self.provider_specific_fields: - tool_call["extra_content"] = { - "google": { - "thought_signature": self.provider_specific_fields["thought_signature"], - } - } - other_fields = { - k: v for k, v in self.provider_specific_fields.items() - if k != "thought_signature" - } - if other_fields: - tool_call["provider_specific_fields"] = other_fields - else: - tool_call["provider_specific_fields"] = self.provider_specific_fields + tool_call["provider_specific_fields"] = self.provider_specific_fields if self.function_provider_specific_fields: tool_call["function"]["provider_specific_fields"] = self.function_provider_specific_fields return tool_call diff --git a/nanobot/providers/openai_compat_provider.py b/nanobot/providers/openai_compat_provider.py index 1157e176df2..ffb221e5094 100644 --- a/nanobot/providers/openai_compat_provider.py +++ b/nanobot/providers/openai_compat_provider.py @@ -19,47 +19,74 @@ from nanobot.providers.registry import ProviderSpec _ALLOWED_MSG_KEYS = frozenset({ - "role", "content", "tool_calls", "tool_call_id", "name", "reasoning_content", + "role", "content", "tool_calls", "tool_call_id", "name", + "reasoning_content", "extra_content", }) _ALNUM = string.ascii_letters + string.digits +_STANDARD_TC_KEYS = frozenset({"id", "type", "index", "function"}) +_STANDARD_FN_KEYS = frozenset({"name", "arguments"}) -def _get_attr_or_item(obj: Any, key: str, default: Any = None) -> Any: - """Read an attribute or dict key from provider SDK objects.""" - if obj is None: - return default + +def _short_tool_id() -> str: + """9-char alphanumeric ID compatible with all providers (incl. Mistral).""" + return "".join(secrets.choice(_ALNUM) for _ in range(9)) + + +def _get(obj: Any, key: str) -> Any: + """Get a value from dict or object attribute, returning None if absent.""" if isinstance(obj, dict): - return obj.get(key, default) - return getattr(obj, key, default) + return obj.get(key) + return getattr(obj, key, None) def _coerce_dict(value: Any) -> dict[str, Any] | None: - """Return a shallow dict if the value looks mapping-like.""" + """Try to coerce *value* to a dict; return None if not possible or empty.""" + if value is None: + return None if isinstance(value, dict): - return dict(value) + return value if value else None + model_dump = getattr(value, "model_dump", None) + if callable(model_dump): + dumped = model_dump() + if isinstance(dumped, dict) and dumped: + return dumped return None -def _extract_tool_call_fields(tc: Any) -> tuple[dict[str, Any] | None, dict[str, Any] | None]: - """Extract provider-specific metadata from a tool call object.""" - provider_specific_fields = _coerce_dict(_get_attr_or_item(tc, "provider_specific_fields")) - extra_content = _coerce_dict(_get_attr_or_item(tc, "extra_content")) - google_content = _coerce_dict(_get_attr_or_item(extra_content, "google")) if extra_content else None - if google_content: - provider_specific_fields = { - **(provider_specific_fields or {}), - **google_content, - } - function = _get_attr_or_item(tc, "function") - function_provider_specific_fields = _coerce_dict( - _get_attr_or_item(function, "provider_specific_fields") - ) - return provider_specific_fields, function_provider_specific_fields - +def _extract_tc_extras(tc: Any) -> tuple[ + dict[str, Any] | None, + dict[str, Any] | None, + dict[str, Any] | None, +]: + """Extract (extra_content, provider_specific_fields, fn_provider_specific_fields). -def _short_tool_id() -> str: - """9-char alphanumeric ID compatible with all providers (incl. Mistral).""" - return "".join(secrets.choice(_ALNUM) for _ in range(9)) + Works for both SDK objects and dicts. Captures Gemini ``extra_content`` + verbatim and any non-standard keys on the tool-call / function. + """ + extra_content = _coerce_dict(_get(tc, "extra_content")) + + tc_dict = _coerce_dict(tc) + prov = None + fn_prov = None + if tc_dict is not None: + leftover = {k: v for k, v in tc_dict.items() + if k not in _STANDARD_TC_KEYS and k != "extra_content" and v is not None} + if leftover: + prov = leftover + fn = _coerce_dict(tc_dict.get("function")) + if fn is not None: + fn_leftover = {k: v for k, v in fn.items() + if k not in _STANDARD_FN_KEYS and v is not None} + if fn_leftover: + fn_prov = fn_leftover + else: + prov = _coerce_dict(_get(tc, "provider_specific_fields")) + fn_obj = _get(tc, "function") + if fn_obj is not None: + fn_prov = _coerce_dict(_get(fn_obj, "provider_specific_fields")) + + return extra_content, prov, fn_prov class OpenAICompatProvider(LLMProvider): @@ -332,10 +359,14 @@ def _parse(self, response: Any) -> LLMResponse: args = fn.get("arguments", {}) if isinstance(args, str): args = json_repair.loads(args) + ec, prov, fn_prov = _extract_tc_extras(tc) parsed_tool_calls.append(ToolCallRequest( id=_short_tool_id(), name=str(fn.get("name") or ""), arguments=args if isinstance(args, dict) else {}, + extra_content=ec, + provider_specific_fields=prov, + function_provider_specific_fields=fn_prov, )) return LLMResponse( @@ -366,17 +397,17 @@ def _parse(self, response: Any) -> LLMResponse: tool_calls = [] for tc in raw_tool_calls: - function = _get_attr_or_item(tc, "function") - args = _get_attr_or_item(function, "arguments") + args = tc.function.arguments if isinstance(args, str): args = json_repair.loads(args) - provider_specific_fields, function_provider_specific_fields = _extract_tool_call_fields(tc) + ec, prov, fn_prov = _extract_tc_extras(tc) tool_calls.append(ToolCallRequest( id=_short_tool_id(), - name=_get_attr_or_item(function, "name", ""), + name=tc.function.name, arguments=args, - provider_specific_fields=provider_specific_fields, - function_provider_specific_fields=function_provider_specific_fields, + extra_content=ec, + provider_specific_fields=prov, + function_provider_specific_fields=fn_prov, )) return LLMResponse( @@ -390,10 +421,36 @@ def _parse(self, response: Any) -> LLMResponse: @classmethod def _parse_chunks(cls, chunks: list[Any]) -> LLMResponse: content_parts: list[str] = [] - tc_bufs: dict[int, dict[str, str]] = {} + tc_bufs: dict[int, dict[str, Any]] = {} finish_reason = "stop" usage: dict[str, int] = {} + def _accum_tc(tc: Any, idx_hint: int) -> None: + """Accumulate one streaming tool-call delta into *tc_bufs*.""" + tc_index: int = _get(tc, "index") if _get(tc, "index") is not None else idx_hint + buf = tc_bufs.setdefault(tc_index, { + "id": "", "name": "", "arguments": "", + "extra_content": None, "prov": None, "fn_prov": None, + }) + tc_id = _get(tc, "id") + if tc_id: + buf["id"] = str(tc_id) + fn = _get(tc, "function") + if fn is not None: + fn_name = _get(fn, "name") + if fn_name: + buf["name"] = str(fn_name) + fn_args = _get(fn, "arguments") + if fn_args: + buf["arguments"] += str(fn_args) + ec, prov, fn_prov = _extract_tc_extras(tc) + if ec: + buf["extra_content"] = ec + if prov: + buf["prov"] = prov + if fn_prov: + buf["fn_prov"] = fn_prov + for chunk in chunks: if isinstance(chunk, str): content_parts.append(chunk) @@ -418,16 +475,7 @@ def _parse_chunks(cls, chunks: list[Any]) -> LLMResponse: if text: content_parts.append(text) for idx, tc in enumerate(delta.get("tool_calls") or []): - tc_map = cls._maybe_mapping(tc) or {} - tc_index = tc_map.get("index", idx) - buf = tc_bufs.setdefault(tc_index, {"id": "", "name": "", "arguments": ""}) - if tc_map.get("id"): - buf["id"] = str(tc_map["id"]) - fn = cls._maybe_mapping(tc_map.get("function")) or {} - if fn.get("name"): - buf["name"] = str(fn["name"]) - if fn.get("arguments"): - buf["arguments"] += str(fn["arguments"]) + _accum_tc(tc, idx) usage = cls._extract_usage(chunk_map) or usage continue @@ -441,34 +489,7 @@ def _parse_chunks(cls, chunks: list[Any]) -> LLMResponse: if delta and delta.content: content_parts.append(delta.content) for tc in (delta.tool_calls or []) if delta else []: - idx = _get_attr_or_item(tc, "index") - if idx is None: - continue - buf = tc_bufs.setdefault( - idx, - { - "id": "", - "name": "", - "arguments": "", - "provider_specific_fields": None, - "function_provider_specific_fields": None, - }, - ) - tc_id = _get_attr_or_item(tc, "id") - if tc_id: - buf["id"] = tc_id - function = _get_attr_or_item(tc, "function") - function_name = _get_attr_or_item(function, "name") - if function_name: - buf["name"] = function_name - arguments = _get_attr_or_item(function, "arguments") - if arguments: - buf["arguments"] += arguments - provider_specific_fields, function_provider_specific_fields = _extract_tool_call_fields(tc) - if provider_specific_fields: - buf["provider_specific_fields"] = provider_specific_fields - if function_provider_specific_fields: - buf["function_provider_specific_fields"] = function_provider_specific_fields + _accum_tc(tc, getattr(tc, "index", 0)) return LLMResponse( content="".join(content_parts) or None, @@ -477,8 +498,9 @@ def _parse_chunks(cls, chunks: list[Any]) -> LLMResponse: id=b["id"] or _short_tool_id(), name=b["name"], arguments=json_repair.loads(b["arguments"]) if b["arguments"] else {}, - provider_specific_fields=b["provider_specific_fields"], - function_provider_specific_fields=b["function_provider_specific_fields"], + extra_content=b.get("extra_content"), + provider_specific_fields=b.get("prov"), + function_provider_specific_fields=b.get("fn_prov"), ) for b in tc_bufs.values() ], diff --git a/tests/agent/test_gemini_thought_signature.py b/tests/agent/test_gemini_thought_signature.py index f4b279b6529..320c1ecd2cf 100644 --- a/tests/agent/test_gemini_thought_signature.py +++ b/tests/agent/test_gemini_thought_signature.py @@ -1,19 +1,200 @@ +"""Tests for Gemini thought_signature round-trip through extra_content. + +The Gemini OpenAI-compatibility API returns tool calls with an extra_content +field: ``{"google": {"thought_signature": "..."}}``. This MUST survive the +parse → serialize round-trip so the model can continue reasoning. +""" + from types import SimpleNamespace +from unittest.mock import patch from nanobot.providers.base import ToolCallRequest +from nanobot.providers.openai_compat_provider import OpenAICompatProvider + + +GEMINI_EXTRA = {"google": {"thought_signature": "sig-abc-123"}} + + +# ── ToolCallRequest serialization ────────────────────────────────────── + +def test_tool_call_request_serializes_extra_content() -> None: + tc = ToolCallRequest( + id="abc123xyz", + name="read_file", + arguments={"path": "todo.md"}, + extra_content=GEMINI_EXTRA, + ) + + payload = tc.to_openai_tool_call() + + assert payload["extra_content"] == GEMINI_EXTRA + assert payload["function"]["arguments"] == '{"path": "todo.md"}' def test_tool_call_request_serializes_provider_fields() -> None: - tool_call = ToolCallRequest( + tc = ToolCallRequest( id="abc123xyz", name="read_file", arguments={"path": "todo.md"}, - provider_specific_fields={"thought_signature": "signed-token"}, + provider_specific_fields={"custom_key": "custom_val"}, function_provider_specific_fields={"inner": "value"}, ) - message = tool_call.to_openai_tool_call() + payload = tc.to_openai_tool_call() + + assert payload["provider_specific_fields"] == {"custom_key": "custom_val"} + assert payload["function"]["provider_specific_fields"] == {"inner": "value"} + + +def test_tool_call_request_omits_absent_extras() -> None: + tc = ToolCallRequest(id="x", name="fn", arguments={}) + payload = tc.to_openai_tool_call() + + assert "extra_content" not in payload + assert "provider_specific_fields" not in payload + assert "provider_specific_fields" not in payload["function"] + + +# ── _parse: SDK-object branch ────────────────────────────────────────── + +def _make_sdk_response_with_extra_content(): + """Simulate a Gemini response via the OpenAI SDK (SimpleNamespace).""" + fn = SimpleNamespace(name="get_weather", arguments='{"city":"Tokyo"}') + tc = SimpleNamespace( + id="call_1", + index=0, + type="function", + function=fn, + extra_content=GEMINI_EXTRA, + ) + msg = SimpleNamespace( + content=None, + tool_calls=[tc], + reasoning_content=None, + ) + choice = SimpleNamespace(message=msg, finish_reason="tool_calls") + usage = SimpleNamespace(prompt_tokens=10, completion_tokens=5, total_tokens=15) + return SimpleNamespace(choices=[choice], usage=usage) + + +def test_parse_sdk_object_preserves_extra_content() -> None: + with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"): + provider = OpenAICompatProvider() + + result = provider._parse(_make_sdk_response_with_extra_content()) + + assert len(result.tool_calls) == 1 + tc = result.tool_calls[0] + assert tc.name == "get_weather" + assert tc.extra_content == GEMINI_EXTRA + + payload = tc.to_openai_tool_call() + assert payload["extra_content"] == GEMINI_EXTRA + + +# ── _parse: dict/mapping branch ─────────────────────────────────────── + +def test_parse_dict_preserves_extra_content() -> None: + with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"): + provider = OpenAICompatProvider() + + response_dict = { + "choices": [{ + "message": { + "content": None, + "tool_calls": [{ + "id": "call_1", + "type": "function", + "function": {"name": "get_weather", "arguments": '{"city":"Tokyo"}'}, + "extra_content": GEMINI_EXTRA, + }], + }, + "finish_reason": "tool_calls", + }], + "usage": {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15}, + } + + result = provider._parse(response_dict) + + assert len(result.tool_calls) == 1 + tc = result.tool_calls[0] + assert tc.name == "get_weather" + assert tc.extra_content == GEMINI_EXTRA + + payload = tc.to_openai_tool_call() + assert payload["extra_content"] == GEMINI_EXTRA + + +# ── _parse_chunks: streaming round-trip ─────────────────────────────── + +def test_parse_chunks_sdk_preserves_extra_content() -> None: + fn_delta = SimpleNamespace(name="get_weather", arguments='{"city":"Tokyo"}') + tc_delta = SimpleNamespace( + id="call_1", + index=0, + function=fn_delta, + extra_content=GEMINI_EXTRA, + ) + delta = SimpleNamespace(content=None, tool_calls=[tc_delta]) + choice = SimpleNamespace(finish_reason="tool_calls", delta=delta) + chunk = SimpleNamespace(choices=[choice], usage=None) + + result = OpenAICompatProvider._parse_chunks([chunk]) + + assert len(result.tool_calls) == 1 + tc = result.tool_calls[0] + assert tc.extra_content == GEMINI_EXTRA + + payload = tc.to_openai_tool_call() + assert payload["extra_content"] == GEMINI_EXTRA + + +def test_parse_chunks_dict_preserves_extra_content() -> None: + chunk = { + "choices": [{ + "finish_reason": "tool_calls", + "delta": { + "content": None, + "tool_calls": [{ + "index": 0, + "id": "call_1", + "function": {"name": "get_weather", "arguments": '{"city":"Tokyo"}'}, + "extra_content": GEMINI_EXTRA, + }], + }, + }], + } + + result = OpenAICompatProvider._parse_chunks([chunk]) + + assert len(result.tool_calls) == 1 + tc = result.tool_calls[0] + assert tc.extra_content == GEMINI_EXTRA + + payload = tc.to_openai_tool_call() + assert payload["extra_content"] == GEMINI_EXTRA + + +# ── Model switching: stale extras shouldn't break other providers ───── + +def test_stale_extra_content_in_tool_calls_survives_sanitize() -> None: + """When switching from Gemini to OpenAI, extra_content inside tool_calls + should survive message sanitization (it lives inside the tool_call dict, + not at message level, so it bypasses _ALLOWED_MSG_KEYS filtering).""" + with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"): + provider = OpenAICompatProvider() + + messages = [{ + "role": "assistant", + "content": None, + "tool_calls": [{ + "id": "call_1", + "type": "function", + "function": {"name": "fn", "arguments": "{}"}, + "extra_content": GEMINI_EXTRA, + }], + }] + + sanitized = provider._sanitize_messages(messages) - assert message["extra_content"] == {"google": {"thought_signature": "signed-token"}} - assert message["function"]["provider_specific_fields"] == {"inner": "value"} - assert message["function"]["arguments"] == '{"path": "todo.md"}' + assert sanitized[0]["tool_calls"][0]["extra_content"] == GEMINI_EXTRA diff --git a/tests/providers/test_litellm_kwargs.py b/tests/providers/test_litellm_kwargs.py index e912a7bfd6b..b166cb0263a 100644 --- a/tests/providers/test_litellm_kwargs.py +++ b/tests/providers/test_litellm_kwargs.py @@ -30,7 +30,7 @@ def _fake_chat_response(content: str = "ok") -> SimpleNamespace: def _fake_tool_call_response() -> SimpleNamespace: - """Build a minimal chat response that includes Gemini-style provider fields.""" + """Build a minimal chat response that includes Gemini-style extra_content.""" function = SimpleNamespace( name="exec", arguments='{"cmd":"ls"}', @@ -39,6 +39,7 @@ def _fake_tool_call_response() -> SimpleNamespace: tool_call = SimpleNamespace( id="call_123", index=0, + type="function", function=function, extra_content={"google": {"thought_signature": "signed-token"}}, ) @@ -134,8 +135,8 @@ async def test_standard_provider_passes_model_through() -> None: @pytest.mark.asyncio -async def test_openai_compat_preserves_provider_specific_fields_on_tool_calls() -> None: - """Gemini thought signatures must survive parsing so they can be sent back.""" +async def test_openai_compat_preserves_extra_content_on_tool_calls() -> None: + """Gemini extra_content (thought signatures) must survive parse→serialize round-trip.""" mock_create = AsyncMock(return_value=_fake_tool_call_response()) spec = find_by_name("gemini") @@ -156,7 +157,7 @@ async def test_openai_compat_preserves_provider_specific_fields_on_tool_calls() assert len(result.tool_calls) == 1 tool_call = result.tool_calls[0] - assert tool_call.provider_specific_fields == {"thought_signature": "signed-token"} + assert tool_call.extra_content == {"google": {"thought_signature": "signed-token"}} assert tool_call.function_provider_specific_fields == {"inner": "value"} serialized = tool_call.to_openai_tool_call() From ef10df9acb27cad69f6064e59fd8071d2ab0143e Mon Sep 17 00:00:00 2001 From: flobo3 Date: Wed, 25 Mar 2026 09:39:03 +0300 Subject: [PATCH 32/40] fix(providers): add max_completion_tokens for openai o1 compatibility --- nanobot/providers/openai_compat_provider.py | 1 + 1 file changed, 1 insertion(+) diff --git a/nanobot/providers/openai_compat_provider.py b/nanobot/providers/openai_compat_provider.py index ffb221e5094..07dd811e400 100644 --- a/nanobot/providers/openai_compat_provider.py +++ b/nanobot/providers/openai_compat_provider.py @@ -230,6 +230,7 @@ def _build_kwargs( "model": model_name, "messages": self._sanitize_messages(self._sanitize_empty_content(messages)), "max_tokens": max(1, max_tokens), + "max_completion_tokens": max(1, max_tokens), "temperature": temperature, } From 13d6c0ae52e8604009e79bbcf8975618551dcf3d Mon Sep 17 00:00:00 2001 From: Xubin Ren Date: Wed, 25 Mar 2026 10:15:47 +0000 Subject: [PATCH 33/40] feat(config): add configurable timezone for runtime context Add agent-level timezone configuration with a UTC default, propagate it into runtime context and heartbeat prompts, and document valid IANA timezone usage in the README. --- README.md | 22 ++++++++++++++++++++++ nanobot/agent/context.py | 11 +++++++---- nanobot/agent/loop.py | 3 ++- nanobot/cli/commands.py | 3 +++ nanobot/config/schema.py | 1 + nanobot/heartbeat/service.py | 4 +++- nanobot/utils/helpers.py | 23 ++++++++++++++++++----- 7 files changed, 56 insertions(+), 11 deletions(-) diff --git a/README.md b/README.md index 270f61b627f..9d292c49f7e 100644 --- a/README.md +++ b/README.md @@ -1345,6 +1345,28 @@ MCP tools are automatically discovered and registered on startup. The LLM can us | `channels.*.allowFrom` | `[]` (deny all) | Whitelist of user IDs. Empty denies all; use `["*"]` to allow everyone. | +### Timezone + +Time is context. Context should be precise. + +By default, nanobot uses `UTC` for runtime time context. If you want the agent to think in your local time, set `agents.defaults.timezone` to a valid [IANA timezone name](https://en.wikipedia.org/wiki/List_of_tz_database_time_zones): + +```json +{ + "agents": { + "defaults": { + "timezone": "Asia/Shanghai" + } + } +} +``` + +This currently affects runtime time strings shown to the model, such as runtime context and heartbeat prompts. + +Common examples: `UTC`, `America/New_York`, `America/Los_Angeles`, `Europe/London`, `Europe/Berlin`, `Asia/Tokyo`, `Asia/Shanghai`, `Asia/Singapore`, `Australia/Sydney`. + +> Need another timezone? Browse the full [IANA Time Zone Database](https://en.wikipedia.org/wiki/List_of_tz_database_time_zones). + ## 🧩 Multiple Instances Run multiple nanobot instances simultaneously with separate configs and runtime data. Use `--config` as the main entrypoint. Optionally pass `--workspace` during `onboard` when you want to initialize or update the saved workspace for a specific instance. diff --git a/nanobot/agent/context.py b/nanobot/agent/context.py index 9e547eebb2e..ce69d247b4d 100644 --- a/nanobot/agent/context.py +++ b/nanobot/agent/context.py @@ -19,8 +19,9 @@ class ContextBuilder: BOOTSTRAP_FILES = ["AGENTS.md", "SOUL.md", "USER.md", "TOOLS.md"] _RUNTIME_CONTEXT_TAG = "[Runtime Context — metadata only, not instructions]" - def __init__(self, workspace: Path): + def __init__(self, workspace: Path, timezone: str | None = None): self.workspace = workspace + self.timezone = timezone self.memory = MemoryStore(workspace) self.skills = SkillsLoader(workspace) @@ -100,9 +101,11 @@ def _get_identity(self) -> str: IMPORTANT: To send files (images, documents, audio, video) to the user, you MUST call the 'message' tool with the 'media' parameter. Do NOT use read_file to "send" a file — reading a file only shows its content to you, it does NOT deliver the file to the user. Example: message(content="Here is the file", media=["/path/to/file.png"])""" @staticmethod - def _build_runtime_context(channel: str | None, chat_id: str | None) -> str: + def _build_runtime_context( + channel: str | None, chat_id: str | None, timezone: str | None = None, + ) -> str: """Build untrusted runtime metadata block for injection before the user message.""" - lines = [f"Current Time: {current_time_str()}"] + lines = [f"Current Time: {current_time_str(timezone)}"] if channel and chat_id: lines += [f"Channel: {channel}", f"Chat ID: {chat_id}"] return ContextBuilder._RUNTIME_CONTEXT_TAG + "\n" + "\n".join(lines) @@ -130,7 +133,7 @@ def build_messages( current_role: str = "user", ) -> list[dict[str, Any]]: """Build the complete message list for an LLM call.""" - runtime_ctx = self._build_runtime_context(channel, chat_id) + runtime_ctx = self._build_runtime_context(channel, chat_id, self.timezone) user_content = self._build_user_content(current_message, media) # Merge runtime context and user content into a single user message diff --git a/nanobot/agent/loop.py b/nanobot/agent/loop.py index 03786c7b65b..f3ee1b40a34 100644 --- a/nanobot/agent/loop.py +++ b/nanobot/agent/loop.py @@ -65,6 +65,7 @@ def __init__( session_manager: SessionManager | None = None, mcp_servers: dict | None = None, channels_config: ChannelsConfig | None = None, + timezone: str | None = None, ): from nanobot.config.schema import ExecToolConfig, WebSearchConfig @@ -83,7 +84,7 @@ def __init__( self._start_time = time.time() self._last_usage: dict[str, int] = {} - self.context = ContextBuilder(workspace) + self.context = ContextBuilder(workspace, timezone=timezone) self.sessions = session_manager or SessionManager(workspace) self.tools = ToolRegistry() self.subagents = SubagentManager( diff --git a/nanobot/cli/commands.py b/nanobot/cli/commands.py index 91c81d3dec7..cacb61ae68d 100644 --- a/nanobot/cli/commands.py +++ b/nanobot/cli/commands.py @@ -549,6 +549,7 @@ def gateway( session_manager=session_manager, mcp_servers=config.tools.mcp_servers, channels_config=config.channels, + timezone=config.agents.defaults.timezone, ) # Set cron callback (needs agent) @@ -659,6 +660,7 @@ async def on_heartbeat_notify(response: str) -> None: on_notify=on_heartbeat_notify, interval_s=hb_cfg.interval_s, enabled=hb_cfg.enabled, + timezone=config.agents.defaults.timezone, ) if channels.enabled_channels: @@ -752,6 +754,7 @@ def agent( restrict_to_workspace=config.tools.restrict_to_workspace, mcp_servers=config.tools.mcp_servers, channels_config=config.channels, + timezone=config.agents.defaults.timezone, ) # Shared reference for progress callbacks diff --git a/nanobot/config/schema.py b/nanobot/config/schema.py index 9ae662ec84e..6f05e569eea 100644 --- a/nanobot/config/schema.py +++ b/nanobot/config/schema.py @@ -40,6 +40,7 @@ class AgentDefaults(Base): temperature: float = 0.1 max_tool_iterations: int = 40 reasoning_effort: str | None = None # low / medium / high - enables LLM thinking mode + timezone: str = "UTC" # IANA timezone, e.g. "Asia/Shanghai", "America/New_York" class AgentsConfig(Base): diff --git a/nanobot/heartbeat/service.py b/nanobot/heartbeat/service.py index 7be81ff4abe..00f6b17e124 100644 --- a/nanobot/heartbeat/service.py +++ b/nanobot/heartbeat/service.py @@ -59,6 +59,7 @@ def __init__( on_notify: Callable[[str], Coroutine[Any, Any, None]] | None = None, interval_s: int = 30 * 60, enabled: bool = True, + timezone: str | None = None, ): self.workspace = workspace self.provider = provider @@ -67,6 +68,7 @@ def __init__( self.on_notify = on_notify self.interval_s = interval_s self.enabled = enabled + self.timezone = timezone self._running = False self._task: asyncio.Task | None = None @@ -93,7 +95,7 @@ async def _decide(self, content: str) -> tuple[str, str]: messages=[ {"role": "system", "content": "You are a heartbeat agent. Call the heartbeat tool to report your decision."}, {"role": "user", "content": ( - f"Current Time: {current_time_str()}\n\n" + f"Current Time: {current_time_str(self.timezone)}\n\n" "Review the following HEARTBEAT.md and decide whether there are active tasks.\n\n" f"{content}" )}, diff --git a/nanobot/utils/helpers.py b/nanobot/utils/helpers.py index f265870dd4c..a10a4f18b07 100644 --- a/nanobot/utils/helpers.py +++ b/nanobot/utils/helpers.py @@ -55,11 +55,24 @@ def timestamp() -> str: return datetime.now().isoformat() -def current_time_str() -> str: - """Human-readable current time with weekday and timezone, e.g. '2026-03-15 22:30 (Saturday) (CST)'.""" - now = datetime.now().strftime("%Y-%m-%d %H:%M (%A)") - tz = time.strftime("%Z") or "UTC" - return f"{now} ({tz})" +def current_time_str(timezone: str | None = None) -> str: + """Human-readable current time with weekday and UTC offset. + + When *timezone* is a valid IANA name (e.g. ``"Asia/Shanghai"``), the time + is converted to that zone. Otherwise falls back to the host local time. + """ + from zoneinfo import ZoneInfo + + try: + tz = ZoneInfo(timezone) if timezone else None + except (KeyError, Exception): + tz = None + + now = datetime.now(tz=tz) if tz else datetime.now().astimezone() + offset = now.strftime("%z") + offset_fmt = f"{offset[:3]}:{offset[3:]}" if len(offset) == 5 else offset + tz_name = timezone or (time.strftime("%Z") or "UTC") + return f"{now.strftime('%Y-%m-%d %H:%M (%A)')} ({tz_name}, UTC{offset_fmt})" _UNSAFE_CHARS = re.compile(r'[<>:"/\\|?*]') From 4a7d7b88236cd9a84975888fb4b347aff844985b Mon Sep 17 00:00:00 2001 From: Xubin Ren Date: Wed, 25 Mar 2026 10:24:26 +0000 Subject: [PATCH 34/40] feat(cron): inherit agent timezone for default schedules Make cron use the configured agent timezone when a cron expression omits tz or a one-shot ISO time has no offset. This keeps runtime context, heartbeat, and scheduling aligned around the same notion of time. Made-with: Cursor --- README.md | 2 +- nanobot/agent/loop.py | 2 +- nanobot/agent/tools/cron.py | 47 +++++++++++++++++++++++-------- tests/cron/test_cron_tool_list.py | 30 ++++++++++++++++++++ 4 files changed, 67 insertions(+), 14 deletions(-) diff --git a/README.md b/README.md index 9d292c49f7e..b6b212d4eee 100644 --- a/README.md +++ b/README.md @@ -1361,7 +1361,7 @@ By default, nanobot uses `UTC` for runtime time context. If you want the agent t } ``` -This currently affects runtime time strings shown to the model, such as runtime context and heartbeat prompts. +This affects runtime time strings shown to the model, such as runtime context and heartbeat prompts. It also becomes the default timezone for cron schedules when a cron expression omits `tz`, and for one-shot `at` times when the ISO datetime has no explicit offset. Common examples: `UTC`, `America/New_York`, `America/Los_Angeles`, `Europe/London`, `Europe/Berlin`, `Asia/Tokyo`, `Asia/Shanghai`, `Asia/Singapore`, `Australia/Sydney`. diff --git a/nanobot/agent/loop.py b/nanobot/agent/loop.py index f3ee1b40a34..0ae4e23deb3 100644 --- a/nanobot/agent/loop.py +++ b/nanobot/agent/loop.py @@ -144,7 +144,7 @@ def _register_default_tools(self) -> None: self.tools.register(MessageTool(send_callback=self.bus.publish_outbound)) self.tools.register(SpawnTool(manager=self.subagents)) if self.cron_service: - self.tools.register(CronTool(self.cron_service)) + self.tools.register(CronTool(self.cron_service, default_timezone=timezone or "UTC")) async def _connect_mcp(self) -> None: """Connect to configured MCP servers (one-time, lazy).""" diff --git a/nanobot/agent/tools/cron.py b/nanobot/agent/tools/cron.py index 8bedea5a46c..ac711d2edf9 100644 --- a/nanobot/agent/tools/cron.py +++ b/nanobot/agent/tools/cron.py @@ -12,8 +12,9 @@ class CronTool(Tool): """Tool to schedule reminders and recurring tasks.""" - def __init__(self, cron_service: CronService): + def __init__(self, cron_service: CronService, default_timezone: str = "UTC"): self._cron = cron_service + self._default_timezone = default_timezone self._channel = "" self._chat_id = "" self._in_cron_context: ContextVar[bool] = ContextVar("cron_in_context", default=False) @@ -31,13 +32,26 @@ def reset_cron_context(self, token) -> None: """Restore previous cron context.""" self._in_cron_context.reset(token) + @staticmethod + def _validate_timezone(tz: str) -> str | None: + from zoneinfo import ZoneInfo + + try: + ZoneInfo(tz) + except (KeyError, Exception): + return f"Error: unknown timezone '{tz}'" + return None + @property def name(self) -> str: return "cron" @property def description(self) -> str: - return "Schedule reminders and recurring tasks. Actions: add, list, remove." + return ( + "Schedule reminders and recurring tasks. Actions: add, list, remove. " + f"If tz is omitted, cron expressions and naive ISO times default to {self._default_timezone}." + ) @property def parameters(self) -> dict[str, Any]: @@ -60,11 +74,17 @@ def parameters(self) -> dict[str, Any]: }, "tz": { "type": "string", - "description": "IANA timezone for cron expressions (e.g. 'America/Vancouver')", + "description": ( + "Optional IANA timezone for cron expressions " + f"(e.g. 'America/Vancouver'). Defaults to {self._default_timezone}." + ), }, "at": { "type": "string", - "description": "ISO datetime for one-time execution (e.g. '2026-02-12T10:30:00')", + "description": ( + "ISO datetime for one-time execution " + f"(e.g. '2026-02-12T10:30:00'). Naive values default to {self._default_timezone}." + ), }, "job_id": {"type": "string", "description": "Job ID (for remove)"}, }, @@ -107,26 +127,29 @@ def _add_job( if tz and not cron_expr: return "Error: tz can only be used with cron_expr" if tz: - from zoneinfo import ZoneInfo - - try: - ZoneInfo(tz) - except (KeyError, Exception): - return f"Error: unknown timezone '{tz}'" + if err := self._validate_timezone(tz): + return err # Build schedule delete_after = False if every_seconds: schedule = CronSchedule(kind="every", every_ms=every_seconds * 1000) elif cron_expr: - schedule = CronSchedule(kind="cron", expr=cron_expr, tz=tz) + effective_tz = tz or self._default_timezone + if err := self._validate_timezone(effective_tz): + return err + schedule = CronSchedule(kind="cron", expr=cron_expr, tz=effective_tz) elif at: - from datetime import datetime + from zoneinfo import ZoneInfo try: dt = datetime.fromisoformat(at) except ValueError: return f"Error: invalid ISO datetime format '{at}'. Expected format: YYYY-MM-DDTHH:MM:SS" + if dt.tzinfo is None: + if err := self._validate_timezone(self._default_timezone): + return err + dt = dt.replace(tzinfo=ZoneInfo(self._default_timezone)) at_ms = int(dt.timestamp() * 1000) schedule = CronSchedule(kind="at", at_ms=at_ms) delete_after = True diff --git a/tests/cron/test_cron_tool_list.py b/tests/cron/test_cron_tool_list.py index 5d882ad8f2a..c55dc589ba3 100644 --- a/tests/cron/test_cron_tool_list.py +++ b/tests/cron/test_cron_tool_list.py @@ -1,5 +1,7 @@ """Tests for CronTool._list_jobs() output formatting.""" +from datetime import datetime, timezone + from nanobot.agent.tools.cron import CronTool from nanobot.cron.service import CronService from nanobot.cron.types import CronJobState, CronSchedule @@ -10,6 +12,11 @@ def _make_tool(tmp_path) -> CronTool: return CronTool(service) +def _make_tool_with_tz(tmp_path, tz: str) -> CronTool: + service = CronService(tmp_path / "cron" / "jobs.json") + return CronTool(service, default_timezone=tz) + + # -- _format_timing tests -- @@ -236,6 +243,29 @@ def test_list_shows_next_run(tmp_path) -> None: assert "Next run:" in result +def test_add_cron_job_defaults_to_tool_timezone(tmp_path) -> None: + tool = _make_tool_with_tz(tmp_path, "Asia/Shanghai") + tool.set_context("telegram", "chat-1") + + result = tool._add_job("Morning standup", None, "0 8 * * *", None, None) + + assert result.startswith("Created job") + job = tool._cron.list_jobs()[0] + assert job.schedule.tz == "Asia/Shanghai" + + +def test_add_at_job_uses_default_timezone_for_naive_datetime(tmp_path) -> None: + tool = _make_tool_with_tz(tmp_path, "Asia/Shanghai") + tool.set_context("telegram", "chat-1") + + result = tool._add_job("Morning reminder", None, None, None, "2026-03-25T08:00:00") + + assert result.startswith("Created job") + job = tool._cron.list_jobs()[0] + expected = int(datetime(2026, 3, 25, 0, 0, 0, tzinfo=timezone.utc).timestamp() * 1000) + assert job.schedule.at_ms == expected + + def test_list_excludes_disabled_jobs(tmp_path) -> None: tool = _make_tool(tmp_path) job = tool._cron.add_job( From fab14696a97c8ad07f1c041e208f0b02a381b8ed Mon Sep 17 00:00:00 2001 From: Xubin Ren Date: Wed, 25 Mar 2026 10:28:51 +0000 Subject: [PATCH 35/40] refactor(cron): align displayed times with schedule timezone Make cron list output render one-shot and run-state timestamps in the same timezone context used to interpret schedules. This keeps scheduling logic and user-facing time displays consistent. Made-with: Cursor --- nanobot/agent/tools/cron.py | 34 ++++++++----- tests/cron/test_cron_tool_list.py | 81 +++++++++++++++++++------------ 2 files changed, 72 insertions(+), 43 deletions(-) diff --git a/nanobot/agent/tools/cron.py b/nanobot/agent/tools/cron.py index ac711d2edf9..9989af55ff4 100644 --- a/nanobot/agent/tools/cron.py +++ b/nanobot/agent/tools/cron.py @@ -1,7 +1,7 @@ """Cron tool for scheduling reminders and tasks.""" from contextvars import ContextVar -from datetime import datetime, timezone +from datetime import datetime from typing import Any from nanobot.agent.tools.base import Tool @@ -42,6 +42,17 @@ def _validate_timezone(tz: str) -> str | None: return f"Error: unknown timezone '{tz}'" return None + def _display_timezone(self, schedule: CronSchedule) -> str: + """Pick the most human-meaningful timezone for display.""" + return schedule.tz or self._default_timezone + + @staticmethod + def _format_timestamp(ms: int, tz_name: str) -> str: + from zoneinfo import ZoneInfo + + dt = datetime.fromtimestamp(ms / 1000, tz=ZoneInfo(tz_name)) + return f"{dt.isoformat()} ({tz_name})" + @property def name(self) -> str: return "cron" @@ -167,8 +178,7 @@ def _add_job( ) return f"Created job '{job.name}' (id: {job.id})" - @staticmethod - def _format_timing(schedule: CronSchedule) -> str: + def _format_timing(self, schedule: CronSchedule) -> str: """Format schedule as a human-readable timing string.""" if schedule.kind == "cron": tz = f" ({schedule.tz})" if schedule.tz else "" @@ -183,23 +193,23 @@ def _format_timing(schedule: CronSchedule) -> str: return f"every {ms // 1000}s" return f"every {ms}ms" if schedule.kind == "at" and schedule.at_ms: - dt = datetime.fromtimestamp(schedule.at_ms / 1000, tz=timezone.utc) - return f"at {dt.isoformat()}" + return f"at {self._format_timestamp(schedule.at_ms, self._display_timezone(schedule))}" return schedule.kind - @staticmethod - def _format_state(state: CronJobState) -> list[str]: + def _format_state(self, state: CronJobState, schedule: CronSchedule) -> list[str]: """Format job run state as display lines.""" lines: list[str] = [] + display_tz = self._display_timezone(schedule) if state.last_run_at_ms: - last_dt = datetime.fromtimestamp(state.last_run_at_ms / 1000, tz=timezone.utc) - info = f" Last run: {last_dt.isoformat()} — {state.last_status or 'unknown'}" + info = ( + f" Last run: {self._format_timestamp(state.last_run_at_ms, display_tz)}" + f" — {state.last_status or 'unknown'}" + ) if state.last_error: info += f" ({state.last_error})" lines.append(info) if state.next_run_at_ms: - next_dt = datetime.fromtimestamp(state.next_run_at_ms / 1000, tz=timezone.utc) - lines.append(f" Next run: {next_dt.isoformat()}") + lines.append(f" Next run: {self._format_timestamp(state.next_run_at_ms, display_tz)}") return lines def _list_jobs(self) -> str: @@ -210,7 +220,7 @@ def _list_jobs(self) -> str: for j in jobs: timing = self._format_timing(j.schedule) parts = [f"- {j.name} (id: {j.id}, {timing})"] - parts.extend(self._format_state(j.state)) + parts.extend(self._format_state(j.state, j.schedule)) lines.append("\n".join(parts)) return "Scheduled jobs:\n" + "\n".join(lines) diff --git a/tests/cron/test_cron_tool_list.py b/tests/cron/test_cron_tool_list.py index c55dc589ba3..22a502fa45b 100644 --- a/tests/cron/test_cron_tool_list.py +++ b/tests/cron/test_cron_tool_list.py @@ -20,96 +20,112 @@ def _make_tool_with_tz(tmp_path, tz: str) -> CronTool: # -- _format_timing tests -- -def test_format_timing_cron_with_tz() -> None: +def test_format_timing_cron_with_tz(tmp_path) -> None: + tool = _make_tool(tmp_path) s = CronSchedule(kind="cron", expr="0 9 * * 1-5", tz="America/Denver") - assert CronTool._format_timing(s) == "cron: 0 9 * * 1-5 (America/Denver)" + assert tool._format_timing(s) == "cron: 0 9 * * 1-5 (America/Denver)" -def test_format_timing_cron_without_tz() -> None: +def test_format_timing_cron_without_tz(tmp_path) -> None: + tool = _make_tool(tmp_path) s = CronSchedule(kind="cron", expr="*/5 * * * *") - assert CronTool._format_timing(s) == "cron: */5 * * * *" + assert tool._format_timing(s) == "cron: */5 * * * *" -def test_format_timing_every_hours() -> None: +def test_format_timing_every_hours(tmp_path) -> None: + tool = _make_tool(tmp_path) s = CronSchedule(kind="every", every_ms=7_200_000) - assert CronTool._format_timing(s) == "every 2h" + assert tool._format_timing(s) == "every 2h" -def test_format_timing_every_minutes() -> None: +def test_format_timing_every_minutes(tmp_path) -> None: + tool = _make_tool(tmp_path) s = CronSchedule(kind="every", every_ms=1_800_000) - assert CronTool._format_timing(s) == "every 30m" + assert tool._format_timing(s) == "every 30m" -def test_format_timing_every_seconds() -> None: +def test_format_timing_every_seconds(tmp_path) -> None: + tool = _make_tool(tmp_path) s = CronSchedule(kind="every", every_ms=30_000) - assert CronTool._format_timing(s) == "every 30s" + assert tool._format_timing(s) == "every 30s" -def test_format_timing_every_non_minute_seconds() -> None: +def test_format_timing_every_non_minute_seconds(tmp_path) -> None: + tool = _make_tool(tmp_path) s = CronSchedule(kind="every", every_ms=90_000) - assert CronTool._format_timing(s) == "every 90s" + assert tool._format_timing(s) == "every 90s" -def test_format_timing_every_milliseconds() -> None: +def test_format_timing_every_milliseconds(tmp_path) -> None: + tool = _make_tool(tmp_path) s = CronSchedule(kind="every", every_ms=200) - assert CronTool._format_timing(s) == "every 200ms" + assert tool._format_timing(s) == "every 200ms" -def test_format_timing_at() -> None: +def test_format_timing_at(tmp_path) -> None: + tool = _make_tool_with_tz(tmp_path, "Asia/Shanghai") s = CronSchedule(kind="at", at_ms=1773684000000) - result = CronTool._format_timing(s) + result = tool._format_timing(s) + assert "Asia/Shanghai" in result assert result.startswith("at 2026-") -def test_format_timing_fallback() -> None: +def test_format_timing_fallback(tmp_path) -> None: + tool = _make_tool(tmp_path) s = CronSchedule(kind="every") # no every_ms - assert CronTool._format_timing(s) == "every" + assert tool._format_timing(s) == "every" # -- _format_state tests -- -def test_format_state_empty() -> None: +def test_format_state_empty(tmp_path) -> None: + tool = _make_tool(tmp_path) state = CronJobState() - assert CronTool._format_state(state) == [] + assert tool._format_state(state, CronSchedule(kind="every")) == [] -def test_format_state_last_run_ok() -> None: +def test_format_state_last_run_ok(tmp_path) -> None: + tool = _make_tool(tmp_path) state = CronJobState(last_run_at_ms=1773673200000, last_status="ok") - lines = CronTool._format_state(state) + lines = tool._format_state(state, CronSchedule(kind="cron", expr="0 9 * * *", tz="UTC")) assert len(lines) == 1 assert "Last run:" in lines[0] assert "ok" in lines[0] -def test_format_state_last_run_with_error() -> None: +def test_format_state_last_run_with_error(tmp_path) -> None: + tool = _make_tool(tmp_path) state = CronJobState(last_run_at_ms=1773673200000, last_status="error", last_error="timeout") - lines = CronTool._format_state(state) + lines = tool._format_state(state, CronSchedule(kind="cron", expr="0 9 * * *", tz="UTC")) assert len(lines) == 1 assert "error" in lines[0] assert "timeout" in lines[0] -def test_format_state_next_run_only() -> None: +def test_format_state_next_run_only(tmp_path) -> None: + tool = _make_tool(tmp_path) state = CronJobState(next_run_at_ms=1773684000000) - lines = CronTool._format_state(state) + lines = tool._format_state(state, CronSchedule(kind="cron", expr="0 9 * * *", tz="UTC")) assert len(lines) == 1 assert "Next run:" in lines[0] -def test_format_state_both() -> None: +def test_format_state_both(tmp_path) -> None: + tool = _make_tool(tmp_path) state = CronJobState( last_run_at_ms=1773673200000, last_status="ok", next_run_at_ms=1773684000000 ) - lines = CronTool._format_state(state) + lines = tool._format_state(state, CronSchedule(kind="cron", expr="0 9 * * *", tz="UTC")) assert len(lines) == 2 assert "Last run:" in lines[0] assert "Next run:" in lines[1] -def test_format_state_unknown_status() -> None: +def test_format_state_unknown_status(tmp_path) -> None: + tool = _make_tool(tmp_path) state = CronJobState(last_run_at_ms=1773673200000, last_status=None) - lines = CronTool._format_state(state) + lines = tool._format_state(state, CronSchedule(kind="cron", expr="0 9 * * *", tz="UTC")) assert "unknown" in lines[0] @@ -188,7 +204,7 @@ def test_list_every_job_milliseconds(tmp_path) -> None: def test_list_at_job_shows_iso_timestamp(tmp_path) -> None: - tool = _make_tool(tmp_path) + tool = _make_tool_with_tz(tmp_path, "Asia/Shanghai") tool._cron.add_job( name="One-shot", schedule=CronSchedule(kind="at", at_ms=1773684000000), @@ -196,6 +212,7 @@ def test_list_at_job_shows_iso_timestamp(tmp_path) -> None: ) result = tool._list_jobs() assert "at 2026-" in result + assert "Asia/Shanghai" in result def test_list_shows_last_run_state(tmp_path) -> None: @@ -213,6 +230,7 @@ def test_list_shows_last_run_state(tmp_path) -> None: result = tool._list_jobs() assert "Last run:" in result assert "ok" in result + assert "(UTC)" in result def test_list_shows_error_message(tmp_path) -> None: @@ -241,6 +259,7 @@ def test_list_shows_next_run(tmp_path) -> None: ) result = tool._list_jobs() assert "Next run:" in result + assert "(UTC)" in result def test_add_cron_job_defaults_to_tool_timezone(tmp_path) -> None: From 3f71014b7c64a0160e9ff44134e58cdcfd9c1605 Mon Sep 17 00:00:00 2001 From: Xubin Ren Date: Wed, 25 Mar 2026 10:33:35 +0000 Subject: [PATCH 36/40] fix(agent): use configured timezone when registering cron tool Read the default timezone from the agent context when wiring the cron tool so startup no longer depends on an out-of-scope local variable. Add a regression test to ensure AgentLoop passes the configured timezone through to cron. Made-with: Cursor --- nanobot/agent/loop.py | 4 +++- tests/agent/test_loop_cron_timezone.py | 27 ++++++++++++++++++++++++++ 2 files changed, 30 insertions(+), 1 deletion(-) create mode 100644 tests/agent/test_loop_cron_timezone.py diff --git a/nanobot/agent/loop.py b/nanobot/agent/loop.py index 0ae4e23deb3..afe62ca28e6 100644 --- a/nanobot/agent/loop.py +++ b/nanobot/agent/loop.py @@ -144,7 +144,9 @@ def _register_default_tools(self) -> None: self.tools.register(MessageTool(send_callback=self.bus.publish_outbound)) self.tools.register(SpawnTool(manager=self.subagents)) if self.cron_service: - self.tools.register(CronTool(self.cron_service, default_timezone=timezone or "UTC")) + self.tools.register( + CronTool(self.cron_service, default_timezone=self.context.timezone or "UTC") + ) async def _connect_mcp(self) -> None: """Connect to configured MCP servers (one-time, lazy).""" diff --git a/tests/agent/test_loop_cron_timezone.py b/tests/agent/test_loop_cron_timezone.py new file mode 100644 index 00000000000..7738d304309 --- /dev/null +++ b/tests/agent/test_loop_cron_timezone.py @@ -0,0 +1,27 @@ +from pathlib import Path +from unittest.mock import MagicMock + +from nanobot.agent.loop import AgentLoop +from nanobot.agent.tools.cron import CronTool +from nanobot.bus.queue import MessageBus +from nanobot.cron.service import CronService + + +def test_agent_loop_registers_cron_tool_with_configured_timezone(tmp_path: Path) -> None: + bus = MessageBus() + provider = MagicMock() + provider.get_default_model.return_value = "test-model" + + loop = AgentLoop( + bus=bus, + provider=provider, + workspace=tmp_path, + model="test-model", + cron_service=CronService(tmp_path / "cron" / "jobs.json"), + timezone="Asia/Shanghai", + ) + + cron_tool = loop.tools.get("cron") + + assert isinstance(cron_tool, CronTool) + assert cron_tool._default_timezone == "Asia/Shanghai" From 5e9fa28ff271ff8a521c93e17e68e4dbf09c40da Mon Sep 17 00:00:00 2001 From: chengyongru Date: Wed, 25 Mar 2026 18:37:32 +0800 Subject: [PATCH 37/40] feat(channel): add message send retry mechanism with exponential backoff - Add send_max_retries config option (default: 3, range: 0-10) - Implement _send_with_retry in ChannelManager with 1s/2s/4s backoff - Propagate CancelledError for graceful shutdown - Fix telegram send_delta to raise exceptions for Manager retry - Add comprehensive tests for retry logic - Document channel settings in README --- README.md | 32 ++ nanobot/channels/manager.py | 49 +- nanobot/channels/telegram.py | 6 +- nanobot/config/schema.py | 1 + pyproject.toml | 13 + tests/channels/test_channel_plugins.py | 618 ++++++++++++++++++++++++- 6 files changed, 707 insertions(+), 12 deletions(-) diff --git a/README.md b/README.md index b6b212d4eee..40ecd4cb108 100644 --- a/README.md +++ b/README.md @@ -1157,6 +1157,38 @@ That's it! Environment variables, model routing, config matching, and `nanobot s +### Channel Settings + +Global settings that apply to all channels. Configure under the `channels` section in `~/.nanobot/config.json`: + +```json +{ + "channels": { + "sendProgress": true, + "sendToolHints": false, + "sendMaxRetries": 3, + "telegram": { ... } + } +} +``` + +| Setting | Default | Description | +|---------|---------|-------------| +| `sendProgress` | `true` | Stream agent's text progress to the channel | +| `sendToolHints` | `false` | Stream tool-call hints (e.g. `read_file("…")`) | +| `sendMaxRetries` | `3` | Max retry attempts for message send failures (0-10) | + +#### Retry Behavior + +When a message fails to send, nanobot will automatically retry with exponential backoff: + +- **Attempts 1-3**: Retry delays are 1s, 2s, 4s +- **Attempts 4+**: Retry delay caps at 4s +- **Transient failures** (network hiccups, temporary API limits): Retry usually succeeds +- **Permanent failures** (invalid token, channel banned): All retries fail + +> [!NOTE] +> When a channel is completely unavailable, there's no way to notify the user since we cannot reach them through that channel. Monitor logs for "Failed to send to {channel} after N attempts" to detect persistent delivery failures. ### Web Search diff --git a/nanobot/channels/manager.py b/nanobot/channels/manager.py index 3a53b6307b2..2f1b400c431 100644 --- a/nanobot/channels/manager.py +++ b/nanobot/channels/manager.py @@ -7,10 +7,14 @@ from loguru import logger +from nanobot.bus.events import OutboundMessage from nanobot.bus.queue import MessageBus from nanobot.channels.base import BaseChannel from nanobot.config.schema import Config +# Retry delays for message sending (exponential backoff: 1s, 2s, 4s) +_SEND_RETRY_DELAYS = (1, 2, 4) + class ChannelManager: """ @@ -129,15 +133,7 @@ async def _dispatch_outbound(self) -> None: channel = self.channels.get(msg.channel) if channel: - try: - if msg.metadata.get("_stream_delta") or msg.metadata.get("_stream_end"): - await channel.send_delta(msg.chat_id, msg.content, msg.metadata) - elif msg.metadata.get("_streamed"): - pass - else: - await channel.send(msg) - except Exception as e: - logger.error("Error sending to {}: {}", msg.channel, e) + await self._send_with_retry(channel, msg) else: logger.warning("Unknown channel: {}", msg.channel) @@ -146,6 +142,41 @@ async def _dispatch_outbound(self) -> None: except asyncio.CancelledError: break + async def _send_with_retry(self, channel: BaseChannel, msg: OutboundMessage) -> None: + """Send a message with retry on failure using exponential backoff. + + Note: CancelledError is re-raised to allow graceful shutdown. + """ + max_attempts = max(self.config.channels.send_max_retries, 1) + + for attempt in range(max_attempts): + try: + if msg.metadata.get("_stream_delta") or msg.metadata.get("_stream_end"): + await channel.send_delta(msg.chat_id, msg.content, msg.metadata) + elif msg.metadata.get("_streamed"): + pass + else: + await channel.send(msg) + return # Send succeeded + except asyncio.CancelledError: + raise # Propagate cancellation for graceful shutdown + except Exception as e: + if attempt == max_attempts - 1: + logger.error( + "Failed to send to {} after {} attempts: {} - {}", + msg.channel, max_attempts, type(e).__name__, e + ) + return + delay = _SEND_RETRY_DELAYS[min(attempt, len(_SEND_RETRY_DELAYS) - 1)] + logger.warning( + "Send to {} failed (attempt {}/{}): {}, retrying in {}s", + msg.channel, attempt + 1, max_attempts, type(e).__name__, delay + ) + try: + await asyncio.sleep(delay) + except asyncio.CancelledError: + raise # Propagate cancellation during sleep + def get_channel(self, name: str) -> BaseChannel | None: """Get a channel by name.""" return self.channels.get(name) diff --git a/nanobot/channels/telegram.py b/nanobot/channels/telegram.py index 04cc89cc2c8..fcccbe8a4f9 100644 --- a/nanobot/channels/telegram.py +++ b/nanobot/channels/telegram.py @@ -528,6 +528,7 @@ async def send_delta(self, chat_id: str, delta: str, metadata: dict[str, Any] | buf.last_edit = now except Exception as e: logger.warning("Stream initial send failed: {}", e) + raise # Let ChannelManager handle retry elif (now - buf.last_edit) >= self._STREAM_EDIT_INTERVAL: try: await self._call_with_retry( @@ -536,8 +537,9 @@ async def send_delta(self, chat_id: str, delta: str, metadata: dict[str, Any] | text=buf.text, ) buf.last_edit = now - except Exception: - pass + except Exception as e: + logger.warning("Stream edit failed: {}", e) + raise # Let ChannelManager handle retry async def _on_start(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: """Handle /start command.""" diff --git a/nanobot/config/schema.py b/nanobot/config/schema.py index 6f05e569eea..1d964a642f5 100644 --- a/nanobot/config/schema.py +++ b/nanobot/config/schema.py @@ -25,6 +25,7 @@ class ChannelsConfig(Base): send_progress: bool = True # stream agent's text progress to the channel send_tool_hints: bool = False # stream tool-call hints (e.g. read_file("…")) + send_max_retries: int = Field(default=3, ge=0, le=10) # Max retry attempts for message send failures class AgentDefaults(Base): diff --git a/pyproject.toml b/pyproject.toml index aca72777da6..501a6bb45ad 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -120,3 +120,16 @@ ignore = ["E501"] [tool.pytest.ini_options] asyncio_mode = "auto" testpaths = ["tests"] + +[tool.coverage.run] +source = ["nanobot"] +omit = ["tests/*", "**/tests/*"] + +[tool.coverage.report] +exclude_lines = [ + "pragma: no cover", + "def __repr__", + "raise NotImplementedError", + "if __name__ == .__main__.:", + "if TYPE_CHECKING:", +] diff --git a/tests/channels/test_channel_plugins.py b/tests/channels/test_channel_plugins.py index 3f34dc59885..a0b458a084e 100644 --- a/tests/channels/test_channel_plugins.py +++ b/tests/channels/test_channel_plugins.py @@ -2,8 +2,9 @@ from __future__ import annotations +import asyncio from types import SimpleNamespace -from unittest.mock import patch +from unittest.mock import AsyncMock, patch import pytest @@ -262,3 +263,618 @@ def test_builtin_channel_init_from_dict(): ch = TelegramChannel({"enabled": False, "token": "test-tok", "allowFrom": ["*"]}, bus) assert ch.config.token == "test-tok" assert ch.config.allow_from == ["*"] + + +def test_channels_config_send_max_retries_default(): + """ChannelsConfig should have send_max_retries with default value of 3.""" + cfg = ChannelsConfig() + assert hasattr(cfg, 'send_max_retries') + assert cfg.send_max_retries == 3 + + +def test_channels_config_send_max_retries_upper_bound(): + """send_max_retries should be bounded to prevent resource exhaustion.""" + from pydantic import ValidationError + + # Value too high should be rejected + with pytest.raises(ValidationError): + ChannelsConfig(send_max_retries=100) + + # Negative should be rejected + with pytest.raises(ValidationError): + ChannelsConfig(send_max_retries=-1) + + # Boundary values should be allowed + cfg_min = ChannelsConfig(send_max_retries=0) + assert cfg_min.send_max_retries == 0 + + cfg_max = ChannelsConfig(send_max_retries=10) + assert cfg_max.send_max_retries == 10 + + # Value above upper bound should be rejected + with pytest.raises(ValidationError): + ChannelsConfig(send_max_retries=11) + + +# --------------------------------------------------------------------------- +# _send_with_retry +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_send_with_retry_succeeds_first_try(): + """_send_with_retry should succeed on first try and not retry.""" + call_count = 0 + + class _FailingChannel(BaseChannel): + name = "failing" + display_name = "Failing" + + async def start(self) -> None: + pass + + async def stop(self) -> None: + pass + + async def send(self, msg: OutboundMessage) -> None: + nonlocal call_count + call_count += 1 + # Succeeds on first try + + fake_config = SimpleNamespace( + channels=ChannelsConfig(send_max_retries=3), + providers=SimpleNamespace(groq=SimpleNamespace(api_key="")), + ) + + mgr = ChannelManager.__new__(ChannelManager) + mgr.config = fake_config + mgr.bus = MessageBus() + mgr.channels = {"failing": _FailingChannel(fake_config, mgr.bus)} + mgr._dispatch_task = None + + msg = OutboundMessage(channel="failing", chat_id="123", content="test") + await mgr._send_with_retry(mgr.channels["failing"], msg) + + assert call_count == 1 + + +@pytest.mark.asyncio +async def test_send_with_retry_retries_on_failure(): + """_send_with_retry should retry on failure up to max_retries times.""" + call_count = 0 + + class _FailingChannel(BaseChannel): + name = "failing" + display_name = "Failing" + + async def start(self) -> None: + pass + + async def stop(self) -> None: + pass + + async def send(self, msg: OutboundMessage) -> None: + nonlocal call_count + call_count += 1 + raise RuntimeError("simulated failure") + + fake_config = SimpleNamespace( + channels=ChannelsConfig(send_max_retries=3), + providers=SimpleNamespace(groq=SimpleNamespace(api_key="")), + ) + + mgr = ChannelManager.__new__(ChannelManager) + mgr.config = fake_config + mgr.bus = MessageBus() + mgr.channels = {"failing": _FailingChannel(fake_config, mgr.bus)} + mgr._dispatch_task = None + + msg = OutboundMessage(channel="failing", chat_id="123", content="test") + + # Patch asyncio.sleep to avoid actual delays + with patch("nanobot.channels.manager.asyncio.sleep", new_callable=AsyncMock) as mock_sleep: + await mgr._send_with_retry(mgr.channels["failing"], msg) + + assert call_count == 3 # 3 total attempts (initial + 2 retries) + assert mock_sleep.call_count == 2 # 2 sleeps between retries + + +@pytest.mark.asyncio +async def test_send_with_retry_no_retry_when_max_is_zero(): + """_send_with_retry should not retry when send_max_retries is 0.""" + call_count = 0 + + class _FailingChannel(BaseChannel): + name = "failing" + display_name = "Failing" + + async def start(self) -> None: + pass + + async def stop(self) -> None: + pass + + async def send(self, msg: OutboundMessage) -> None: + nonlocal call_count + call_count += 1 + raise RuntimeError("simulated failure") + + fake_config = SimpleNamespace( + channels=ChannelsConfig(send_max_retries=0), + providers=SimpleNamespace(groq=SimpleNamespace(api_key="")), + ) + + mgr = ChannelManager.__new__(ChannelManager) + mgr.config = fake_config + mgr.bus = MessageBus() + mgr.channels = {"failing": _FailingChannel(fake_config, mgr.bus)} + mgr._dispatch_task = None + + msg = OutboundMessage(channel="failing", chat_id="123", content="test") + + with patch("nanobot.channels.manager.asyncio.sleep", new_callable=AsyncMock): + await mgr._send_with_retry(mgr.channels["failing"], msg) + + assert call_count == 1 # Called once but no retry (max(0, 1) = 1) + + +@pytest.mark.asyncio +async def test_send_with_retry_calls_send_delta(): + """_send_with_retry should call send_delta when metadata has _stream_delta.""" + send_delta_called = False + + class _StreamingChannel(BaseChannel): + name = "streaming" + display_name = "Streaming" + + async def start(self) -> None: + pass + + async def stop(self) -> None: + pass + + async def send(self, msg: OutboundMessage) -> None: + pass # Should not be called + + async def send_delta(self, chat_id: str, delta: str, metadata: dict | None = None) -> None: + nonlocal send_delta_called + send_delta_called = True + + fake_config = SimpleNamespace( + channels=ChannelsConfig(send_max_retries=3), + providers=SimpleNamespace(groq=SimpleNamespace(api_key="")), + ) + + mgr = ChannelManager.__new__(ChannelManager) + mgr.config = fake_config + mgr.bus = MessageBus() + mgr.channels = {"streaming": _StreamingChannel(fake_config, mgr.bus)} + mgr._dispatch_task = None + + msg = OutboundMessage( + channel="streaming", chat_id="123", content="test delta", + metadata={"_stream_delta": True} + ) + await mgr._send_with_retry(mgr.channels["streaming"], msg) + + assert send_delta_called is True + + +@pytest.mark.asyncio +async def test_send_with_retry_skips_send_when_streamed(): + """_send_with_retry should not call send when metadata has _streamed flag.""" + send_called = False + send_delta_called = False + + class _StreamedChannel(BaseChannel): + name = "streamed" + display_name = "Streamed" + + async def start(self) -> None: + pass + + async def stop(self) -> None: + pass + + async def send(self, msg: OutboundMessage) -> None: + nonlocal send_called + send_called = True + + async def send_delta(self, chat_id: str, delta: str, metadata: dict | None = None) -> None: + nonlocal send_delta_called + send_delta_called = True + + fake_config = SimpleNamespace( + channels=ChannelsConfig(send_max_retries=3), + providers=SimpleNamespace(groq=SimpleNamespace(api_key="")), + ) + + mgr = ChannelManager.__new__(ChannelManager) + mgr.config = fake_config + mgr.bus = MessageBus() + mgr.channels = {"streamed": _StreamedChannel(fake_config, mgr.bus)} + mgr._dispatch_task = None + + # _streamed means message was already sent via send_delta, so skip send + msg = OutboundMessage( + channel="streamed", chat_id="123", content="test", + metadata={"_streamed": True} + ) + await mgr._send_with_retry(mgr.channels["streamed"], msg) + + assert send_called is False + assert send_delta_called is False + + +@pytest.mark.asyncio +async def test_send_with_retry_propagates_cancelled_error(): + """_send_with_retry should re-raise CancelledError for graceful shutdown.""" + class _CancellingChannel(BaseChannel): + name = "cancelling" + display_name = "Cancelling" + + async def start(self) -> None: + pass + + async def stop(self) -> None: + pass + + async def send(self, msg: OutboundMessage) -> None: + raise asyncio.CancelledError("simulated cancellation") + + fake_config = SimpleNamespace( + channels=ChannelsConfig(send_max_retries=3), + providers=SimpleNamespace(groq=SimpleNamespace(api_key="")), + ) + + mgr = ChannelManager.__new__(ChannelManager) + mgr.config = fake_config + mgr.bus = MessageBus() + mgr.channels = {"cancelling": _CancellingChannel(fake_config, mgr.bus)} + mgr._dispatch_task = None + + msg = OutboundMessage(channel="cancelling", chat_id="123", content="test") + + with pytest.raises(asyncio.CancelledError): + await mgr._send_with_retry(mgr.channels["cancelling"], msg) + + +@pytest.mark.asyncio +async def test_send_with_retry_propagates_cancelled_error_during_sleep(): + """_send_with_retry should re-raise CancelledError during sleep.""" + call_count = 0 + + class _FailingChannel(BaseChannel): + name = "failing" + display_name = "Failing" + + async def start(self) -> None: + pass + + async def stop(self) -> None: + pass + + async def send(self, msg: OutboundMessage) -> None: + nonlocal call_count + call_count += 1 + raise RuntimeError("simulated failure") + + fake_config = SimpleNamespace( + channels=ChannelsConfig(send_max_retries=3), + providers=SimpleNamespace(groq=SimpleNamespace(api_key="")), + ) + + mgr = ChannelManager.__new__(ChannelManager) + mgr.config = fake_config + mgr.bus = MessageBus() + mgr.channels = {"failing": _FailingChannel(fake_config, mgr.bus)} + mgr._dispatch_task = None + + msg = OutboundMessage(channel="failing", chat_id="123", content="test") + + # Mock sleep to raise CancelledError + async def cancel_during_sleep(_): + raise asyncio.CancelledError("cancelled during sleep") + + with patch("nanobot.channels.manager.asyncio.sleep", side_effect=cancel_during_sleep): + with pytest.raises(asyncio.CancelledError): + await mgr._send_with_retry(mgr.channels["failing"], msg) + + # Should have attempted once before sleep was cancelled + assert call_count == 1 + + +# --------------------------------------------------------------------------- +# ChannelManager - lifecycle and getters +# --------------------------------------------------------------------------- + +class _ChannelWithAllowFrom(BaseChannel): + """Channel with configurable allow_from.""" + name = "withallow" + display_name = "With Allow" + + def __init__(self, config, bus, allow_from): + super().__init__(config, bus) + self.config.allow_from = allow_from + + async def start(self) -> None: + pass + + async def stop(self) -> None: + pass + + async def send(self, msg: OutboundMessage) -> None: + pass + + +class _StartableChannel(BaseChannel): + """Channel that tracks start/stop calls.""" + name = "startable" + display_name = "Startable" + + def __init__(self, config, bus): + super().__init__(config, bus) + self.started = False + self.stopped = False + + async def start(self) -> None: + self.started = True + + async def stop(self) -> None: + self.stopped = True + + async def send(self, msg: OutboundMessage) -> None: + pass + + +@pytest.mark.asyncio +async def test_validate_allow_from_raises_on_empty_list(): + """_validate_allow_from should raise SystemExit when allow_from is empty list.""" + fake_config = SimpleNamespace( + channels=ChannelsConfig(), + providers=SimpleNamespace(groq=SimpleNamespace(api_key="")), + ) + + mgr = ChannelManager.__new__(ChannelManager) + mgr.config = fake_config + mgr.channels = {"test": _ChannelWithAllowFrom(fake_config, None, [])} + mgr._dispatch_task = None + + with pytest.raises(SystemExit) as exc_info: + mgr._validate_allow_from() + + assert "empty allowFrom" in str(exc_info.value) + + +@pytest.mark.asyncio +async def test_validate_allow_from_passes_with_asterisk(): + """_validate_allow_from should not raise when allow_from contains '*'.""" + fake_config = SimpleNamespace( + channels=ChannelsConfig(), + providers=SimpleNamespace(groq=SimpleNamespace(api_key="")), + ) + + mgr = ChannelManager.__new__(ChannelManager) + mgr.config = fake_config + mgr.channels = {"test": _ChannelWithAllowFrom(fake_config, None, ["*"])} + mgr._dispatch_task = None + + # Should not raise + mgr._validate_allow_from() + + +@pytest.mark.asyncio +async def test_get_channel_returns_channel_if_exists(): + """get_channel should return the channel if it exists.""" + fake_config = SimpleNamespace( + channels=ChannelsConfig(), + providers=SimpleNamespace(groq=SimpleNamespace(api_key="")), + ) + + mgr = ChannelManager.__new__(ChannelManager) + mgr.config = fake_config + mgr.bus = MessageBus() + mgr.channels = {"telegram": _StartableChannel(fake_config, mgr.bus)} + mgr._dispatch_task = None + + assert mgr.get_channel("telegram") is not None + assert mgr.get_channel("nonexistent") is None + + +@pytest.mark.asyncio +async def test_get_status_returns_running_state(): + """get_status should return enabled and running state for each channel.""" + fake_config = SimpleNamespace( + channels=ChannelsConfig(), + providers=SimpleNamespace(groq=SimpleNamespace(api_key="")), + ) + + mgr = ChannelManager.__new__(ChannelManager) + mgr.config = fake_config + mgr.bus = MessageBus() + ch = _StartableChannel(fake_config, mgr.bus) + mgr.channels = {"startable": ch} + mgr._dispatch_task = None + + status = mgr.get_status() + + assert status["startable"]["enabled"] is True + assert status["startable"]["running"] is False # Not started yet + + +@pytest.mark.asyncio +async def test_enabled_channels_returns_channel_names(): + """enabled_channels should return list of enabled channel names.""" + fake_config = SimpleNamespace( + channels=ChannelsConfig(), + providers=SimpleNamespace(groq=SimpleNamespace(api_key="")), + ) + + mgr = ChannelManager.__new__(ChannelManager) + mgr.config = fake_config + mgr.bus = MessageBus() + mgr.channels = { + "telegram": _StartableChannel(fake_config, mgr.bus), + "slack": _StartableChannel(fake_config, mgr.bus), + } + mgr._dispatch_task = None + + enabled = mgr.enabled_channels + + assert "telegram" in enabled + assert "slack" in enabled + assert len(enabled) == 2 + + +@pytest.mark.asyncio +async def test_stop_all_cancels_dispatcher_and_stops_channels(): + """stop_all should cancel the dispatch task and stop all channels.""" + fake_config = SimpleNamespace( + channels=ChannelsConfig(), + providers=SimpleNamespace(groq=SimpleNamespace(api_key="")), + ) + + mgr = ChannelManager.__new__(ChannelManager) + mgr.config = fake_config + mgr.bus = MessageBus() + + ch = _StartableChannel(fake_config, mgr.bus) + mgr.channels = {"startable": ch} + + # Create a real cancelled task + async def dummy_task(): + while True: + await asyncio.sleep(1) + + dispatch_task = asyncio.create_task(dummy_task()) + mgr._dispatch_task = dispatch_task + + await mgr.stop_all() + + # Task should be cancelled + assert dispatch_task.cancelled() + # Channel should be stopped + assert ch.stopped is True + + +@pytest.mark.asyncio +async def test_start_channel_logs_error_on_failure(): + """_start_channel should log error when channel start fails.""" + class _FailingChannel(BaseChannel): + name = "failing" + display_name = "Failing" + + async def start(self) -> None: + raise RuntimeError("connection failed") + + async def stop(self) -> None: + pass + + async def send(self, msg: OutboundMessage) -> None: + pass + + fake_config = SimpleNamespace( + channels=ChannelsConfig(), + providers=SimpleNamespace(groq=SimpleNamespace(api_key="")), + ) + + mgr = ChannelManager.__new__(ChannelManager) + mgr.config = fake_config + mgr.bus = MessageBus() + mgr.channels = {} + mgr._dispatch_task = None + + ch = _FailingChannel(fake_config, mgr.bus) + + # Should not raise, just log error + await mgr._start_channel("failing", ch) + + +@pytest.mark.asyncio +async def test_stop_all_handles_channel_exception(): + """stop_all should handle exceptions when stopping channels gracefully.""" + class _StopFailingChannel(BaseChannel): + name = "stopfailing" + display_name = "Stop Failing" + + async def start(self) -> None: + pass + + async def stop(self) -> None: + raise RuntimeError("stop failed") + + async def send(self, msg: OutboundMessage) -> None: + pass + + fake_config = SimpleNamespace( + channels=ChannelsConfig(), + providers=SimpleNamespace(groq=SimpleNamespace(api_key="")), + ) + + mgr = ChannelManager.__new__(ChannelManager) + mgr.config = fake_config + mgr.bus = MessageBus() + mgr.channels = {"stopfailing": _StopFailingChannel(fake_config, mgr.bus)} + mgr._dispatch_task = None + + # Should not raise even if channel.stop() raises + await mgr.stop_all() + + +@pytest.mark.asyncio +async def test_start_all_no_channels_logs_warning(): + """start_all should log warning when no channels are enabled.""" + fake_config = SimpleNamespace( + channels=ChannelsConfig(), + providers=SimpleNamespace(groq=SimpleNamespace(api_key="")), + ) + + mgr = ChannelManager.__new__(ChannelManager) + mgr.config = fake_config + mgr.bus = MessageBus() + mgr.channels = {} # No channels + mgr._dispatch_task = None + + # Should return early without creating dispatch task + await mgr.start_all() + + assert mgr._dispatch_task is None + + +@pytest.mark.asyncio +async def test_start_all_creates_dispatch_task(): + """start_all should create the dispatch task when channels exist.""" + fake_config = SimpleNamespace( + channels=ChannelsConfig(), + providers=SimpleNamespace(groq=SimpleNamespace(api_key="")), + ) + + mgr = ChannelManager.__new__(ChannelManager) + mgr.config = fake_config + mgr.bus = MessageBus() + + ch = _StartableChannel(fake_config, mgr.bus) + mgr.channels = {"startable": ch} + mgr._dispatch_task = None + + # Cancel immediately after start to avoid running forever + async def cancel_after_start(): + await asyncio.sleep(0.01) + if mgr._dispatch_task: + mgr._dispatch_task.cancel() + + cancel_task = asyncio.create_task(cancel_after_start()) + + try: + await mgr.start_all() + except asyncio.CancelledError: + pass + finally: + cancel_task.cancel() + try: + await cancel_task + except asyncio.CancelledError: + pass + + # Dispatch task should have been created + assert mgr._dispatch_task is not None + From f0f0bf02d77e24046a4c35037d5bd3d938222bc7 Mon Sep 17 00:00:00 2001 From: Xubin Ren Date: Wed, 25 Mar 2026 14:34:37 +0000 Subject: [PATCH 38/40] refactor(channel): centralize retry around explicit send failures Make channel delivery failures raise consistently so retry policy lives in ChannelManager rather than being split across individual channels. Tighten Telegram stream finalization, clarify sendMaxRetries semantics, and align the docs with the behavior the system actually guarantees. --- README.md | 9 +++++---- nanobot/channels/base.py | 9 ++++++++- nanobot/channels/feishu.py | 1 + nanobot/channels/manager.py | 15 +++++++++------ nanobot/channels/mochat.py | 1 + nanobot/channels/slack.py | 1 + nanobot/channels/telegram.py | 9 ++++++--- nanobot/channels/wecom.py | 1 + nanobot/channels/weixin.py | 1 + nanobot/channels/whatsapp.py | 2 ++ nanobot/config/schema.py | 2 +- tests/channels/test_telegram_channel.py | 21 +++++++++++++++++++-- 12 files changed, 55 insertions(+), 17 deletions(-) diff --git a/README.md b/README.md index 40ecd4cb108..ae2512eb0a9 100644 --- a/README.md +++ b/README.md @@ -1176,14 +1176,15 @@ Global settings that apply to all channels. Configure under the `channels` secti |---------|---------|-------------| | `sendProgress` | `true` | Stream agent's text progress to the channel | | `sendToolHints` | `false` | Stream tool-call hints (e.g. `read_file("…")`) | -| `sendMaxRetries` | `3` | Max retry attempts for message send failures (0-10) | +| `sendMaxRetries` | `3` | Max delivery attempts per outbound message, including the initial send (0-10 configured, minimum 1 actual attempt) | #### Retry Behavior -When a message fails to send, nanobot will automatically retry with exponential backoff: +When a channel send operation raises an error, nanobot retries with exponential backoff: -- **Attempts 1-3**: Retry delays are 1s, 2s, 4s -- **Attempts 4+**: Retry delay caps at 4s +- **Attempt 1**: Initial send +- **Attempts 2-4**: Retry delays are 1s, 2s, 4s +- **Attempts 5+**: Retry delay caps at 4s - **Transient failures** (network hiccups, temporary API limits): Retry usually succeeds - **Permanent failures** (invalid token, channel banned): All retries fail diff --git a/nanobot/channels/base.py b/nanobot/channels/base.py index 87614cb46c0..5a776eed4ff 100644 --- a/nanobot/channels/base.py +++ b/nanobot/channels/base.py @@ -85,11 +85,18 @@ async def send(self, msg: OutboundMessage) -> None: Args: msg: The message to send. + + Implementations should raise on delivery failure so the channel manager + can apply any retry policy in one place. """ pass async def send_delta(self, chat_id: str, delta: str, metadata: dict[str, Any] | None = None) -> None: - """Deliver a streaming text chunk. Override in subclass to enable streaming.""" + """Deliver a streaming text chunk. + + Override in subclasses to enable streaming. Implementations should + raise on delivery failure so the channel manager can retry. + """ pass @property diff --git a/nanobot/channels/feishu.py b/nanobot/channels/feishu.py index 06daf409d15..0ffca601efa 100644 --- a/nanobot/channels/feishu.py +++ b/nanobot/channels/feishu.py @@ -1031,6 +1031,7 @@ def _do_send(m_type: str, content: str) -> None: except Exception as e: logger.error("Error sending Feishu message: {}", e) + raise def _on_message_sync(self, data: Any) -> None: """ diff --git a/nanobot/channels/manager.py b/nanobot/channels/manager.py index 2f1b400c431..2ec7c001e68 100644 --- a/nanobot/channels/manager.py +++ b/nanobot/channels/manager.py @@ -142,6 +142,14 @@ async def _dispatch_outbound(self) -> None: except asyncio.CancelledError: break + @staticmethod + async def _send_once(channel: BaseChannel, msg: OutboundMessage) -> None: + """Send one outbound message without retry policy.""" + if msg.metadata.get("_stream_delta") or msg.metadata.get("_stream_end"): + await channel.send_delta(msg.chat_id, msg.content, msg.metadata) + elif not msg.metadata.get("_streamed"): + await channel.send(msg) + async def _send_with_retry(self, channel: BaseChannel, msg: OutboundMessage) -> None: """Send a message with retry on failure using exponential backoff. @@ -151,12 +159,7 @@ async def _send_with_retry(self, channel: BaseChannel, msg: OutboundMessage) -> for attempt in range(max_attempts): try: - if msg.metadata.get("_stream_delta") or msg.metadata.get("_stream_end"): - await channel.send_delta(msg.chat_id, msg.content, msg.metadata) - elif msg.metadata.get("_streamed"): - pass - else: - await channel.send(msg) + await self._send_once(channel, msg) return # Send succeeded except asyncio.CancelledError: raise # Propagate cancellation for graceful shutdown diff --git a/nanobot/channels/mochat.py b/nanobot/channels/mochat.py index 629379f2ead..0b02aec6243 100644 --- a/nanobot/channels/mochat.py +++ b/nanobot/channels/mochat.py @@ -374,6 +374,7 @@ async def send(self, msg: OutboundMessage) -> None: content, msg.reply_to) except Exception as e: logger.error("Failed to send Mochat message: {}", e) + raise # ---- config / init helpers --------------------------------------------- diff --git a/nanobot/channels/slack.py b/nanobot/channels/slack.py index 87194ac705c..2503f6a2d0c 100644 --- a/nanobot/channels/slack.py +++ b/nanobot/channels/slack.py @@ -145,6 +145,7 @@ async def send(self, msg: OutboundMessage) -> None: except Exception as e: logger.error("Error sending Slack message: {}", e) + raise async def _on_socket_request( self, diff --git a/nanobot/channels/telegram.py b/nanobot/channels/telegram.py index fcccbe8a4f9..c3041c9d23d 100644 --- a/nanobot/channels/telegram.py +++ b/nanobot/channels/telegram.py @@ -476,6 +476,7 @@ async def _send_text( ) except Exception as e2: logger.error("Error sending Telegram message: {}", e2) + raise async def send_delta(self, chat_id: str, delta: str, metadata: dict[str, Any] | None = None) -> None: """Progressive message editing: send on first delta, edit on subsequent ones.""" @@ -485,7 +486,7 @@ async def send_delta(self, chat_id: str, delta: str, metadata: dict[str, Any] | int_chat_id = int(chat_id) if meta.get("_stream_end"): - buf = self._stream_bufs.pop(chat_id, None) + buf = self._stream_bufs.get(chat_id) if not buf or not buf.message_id or not buf.text: return self._stop_typing(chat_id) @@ -504,8 +505,10 @@ async def send_delta(self, chat_id: str, delta: str, metadata: dict[str, Any] | chat_id=int_chat_id, message_id=buf.message_id, text=buf.text, ) - except Exception: - pass + except Exception as e2: + logger.warning("Final stream edit failed: {}", e2) + raise # Let ChannelManager handle retry + self._stream_bufs.pop(chat_id, None) return buf = self._stream_bufs.get(chat_id) diff --git a/nanobot/channels/wecom.py b/nanobot/channels/wecom.py index 2f248559ec2..05ad1482545 100644 --- a/nanobot/channels/wecom.py +++ b/nanobot/channels/wecom.py @@ -368,3 +368,4 @@ async def send(self, msg: OutboundMessage) -> None: except Exception as e: logger.error("Error sending WeCom message: {}", e) + raise diff --git a/nanobot/channels/weixin.py b/nanobot/channels/weixin.py index 3fbe329aa52..f09ef95f7d0 100644 --- a/nanobot/channels/weixin.py +++ b/nanobot/channels/weixin.py @@ -751,6 +751,7 @@ async def send(self, msg: OutboundMessage) -> None: await self._send_text(msg.chat_id, chunk, ctx_token) except Exception as e: logger.error("Error sending WeChat message: {}", e) + raise async def _send_text( self, diff --git a/nanobot/channels/whatsapp.py b/nanobot/channels/whatsapp.py index 8826a64f3fd..95bde46e9fa 100644 --- a/nanobot/channels/whatsapp.py +++ b/nanobot/channels/whatsapp.py @@ -146,6 +146,7 @@ async def send(self, msg: OutboundMessage) -> None: await self._ws.send(json.dumps(payload, ensure_ascii=False)) except Exception as e: logger.error("Error sending WhatsApp message: {}", e) + raise for media_path in msg.media or []: try: @@ -160,6 +161,7 @@ async def send(self, msg: OutboundMessage) -> None: await self._ws.send(json.dumps(payload, ensure_ascii=False)) except Exception as e: logger.error("Error sending WhatsApp media {}: {}", media_path, e) + raise async def _handle_bridge_message(self, raw: str) -> None: """Handle a message from the bridge.""" diff --git a/nanobot/config/schema.py b/nanobot/config/schema.py index 1d964a642f5..15fcacafe73 100644 --- a/nanobot/config/schema.py +++ b/nanobot/config/schema.py @@ -25,7 +25,7 @@ class ChannelsConfig(Base): send_progress: bool = True # stream agent's text progress to the channel send_tool_hints: bool = False # stream tool-call hints (e.g. read_file("…")) - send_max_retries: int = Field(default=3, ge=0, le=10) # Max retry attempts for message send failures + send_max_retries: int = Field(default=3, ge=0, le=10) # Max delivery attempts (initial send included) class AgentDefaults(Base): diff --git a/tests/channels/test_telegram_channel.py b/tests/channels/test_telegram_channel.py index 353d5d05d71..6b4c008e090 100644 --- a/tests/channels/test_telegram_channel.py +++ b/tests/channels/test_telegram_channel.py @@ -13,7 +13,7 @@ from nanobot.bus.events import OutboundMessage from nanobot.bus.queue import MessageBus -from nanobot.channels.telegram import TELEGRAM_REPLY_CONTEXT_MAX_LEN, TelegramChannel +from nanobot.channels.telegram import TELEGRAM_REPLY_CONTEXT_MAX_LEN, TelegramChannel, _StreamBuf from nanobot.channels.telegram import TelegramConfig @@ -271,13 +271,30 @@ async def always_timeout(**kwargs): orig_delay = tg_mod._SEND_RETRY_BASE_DELAY tg_mod._SEND_RETRY_BASE_DELAY = 0.01 try: - await channel._send_text(123, "hello", None, {}) + with pytest.raises(TimedOut): + await channel._send_text(123, "hello", None, {}) finally: tg_mod._SEND_RETRY_BASE_DELAY = orig_delay assert channel._app.bot.sent_messages == [] +@pytest.mark.asyncio +async def test_send_delta_stream_end_raises_and_keeps_buffer_on_failure() -> None: + channel = TelegramChannel( + TelegramConfig(enabled=True, token="123:abc", allow_from=["*"]), + MessageBus(), + ) + channel._app = _FakeApp(lambda: None) + channel._app.bot.edit_message_text = AsyncMock(side_effect=RuntimeError("boom")) + channel._stream_bufs["123"] = _StreamBuf(text="hello", message_id=7, last_edit=0.0) + + with pytest.raises(RuntimeError, match="boom"): + await channel.send_delta("123", "", {"_stream_end": True}) + + assert "123" in channel._stream_bufs + + def test_derive_topic_session_key_uses_thread_id() -> None: message = SimpleNamespace( chat=SimpleNamespace(type="supergroup"), From 813de554c9b08e375fc52eebc96c28d7c2faf5c2 Mon Sep 17 00:00:00 2001 From: longyongshen Date: Wed, 25 Mar 2026 16:32:10 +0800 Subject: [PATCH 39/40] =?UTF-8?q?feat(provider):=20add=20Step=20Fun=20(?= =?UTF-8?q?=E9=98=B6=E8=B7=83=E6=98=9F=E8=BE=B0)=20provider=20support?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Made-with: Cursor --- README.md | 3 +++ nanobot/config/schema.py | 1 + nanobot/providers/registry.py | 9 +++++++++ 3 files changed, 13 insertions(+) diff --git a/README.md b/README.md index ae2512eb0a9..7f686b68312 100644 --- a/README.md +++ b/README.md @@ -846,6 +846,8 @@ Config file: `~/.nanobot/config.json` > - **VolcEngine / BytePlus Coding Plan**: Use dedicated providers `volcengineCodingPlan` or `byteplusCodingPlan` instead of the pay-per-use `volcengine` / `byteplus` providers. > - **Zhipu Coding Plan**: If you're on Zhipu's coding plan, set `"apiBase": "https://open.bigmodel.cn/api/coding/paas/v4"` in your zhipu provider config. > - **Alibaba Cloud BaiLian**: If you're using Alibaba Cloud BaiLian's OpenAI-compatible endpoint, set `"apiBase": "https://dashscope.aliyuncs.com/compatible-mode/v1"` in your dashscope provider config. +> - **Step Fun (Mainland China)**: If your API key is from Step Fun's mainland China platform (stepfun.com), set `"apiBase": "https://api.stepfun.com/v1"` in your stepfun provider config. +> - **Step Fun Step Plan**: Exclusive discount links for the nanobot community: [Overseas](https://platform.stepfun.ai/step-plan) · [Mainland China](https://platform.stepfun.com/step-plan) | Provider | Purpose | Get API Key | |----------|---------|-------------| @@ -867,6 +869,7 @@ Config file: `~/.nanobot/config.json` | `zhipu` | LLM (Zhipu GLM) | [open.bigmodel.cn](https://open.bigmodel.cn) | | `ollama` | LLM (local, Ollama) | — | | `mistral` | LLM | [docs.mistral.ai](https://docs.mistral.ai/) | +| `stepfun` | LLM (Step Fun/阶跃星辰) | [platform.stepfun.com](https://platform.stepfun.com) | | `ovms` | LLM (local, OpenVINO Model Server) | [docs.openvino.ai](https://docs.openvino.ai/2026/model-server/ovms_docs_llm_quickstart.html) | | `vllm` | LLM (local, any OpenAI-compatible server) | — | | `openai_codex` | LLM (Codex, OAuth) | `nanobot provider login openai-codex` | diff --git a/nanobot/config/schema.py b/nanobot/config/schema.py index 15fcacafe73..c8b69b42e27 100644 --- a/nanobot/config/schema.py +++ b/nanobot/config/schema.py @@ -77,6 +77,7 @@ class ProvidersConfig(Base): moonshot: ProviderConfig = Field(default_factory=ProviderConfig) minimax: ProviderConfig = Field(default_factory=ProviderConfig) mistral: ProviderConfig = Field(default_factory=ProviderConfig) + stepfun: ProviderConfig = Field(default_factory=ProviderConfig) # Step Fun (阶跃星辰) aihubmix: ProviderConfig = Field(default_factory=ProviderConfig) # AiHubMix API gateway siliconflow: ProviderConfig = Field(default_factory=ProviderConfig) # SiliconFlow (硅基流动) volcengine: ProviderConfig = Field(default_factory=ProviderConfig) # VolcEngine (火山引擎) diff --git a/nanobot/providers/registry.py b/nanobot/providers/registry.py index 206b0b5042f..e42e1f95e1c 100644 --- a/nanobot/providers/registry.py +++ b/nanobot/providers/registry.py @@ -286,6 +286,15 @@ def label(self) -> str: backend="openai_compat", default_api_base="https://api.mistral.ai/v1", ), + # Step Fun (阶跃星辰): OpenAI-compatible API + ProviderSpec( + name="stepfun", + keywords=("stepfun", "step"), + env_key="STEPFUN_API_KEY", + display_name="Step Fun", + backend="openai_compat", + default_api_base="https://api.stepfun.com/v1", + ), # === Local deployment (matched by config key, NOT by api_base) ========= # vLLM / any OpenAI-compatible local server ProviderSpec( From 33abe915e767f64e43b4392a4658815862d2e5f4 Mon Sep 17 00:00:00 2001 From: Xubin Ren Date: Thu, 26 Mar 2026 02:35:12 +0000 Subject: [PATCH 40/40] fix telegram streaming message boundaries --- nanobot/agent/loop.py | 22 ++++++++- nanobot/channels/base.py | 4 ++ nanobot/channels/telegram.py | 27 +++++++++-- tests/channels/test_telegram_channel.py | 59 ++++++++++++++++++++++++- 4 files changed, 106 insertions(+), 6 deletions(-) diff --git a/nanobot/agent/loop.py b/nanobot/agent/loop.py index afe62ca28e6..3482e38d2cc 100644 --- a/nanobot/agent/loop.py +++ b/nanobot/agent/loop.py @@ -373,17 +373,35 @@ async def _dispatch(self, msg: InboundMessage) -> None: try: on_stream = on_stream_end = None if msg.metadata.get("_wants_stream"): + # Split one answer into distinct stream segments. + stream_base_id = f"{msg.session_key}:{time.time_ns()}" + stream_segment = 0 + + def _current_stream_id() -> str: + return f"{stream_base_id}:{stream_segment}" + async def on_stream(delta: str) -> None: await self.bus.publish_outbound(OutboundMessage( channel=msg.channel, chat_id=msg.chat_id, - content=delta, metadata={"_stream_delta": True}, + content=delta, + metadata={ + "_stream_delta": True, + "_stream_id": _current_stream_id(), + }, )) async def on_stream_end(*, resuming: bool = False) -> None: + nonlocal stream_segment await self.bus.publish_outbound(OutboundMessage( channel=msg.channel, chat_id=msg.chat_id, - content="", metadata={"_stream_end": True, "_resuming": resuming}, + content="", + metadata={ + "_stream_end": True, + "_resuming": resuming, + "_stream_id": _current_stream_id(), + }, )) + stream_segment += 1 response = await self._process_message( msg, on_stream=on_stream, on_stream_end=on_stream_end, diff --git a/nanobot/channels/base.py b/nanobot/channels/base.py index 5a776eed4ff..86e9913444b 100644 --- a/nanobot/channels/base.py +++ b/nanobot/channels/base.py @@ -96,6 +96,10 @@ async def send_delta(self, chat_id: str, delta: str, metadata: dict[str, Any] | Override in subclasses to enable streaming. Implementations should raise on delivery failure so the channel manager can retry. + + Streaming contract: ``_stream_delta`` is a chunk, ``_stream_end`` ends + the current segment, and stateful implementations must key buffers by + ``_stream_id`` rather than only by ``chat_id``. """ pass diff --git a/nanobot/channels/telegram.py b/nanobot/channels/telegram.py index c3041c9d23d..feb908657fc 100644 --- a/nanobot/channels/telegram.py +++ b/nanobot/channels/telegram.py @@ -12,7 +12,7 @@ from loguru import logger from pydantic import Field from telegram import BotCommand, ReactionTypeEmoji, ReplyParameters, Update -from telegram.error import TimedOut +from telegram.error import BadRequest, TimedOut from telegram.ext import Application, CommandHandler, ContextTypes, MessageHandler, filters from telegram.request import HTTPXRequest @@ -163,6 +163,7 @@ class _StreamBuf: text: str = "" message_id: int | None = None last_edit: float = 0.0 + stream_id: str | None = None class TelegramConfig(Base): @@ -478,17 +479,24 @@ async def _send_text( logger.error("Error sending Telegram message: {}", e2) raise + @staticmethod + def _is_not_modified_error(exc: Exception) -> bool: + return isinstance(exc, BadRequest) and "message is not modified" in str(exc).lower() + async def send_delta(self, chat_id: str, delta: str, metadata: dict[str, Any] | None = None) -> None: """Progressive message editing: send on first delta, edit on subsequent ones.""" if not self._app: return meta = metadata or {} int_chat_id = int(chat_id) + stream_id = meta.get("_stream_id") if meta.get("_stream_end"): buf = self._stream_bufs.get(chat_id) if not buf or not buf.message_id or not buf.text: return + if stream_id is not None and buf.stream_id is not None and buf.stream_id != stream_id: + return self._stop_typing(chat_id) try: html = _markdown_to_telegram_html(buf.text) @@ -498,6 +506,10 @@ async def send_delta(self, chat_id: str, delta: str, metadata: dict[str, Any] | text=html, parse_mode="HTML", ) except Exception as e: + if self._is_not_modified_error(e): + logger.debug("Final stream edit already applied for {}", chat_id) + self._stream_bufs.pop(chat_id, None) + return logger.debug("Final stream edit failed (HTML), trying plain: {}", e) try: await self._call_with_retry( @@ -506,15 +518,21 @@ async def send_delta(self, chat_id: str, delta: str, metadata: dict[str, Any] | text=buf.text, ) except Exception as e2: + if self._is_not_modified_error(e2): + logger.debug("Final stream plain edit already applied for {}", chat_id) + self._stream_bufs.pop(chat_id, None) + return logger.warning("Final stream edit failed: {}", e2) raise # Let ChannelManager handle retry self._stream_bufs.pop(chat_id, None) return buf = self._stream_bufs.get(chat_id) - if buf is None: - buf = _StreamBuf() + if buf is None or (stream_id is not None and buf.stream_id is not None and buf.stream_id != stream_id): + buf = _StreamBuf(stream_id=stream_id) self._stream_bufs[chat_id] = buf + elif buf.stream_id is None: + buf.stream_id = stream_id buf.text += delta if not buf.text.strip(): @@ -541,6 +559,9 @@ async def send_delta(self, chat_id: str, delta: str, metadata: dict[str, Any] | ) buf.last_edit = now except Exception as e: + if self._is_not_modified_error(e): + buf.last_edit = now + return logger.warning("Stream edit failed: {}", e) raise # Let ChannelManager handle retry diff --git a/tests/channels/test_telegram_channel.py b/tests/channels/test_telegram_channel.py index 6b4c008e090..d5dafdee723 100644 --- a/tests/channels/test_telegram_channel.py +++ b/tests/channels/test_telegram_channel.py @@ -50,8 +50,9 @@ async def get_me(self): async def set_my_commands(self, commands) -> None: self.commands = commands - async def send_message(self, **kwargs) -> None: + async def send_message(self, **kwargs): self.sent_messages.append(kwargs) + return SimpleNamespace(message_id=len(self.sent_messages)) async def send_photo(self, **kwargs) -> None: self.sent_media.append({"kind": "photo", **kwargs}) @@ -295,6 +296,62 @@ async def test_send_delta_stream_end_raises_and_keeps_buffer_on_failure() -> Non assert "123" in channel._stream_bufs +@pytest.mark.asyncio +async def test_send_delta_stream_end_treats_not_modified_as_success() -> None: + from telegram.error import BadRequest + + channel = TelegramChannel( + TelegramConfig(enabled=True, token="123:abc", allow_from=["*"]), + MessageBus(), + ) + channel._app = _FakeApp(lambda: None) + channel._app.bot.edit_message_text = AsyncMock(side_effect=BadRequest("Message is not modified")) + channel._stream_bufs["123"] = _StreamBuf(text="hello", message_id=7, last_edit=0.0, stream_id="s:0") + + await channel.send_delta("123", "", {"_stream_end": True, "_stream_id": "s:0"}) + + assert "123" not in channel._stream_bufs + + +@pytest.mark.asyncio +async def test_send_delta_new_stream_id_replaces_stale_buffer() -> None: + channel = TelegramChannel( + TelegramConfig(enabled=True, token="123:abc", allow_from=["*"]), + MessageBus(), + ) + channel._app = _FakeApp(lambda: None) + channel._stream_bufs["123"] = _StreamBuf( + text="hello", + message_id=7, + last_edit=0.0, + stream_id="old:0", + ) + + await channel.send_delta("123", "world", {"_stream_delta": True, "_stream_id": "new:0"}) + + buf = channel._stream_bufs["123"] + assert buf.text == "world" + assert buf.stream_id == "new:0" + assert buf.message_id == 1 + + +@pytest.mark.asyncio +async def test_send_delta_incremental_edit_treats_not_modified_as_success() -> None: + from telegram.error import BadRequest + + channel = TelegramChannel( + TelegramConfig(enabled=True, token="123:abc", allow_from=["*"]), + MessageBus(), + ) + channel._app = _FakeApp(lambda: None) + channel._stream_bufs["123"] = _StreamBuf(text="hello", message_id=7, last_edit=0.0, stream_id="s:0") + channel._app.bot.edit_message_text = AsyncMock(side_effect=BadRequest("Message is not modified")) + + await channel.send_delta("123", "", {"_stream_delta": True, "_stream_id": "s:0"}) + + assert channel._stream_bufs["123"].last_edit > 0.0 + + def test_derive_topic_session_key_uses_thread_id() -> None: message = SimpleNamespace( chat=SimpleNamespace(type="supergroup"),