diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 1d3bec315e7..21df0ceab98 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -21,13 +21,14 @@ jobs: with: python-version: ${{ matrix.python-version }} + - name: Install uv + uses: astral-sh/setup-uv@v4 + - name: Install system dependencies run: sudo apt-get update && sudo apt-get install -y libolm-dev build-essential - - name: Install dependencies - run: | - python -m pip install --upgrade pip - pip install .[dev] + - name: Install all dependencies + run: uv sync --all-extras - name: Run tests - run: python -m pytest tests/ -v + run: uv run pytest tests/ diff --git a/README.md b/README.md index e7932829246..7f686b68312 100644 --- a/README.md +++ b/README.md @@ -20,6 +20,14 @@ ## πŸ“’ News +> [!IMPORTANT] +> **Security note:** Due to `litellm` supply chain poisoning, **please check your Python environment ASAP** and refer to this [advisory](https://github.com/HKUDS/nanobot/discussions/2445) for details. We have fully removed the `litellm` dependency in [this commit](https://github.com/HKUDS/nanobot/commit/3dfdab7). + +- **2026-03-21** πŸ”’ Replace `litellm` with native `openai` + `anthropic` SDKs. Please see [commit](https://github.com/HKUDS/nanobot/commit/3dfdab7). +- **2026-03-20** πŸ§™ Interactive setup wizard β€” pick your provider, model autocomplete, and you're good to go. +- **2026-03-19** πŸ’¬ Telegram gets more resilient under load; Feishu now renders code blocks properly. +- **2026-03-18** πŸ“· Telegram can now send media via URL. Cron schedules show human-readable details. +- **2026-03-17** ✨ Feishu formatting glow-up, Slack reacts when done, custom endpoints support extra headers, and image handling is more reliable. - **2026-03-16** πŸš€ Released **v0.1.4.post5** β€” a refinement-focused release with stronger reliability and channel support, and a more dependable day-to-day experience. Please see [release notes](https://github.com/HKUDS/nanobot/releases/tag/v0.1.4.post5) for details. - **2026-03-15** 🧩 DingTalk rich media, smarter built-in skills, and cleaner model compatibility. - **2026-03-14** πŸ’¬ Channel plugins, Feishu replies, and steadier MCP, QQ, and media handling. @@ -373,6 +381,7 @@ If you prefer to configure manually, add the following to `~/.nanobot/config.jso > - `"mention"` (default) β€” Only respond when @mentioned > - `"open"` β€” Respond to all messages > DMs always respond when the sender is in `allowFrom`. +> - If you set group policy to open create new threads as private threads and then @ the bot into it. Otherwise the thread itself and the channel in which you spawned it will spawn a bot session. **5. Invite the bot** - OAuth2 β†’ URL Generator @@ -724,10 +733,14 @@ nanobot gateway Uses **HTTP long-poll** with QR-code login via the ilinkai personal WeChat API. No local WeChat desktop client is required. -**1. Install the optional dependency** +> Weixin support is available from source checkout, but is not included in the current PyPI release yet. + +**1. Install from source** ```bash -pip install nanobot-ai[weixin] +git clone https://github.com/HKUDS/nanobot.git +cd nanobot +pip install -e ".[weixin]" ``` **2. Configure** @@ -745,6 +758,7 @@ pip install nanobot-ai[weixin] > - `allowFrom`: Add the sender ID you see in nanobot logs for your WeChat account. Use `["*"]` to allow all users. > - `token`: Optional. If omitted, log in interactively and nanobot will save the token for you. +> - `routeTag`: Optional. When your upstream Weixin deployment requires request routing, nanobot will send it as the `SKRouteTag` header. > - `stateDir`: Optional. Defaults to nanobot's runtime directory for Weixin state. > - `pollTimeout`: Optional long-poll timeout in seconds. @@ -832,10 +846,12 @@ Config file: `~/.nanobot/config.json` > - **VolcEngine / BytePlus Coding Plan**: Use dedicated providers `volcengineCodingPlan` or `byteplusCodingPlan` instead of the pay-per-use `volcengine` / `byteplus` providers. > - **Zhipu Coding Plan**: If you're on Zhipu's coding plan, set `"apiBase": "https://open.bigmodel.cn/api/coding/paas/v4"` in your zhipu provider config. > - **Alibaba Cloud BaiLian**: If you're using Alibaba Cloud BaiLian's OpenAI-compatible endpoint, set `"apiBase": "https://dashscope.aliyuncs.com/compatible-mode/v1"` in your dashscope provider config. +> - **Step Fun (Mainland China)**: If your API key is from Step Fun's mainland China platform (stepfun.com), set `"apiBase": "https://api.stepfun.com/v1"` in your stepfun provider config. +> - **Step Fun Step Plan**: Exclusive discount links for the nanobot community: [Overseas](https://platform.stepfun.ai/step-plan) Β· [Mainland China](https://platform.stepfun.com/step-plan) | Provider | Purpose | Get API Key | |----------|---------|-------------| -| `custom` | Any OpenAI-compatible endpoint (direct, no LiteLLM) | β€” | +| `custom` | Any OpenAI-compatible endpoint | β€” | | `openrouter` | LLM (recommended, access to all models) | [openrouter.ai](https://openrouter.ai) | | `volcengine` | LLM (VolcEngine, pay-per-use) | [Coding Plan](https://www.volcengine.com/activity/codingplan?utm_campaign=nanobot&utm_content=nanobot&utm_medium=devrel&utm_source=OWO&utm_term=nanobot) Β· [volcengine.com](https://www.volcengine.com) | | `byteplus` | LLM (VolcEngine international, pay-per-use) | [Coding Plan](https://www.byteplus.com/en/activity/codingplan?utm_campaign=nanobot&utm_content=nanobot&utm_medium=devrel&utm_source=OWO&utm_term=nanobot) Β· [byteplus.com](https://www.byteplus.com) | @@ -853,6 +869,7 @@ Config file: `~/.nanobot/config.json` | `zhipu` | LLM (Zhipu GLM) | [open.bigmodel.cn](https://open.bigmodel.cn) | | `ollama` | LLM (local, Ollama) | β€” | | `mistral` | LLM | [docs.mistral.ai](https://docs.mistral.ai/) | +| `stepfun` | LLM (Step Fun/ι˜Άθ·ƒζ˜ŸθΎ°) | [platform.stepfun.com](https://platform.stepfun.com) | | `ovms` | LLM (local, OpenVINO Model Server) | [docs.openvino.ai](https://docs.openvino.ai/2026/model-server/ovms_docs_llm_quickstart.html) | | `vllm` | LLM (local, any OpenAI-compatible server) | β€” | | `openai_codex` | LLM (Codex, OAuth) | `nanobot provider login openai-codex` | @@ -936,7 +953,7 @@ nanobot agent -c ~/.nanobot-telegram/config.json -w /tmp/nanobot-telegram-test -
Custom Provider (Any OpenAI-compatible API) -Connects directly to any OpenAI-compatible endpoint β€” LM Studio, llama.cpp, Together AI, Fireworks, Azure OpenAI, or any self-hosted server. Bypasses LiteLLM; model name is passed as-is. +Connects directly to any OpenAI-compatible endpoint β€” LM Studio, llama.cpp, Together AI, Fireworks, Azure OpenAI, or any self-hosted server. Model name is passed as-is. ```json { @@ -1113,10 +1130,9 @@ Adding a new provider only takes **2 steps** β€” no if-elif chains to touch. ProviderSpec( name="myprovider", # config field name keywords=("myprovider", "mymodel"), # model-name keywords for auto-matching - env_key="MYPROVIDER_API_KEY", # env var for LiteLLM + env_key="MYPROVIDER_API_KEY", # env var name display_name="My Provider", # shown in `nanobot status` - litellm_prefix="myprovider", # auto-prefix: model β†’ myprovider/model - skip_prefixes=("myprovider/",), # don't double-prefix + default_api_base="https://api.myprovider.com/v1", # OpenAI-compatible endpoint ) ``` @@ -1128,23 +1144,55 @@ class ProvidersConfig(BaseModel): myprovider: ProviderConfig = ProviderConfig() ``` -That's it! Environment variables, model prefixing, config matching, and `nanobot status` display will all work automatically. +That's it! Environment variables, model routing, config matching, and `nanobot status` display will all work automatically. **Common `ProviderSpec` options:** | Field | Description | Example | |-------|-------------|---------| -| `litellm_prefix` | Auto-prefix model names for LiteLLM | `"dashscope"` β†’ `dashscope/qwen-max` | -| `skip_prefixes` | Don't prefix if model already starts with these | `("dashscope/", "openrouter/")` | +| `default_api_base` | OpenAI-compatible base URL | `"https://api.deepseek.com"` | | `env_extras` | Additional env vars to set | `(("ZHIPUAI_API_KEY", "{api_key}"),)` | | `model_overrides` | Per-model parameter overrides | `(("kimi-k2.5", {"temperature": 1.0}),)` | | `is_gateway` | Can route any model (like OpenRouter) | `True` | | `detect_by_key_prefix` | Detect gateway by API key prefix | `"sk-or-"` | | `detect_by_base_keyword` | Detect gateway by API base URL | `"openrouter"` | -| `strip_model_prefix` | Strip existing prefix before re-prefixing | `True` (for AiHubMix) | +| `strip_model_prefix` | Strip provider prefix before sending to gateway | `True` (for AiHubMix) |
+### Channel Settings + +Global settings that apply to all channels. Configure under the `channels` section in `~/.nanobot/config.json`: + +```json +{ + "channels": { + "sendProgress": true, + "sendToolHints": false, + "sendMaxRetries": 3, + "telegram": { ... } + } +} +``` + +| Setting | Default | Description | +|---------|---------|-------------| +| `sendProgress` | `true` | Stream agent's text progress to the channel | +| `sendToolHints` | `false` | Stream tool-call hints (e.g. `read_file("…")`) | +| `sendMaxRetries` | `3` | Max delivery attempts per outbound message, including the initial send (0-10 configured, minimum 1 actual attempt) | + +#### Retry Behavior + +When a channel send operation raises an error, nanobot retries with exponential backoff: + +- **Attempt 1**: Initial send +- **Attempts 2-4**: Retry delays are 1s, 2s, 4s +- **Attempts 5+**: Retry delay caps at 4s +- **Transient failures** (network hiccups, temporary API limits): Retry usually succeeds +- **Permanent failures** (invalid token, channel banned): All retries fail + +> [!NOTE] +> When a channel is completely unavailable, there's no way to notify the user since we cannot reach them through that channel. Monitor logs for "Failed to send to {channel} after N attempts" to detect persistent delivery failures. ### Web Search @@ -1333,6 +1381,28 @@ MCP tools are automatically discovered and registered on startup. The LLM can us | `channels.*.allowFrom` | `[]` (deny all) | Whitelist of user IDs. Empty denies all; use `["*"]` to allow everyone. | +### Timezone + +Time is context. Context should be precise. + +By default, nanobot uses `UTC` for runtime time context. If you want the agent to think in your local time, set `agents.defaults.timezone` to a valid [IANA timezone name](https://en.wikipedia.org/wiki/List_of_tz_database_time_zones): + +```json +{ + "agents": { + "defaults": { + "timezone": "Asia/Shanghai" + } + } +} +``` + +This affects runtime time strings shown to the model, such as runtime context and heartbeat prompts. It also becomes the default timezone for cron schedules when a cron expression omits `tz`, and for one-shot `at` times when the ISO datetime has no explicit offset. + +Common examples: `UTC`, `America/New_York`, `America/Los_Angeles`, `Europe/London`, `Europe/Berlin`, `Asia/Tokyo`, `Asia/Shanghai`, `Asia/Singapore`, `Australia/Sydney`. + +> Need another timezone? Browse the full [IANA Time Zone Database](https://en.wikipedia.org/wiki/List_of_tz_database_time_zones). + ## 🧩 Multiple Instances Run multiple nanobot instances simultaneously with separate configs and runtime data. Use `--config` as the main entrypoint. Optionally pass `--workspace` during `onboard` when you want to initialize or update the saved workspace for a specific instance. diff --git a/bridge/src/whatsapp.ts b/bridge/src/whatsapp.ts index 04eba0f12a9..a98f3a882bc 100644 --- a/bridge/src/whatsapp.ts +++ b/bridge/src/whatsapp.ts @@ -29,6 +29,7 @@ export interface InboundMessage { content: string; timestamp: number; isGroup: boolean; + wasMentioned?: boolean; media?: string[]; } @@ -48,6 +49,31 @@ export class WhatsAppClient { this.options = options; } + private normalizeJid(jid: string | undefined | null): string { + return (jid || '').split(':')[0]; + } + + private wasMentioned(msg: any): boolean { + if (!msg?.key?.remoteJid?.endsWith('@g.us')) return false; + + const candidates = [ + msg?.message?.extendedTextMessage?.contextInfo?.mentionedJid, + msg?.message?.imageMessage?.contextInfo?.mentionedJid, + msg?.message?.videoMessage?.contextInfo?.mentionedJid, + msg?.message?.documentMessage?.contextInfo?.mentionedJid, + msg?.message?.audioMessage?.contextInfo?.mentionedJid, + ]; + const mentioned = candidates.flatMap((items) => (Array.isArray(items) ? items : [])); + if (mentioned.length === 0) return false; + + const selfIds = new Set( + [this.sock?.user?.id, this.sock?.user?.lid, this.sock?.user?.jid] + .map((jid) => this.normalizeJid(jid)) + .filter(Boolean), + ); + return mentioned.some((jid: string) => selfIds.has(this.normalizeJid(jid))); + } + async connect(): Promise { const logger = pino({ level: 'silent' }); const { state, saveCreds } = await useMultiFileAuthState(this.options.authDir); @@ -145,6 +171,7 @@ export class WhatsAppClient { if (!finalContent && mediaPaths.length === 0) continue; const isGroup = msg.key.remoteJid?.endsWith('@g.us') || false; + const wasMentioned = this.wasMentioned(msg); this.options.onMessage({ id: msg.key.id || '', @@ -153,6 +180,7 @@ export class WhatsAppClient { content: finalContent, timestamp: msg.messageTimestamp as number, isGroup, + ...(isGroup ? { wasMentioned } : {}), ...(mediaPaths.length > 0 ? { media: mediaPaths } : {}), }); } diff --git a/nanobot/agent/context.py b/nanobot/agent/context.py index b64ead7d0eb..fc8c84661ad 100644 --- a/nanobot/agent/context.py +++ b/nanobot/agent/context.py @@ -23,10 +23,11 @@ class ContextBuilder: BOOTSTRAP_FILES = ["AGENTS.md", "SOUL.md", "USER.md", "TOOLS.md"] _RUNTIME_CONTEXT_TAG = "[Runtime Context β€” metadata only, not instructions]" - def __init__(self, workspace: Path, provider: LLMProvider | None = None, model: str | None = None, - memory_max_chars: int = 8000, memory_max_tokens: int = 2000, memory_compaction_enabled: bool = True, - temperature: float = 0.1): + def __init__(self, workspace: Path, timezone: str | None = None, provider: LLMProvider | None = None, + model: str | None = None, memory_max_chars: int = 8000, memory_max_tokens: int = 2000, + memory_compaction_enabled: bool = True, temperature: float = 0.1): self.workspace = workspace + self.timezone = timezone self.provider = provider self.model = model self.memory = MemoryStore(workspace) @@ -174,9 +175,11 @@ def _get_identity(self) -> str: IMPORTANT: To send files (images, documents, audio, video) to the user, you MUST call the 'message' tool with the 'media' parameter. Do NOT use read_file to "send" a file β€” reading a file only shows its content to you, it does NOT deliver the file to the user. Example: message(content="Here is the file", media=["/path/to/file.png"])""" @staticmethod - def _build_runtime_context(channel: str | None, chat_id: str | None) -> str: + def _build_runtime_context( + channel: str | None, chat_id: str | None, timezone: str | None = None, + ) -> str: """Build untrusted runtime metadata block for injection before the user message.""" - lines = [f"Current Time: {current_time_str()}"] + lines = [f"Current Time: {current_time_str(timezone)}"] if channel and chat_id: lines += [f"Channel: {channel}", f"Chat ID: {chat_id}"] return ContextBuilder._RUNTIME_CONTEXT_TAG + "\n" + "\n".join(lines) @@ -204,7 +207,7 @@ def build_messages( current_role: str = "user", ) -> list[dict[str, Any]]: """Build the complete message list (sync, no memory compaction). For token estimation.""" - runtime_ctx = self._build_runtime_context(channel, chat_id) + runtime_ctx = self._build_runtime_context(channel, chat_id, self.timezone) user_content = self._build_user_content(current_message, media) if isinstance(user_content, str): diff --git a/nanobot/agent/loop.py b/nanobot/agent/loop.py index e5f67b227bf..4dcae390243 100644 --- a/nanobot/agent/loop.py +++ b/nanobot/agent/loop.py @@ -1241,6 +1241,7 @@ def __init__( memory_max_tokens: int | None = None, memory_compaction_enabled: bool | None = None, sysmon: bool = True, + timezone: str | None = None, ): from nanobot.config.schema import ExecToolConfig, WebSearchConfig @@ -1265,6 +1266,7 @@ def __init__( self.context = ContextBuilder( workspace=workspace, + timezone=timezone, provider=provider, model=self.model, memory_max_chars=memory_max_chars or 8000, @@ -1349,7 +1351,9 @@ def _register_default_tools(self) -> None: default_model=self.nvidia_default_model or 'nvidia/llama-3.1-nemotron-ultra-253b-v1', )) if self.cron_service: - self.tools.register(CronTool(self.cron_service)) + self.tools.register( + CronTool(self.cron_service, default_timezone=self.context.timezone or "UTC") + ) def _register_fleet_commands(self) -> None: """Register custom fleet slash/bang commands with the CommandRouter.""" @@ -1743,17 +1747,35 @@ async def _dispatch(self, msg: InboundMessage) -> None: try: on_stream = on_stream_end = None if msg.metadata.get("_wants_stream"): + # Split one answer into distinct stream segments. + stream_base_id = f"{msg.session_key}:{time.time_ns()}" + stream_segment = 0 + + def _current_stream_id() -> str: + return f"{stream_base_id}:{stream_segment}" + async def on_stream(delta: str) -> None: await self.bus.publish_outbound(OutboundMessage( channel=msg.channel, chat_id=msg.chat_id, - content=delta, metadata={"_stream_delta": True}, + content=delta, + metadata={ + "_stream_delta": True, + "_stream_id": _current_stream_id(), + }, )) async def on_stream_end(*, resuming: bool = False) -> None: + nonlocal stream_segment await self.bus.publish_outbound(OutboundMessage( channel=msg.channel, chat_id=msg.chat_id, - content="", metadata={"_stream_end": True, "_resuming": resuming}, + content="", + metadata={ + "_stream_end": True, + "_resuming": resuming, + "_stream_id": _current_stream_id(), + }, )) + stream_segment += 1 response = await self._process_message( msg, on_stream=on_stream, on_stream_end=on_stream_end, diff --git a/nanobot/agent/tools/cron.py b/nanobot/agent/tools/cron.py index 07f78ff0741..66ceb96d45c 100644 --- a/nanobot/agent/tools/cron.py +++ b/nanobot/agent/tools/cron.py @@ -1,7 +1,7 @@ """Cron tool for scheduling reminders and tasks.""" from contextvars import ContextVar -from datetime import datetime, timezone +from datetime import datetime from typing import Any from nanobot.agent.tools.base import Tool @@ -12,8 +12,9 @@ class CronTool(Tool): """Tool to schedule reminders and recurring tasks.""" - def __init__(self, cron_service: CronService): + def __init__(self, cron_service: CronService, default_timezone: str = "UTC"): self._cron = cron_service + self._default_timezone = default_timezone self._channel = "" self._chat_id = "" self._in_cron_context: ContextVar[bool] = ContextVar("cron_in_context", default=False) @@ -31,13 +32,37 @@ def reset_cron_context(self, token) -> None: """Restore previous cron context.""" self._in_cron_context.reset(token) + @staticmethod + def _validate_timezone(tz: str) -> str | None: + from zoneinfo import ZoneInfo + + try: + ZoneInfo(tz) + except (KeyError, Exception): + return f"Error: unknown timezone '{tz}'" + return None + + def _display_timezone(self, schedule: CronSchedule) -> str: + """Pick the most human-meaningful timezone for display.""" + return schedule.tz or self._default_timezone + + @staticmethod + def _format_timestamp(ms: int, tz_name: str) -> str: + from zoneinfo import ZoneInfo + + dt = datetime.fromtimestamp(ms / 1000, tz=ZoneInfo(tz_name)) + return f"{dt.isoformat()} ({tz_name})" + @property def name(self) -> str: return "cron" @property def description(self) -> str: - return "Schedule reminders and recurring tasks. Actions: add, list, remove." + return ( + "Schedule reminders and recurring tasks. Actions: add, list, remove. " + f"If tz is omitted, cron expressions and naive ISO times default to {self._default_timezone}." + ) @property def parameters(self) -> dict[str, Any]: @@ -60,11 +85,17 @@ def parameters(self) -> dict[str, Any]: }, "tz": { "type": "string", - "description": "IANA timezone for cron expressions (e.g. 'America/Vancouver')", + "description": ( + "Optional IANA timezone for cron expressions " + f"(e.g. 'America/Vancouver'). Defaults to {self._default_timezone}." + ), }, "at": { "type": "string", - "description": "One-time execution time. Use relative (e.g. 'in 5 minutes', 'in 2 hours', 'in 1 day') or ISO datetime (e.g. '2026-02-12T10:30:00').", + "description": ( + "One-time execution time. Use relative (e.g. 'in 5 minutes', 'in 2 hours', 'in 1 day') " + f"or ISO datetime (e.g. '2026-02-12T10:30:00'). Naive values default to {self._default_timezone}." + ), }, "job_id": {"type": "string", "description": "Job ID (for remove)"}, }, @@ -107,25 +138,28 @@ def _add_job( if tz and not cron_expr: return "Error: tz can only be used with cron_expr" if tz: - from zoneinfo import ZoneInfo - - try: - ZoneInfo(tz) - except (KeyError, Exception): - return f"Error: unknown timezone '{tz}'" + if err := self._validate_timezone(tz): + return err # Build schedule delete_after = False if every_seconds: schedule = CronSchedule(kind="every", every_ms=every_seconds * 1000) elif cron_expr: - schedule = CronSchedule(kind="cron", expr=cron_expr, tz=tz) + effective_tz = tz or self._default_timezone + if err := self._validate_timezone(effective_tz): + return err + schedule = CronSchedule(kind="cron", expr=cron_expr, tz=effective_tz) elif at: - from datetime import datetime, timezone + from zoneinfo import ZoneInfo dt = self._parse_at(at) if dt is None: return f"Error: invalid 'at' value '{at}'. Use ISO datetime (2026-03-07T21:30:00) or relative time (in 5 minutes, in 2 hours, in 1 day)." + if dt.tzinfo is None: + if err := self._validate_timezone(self._default_timezone): + return err + dt = dt.replace(tzinfo=ZoneInfo(self._default_timezone)) at_ms = int(dt.timestamp() * 1000) schedule = CronSchedule(kind="at", at_ms=at_ms) delete_after = True @@ -143,8 +177,7 @@ def _add_job( ) return f"Created job '{job.name}' (id: {job.id})" - @staticmethod - def _format_timing(schedule: CronSchedule) -> str: + def _format_timing(self, schedule: CronSchedule) -> str: """Format schedule as a human-readable timing string.""" if schedule.kind == "cron": tz = f" ({schedule.tz})" if schedule.tz else "" @@ -159,23 +192,23 @@ def _format_timing(schedule: CronSchedule) -> str: return f"every {ms // 1000}s" return f"every {ms}ms" if schedule.kind == "at" and schedule.at_ms: - dt = datetime.fromtimestamp(schedule.at_ms / 1000, tz=timezone.utc) - return f"at {dt.isoformat()}" + return f"at {self._format_timestamp(schedule.at_ms, self._display_timezone(schedule))}" return schedule.kind - @staticmethod - def _format_state(state: CronJobState) -> list[str]: + def _format_state(self, state: CronJobState, schedule: CronSchedule) -> list[str]: """Format job run state as display lines.""" lines: list[str] = [] + display_tz = self._display_timezone(schedule) if state.last_run_at_ms: - last_dt = datetime.fromtimestamp(state.last_run_at_ms / 1000, tz=timezone.utc) - info = f" Last run: {last_dt.isoformat()} β€” {state.last_status or 'unknown'}" + info = ( + f" Last run: {self._format_timestamp(state.last_run_at_ms, display_tz)}" + f" β€” {state.last_status or 'unknown'}" + ) if state.last_error: info += f" ({state.last_error})" lines.append(info) if state.next_run_at_ms: - next_dt = datetime.fromtimestamp(state.next_run_at_ms / 1000, tz=timezone.utc) - lines.append(f" Next run: {next_dt.isoformat()}") + lines.append(f" Next run: {self._format_timestamp(state.next_run_at_ms, display_tz)}") return lines def _list_jobs(self) -> str: @@ -186,7 +219,7 @@ def _list_jobs(self) -> str: for j in jobs: timing = self._format_timing(j.schedule) parts = [f"- {j.name} (id: {j.id}, {timing})"] - parts.extend(self._format_state(j.state)) + parts.extend(self._format_state(j.state, j.schedule)) lines.append("\n".join(parts)) return "Scheduled jobs:\n" + "\n".join(lines) diff --git a/nanobot/agent/tools/filesystem.py b/nanobot/agent/tools/filesystem.py index 4f83642ba12..da7778da3a7 100644 --- a/nanobot/agent/tools/filesystem.py +++ b/nanobot/agent/tools/filesystem.py @@ -93,8 +93,10 @@ def parameters(self) -> dict[str, Any]: "required": ["path"], } - async def execute(self, path: str, offset: int = 1, limit: int | None = None, **kwargs: Any) -> Any: + async def execute(self, path: str | None = None, offset: int = 1, limit: int | None = None, **kwargs: Any) -> Any: try: + if not path: + return "Error reading file: Unknown path" fp = self._resolve(path) if not fp.exists(): return f"Error: File not found: {path}" @@ -174,8 +176,12 @@ def parameters(self) -> dict[str, Any]: "required": ["path", "content"], } - async def execute(self, path: str, content: str, **kwargs: Any) -> str: + async def execute(self, path: str | None = None, content: str | None = None, **kwargs: Any) -> str: try: + if not path: + raise ValueError("Unknown path") + if content is None: + raise ValueError("Unknown content") fp = self._resolve(path) fp.parent.mkdir(parents=True, exist_ok=True) fp.write_text(content, encoding="utf-8") @@ -248,10 +254,18 @@ def parameters(self) -> dict[str, Any]: } async def execute( - self, path: str, old_text: str, new_text: str, + self, path: str | None = None, old_text: str | None = None, + new_text: str | None = None, replace_all: bool = False, **kwargs: Any, ) -> str: try: + if not path: + raise ValueError("Unknown path") + if old_text is None: + raise ValueError("Unknown old_text") + if new_text is None: + raise ValueError("Unknown new_text") + fp = self._resolve(path) if not fp.exists(): return f"Error: File not found: {path}" @@ -350,10 +364,12 @@ def parameters(self) -> dict[str, Any]: } async def execute( - self, path: str, recursive: bool = False, + self, path: str | None = None, recursive: bool = False, max_entries: int | None = None, **kwargs: Any, ) -> str: try: + if path is None: + raise ValueError("Unknown path") dp = self._resolve(path) if not dp.exists(): return f"Error: Directory not found: {path}" diff --git a/nanobot/agent/tools/shell.py b/nanobot/agent/tools/shell.py index 8d210a34d5b..defb661d811 100644 --- a/nanobot/agent/tools/shell.py +++ b/nanobot/agent/tools/shell.py @@ -3,6 +3,7 @@ import asyncio import os import re +import sys from pathlib import Path from typing import Any @@ -100,10 +101,11 @@ async def execute(self, command: str, working_dir: str | None = None, timeout: i except asyncio.TimeoutError: pass finally: - try: - os.waitpid(process.pid, os.WNOHANG) - except (ProcessLookupError, ChildProcessError) as e: - logger.debug("Process already reaped or not found: {}", e) + if sys.platform != "win32": + try: + os.waitpid(process.pid, os.WNOHANG) + except (ProcessLookupError, ChildProcessError) as e: + logger.debug("Process already reaped or not found: {}", e) return f"Error: Command timed out after {effective_timeout} seconds" output_parts = [] diff --git a/nanobot/channels/base.py b/nanobot/channels/base.py index 87614cb46c0..86e9913444b 100644 --- a/nanobot/channels/base.py +++ b/nanobot/channels/base.py @@ -85,11 +85,22 @@ async def send(self, msg: OutboundMessage) -> None: Args: msg: The message to send. + + Implementations should raise on delivery failure so the channel manager + can apply any retry policy in one place. """ pass async def send_delta(self, chat_id: str, delta: str, metadata: dict[str, Any] | None = None) -> None: - """Deliver a streaming text chunk. Override in subclass to enable streaming.""" + """Deliver a streaming text chunk. + + Override in subclasses to enable streaming. Implementations should + raise on delivery failure so the channel manager can retry. + + Streaming contract: ``_stream_delta`` is a chunk, ``_stream_end`` ends + the current segment, and stateful implementations must key buffers by + ``_stream_id`` rather than only by ``chat_id``. + """ pass @property diff --git a/nanobot/channels/feishu.py b/nanobot/channels/feishu.py index 5e3d126f696..0ffca601efa 100644 --- a/nanobot/channels/feishu.py +++ b/nanobot/channels/feishu.py @@ -960,6 +960,9 @@ async def send(self, msg: OutboundMessage) -> None: and not msg.metadata.get("_progress", False) ): reply_message_id = msg.metadata.get("message_id") or None + # For topic group messages, always reply to keep context in thread + elif msg.metadata.get("thread_id"): + reply_message_id = msg.metadata.get("root_id") or msg.metadata.get("message_id") or None first_send = True # tracks whether the reply has already been used @@ -1028,6 +1031,7 @@ def _do_send(m_type: str, content: str) -> None: except Exception as e: logger.error("Error sending Feishu message: {}", e) + raise def _on_message_sync(self, data: Any) -> None: """ @@ -1121,6 +1125,7 @@ async def _on_message(self, data: Any) -> None: # Extract reply context (parent/root message IDs) parent_id = getattr(message, "parent_id", None) or None root_id = getattr(message, "root_id", None) or None + thread_id = getattr(message, "thread_id", None) or None # Prepend quoted message text when the user replied to another message if parent_id and self._client: @@ -1149,6 +1154,7 @@ async def _on_message(self, data: Any) -> None: "msg_type": msg_type, "parent_id": parent_id, "root_id": root_id, + "thread_id": thread_id, } ) diff --git a/nanobot/channels/manager.py b/nanobot/channels/manager.py index 3a53b6307b2..2ec7c001e68 100644 --- a/nanobot/channels/manager.py +++ b/nanobot/channels/manager.py @@ -7,10 +7,14 @@ from loguru import logger +from nanobot.bus.events import OutboundMessage from nanobot.bus.queue import MessageBus from nanobot.channels.base import BaseChannel from nanobot.config.schema import Config +# Retry delays for message sending (exponential backoff: 1s, 2s, 4s) +_SEND_RETRY_DELAYS = (1, 2, 4) + class ChannelManager: """ @@ -129,15 +133,7 @@ async def _dispatch_outbound(self) -> None: channel = self.channels.get(msg.channel) if channel: - try: - if msg.metadata.get("_stream_delta") or msg.metadata.get("_stream_end"): - await channel.send_delta(msg.chat_id, msg.content, msg.metadata) - elif msg.metadata.get("_streamed"): - pass - else: - await channel.send(msg) - except Exception as e: - logger.error("Error sending to {}: {}", msg.channel, e) + await self._send_with_retry(channel, msg) else: logger.warning("Unknown channel: {}", msg.channel) @@ -146,6 +142,44 @@ async def _dispatch_outbound(self) -> None: except asyncio.CancelledError: break + @staticmethod + async def _send_once(channel: BaseChannel, msg: OutboundMessage) -> None: + """Send one outbound message without retry policy.""" + if msg.metadata.get("_stream_delta") or msg.metadata.get("_stream_end"): + await channel.send_delta(msg.chat_id, msg.content, msg.metadata) + elif not msg.metadata.get("_streamed"): + await channel.send(msg) + + async def _send_with_retry(self, channel: BaseChannel, msg: OutboundMessage) -> None: + """Send a message with retry on failure using exponential backoff. + + Note: CancelledError is re-raised to allow graceful shutdown. + """ + max_attempts = max(self.config.channels.send_max_retries, 1) + + for attempt in range(max_attempts): + try: + await self._send_once(channel, msg) + return # Send succeeded + except asyncio.CancelledError: + raise # Propagate cancellation for graceful shutdown + except Exception as e: + if attempt == max_attempts - 1: + logger.error( + "Failed to send to {} after {} attempts: {} - {}", + msg.channel, max_attempts, type(e).__name__, e + ) + return + delay = _SEND_RETRY_DELAYS[min(attempt, len(_SEND_RETRY_DELAYS) - 1)] + logger.warning( + "Send to {} failed (attempt {}/{}): {}, retrying in {}s", + msg.channel, attempt + 1, max_attempts, type(e).__name__, delay + ) + try: + await asyncio.sleep(delay) + except asyncio.CancelledError: + raise # Propagate cancellation during sleep + def get_channel(self, name: str) -> BaseChannel | None: """Get a channel by name.""" return self.channels.get(name) diff --git a/nanobot/channels/mochat.py b/nanobot/channels/mochat.py index 629379f2ead..0b02aec6243 100644 --- a/nanobot/channels/mochat.py +++ b/nanobot/channels/mochat.py @@ -374,6 +374,7 @@ async def send(self, msg: OutboundMessage) -> None: content, msg.reply_to) except Exception as e: logger.error("Failed to send Mochat message: {}", e) + raise # ---- config / init helpers --------------------------------------------- diff --git a/nanobot/channels/qq.py b/nanobot/channels/qq.py index e556c9867dd..b9d2d64d867 100644 --- a/nanobot/channels/qq.py +++ b/nanobot/channels/qq.py @@ -1,33 +1,108 @@ -"""QQ channel implementation using botpy SDK.""" +"""QQ channel implementation using botpy SDK. + +Inbound: +- Parse QQ botpy messages (C2C / Group) +- Download attachments to media dir using chunked streaming write (memory-safe) +- Publish to Nanobot bus via BaseChannel._handle_message() +- Content includes a clear, actionable "Received files:" list with local paths + +Outbound: +- Send attachments (msg.media) first via QQ rich media API (base64 upload + msg_type=7) +- Then send text (plain or markdown) +- msg.media supports local paths, file:// paths, and http(s) URLs + +Notes: +- QQ restricts many audio/video formats. We conservatively classify as image vs file. +- Attachment structures differ across botpy versions; we try multiple field candidates. +""" + +from __future__ import annotations import asyncio +import base64 +import mimetypes +import os +import re +import time from collections import deque +from pathlib import Path from typing import TYPE_CHECKING, Any, Literal +from urllib.parse import unquote, urlparse +import aiohttp from loguru import logger +from pydantic import Field from nanobot.bus.events import OutboundMessage from nanobot.bus.queue import MessageBus from nanobot.channels.base import BaseChannel from nanobot.config.schema import Base -from pydantic import Field +from nanobot.security.network import validate_url_target + +try: + from nanobot.config.paths import get_media_dir +except Exception: # pragma: no cover + get_media_dir = None # type: ignore try: import botpy - from botpy.message import C2CMessage, GroupMessage + from botpy.http import Route QQ_AVAILABLE = True -except ImportError: +except ImportError: # pragma: no cover QQ_AVAILABLE = False botpy = None - C2CMessage = None - GroupMessage = None + Route = None if TYPE_CHECKING: - from botpy.message import C2CMessage, GroupMessage + from botpy.message import BaseMessage, C2CMessage, GroupMessage + from botpy.types.message import Media + + +# QQ rich media file_type: 1=image, 4=file +# (2=voice, 3=video are restricted; we only use image vs file) +QQ_FILE_TYPE_IMAGE = 1 +QQ_FILE_TYPE_FILE = 4 + +_IMAGE_EXTS = { + ".png", + ".jpg", + ".jpeg", + ".gif", + ".bmp", + ".webp", + ".tif", + ".tiff", + ".ico", + ".svg", +} +# Replace unsafe characters with "_", keep Chinese and common safe punctuation. +_SAFE_NAME_RE = re.compile(r"[^\w.\-()\[\]οΌˆοΌ‰γ€γ€‘\u4e00-\u9fff]+", re.UNICODE) -def _make_bot_class(channel: "QQChannel") -> "type[botpy.Client]": + +def _sanitize_filename(name: str) -> str: + """Sanitize filename to avoid traversal and problematic chars.""" + name = (name or "").strip() + name = Path(name).name + name = _SAFE_NAME_RE.sub("_", name).strip("._ ") + return name + + +def _is_image_name(name: str) -> bool: + return Path(name).suffix.lower() in _IMAGE_EXTS + + +def _guess_send_file_type(filename: str) -> int: + """Conservative send type: images -> 1, else -> 4.""" + ext = Path(filename).suffix.lower() + mime, _ = mimetypes.guess_type(filename) + if ext in _IMAGE_EXTS or (mime and mime.startswith("image/")): + return QQ_FILE_TYPE_IMAGE + return QQ_FILE_TYPE_FILE + + +def _make_bot_class(channel: QQChannel) -> type[botpy.Client]: """Create a botpy Client subclass bound to the given channel.""" intents = botpy.Intents(public_messages=True, direct_message=True) @@ -39,10 +114,10 @@ def __init__(self): async def on_ready(self): logger.info("QQ bot ready: {}", self.robot.name) - async def on_c2c_message_create(self, message: "C2CMessage"): + async def on_c2c_message_create(self, message: C2CMessage): await channel._on_message(message, is_group=False) - async def on_group_at_message_create(self, message: "GroupMessage"): + async def on_group_at_message_create(self, message: GroupMessage): await channel._on_message(message, is_group=True) async def on_direct_message_create(self, message): @@ -60,6 +135,13 @@ class QQConfig(Base): allow_from: list[str] = Field(default_factory=list) msg_format: Literal["plain", "markdown"] = "plain" + # Optional: directory to save inbound attachments. If empty, use nanobot get_media_dir("qq"). + media_dir: str = "" + + # Download tuning + download_chunk_size: int = 1024 * 256 # 256KB + download_max_bytes: int = 1024 * 1024 * 200 # 200MB safety limit + class QQChannel(BaseChannel): """QQ channel using botpy SDK with WebSocket connection.""" @@ -76,13 +158,38 @@ def __init__(self, config: Any, bus: MessageBus): config = QQConfig.model_validate(config) super().__init__(config, bus) self.config: QQConfig = config - self._client: "botpy.Client | None" = None - self._processed_ids: deque = deque(maxlen=1000) - self._msg_seq: int = 1 # ζΆˆζ―εΊεˆ—ε·οΌŒιΏε…θ’« QQ API εŽ»ι‡ + + self._client: botpy.Client | None = None + self._http: aiohttp.ClientSession | None = None + + self._processed_ids: deque[str] = deque(maxlen=1000) + self._msg_seq: int = 1 # used to avoid QQ API dedup self._chat_type_cache: dict[str, str] = {} + self._media_root: Path = self._init_media_root() + + # --------------------------- + # Lifecycle + # --------------------------- + + def _init_media_root(self) -> Path: + """Choose a directory for saving inbound attachments.""" + if self.config.media_dir: + root = Path(self.config.media_dir).expanduser() + elif get_media_dir: + try: + root = Path(get_media_dir("qq")) + except Exception: + root = Path.home() / ".nanobot" / "media" / "qq" + else: + root = Path.home() / ".nanobot" / "media" / "qq" + + root.mkdir(parents=True, exist_ok=True) + logger.info("QQ media directory: {}", str(root)) + return root + async def start(self) -> None: - """Start the QQ bot.""" + """Start the QQ bot with auto-reconnect loop.""" if not QQ_AVAILABLE: logger.error("QQ SDK not installed. Run: pip install qq-botpy") return @@ -92,8 +199,9 @@ async def start(self) -> None: return self._running = True - BotClass = _make_bot_class(self) - self._client = BotClass() + self._http = aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=120)) + + self._client = _make_bot_class(self)() logger.info("QQ bot started (C2C & Group supported)") await self._run_bot() @@ -109,75 +217,423 @@ async def _run_bot(self) -> None: await asyncio.sleep(5) async def stop(self) -> None: - """Stop the QQ bot.""" + """Stop bot and cleanup resources.""" self._running = False if self._client: try: await self._client.close() except Exception: pass + self._client = None + + if self._http: + try: + await self._http.close() + except Exception: + pass + self._http = None + logger.info("QQ bot stopped") + # --------------------------- + # Outbound (send) + # --------------------------- + async def send(self, msg: OutboundMessage) -> None: - """Send a message through QQ.""" + """Send attachments first, then text.""" if not self._client: logger.warning("QQ client not initialized") return + msg_id = msg.metadata.get("message_id") + chat_type = self._chat_type_cache.get(msg.chat_id, "c2c") + is_group = chat_type == "group" + + # 1) Send media + for media_ref in msg.media or []: + ok = await self._send_media( + chat_id=msg.chat_id, + media_ref=media_ref, + msg_id=msg_id, + is_group=is_group, + ) + if not ok: + filename = ( + os.path.basename(urlparse(media_ref).path) + or os.path.basename(media_ref) + or "file" + ) + await self._send_text_only( + chat_id=msg.chat_id, + is_group=is_group, + msg_id=msg_id, + content=f"[Attachment send failed: {filename}]", + ) + + # 2) Send text + if msg.content and msg.content.strip(): + await self._send_text_only( + chat_id=msg.chat_id, + is_group=is_group, + msg_id=msg_id, + content=msg.content.strip(), + ) + + async def _send_text_only( + self, + chat_id: str, + is_group: bool, + msg_id: str | None, + content: str, + ) -> None: + """Send a plain/markdown text message.""" + if not self._client: + return + + self._msg_seq += 1 + use_markdown = self.config.msg_format == "markdown" + payload: dict[str, Any] = { + "msg_type": 2 if use_markdown else 0, + "msg_id": msg_id, + "msg_seq": self._msg_seq, + } + if use_markdown: + payload["markdown"] = {"content": content} + else: + payload["content"] = content + + if is_group: + await self._client.api.post_group_message(group_openid=chat_id, **payload) + else: + await self._client.api.post_c2c_message(openid=chat_id, **payload) + + async def _send_media( + self, + chat_id: str, + media_ref: str, + msg_id: str | None, + is_group: bool, + ) -> bool: + """Read bytes -> base64 upload -> msg_type=7 send.""" + if not self._client: + return False + + data, filename = await self._read_media_bytes(media_ref) + if not data or not filename: + return False + try: - msg_id = msg.metadata.get("message_id") - self._msg_seq += 1 - use_markdown = self.config.msg_format == "markdown" - payload: dict[str, Any] = { - "msg_type": 2 if use_markdown else 0, - "msg_id": msg_id, - "msg_seq": self._msg_seq, - } - if use_markdown: - payload["markdown"] = {"content": msg.content} - else: - payload["content"] = msg.content + file_type = _guess_send_file_type(filename) + file_data_b64 = base64.b64encode(data).decode() + + media_obj = await self._post_base64file( + chat_id=chat_id, + is_group=is_group, + file_type=file_type, + file_data=file_data_b64, + file_name=filename, + srv_send_msg=False, + ) + if not media_obj: + logger.error("QQ media upload failed: empty response") + return False - chat_type = self._chat_type_cache.get(msg.chat_id, "c2c") - if chat_type == "group": + self._msg_seq += 1 + if is_group: await self._client.api.post_group_message( - group_openid=msg.chat_id, - **payload, + group_openid=chat_id, + msg_type=7, + msg_id=msg_id, + msg_seq=self._msg_seq, + media=media_obj, ) else: await self._client.api.post_c2c_message( - openid=msg.chat_id, - **payload, + openid=chat_id, + msg_type=7, + msg_id=msg_id, + msg_seq=self._msg_seq, + media=media_obj, ) + + logger.info("QQ media sent: {}", filename) + return True except Exception as e: - logger.error("Error sending QQ message: {}", e) + logger.error("QQ send media failed filename={} err={}", filename, e) + return False + + async def _read_media_bytes(self, media_ref: str) -> tuple[bytes | None, str | None]: + """Read bytes from http(s) or local file path; return (data, filename).""" + media_ref = (media_ref or "").strip() + if not media_ref: + return None, None + + # Local file: plain path or file:// URI + if not media_ref.startswith("http://") and not media_ref.startswith("https://"): + try: + if media_ref.startswith("file://"): + parsed = urlparse(media_ref) + # Windows: path in netloc; Unix: path in path + raw = parsed.path or parsed.netloc + local_path = Path(unquote(raw)) + else: + local_path = Path(os.path.expanduser(media_ref)) + + if not local_path.is_file(): + logger.warning("QQ outbound media file not found: {}", str(local_path)) + return None, None + + data = await asyncio.to_thread(local_path.read_bytes) + return data, local_path.name + except Exception as e: + logger.warning("QQ outbound media read error ref={} err={}", media_ref, e) + return None, None + + # Remote URL + ok, err = validate_url_target(media_ref) + if not ok: + logger.warning("QQ outbound media URL validation failed url={} err={}", media_ref, err) + return None, None - async def _on_message(self, data: "C2CMessage | GroupMessage", is_group: bool = False) -> None: - """Handle incoming message from QQ.""" + if not self._http: + self._http = aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=120)) try: - # Dedup by message ID - if data.id in self._processed_ids: - return - self._processed_ids.append(data.id) + async with self._http.get(media_ref, allow_redirects=True) as resp: + if resp.status >= 400: + logger.warning( + "QQ outbound media download failed status={} url={}", + resp.status, + media_ref, + ) + return None, None + data = await resp.read() + if not data: + return None, None + filename = os.path.basename(urlparse(media_ref).path) or "file.bin" + return data, filename + except Exception as e: + logger.warning("QQ outbound media download error url={} err={}", media_ref, e) + return None, None + + # https://github.com/tencent-connect/botpy/issues/198 + # https://bot.q.qq.com/wiki/develop/api-v2/server-inter/message/send-receive/rich-media.html + async def _post_base64file( + self, + chat_id: str, + is_group: bool, + file_type: int, + file_data: str, + file_name: str | None = None, + srv_send_msg: bool = False, + ) -> Media: + """Upload base64-encoded file and return Media object.""" + if not self._client: + raise RuntimeError("QQ client not initialized") + + if is_group: + endpoint = "/v2/groups/{group_openid}/files" + id_key = "group_openid" + else: + endpoint = "/v2/users/{openid}/files" + id_key = "openid" + + payload = { + id_key: chat_id, + "file_type": file_type, + "file_data": file_data, + "file_name": file_name, + "srv_send_msg": srv_send_msg, + } + route = Route("POST", endpoint, **{id_key: chat_id}) + return await self._client.api._http.request(route, json=payload) + + # --------------------------- + # Inbound (receive) + # --------------------------- + + async def _on_message(self, data: C2CMessage | GroupMessage, is_group: bool = False) -> None: + """Parse inbound message, download attachments, and publish to the bus.""" + if data.id in self._processed_ids: + return + self._processed_ids.append(data.id) + + if is_group: + chat_id = data.group_openid + user_id = data.author.member_openid + self._chat_type_cache[chat_id] = "group" + else: + chat_id = str( + getattr(data.author, "id", None) or getattr(data.author, "user_openid", "unknown") + ) + user_id = chat_id + self._chat_type_cache[chat_id] = "c2c" - content = (data.content or "").strip() - if not content: - return + content = (data.content or "").strip() - if is_group: - chat_id = data.group_openid - user_id = data.author.member_openid - self._chat_type_cache[chat_id] = "group" - else: - chat_id = str(getattr(data.author, 'id', None) or getattr(data.author, 'user_openid', 'unknown')) - user_id = chat_id - self._chat_type_cache[chat_id] = "c2c" + # the data used by tests don't contain attachments property + # so we use getattr with a default of [] to avoid AttributeError in tests + attachments = getattr(data, "attachments", None) or [] + media_paths, recv_lines, att_meta = await self._handle_attachments(attachments) - await self._handle_message( - sender_id=user_id, - chat_id=chat_id, - content=content, - metadata={"message_id": data.id}, + # Compose content that always contains actionable saved paths + if recv_lines: + tag = "[Image]" if any(_is_image_name(Path(p).name) for p in media_paths) else "[File]" + file_block = "Received files:\n" + "\n".join(recv_lines) + content = f"{content}\n\n{file_block}".strip() if content else f"{tag}\n{file_block}" + + if not content and not media_paths: + return + + await self._handle_message( + sender_id=user_id, + chat_id=chat_id, + content=content, + media=media_paths if media_paths else None, + metadata={ + "message_id": data.id, + "attachments": att_meta, + }, + ) + + async def _handle_attachments( + self, + attachments: list[BaseMessage._Attachments], + ) -> tuple[list[str], list[str], list[dict[str, Any]]]: + """Extract, download (chunked), and format attachments for agent consumption.""" + media_paths: list[str] = [] + recv_lines: list[str] = [] + att_meta: list[dict[str, Any]] = [] + + if not attachments: + return media_paths, recv_lines, att_meta + + for att in attachments: + url, filename, ctype = att.url, att.filename, att.content_type + + logger.info("Downloading file from QQ: {}", filename or url) + local_path = await self._download_to_media_dir_chunked(url, filename_hint=filename) + + att_meta.append( + { + "url": url, + "filename": filename, + "content_type": ctype, + "saved_path": local_path, + } ) - except Exception: - logger.exception("Error handling QQ message") + + if local_path: + media_paths.append(local_path) + shown_name = filename or os.path.basename(local_path) + recv_lines.append(f"- {shown_name}\n saved: {local_path}") + else: + shown_name = filename or url + recv_lines.append(f"- {shown_name}\n saved: [download failed]") + + return media_paths, recv_lines, att_meta + + async def _download_to_media_dir_chunked( + self, + url: str, + filename_hint: str = "", + ) -> str | None: + """Download an inbound attachment using streaming chunk write. + + Uses chunked streaming to avoid loading large files into memory. + Enforces a max download size and writes to a .part temp file + that is atomically renamed on success. + """ + if not self._http: + self._http = aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=120)) + + safe = _sanitize_filename(filename_hint) + ts = int(time.time() * 1000) + tmp_path: Path | None = None + + try: + async with self._http.get( + url, + timeout=aiohttp.ClientTimeout(total=120), + allow_redirects=True, + ) as resp: + if resp.status != 200: + logger.warning("QQ download failed: status={} url={}", resp.status, url) + return None + + ctype = (resp.headers.get("Content-Type") or "").lower() + + # Infer extension: url -> filename_hint -> content-type -> fallback + ext = Path(urlparse(url).path).suffix + if not ext: + ext = Path(filename_hint).suffix + if not ext: + if "png" in ctype: + ext = ".png" + elif "jpeg" in ctype or "jpg" in ctype: + ext = ".jpg" + elif "gif" in ctype: + ext = ".gif" + elif "webp" in ctype: + ext = ".webp" + elif "pdf" in ctype: + ext = ".pdf" + else: + ext = ".bin" + + if safe: + if not Path(safe).suffix: + safe = safe + ext + filename = safe + else: + filename = f"qq_file_{ts}{ext}" + + target = self._media_root / filename + if target.exists(): + target = self._media_root / f"{target.stem}_{ts}{target.suffix}" + + tmp_path = target.with_suffix(target.suffix + ".part") + + # Stream write + downloaded = 0 + chunk_size = max(1024, int(self.config.download_chunk_size or 262144)) + max_bytes = max( + 1024 * 1024, int(self.config.download_max_bytes or (200 * 1024 * 1024)) + ) + + def _open_tmp(): + tmp_path.parent.mkdir(parents=True, exist_ok=True) + return open(tmp_path, "wb") # noqa: SIM115 + + f = await asyncio.to_thread(_open_tmp) + try: + async for chunk in resp.content.iter_chunked(chunk_size): + if not chunk: + continue + downloaded += len(chunk) + if downloaded > max_bytes: + logger.warning( + "QQ download exceeded max_bytes={} url={} -> abort", + max_bytes, + url, + ) + return None + await asyncio.to_thread(f.write, chunk) + finally: + await asyncio.to_thread(f.close) + + # Atomic rename + await asyncio.to_thread(os.replace, tmp_path, target) + tmp_path = None # mark as moved + logger.info("QQ file saved: {}", str(target)) + return str(target) + + except Exception as e: + logger.error("QQ download error: {}", e) + return None + finally: + # Cleanup partial file + if tmp_path is not None: + try: + tmp_path.unlink(missing_ok=True) + except Exception: + pass diff --git a/nanobot/channels/slack.py b/nanobot/channels/slack.py index 87194ac705c..2503f6a2d0c 100644 --- a/nanobot/channels/slack.py +++ b/nanobot/channels/slack.py @@ -145,6 +145,7 @@ async def send(self, msg: OutboundMessage) -> None: except Exception as e: logger.error("Error sending Slack message: {}", e) + raise async def _on_socket_request( self, diff --git a/nanobot/channels/telegram.py b/nanobot/channels/telegram.py index 850e09c0f67..feb908657fc 100644 --- a/nanobot/channels/telegram.py +++ b/nanobot/channels/telegram.py @@ -11,8 +11,8 @@ from loguru import logger from pydantic import Field -from telegram import BotCommand, ReplyParameters, Update -from telegram.error import TimedOut +from telegram import BotCommand, ReactionTypeEmoji, ReplyParameters, Update +from telegram.error import BadRequest, TimedOut from telegram.ext import Application, CommandHandler, ContextTypes, MessageHandler, filters from telegram.request import HTTPXRequest @@ -163,6 +163,7 @@ class _StreamBuf: text: str = "" message_id: int | None = None last_edit: float = 0.0 + stream_id: str | None = None class TelegramConfig(Base): @@ -173,6 +174,7 @@ class TelegramConfig(Base): allow_from: list[str] = Field(default_factory=list) proxy: str | None = None reply_to_message: bool = False + react_emoji: str = "πŸ‘€" group_policy: Literal["open", "mention"] = "mention" connection_pool_size: int = 32 pool_timeout: float = 5.0 @@ -475,6 +477,11 @@ async def _send_text( ) except Exception as e2: logger.error("Error sending Telegram message: {}", e2) + raise + + @staticmethod + def _is_not_modified_error(exc: Exception) -> bool: + return isinstance(exc, BadRequest) and "message is not modified" in str(exc).lower() async def send_delta(self, chat_id: str, delta: str, metadata: dict[str, Any] | None = None) -> None: """Progressive message editing: send on first delta, edit on subsequent ones.""" @@ -482,11 +489,14 @@ async def send_delta(self, chat_id: str, delta: str, metadata: dict[str, Any] | return meta = metadata or {} int_chat_id = int(chat_id) + stream_id = meta.get("_stream_id") if meta.get("_stream_end"): - buf = self._stream_bufs.pop(chat_id, None) + buf = self._stream_bufs.get(chat_id) if not buf or not buf.message_id or not buf.text: return + if stream_id is not None and buf.stream_id is not None and buf.stream_id != stream_id: + return self._stop_typing(chat_id) try: html = _markdown_to_telegram_html(buf.text) @@ -496,6 +506,10 @@ async def send_delta(self, chat_id: str, delta: str, metadata: dict[str, Any] | text=html, parse_mode="HTML", ) except Exception as e: + if self._is_not_modified_error(e): + logger.debug("Final stream edit already applied for {}", chat_id) + self._stream_bufs.pop(chat_id, None) + return logger.debug("Final stream edit failed (HTML), trying plain: {}", e) try: await self._call_with_retry( @@ -503,14 +517,22 @@ async def send_delta(self, chat_id: str, delta: str, metadata: dict[str, Any] | chat_id=int_chat_id, message_id=buf.message_id, text=buf.text, ) - except Exception: - pass + except Exception as e2: + if self._is_not_modified_error(e2): + logger.debug("Final stream plain edit already applied for {}", chat_id) + self._stream_bufs.pop(chat_id, None) + return + logger.warning("Final stream edit failed: {}", e2) + raise # Let ChannelManager handle retry + self._stream_bufs.pop(chat_id, None) return buf = self._stream_bufs.get(chat_id) - if buf is None: - buf = _StreamBuf() + if buf is None or (stream_id is not None and buf.stream_id is not None and buf.stream_id != stream_id): + buf = _StreamBuf(stream_id=stream_id) self._stream_bufs[chat_id] = buf + elif buf.stream_id is None: + buf.stream_id = stream_id buf.text += delta if not buf.text.strip(): @@ -527,6 +549,7 @@ async def send_delta(self, chat_id: str, delta: str, metadata: dict[str, Any] | buf.last_edit = now except Exception as e: logger.warning("Stream initial send failed: {}", e) + raise # Let ChannelManager handle retry elif (now - buf.last_edit) >= self._STREAM_EDIT_INTERVAL: try: await self._call_with_retry( @@ -535,8 +558,12 @@ async def send_delta(self, chat_id: str, delta: str, metadata: dict[str, Any] | text=buf.text, ) buf.last_edit = now - except Exception: - pass + except Exception as e: + if self._is_not_modified_error(e): + buf.last_edit = now + return + logger.warning("Stream edit failed: {}", e) + raise # Let ChannelManager handle retry async def _on_start(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: """Handle /start command.""" @@ -812,6 +839,7 @@ async def _on_message(self, update: Update, context: ContextTypes.DEFAULT_TYPE) "session_key": session_key, } self._start_typing(str_chat_id) + await self._add_reaction(str_chat_id, message.message_id, self.config.react_emoji) buf = self._media_group_buffers[key] if content and content != "[empty message]": buf["contents"].append(content) @@ -822,6 +850,7 @@ async def _on_message(self, update: Update, context: ContextTypes.DEFAULT_TYPE) # Start typing indicator before processing self._start_typing(str_chat_id) + await self._add_reaction(str_chat_id, message.message_id, self.config.react_emoji) # Forward to the message bus await self._handle_message( @@ -861,6 +890,19 @@ def _stop_typing(self, chat_id: str) -> None: if task and not task.done(): task.cancel() + async def _add_reaction(self, chat_id: str, message_id: int, emoji: str) -> None: + """Add emoji reaction to a message (best-effort, non-blocking).""" + if not self._app or not emoji: + return + try: + await self._app.bot.set_message_reaction( + chat_id=int(chat_id), + message_id=message_id, + reaction=[ReactionTypeEmoji(emoji=emoji)], + ) + except Exception as e: + logger.debug("Telegram reaction failed: {}", e) + async def _typing_loop(self, chat_id: str) -> None: """Repeatedly send 'typing' action until cancelled.""" try: diff --git a/nanobot/channels/wecom.py b/nanobot/channels/wecom.py index 2f248559ec2..05ad1482545 100644 --- a/nanobot/channels/wecom.py +++ b/nanobot/channels/wecom.py @@ -368,3 +368,4 @@ async def send(self, msg: OutboundMessage) -> None: except Exception as e: logger.error("Error sending WeCom message: {}", e) + raise diff --git a/nanobot/channels/weixin.py b/nanobot/channels/weixin.py index 48a97f582db..f09ef95f7d0 100644 --- a/nanobot/channels/weixin.py +++ b/nanobot/channels/weixin.py @@ -4,7 +4,7 @@ No WebSocket, no local WeChat client needed β€” just HTTP requests with a bot token obtained via QR code login. -Protocol reverse-engineered from ``@tencent-weixin/openclaw-weixin`` v1.0.2. +Protocol reverse-engineered from ``@tencent-weixin/openclaw-weixin`` v1.0.3. """ from __future__ import annotations @@ -53,15 +53,18 @@ MESSAGE_STATE_FINISH = 2 WEIXIN_MAX_MESSAGE_LEN = 4000 -BASE_INFO: dict[str, str] = {"channel_version": "1.0.2"} +WEIXIN_CHANNEL_VERSION = "1.0.3" +BASE_INFO: dict[str, str] = {"channel_version": WEIXIN_CHANNEL_VERSION} # Session-expired error code ERRCODE_SESSION_EXPIRED = -14 +SESSION_PAUSE_DURATION_S = 60 * 60 # Retry constants (matching the reference plugin's monitor.ts) MAX_CONSECUTIVE_FAILURES = 3 BACKOFF_DELAY_S = 30 RETRY_DELAY_S = 2 +MAX_QR_REFRESH_COUNT = 3 # Default long-poll timeout; overridden by server via longpolling_timeout_ms. DEFAULT_LONG_POLL_TIMEOUT_S = 35 @@ -83,6 +86,7 @@ class WeixinConfig(Base): allow_from: list[str] = Field(default_factory=list) base_url: str = "https://ilinkai.weixin.qq.com" cdn_base_url: str = "https://novac2c.cdn.weixin.qq.com/c2c" + route_tag: str | int | None = None token: str = "" # Manually set token, or obtained via QR login state_dir: str = "" # Default: ~/.nanobot/weixin/ poll_timeout: int = DEFAULT_LONG_POLL_TIMEOUT_S # seconds for long-poll @@ -119,6 +123,7 @@ def __init__(self, config: Any, bus: MessageBus): self._token: str = "" self._poll_task: asyncio.Task | None = None self._next_poll_timeout_s: int = DEFAULT_LONG_POLL_TIMEOUT_S + self._session_pause_until: float = 0.0 # ------------------------------------------------------------------ # State persistence @@ -144,6 +149,15 @@ def _load_state(self) -> bool: data = json.loads(state_file.read_text()) self._token = data.get("token", "") self._get_updates_buf = data.get("get_updates_buf", "") + context_tokens = data.get("context_tokens", {}) + if isinstance(context_tokens, dict): + self._context_tokens = { + str(user_id): str(token) + for user_id, token in context_tokens.items() + if str(user_id).strip() and str(token).strip() + } + else: + self._context_tokens = {} base_url = data.get("base_url", "") if base_url: self.config.base_url = base_url @@ -158,6 +172,7 @@ def _save_state(self) -> None: data = { "token": self._token, "get_updates_buf": self._get_updates_buf, + "context_tokens": self._context_tokens, "base_url": self.config.base_url, } state_file.write_text(json.dumps(data, ensure_ascii=False)) @@ -187,6 +202,8 @@ def _make_headers(self, *, auth: bool = True) -> dict[str, str]: } if auth and self._token: headers["Authorization"] = f"Bearer {self._token}" + if self.config.route_tag is not None and str(self.config.route_tag).strip(): + headers["SKRouteTag"] = str(self.config.route_tag).strip() return headers async def _api_get( @@ -226,24 +243,25 @@ async def _api_post( # QR Code Login (matches login-qr.ts) # ------------------------------------------------------------------ + async def _fetch_qr_code(self) -> tuple[str, str]: + """Fetch a fresh QR code. Returns (qrcode_id, scan_url).""" + data = await self._api_get( + "ilink/bot/get_bot_qrcode", + params={"bot_type": "3"}, + auth=False, + ) + qrcode_img_content = data.get("qrcode_img_content", "") + qrcode_id = data.get("qrcode", "") + if not qrcode_id: + raise RuntimeError(f"Failed to get QR code from WeChat API: {data}") + return qrcode_id, (qrcode_img_content or qrcode_id) + async def _qr_login(self) -> bool: """Perform QR code login flow. Returns True on success.""" try: logger.info("Starting WeChat QR code login...") - - data = await self._api_get( - "ilink/bot/get_bot_qrcode", - params={"bot_type": "3"}, - auth=False, - ) - qrcode_img_content = data.get("qrcode_img_content", "") - qrcode_id = data.get("qrcode", "") - - if not qrcode_id: - logger.error("Failed to get QR code from WeChat API: {}", data) - return False - - scan_url = qrcode_img_content or qrcode_id + refresh_count = 0 + qrcode_id, scan_url = await self._fetch_qr_code() self._print_qr_code(scan_url) logger.info("Waiting for QR code scan...") @@ -283,8 +301,23 @@ async def _qr_login(self) -> bool: elif status == "scaned": logger.info("QR code scanned, waiting for confirmation...") elif status == "expired": - logger.warning("QR code expired") - return False + refresh_count += 1 + if refresh_count > MAX_QR_REFRESH_COUNT: + logger.warning( + "QR code expired too many times ({}/{}), giving up.", + refresh_count - 1, + MAX_QR_REFRESH_COUNT, + ) + return False + logger.warning( + "QR code expired, refreshing... ({}/{})", + refresh_count, + MAX_QR_REFRESH_COUNT, + ) + qrcode_id, scan_url = await self._fetch_qr_code() + self._print_qr_code(scan_url) + logger.info("New QR code generated, waiting for scan...") + continue # status == "wait" β€” keep polling await asyncio.sleep(1) @@ -392,7 +425,34 @@ async def stop(self) -> None: # Polling (matches monitor.ts monitorWeixinProvider) # ------------------------------------------------------------------ + def _pause_session(self, duration_s: int = SESSION_PAUSE_DURATION_S) -> None: + self._session_pause_until = time.time() + duration_s + + def _session_pause_remaining_s(self) -> int: + remaining = int(self._session_pause_until - time.time()) + if remaining <= 0: + self._session_pause_until = 0.0 + return 0 + return remaining + + def _assert_session_active(self) -> None: + remaining = self._session_pause_remaining_s() + if remaining > 0: + remaining_min = max((remaining + 59) // 60, 1) + raise RuntimeError( + f"WeChat session paused, {remaining_min} min remaining (errcode {ERRCODE_SESSION_EXPIRED})" + ) + async def _poll_once(self) -> None: + remaining = self._session_pause_remaining_s() + if remaining > 0: + logger.warning( + "WeChat session paused, waiting {} min before next poll.", + max((remaining + 59) // 60, 1), + ) + await asyncio.sleep(remaining) + return + body: dict[str, Any] = { "get_updates_buf": self._get_updates_buf, "base_info": BASE_INFO, @@ -411,11 +471,13 @@ async def _poll_once(self) -> None: if is_error: if errcode == ERRCODE_SESSION_EXPIRED or ret == ERRCODE_SESSION_EXPIRED: + self._pause_session() + remaining = self._session_pause_remaining_s() logger.warning( - "WeChat session expired (errcode {}). Pausing 60 min.", + "WeChat session expired (errcode {}). Pausing {} min.", errcode, + max((remaining + 59) // 60, 1), ) - await asyncio.sleep(3600) return raise RuntimeError( f"getUpdates failed: ret={ret} errcode={errcode} errmsg={data.get('errmsg', '')}" @@ -468,6 +530,7 @@ async def _process_message(self, msg: dict) -> None: ctx_token = msg.get("context_token", "") if ctx_token: self._context_tokens[from_user_id] = ctx_token + self._save_state() # Parse item_list (WeixinMessage.item_list β€” types.ts:161) item_list: list[dict] = msg.get("item_list") or [] @@ -651,6 +714,11 @@ async def send(self, msg: OutboundMessage) -> None: if not self._client or not self._token: logger.warning("WeChat client not initialized or not authenticated") return + try: + self._assert_session_active() + except RuntimeError as e: + logger.warning("WeChat send blocked: {}", e) + return content = msg.content.strip() ctx_token = self._context_tokens.get(msg.chat_id, "") @@ -683,6 +751,7 @@ async def send(self, msg: OutboundMessage) -> None: await self._send_text(msg.chat_id, chunk, ctx_token) except Exception as e: logger.error("Error sending WeChat message: {}", e) + raise async def _send_text( self, @@ -731,7 +800,7 @@ async def _send_media_file( ) -> None: """Upload a local file to WeChat CDN and send it as a media message. - Follows the exact protocol from ``@tencent-weixin/openclaw-weixin`` v1.0.2: + Follows the exact protocol from ``@tencent-weixin/openclaw-weixin`` v1.0.3: 1. Generate a random 16-byte AES key (client-side). 2. Call ``getuploadurl`` with file metadata + hex-encoded AES key. 3. AES-128-ECB encrypt the file and POST to CDN (``{cdnBaseUrl}/upload``). diff --git a/nanobot/channels/whatsapp.py b/nanobot/channels/whatsapp.py index 7239888b163..95bde46e9fa 100644 --- a/nanobot/channels/whatsapp.py +++ b/nanobot/channels/whatsapp.py @@ -26,6 +26,7 @@ class WhatsAppConfig(Base): bridge_url: str = "ws://localhost:3001" bridge_token: str = "" allow_from: list[str] = Field(default_factory=list) + group_policy: Literal["open", "mention"] = "open" # "open" responds to all, "mention" only when @mentioned class WhatsAppChannel(BaseChannel): @@ -145,6 +146,7 @@ async def send(self, msg: OutboundMessage) -> None: await self._ws.send(json.dumps(payload, ensure_ascii=False)) except Exception as e: logger.error("Error sending WhatsApp message: {}", e) + raise for media_path in msg.media or []: try: @@ -159,6 +161,7 @@ async def send(self, msg: OutboundMessage) -> None: await self._ws.send(json.dumps(payload, ensure_ascii=False)) except Exception as e: logger.error("Error sending WhatsApp media {}: {}", media_path, e) + raise async def _handle_bridge_message(self, raw: str) -> None: """Handle a message from the bridge.""" @@ -187,6 +190,13 @@ async def _handle_bridge_message(self, raw: str) -> None: self._processed_message_ids.popitem(last=False) # Extract just the phone number or lid as chat_id + is_group = data.get("isGroup", False) + was_mentioned = data.get("wasMentioned", False) + + if is_group and getattr(self.config, "group_policy", "open") == "mention": + if not was_mentioned: + return + user_id = pn if pn else sender sender_id = user_id.split("@")[0] if "@" in user_id else user_id logger.info("Sender {}", sender) diff --git a/nanobot/cli/commands.py b/nanobot/cli/commands.py index d492913147f..6b54f51a381 100644 --- a/nanobot/cli/commands.py +++ b/nanobot/cli/commands.py @@ -34,7 +34,7 @@ from nanobot import __logo__, __version__ from nanobot.cli.stream import StreamRenderer, ThinkingSpinner -from nanobot.config.paths import get_workspace_path +from nanobot.config.paths import get_workspace_path, is_default_workspace from nanobot.config.schema import Config from nanobot.utils.helpers import sync_workspace_templates @@ -376,61 +376,61 @@ def _onboard_plugins(config_path: Path) -> None: def _make_provider(config: Config): - """Create the appropriate LLM provider from config.""" - from nanobot.providers.azure_openai_provider import AzureOpenAIProvider + """Create the appropriate LLM provider from config. + + Routing is driven by ``ProviderSpec.backend`` in the registry. + """ from nanobot.providers.base import GenerationSettings - from nanobot.providers.openai_codex_provider import OpenAICodexProvider + from nanobot.providers.registry import find_by_name model = config.agents.defaults.model provider_name = config.get_provider_name(model) p = config.get_provider(model) + spec = find_by_name(provider_name) if provider_name else None + backend = spec.backend if spec else "openai_compat" - # OpenAI Codex (OAuth) - if provider_name == "openai_codex" or model.startswith("openai-codex/"): - provider = OpenAICodexProvider(default_model=model) - # Custom: direct OpenAI-compatible endpoint, bypasses LiteLLM - elif provider_name == "custom": - from nanobot.providers.custom_provider import CustomProvider - provider = CustomProvider( - api_key=p.api_key if p else "no-key", - api_base=config.get_api_base(model) or "http://localhost:8000/v1", - default_model=model, - extra_headers=p.extra_headers if p else None, - ) - # Azure OpenAI: direct Azure OpenAI endpoint with deployment name - elif provider_name == "azure_openai": + # --- validation --- + if backend == "azure_openai": if not p or not p.api_key or not p.api_base: console.print("[red]Error: Azure OpenAI requires api_key and api_base.[/red]") console.print("Set them in ~/.nanobot/config.json under providers.azure_openai section") console.print("Use the model field to specify the deployment name.") raise typer.Exit(1) + elif backend == "openai_compat" and not model.startswith("bedrock/"): + needs_key = not (p and p.api_key) + exempt = spec and (spec.is_oauth or spec.is_local or spec.is_direct) + if needs_key and not exempt: + console.print("[red]Error: No API key configured.[/red]") + console.print("Set one in ~/.nanobot/config.json under providers section") + raise typer.Exit(1) + + # --- instantiation by backend --- + if backend == "openai_codex": + from nanobot.providers.openai_codex_provider import OpenAICodexProvider + provider = OpenAICodexProvider(default_model=model) + elif backend == "azure_openai": + from nanobot.providers.azure_openai_provider import AzureOpenAIProvider provider = AzureOpenAIProvider( api_key=p.api_key, api_base=p.api_base, default_model=model, ) - # OpenVINO Model Server: direct OpenAI-compatible endpoint at /v3 - elif provider_name == "ovms": - from nanobot.providers.custom_provider import CustomProvider - provider = CustomProvider( - api_key=p.api_key if p else "no-key", - api_base=config.get_api_base(model) or "http://localhost:8000/v3", + elif backend == "anthropic": + from nanobot.providers.anthropic_provider import AnthropicProvider + provider = AnthropicProvider( + api_key=p.api_key if p else None, + api_base=config.get_api_base(model), default_model=model, + extra_headers=p.extra_headers if p else None, ) else: - from nanobot.providers.litellm_provider import LiteLLMProvider - from nanobot.providers.registry import find_by_name - spec = find_by_name(provider_name) - if not model.startswith("bedrock/") and not (p and p.api_key) and not (spec and (spec.is_oauth or spec.is_local)): - console.print("[red]Error: No API key configured.[/red]") - console.print("Set one in ~/.nanobot/config.json under providers section") - raise typer.Exit(1) - provider = LiteLLMProvider( + from nanobot.providers.openai_compat_provider import OpenAICompatProvider + provider = OpenAICompatProvider( api_key=p.api_key if p else None, api_base=config.get_api_base(model), default_model=model, extra_headers=p.extra_headers if p else None, - provider_name=provider_name, + spec=spec, suppress_tools_param=p.suppress_tools_param if p else False, request_timeout=p.request_timeout if p else None, ) @@ -481,6 +481,17 @@ def _warn_deprecated_config_keys(config_path: Path | None) -> None: ) +def _migrate_cron_store(config: "Config") -> None: + """One-time migration: move legacy global cron store into the workspace.""" + from nanobot.config.paths import get_cron_dir + + legacy_path = get_cron_dir() / "jobs.json" + new_path = config.workspace_path / "cron" / "jobs.json" + if legacy_path.is_file() and not new_path.exists(): + new_path.parent.mkdir(parents=True, exist_ok=True) + import shutil + shutil.move(str(legacy_path), str(new_path)) + # ============================================================================ # Gateway / Server @@ -498,7 +509,6 @@ def gateway( from nanobot.agent.loop import AgentLoop from nanobot.bus.queue import MessageBus from nanobot.channels.manager import ChannelManager - from nanobot.config.paths import get_cron_dir from nanobot.cron.service import CronService from nanobot.cron.types import CronJob from nanobot.heartbeat.service import HeartbeatService @@ -517,8 +527,12 @@ def gateway( provider = _make_provider(config) session_manager = SessionManager(config.workspace_path) - # Create cron service first (callback set after agent creation) - cron_store_path = get_cron_dir() / "jobs.json" + # Preserve existing single-workspace installs, but keep custom workspaces clean. + if is_default_workspace(config.workspace_path): + _migrate_cron_store(config) + + # Create cron service with workspace-scoped store + cron_store_path = config.workspace_path / "cron" / "jobs.json" cron = CronService(cron_store_path) # Create agent with cron service @@ -542,6 +556,7 @@ def gateway( mcp_servers=config.tools.mcp_servers, channels_config=config.channels, sysmon=config.tools.sysmon, + timezone=config.agents.defaults.timezone, ) # Set cron callback (needs agent) @@ -652,6 +667,7 @@ async def on_heartbeat_notify(response: str) -> None: on_notify=on_heartbeat_notify, interval_s=hb_cfg.interval_s, enabled=hb_cfg.enabled, + timezone=config.agents.defaults.timezone, ) if channels.enabled_channels: @@ -710,7 +726,6 @@ def agent( from nanobot.agent.loop import AgentLoop from nanobot.bus.queue import MessageBus - from nanobot.config.paths import get_cron_dir from nanobot.cron.service import CronService config = _load_runtime_config(config, workspace) @@ -719,8 +734,12 @@ def agent( bus = MessageBus() provider = _make_provider(config) - # Create cron service for tool usage (no callback needed for CLI unless running) - cron_store_path = get_cron_dir() / "jobs.json" + # Preserve existing single-workspace installs, but keep custom workspaces clean. + if is_default_workspace(config.workspace_path): + _migrate_cron_store(config) + + # Create cron service with workspace-scoped store + cron_store_path = config.workspace_path / "cron" / "jobs.json" cron = CronService(cron_store_path) if logs: @@ -747,6 +766,7 @@ def agent( mcp_servers=config.tools.mcp_servers, channels_config=config.channels, sysmon=config.tools.sysmon, + timezone=config.agents.defaults.timezone, ) # Shared reference for progress callbacks @@ -1198,11 +1218,20 @@ def _login_openai_codex() -> None: def _login_github_copilot() -> None: import asyncio + from openai import AsyncOpenAI + console.print("[cyan]Starting GitHub Copilot device flow...[/cyan]\n") async def _trigger(): - from litellm import acompletion - await acompletion(model="github_copilot/gpt-4o", messages=[{"role": "user", "content": "hi"}], max_tokens=1) + client = AsyncOpenAI( + api_key="dummy", + base_url="https://api.githubcopilot.com", + ) + await client.chat.completions.create( + model="gpt-4o", + messages=[{"role": "user", "content": "hi"}], + max_tokens=1, + ) try: asyncio.run(_trigger()) diff --git a/nanobot/cli/models.py b/nanobot/cli/models.py index 520370c4b8f..0ba24018fc4 100644 --- a/nanobot/cli/models.py +++ b/nanobot/cli/models.py @@ -1,229 +1,29 @@ """Model information helpers for the onboard wizard. -Provides model context window lookup and autocomplete suggestions using litellm. +Model database / autocomplete is temporarily disabled while litellm is +being replaced. All public function signatures are preserved so callers +continue to work without changes. """ from __future__ import annotations -from functools import lru_cache from typing import Any -def _litellm(): - """Lazy accessor for litellm (heavy import deferred until actually needed).""" - import litellm as _ll - - return _ll - - -@lru_cache(maxsize=1) -def _get_model_cost_map() -> dict[str, Any]: - """Get litellm's model cost map (cached).""" - return getattr(_litellm(), "model_cost", {}) - - -@lru_cache(maxsize=1) def get_all_models() -> list[str]: - """Get all known model names from litellm. - """ - models = set() - - # From model_cost (has pricing info) - cost_map = _get_model_cost_map() - for k in cost_map.keys(): - if k != "sample_spec": - models.add(k) - - # From models_by_provider (more complete provider coverage) - for provider_models in getattr(_litellm(), "models_by_provider", {}).values(): - if isinstance(provider_models, (set, list)): - models.update(provider_models) - - return sorted(models) - - -def _normalize_model_name(model: str) -> str: - """Normalize model name for comparison.""" - return model.lower().replace("-", "_").replace(".", "") + return [] def find_model_info(model_name: str) -> dict[str, Any] | None: - """Find model info with fuzzy matching. - - Args: - model_name: Model name in any common format - - Returns: - Model info dict or None if not found - """ - cost_map = _get_model_cost_map() - if not cost_map: - return None - - # Direct match - if model_name in cost_map: - return cost_map[model_name] - - # Extract base name (without provider prefix) - base_name = model_name.split("/")[-1] if "/" in model_name else model_name - base_normalized = _normalize_model_name(base_name) - - candidates = [] - - for key, info in cost_map.items(): - if key == "sample_spec": - continue - - key_base = key.split("/")[-1] if "/" in key else key - key_base_normalized = _normalize_model_name(key_base) - - # Score the match - score = 0 - - # Exact base name match (highest priority) - if base_normalized == key_base_normalized: - score = 100 - # Base name contains model - elif base_normalized in key_base_normalized: - score = 80 - # Model contains base name - elif key_base_normalized in base_normalized: - score = 70 - # Partial match - elif base_normalized[:10] in key_base_normalized: - score = 50 - - if score > 0: - # Prefer models with max_input_tokens - if info.get("max_input_tokens"): - score += 10 - candidates.append((score, key, info)) - - if not candidates: - return None - - # Return the best match - candidates.sort(key=lambda x: (-x[0], x[1])) - return candidates[0][2] + return None def get_model_context_limit(model: str, provider: str = "auto") -> int | None: - """Get the maximum input context tokens for a model. - - Args: - model: Model name (e.g., "claude-3.5-sonnet", "gpt-4o") - provider: Provider name for informational purposes (not yet used for filtering) - - Returns: - Maximum input tokens, or None if unknown - - Note: - The provider parameter is currently informational only. Future versions may - use it to prefer provider-specific model variants in the lookup. - """ - # First try fuzzy search in model_cost (has more accurate max_input_tokens) - info = find_model_info(model) - if info: - # Prefer max_input_tokens (this is what we want for context window) - max_input = info.get("max_input_tokens") - if max_input and isinstance(max_input, int): - return max_input - - # Fall back to litellm's get_max_tokens (returns max_output_tokens typically) - try: - result = _litellm().get_max_tokens(model) - if result and result > 0: - return result - except (KeyError, ValueError, AttributeError): - # Model not found in litellm's database or invalid response - pass - - # Last resort: use max_tokens from model_cost - if info: - max_tokens = info.get("max_tokens") - if max_tokens and isinstance(max_tokens, int): - return max_tokens - return None -@lru_cache(maxsize=1) -def _get_provider_keywords() -> dict[str, list[str]]: - """Build provider keywords mapping from nanobot's provider registry. - - Returns: - Dict mapping provider name to list of keywords for model filtering. - """ - try: - from nanobot.providers.registry import PROVIDERS - - mapping = {} - for spec in PROVIDERS: - if spec.keywords: - mapping[spec.name] = list(spec.keywords) - return mapping - except ImportError: - return {} - - def get_model_suggestions(partial: str, provider: str = "auto", limit: int = 20) -> list[str]: - """Get autocomplete suggestions for model names. - - Args: - partial: Partial model name typed by user - provider: Provider name for filtering (e.g., "openrouter", "minimax") - limit: Maximum number of suggestions to return - - Returns: - List of matching model names - """ - all_models = get_all_models() - if not all_models: - return [] - - partial_lower = partial.lower() - partial_normalized = _normalize_model_name(partial) - - # Get provider keywords from registry - provider_keywords = _get_provider_keywords() - - # Filter by provider if specified - allowed_keywords = None - if provider and provider != "auto": - allowed_keywords = provider_keywords.get(provider.lower()) - - matches = [] - - for model in all_models: - model_lower = model.lower() - - # Apply provider filter - if allowed_keywords: - if not any(kw in model_lower for kw in allowed_keywords): - continue - - # Match against partial input - if not partial: - matches.append(model) - continue - - if partial_lower in model_lower: - # Score by position of match (earlier = better) - pos = model_lower.find(partial_lower) - score = 100 - pos - matches.append((score, model)) - elif partial_normalized in _normalize_model_name(model): - score = 50 - matches.append((score, model)) - - # Sort by score if we have scored matches - if matches and isinstance(matches[0], tuple): - matches.sort(key=lambda x: (-x[0], x[1])) - matches = [m[1] for m in matches] - else: - matches.sort() - - return matches[:limit] + return [] def format_token_count(tokens: int) -> str: diff --git a/nanobot/config/__init__.py b/nanobot/config/__init__.py index e2c24f80633..4b9fccec3b1 100644 --- a/nanobot/config/__init__.py +++ b/nanobot/config/__init__.py @@ -7,6 +7,7 @@ get_cron_dir, get_data_dir, get_legacy_sessions_dir, + is_default_workspace, get_logs_dir, get_media_dir, get_runtime_subdir, @@ -24,6 +25,7 @@ "get_cron_dir", "get_logs_dir", "get_workspace_path", + "is_default_workspace", "get_cli_history_path", "get_bridge_install_dir", "get_legacy_sessions_dir", diff --git a/nanobot/config/paths.py b/nanobot/config/paths.py index f4dfbd92a27..527c5f38ebb 100644 --- a/nanobot/config/paths.py +++ b/nanobot/config/paths.py @@ -40,6 +40,13 @@ def get_workspace_path(workspace: str | None = None) -> Path: return ensure_dir(path) +def is_default_workspace(workspace: str | Path | None) -> bool: + """Return whether a workspace resolves to nanobot's default workspace path.""" + current = Path(workspace).expanduser() if workspace is not None else Path.home() / ".nanobot" / "workspace" + default = Path.home() / ".nanobot" / "workspace" + return current.resolve(strict=False) == default.resolve(strict=False) + + def get_cli_history_path() -> Path: """Return the shared CLI history file path.""" return Path.home() / ".nanobot" / "history" / "cli_history" diff --git a/nanobot/config/schema.py b/nanobot/config/schema.py index f8a7850e46c..c6ad0c31760 100644 --- a/nanobot/config/schema.py +++ b/nanobot/config/schema.py @@ -190,6 +190,7 @@ class ChannelsConfig(Base): send_progress: bool = True # stream agent's text progress to the channel send_tool_hints: bool = False # stream tool-call hints (e.g. read_file("…")) http: HttpConfig = Field(default_factory=HttpConfig) + send_max_retries: int = Field(default=3, ge=0, le=10) # Max delivery attempts (initial send included) class AgentDefaults(Base): @@ -205,6 +206,7 @@ class AgentDefaults(Base): temperature: float = 0.1 max_tool_iterations: int = 40 reasoning_effort: str | None = None # low / medium / high - enables LLM thinking mode + timezone: str = "UTC" # IANA timezone, e.g. "Asia/Shanghai", "America/New_York" class AgentsConfig(Base): @@ -242,6 +244,7 @@ class ProvidersConfig(Base): moonshot: ProviderConfig = Field(default_factory=ProviderConfig) minimax: ProviderConfig = Field(default_factory=ProviderConfig) mistral: ProviderConfig = Field(default_factory=ProviderConfig) + stepfun: ProviderConfig = Field(default_factory=ProviderConfig) # Step Fun (ι˜Άθ·ƒζ˜ŸθΎ°) aihubmix: ProviderConfig = Field(default_factory=ProviderConfig) # AiHubMix API gateway siliconflow: ProviderConfig = Field(default_factory=ProviderConfig) # SiliconFlow (η‘…εŸΊζ΅εŠ¨) volcengine: ProviderConfig = Field(default_factory=ProviderConfig) # VolcEngine (η«ε±±εΌ•ζ“Ž) @@ -352,12 +355,15 @@ def _match_provider( self, model: str | None = None ) -> tuple["ProviderConfig | None", str | None]: """Match provider config and its registry name. Returns (config, spec_name).""" - from nanobot.providers.registry import PROVIDERS + from nanobot.providers.registry import PROVIDERS, find_by_name forced = self.agents.defaults.provider if forced != "auto": - p = getattr(self.providers, forced, None) - return (p, forced) if p else (None, None) + spec = find_by_name(forced) + if spec: + p = getattr(self.providers, spec.name, None) + return (p, spec.name) if p else (None, None) + return None, None model_lower = (model or self.agents.defaults.model).lower() model_normalized = model_lower.replace("-", "_") @@ -433,8 +439,7 @@ def get_api_base(self, model: str | None = None) -> str | None: if p and p.api_base: return p.api_base # Only gateways get a default api_base here. Standard providers - # (like Moonshot) set their base URL via env vars in _setup_env - # to avoid polluting the global litellm.api_base. + # resolve their base URL from the registry in the provider constructor. if name: spec = find_by_name(name) if spec and (spec.is_gateway or spec.is_local) and spec.default_api_base: diff --git a/nanobot/heartbeat/service.py b/nanobot/heartbeat/service.py index 7be81ff4abe..00f6b17e124 100644 --- a/nanobot/heartbeat/service.py +++ b/nanobot/heartbeat/service.py @@ -59,6 +59,7 @@ def __init__( on_notify: Callable[[str], Coroutine[Any, Any, None]] | None = None, interval_s: int = 30 * 60, enabled: bool = True, + timezone: str | None = None, ): self.workspace = workspace self.provider = provider @@ -67,6 +68,7 @@ def __init__( self.on_notify = on_notify self.interval_s = interval_s self.enabled = enabled + self.timezone = timezone self._running = False self._task: asyncio.Task | None = None @@ -93,7 +95,7 @@ async def _decide(self, content: str) -> tuple[str, str]: messages=[ {"role": "system", "content": "You are a heartbeat agent. Call the heartbeat tool to report your decision."}, {"role": "user", "content": ( - f"Current Time: {current_time_str()}\n\n" + f"Current Time: {current_time_str(self.timezone)}\n\n" "Review the following HEARTBEAT.md and decide whether there are active tasks.\n\n" f"{content}" )}, diff --git a/nanobot/providers/__init__.py b/nanobot/providers/__init__.py index 9d4994eb13c..0e259e6f014 100644 --- a/nanobot/providers/__init__.py +++ b/nanobot/providers/__init__.py @@ -7,17 +7,26 @@ from nanobot.providers.base import LLMProvider, LLMResponse -__all__ = ["LLMProvider", "LLMResponse", "LiteLLMProvider", "OpenAICodexProvider", "AzureOpenAIProvider"] +__all__ = [ + "LLMProvider", + "LLMResponse", + "AnthropicProvider", + "OpenAICompatProvider", + "OpenAICodexProvider", + "AzureOpenAIProvider", +] _LAZY_IMPORTS = { - "LiteLLMProvider": ".litellm_provider", + "AnthropicProvider": ".anthropic_provider", + "OpenAICompatProvider": ".openai_compat_provider", "OpenAICodexProvider": ".openai_codex_provider", "AzureOpenAIProvider": ".azure_openai_provider", } if TYPE_CHECKING: + from nanobot.providers.anthropic_provider import AnthropicProvider from nanobot.providers.azure_openai_provider import AzureOpenAIProvider - from nanobot.providers.litellm_provider import LiteLLMProvider + from nanobot.providers.openai_compat_provider import OpenAICompatProvider from nanobot.providers.openai_codex_provider import OpenAICodexProvider diff --git a/nanobot/providers/anthropic_provider.py b/nanobot/providers/anthropic_provider.py new file mode 100644 index 00000000000..3c789e73068 --- /dev/null +++ b/nanobot/providers/anthropic_provider.py @@ -0,0 +1,441 @@ +"""Anthropic provider β€” direct SDK integration for Claude models.""" + +from __future__ import annotations + +import re +import secrets +import string +from collections.abc import Awaitable, Callable +from typing import Any + +import json_repair +from loguru import logger + +from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest + +_ALNUM = string.ascii_letters + string.digits + + +def _gen_tool_id() -> str: + return "toolu_" + "".join(secrets.choice(_ALNUM) for _ in range(22)) + + +class AnthropicProvider(LLMProvider): + """LLM provider using the native Anthropic SDK for Claude models. + + Handles message format conversion (OpenAI β†’ Anthropic Messages API), + prompt caching, extended thinking, tool calls, and streaming. + """ + + def __init__( + self, + api_key: str | None = None, + api_base: str | None = None, + default_model: str = "claude-sonnet-4-20250514", + extra_headers: dict[str, str] | None = None, + ): + super().__init__(api_key, api_base) + self.default_model = default_model + self.extra_headers = extra_headers or {} + + from anthropic import AsyncAnthropic + + client_kw: dict[str, Any] = {} + if api_key: + client_kw["api_key"] = api_key + if api_base: + client_kw["base_url"] = api_base + if extra_headers: + client_kw["default_headers"] = extra_headers + self._client = AsyncAnthropic(**client_kw) + + @staticmethod + def _strip_prefix(model: str) -> str: + if model.startswith("anthropic/"): + return model[len("anthropic/"):] + return model + + # ------------------------------------------------------------------ + # Message conversion: OpenAI chat format β†’ Anthropic Messages API + # ------------------------------------------------------------------ + + def _convert_messages( + self, messages: list[dict[str, Any]], + ) -> tuple[str | list[dict[str, Any]], list[dict[str, Any]]]: + """Return ``(system, anthropic_messages)``.""" + system: str | list[dict[str, Any]] = "" + raw: list[dict[str, Any]] = [] + + for msg in messages: + role = msg.get("role", "") + content = msg.get("content") + + if role == "system": + system = content if isinstance(content, (str, list)) else str(content or "") + continue + + if role == "tool": + block = self._tool_result_block(msg) + if raw and raw[-1]["role"] == "user": + prev_c = raw[-1]["content"] + if isinstance(prev_c, list): + prev_c.append(block) + else: + raw[-1]["content"] = [ + {"type": "text", "text": prev_c or ""}, block, + ] + else: + raw.append({"role": "user", "content": [block]}) + continue + + if role == "assistant": + raw.append({"role": "assistant", "content": self._assistant_blocks(msg)}) + continue + + if role == "user": + raw.append({ + "role": "user", + "content": self._convert_user_content(content), + }) + continue + + return system, self._merge_consecutive(raw) + + @staticmethod + def _tool_result_block(msg: dict[str, Any]) -> dict[str, Any]: + content = msg.get("content") + block: dict[str, Any] = { + "type": "tool_result", + "tool_use_id": msg.get("tool_call_id", ""), + } + if isinstance(content, (str, list)): + block["content"] = content + else: + block["content"] = str(content) if content else "" + return block + + @staticmethod + def _assistant_blocks(msg: dict[str, Any]) -> list[dict[str, Any]]: + blocks: list[dict[str, Any]] = [] + content = msg.get("content") + + for tb in msg.get("thinking_blocks") or []: + if isinstance(tb, dict) and tb.get("type") == "thinking": + blocks.append({ + "type": "thinking", + "thinking": tb.get("thinking", ""), + "signature": tb.get("signature", ""), + }) + + if isinstance(content, str) and content: + blocks.append({"type": "text", "text": content}) + elif isinstance(content, list): + for item in content: + blocks.append(item if isinstance(item, dict) else {"type": "text", "text": str(item)}) + + for tc in msg.get("tool_calls") or []: + if not isinstance(tc, dict): + continue + func = tc.get("function", {}) + args = func.get("arguments", "{}") + if isinstance(args, str): + args = json_repair.loads(args) + blocks.append({ + "type": "tool_use", + "id": tc.get("id") or _gen_tool_id(), + "name": func.get("name", ""), + "input": args, + }) + + return blocks or [{"type": "text", "text": ""}] + + def _convert_user_content(self, content: Any) -> Any: + """Convert user message content, translating image_url blocks.""" + if isinstance(content, str) or content is None: + return content or "(empty)" + if not isinstance(content, list): + return str(content) + + result: list[dict[str, Any]] = [] + for item in content: + if not isinstance(item, dict): + result.append({"type": "text", "text": str(item)}) + continue + if item.get("type") == "image_url": + converted = self._convert_image_block(item) + if converted: + result.append(converted) + continue + result.append(item) + return result or "(empty)" + + @staticmethod + def _convert_image_block(block: dict[str, Any]) -> dict[str, Any] | None: + """Convert OpenAI image_url block to Anthropic image block.""" + url = (block.get("image_url") or {}).get("url", "") + if not url: + return None + m = re.match(r"data:(image/\w+);base64,(.+)", url, re.DOTALL) + if m: + return { + "type": "image", + "source": {"type": "base64", "media_type": m.group(1), "data": m.group(2)}, + } + return { + "type": "image", + "source": {"type": "url", "url": url}, + } + + @staticmethod + def _merge_consecutive(msgs: list[dict[str, Any]]) -> list[dict[str, Any]]: + """Anthropic requires alternating user/assistant roles.""" + merged: list[dict[str, Any]] = [] + for msg in msgs: + if merged and merged[-1]["role"] == msg["role"]: + prev_c = merged[-1]["content"] + cur_c = msg["content"] + if isinstance(prev_c, str): + prev_c = [{"type": "text", "text": prev_c}] + if isinstance(cur_c, str): + cur_c = [{"type": "text", "text": cur_c}] + if isinstance(cur_c, list): + prev_c.extend(cur_c) + merged[-1]["content"] = prev_c + else: + merged.append(msg) + return merged + + # ------------------------------------------------------------------ + # Tool definition conversion + # ------------------------------------------------------------------ + + @staticmethod + def _convert_tools(tools: list[dict[str, Any]] | None) -> list[dict[str, Any]] | None: + if not tools: + return None + result = [] + for tool in tools: + func = tool.get("function", tool) + entry: dict[str, Any] = { + "name": func.get("name", ""), + "input_schema": func.get("parameters", {"type": "object", "properties": {}}), + } + desc = func.get("description") + if desc: + entry["description"] = desc + if "cache_control" in tool: + entry["cache_control"] = tool["cache_control"] + result.append(entry) + return result + + @staticmethod + def _convert_tool_choice( + tool_choice: str | dict[str, Any] | None, + thinking_enabled: bool = False, + ) -> dict[str, Any] | None: + if thinking_enabled: + return {"type": "auto"} + if tool_choice is None or tool_choice == "auto": + return {"type": "auto"} + if tool_choice == "required": + return {"type": "any"} + if tool_choice == "none": + return None + if isinstance(tool_choice, dict): + name = tool_choice.get("function", {}).get("name") + if name: + return {"type": "tool", "name": name} + return {"type": "auto"} + + # ------------------------------------------------------------------ + # Prompt caching + # ------------------------------------------------------------------ + + @staticmethod + def _apply_cache_control( + system: str | list[dict[str, Any]], + messages: list[dict[str, Any]], + tools: list[dict[str, Any]] | None, + ) -> tuple[str | list[dict[str, Any]], list[dict[str, Any]], list[dict[str, Any]] | None]: + marker = {"type": "ephemeral"} + + if isinstance(system, str) and system: + system = [{"type": "text", "text": system, "cache_control": marker}] + elif isinstance(system, list) and system: + system = list(system) + system[-1] = {**system[-1], "cache_control": marker} + + new_msgs = list(messages) + if len(new_msgs) >= 3: + m = new_msgs[-2] + c = m.get("content") + if isinstance(c, str): + new_msgs[-2] = {**m, "content": [{"type": "text", "text": c, "cache_control": marker}]} + elif isinstance(c, list) and c: + nc = list(c) + nc[-1] = {**nc[-1], "cache_control": marker} + new_msgs[-2] = {**m, "content": nc} + + new_tools = tools + if tools: + new_tools = list(tools) + new_tools[-1] = {**new_tools[-1], "cache_control": marker} + + return system, new_msgs, new_tools + + # ------------------------------------------------------------------ + # Build API kwargs + # ------------------------------------------------------------------ + + def _build_kwargs( + self, + messages: list[dict[str, Any]], + tools: list[dict[str, Any]] | None, + model: str | None, + max_tokens: int, + temperature: float, + reasoning_effort: str | None, + tool_choice: str | dict[str, Any] | None, + supports_caching: bool = True, + ) -> dict[str, Any]: + model_name = self._strip_prefix(model or self.default_model) + system, anthropic_msgs = self._convert_messages(self._sanitize_empty_content(messages)) + anthropic_tools = self._convert_tools(tools) + + if supports_caching: + system, anthropic_msgs, anthropic_tools = self._apply_cache_control( + system, anthropic_msgs, anthropic_tools, + ) + + max_tokens = max(1, max_tokens) + thinking_enabled = bool(reasoning_effort) + + kwargs: dict[str, Any] = { + "model": model_name, + "messages": anthropic_msgs, + "max_tokens": max_tokens, + } + + if system: + kwargs["system"] = system + + if thinking_enabled: + budget_map = {"low": 1024, "medium": 4096, "high": max(8192, max_tokens)} + budget = budget_map.get(reasoning_effort.lower(), 4096) # type: ignore[union-attr] + kwargs["thinking"] = {"type": "enabled", "budget_tokens": budget} + kwargs["max_tokens"] = max(max_tokens, budget + 4096) + kwargs["temperature"] = 1.0 + else: + kwargs["temperature"] = temperature + + if anthropic_tools: + kwargs["tools"] = anthropic_tools + tc = self._convert_tool_choice(tool_choice, thinking_enabled) + if tc: + kwargs["tool_choice"] = tc + + if self.extra_headers: + kwargs["extra_headers"] = self.extra_headers + + return kwargs + + # ------------------------------------------------------------------ + # Response parsing + # ------------------------------------------------------------------ + + @staticmethod + def _parse_response(response: Any) -> LLMResponse: + content_parts: list[str] = [] + tool_calls: list[ToolCallRequest] = [] + thinking_blocks: list[dict[str, Any]] = [] + + for block in response.content: + if block.type == "text": + content_parts.append(block.text) + elif block.type == "tool_use": + tool_calls.append(ToolCallRequest( + id=block.id, + name=block.name, + arguments=block.input if isinstance(block.input, dict) else {}, + )) + elif block.type == "thinking": + thinking_blocks.append({ + "type": "thinking", + "thinking": block.thinking, + "signature": getattr(block, "signature", ""), + }) + + stop_map = {"tool_use": "tool_calls", "end_turn": "stop", "max_tokens": "length"} + finish_reason = stop_map.get(response.stop_reason or "", response.stop_reason or "stop") + + usage: dict[str, int] = {} + if response.usage: + usage = { + "prompt_tokens": response.usage.input_tokens, + "completion_tokens": response.usage.output_tokens, + "total_tokens": response.usage.input_tokens + response.usage.output_tokens, + } + for attr in ("cache_creation_input_tokens", "cache_read_input_tokens"): + val = getattr(response.usage, attr, 0) + if val: + usage[attr] = val + + return LLMResponse( + content="".join(content_parts) or None, + tool_calls=tool_calls, + finish_reason=finish_reason, + usage=usage, + thinking_blocks=thinking_blocks or None, + ) + + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ + + async def chat( + self, + messages: list[dict[str, Any]], + tools: list[dict[str, Any]] | None = None, + model: str | None = None, + max_tokens: int = 4096, + temperature: float = 0.7, + reasoning_effort: str | None = None, + tool_choice: str | dict[str, Any] | None = None, + ) -> LLMResponse: + kwargs = self._build_kwargs( + messages, tools, model, max_tokens, temperature, + reasoning_effort, tool_choice, + ) + try: + response = await self._client.messages.create(**kwargs) + return self._parse_response(response) + except Exception as e: + return LLMResponse(content=f"Error calling LLM: {e}", finish_reason="error") + + async def chat_stream( + self, + messages: list[dict[str, Any]], + tools: list[dict[str, Any]] | None = None, + model: str | None = None, + max_tokens: int = 4096, + temperature: float = 0.7, + reasoning_effort: str | None = None, + tool_choice: str | dict[str, Any] | None = None, + on_content_delta: Callable[[str], Awaitable[None]] | None = None, + ) -> LLMResponse: + kwargs = self._build_kwargs( + messages, tools, model, max_tokens, temperature, + reasoning_effort, tool_choice, + ) + try: + async with self._client.messages.stream(**kwargs) as stream: + if on_content_delta: + async for text in stream.text_stream: + await on_content_delta(text) + response = await stream.get_final_message() + return self._parse_response(response) + except Exception as e: + return LLMResponse(content=f"Error calling LLM: {e}", finish_reason="error") + + def get_default_model(self) -> str: + return self.default_model diff --git a/nanobot/providers/base.py b/nanobot/providers/base.py index 8c65d5e9980..87e9de6eb26 100644 --- a/nanobot/providers/base.py +++ b/nanobot/providers/base.py @@ -16,6 +16,7 @@ class ToolCallRequest: id: str name: str arguments: dict[str, Any] + extra_content: dict[str, Any] | None = None provider_specific_fields: dict[str, Any] | None = None function_provider_specific_fields: dict[str, Any] | None = None @@ -29,6 +30,8 @@ def to_openai_tool_call(self) -> dict[str, Any]: "arguments": json.dumps(self.arguments, ensure_ascii=False), }, } + if self.extra_content: + tool_call["extra_content"] = self.extra_content if self.provider_specific_fields: tool_call["provider_specific_fields"] = self.provider_specific_fields if self.function_provider_specific_fields: diff --git a/nanobot/providers/custom_provider.py b/nanobot/providers/custom_provider.py deleted file mode 100644 index a47dae7cd1e..00000000000 --- a/nanobot/providers/custom_provider.py +++ /dev/null @@ -1,152 +0,0 @@ -"""Direct OpenAI-compatible provider β€” bypasses LiteLLM.""" - -from __future__ import annotations - -import uuid -from collections.abc import Awaitable, Callable -from typing import Any - -import json_repair -from openai import AsyncOpenAI - -from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest - - -class CustomProvider(LLMProvider): - - def __init__( - self, - api_key: str = "no-key", - api_base: str = "http://localhost:8000/v1", - default_model: str = "default", - extra_headers: dict[str, str] | None = None, - ): - super().__init__(api_key, api_base) - self.default_model = default_model - self._client = AsyncOpenAI( - api_key=api_key, - base_url=api_base, - default_headers={ - "x-session-affinity": uuid.uuid4().hex, - **(extra_headers or {}), - }, - ) - - def _build_kwargs( - self, messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None, - model: str | None, max_tokens: int, temperature: float, - reasoning_effort: str | None, tool_choice: str | dict[str, Any] | None, - ) -> dict[str, Any]: - kwargs: dict[str, Any] = { - "model": model or self.default_model, - "messages": self._sanitize_empty_content(messages), - "max_tokens": max(1, max_tokens), - "temperature": temperature, - } - if reasoning_effort: - kwargs["reasoning_effort"] = reasoning_effort - if tools: - kwargs.update(tools=tools, tool_choice=tool_choice or "auto") - return kwargs - - def _handle_error(self, e: Exception) -> LLMResponse: - body = getattr(e, "doc", None) or getattr(getattr(e, "response", None), "text", None) - msg = f"Error: {body.strip()[:500]}" if body and body.strip() else f"Error: {e}" - return LLMResponse(content=msg, finish_reason="error") - - async def chat(self, messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None = None, - model: str | None = None, max_tokens: int = 4096, temperature: float = 0.7, - reasoning_effort: str | None = None, - tool_choice: str | dict[str, Any] | None = None) -> LLMResponse: - kwargs = self._build_kwargs(messages, tools, model, max_tokens, temperature, reasoning_effort, tool_choice) - try: - return self._parse(await self._client.chat.completions.create(**kwargs)) - except Exception as e: - return self._handle_error(e) - - async def chat_stream( - self, messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None = None, - model: str | None = None, max_tokens: int = 4096, temperature: float = 0.7, - reasoning_effort: str | None = None, - tool_choice: str | dict[str, Any] | None = None, - on_content_delta: Callable[[str], Awaitable[None]] | None = None, - ) -> LLMResponse: - kwargs = self._build_kwargs(messages, tools, model, max_tokens, temperature, reasoning_effort, tool_choice) - kwargs["stream"] = True - try: - stream = await self._client.chat.completions.create(**kwargs) - chunks: list[Any] = [] - async for chunk in stream: - chunks.append(chunk) - if on_content_delta and chunk.choices: - text = getattr(chunk.choices[0].delta, "content", None) - if text: - await on_content_delta(text) - return self._parse_chunks(chunks) - except Exception as e: - return self._handle_error(e) - - def _parse(self, response: Any) -> LLMResponse: - if not response.choices: - return LLMResponse( - content="Error: API returned empty choices.", - finish_reason="error", - ) - choice = response.choices[0] - msg = choice.message - tool_calls = [ - ToolCallRequest( - id=tc.id, name=tc.function.name, - arguments=json_repair.loads(tc.function.arguments) if isinstance(tc.function.arguments, str) else tc.function.arguments, - ) - for tc in (msg.tool_calls or []) - ] - u = response.usage - return LLMResponse( - content=msg.content, tool_calls=tool_calls, - finish_reason=choice.finish_reason or "stop", - usage={"prompt_tokens": u.prompt_tokens, "completion_tokens": u.completion_tokens, "total_tokens": u.total_tokens} if u else {}, - reasoning_content=getattr(msg, "reasoning_content", None) or None, - ) - - def _parse_chunks(self, chunks: list[Any]) -> LLMResponse: - """Reassemble streamed chunks into a single LLMResponse.""" - content_parts: list[str] = [] - tc_bufs: dict[int, dict[str, str]] = {} - finish_reason = "stop" - usage: dict[str, int] = {} - - for chunk in chunks: - if not chunk.choices: - if hasattr(chunk, "usage") and chunk.usage: - u = chunk.usage - usage = {"prompt_tokens": u.prompt_tokens or 0, "completion_tokens": u.completion_tokens or 0, - "total_tokens": u.total_tokens or 0} - continue - choice = chunk.choices[0] - if choice.finish_reason: - finish_reason = choice.finish_reason - delta = choice.delta - if delta and delta.content: - content_parts.append(delta.content) - for tc in (delta.tool_calls or []) if delta else []: - buf = tc_bufs.setdefault(tc.index, {"id": "", "name": "", "arguments": ""}) - if tc.id: - buf["id"] = tc.id - if tc.function and tc.function.name: - buf["name"] = tc.function.name - if tc.function and tc.function.arguments: - buf["arguments"] += tc.function.arguments - - return LLMResponse( - content="".join(content_parts) or None, - tool_calls=[ - ToolCallRequest(id=b["id"], name=b["name"], arguments=json_repair.loads(b["arguments"]) if b["arguments"] else {}) - for b in tc_bufs.values() - ], - finish_reason=finish_reason, - usage=usage, - ) - - def get_default_model(self) -> str: - return self.default_model diff --git a/nanobot/providers/litellm_provider.py b/nanobot/providers/litellm_provider.py deleted file mode 100644 index 8e84a3d027d..00000000000 --- a/nanobot/providers/litellm_provider.py +++ /dev/null @@ -1,562 +0,0 @@ -"""LiteLLM provider implementation for multi-provider support.""" - -import hashlib -import os -import secrets -import string -from collections.abc import Awaitable, Callable -from typing import Any - -import json_repair -import litellm -from litellm import acompletion -from loguru import logger - -from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest -from nanobot.providers.registry import find_by_model, find_gateway - -# Standard chat-completion message keys. -_ALLOWED_MSG_KEYS = frozenset({"role", "content", "tool_calls", "tool_call_id", "name", "reasoning_content"}) -_ANTHROPIC_EXTRA_KEYS = frozenset({"thinking_blocks"}) -_ALNUM = string.ascii_letters + string.digits - -def _short_tool_id() -> str: - """Generate a 9-char alphanumeric ID compatible with all providers (incl. Mistral).""" - return "".join(secrets.choice(_ALNUM) for _ in range(9)) - - -class LiteLLMProvider(LLMProvider): - """ - LLM provider using LiteLLM for multi-provider support. - - Supports OpenRouter, Anthropic, OpenAI, Gemini, MiniMax, and many other providers through - a unified interface. Provider-specific logic is driven by the registry - (see providers/registry.py) β€” no if-elif chains needed here. - """ - - def __init__( - self, - api_key: str | None = None, - api_base: str | None = None, - default_model: str = "anthropic/claude-opus-4-5", - extra_headers: dict[str, str] | None = None, - provider_name: str | None = None, - suppress_tools_param: bool = False, - request_timeout: int | None = None, - ): - super().__init__(api_key, api_base) - self.default_model = default_model - self.extra_headers = extra_headers or {} - self.suppress_tools_param = suppress_tools_param - self.request_timeout = request_timeout - - # Detect gateway / local deployment. - # provider_name (from config key) is the primary signal; - # api_key / api_base are fallback for auto-detection. - self._gateway = find_gateway(provider_name, api_key, api_base) - - # Configure environment variables - if api_key: - self._setup_env(api_key, api_base, default_model) - - if api_base: - litellm.api_base = api_base - - # Disable LiteLLM logging noise - litellm.suppress_debug_info = True - # Drop unsupported parameters for providers (e.g., gpt-5 rejects some params) - litellm.drop_params = True - - self._langsmith_enabled = bool(os.getenv("LANGSMITH_API_KEY")) - - def _setup_env(self, api_key: str, api_base: str | None, model: str) -> None: - """Set environment variables based on detected provider.""" - spec = self._gateway or find_by_model(model) - if not spec: - return - if not spec.env_key: - # OAuth/provider-only specs (for example: openai_codex) - return - - # Gateway/local overrides existing env; standard provider doesn't - if self._gateway: - os.environ[spec.env_key] = api_key - else: - os.environ.setdefault(spec.env_key, api_key) - - # Resolve env_extras placeholders: - # {api_key} β†’ user's API key - # {api_base} β†’ user's api_base, falling back to spec.default_api_base - effective_base = api_base or spec.default_api_base - for env_name, env_val in spec.env_extras: - resolved = env_val.replace("{api_key}", api_key) - resolved = resolved.replace("{api_base}", effective_base) - os.environ.setdefault(env_name, resolved) - - def _resolve_model(self, model: str) -> str: - """Resolve model name by applying provider/gateway prefixes.""" - if self._gateway: - prefix = self._gateway.litellm_prefix - if self._gateway.strip_model_prefix: - model = model.split("/")[-1] - if prefix: - model = f"{prefix}/{model}" - return model - - # Standard mode: auto-prefix for known providers - spec = find_by_model(model) - if spec and spec.litellm_prefix: - model = self._canonicalize_explicit_prefix(model, spec.name, spec.litellm_prefix) - if not any(model.startswith(s) for s in spec.skip_prefixes): - model = f"{spec.litellm_prefix}/{model}" - - return model - - @staticmethod - def _canonicalize_explicit_prefix(model: str, spec_name: str, canonical_prefix: str) -> str: - """Normalize explicit provider prefixes like `github-copilot/...`.""" - if "/" not in model: - return model - prefix, remainder = model.split("/", 1) - if prefix.lower().replace("-", "_") != spec_name: - return model - return f"{canonical_prefix}/{remainder}" - - def _supports_cache_control(self, model: str) -> bool: - """Return True when the provider supports cache_control on content blocks.""" - if self._gateway is not None: - return self._gateway.supports_prompt_caching - spec = find_by_model(model) - return spec is not None and spec.supports_prompt_caching - - def _apply_cache_control( - self, - messages: list[dict[str, Any]], - tools: list[dict[str, Any]] | None, - ) -> tuple[list[dict[str, Any]], list[dict[str, Any]] | None]: - """Return copies of messages and tools with cache_control injected. - - Two breakpoints are placed: - 1. System message β€” caches the static system prompt - 2. Second-to-last message β€” caches the conversation history prefix - This maximises cache hits across multi-turn conversations. - """ - cache_marker = {"type": "ephemeral"} - new_messages = list(messages) - - def _mark(msg: dict[str, Any]) -> dict[str, Any]: - content = msg.get("content") - if isinstance(content, str): - return {**msg, "content": [ - {"type": "text", "text": content, "cache_control": cache_marker} - ]} - elif isinstance(content, list) and content: - new_content = list(content) - new_content[-1] = {**new_content[-1], "cache_control": cache_marker} - return {**msg, "content": new_content} - return msg - - # Breakpoint 1: system message - if new_messages and new_messages[0].get("role") == "system": - new_messages[0] = _mark(new_messages[0]) - - # Breakpoint 2: second-to-last message (caches conversation history prefix) - if len(new_messages) >= 3: - new_messages[-2] = _mark(new_messages[-2]) - - new_tools = tools - if tools: - new_tools = list(tools) - new_tools[-1] = {**new_tools[-1], "cache_control": cache_marker} - - return new_messages, new_tools - - def _apply_model_overrides(self, model: str, kwargs: dict[str, Any]) -> None: - """Apply model-specific parameter overrides from the registry.""" - model_lower = model.lower() - spec = find_by_model(model) - if spec: - for pattern, overrides in spec.model_overrides: - if pattern in model_lower: - kwargs.update(overrides) - return - - @staticmethod - def _extra_msg_keys(original_model: str, resolved_model: str) -> frozenset[str]: - """Return provider-specific extra keys to preserve in request messages.""" - spec = find_by_model(original_model) or find_by_model(resolved_model) - if (spec and spec.name == "anthropic") or "claude" in original_model.lower() or resolved_model.startswith("anthropic/"): - return _ANTHROPIC_EXTRA_KEYS - return frozenset() - - @staticmethod - def _normalize_tool_call_id(tool_call_id: Any) -> Any: - """Normalize tool_call_id to a provider-safe 9-char alphanumeric form.""" - if not isinstance(tool_call_id, str): - return tool_call_id - if len(tool_call_id) == 9 and tool_call_id.isalnum(): - return tool_call_id - return hashlib.sha1(tool_call_id.encode()).hexdigest()[:9] - - @staticmethod - def _sanitize_messages(messages: list[dict[str, Any]], extra_keys: frozenset[str] = frozenset()) -> list[dict[str, Any]]: - """Strip non-standard keys and ensure assistant messages have a content key.""" - allowed = _ALLOWED_MSG_KEYS | extra_keys - sanitized = LLMProvider._sanitize_request_messages(messages, allowed) - id_map: dict[str, str] = {} - - def map_id(value: Any) -> Any: - if not isinstance(value, str): - return value - return id_map.setdefault(value, LiteLLMProvider._normalize_tool_call_id(value)) - - for clean in sanitized: - # Keep assistant tool_calls[].id and tool tool_call_id in sync after - # shortening, otherwise strict providers reject the broken linkage. - if isinstance(clean.get("tool_calls"), list): - normalized_tool_calls = [] - for tc in clean["tool_calls"]: - if not isinstance(tc, dict): - normalized_tool_calls.append(tc) - continue - tc_clean = dict(tc) - tc_clean["id"] = map_id(tc_clean.get("id")) - normalized_tool_calls.append(tc_clean) - clean["tool_calls"] = normalized_tool_calls - - if "tool_call_id" in clean and clean["tool_call_id"]: - clean["tool_call_id"] = map_id(clean["tool_call_id"]) - return sanitized - - def _build_chat_kwargs( - self, - messages: list[dict[str, Any]], - tools: list[dict[str, Any]] | None, - model: str | None, - max_tokens: int, - temperature: float, - reasoning_effort: str | None, - tool_choice: str | dict[str, Any] | None, - ) -> tuple[dict[str, Any], str]: - """Build the kwargs dict for ``acompletion``. - - Returns ``(kwargs, original_model)`` so callers can reuse the - original model string for downstream logic. - """ - original_model = model or self.default_model - resolved = self._resolve_model(original_model) - extra_msg_keys = self._extra_msg_keys(original_model, resolved) - - if self._supports_cache_control(original_model): - messages, tools = self._apply_cache_control(messages, tools) - - max_tokens = max(1, max_tokens) - - kwargs: dict[str, Any] = { - "model": resolved, - "messages": self._sanitize_messages( - self._sanitize_empty_content(messages), extra_keys=extra_msg_keys, - ), - "max_tokens": max_tokens, - "temperature": temperature, - } - - if self._gateway: - kwargs.update(self._gateway.litellm_kwargs) - - self._apply_model_overrides(resolved, kwargs) - - if self._langsmith_enabled: - kwargs.setdefault("callbacks", []).append("langsmith") - - if self.api_key: - kwargs["api_key"] = self.api_key - if self.api_base: - kwargs["api_base"] = self.api_base - if self.extra_headers: - kwargs["extra_headers"] = self.extra_headers - - if reasoning_effort: - kwargs["reasoning_effort"] = reasoning_effort - kwargs["drop_params"] = True - - if tools and not self.suppress_tools_param: - kwargs["tools"] = tools - kwargs["tool_choice"] = tool_choice or "auto" - - if self.request_timeout is not None: - kwargs["timeout"] = self.request_timeout - - return kwargs, original_model - - async def chat( - self, - messages: list[dict[str, Any]], - tools: list[dict[str, Any]] | None = None, - model: str | None = None, - max_tokens: int = 4096, - temperature: float = 0.7, - reasoning_effort: str | None = None, - tool_choice: str | dict[str, Any] | None = None, - ) -> LLMResponse: - """Send a chat completion request via LiteLLM.""" - kwargs, _ = self._build_chat_kwargs( - messages, tools, model, max_tokens, temperature, - reasoning_effort, tool_choice, - ) - try: - response = await acompletion(**kwargs) - return self._parse_response(response) - except Exception as e: - return LLMResponse( - content=f"Error calling LLM: {str(e)}", - finish_reason="error", - ) - - async def chat_stream( - self, - messages: list[dict[str, Any]], - tools: list[dict[str, Any]] | None = None, - model: str | None = None, - max_tokens: int = 4096, - temperature: float = 0.7, - reasoning_effort: str | None = None, - tool_choice: str | dict[str, Any] | None = None, - on_content_delta: Callable[[str], Awaitable[None]] | None = None, - ) -> LLMResponse: - """Stream a chat completion via LiteLLM, forwarding text deltas.""" - kwargs, _ = self._build_chat_kwargs( - messages, tools, model, max_tokens, temperature, - reasoning_effort, tool_choice, - ) - kwargs["stream"] = True - - try: - stream = await acompletion(**kwargs) - chunks: list[Any] = [] - async for chunk in stream: - chunks.append(chunk) - if on_content_delta: - delta = chunk.choices[0].delta if chunk.choices else None - text = getattr(delta, "content", None) if delta else None - if text: - await on_content_delta(text) - - full_response = litellm.stream_chunk_builder( - chunks, messages=kwargs["messages"], - ) - return self._parse_response(full_response) - except Exception as e: - return LLMResponse( - content=f"Error calling LLM: {str(e)}", - finish_reason="error", - ) - - def _parse_response(self, response: Any) -> LLMResponse: - """Parse LiteLLM response into our standard format.""" - choice = response.choices[0] - message = choice.message - content = message.content - finish_reason = choice.finish_reason - - # Some providers (e.g. GitHub Copilot) split content and tool_calls - # across multiple choices. Merge them so tool_calls are not lost. - raw_tool_calls = [] - for ch in response.choices: - msg = ch.message - if hasattr(msg, "tool_calls") and msg.tool_calls: - raw_tool_calls.extend(msg.tool_calls) - if ch.finish_reason in ("tool_calls", "stop"): - finish_reason = ch.finish_reason - if not content and msg.content: - content = msg.content - - if len(response.choices) > 1: - logger.debug("LiteLLM response has {} choices, merged {} tool_calls", - len(response.choices), len(raw_tool_calls)) - - tool_calls = [] - for tc in raw_tool_calls: - # Parse arguments from JSON string if needed - args = tc.function.arguments - if isinstance(args, str): - args = json_repair.loads(args) - if not isinstance(args, dict): - args = {} - - provider_specific_fields = getattr(tc, "provider_specific_fields", None) or None - function_provider_specific_fields = ( - getattr(tc.function, "provider_specific_fields", None) or None - ) - tool_calls.append(ToolCallRequest( - id=_short_tool_id(), - name=tc.function.name, - arguments=args, - provider_specific_fields=provider_specific_fields, - function_provider_specific_fields=function_provider_specific_fields, - )) - - usage = {} - if hasattr(response, "usage") and response.usage: - usage = { - "prompt_tokens": response.usage.prompt_tokens, - "completion_tokens": response.usage.completion_tokens, - "total_tokens": response.usage.total_tokens, - } - - # Fallback: some models (e.g. Qwen via Ollama/vLLM) embed tool-call JSON - # in the text content instead of the structured tool_calls field. - # Strip ... blocks first β€” thinking models (e.g. Qwen3 with - # thinking=1) embed reasoning in content; tool calls appear after the block. - if not tool_calls and content: - import re as _re - stripped = _re.sub(r".*?", "", content, flags=_re.DOTALL).strip() - tool_calls, extracted_content = LiteLLMProvider._extract_text_tool_calls(stripped) - if tool_calls: - content = extracted_content - - # Strip EOS/EOT tokens that some models (e.g. Ministral via llama.cpp) leak - # into response content. If stored in history they break Jinja chat templating - # on the next request. - if content: - for tok in ("<|im_end|>", "<|endoftext|>", "", "<|eot_id|>"): - content = content.replace(tok, "") - content = content.rstrip() or None - - reasoning_content = getattr(message, "reasoning_content", None) or None - thinking_blocks = getattr(message, "thinking_blocks", None) or None - - return LLMResponse( - content=content, - tool_calls=tool_calls, - finish_reason=finish_reason or "stop", - usage=usage, - reasoning_content=reasoning_content, - thinking_blocks=thinking_blocks, - ) - - @staticmethod - def _extract_text_tool_calls(content: str) -> tuple[list[ToolCallRequest], str | None]: - """Parse tool calls embedded as JSON or XML in text content. - - Returns (tool_calls, remaining_content). If the entire content is a - tool-call envelope the remaining content is set to None so the agent - loop doesn't forward raw JSON/XML to the user. - """ - import re - - # Try Qwen3 format: {"name": "...", "arguments": {...}} - qwen_calls = [] - for qwen_match in re.finditer(r"(.*?)", content, re.DOTALL): - try: - obj = json_repair.loads(qwen_match.group(1).strip()) - if isinstance(obj, dict) and isinstance(obj.get("name"), str) and isinstance(obj.get("arguments"), dict): - qwen_calls.append(ToolCallRequest(id=_short_tool_id(), name=obj["name"], arguments=obj["arguments"])) - except Exception: - pass - if qwen_calls: - first = re.search(r"", content) - preamble = (content[:first.start()].strip() if first else None) or None - logger.info("_parse_response: extracted {} Qwen3-embedded tool call(s): {}", - len(qwen_calls), [c.name for c in qwen_calls]) - return qwen_calls, preamble - - # Try [TOOL_CALLS]name[ARGS]{json} format (Nemotron/Orchestrator models) - # Don't regex-match the JSON body β€” find the opening { and let json_repair - # consume as much as it needs. This handles large result strings with } inside. - tc_calls = [] - for tc_match in re.finditer(r"\[TOOL_CALLS\]([\w_\-]+)\[ARGS\](\{)", content, re.DOTALL): - try: - json_start = tc_match.start(2) - args = json_repair.loads(content[json_start:]) - if not isinstance(args, dict): - args = {} - tc_calls.append(ToolCallRequest(id=_short_tool_id(), name=tc_match.group(1).strip(), arguments=args)) - except Exception: - pass - if tc_calls: - first = re.search(r"\[TOOL_CALLS\]", content) - preamble = (content[:first.start()].strip() if first else None) or None - logger.info("_parse_response: extracted {} [TOOL_CALLS]-format tool call(s): {}", - len(tc_calls), [c.name for c in tc_calls]) - return tc_calls, preamble - - # Try Python code-block format: ```python\ntool_name(arg="value")\n``` - import ast as _ast - py_block_match = re.search(r"```(?:python)?\s*\n?([\w_]+\(.*?\))\s*\n?```", content, re.DOTALL) - if py_block_match: - call_src = py_block_match.group(1).strip() - fn_match = re.match(r"([\w_]+)\((.*)\)$", call_src, re.DOTALL) - if fn_match: - fn_name = fn_match.group(1) - args_src = fn_match.group(2).strip() - try: - # Parse as keyword args only: key="val", key2=123, ... - dummy = f"_f({args_src})" - tree = _ast.parse(dummy, mode="eval") - arguments = {} - for kw in tree.body.keywords: # type: ignore[attr-defined] - if kw.arg: - arguments[kw.arg] = _ast.literal_eval(kw.value) - py_call = ToolCallRequest(id=_short_tool_id(), name=fn_name, arguments=arguments) - preamble = content[:py_block_match.start()].strip() or None - logger.info("_parse_response: extracted Python code-block tool call: {}", fn_name) - return [py_call], preamble - except Exception: - pass - - # Try legacy XML format: value... - # Model may omit closing tag, so match up to or end of string. - xml_match = re.search(r"", content) - if xml_match: - calls = [] - xml_body = content[xml_match.start():] - for fn_match in re.finditer( - r"(.*?)(?:||\Z)", - xml_body, - re.DOTALL, - ): - name = fn_match.group(1) - body = fn_match.group(2) - arguments = {} - for param in re.finditer(r"(.*?)", body, re.DOTALL): - arguments[param.group(1)] = param.group(2).strip() - if arguments: - calls.append(ToolCallRequest(id=_short_tool_id(), name=name, arguments=arguments)) - if calls: - logger.info("_parse_response: extracted {} XML-embedded tool call(s): {}", - len(calls), [c.name for c in calls]) - preamble = content[:xml_match.start()].strip() or None - return calls, preamble - - # Fall back to JSON format - match = re.search(r"[{\[]", content) - if not match: - return [], content - json_start = match.start() - candidate = content[json_start:].strip() - try: - parsed = json_repair.loads(candidate) - except Exception: - return [], content - candidates = parsed if isinstance(parsed, list) else [parsed] - calls = [] - for obj in candidates: - if (isinstance(obj, dict) - and isinstance(obj.get("name"), str) - and isinstance(obj.get("arguments"), dict)): - calls.append(ToolCallRequest( - id=_short_tool_id(), - name=obj["name"], - arguments=obj["arguments"], - )) - if calls: - logger.info("_parse_response: extracted {} JSON-embedded tool call(s): {}", - len(calls), [c.name for c in calls]) - preamble = content[:json_start].strip() or None - return calls, preamble - return [], content - - def get_default_model(self) -> str: - """Get the default model.""" - return self.default_model diff --git a/nanobot/providers/openai_compat_provider.py b/nanobot/providers/openai_compat_provider.py new file mode 100644 index 00000000000..19115cdf493 --- /dev/null +++ b/nanobot/providers/openai_compat_provider.py @@ -0,0 +1,700 @@ +"""OpenAI-compatible provider for all non-Anthropic LLM APIs.""" + +from __future__ import annotations + +import hashlib +import os +import secrets +import string +import uuid +from collections.abc import Awaitable, Callable +from typing import TYPE_CHECKING, Any + +import json_repair +from loguru import logger +from openai import AsyncOpenAI + +from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest + +if TYPE_CHECKING: + from nanobot.providers.registry import ProviderSpec + +_ALLOWED_MSG_KEYS = frozenset({ + "role", "content", "tool_calls", "tool_call_id", "name", + "reasoning_content", "extra_content", +}) +_ALNUM = string.ascii_letters + string.digits + +_STANDARD_TC_KEYS = frozenset({"id", "type", "index", "function"}) +_STANDARD_FN_KEYS = frozenset({"name", "arguments"}) + + +def _short_tool_id() -> str: + """9-char alphanumeric ID compatible with all providers (incl. Mistral).""" + return "".join(secrets.choice(_ALNUM) for _ in range(9)) + + +def _get(obj: Any, key: str) -> Any: + """Get a value from dict or object attribute, returning None if absent.""" + if isinstance(obj, dict): + return obj.get(key) + return getattr(obj, key, None) + + +def _coerce_dict(value: Any) -> dict[str, Any] | None: + """Try to coerce *value* to a dict; return None if not possible or empty.""" + if value is None: + return None + if isinstance(value, dict): + return value if value else None + model_dump = getattr(value, "model_dump", None) + if callable(model_dump): + dumped = model_dump() + if isinstance(dumped, dict) and dumped: + return dumped + return None + + +def _extract_tc_extras(tc: Any) -> tuple[ + dict[str, Any] | None, + dict[str, Any] | None, + dict[str, Any] | None, +]: + """Extract (extra_content, provider_specific_fields, fn_provider_specific_fields). + + Works for both SDK objects and dicts. Captures Gemini ``extra_content`` + verbatim and any non-standard keys on the tool-call / function. + """ + extra_content = _coerce_dict(_get(tc, "extra_content")) + + tc_dict = _coerce_dict(tc) + prov = None + fn_prov = None + if tc_dict is not None: + leftover = {k: v for k, v in tc_dict.items() + if k not in _STANDARD_TC_KEYS and k != "extra_content" and v is not None} + if leftover: + prov = leftover + fn = _coerce_dict(tc_dict.get("function")) + if fn is not None: + fn_leftover = {k: v for k, v in fn.items() + if k not in _STANDARD_FN_KEYS and v is not None} + if fn_leftover: + fn_prov = fn_leftover + else: + prov = _coerce_dict(_get(tc, "provider_specific_fields")) + fn_obj = _get(tc, "function") + if fn_obj is not None: + fn_prov = _coerce_dict(_get(fn_obj, "provider_specific_fields")) + + return extra_content, prov, fn_prov + + +class OpenAICompatProvider(LLMProvider): + """Unified provider for all OpenAI-compatible APIs. + + Receives a resolved ``ProviderSpec`` from the caller β€” no internal + registry lookups needed. + """ + + def __init__( + self, + api_key: str | None = None, + api_base: str | None = None, + default_model: str = "gpt-4o", + extra_headers: dict[str, str] | None = None, + spec: ProviderSpec | None = None, + suppress_tools_param: bool = False, + request_timeout: int | None = None, + ): + super().__init__(api_key, api_base) + self.default_model = default_model + self.extra_headers = extra_headers or {} + self._spec = spec + self.suppress_tools_param = suppress_tools_param + self.request_timeout = request_timeout + + if api_key and spec and spec.env_key: + self._setup_env(api_key, api_base) + + effective_base = api_base or (spec.default_api_base if spec else None) or None + + self._client = AsyncOpenAI( + api_key=api_key or "no-key", + base_url=effective_base, + default_headers={ + "x-session-affinity": uuid.uuid4().hex, + **(extra_headers or {}), + }, + ) + + def _setup_env(self, api_key: str, api_base: str | None) -> None: + """Set environment variables based on provider spec.""" + spec = self._spec + if not spec or not spec.env_key: + return + if spec.is_gateway: + os.environ[spec.env_key] = api_key + else: + os.environ.setdefault(spec.env_key, api_key) + effective_base = api_base or spec.default_api_base + for env_name, env_val in spec.env_extras: + resolved = env_val.replace("{api_key}", api_key).replace("{api_base}", effective_base) + os.environ.setdefault(env_name, resolved) + + @staticmethod + def _apply_cache_control( + messages: list[dict[str, Any]], + tools: list[dict[str, Any]] | None, + ) -> tuple[list[dict[str, Any]], list[dict[str, Any]] | None]: + """Inject cache_control markers for prompt caching.""" + cache_marker = {"type": "ephemeral"} + new_messages = list(messages) + + def _mark(msg: dict[str, Any]) -> dict[str, Any]: + content = msg.get("content") + if isinstance(content, str): + return {**msg, "content": [ + {"type": "text", "text": content, "cache_control": cache_marker}, + ]} + if isinstance(content, list) and content: + nc = list(content) + nc[-1] = {**nc[-1], "cache_control": cache_marker} + return {**msg, "content": nc} + return msg + + if new_messages and new_messages[0].get("role") == "system": + new_messages[0] = _mark(new_messages[0]) + if len(new_messages) >= 3: + new_messages[-2] = _mark(new_messages[-2]) + + new_tools = tools + if tools: + new_tools = list(tools) + new_tools[-1] = {**new_tools[-1], "cache_control": cache_marker} + return new_messages, new_tools + + @staticmethod + def _normalize_tool_call_id(tool_call_id: Any) -> Any: + """Normalize to a provider-safe 9-char alphanumeric form.""" + if not isinstance(tool_call_id, str): + return tool_call_id + if len(tool_call_id) == 9 and tool_call_id.isalnum(): + return tool_call_id + return hashlib.sha1(tool_call_id.encode()).hexdigest()[:9] + + def _sanitize_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]: + """Strip non-standard keys, normalize tool_call IDs.""" + sanitized = LLMProvider._sanitize_request_messages(messages, _ALLOWED_MSG_KEYS) + id_map: dict[str, str] = {} + + def map_id(value: Any) -> Any: + if not isinstance(value, str): + return value + return id_map.setdefault(value, self._normalize_tool_call_id(value)) + + for clean in sanitized: + if isinstance(clean.get("tool_calls"), list): + normalized = [] + for tc in clean["tool_calls"]: + if not isinstance(tc, dict): + normalized.append(tc) + continue + tc_clean = dict(tc) + tc_clean["id"] = map_id(tc_clean.get("id")) + normalized.append(tc_clean) + clean["tool_calls"] = normalized + if "tool_call_id" in clean and clean["tool_call_id"]: + clean["tool_call_id"] = map_id(clean["tool_call_id"]) + return sanitized + + # ------------------------------------------------------------------ + # Build kwargs + # ------------------------------------------------------------------ + + def _build_kwargs( + self, + messages: list[dict[str, Any]], + tools: list[dict[str, Any]] | None, + model: str | None, + max_tokens: int, + temperature: float, + reasoning_effort: str | None, + tool_choice: str | dict[str, Any] | None, + ) -> dict[str, Any]: + model_name = model or self.default_model + spec = self._spec + + if spec and spec.supports_prompt_caching: + messages, tools = self._apply_cache_control(messages, tools) + + if spec and spec.strip_model_prefix: + model_name = model_name.split("/")[-1] + + kwargs: dict[str, Any] = { + "model": model_name, + "messages": self._sanitize_messages(self._sanitize_empty_content(messages)), + "max_tokens": max(1, max_tokens), + "max_completion_tokens": max(1, max_tokens), + "temperature": temperature, + } + + if spec: + model_lower = model_name.lower() + for pattern, overrides in spec.model_overrides: + if pattern in model_lower: + kwargs.update(overrides) + break + + if reasoning_effort: + kwargs["reasoning_effort"] = reasoning_effort + + if tools and not self.suppress_tools_param: + kwargs["tools"] = tools + kwargs["tool_choice"] = tool_choice or "auto" + + if self.request_timeout is not None: + kwargs["timeout"] = self.request_timeout + + return kwargs + + # ------------------------------------------------------------------ + # Response parsing + # ------------------------------------------------------------------ + + @staticmethod + def _maybe_mapping(value: Any) -> dict[str, Any] | None: + if isinstance(value, dict): + return value + model_dump = getattr(value, "model_dump", None) + if callable(model_dump): + dumped = model_dump() + if isinstance(dumped, dict): + return dumped + return None + + @classmethod + def _extract_text_content(cls, value: Any) -> str | None: + if value is None: + return None + if isinstance(value, str): + return value + if isinstance(value, list): + parts: list[str] = [] + for item in value: + item_map = cls._maybe_mapping(item) + if item_map: + text = item_map.get("text") + if isinstance(text, str): + parts.append(text) + continue + text = getattr(item, "text", None) + if isinstance(text, str): + parts.append(text) + continue + if isinstance(item, str): + parts.append(item) + return "".join(parts) or None + return str(value) + + @classmethod + def _extract_usage(cls, response: Any) -> dict[str, int]: + usage_obj = None + response_map = cls._maybe_mapping(response) + if response_map is not None: + usage_obj = response_map.get("usage") + elif hasattr(response, "usage") and response.usage: + usage_obj = response.usage + + usage_map = cls._maybe_mapping(usage_obj) + if usage_map is not None: + return { + "prompt_tokens": int(usage_map.get("prompt_tokens") or 0), + "completion_tokens": int(usage_map.get("completion_tokens") or 0), + "total_tokens": int(usage_map.get("total_tokens") or 0), + } + + if usage_obj: + return { + "prompt_tokens": getattr(usage_obj, "prompt_tokens", 0) or 0, + "completion_tokens": getattr(usage_obj, "completion_tokens", 0) or 0, + "total_tokens": getattr(usage_obj, "total_tokens", 0) or 0, + } + return {} + + def _parse(self, response: Any) -> LLMResponse: + if isinstance(response, str): + return LLMResponse(content=response, finish_reason="stop") + + response_map = self._maybe_mapping(response) + if response_map is not None: + choices = response_map.get("choices") or [] + if not choices: + content = self._extract_text_content( + response_map.get("content") or response_map.get("output_text") + ) + if content is not None: + return LLMResponse( + content=content, + finish_reason=str(response_map.get("finish_reason") or "stop"), + usage=self._extract_usage(response_map), + ) + return LLMResponse(content="Error: API returned empty choices.", finish_reason="error") + + choice0 = self._maybe_mapping(choices[0]) or {} + msg0 = self._maybe_mapping(choice0.get("message")) or {} + content = self._extract_text_content(msg0.get("content")) + finish_reason = str(choice0.get("finish_reason") or "stop") + + raw_tool_calls: list[Any] = [] + reasoning_content = msg0.get("reasoning_content") + for ch in choices: + ch_map = self._maybe_mapping(ch) or {} + m = self._maybe_mapping(ch_map.get("message")) or {} + tool_calls = m.get("tool_calls") + if isinstance(tool_calls, list) and tool_calls: + raw_tool_calls.extend(tool_calls) + if ch_map.get("finish_reason") in ("tool_calls", "stop"): + finish_reason = str(ch_map["finish_reason"]) + if not content: + content = self._extract_text_content(m.get("content")) + if not reasoning_content: + reasoning_content = m.get("reasoning_content") + + parsed_tool_calls = [] + for tc in raw_tool_calls: + tc_map = self._maybe_mapping(tc) or {} + fn = self._maybe_mapping(tc_map.get("function")) or {} + args = fn.get("arguments", {}) + if isinstance(args, str): + args = json_repair.loads(args) + ec, prov, fn_prov = _extract_tc_extras(tc) + parsed_tool_calls.append(ToolCallRequest( + id=_short_tool_id(), + name=str(fn.get("name") or ""), + arguments=args if isinstance(args, dict) else {}, + extra_content=ec, + provider_specific_fields=prov, + function_provider_specific_fields=fn_prov, + )) + + # Fallback: some local models embed tool calls in text content + if not parsed_tool_calls and content: + import re as _re + stripped = _re.sub(r".*?", "", content, flags=_re.DOTALL).strip() + parsed_tool_calls, extracted_content = OpenAICompatProvider._extract_text_tool_calls(stripped) + if parsed_tool_calls: + content = extracted_content + + # Strip EOS/EOT tokens leaked by some models (e.g. Ministral via llama.cpp) + if content: + for tok in ("<|im_end|>", "<|endoftext|>", "", "<|eot_id|>"): + content = content.replace(tok, "") + content = content.rstrip() or None + + return LLMResponse( + content=content, + tool_calls=parsed_tool_calls, + finish_reason=finish_reason, + usage=self._extract_usage(response_map), + reasoning_content=reasoning_content if isinstance(reasoning_content, str) else None, + ) + + if not response.choices: + return LLMResponse(content="Error: API returned empty choices.", finish_reason="error") + + choice = response.choices[0] + msg = choice.message + content = msg.content + finish_reason = choice.finish_reason + + raw_tool_calls: list[Any] = [] + for ch in response.choices: + m = ch.message + if hasattr(m, "tool_calls") and m.tool_calls: + raw_tool_calls.extend(m.tool_calls) + if ch.finish_reason in ("tool_calls", "stop"): + finish_reason = ch.finish_reason + if not content and m.content: + content = m.content + + tool_calls = [] + for tc in raw_tool_calls: + args = tc.function.arguments + if isinstance(args, str): + args = json_repair.loads(args) + ec, prov, fn_prov = _extract_tc_extras(tc) + tool_calls.append(ToolCallRequest( + id=_short_tool_id(), + name=tc.function.name, + arguments=args, + extra_content=ec, + provider_specific_fields=prov, + function_provider_specific_fields=fn_prov, + )) + + # Fallback: some local models embed tool calls in text content + if not tool_calls and content: + import re as _re + stripped = _re.sub(r".*?", "", content, flags=_re.DOTALL).strip() + tool_calls, extracted_content = OpenAICompatProvider._extract_text_tool_calls(stripped) + if tool_calls: + content = extracted_content + + # Strip EOS/EOT tokens leaked by some models (e.g. Ministral via llama.cpp) + if content: + for tok in ("<|im_end|>", "<|endoftext|>", "", "<|eot_id|>"): + content = content.replace(tok, "") + content = content.rstrip() or None + + return LLMResponse( + content=content, + tool_calls=tool_calls, + finish_reason=finish_reason or "stop", + usage=self._extract_usage(response), + reasoning_content=getattr(msg, "reasoning_content", None) or None, + ) + + @classmethod + def _parse_chunks(cls, chunks: list[Any]) -> LLMResponse: + content_parts: list[str] = [] + tc_bufs: dict[int, dict[str, Any]] = {} + finish_reason = "stop" + usage: dict[str, int] = {} + + def _accum_tc(tc: Any, idx_hint: int) -> None: + """Accumulate one streaming tool-call delta into *tc_bufs*.""" + tc_index: int = _get(tc, "index") if _get(tc, "index") is not None else idx_hint + buf = tc_bufs.setdefault(tc_index, { + "id": "", "name": "", "arguments": "", + "extra_content": None, "prov": None, "fn_prov": None, + }) + tc_id = _get(tc, "id") + if tc_id: + buf["id"] = str(tc_id) + fn = _get(tc, "function") + if fn is not None: + fn_name = _get(fn, "name") + if fn_name: + buf["name"] = str(fn_name) + fn_args = _get(fn, "arguments") + if fn_args: + buf["arguments"] += str(fn_args) + ec, prov, fn_prov = _extract_tc_extras(tc) + if ec: + buf["extra_content"] = ec + if prov: + buf["prov"] = prov + if fn_prov: + buf["fn_prov"] = fn_prov + + for chunk in chunks: + if isinstance(chunk, str): + content_parts.append(chunk) + continue + + chunk_map = cls._maybe_mapping(chunk) + if chunk_map is not None: + choices = chunk_map.get("choices") or [] + if not choices: + usage = cls._extract_usage(chunk_map) or usage + text = cls._extract_text_content( + chunk_map.get("content") or chunk_map.get("output_text") + ) + if text: + content_parts.append(text) + continue + choice = cls._maybe_mapping(choices[0]) or {} + if choice.get("finish_reason"): + finish_reason = str(choice["finish_reason"]) + delta = cls._maybe_mapping(choice.get("delta")) or {} + text = cls._extract_text_content(delta.get("content")) + if text: + content_parts.append(text) + for idx, tc in enumerate(delta.get("tool_calls") or []): + _accum_tc(tc, idx) + usage = cls._extract_usage(chunk_map) or usage + continue + + if not chunk.choices: + usage = cls._extract_usage(chunk) or usage + continue + choice = chunk.choices[0] + if choice.finish_reason: + finish_reason = choice.finish_reason + delta = choice.delta + if delta and delta.content: + content_parts.append(delta.content) + for tc in (delta.tool_calls or []) if delta else []: + _accum_tc(tc, getattr(tc, "index", 0)) + + return LLMResponse( + content="".join(content_parts) or None, + tool_calls=[ + ToolCallRequest( + id=b["id"] or _short_tool_id(), + name=b["name"], + arguments=json_repair.loads(b["arguments"]) if b["arguments"] else {}, + extra_content=b.get("extra_content"), + provider_specific_fields=b.get("prov"), + function_provider_specific_fields=b.get("fn_prov"), + ) + for b in tc_bufs.values() + ], + finish_reason=finish_reason, + usage=usage, + ) + + @staticmethod + def _handle_error(e: Exception) -> LLMResponse: + body = getattr(e, "doc", None) or getattr(getattr(e, "response", None), "text", None) + msg = f"Error: {body.strip()[:500]}" if body and body.strip() else f"Error calling LLM: {e}" + return LLMResponse(content=msg, finish_reason="error") + + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ + + async def chat( + self, + messages: list[dict[str, Any]], + tools: list[dict[str, Any]] | None = None, + model: str | None = None, + max_tokens: int = 4096, + temperature: float = 0.7, + reasoning_effort: str | None = None, + tool_choice: str | dict[str, Any] | None = None, + ) -> LLMResponse: + kwargs = self._build_kwargs( + messages, tools, model, max_tokens, temperature, + reasoning_effort, tool_choice, + ) + try: + return self._parse(await self._client.chat.completions.create(**kwargs)) + except Exception as e: + return self._handle_error(e) + + async def chat_stream( + self, + messages: list[dict[str, Any]], + tools: list[dict[str, Any]] | None = None, + model: str | None = None, + max_tokens: int = 4096, + temperature: float = 0.7, + reasoning_effort: str | None = None, + tool_choice: str | dict[str, Any] | None = None, + on_content_delta: Callable[[str], Awaitable[None]] | None = None, + ) -> LLMResponse: + kwargs = self._build_kwargs( + messages, tools, model, max_tokens, temperature, + reasoning_effort, tool_choice, + ) + kwargs["stream"] = True + kwargs["stream_options"] = {"include_usage": True} + try: + stream = await self._client.chat.completions.create(**kwargs) + chunks: list[Any] = [] + async for chunk in stream: + chunks.append(chunk) + if on_content_delta and chunk.choices: + text = getattr(chunk.choices[0].delta, "content", None) + if text: + await on_content_delta(text) + return self._parse_chunks(chunks) + except Exception as e: + return self._handle_error(e) + + def get_default_model(self) -> str: + return self.default_model + + @staticmethod + def _extract_text_tool_calls(content: str) -> tuple[list[ToolCallRequest], str | None]: + """Parse tool calls embedded as JSON or XML in text content. + + Handles models that emit tool calls in text rather than structured fields: + Qwen3 (), Nemotron ([TOOL_CALLS]), Python code-blocks, legacy XML, raw JSON. + Returns (tool_calls, remaining_content). remaining_content is None when the full + content is a tool-call envelope. + """ + import re + + # Qwen3 format: {"name": "...", "arguments": {...}} + qwen_calls = [] + for m in re.finditer(r"(.*?)", content, re.DOTALL): + try: + obj = json_repair.loads(m.group(1).strip()) + if isinstance(obj, dict) and isinstance(obj.get("name"), str) and isinstance(obj.get("arguments"), dict): + qwen_calls.append(ToolCallRequest(id=_short_tool_id(), name=obj["name"], arguments=obj["arguments"])) + except Exception: + pass + if qwen_calls: + first = re.search(r"", content) + preamble = (content[:first.start()].strip() if first else None) or None + logger.info("_extract_text_tool_calls: extracted {} Qwen3 tool call(s): {}", len(qwen_calls), [c.name for c in qwen_calls]) + return qwen_calls, preamble + + # Nemotron/Orchestrator format: [TOOL_CALLS]name[ARGS]{json} + tc_calls = [] + for m in re.finditer(r"\[TOOL_CALLS\]([\w_\-]+)\[ARGS\](\{)", content, re.DOTALL): + try: + args = json_repair.loads(content[m.start(2):]) + if not isinstance(args, dict): + args = {} + tc_calls.append(ToolCallRequest(id=_short_tool_id(), name=m.group(1).strip(), arguments=args)) + except Exception: + pass + if tc_calls: + first = re.search(r"\[TOOL_CALLS\]", content) + preamble = (content[:first.start()].strip() if first else None) or None + logger.info("_extract_text_tool_calls: extracted {} [TOOL_CALLS] tool call(s): {}", len(tc_calls), [c.name for c in tc_calls]) + return tc_calls, preamble + + # Python code-block: ```python\nfn(key="val")\n``` + import ast as _ast + py_block = re.search(r"```(?:python)?\s*\n?([\w_]+\(.*?\))\s*\n?```", content, re.DOTALL) + if py_block: + call_src = py_block.group(1).strip() + fn_m = re.match(r"([\w_]+)\((.*)\)$", call_src, re.DOTALL) + if fn_m: + try: + dummy = f"_f({fn_m.group(2).strip()})" + tree = _ast.parse(dummy, mode="eval") + arguments = {kw.arg: _ast.literal_eval(kw.value) for kw in tree.body.keywords if kw.arg} # type: ignore[attr-defined] + py_call = ToolCallRequest(id=_short_tool_id(), name=fn_m.group(1), arguments=arguments) + logger.info("_extract_text_tool_calls: extracted Python code-block tool call: {}", fn_m.group(1)) + return [py_call], content[:py_block.start()].strip() or None + except Exception: + pass + + # Legacy XML: v... + xml_match = re.search(r"", content) + if xml_match: + calls = [] + xml_body = content[xml_match.start():] + for fn_m in re.finditer(r"(.*?)(?:||\Z)", xml_body, re.DOTALL): + arguments = {p.group(1): p.group(2).strip() + for p in re.finditer(r"(.*?)", fn_m.group(2), re.DOTALL)} + if arguments: + calls.append(ToolCallRequest(id=_short_tool_id(), name=fn_m.group(1), arguments=arguments)) + if calls: + logger.info("_extract_text_tool_calls: extracted {} XML tool call(s): {}", len(calls), [c.name for c in calls]) + return calls, content[:xml_match.start()].strip() or None + + # Raw JSON fallback + match = re.search(r"[{\[]", content) + if not match: + return [], content + try: + parsed = json_repair.loads(content[match.start():].strip()) + except Exception: + return [], content + candidates = parsed if isinstance(parsed, list) else [parsed] + calls = [ + ToolCallRequest(id=_short_tool_id(), name=obj["name"], arguments=obj["arguments"]) + for obj in candidates + if isinstance(obj, dict) and isinstance(obj.get("name"), str) and isinstance(obj.get("arguments"), dict) + ] + if calls: + logger.info("_extract_text_tool_calls: extracted {} JSON tool call(s): {}", len(calls), [c.name for c in calls]) + return calls, content[:match.start()].strip() or None + return [], content diff --git a/nanobot/providers/registry.py b/nanobot/providers/registry.py index 9b691dcac0a..e42e1f95e1c 100644 --- a/nanobot/providers/registry.py +++ b/nanobot/providers/registry.py @@ -4,7 +4,7 @@ Adding a new provider: 1. Add a ProviderSpec to PROVIDERS below. 2. Add a field to ProvidersConfig in config/schema.py. - Done. Env vars, prefixing, config matching, status display all derive from here. + Done. Env vars, config matching, status display all derive from here. Order matters β€” it controls match priority and fallback. Gateways first. Every entry writes out all fields so you can copy-paste as a template. @@ -12,9 +12,11 @@ from __future__ import annotations -from dataclasses import dataclass, field +from dataclasses import dataclass from typing import Any +from pydantic.alias_generators import to_snake + @dataclass(frozen=True) class ProviderSpec: @@ -28,12 +30,12 @@ class ProviderSpec: # identity name: str # config field name, e.g. "dashscope" keywords: tuple[str, ...] # model-name keywords for matching (lowercase) - env_key: str # LiteLLM env var, e.g. "DASHSCOPE_API_KEY" + env_key: str # env var for API key, e.g. "DASHSCOPE_API_KEY" display_name: str = "" # shown in `nanobot status` - # model prefixing - litellm_prefix: str = "" # "dashscope" β†’ model becomes "dashscope/{model}" - skip_prefixes: tuple[str, ...] = () # don't prefix if model already starts with these + # which provider implementation to use + # "openai_compat" | "anthropic" | "azure_openai" | "openai_codex" + backend: str = "openai_compat" # extra env vars, e.g. (("ZHIPUAI_API_KEY", "{api_key}"),) env_extras: tuple[tuple[str, str], ...] = () @@ -43,19 +45,18 @@ class ProviderSpec: is_local: bool = False # local deployment (vLLM, Ollama) detect_by_key_prefix: str = "" # match api_key prefix, e.g. "sk-or-" detect_by_base_keyword: str = "" # match substring in api_base URL - default_api_base: str = "" # fallback base URL + default_api_base: str = "" # OpenAI-compatible base URL for this provider # gateway behavior - strip_model_prefix: bool = False # strip "provider/" before re-prefixing - litellm_kwargs: dict[str, Any] = field(default_factory=dict) # extra kwargs passed to LiteLLM + strip_model_prefix: bool = False # strip "provider/" before sending to gateway # per-model param overrides, e.g. (("kimi-k2.5", {"temperature": 1.0}),) model_overrides: tuple[tuple[str, dict[str, Any]], ...] = () # OAuth-based providers (e.g., OpenAI Codex) don't use API keys - is_oauth: bool = False # if True, uses OAuth flow instead of API key + is_oauth: bool = False - # Direct providers bypass LiteLLM entirely (e.g., CustomProvider) + # Direct providers skip API-key validation (user supplies everything) is_direct: bool = False # Provider supports cache_control on content blocks (e.g. Anthropic prompt caching) @@ -71,13 +72,13 @@ def label(self) -> str: # --------------------------------------------------------------------------- PROVIDERS: tuple[ProviderSpec, ...] = ( - # === Custom (direct OpenAI-compatible endpoint, bypasses LiteLLM) ====== + # === Custom (direct OpenAI-compatible endpoint) ======================== ProviderSpec( name="custom", keywords=(), env_key="", display_name="Custom", - litellm_prefix="", + backend="openai_compat", is_direct=True, ), @@ -87,7 +88,7 @@ def label(self) -> str: keywords=("azure", "azure-openai"), env_key="", display_name="Azure OpenAI", - litellm_prefix="", + backend="azure_openai", is_direct=True, ), # === Gateways (detected by api_key / api_base, not model name) ========= @@ -98,36 +99,26 @@ def label(self) -> str: keywords=("openrouter",), env_key="OPENROUTER_API_KEY", display_name="OpenRouter", - litellm_prefix="openrouter", # anthropic/claude-3 β†’ openrouter/anthropic/claude-3 - skip_prefixes=(), - env_extras=(), + backend="openai_compat", is_gateway=True, - is_local=False, detect_by_key_prefix="sk-or-", detect_by_base_keyword="openrouter", default_api_base="https://openrouter.ai/api/v1", - strip_model_prefix=False, - model_overrides=(), supports_prompt_caching=True, ), # AiHubMix: global gateway, OpenAI-compatible interface. - # strip_model_prefix=True: it doesn't understand "anthropic/claude-3", - # so we strip to bare "claude-3" then re-prefix as "openai/claude-3". + # strip_model_prefix=True: doesn't understand "anthropic/claude-3", + # strips to bare "claude-3". ProviderSpec( name="aihubmix", keywords=("aihubmix",), - env_key="OPENAI_API_KEY", # OpenAI-compatible + env_key="OPENAI_API_KEY", display_name="AiHubMix", - litellm_prefix="openai", # β†’ openai/{model} - skip_prefixes=(), - env_extras=(), + backend="openai_compat", is_gateway=True, - is_local=False, - detect_by_key_prefix="", detect_by_base_keyword="aihubmix", default_api_base="https://aihubmix.com/v1", - strip_model_prefix=True, # anthropic/claude-3 β†’ claude-3 β†’ openai/claude-3 - model_overrides=(), + strip_model_prefix=True, ), # SiliconFlow (η‘…εŸΊζ΅εŠ¨): OpenAI-compatible gateway, model names keep org prefix ProviderSpec( @@ -135,16 +126,10 @@ def label(self) -> str: keywords=("siliconflow",), env_key="OPENAI_API_KEY", display_name="SiliconFlow", - litellm_prefix="openai", - skip_prefixes=(), - env_extras=(), + backend="openai_compat", is_gateway=True, - is_local=False, - detect_by_key_prefix="", detect_by_base_keyword="siliconflow", default_api_base="https://api.siliconflow.cn/v1", - strip_model_prefix=False, - model_overrides=(), ), # VolcEngine (η«ε±±εΌ•ζ“Ž): OpenAI-compatible gateway, pay-per-use models @@ -153,16 +138,10 @@ def label(self) -> str: keywords=("volcengine", "volces", "ark"), env_key="OPENAI_API_KEY", display_name="VolcEngine", - litellm_prefix="volcengine", - skip_prefixes=(), - env_extras=(), + backend="openai_compat", is_gateway=True, - is_local=False, - detect_by_key_prefix="", detect_by_base_keyword="volces", default_api_base="https://ark.cn-beijing.volces.com/api/v3", - strip_model_prefix=False, - model_overrides=(), ), # VolcEngine Coding Plan (η«ε±±εΌ•ζ“Ž Coding Plan): same key as volcengine @@ -171,16 +150,10 @@ def label(self) -> str: keywords=("volcengine-plan",), env_key="OPENAI_API_KEY", display_name="VolcEngine Coding Plan", - litellm_prefix="volcengine", - skip_prefixes=(), - env_extras=(), + backend="openai_compat", is_gateway=True, - is_local=False, - detect_by_key_prefix="", - detect_by_base_keyword="", default_api_base="https://ark.cn-beijing.volces.com/api/coding/v3", strip_model_prefix=True, - model_overrides=(), ), # BytePlus: VolcEngine international, pay-per-use models @@ -189,16 +162,11 @@ def label(self) -> str: keywords=("byteplus",), env_key="OPENAI_API_KEY", display_name="BytePlus", - litellm_prefix="volcengine", - skip_prefixes=(), - env_extras=(), + backend="openai_compat", is_gateway=True, - is_local=False, - detect_by_key_prefix="", detect_by_base_keyword="bytepluses", default_api_base="https://ark.ap-southeast.bytepluses.com/api/v3", strip_model_prefix=True, - model_overrides=(), ), # BytePlus Coding Plan: same key as byteplus @@ -207,250 +175,146 @@ def label(self) -> str: keywords=("byteplus-plan",), env_key="OPENAI_API_KEY", display_name="BytePlus Coding Plan", - litellm_prefix="volcengine", - skip_prefixes=(), - env_extras=(), + backend="openai_compat", is_gateway=True, - is_local=False, - detect_by_key_prefix="", - detect_by_base_keyword="", default_api_base="https://ark.ap-southeast.bytepluses.com/api/coding/v3", strip_model_prefix=True, - model_overrides=(), ), # === Standard providers (matched by model-name keywords) =============== - # Anthropic: LiteLLM recognizes "claude-*" natively, no prefix needed. + # Anthropic: native Anthropic SDK ProviderSpec( name="anthropic", keywords=("anthropic", "claude"), env_key="ANTHROPIC_API_KEY", display_name="Anthropic", - litellm_prefix="", - skip_prefixes=(), - env_extras=(), - is_gateway=False, - is_local=False, - detect_by_key_prefix="", - detect_by_base_keyword="", - default_api_base="", - strip_model_prefix=False, - model_overrides=(), + backend="anthropic", supports_prompt_caching=True, ), - # OpenAI: LiteLLM recognizes "gpt-*" natively, no prefix needed. + # OpenAI: SDK default base URL (no override needed) ProviderSpec( name="openai", keywords=("openai", "gpt"), env_key="OPENAI_API_KEY", display_name="OpenAI", - litellm_prefix="", - skip_prefixes=(), - env_extras=(), - is_gateway=False, - is_local=False, - detect_by_key_prefix="", - detect_by_base_keyword="", - default_api_base="", - strip_model_prefix=False, - model_overrides=(), + backend="openai_compat", ), - # OpenAI Codex: uses OAuth, not API key. + # OpenAI Codex: OAuth-based, dedicated provider ProviderSpec( name="openai_codex", keywords=("openai-codex",), - env_key="", # OAuth-based, no API key + env_key="", display_name="OpenAI Codex", - litellm_prefix="", # Not routed through LiteLLM - skip_prefixes=(), - env_extras=(), - is_gateway=False, - is_local=False, - detect_by_key_prefix="", + backend="openai_codex", detect_by_base_keyword="codex", default_api_base="https://chatgpt.com/backend-api", - strip_model_prefix=False, - model_overrides=(), - is_oauth=True, # OAuth-based authentication + is_oauth=True, ), - # Github Copilot: uses OAuth, not API key. + # GitHub Copilot: OAuth-based ProviderSpec( name="github_copilot", keywords=("github_copilot", "copilot"), - env_key="", # OAuth-based, no API key + env_key="", display_name="Github Copilot", - litellm_prefix="github_copilot", # github_copilot/model β†’ github_copilot/model - skip_prefixes=("github_copilot/",), - env_extras=(), - is_gateway=False, - is_local=False, - detect_by_key_prefix="", - detect_by_base_keyword="", - default_api_base="", - strip_model_prefix=False, - model_overrides=(), - is_oauth=True, # OAuth-based authentication + backend="openai_compat", + default_api_base="https://api.githubcopilot.com", + is_oauth=True, ), - # DeepSeek: needs "deepseek/" prefix for LiteLLM routing. + # DeepSeek: OpenAI-compatible at api.deepseek.com ProviderSpec( name="deepseek", keywords=("deepseek",), env_key="DEEPSEEK_API_KEY", display_name="DeepSeek", - litellm_prefix="deepseek", # deepseek-chat β†’ deepseek/deepseek-chat - skip_prefixes=("deepseek/",), # avoid double-prefix - env_extras=(), - is_gateway=False, - is_local=False, - detect_by_key_prefix="", - detect_by_base_keyword="", - default_api_base="", - strip_model_prefix=False, - model_overrides=(), + backend="openai_compat", + default_api_base="https://api.deepseek.com", ), - # Gemini: needs "gemini/" prefix for LiteLLM. + # Gemini: Google's OpenAI-compatible endpoint ProviderSpec( name="gemini", keywords=("gemini",), env_key="GEMINI_API_KEY", display_name="Gemini", - litellm_prefix="gemini", # gemini-pro β†’ gemini/gemini-pro - skip_prefixes=("gemini/",), # avoid double-prefix - env_extras=(), - is_gateway=False, - is_local=False, - detect_by_key_prefix="", - detect_by_base_keyword="", - default_api_base="", - strip_model_prefix=False, - model_overrides=(), + backend="openai_compat", + default_api_base="https://generativelanguage.googleapis.com/v1beta/openai/", ), - # Zhipu: LiteLLM uses "zai/" prefix. - # Also mirrors key to ZHIPUAI_API_KEY (some LiteLLM paths check that). - # skip_prefixes: don't add "zai/" when already routed via gateway. + # Zhipu (ζ™Ίθ°±): OpenAI-compatible at open.bigmodel.cn ProviderSpec( name="zhipu", keywords=("zhipu", "glm", "zai"), env_key="ZAI_API_KEY", display_name="Zhipu AI", - litellm_prefix="zai", # glm-4 β†’ zai/glm-4 - skip_prefixes=("zhipu/", "zai/", "openrouter/", "hosted_vllm/"), + backend="openai_compat", env_extras=(("ZHIPUAI_API_KEY", "{api_key}"),), - is_gateway=False, - is_local=False, - detect_by_key_prefix="", - detect_by_base_keyword="", - default_api_base="", - strip_model_prefix=False, - model_overrides=(), + default_api_base="https://open.bigmodel.cn/api/paas/v4", ), - # DashScope: Qwen models, needs "dashscope/" prefix. + # DashScope (ι€šδΉ‰): Qwen models, OpenAI-compatible endpoint ProviderSpec( name="dashscope", keywords=("qwen", "dashscope"), env_key="DASHSCOPE_API_KEY", display_name="DashScope", - litellm_prefix="dashscope", # qwen-max β†’ dashscope/qwen-max - skip_prefixes=("dashscope/", "openrouter/"), - env_extras=(), - is_gateway=False, - is_local=False, - detect_by_key_prefix="", - detect_by_base_keyword="", - default_api_base="", - strip_model_prefix=False, - model_overrides=(), + backend="openai_compat", + default_api_base="https://dashscope.aliyuncs.com/compatible-mode/v1", ), - # Moonshot: Kimi models, needs "moonshot/" prefix. - # LiteLLM requires MOONSHOT_API_BASE env var to find the endpoint. - # Kimi K2.5 API enforces temperature >= 1.0. + # Moonshot (ζœˆδΉ‹ζš—ι’): Kimi models. K2.5 enforces temperature >= 1.0. ProviderSpec( name="moonshot", keywords=("moonshot", "kimi"), env_key="MOONSHOT_API_KEY", display_name="Moonshot", - litellm_prefix="moonshot", # kimi-k2.5 β†’ moonshot/kimi-k2.5 - skip_prefixes=("moonshot/", "openrouter/"), - env_extras=(("MOONSHOT_API_BASE", "{api_base}"),), - is_gateway=False, - is_local=False, - detect_by_key_prefix="", - detect_by_base_keyword="", - default_api_base="https://api.moonshot.ai/v1", # intl; use api.moonshot.cn for China - strip_model_prefix=False, + backend="openai_compat", + default_api_base="https://api.moonshot.ai/v1", model_overrides=(("kimi-k2.5", {"temperature": 1.0}),), ), - # MiniMax: needs "minimax/" prefix for LiteLLM routing. - # Uses OpenAI-compatible API at api.minimax.io/v1. + # MiniMax: OpenAI-compatible API ProviderSpec( name="minimax", keywords=("minimax",), env_key="MINIMAX_API_KEY", display_name="MiniMax", - litellm_prefix="minimax", # MiniMax-M2.1 β†’ minimax/MiniMax-M2.1 - skip_prefixes=("minimax/", "openrouter/"), - env_extras=(), - is_gateway=False, - is_local=False, - detect_by_key_prefix="", - detect_by_base_keyword="", + backend="openai_compat", default_api_base="https://api.minimax.io/v1", - strip_model_prefix=False, - model_overrides=(), ), - # Mistral AI: OpenAI-compatible API at api.mistral.ai/v1. + # Mistral AI: OpenAI-compatible API ProviderSpec( name="mistral", keywords=("mistral",), env_key="MISTRAL_API_KEY", display_name="Mistral", - litellm_prefix="mistral", # mistral-large-latest β†’ mistral/mistral-large-latest - skip_prefixes=("mistral/",), # avoid double-prefix - env_extras=(), - is_gateway=False, - is_local=False, - detect_by_key_prefix="", - detect_by_base_keyword="", + backend="openai_compat", default_api_base="https://api.mistral.ai/v1", - strip_model_prefix=False, - model_overrides=(), + ), + # Step Fun (ι˜Άθ·ƒζ˜ŸθΎ°): OpenAI-compatible API + ProviderSpec( + name="stepfun", + keywords=("stepfun", "step"), + env_key="STEPFUN_API_KEY", + display_name="Step Fun", + backend="openai_compat", + default_api_base="https://api.stepfun.com/v1", ), # === Local deployment (matched by config key, NOT by api_base) ========= - # vLLM / any OpenAI-compatible local server. - # Detected when config key is "vllm" (provider_name="vllm"). + # vLLM / any OpenAI-compatible local server ProviderSpec( name="vllm", keywords=("vllm",), env_key="HOSTED_VLLM_API_KEY", display_name="vLLM/Local", - litellm_prefix="hosted_vllm", # Llama-3-8B β†’ hosted_vllm/Llama-3-8B - skip_prefixes=(), - env_extras=(), - is_gateway=False, + backend="openai_compat", is_local=True, - detect_by_key_prefix="", - detect_by_base_keyword="", - default_api_base="", # user must provide in config - strip_model_prefix=False, - model_overrides=(), ), - # === Ollama (local, OpenAI-compatible) =================================== + # Ollama (local, OpenAI-compatible) ProviderSpec( name="ollama", keywords=("ollama", "nemotron"), env_key="OLLAMA_API_KEY", display_name="Ollama", - litellm_prefix="ollama_chat", - skip_prefixes=("ollama/", "ollama_chat/"), - env_extras=(), - is_gateway=False, + backend="openai_compat", is_local=True, - detect_by_key_prefix="", detect_by_base_keyword="11434", - default_api_base="http://localhost:11434", - strip_model_prefix=False, - model_overrides=(), + default_api_base="http://localhost:11434/v1", ), # === OpenVINO Model Server (direct, local, OpenAI-compatible at /v3) === ProviderSpec( @@ -458,29 +322,20 @@ def label(self) -> str: keywords=("openvino", "ovms"), env_key="", display_name="OpenVINO Model Server", - litellm_prefix="", + backend="openai_compat", is_direct=True, is_local=True, default_api_base="http://localhost:8000/v3", ), # === Auxiliary (not a primary LLM provider) ============================ - # Groq: mainly used for Whisper voice transcription, also usable for LLM. - # Needs "groq/" prefix for LiteLLM routing. Placed last β€” it rarely wins fallback. + # Groq: mainly used for Whisper voice transcription, also usable for LLM ProviderSpec( name="groq", keywords=("groq",), env_key="GROQ_API_KEY", display_name="Groq", - litellm_prefix="groq", # llama3-8b-8192 β†’ groq/llama3-8b-8192 - skip_prefixes=("groq/",), # avoid double-prefix - env_extras=(), - is_gateway=False, - is_local=False, - detect_by_key_prefix="", - detect_by_base_keyword="", - default_api_base="", - strip_model_prefix=False, - model_overrides=(), + backend="openai_compat", + default_api_base="https://api.groq.com/openai/v1", ), ) @@ -490,62 +345,10 @@ def label(self) -> str: # --------------------------------------------------------------------------- -def find_by_model(model: str) -> ProviderSpec | None: - """Match a standard provider by model-name keyword (case-insensitive). - Skips gateways/local β€” those are matched by api_key/api_base instead.""" - model_lower = model.lower() - model_normalized = model_lower.replace("-", "_") - model_prefix = model_lower.split("/", 1)[0] if "/" in model_lower else "" - normalized_prefix = model_prefix.replace("-", "_") - std_specs = [s for s in PROVIDERS if not s.is_gateway and not s.is_local] - - # Prefer explicit provider prefix β€” prevents `github-copilot/...codex` matching openai_codex. - for spec in std_specs: - if model_prefix and normalized_prefix == spec.name: - return spec - - for spec in std_specs: - if any( - kw in model_lower or kw.replace("-", "_") in model_normalized for kw in spec.keywords - ): - return spec - return None - - -def find_gateway( - provider_name: str | None = None, - api_key: str | None = None, - api_base: str | None = None, -) -> ProviderSpec | None: - """Detect gateway/local provider. - - Priority: - 1. provider_name β€” if it maps to a gateway/local spec, use it directly. - 2. api_key prefix β€” e.g. "sk-or-" β†’ OpenRouter. - 3. api_base keyword β€” e.g. "aihubmix" in URL β†’ AiHubMix. - - A standard provider with a custom api_base (e.g. DeepSeek behind a proxy) - will NOT be mistaken for vLLM β€” the old fallback is gone. - """ - # 1. Direct match by config key - if provider_name: - spec = find_by_name(provider_name) - if spec and (spec.is_gateway or spec.is_local): - return spec - - # 2. Auto-detect by api_key prefix / api_base keyword - for spec in PROVIDERS: - if spec.detect_by_key_prefix and api_key and api_key.startswith(spec.detect_by_key_prefix): - return spec - if spec.detect_by_base_keyword and api_base and spec.detect_by_base_keyword in api_base: - return spec - - return None - - def find_by_name(name: str) -> ProviderSpec | None: """Find a provider spec by config field name, e.g. "dashscope".""" + normalized = to_snake(name.replace("-", "_")) for spec in PROVIDERS: - if spec.name == name: + if spec.name == normalized: return spec return None diff --git a/nanobot/utils/helpers.py b/nanobot/utils/helpers.py index 06441b878bc..afd0bd1b20f 100644 --- a/nanobot/utils/helpers.py +++ b/nanobot/utils/helpers.py @@ -53,11 +53,24 @@ def timestamp() -> str: return datetime.now().isoformat() -def current_time_str() -> str: - """Human-readable current time with weekday and timezone, e.g. '2026-03-15 22:30 (Saturday) (CST)'.""" - now = datetime.now().strftime("%Y-%m-%d %H:%M (%A)") - tz = time.strftime("%Z") or "UTC" - return f"{now} ({tz})" +def current_time_str(timezone: str | None = None) -> str: + """Human-readable current time with weekday and UTC offset. + + When *timezone* is a valid IANA name (e.g. ``"Asia/Shanghai"``), the time + is converted to that zone. Otherwise falls back to the host local time. + """ + from zoneinfo import ZoneInfo + + try: + tz = ZoneInfo(timezone) if timezone else None + except (KeyError, Exception): + tz = None + + now = datetime.now(tz=tz) if tz else datetime.now().astimezone() + offset = now.strftime("%z") + offset_fmt = f"{offset[:3]}:{offset[3:]}" if len(offset) == 5 else offset + tz_name = timezone or (time.strftime("%Z") or "UTC") + return f"{now.strftime('%Y-%m-%d %H:%M (%A)')} ({tz_name}, UTC{offset_fmt})" _UNSAFE_CHARS = re.compile(r'[<>:"/\\|?*]') diff --git a/pyproject.toml b/pyproject.toml index b7657206858..501a6bb45ad 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,7 +19,7 @@ classifiers = [ dependencies = [ "typer>=0.20.0,<1.0.0", - "litellm>=1.82.1,<2.0.0", + "anthropic>=0.45.0,<1.0.0", "pydantic>=2.12.0,<3.0.0", "pydantic-settings>=2.12.0,<3.0.0", "websockets>=16.0,<17.0", @@ -70,10 +70,8 @@ langsmith = [ dev = [ "pytest>=9.0.0,<10.0.0", "pytest-asyncio>=1.3.0,<2.0.0", + "pytest-cov>=6.0.0,<7.0.0", "ruff>=0.1.0", - "matrix-nio[e2e]>=0.25.2", - "mistune>=3.0.0,<4.0.0", - "nh3>=0.2.17,<1.0.0", ] [project.scripts] @@ -122,3 +120,16 @@ ignore = ["E501"] [tool.pytest.ini_options] asyncio_mode = "auto" testpaths = ["tests"] + +[tool.coverage.run] +source = ["nanobot"] +omit = ["tests/*", "**/tests/*"] + +[tool.coverage.report] +exclude_lines = [ + "pragma: no cover", + "def __repr__", + "raise NotImplementedError", + "if __name__ == .__main__.:", + "if TYPE_CHECKING:", +] diff --git a/tests/test_consolidate_offset.py b/tests/agent/test_consolidate_offset.py similarity index 100% rename from tests/test_consolidate_offset.py rename to tests/agent/test_consolidate_offset.py diff --git a/tests/test_context_prompt_cache.py b/tests/agent/test_context_prompt_cache.py similarity index 100% rename from tests/test_context_prompt_cache.py rename to tests/agent/test_context_prompt_cache.py diff --git a/tests/test_evaluator.py b/tests/agent/test_evaluator.py similarity index 100% rename from tests/test_evaluator.py rename to tests/agent/test_evaluator.py diff --git a/tests/agent/test_gemini_thought_signature.py b/tests/agent/test_gemini_thought_signature.py new file mode 100644 index 00000000000..320c1ecd2cf --- /dev/null +++ b/tests/agent/test_gemini_thought_signature.py @@ -0,0 +1,200 @@ +"""Tests for Gemini thought_signature round-trip through extra_content. + +The Gemini OpenAI-compatibility API returns tool calls with an extra_content +field: ``{"google": {"thought_signature": "..."}}``. This MUST survive the +parse β†’ serialize round-trip so the model can continue reasoning. +""" + +from types import SimpleNamespace +from unittest.mock import patch + +from nanobot.providers.base import ToolCallRequest +from nanobot.providers.openai_compat_provider import OpenAICompatProvider + + +GEMINI_EXTRA = {"google": {"thought_signature": "sig-abc-123"}} + + +# ── ToolCallRequest serialization ────────────────────────────────────── + +def test_tool_call_request_serializes_extra_content() -> None: + tc = ToolCallRequest( + id="abc123xyz", + name="read_file", + arguments={"path": "todo.md"}, + extra_content=GEMINI_EXTRA, + ) + + payload = tc.to_openai_tool_call() + + assert payload["extra_content"] == GEMINI_EXTRA + assert payload["function"]["arguments"] == '{"path": "todo.md"}' + + +def test_tool_call_request_serializes_provider_fields() -> None: + tc = ToolCallRequest( + id="abc123xyz", + name="read_file", + arguments={"path": "todo.md"}, + provider_specific_fields={"custom_key": "custom_val"}, + function_provider_specific_fields={"inner": "value"}, + ) + + payload = tc.to_openai_tool_call() + + assert payload["provider_specific_fields"] == {"custom_key": "custom_val"} + assert payload["function"]["provider_specific_fields"] == {"inner": "value"} + + +def test_tool_call_request_omits_absent_extras() -> None: + tc = ToolCallRequest(id="x", name="fn", arguments={}) + payload = tc.to_openai_tool_call() + + assert "extra_content" not in payload + assert "provider_specific_fields" not in payload + assert "provider_specific_fields" not in payload["function"] + + +# ── _parse: SDK-object branch ────────────────────────────────────────── + +def _make_sdk_response_with_extra_content(): + """Simulate a Gemini response via the OpenAI SDK (SimpleNamespace).""" + fn = SimpleNamespace(name="get_weather", arguments='{"city":"Tokyo"}') + tc = SimpleNamespace( + id="call_1", + index=0, + type="function", + function=fn, + extra_content=GEMINI_EXTRA, + ) + msg = SimpleNamespace( + content=None, + tool_calls=[tc], + reasoning_content=None, + ) + choice = SimpleNamespace(message=msg, finish_reason="tool_calls") + usage = SimpleNamespace(prompt_tokens=10, completion_tokens=5, total_tokens=15) + return SimpleNamespace(choices=[choice], usage=usage) + + +def test_parse_sdk_object_preserves_extra_content() -> None: + with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"): + provider = OpenAICompatProvider() + + result = provider._parse(_make_sdk_response_with_extra_content()) + + assert len(result.tool_calls) == 1 + tc = result.tool_calls[0] + assert tc.name == "get_weather" + assert tc.extra_content == GEMINI_EXTRA + + payload = tc.to_openai_tool_call() + assert payload["extra_content"] == GEMINI_EXTRA + + +# ── _parse: dict/mapping branch ─────────────────────────────────────── + +def test_parse_dict_preserves_extra_content() -> None: + with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"): + provider = OpenAICompatProvider() + + response_dict = { + "choices": [{ + "message": { + "content": None, + "tool_calls": [{ + "id": "call_1", + "type": "function", + "function": {"name": "get_weather", "arguments": '{"city":"Tokyo"}'}, + "extra_content": GEMINI_EXTRA, + }], + }, + "finish_reason": "tool_calls", + }], + "usage": {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15}, + } + + result = provider._parse(response_dict) + + assert len(result.tool_calls) == 1 + tc = result.tool_calls[0] + assert tc.name == "get_weather" + assert tc.extra_content == GEMINI_EXTRA + + payload = tc.to_openai_tool_call() + assert payload["extra_content"] == GEMINI_EXTRA + + +# ── _parse_chunks: streaming round-trip ─────────────────────────────── + +def test_parse_chunks_sdk_preserves_extra_content() -> None: + fn_delta = SimpleNamespace(name="get_weather", arguments='{"city":"Tokyo"}') + tc_delta = SimpleNamespace( + id="call_1", + index=0, + function=fn_delta, + extra_content=GEMINI_EXTRA, + ) + delta = SimpleNamespace(content=None, tool_calls=[tc_delta]) + choice = SimpleNamespace(finish_reason="tool_calls", delta=delta) + chunk = SimpleNamespace(choices=[choice], usage=None) + + result = OpenAICompatProvider._parse_chunks([chunk]) + + assert len(result.tool_calls) == 1 + tc = result.tool_calls[0] + assert tc.extra_content == GEMINI_EXTRA + + payload = tc.to_openai_tool_call() + assert payload["extra_content"] == GEMINI_EXTRA + + +def test_parse_chunks_dict_preserves_extra_content() -> None: + chunk = { + "choices": [{ + "finish_reason": "tool_calls", + "delta": { + "content": None, + "tool_calls": [{ + "index": 0, + "id": "call_1", + "function": {"name": "get_weather", "arguments": '{"city":"Tokyo"}'}, + "extra_content": GEMINI_EXTRA, + }], + }, + }], + } + + result = OpenAICompatProvider._parse_chunks([chunk]) + + assert len(result.tool_calls) == 1 + tc = result.tool_calls[0] + assert tc.extra_content == GEMINI_EXTRA + + payload = tc.to_openai_tool_call() + assert payload["extra_content"] == GEMINI_EXTRA + + +# ── Model switching: stale extras shouldn't break other providers ───── + +def test_stale_extra_content_in_tool_calls_survives_sanitize() -> None: + """When switching from Gemini to OpenAI, extra_content inside tool_calls + should survive message sanitization (it lives inside the tool_call dict, + not at message level, so it bypasses _ALLOWED_MSG_KEYS filtering).""" + with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"): + provider = OpenAICompatProvider() + + messages = [{ + "role": "assistant", + "content": None, + "tool_calls": [{ + "id": "call_1", + "type": "function", + "function": {"name": "fn", "arguments": "{}"}, + "extra_content": GEMINI_EXTRA, + }], + }] + + sanitized = provider._sanitize_messages(messages) + + assert sanitized[0]["tool_calls"][0]["extra_content"] == GEMINI_EXTRA diff --git a/tests/test_heartbeat_service.py b/tests/agent/test_heartbeat_service.py similarity index 100% rename from tests/test_heartbeat_service.py rename to tests/agent/test_heartbeat_service.py diff --git a/tests/test_loop_consolidation_tokens.py b/tests/agent/test_loop_consolidation_tokens.py similarity index 100% rename from tests/test_loop_consolidation_tokens.py rename to tests/agent/test_loop_consolidation_tokens.py diff --git a/tests/agent/test_loop_cron_timezone.py b/tests/agent/test_loop_cron_timezone.py new file mode 100644 index 00000000000..7738d304309 --- /dev/null +++ b/tests/agent/test_loop_cron_timezone.py @@ -0,0 +1,27 @@ +from pathlib import Path +from unittest.mock import MagicMock + +from nanobot.agent.loop import AgentLoop +from nanobot.agent.tools.cron import CronTool +from nanobot.bus.queue import MessageBus +from nanobot.cron.service import CronService + + +def test_agent_loop_registers_cron_tool_with_configured_timezone(tmp_path: Path) -> None: + bus = MessageBus() + provider = MagicMock() + provider.get_default_model.return_value = "test-model" + + loop = AgentLoop( + bus=bus, + provider=provider, + workspace=tmp_path, + model="test-model", + cron_service=CronService(tmp_path / "cron" / "jobs.json"), + timezone="Asia/Shanghai", + ) + + cron_tool = loop.tools.get("cron") + + assert isinstance(cron_tool, CronTool) + assert cron_tool._default_timezone == "Asia/Shanghai" diff --git a/tests/test_loop_save_turn.py b/tests/agent/test_loop_save_turn.py similarity index 100% rename from tests/test_loop_save_turn.py rename to tests/agent/test_loop_save_turn.py diff --git a/tests/test_memory_consolidation_types.py b/tests/agent/test_memory_consolidation_types.py similarity index 99% rename from tests/test_memory_consolidation_types.py rename to tests/agent/test_memory_consolidation_types.py index d63cc90475c..203e39a90f4 100644 --- a/tests/test_memory_consolidation_types.py +++ b/tests/agent/test_memory_consolidation_types.py @@ -380,7 +380,7 @@ async def test_tool_choice_fallback_on_unsupported_error(self, tmp_path: Path) - """Forced tool_choice rejected by provider -> retry with auto and succeed.""" store = MemoryStore(tmp_path) error_resp = LLMResponse( - content="Error calling LLM: litellm.BadRequestError: " + content="Error calling LLM: BadRequestError: " "The tool_choice parameter does not support being set to required or object", finish_reason="error", tool_calls=[], diff --git a/tests/test_onboard_logic.py b/tests/agent/test_onboard_logic.py similarity index 100% rename from tests/test_onboard_logic.py rename to tests/agent/test_onboard_logic.py diff --git a/tests/test_session_manager_history.py b/tests/agent/test_session_manager_history.py similarity index 100% rename from tests/test_session_manager_history.py rename to tests/agent/test_session_manager_history.py diff --git a/tests/test_skill_creator_scripts.py b/tests/agent/test_skill_creator_scripts.py similarity index 100% rename from tests/test_skill_creator_scripts.py rename to tests/agent/test_skill_creator_scripts.py diff --git a/tests/test_task_cancel.py b/tests/agent/test_task_cancel.py similarity index 100% rename from tests/test_task_cancel.py rename to tests/agent/test_task_cancel.py diff --git a/tests/test_base_channel.py b/tests/channels/test_base_channel.py similarity index 100% rename from tests/test_base_channel.py rename to tests/channels/test_base_channel.py diff --git a/tests/channels/test_channel_plugins.py b/tests/channels/test_channel_plugins.py new file mode 100644 index 00000000000..a0b458a084e --- /dev/null +++ b/tests/channels/test_channel_plugins.py @@ -0,0 +1,880 @@ +"""Tests for channel plugin discovery, merging, and config compatibility.""" + +from __future__ import annotations + +import asyncio +from types import SimpleNamespace +from unittest.mock import AsyncMock, patch + +import pytest + +from nanobot.bus.events import OutboundMessage +from nanobot.bus.queue import MessageBus +from nanobot.channels.base import BaseChannel +from nanobot.channels.manager import ChannelManager +from nanobot.config.schema import ChannelsConfig + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +class _FakePlugin(BaseChannel): + name = "fakeplugin" + display_name = "Fake Plugin" + + def __init__(self, config, bus): + super().__init__(config, bus) + self.login_calls: list[bool] = [] + + async def start(self) -> None: + pass + + async def stop(self) -> None: + pass + + async def send(self, msg: OutboundMessage) -> None: + pass + + async def login(self, force: bool = False) -> bool: + self.login_calls.append(force) + return True + + +class _FakeTelegram(BaseChannel): + """Plugin that tries to shadow built-in telegram.""" + name = "telegram" + display_name = "Fake Telegram" + + async def start(self) -> None: + pass + + async def stop(self) -> None: + pass + + async def send(self, msg: OutboundMessage) -> None: + pass + + +def _make_entry_point(name: str, cls: type): + """Create a mock entry point that returns *cls* on load().""" + ep = SimpleNamespace(name=name, load=lambda _cls=cls: _cls) + return ep + + +# --------------------------------------------------------------------------- +# ChannelsConfig extra="allow" +# --------------------------------------------------------------------------- + +def test_channels_config_accepts_unknown_keys(): + cfg = ChannelsConfig.model_validate({ + "myplugin": {"enabled": True, "token": "abc"}, + }) + extra = cfg.model_extra + assert extra is not None + assert extra["myplugin"]["enabled"] is True + assert extra["myplugin"]["token"] == "abc" + + +def test_channels_config_getattr_returns_extra(): + cfg = ChannelsConfig.model_validate({"myplugin": {"enabled": True}}) + section = getattr(cfg, "myplugin", None) + assert isinstance(section, dict) + assert section["enabled"] is True + + +def test_channels_config_builtin_fields_removed(): + """After decoupling, ChannelsConfig has no explicit channel fields.""" + cfg = ChannelsConfig() + assert not hasattr(cfg, "telegram") + assert cfg.send_progress is True + assert cfg.send_tool_hints is False + + +# --------------------------------------------------------------------------- +# discover_plugins +# --------------------------------------------------------------------------- + +_EP_TARGET = "importlib.metadata.entry_points" + + +def test_discover_plugins_loads_entry_points(): + from nanobot.channels.registry import discover_plugins + + ep = _make_entry_point("line", _FakePlugin) + with patch(_EP_TARGET, return_value=[ep]): + result = discover_plugins() + + assert "line" in result + assert result["line"] is _FakePlugin + + +def test_discover_plugins_handles_load_error(): + from nanobot.channels.registry import discover_plugins + + def _boom(): + raise RuntimeError("broken") + + ep = SimpleNamespace(name="broken", load=_boom) + with patch(_EP_TARGET, return_value=[ep]): + result = discover_plugins() + + assert "broken" not in result + + +# --------------------------------------------------------------------------- +# discover_all β€” merge & priority +# --------------------------------------------------------------------------- + +def test_discover_all_includes_builtins(): + from nanobot.channels.registry import discover_all, discover_channel_names + + with patch(_EP_TARGET, return_value=[]): + result = discover_all() + + # discover_all() only returns channels that are actually available (dependencies installed) + # discover_channel_names() returns all built-in channel names + # So we check that all actually loaded channels are in the result + for name in result: + assert name in discover_channel_names() + + +def test_discover_all_includes_external_plugin(): + from nanobot.channels.registry import discover_all + + ep = _make_entry_point("line", _FakePlugin) + with patch(_EP_TARGET, return_value=[ep]): + result = discover_all() + + assert "line" in result + assert result["line"] is _FakePlugin + + +def test_discover_all_builtin_shadows_plugin(): + from nanobot.channels.registry import discover_all + + ep = _make_entry_point("telegram", _FakeTelegram) + with patch(_EP_TARGET, return_value=[ep]): + result = discover_all() + + assert "telegram" in result + assert result["telegram"] is not _FakeTelegram + + +# --------------------------------------------------------------------------- +# Manager _init_channels with dict config (plugin scenario) +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_manager_loads_plugin_from_dict_config(): + """ChannelManager should instantiate a plugin channel from a raw dict config.""" + from nanobot.channels.manager import ChannelManager + + fake_config = SimpleNamespace( + channels=ChannelsConfig.model_validate({ + "fakeplugin": {"enabled": True, "allowFrom": ["*"]}, + }), + providers=SimpleNamespace(groq=SimpleNamespace(api_key="")), + ) + + with patch( + "nanobot.channels.registry.discover_all", + return_value={"fakeplugin": _FakePlugin}, + ): + mgr = ChannelManager.__new__(ChannelManager) + mgr.config = fake_config + mgr.bus = MessageBus() + mgr.channels = {} + mgr._dispatch_task = None + mgr._init_channels() + + assert "fakeplugin" in mgr.channels + assert isinstance(mgr.channels["fakeplugin"], _FakePlugin) + + +def test_channels_login_uses_discovered_plugin_class(monkeypatch): + from nanobot.cli.commands import app + from nanobot.config.schema import Config + from typer.testing import CliRunner + + runner = CliRunner() + seen: dict[str, object] = {} + + class _LoginPlugin(_FakePlugin): + display_name = "Login Plugin" + + async def login(self, force: bool = False) -> bool: + seen["force"] = force + seen["config"] = self.config + return True + + monkeypatch.setattr("nanobot.config.loader.load_config", lambda: Config()) + monkeypatch.setattr( + "nanobot.channels.registry.discover_all", + lambda: {"fakeplugin": _LoginPlugin}, + ) + + result = runner.invoke(app, ["channels", "login", "fakeplugin", "--force"]) + + assert result.exit_code == 0 + assert seen["force"] is True + + +@pytest.mark.asyncio +async def test_manager_skips_disabled_plugin(): + fake_config = SimpleNamespace( + channels=ChannelsConfig.model_validate({ + "fakeplugin": {"enabled": False}, + }), + providers=SimpleNamespace(groq=SimpleNamespace(api_key="")), + ) + + with patch( + "nanobot.channels.registry.discover_all", + return_value={"fakeplugin": _FakePlugin}, + ): + mgr = ChannelManager.__new__(ChannelManager) + mgr.config = fake_config + mgr.bus = MessageBus() + mgr.channels = {} + mgr._dispatch_task = None + mgr._init_channels() + + assert "fakeplugin" not in mgr.channels + + +# --------------------------------------------------------------------------- +# Built-in channel default_config() and dict->Pydantic conversion +# --------------------------------------------------------------------------- + +def test_builtin_channel_default_config(): + """Built-in channels expose default_config() returning a dict with 'enabled': False.""" + from nanobot.channels.telegram import TelegramChannel + cfg = TelegramChannel.default_config() + assert isinstance(cfg, dict) + assert cfg["enabled"] is False + assert "token" in cfg + + +def test_builtin_channel_init_from_dict(): + """Built-in channels accept a raw dict and convert to Pydantic internally.""" + from nanobot.channels.telegram import TelegramChannel + bus = MessageBus() + ch = TelegramChannel({"enabled": False, "token": "test-tok", "allowFrom": ["*"]}, bus) + assert ch.config.token == "test-tok" + assert ch.config.allow_from == ["*"] + + +def test_channels_config_send_max_retries_default(): + """ChannelsConfig should have send_max_retries with default value of 3.""" + cfg = ChannelsConfig() + assert hasattr(cfg, 'send_max_retries') + assert cfg.send_max_retries == 3 + + +def test_channels_config_send_max_retries_upper_bound(): + """send_max_retries should be bounded to prevent resource exhaustion.""" + from pydantic import ValidationError + + # Value too high should be rejected + with pytest.raises(ValidationError): + ChannelsConfig(send_max_retries=100) + + # Negative should be rejected + with pytest.raises(ValidationError): + ChannelsConfig(send_max_retries=-1) + + # Boundary values should be allowed + cfg_min = ChannelsConfig(send_max_retries=0) + assert cfg_min.send_max_retries == 0 + + cfg_max = ChannelsConfig(send_max_retries=10) + assert cfg_max.send_max_retries == 10 + + # Value above upper bound should be rejected + with pytest.raises(ValidationError): + ChannelsConfig(send_max_retries=11) + + +# --------------------------------------------------------------------------- +# _send_with_retry +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_send_with_retry_succeeds_first_try(): + """_send_with_retry should succeed on first try and not retry.""" + call_count = 0 + + class _FailingChannel(BaseChannel): + name = "failing" + display_name = "Failing" + + async def start(self) -> None: + pass + + async def stop(self) -> None: + pass + + async def send(self, msg: OutboundMessage) -> None: + nonlocal call_count + call_count += 1 + # Succeeds on first try + + fake_config = SimpleNamespace( + channels=ChannelsConfig(send_max_retries=3), + providers=SimpleNamespace(groq=SimpleNamespace(api_key="")), + ) + + mgr = ChannelManager.__new__(ChannelManager) + mgr.config = fake_config + mgr.bus = MessageBus() + mgr.channels = {"failing": _FailingChannel(fake_config, mgr.bus)} + mgr._dispatch_task = None + + msg = OutboundMessage(channel="failing", chat_id="123", content="test") + await mgr._send_with_retry(mgr.channels["failing"], msg) + + assert call_count == 1 + + +@pytest.mark.asyncio +async def test_send_with_retry_retries_on_failure(): + """_send_with_retry should retry on failure up to max_retries times.""" + call_count = 0 + + class _FailingChannel(BaseChannel): + name = "failing" + display_name = "Failing" + + async def start(self) -> None: + pass + + async def stop(self) -> None: + pass + + async def send(self, msg: OutboundMessage) -> None: + nonlocal call_count + call_count += 1 + raise RuntimeError("simulated failure") + + fake_config = SimpleNamespace( + channels=ChannelsConfig(send_max_retries=3), + providers=SimpleNamespace(groq=SimpleNamespace(api_key="")), + ) + + mgr = ChannelManager.__new__(ChannelManager) + mgr.config = fake_config + mgr.bus = MessageBus() + mgr.channels = {"failing": _FailingChannel(fake_config, mgr.bus)} + mgr._dispatch_task = None + + msg = OutboundMessage(channel="failing", chat_id="123", content="test") + + # Patch asyncio.sleep to avoid actual delays + with patch("nanobot.channels.manager.asyncio.sleep", new_callable=AsyncMock) as mock_sleep: + await mgr._send_with_retry(mgr.channels["failing"], msg) + + assert call_count == 3 # 3 total attempts (initial + 2 retries) + assert mock_sleep.call_count == 2 # 2 sleeps between retries + + +@pytest.mark.asyncio +async def test_send_with_retry_no_retry_when_max_is_zero(): + """_send_with_retry should not retry when send_max_retries is 0.""" + call_count = 0 + + class _FailingChannel(BaseChannel): + name = "failing" + display_name = "Failing" + + async def start(self) -> None: + pass + + async def stop(self) -> None: + pass + + async def send(self, msg: OutboundMessage) -> None: + nonlocal call_count + call_count += 1 + raise RuntimeError("simulated failure") + + fake_config = SimpleNamespace( + channels=ChannelsConfig(send_max_retries=0), + providers=SimpleNamespace(groq=SimpleNamespace(api_key="")), + ) + + mgr = ChannelManager.__new__(ChannelManager) + mgr.config = fake_config + mgr.bus = MessageBus() + mgr.channels = {"failing": _FailingChannel(fake_config, mgr.bus)} + mgr._dispatch_task = None + + msg = OutboundMessage(channel="failing", chat_id="123", content="test") + + with patch("nanobot.channels.manager.asyncio.sleep", new_callable=AsyncMock): + await mgr._send_with_retry(mgr.channels["failing"], msg) + + assert call_count == 1 # Called once but no retry (max(0, 1) = 1) + + +@pytest.mark.asyncio +async def test_send_with_retry_calls_send_delta(): + """_send_with_retry should call send_delta when metadata has _stream_delta.""" + send_delta_called = False + + class _StreamingChannel(BaseChannel): + name = "streaming" + display_name = "Streaming" + + async def start(self) -> None: + pass + + async def stop(self) -> None: + pass + + async def send(self, msg: OutboundMessage) -> None: + pass # Should not be called + + async def send_delta(self, chat_id: str, delta: str, metadata: dict | None = None) -> None: + nonlocal send_delta_called + send_delta_called = True + + fake_config = SimpleNamespace( + channels=ChannelsConfig(send_max_retries=3), + providers=SimpleNamespace(groq=SimpleNamespace(api_key="")), + ) + + mgr = ChannelManager.__new__(ChannelManager) + mgr.config = fake_config + mgr.bus = MessageBus() + mgr.channels = {"streaming": _StreamingChannel(fake_config, mgr.bus)} + mgr._dispatch_task = None + + msg = OutboundMessage( + channel="streaming", chat_id="123", content="test delta", + metadata={"_stream_delta": True} + ) + await mgr._send_with_retry(mgr.channels["streaming"], msg) + + assert send_delta_called is True + + +@pytest.mark.asyncio +async def test_send_with_retry_skips_send_when_streamed(): + """_send_with_retry should not call send when metadata has _streamed flag.""" + send_called = False + send_delta_called = False + + class _StreamedChannel(BaseChannel): + name = "streamed" + display_name = "Streamed" + + async def start(self) -> None: + pass + + async def stop(self) -> None: + pass + + async def send(self, msg: OutboundMessage) -> None: + nonlocal send_called + send_called = True + + async def send_delta(self, chat_id: str, delta: str, metadata: dict | None = None) -> None: + nonlocal send_delta_called + send_delta_called = True + + fake_config = SimpleNamespace( + channels=ChannelsConfig(send_max_retries=3), + providers=SimpleNamespace(groq=SimpleNamespace(api_key="")), + ) + + mgr = ChannelManager.__new__(ChannelManager) + mgr.config = fake_config + mgr.bus = MessageBus() + mgr.channels = {"streamed": _StreamedChannel(fake_config, mgr.bus)} + mgr._dispatch_task = None + + # _streamed means message was already sent via send_delta, so skip send + msg = OutboundMessage( + channel="streamed", chat_id="123", content="test", + metadata={"_streamed": True} + ) + await mgr._send_with_retry(mgr.channels["streamed"], msg) + + assert send_called is False + assert send_delta_called is False + + +@pytest.mark.asyncio +async def test_send_with_retry_propagates_cancelled_error(): + """_send_with_retry should re-raise CancelledError for graceful shutdown.""" + class _CancellingChannel(BaseChannel): + name = "cancelling" + display_name = "Cancelling" + + async def start(self) -> None: + pass + + async def stop(self) -> None: + pass + + async def send(self, msg: OutboundMessage) -> None: + raise asyncio.CancelledError("simulated cancellation") + + fake_config = SimpleNamespace( + channels=ChannelsConfig(send_max_retries=3), + providers=SimpleNamespace(groq=SimpleNamespace(api_key="")), + ) + + mgr = ChannelManager.__new__(ChannelManager) + mgr.config = fake_config + mgr.bus = MessageBus() + mgr.channels = {"cancelling": _CancellingChannel(fake_config, mgr.bus)} + mgr._dispatch_task = None + + msg = OutboundMessage(channel="cancelling", chat_id="123", content="test") + + with pytest.raises(asyncio.CancelledError): + await mgr._send_with_retry(mgr.channels["cancelling"], msg) + + +@pytest.mark.asyncio +async def test_send_with_retry_propagates_cancelled_error_during_sleep(): + """_send_with_retry should re-raise CancelledError during sleep.""" + call_count = 0 + + class _FailingChannel(BaseChannel): + name = "failing" + display_name = "Failing" + + async def start(self) -> None: + pass + + async def stop(self) -> None: + pass + + async def send(self, msg: OutboundMessage) -> None: + nonlocal call_count + call_count += 1 + raise RuntimeError("simulated failure") + + fake_config = SimpleNamespace( + channels=ChannelsConfig(send_max_retries=3), + providers=SimpleNamespace(groq=SimpleNamespace(api_key="")), + ) + + mgr = ChannelManager.__new__(ChannelManager) + mgr.config = fake_config + mgr.bus = MessageBus() + mgr.channels = {"failing": _FailingChannel(fake_config, mgr.bus)} + mgr._dispatch_task = None + + msg = OutboundMessage(channel="failing", chat_id="123", content="test") + + # Mock sleep to raise CancelledError + async def cancel_during_sleep(_): + raise asyncio.CancelledError("cancelled during sleep") + + with patch("nanobot.channels.manager.asyncio.sleep", side_effect=cancel_during_sleep): + with pytest.raises(asyncio.CancelledError): + await mgr._send_with_retry(mgr.channels["failing"], msg) + + # Should have attempted once before sleep was cancelled + assert call_count == 1 + + +# --------------------------------------------------------------------------- +# ChannelManager - lifecycle and getters +# --------------------------------------------------------------------------- + +class _ChannelWithAllowFrom(BaseChannel): + """Channel with configurable allow_from.""" + name = "withallow" + display_name = "With Allow" + + def __init__(self, config, bus, allow_from): + super().__init__(config, bus) + self.config.allow_from = allow_from + + async def start(self) -> None: + pass + + async def stop(self) -> None: + pass + + async def send(self, msg: OutboundMessage) -> None: + pass + + +class _StartableChannel(BaseChannel): + """Channel that tracks start/stop calls.""" + name = "startable" + display_name = "Startable" + + def __init__(self, config, bus): + super().__init__(config, bus) + self.started = False + self.stopped = False + + async def start(self) -> None: + self.started = True + + async def stop(self) -> None: + self.stopped = True + + async def send(self, msg: OutboundMessage) -> None: + pass + + +@pytest.mark.asyncio +async def test_validate_allow_from_raises_on_empty_list(): + """_validate_allow_from should raise SystemExit when allow_from is empty list.""" + fake_config = SimpleNamespace( + channels=ChannelsConfig(), + providers=SimpleNamespace(groq=SimpleNamespace(api_key="")), + ) + + mgr = ChannelManager.__new__(ChannelManager) + mgr.config = fake_config + mgr.channels = {"test": _ChannelWithAllowFrom(fake_config, None, [])} + mgr._dispatch_task = None + + with pytest.raises(SystemExit) as exc_info: + mgr._validate_allow_from() + + assert "empty allowFrom" in str(exc_info.value) + + +@pytest.mark.asyncio +async def test_validate_allow_from_passes_with_asterisk(): + """_validate_allow_from should not raise when allow_from contains '*'.""" + fake_config = SimpleNamespace( + channels=ChannelsConfig(), + providers=SimpleNamespace(groq=SimpleNamespace(api_key="")), + ) + + mgr = ChannelManager.__new__(ChannelManager) + mgr.config = fake_config + mgr.channels = {"test": _ChannelWithAllowFrom(fake_config, None, ["*"])} + mgr._dispatch_task = None + + # Should not raise + mgr._validate_allow_from() + + +@pytest.mark.asyncio +async def test_get_channel_returns_channel_if_exists(): + """get_channel should return the channel if it exists.""" + fake_config = SimpleNamespace( + channels=ChannelsConfig(), + providers=SimpleNamespace(groq=SimpleNamespace(api_key="")), + ) + + mgr = ChannelManager.__new__(ChannelManager) + mgr.config = fake_config + mgr.bus = MessageBus() + mgr.channels = {"telegram": _StartableChannel(fake_config, mgr.bus)} + mgr._dispatch_task = None + + assert mgr.get_channel("telegram") is not None + assert mgr.get_channel("nonexistent") is None + + +@pytest.mark.asyncio +async def test_get_status_returns_running_state(): + """get_status should return enabled and running state for each channel.""" + fake_config = SimpleNamespace( + channels=ChannelsConfig(), + providers=SimpleNamespace(groq=SimpleNamespace(api_key="")), + ) + + mgr = ChannelManager.__new__(ChannelManager) + mgr.config = fake_config + mgr.bus = MessageBus() + ch = _StartableChannel(fake_config, mgr.bus) + mgr.channels = {"startable": ch} + mgr._dispatch_task = None + + status = mgr.get_status() + + assert status["startable"]["enabled"] is True + assert status["startable"]["running"] is False # Not started yet + + +@pytest.mark.asyncio +async def test_enabled_channels_returns_channel_names(): + """enabled_channels should return list of enabled channel names.""" + fake_config = SimpleNamespace( + channels=ChannelsConfig(), + providers=SimpleNamespace(groq=SimpleNamespace(api_key="")), + ) + + mgr = ChannelManager.__new__(ChannelManager) + mgr.config = fake_config + mgr.bus = MessageBus() + mgr.channels = { + "telegram": _StartableChannel(fake_config, mgr.bus), + "slack": _StartableChannel(fake_config, mgr.bus), + } + mgr._dispatch_task = None + + enabled = mgr.enabled_channels + + assert "telegram" in enabled + assert "slack" in enabled + assert len(enabled) == 2 + + +@pytest.mark.asyncio +async def test_stop_all_cancels_dispatcher_and_stops_channels(): + """stop_all should cancel the dispatch task and stop all channels.""" + fake_config = SimpleNamespace( + channels=ChannelsConfig(), + providers=SimpleNamespace(groq=SimpleNamespace(api_key="")), + ) + + mgr = ChannelManager.__new__(ChannelManager) + mgr.config = fake_config + mgr.bus = MessageBus() + + ch = _StartableChannel(fake_config, mgr.bus) + mgr.channels = {"startable": ch} + + # Create a real cancelled task + async def dummy_task(): + while True: + await asyncio.sleep(1) + + dispatch_task = asyncio.create_task(dummy_task()) + mgr._dispatch_task = dispatch_task + + await mgr.stop_all() + + # Task should be cancelled + assert dispatch_task.cancelled() + # Channel should be stopped + assert ch.stopped is True + + +@pytest.mark.asyncio +async def test_start_channel_logs_error_on_failure(): + """_start_channel should log error when channel start fails.""" + class _FailingChannel(BaseChannel): + name = "failing" + display_name = "Failing" + + async def start(self) -> None: + raise RuntimeError("connection failed") + + async def stop(self) -> None: + pass + + async def send(self, msg: OutboundMessage) -> None: + pass + + fake_config = SimpleNamespace( + channels=ChannelsConfig(), + providers=SimpleNamespace(groq=SimpleNamespace(api_key="")), + ) + + mgr = ChannelManager.__new__(ChannelManager) + mgr.config = fake_config + mgr.bus = MessageBus() + mgr.channels = {} + mgr._dispatch_task = None + + ch = _FailingChannel(fake_config, mgr.bus) + + # Should not raise, just log error + await mgr._start_channel("failing", ch) + + +@pytest.mark.asyncio +async def test_stop_all_handles_channel_exception(): + """stop_all should handle exceptions when stopping channels gracefully.""" + class _StopFailingChannel(BaseChannel): + name = "stopfailing" + display_name = "Stop Failing" + + async def start(self) -> None: + pass + + async def stop(self) -> None: + raise RuntimeError("stop failed") + + async def send(self, msg: OutboundMessage) -> None: + pass + + fake_config = SimpleNamespace( + channels=ChannelsConfig(), + providers=SimpleNamespace(groq=SimpleNamespace(api_key="")), + ) + + mgr = ChannelManager.__new__(ChannelManager) + mgr.config = fake_config + mgr.bus = MessageBus() + mgr.channels = {"stopfailing": _StopFailingChannel(fake_config, mgr.bus)} + mgr._dispatch_task = None + + # Should not raise even if channel.stop() raises + await mgr.stop_all() + + +@pytest.mark.asyncio +async def test_start_all_no_channels_logs_warning(): + """start_all should log warning when no channels are enabled.""" + fake_config = SimpleNamespace( + channels=ChannelsConfig(), + providers=SimpleNamespace(groq=SimpleNamespace(api_key="")), + ) + + mgr = ChannelManager.__new__(ChannelManager) + mgr.config = fake_config + mgr.bus = MessageBus() + mgr.channels = {} # No channels + mgr._dispatch_task = None + + # Should return early without creating dispatch task + await mgr.start_all() + + assert mgr._dispatch_task is None + + +@pytest.mark.asyncio +async def test_start_all_creates_dispatch_task(): + """start_all should create the dispatch task when channels exist.""" + fake_config = SimpleNamespace( + channels=ChannelsConfig(), + providers=SimpleNamespace(groq=SimpleNamespace(api_key="")), + ) + + mgr = ChannelManager.__new__(ChannelManager) + mgr.config = fake_config + mgr.bus = MessageBus() + + ch = _StartableChannel(fake_config, mgr.bus) + mgr.channels = {"startable": ch} + mgr._dispatch_task = None + + # Cancel immediately after start to avoid running forever + async def cancel_after_start(): + await asyncio.sleep(0.01) + if mgr._dispatch_task: + mgr._dispatch_task.cancel() + + cancel_task = asyncio.create_task(cancel_after_start()) + + try: + await mgr.start_all() + except asyncio.CancelledError: + pass + finally: + cancel_task.cancel() + try: + await cancel_task + except asyncio.CancelledError: + pass + + # Dispatch task should have been created + assert mgr._dispatch_task is not None + diff --git a/tests/test_dingtalk_channel.py b/tests/channels/test_dingtalk_channel.py similarity index 95% rename from tests/test_dingtalk_channel.py rename to tests/channels/test_dingtalk_channel.py index a0b866faded..6894c86837c 100644 --- a/tests/test_dingtalk_channel.py +++ b/tests/channels/test_dingtalk_channel.py @@ -3,6 +3,16 @@ import pytest +# Check optional dingtalk dependencies before running tests +try: + from nanobot.channels import dingtalk + DINGTALK_AVAILABLE = getattr(dingtalk, "DINGTALK_AVAILABLE", False) +except ImportError: + DINGTALK_AVAILABLE = False + +if not DINGTALK_AVAILABLE: + pytest.skip("DingTalk dependencies not installed (dingtalk-stream)", allow_module_level=True) + from nanobot.bus.queue import MessageBus import nanobot.channels.dingtalk as dingtalk_module from nanobot.channels.dingtalk import DingTalkChannel, NanobotDingTalkHandler diff --git a/tests/test_email_channel.py b/tests/channels/test_email_channel.py similarity index 100% rename from tests/test_email_channel.py rename to tests/channels/test_email_channel.py diff --git a/tests/test_feishu_markdown_rendering.py b/tests/channels/test_feishu_markdown_rendering.py similarity index 81% rename from tests/test_feishu_markdown_rendering.py rename to tests/channels/test_feishu_markdown_rendering.py index 6812a21aa66..efcd207335b 100644 --- a/tests/test_feishu_markdown_rendering.py +++ b/tests/channels/test_feishu_markdown_rendering.py @@ -1,3 +1,14 @@ +# Check optional Feishu dependencies before running tests +try: + from nanobot.channels import feishu + FEISHU_AVAILABLE = getattr(feishu, "FEISHU_AVAILABLE", False) +except ImportError: + FEISHU_AVAILABLE = False + +if not FEISHU_AVAILABLE: + import pytest + pytest.skip("Feishu dependencies not installed (lark-oapi)", allow_module_level=True) + from nanobot.channels.feishu import FeishuChannel diff --git a/tests/test_feishu_post_content.py b/tests/channels/test_feishu_post_content.py similarity index 82% rename from tests/test_feishu_post_content.py rename to tests/channels/test_feishu_post_content.py index 7b1cb9d31cf..a4c5bae19f5 100644 --- a/tests/test_feishu_post_content.py +++ b/tests/channels/test_feishu_post_content.py @@ -1,3 +1,14 @@ +# Check optional Feishu dependencies before running tests +try: + from nanobot.channels import feishu + FEISHU_AVAILABLE = getattr(feishu, "FEISHU_AVAILABLE", False) +except ImportError: + FEISHU_AVAILABLE = False + +if not FEISHU_AVAILABLE: + import pytest + pytest.skip("Feishu dependencies not installed (lark-oapi)", allow_module_level=True) + from nanobot.channels.feishu import FeishuChannel, _extract_post_content diff --git a/tests/test_feishu_reply.py b/tests/channels/test_feishu_reply.py similarity index 97% rename from tests/test_feishu_reply.py rename to tests/channels/test_feishu_reply.py index b2072b31aaf..0753653a775 100644 --- a/tests/test_feishu_reply.py +++ b/tests/channels/test_feishu_reply.py @@ -7,6 +7,16 @@ import pytest +# Check optional Feishu dependencies before running tests +try: + from nanobot.channels import feishu + FEISHU_AVAILABLE = getattr(feishu, "FEISHU_AVAILABLE", False) +except ImportError: + FEISHU_AVAILABLE = False + +if not FEISHU_AVAILABLE: + pytest.skip("Feishu dependencies not installed (lark-oapi)", allow_module_level=True) + from nanobot.bus.events import OutboundMessage from nanobot.bus.queue import MessageBus from nanobot.channels.feishu import FeishuChannel, FeishuConfig diff --git a/tests/test_feishu_table_split.py b/tests/channels/test_feishu_table_split.py similarity index 89% rename from tests/test_feishu_table_split.py rename to tests/channels/test_feishu_table_split.py index af8fa164a8e..030b8910dd4 100644 --- a/tests/test_feishu_table_split.py +++ b/tests/channels/test_feishu_table_split.py @@ -6,6 +6,17 @@ table, allowing nanobot to send multiple cards instead of failing. """ +# Check optional Feishu dependencies before running tests +try: + from nanobot.channels import feishu + FEISHU_AVAILABLE = getattr(feishu, "FEISHU_AVAILABLE", False) +except ImportError: + FEISHU_AVAILABLE = False + +if not FEISHU_AVAILABLE: + import pytest + pytest.skip("Feishu dependencies not installed (lark-oapi)", allow_module_level=True) + from nanobot.channels.feishu import FeishuChannel diff --git a/tests/test_feishu_tool_hint_code_block.py b/tests/channels/test_feishu_tool_hint_code_block.py similarity index 93% rename from tests/test_feishu_tool_hint_code_block.py rename to tests/channels/test_feishu_tool_hint_code_block.py index 2a1b81227bf..a65f1d9882e 100644 --- a/tests/test_feishu_tool_hint_code_block.py +++ b/tests/channels/test_feishu_tool_hint_code_block.py @@ -6,6 +6,16 @@ import pytest from pytest import mark +# Check optional Feishu dependencies before running tests +try: + from nanobot.channels import feishu + FEISHU_AVAILABLE = getattr(feishu, "FEISHU_AVAILABLE", False) +except ImportError: + FEISHU_AVAILABLE = False + +if not FEISHU_AVAILABLE: + pytest.skip("Feishu dependencies not installed (lark-oapi)", allow_module_level=True) + from nanobot.bus.events import OutboundMessage from nanobot.channels.feishu import FeishuChannel diff --git a/tests/test_matrix_channel.py b/tests/channels/test_matrix_channel.py similarity index 99% rename from tests/test_matrix_channel.py rename to tests/channels/test_matrix_channel.py index 1f3b69ccf81..dd5e97d9056 100644 --- a/tests/test_matrix_channel.py +++ b/tests/channels/test_matrix_channel.py @@ -4,6 +4,12 @@ import pytest +# Check optional matrix dependencies before importing +try: + import nh3 # noqa: F401 +except ImportError: + pytest.skip("Matrix dependencies not installed (nh3)", allow_module_level=True) + import nanobot.channels.matrix as matrix_module from nanobot.bus.events import OutboundMessage from nanobot.bus.queue import MessageBus diff --git a/tests/test_qq_channel.py b/tests/channels/test_qq_channel.py similarity index 68% rename from tests/test_qq_channel.py rename to tests/channels/test_qq_channel.py index bd5e8911c54..729442a13e3 100644 --- a/tests/test_qq_channel.py +++ b/tests/channels/test_qq_channel.py @@ -1,11 +1,22 @@ +import tempfile +from pathlib import Path from types import SimpleNamespace import pytest +# Check optional QQ dependencies before running tests +try: + from nanobot.channels import qq + QQ_AVAILABLE = getattr(qq, "QQ_AVAILABLE", False) +except ImportError: + QQ_AVAILABLE = False + +if not QQ_AVAILABLE: + pytest.skip("QQ dependencies not installed (qq-botpy)", allow_module_level=True) + from nanobot.bus.events import OutboundMessage from nanobot.bus.queue import MessageBus -from nanobot.channels.qq import QQChannel -from nanobot.channels.qq import QQConfig +from nanobot.channels.qq import QQChannel, QQConfig class _FakeApi: @@ -34,6 +45,7 @@ async def test_on_group_message_routes_to_group_chat_id() -> None: content="hello", group_openid="group123", author=SimpleNamespace(member_openid="user1"), + attachments=[], ) await channel._on_message(data, is_group=True) @@ -123,3 +135,38 @@ async def test_send_group_message_uses_markdown_when_configured() -> None: "msg_id": "msg1", "msg_seq": 2, } + + +@pytest.mark.asyncio +async def test_read_media_bytes_local_path() -> None: + channel = QQChannel(QQConfig(app_id="app", secret="secret"), MessageBus()) + + with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f: + f.write(b"\x89PNG\r\n") + tmp_path = f.name + + data, filename = await channel._read_media_bytes(tmp_path) + assert data == b"\x89PNG\r\n" + assert filename == Path(tmp_path).name + + +@pytest.mark.asyncio +async def test_read_media_bytes_file_uri() -> None: + channel = QQChannel(QQConfig(app_id="app", secret="secret"), MessageBus()) + + with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as f: + f.write(b"JFIF") + tmp_path = f.name + + data, filename = await channel._read_media_bytes(f"file://{tmp_path}") + assert data == b"JFIF" + assert filename == Path(tmp_path).name + + +@pytest.mark.asyncio +async def test_read_media_bytes_missing_file() -> None: + channel = QQChannel(QQConfig(app_id="app", secret="secret"), MessageBus()) + + data, filename = await channel._read_media_bytes("/nonexistent/path/image.png") + assert data is None + assert filename is None diff --git a/tests/test_slack_channel.py b/tests/channels/test_slack_channel.py similarity index 95% rename from tests/test_slack_channel.py rename to tests/channels/test_slack_channel.py index d243235aaa7..f7eec95c036 100644 --- a/tests/test_slack_channel.py +++ b/tests/channels/test_slack_channel.py @@ -2,6 +2,12 @@ import pytest +# Check optional Slack dependencies before running tests +try: + import slack_sdk # noqa: F401 +except ImportError: + pytest.skip("Slack dependencies not installed (slack-sdk)", allow_module_level=True) + from nanobot.bus.events import OutboundMessage from nanobot.bus.queue import MessageBus from nanobot.channels.slack import SlackChannel diff --git a/tests/test_telegram_channel.py b/tests/channels/test_telegram_channel.py similarity index 89% rename from tests/test_telegram_channel.py rename to tests/channels/test_telegram_channel.py index 8b6ba97896e..d5dafdee723 100644 --- a/tests/test_telegram_channel.py +++ b/tests/channels/test_telegram_channel.py @@ -5,9 +5,15 @@ import pytest +# Check optional Telegram dependencies before running tests +try: + import telegram # noqa: F401 +except ImportError: + pytest.skip("Telegram dependencies not installed (python-telegram-bot)", allow_module_level=True) + from nanobot.bus.events import OutboundMessage from nanobot.bus.queue import MessageBus -from nanobot.channels.telegram import TELEGRAM_REPLY_CONTEXT_MAX_LEN, TelegramChannel +from nanobot.channels.telegram import TELEGRAM_REPLY_CONTEXT_MAX_LEN, TelegramChannel, _StreamBuf from nanobot.channels.telegram import TelegramConfig @@ -44,8 +50,9 @@ async def get_me(self): async def set_my_commands(self, commands) -> None: self.commands = commands - async def send_message(self, **kwargs) -> None: + async def send_message(self, **kwargs): self.sent_messages.append(kwargs) + return SimpleNamespace(message_id=len(self.sent_messages)) async def send_photo(self, **kwargs) -> None: self.sent_media.append({"kind": "photo", **kwargs}) @@ -265,13 +272,86 @@ async def always_timeout(**kwargs): orig_delay = tg_mod._SEND_RETRY_BASE_DELAY tg_mod._SEND_RETRY_BASE_DELAY = 0.01 try: - await channel._send_text(123, "hello", None, {}) + with pytest.raises(TimedOut): + await channel._send_text(123, "hello", None, {}) finally: tg_mod._SEND_RETRY_BASE_DELAY = orig_delay assert channel._app.bot.sent_messages == [] +@pytest.mark.asyncio +async def test_send_delta_stream_end_raises_and_keeps_buffer_on_failure() -> None: + channel = TelegramChannel( + TelegramConfig(enabled=True, token="123:abc", allow_from=["*"]), + MessageBus(), + ) + channel._app = _FakeApp(lambda: None) + channel._app.bot.edit_message_text = AsyncMock(side_effect=RuntimeError("boom")) + channel._stream_bufs["123"] = _StreamBuf(text="hello", message_id=7, last_edit=0.0) + + with pytest.raises(RuntimeError, match="boom"): + await channel.send_delta("123", "", {"_stream_end": True}) + + assert "123" in channel._stream_bufs + + +@pytest.mark.asyncio +async def test_send_delta_stream_end_treats_not_modified_as_success() -> None: + from telegram.error import BadRequest + + channel = TelegramChannel( + TelegramConfig(enabled=True, token="123:abc", allow_from=["*"]), + MessageBus(), + ) + channel._app = _FakeApp(lambda: None) + channel._app.bot.edit_message_text = AsyncMock(side_effect=BadRequest("Message is not modified")) + channel._stream_bufs["123"] = _StreamBuf(text="hello", message_id=7, last_edit=0.0, stream_id="s:0") + + await channel.send_delta("123", "", {"_stream_end": True, "_stream_id": "s:0"}) + + assert "123" not in channel._stream_bufs + + +@pytest.mark.asyncio +async def test_send_delta_new_stream_id_replaces_stale_buffer() -> None: + channel = TelegramChannel( + TelegramConfig(enabled=True, token="123:abc", allow_from=["*"]), + MessageBus(), + ) + channel._app = _FakeApp(lambda: None) + channel._stream_bufs["123"] = _StreamBuf( + text="hello", + message_id=7, + last_edit=0.0, + stream_id="old:0", + ) + + await channel.send_delta("123", "world", {"_stream_delta": True, "_stream_id": "new:0"}) + + buf = channel._stream_bufs["123"] + assert buf.text == "world" + assert buf.stream_id == "new:0" + assert buf.message_id == 1 + + +@pytest.mark.asyncio +async def test_send_delta_incremental_edit_treats_not_modified_as_success() -> None: + from telegram.error import BadRequest + + channel = TelegramChannel( + TelegramConfig(enabled=True, token="123:abc", allow_from=["*"]), + MessageBus(), + ) + channel._app = _FakeApp(lambda: None) + channel._stream_bufs["123"] = _StreamBuf(text="hello", message_id=7, last_edit=0.0, stream_id="s:0") + channel._app.bot.edit_message_text = AsyncMock(side_effect=BadRequest("Message is not modified")) + + await channel.send_delta("123", "", {"_stream_delta": True, "_stream_id": "s:0"}) + + assert channel._stream_bufs["123"].last_edit > 0.0 + + def test_derive_topic_session_key_uses_thread_id() -> None: message = SimpleNamespace( chat=SimpleNamespace(type="supergroup"), diff --git a/tests/channels/test_weixin_channel.py b/tests/channels/test_weixin_channel.py new file mode 100644 index 00000000000..54d9bd93f91 --- /dev/null +++ b/tests/channels/test_weixin_channel.py @@ -0,0 +1,280 @@ +import asyncio +import json +import tempfile +from types import SimpleNamespace +from unittest.mock import AsyncMock + +import pytest + +from nanobot.bus.queue import MessageBus +from nanobot.channels.weixin import ( + ITEM_IMAGE, + ITEM_TEXT, + MESSAGE_TYPE_BOT, + WEIXIN_CHANNEL_VERSION, + WeixinChannel, + WeixinConfig, +) + + +def _make_channel() -> tuple[WeixinChannel, MessageBus]: + bus = MessageBus() + channel = WeixinChannel( + WeixinConfig( + enabled=True, + allow_from=["*"], + state_dir=tempfile.mkdtemp(prefix="nanobot-weixin-test-"), + ), + bus, + ) + return channel, bus + + +def test_make_headers_includes_route_tag_when_configured() -> None: + bus = MessageBus() + channel = WeixinChannel( + WeixinConfig(enabled=True, allow_from=["*"], route_tag=123), + bus, + ) + channel._token = "token" + + headers = channel._make_headers() + + assert headers["Authorization"] == "Bearer token" + assert headers["SKRouteTag"] == "123" + + +def test_channel_version_matches_reference_plugin_version() -> None: + assert WEIXIN_CHANNEL_VERSION == "1.0.3" + + +def test_save_and_load_state_persists_context_tokens(tmp_path) -> None: + bus = MessageBus() + channel = WeixinChannel( + WeixinConfig(enabled=True, allow_from=["*"], state_dir=str(tmp_path)), + bus, + ) + channel._token = "token" + channel._get_updates_buf = "cursor" + channel._context_tokens = {"wx-user": "ctx-1"} + + channel._save_state() + + saved = json.loads((tmp_path / "account.json").read_text()) + assert saved["context_tokens"] == {"wx-user": "ctx-1"} + + restored = WeixinChannel( + WeixinConfig(enabled=True, allow_from=["*"], state_dir=str(tmp_path)), + bus, + ) + + assert restored._load_state() is True + assert restored._context_tokens == {"wx-user": "ctx-1"} + + +@pytest.mark.asyncio +async def test_process_message_deduplicates_inbound_ids() -> None: + channel, bus = _make_channel() + msg = { + "message_type": 1, + "message_id": "m1", + "from_user_id": "wx-user", + "context_token": "ctx-1", + "item_list": [ + {"type": ITEM_TEXT, "text_item": {"text": "hello"}}, + ], + } + + await channel._process_message(msg) + first = await asyncio.wait_for(bus.consume_inbound(), timeout=1.0) + await channel._process_message(msg) + + assert first.sender_id == "wx-user" + assert first.chat_id == "wx-user" + assert first.content == "hello" + assert bus.inbound_size == 0 + + +@pytest.mark.asyncio +async def test_process_message_caches_context_token_and_send_uses_it() -> None: + channel, _bus = _make_channel() + channel._client = object() + channel._token = "token" + channel._send_text = AsyncMock() + + await channel._process_message( + { + "message_type": 1, + "message_id": "m2", + "from_user_id": "wx-user", + "context_token": "ctx-2", + "item_list": [ + {"type": ITEM_TEXT, "text_item": {"text": "ping"}}, + ], + } + ) + + await channel.send( + type("Msg", (), {"chat_id": "wx-user", "content": "pong", "media": [], "metadata": {}})() + ) + + channel._send_text.assert_awaited_once_with("wx-user", "pong", "ctx-2") + + +@pytest.mark.asyncio +async def test_process_message_persists_context_token_to_state_file(tmp_path) -> None: + bus = MessageBus() + channel = WeixinChannel( + WeixinConfig(enabled=True, allow_from=["*"], state_dir=str(tmp_path)), + bus, + ) + + await channel._process_message( + { + "message_type": 1, + "message_id": "m2b", + "from_user_id": "wx-user", + "context_token": "ctx-2b", + "item_list": [ + {"type": ITEM_TEXT, "text_item": {"text": "ping"}}, + ], + } + ) + + saved = json.loads((tmp_path / "account.json").read_text()) + assert saved["context_tokens"] == {"wx-user": "ctx-2b"} + + +@pytest.mark.asyncio +async def test_process_message_extracts_media_and_preserves_paths() -> None: + channel, bus = _make_channel() + channel._download_media_item = AsyncMock(return_value="/tmp/test.jpg") + + await channel._process_message( + { + "message_type": 1, + "message_id": "m3", + "from_user_id": "wx-user", + "context_token": "ctx-3", + "item_list": [ + {"type": ITEM_IMAGE, "image_item": {"media": {"encrypt_query_param": "x"}}}, + ], + } + ) + + inbound = await asyncio.wait_for(bus.consume_inbound(), timeout=1.0) + + assert "[image]" in inbound.content + assert "/tmp/test.jpg" in inbound.content + assert inbound.media == ["/tmp/test.jpg"] + + +@pytest.mark.asyncio +async def test_send_without_context_token_does_not_send_text() -> None: + channel, _bus = _make_channel() + channel._client = object() + channel._token = "token" + channel._send_text = AsyncMock() + + await channel.send( + type("Msg", (), {"chat_id": "unknown-user", "content": "pong", "media": [], "metadata": {}})() + ) + + channel._send_text.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_send_does_not_send_when_session_is_paused() -> None: + channel, _bus = _make_channel() + channel._client = object() + channel._token = "token" + channel._context_tokens["wx-user"] = "ctx-2" + channel._pause_session(60) + channel._send_text = AsyncMock() + + await channel.send( + type("Msg", (), {"chat_id": "wx-user", "content": "pong", "media": [], "metadata": {}})() + ) + + channel._send_text.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_poll_once_pauses_session_on_expired_errcode() -> None: + channel, _bus = _make_channel() + channel._client = SimpleNamespace(timeout=None) + channel._token = "token" + channel._api_post = AsyncMock(return_value={"ret": 0, "errcode": -14, "errmsg": "expired"}) + + await channel._poll_once() + + assert channel._session_pause_remaining_s() > 0 + + +@pytest.mark.asyncio +async def test_qr_login_refreshes_expired_qr_and_then_succeeds() -> None: + channel, _bus = _make_channel() + channel._running = True + channel._save_state = lambda: None + channel._print_qr_code = lambda url: None + channel._api_get = AsyncMock( + side_effect=[ + {"qrcode": "qr-1", "qrcode_img_content": "url-1"}, + {"status": "expired"}, + {"qrcode": "qr-2", "qrcode_img_content": "url-2"}, + { + "status": "confirmed", + "bot_token": "token-2", + "ilink_bot_id": "bot-2", + "baseurl": "https://example.test", + "ilink_user_id": "wx-user", + }, + ] + ) + + ok = await channel._qr_login() + + assert ok is True + assert channel._token == "token-2" + assert channel.config.base_url == "https://example.test" + + +@pytest.mark.asyncio +async def test_qr_login_returns_false_after_too_many_expired_qr_codes() -> None: + channel, _bus = _make_channel() + channel._running = True + channel._print_qr_code = lambda url: None + channel._api_get = AsyncMock( + side_effect=[ + {"qrcode": "qr-1", "qrcode_img_content": "url-1"}, + {"status": "expired"}, + {"qrcode": "qr-2", "qrcode_img_content": "url-2"}, + {"status": "expired"}, + {"qrcode": "qr-3", "qrcode_img_content": "url-3"}, + {"status": "expired"}, + {"qrcode": "qr-4", "qrcode_img_content": "url-4"}, + {"status": "expired"}, + ] + ) + + ok = await channel._qr_login() + + assert ok is False + + +@pytest.mark.asyncio +async def test_process_message_skips_bot_messages() -> None: + channel, bus = _make_channel() + + await channel._process_message( + { + "message_type": MESSAGE_TYPE_BOT, + "message_id": "m4", + "from_user_id": "wx-user", + "item_list": [ + {"type": ITEM_TEXT, "text_item": {"text": "hello"}}, + ], + } + ) + + assert bus.inbound_size == 0 diff --git a/tests/test_whatsapp_channel.py b/tests/channels/test_whatsapp_channel.py similarity index 66% rename from tests/test_whatsapp_channel.py rename to tests/channels/test_whatsapp_channel.py index 1413429e35b..dea15d7b291 100644 --- a/tests/test_whatsapp_channel.py +++ b/tests/channels/test_whatsapp_channel.py @@ -106,3 +106,52 @@ async def test_send_when_disconnected_is_noop(): await ch.send(msg) ch._ws.send.assert_not_called() + + +@pytest.mark.asyncio +async def test_group_policy_mention_skips_unmentioned_group_message(): + ch = WhatsAppChannel({"enabled": True, "groupPolicy": "mention"}, MagicMock()) + ch._handle_message = AsyncMock() + + await ch._handle_bridge_message( + json.dumps( + { + "type": "message", + "id": "m1", + "sender": "12345@g.us", + "pn": "user@s.whatsapp.net", + "content": "hello group", + "timestamp": 1, + "isGroup": True, + "wasMentioned": False, + } + ) + ) + + ch._handle_message.assert_not_called() + + +@pytest.mark.asyncio +async def test_group_policy_mention_accepts_mentioned_group_message(): + ch = WhatsAppChannel({"enabled": True, "groupPolicy": "mention"}, MagicMock()) + ch._handle_message = AsyncMock() + + await ch._handle_bridge_message( + json.dumps( + { + "type": "message", + "id": "m1", + "sender": "12345@g.us", + "pn": "user@s.whatsapp.net", + "content": "hello @bot", + "timestamp": 1, + "isGroup": True, + "wasMentioned": True, + } + ) + ) + + ch._handle_message.assert_awaited_once() + kwargs = ch._handle_message.await_args.kwargs + assert kwargs["chat_id"] == "12345@g.us" + assert kwargs["sender_id"] == "user" diff --git a/tests/test_cli_input.py b/tests/cli/test_cli_input.py similarity index 100% rename from tests/test_cli_input.py rename to tests/cli/test_cli_input.py diff --git a/tests/test_commands.py b/tests/cli/test_commands.py similarity index 63% rename from tests/test_commands.py rename to tests/cli/test_commands.py index 5d4c2bcdc0e..a8fcc4aa0cf 100644 --- a/tests/test_commands.py +++ b/tests/cli/test_commands.py @@ -9,9 +9,8 @@ from nanobot.bus.events import OutboundMessage from nanobot.cli.commands import _make_provider, app from nanobot.config.schema import Config -from nanobot.providers.litellm_provider import LiteLLMProvider from nanobot.providers.openai_codex_provider import _strip_model_prefix -from nanobot.providers.registry import find_by_model +from nanobot.providers.registry import find_by_name runner = CliRunner() @@ -228,7 +227,7 @@ def test_config_matches_explicit_ollama_prefix_without_api_key(): config.agents.defaults.model = "ollama/llama3.2" assert config.get_provider_name() == "ollama" - assert config.get_api_base() == "http://localhost:11434" + assert config.get_api_base() == "http://localhost:11434/v1" def test_config_explicit_ollama_provider_uses_default_localhost_api_base(): @@ -237,19 +236,47 @@ def test_config_explicit_ollama_provider_uses_default_localhost_api_base(): config.agents.defaults.model = "llama3.2" assert config.get_provider_name() == "ollama" - assert config.get_api_base() == "http://localhost:11434" + assert config.get_api_base() == "http://localhost:11434/v1" + + +def test_config_accepts_camel_case_explicit_provider_name_for_coding_plan(): + config = Config.model_validate( + { + "agents": { + "defaults": { + "provider": "volcengineCodingPlan", + "model": "doubao-1-5-pro", + } + }, + "providers": { + "volcengineCodingPlan": { + "apiKey": "test-key", + } + }, + } + ) + + assert config.get_provider_name() == "volcengine_coding_plan" + assert config.get_api_base() == "https://ark.cn-beijing.volces.com/api/coding/v3" + + +def test_find_by_name_accepts_camel_case_and_hyphen_aliases(): + assert find_by_name("volcengineCodingPlan") is not None + assert find_by_name("volcengineCodingPlan").name == "volcengine_coding_plan" + assert find_by_name("github-copilot") is not None + assert find_by_name("github-copilot").name == "github_copilot" def test_config_auto_detects_ollama_from_local_api_base(): config = Config.model_validate( { "agents": {"defaults": {"provider": "auto", "model": "llama3.2"}}, - "providers": {"ollama": {"apiBase": "http://localhost:11434"}}, + "providers": {"ollama": {"apiBase": "http://localhost:11434/v1"}}, } ) assert config.get_provider_name() == "ollama" - assert config.get_api_base() == "http://localhost:11434" + assert config.get_api_base() == "http://localhost:11434/v1" def test_config_prefers_ollama_over_vllm_when_both_local_providers_configured(): @@ -258,13 +285,13 @@ def test_config_prefers_ollama_over_vllm_when_both_local_providers_configured(): "agents": {"defaults": {"provider": "auto", "model": "llama3.2"}}, "providers": { "vllm": {"apiBase": "http://localhost:8000"}, - "ollama": {"apiBase": "http://localhost:11434"}, + "ollama": {"apiBase": "http://localhost:11434/v1"}, }, } ) assert config.get_provider_name() == "ollama" - assert config.get_api_base() == "http://localhost:11434" + assert config.get_api_base() == "http://localhost:11434/v1" def test_config_falls_back_to_vllm_when_ollama_not_configured(): @@ -281,19 +308,13 @@ def test_config_falls_back_to_vllm_when_ollama_not_configured(): assert config.get_api_base() == "http://localhost:8000" -def test_find_by_model_prefers_explicit_prefix_over_generic_codex_keyword(): - spec = find_by_model("github-copilot/gpt-5.3-codex") - - assert spec is not None - assert spec.name == "github_copilot" +def test_openai_compat_provider_passes_model_through(): + from nanobot.providers.openai_compat_provider import OpenAICompatProvider + with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"): + provider = OpenAICompatProvider(default_model="github-copilot/gpt-5.3-codex") -def test_litellm_provider_canonicalizes_github_copilot_hyphen_prefix(): - provider = LiteLLMProvider(default_model="github-copilot/gpt-5.3-codex") - - resolved = provider._resolve_model("github-copilot/gpt-5.3-codex") - - assert resolved == "github_copilot/gpt-5.3-codex" + assert provider.get_default_model() == "github-copilot/gpt-5.3-codex" def test_openai_codex_strip_prefix_supports_hyphen_and_underscore(): @@ -318,7 +339,7 @@ def test_make_provider_passes_extra_headers_to_custom_provider(): } ) - with patch("nanobot.providers.custom_provider.AsyncOpenAI") as mock_async_openai: + with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI") as mock_async_openai: _make_provider(config) kwargs = mock_async_openai.call_args.kwargs @@ -333,10 +354,8 @@ def mock_agent_runtime(tmp_path): """Mock agent command dependencies for focused CLI tests.""" config = Config() config.agents.defaults.workspace = str(tmp_path / "default-workspace") - cron_dir = tmp_path / "data" / "cron" with patch("nanobot.config.loader.load_config", return_value=config) as mock_load_config, \ - patch("nanobot.config.paths.get_cron_dir", return_value=cron_dir), \ patch("nanobot.cli.commands.sync_workspace_templates") as mock_sync_templates, \ patch("nanobot.cli.commands._make_provider", return_value=object()), \ patch("nanobot.cli.commands._print_agent_response") as mock_print_response, \ @@ -413,7 +432,6 @@ def test_agent_config_sets_active_path(monkeypatch, tmp_path: Path) -> None: lambda path: seen.__setitem__("config_path", path), ) monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config) - monkeypatch.setattr("nanobot.config.paths.get_cron_dir", lambda: config_file.parent / "cron") monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None) monkeypatch.setattr("nanobot.cli.commands._make_provider", lambda _config: object()) monkeypatch.setattr("nanobot.bus.queue.MessageBus", lambda: object()) @@ -438,6 +456,147 @@ async def close_mcp(self) -> None: assert seen["config_path"] == config_file.resolve() +def test_agent_uses_workspace_directory_for_cron_store(monkeypatch, tmp_path: Path) -> None: + config_file = tmp_path / "instance" / "config.json" + config_file.parent.mkdir(parents=True) + config_file.write_text("{}") + + config = Config() + config.agents.defaults.workspace = str(tmp_path / "agent-workspace") + seen: dict[str, Path] = {} + + monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None) + monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config) + monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None) + monkeypatch.setattr("nanobot.cli.commands._make_provider", lambda _config: object()) + monkeypatch.setattr("nanobot.bus.queue.MessageBus", lambda: object()) + + class _FakeCron: + def __init__(self, store_path: Path) -> None: + seen["cron_store"] = store_path + + class _FakeAgentLoop: + def __init__(self, *args, **kwargs) -> None: + pass + + async def process_direct(self, *_args, **_kwargs): + return OutboundMessage(channel="cli", chat_id="direct", content="ok") + + async def close_mcp(self) -> None: + return None + + monkeypatch.setattr("nanobot.cron.service.CronService", _FakeCron) + monkeypatch.setattr("nanobot.agent.loop.AgentLoop", _FakeAgentLoop) + monkeypatch.setattr("nanobot.cli.commands._print_agent_response", lambda *_args, **_kwargs: None) + + result = runner.invoke(app, ["agent", "-m", "hello", "-c", str(config_file)]) + + assert result.exit_code == 0 + assert seen["cron_store"] == config.workspace_path / "cron" / "jobs.json" + + +def test_agent_workspace_override_does_not_migrate_legacy_cron( + monkeypatch, tmp_path: Path +) -> None: + config_file = tmp_path / "instance" / "config.json" + config_file.parent.mkdir(parents=True) + config_file.write_text("{}") + + legacy_dir = tmp_path / "global" / "cron" + legacy_dir.mkdir(parents=True) + legacy_file = legacy_dir / "jobs.json" + legacy_file.write_text('{"jobs": []}') + + override = tmp_path / "override-workspace" + config = Config() + seen: dict[str, Path] = {} + + monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None) + monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config) + monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None) + monkeypatch.setattr("nanobot.cli.commands._make_provider", lambda _config: object()) + monkeypatch.setattr("nanobot.bus.queue.MessageBus", lambda: object()) + monkeypatch.setattr("nanobot.config.paths.get_cron_dir", lambda: legacy_dir) + + class _FakeCron: + def __init__(self, store_path: Path) -> None: + seen["cron_store"] = store_path + + class _FakeAgentLoop: + def __init__(self, *args, **kwargs) -> None: + pass + + async def process_direct(self, *_args, **_kwargs): + return OutboundMessage(channel="cli", chat_id="direct", content="ok") + + async def close_mcp(self) -> None: + return None + + monkeypatch.setattr("nanobot.cron.service.CronService", _FakeCron) + monkeypatch.setattr("nanobot.agent.loop.AgentLoop", _FakeAgentLoop) + monkeypatch.setattr("nanobot.cli.commands._print_agent_response", lambda *_args, **_kwargs: None) + + result = runner.invoke( + app, + ["agent", "-m", "hello", "-c", str(config_file), "-w", str(override)], + ) + + assert result.exit_code == 0 + assert seen["cron_store"] == override / "cron" / "jobs.json" + assert legacy_file.exists() + assert not (override / "cron" / "jobs.json").exists() + + +def test_agent_custom_config_workspace_does_not_migrate_legacy_cron( + monkeypatch, tmp_path: Path +) -> None: + config_file = tmp_path / "instance" / "config.json" + config_file.parent.mkdir(parents=True) + config_file.write_text("{}") + + legacy_dir = tmp_path / "global" / "cron" + legacy_dir.mkdir(parents=True) + legacy_file = legacy_dir / "jobs.json" + legacy_file.write_text('{"jobs": []}') + + custom_workspace = tmp_path / "custom-workspace" + config = Config() + config.agents.defaults.workspace = str(custom_workspace) + seen: dict[str, Path] = {} + + monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None) + monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config) + monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None) + monkeypatch.setattr("nanobot.cli.commands._make_provider", lambda _config: object()) + monkeypatch.setattr("nanobot.bus.queue.MessageBus", lambda: object()) + monkeypatch.setattr("nanobot.config.paths.get_cron_dir", lambda: legacy_dir) + + class _FakeCron: + def __init__(self, store_path: Path) -> None: + seen["cron_store"] = store_path + + class _FakeAgentLoop: + def __init__(self, *args, **kwargs) -> None: + pass + + async def process_direct(self, *_args, **_kwargs): + return OutboundMessage(channel="cli", chat_id="direct", content="ok") + + async def close_mcp(self) -> None: + return None + + monkeypatch.setattr("nanobot.cron.service.CronService", _FakeCron) + monkeypatch.setattr("nanobot.agent.loop.AgentLoop", _FakeAgentLoop) + monkeypatch.setattr("nanobot.cli.commands._print_agent_response", lambda *_args, **_kwargs: None) + + result = runner.invoke(app, ["agent", "-m", "hello", "-c", str(config_file)]) + + assert result.exit_code == 0 + assert seen["cron_store"] == custom_workspace / "cron" / "jobs.json" + assert legacy_file.exists() + assert not (custom_workspace / "cron" / "jobs.json").exists() + + def test_agent_overrides_workspace_path(mock_agent_runtime): workspace_path = Path("/tmp/agent-workspace") @@ -544,7 +703,7 @@ def test_gateway_workspace_option_overrides_config(monkeypatch, tmp_path: Path) assert config.workspace_path == override -def test_gateway_uses_config_directory_for_cron_store(monkeypatch, tmp_path: Path) -> None: +def test_gateway_uses_workspace_directory_for_cron_store(monkeypatch, tmp_path: Path) -> None: config_file = tmp_path / "instance" / "config.json" config_file.parent.mkdir(parents=True) config_file.write_text("{}") @@ -555,7 +714,6 @@ def test_gateway_uses_config_directory_for_cron_store(monkeypatch, tmp_path: Pat monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None) monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config) - monkeypatch.setattr("nanobot.config.paths.get_cron_dir", lambda: config_file.parent / "cron") monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None) monkeypatch.setattr("nanobot.cli.commands._make_provider", lambda _config: object()) monkeypatch.setattr("nanobot.bus.queue.MessageBus", lambda: object()) @@ -571,7 +729,130 @@ def __init__(self, store_path: Path) -> None: result = runner.invoke(app, ["gateway", "--config", str(config_file)]) assert isinstance(result.exception, _StopGatewayError) - assert seen["cron_store"] == config_file.parent / "cron" / "jobs.json" + assert seen["cron_store"] == config.workspace_path / "cron" / "jobs.json" + + +def test_gateway_workspace_override_does_not_migrate_legacy_cron( + monkeypatch, tmp_path: Path +) -> None: + config_file = tmp_path / "instance" / "config.json" + config_file.parent.mkdir(parents=True) + config_file.write_text("{}") + + legacy_dir = tmp_path / "global" / "cron" + legacy_dir.mkdir(parents=True) + legacy_file = legacy_dir / "jobs.json" + legacy_file.write_text('{"jobs": []}') + + override = tmp_path / "override-workspace" + config = Config() + seen: dict[str, Path] = {} + + monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None) + monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config) + monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None) + monkeypatch.setattr("nanobot.cli.commands._make_provider", lambda _config: object()) + monkeypatch.setattr("nanobot.bus.queue.MessageBus", lambda: object()) + monkeypatch.setattr("nanobot.session.manager.SessionManager", lambda _workspace: object()) + monkeypatch.setattr("nanobot.config.paths.get_cron_dir", lambda: legacy_dir) + + class _StopCron: + def __init__(self, store_path: Path) -> None: + seen["cron_store"] = store_path + raise _StopGatewayError("stop") + + monkeypatch.setattr("nanobot.cron.service.CronService", _StopCron) + + result = runner.invoke( + app, + ["gateway", "--config", str(config_file), "--workspace", str(override)], + ) + + assert isinstance(result.exception, _StopGatewayError) + assert seen["cron_store"] == override / "cron" / "jobs.json" + assert legacy_file.exists() + assert not (override / "cron" / "jobs.json").exists() + + +def test_gateway_custom_config_workspace_does_not_migrate_legacy_cron( + monkeypatch, tmp_path: Path +) -> None: + config_file = tmp_path / "instance" / "config.json" + config_file.parent.mkdir(parents=True) + config_file.write_text("{}") + + legacy_dir = tmp_path / "global" / "cron" + legacy_dir.mkdir(parents=True) + legacy_file = legacy_dir / "jobs.json" + legacy_file.write_text('{"jobs": []}') + + custom_workspace = tmp_path / "custom-workspace" + config = Config() + config.agents.defaults.workspace = str(custom_workspace) + seen: dict[str, Path] = {} + + monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None) + monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config) + monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None) + monkeypatch.setattr("nanobot.cli.commands._make_provider", lambda _config: object()) + monkeypatch.setattr("nanobot.bus.queue.MessageBus", lambda: object()) + monkeypatch.setattr("nanobot.session.manager.SessionManager", lambda _workspace: object()) + monkeypatch.setattr("nanobot.config.paths.get_cron_dir", lambda: legacy_dir) + + class _StopCron: + def __init__(self, store_path: Path) -> None: + seen["cron_store"] = store_path + raise _StopGatewayError("stop") + + monkeypatch.setattr("nanobot.cron.service.CronService", _StopCron) + + result = runner.invoke(app, ["gateway", "--config", str(config_file)]) + + assert isinstance(result.exception, _StopGatewayError) + assert seen["cron_store"] == custom_workspace / "cron" / "jobs.json" + assert legacy_file.exists() + assert not (custom_workspace / "cron" / "jobs.json").exists() + + +def test_migrate_cron_store_moves_legacy_file(tmp_path: Path) -> None: + """Legacy global jobs.json is moved into the workspace on first run.""" + from nanobot.cli.commands import _migrate_cron_store + + legacy_dir = tmp_path / "global" / "cron" + legacy_dir.mkdir(parents=True) + legacy_file = legacy_dir / "jobs.json" + legacy_file.write_text('{"jobs": []}') + + config = Config() + config.agents.defaults.workspace = str(tmp_path / "workspace") + workspace_cron = config.workspace_path / "cron" / "jobs.json" + + with patch("nanobot.config.paths.get_cron_dir", return_value=legacy_dir): + _migrate_cron_store(config) + + assert workspace_cron.exists() + assert workspace_cron.read_text() == '{"jobs": []}' + assert not legacy_file.exists() + + +def test_migrate_cron_store_skips_when_workspace_file_exists(tmp_path: Path) -> None: + """Migration does not overwrite an existing workspace cron store.""" + from nanobot.cli.commands import _migrate_cron_store + + legacy_dir = tmp_path / "global" / "cron" + legacy_dir.mkdir(parents=True) + (legacy_dir / "jobs.json").write_text('{"old": true}') + + config = Config() + config.agents.defaults.workspace = str(tmp_path / "workspace") + workspace_cron = config.workspace_path / "cron" / "jobs.json" + workspace_cron.parent.mkdir(parents=True) + workspace_cron.write_text('{"new": true}') + + with patch("nanobot.config.paths.get_cron_dir", return_value=legacy_dir): + _migrate_cron_store(config) + + assert workspace_cron.read_text() == '{"new": true}' def test_gateway_uses_configured_port_when_cli_flag_is_missing(monkeypatch, tmp_path: Path) -> None: diff --git a/tests/test_restart_command.py b/tests/cli/test_restart_command.py similarity index 100% rename from tests/test_restart_command.py rename to tests/cli/test_restart_command.py diff --git a/tests/test_config_migration.py b/tests/config/test_config_migration.py similarity index 100% rename from tests/test_config_migration.py rename to tests/config/test_config_migration.py diff --git a/tests/test_config_paths.py b/tests/config/test_config_paths.py similarity index 84% rename from tests/test_config_paths.py rename to tests/config/test_config_paths.py index 473a6c8ca1b..6c560ceb187 100644 --- a/tests/test_config_paths.py +++ b/tests/config/test_config_paths.py @@ -10,6 +10,7 @@ get_media_dir, get_runtime_subdir, get_workspace_path, + is_default_workspace, ) @@ -40,3 +41,9 @@ def test_shared_and_legacy_paths_remain_global() -> None: def test_workspace_path_is_explicitly_resolved() -> None: assert get_workspace_path() == Path.home() / ".nanobot" / "workspace" assert get_workspace_path("~/custom-workspace") == Path.home() / "custom-workspace" + + +def test_is_default_workspace_distinguishes_default_and_custom_paths() -> None: + assert is_default_workspace(None) is True + assert is_default_workspace(Path.home() / ".nanobot" / "workspace") is True + assert is_default_workspace("~/custom-workspace") is False diff --git a/tests/test_cron_service.py b/tests/cron/test_cron_service.py similarity index 100% rename from tests/test_cron_service.py rename to tests/cron/test_cron_service.py diff --git a/tests/test_cron_tool_list.py b/tests/cron/test_cron_tool_list.py similarity index 62% rename from tests/test_cron_tool_list.py rename to tests/cron/test_cron_tool_list.py index 5d882ad8f2a..22a502fa45b 100644 --- a/tests/test_cron_tool_list.py +++ b/tests/cron/test_cron_tool_list.py @@ -1,5 +1,7 @@ """Tests for CronTool._list_jobs() output formatting.""" +from datetime import datetime, timezone + from nanobot.agent.tools.cron import CronTool from nanobot.cron.service import CronService from nanobot.cron.types import CronJobState, CronSchedule @@ -10,99 +12,120 @@ def _make_tool(tmp_path) -> CronTool: return CronTool(service) +def _make_tool_with_tz(tmp_path, tz: str) -> CronTool: + service = CronService(tmp_path / "cron" / "jobs.json") + return CronTool(service, default_timezone=tz) + + # -- _format_timing tests -- -def test_format_timing_cron_with_tz() -> None: +def test_format_timing_cron_with_tz(tmp_path) -> None: + tool = _make_tool(tmp_path) s = CronSchedule(kind="cron", expr="0 9 * * 1-5", tz="America/Denver") - assert CronTool._format_timing(s) == "cron: 0 9 * * 1-5 (America/Denver)" + assert tool._format_timing(s) == "cron: 0 9 * * 1-5 (America/Denver)" -def test_format_timing_cron_without_tz() -> None: +def test_format_timing_cron_without_tz(tmp_path) -> None: + tool = _make_tool(tmp_path) s = CronSchedule(kind="cron", expr="*/5 * * * *") - assert CronTool._format_timing(s) == "cron: */5 * * * *" + assert tool._format_timing(s) == "cron: */5 * * * *" -def test_format_timing_every_hours() -> None: +def test_format_timing_every_hours(tmp_path) -> None: + tool = _make_tool(tmp_path) s = CronSchedule(kind="every", every_ms=7_200_000) - assert CronTool._format_timing(s) == "every 2h" + assert tool._format_timing(s) == "every 2h" -def test_format_timing_every_minutes() -> None: +def test_format_timing_every_minutes(tmp_path) -> None: + tool = _make_tool(tmp_path) s = CronSchedule(kind="every", every_ms=1_800_000) - assert CronTool._format_timing(s) == "every 30m" + assert tool._format_timing(s) == "every 30m" -def test_format_timing_every_seconds() -> None: +def test_format_timing_every_seconds(tmp_path) -> None: + tool = _make_tool(tmp_path) s = CronSchedule(kind="every", every_ms=30_000) - assert CronTool._format_timing(s) == "every 30s" + assert tool._format_timing(s) == "every 30s" -def test_format_timing_every_non_minute_seconds() -> None: +def test_format_timing_every_non_minute_seconds(tmp_path) -> None: + tool = _make_tool(tmp_path) s = CronSchedule(kind="every", every_ms=90_000) - assert CronTool._format_timing(s) == "every 90s" + assert tool._format_timing(s) == "every 90s" -def test_format_timing_every_milliseconds() -> None: +def test_format_timing_every_milliseconds(tmp_path) -> None: + tool = _make_tool(tmp_path) s = CronSchedule(kind="every", every_ms=200) - assert CronTool._format_timing(s) == "every 200ms" + assert tool._format_timing(s) == "every 200ms" -def test_format_timing_at() -> None: +def test_format_timing_at(tmp_path) -> None: + tool = _make_tool_with_tz(tmp_path, "Asia/Shanghai") s = CronSchedule(kind="at", at_ms=1773684000000) - result = CronTool._format_timing(s) + result = tool._format_timing(s) + assert "Asia/Shanghai" in result assert result.startswith("at 2026-") -def test_format_timing_fallback() -> None: +def test_format_timing_fallback(tmp_path) -> None: + tool = _make_tool(tmp_path) s = CronSchedule(kind="every") # no every_ms - assert CronTool._format_timing(s) == "every" + assert tool._format_timing(s) == "every" # -- _format_state tests -- -def test_format_state_empty() -> None: +def test_format_state_empty(tmp_path) -> None: + tool = _make_tool(tmp_path) state = CronJobState() - assert CronTool._format_state(state) == [] + assert tool._format_state(state, CronSchedule(kind="every")) == [] -def test_format_state_last_run_ok() -> None: +def test_format_state_last_run_ok(tmp_path) -> None: + tool = _make_tool(tmp_path) state = CronJobState(last_run_at_ms=1773673200000, last_status="ok") - lines = CronTool._format_state(state) + lines = tool._format_state(state, CronSchedule(kind="cron", expr="0 9 * * *", tz="UTC")) assert len(lines) == 1 assert "Last run:" in lines[0] assert "ok" in lines[0] -def test_format_state_last_run_with_error() -> None: +def test_format_state_last_run_with_error(tmp_path) -> None: + tool = _make_tool(tmp_path) state = CronJobState(last_run_at_ms=1773673200000, last_status="error", last_error="timeout") - lines = CronTool._format_state(state) + lines = tool._format_state(state, CronSchedule(kind="cron", expr="0 9 * * *", tz="UTC")) assert len(lines) == 1 assert "error" in lines[0] assert "timeout" in lines[0] -def test_format_state_next_run_only() -> None: +def test_format_state_next_run_only(tmp_path) -> None: + tool = _make_tool(tmp_path) state = CronJobState(next_run_at_ms=1773684000000) - lines = CronTool._format_state(state) + lines = tool._format_state(state, CronSchedule(kind="cron", expr="0 9 * * *", tz="UTC")) assert len(lines) == 1 assert "Next run:" in lines[0] -def test_format_state_both() -> None: +def test_format_state_both(tmp_path) -> None: + tool = _make_tool(tmp_path) state = CronJobState( last_run_at_ms=1773673200000, last_status="ok", next_run_at_ms=1773684000000 ) - lines = CronTool._format_state(state) + lines = tool._format_state(state, CronSchedule(kind="cron", expr="0 9 * * *", tz="UTC")) assert len(lines) == 2 assert "Last run:" in lines[0] assert "Next run:" in lines[1] -def test_format_state_unknown_status() -> None: +def test_format_state_unknown_status(tmp_path) -> None: + tool = _make_tool(tmp_path) state = CronJobState(last_run_at_ms=1773673200000, last_status=None) - lines = CronTool._format_state(state) + lines = tool._format_state(state, CronSchedule(kind="cron", expr="0 9 * * *", tz="UTC")) assert "unknown" in lines[0] @@ -181,7 +204,7 @@ def test_list_every_job_milliseconds(tmp_path) -> None: def test_list_at_job_shows_iso_timestamp(tmp_path) -> None: - tool = _make_tool(tmp_path) + tool = _make_tool_with_tz(tmp_path, "Asia/Shanghai") tool._cron.add_job( name="One-shot", schedule=CronSchedule(kind="at", at_ms=1773684000000), @@ -189,6 +212,7 @@ def test_list_at_job_shows_iso_timestamp(tmp_path) -> None: ) result = tool._list_jobs() assert "at 2026-" in result + assert "Asia/Shanghai" in result def test_list_shows_last_run_state(tmp_path) -> None: @@ -206,6 +230,7 @@ def test_list_shows_last_run_state(tmp_path) -> None: result = tool._list_jobs() assert "Last run:" in result assert "ok" in result + assert "(UTC)" in result def test_list_shows_error_message(tmp_path) -> None: @@ -234,6 +259,30 @@ def test_list_shows_next_run(tmp_path) -> None: ) result = tool._list_jobs() assert "Next run:" in result + assert "(UTC)" in result + + +def test_add_cron_job_defaults_to_tool_timezone(tmp_path) -> None: + tool = _make_tool_with_tz(tmp_path, "Asia/Shanghai") + tool.set_context("telegram", "chat-1") + + result = tool._add_job("Morning standup", None, "0 8 * * *", None, None) + + assert result.startswith("Created job") + job = tool._cron.list_jobs()[0] + assert job.schedule.tz == "Asia/Shanghai" + + +def test_add_at_job_uses_default_timezone_for_naive_datetime(tmp_path) -> None: + tool = _make_tool_with_tz(tmp_path, "Asia/Shanghai") + tool.set_context("telegram", "chat-1") + + result = tool._add_job("Morning reminder", None, None, None, "2026-03-25T08:00:00") + + assert result.startswith("Created job") + job = tool._cron.list_jobs()[0] + expected = int(datetime(2026, 3, 25, 0, 0, 0, tzinfo=timezone.utc).timestamp() * 1000) + assert job.schedule.at_ms == expected def test_list_excludes_disabled_jobs(tmp_path) -> None: diff --git a/tests/test_azure_openai_provider.py b/tests/providers/test_azure_openai_provider.py similarity index 100% rename from tests/test_azure_openai_provider.py rename to tests/providers/test_azure_openai_provider.py diff --git a/tests/providers/test_custom_provider.py b/tests/providers/test_custom_provider.py new file mode 100644 index 00000000000..d2a9f42473b --- /dev/null +++ b/tests/providers/test_custom_provider.py @@ -0,0 +1,55 @@ +"""Tests for OpenAICompatProvider handling custom/direct endpoints.""" + +from types import SimpleNamespace +from unittest.mock import patch + +from nanobot.providers.openai_compat_provider import OpenAICompatProvider + + +def test_custom_provider_parse_handles_empty_choices() -> None: + with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"): + provider = OpenAICompatProvider() + response = SimpleNamespace(choices=[]) + + result = provider._parse(response) + + assert result.finish_reason == "error" + assert "empty choices" in result.content + + +def test_custom_provider_parse_accepts_plain_string_response() -> None: + with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"): + provider = OpenAICompatProvider() + + result = provider._parse("hello from backend") + + assert result.finish_reason == "stop" + assert result.content == "hello from backend" + + +def test_custom_provider_parse_accepts_dict_response() -> None: + with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"): + provider = OpenAICompatProvider() + + result = provider._parse({ + "choices": [{ + "message": {"content": "hello from dict"}, + "finish_reason": "stop", + }], + "usage": { + "prompt_tokens": 1, + "completion_tokens": 2, + "total_tokens": 3, + }, + }) + + assert result.finish_reason == "stop" + assert result.content == "hello from dict" + assert result.usage["total_tokens"] == 3 + + +def test_custom_provider_parse_chunks_accepts_plain_text_chunks() -> None: + result = OpenAICompatProvider._parse_chunks(["hello ", "world"]) + + assert result.finish_reason == "stop" + assert result.content == "hello world" diff --git a/tests/providers/test_litellm_kwargs.py b/tests/providers/test_litellm_kwargs.py new file mode 100644 index 00000000000..b166cb0263a --- /dev/null +++ b/tests/providers/test_litellm_kwargs.py @@ -0,0 +1,177 @@ +"""Tests for OpenAICompatProvider spec-driven behavior. + +Validates that: +- OpenRouter (no strip) keeps model names intact. +- AiHubMix (strip_model_prefix=True) strips provider prefixes. +- Standard providers pass model names through as-is. +""" + +from __future__ import annotations + +from types import SimpleNamespace +from unittest.mock import AsyncMock, patch + +import pytest + +from nanobot.providers.openai_compat_provider import OpenAICompatProvider +from nanobot.providers.registry import find_by_name + + +def _fake_chat_response(content: str = "ok") -> SimpleNamespace: + """Build a minimal OpenAI chat completion response.""" + message = SimpleNamespace( + content=content, + tool_calls=None, + reasoning_content=None, + ) + choice = SimpleNamespace(message=message, finish_reason="stop") + usage = SimpleNamespace(prompt_tokens=10, completion_tokens=5, total_tokens=15) + return SimpleNamespace(choices=[choice], usage=usage) + + +def _fake_tool_call_response() -> SimpleNamespace: + """Build a minimal chat response that includes Gemini-style extra_content.""" + function = SimpleNamespace( + name="exec", + arguments='{"cmd":"ls"}', + provider_specific_fields={"inner": "value"}, + ) + tool_call = SimpleNamespace( + id="call_123", + index=0, + type="function", + function=function, + extra_content={"google": {"thought_signature": "signed-token"}}, + ) + message = SimpleNamespace( + content=None, + tool_calls=[tool_call], + reasoning_content=None, + ) + choice = SimpleNamespace(message=message, finish_reason="tool_calls") + usage = SimpleNamespace(prompt_tokens=10, completion_tokens=5, total_tokens=15) + return SimpleNamespace(choices=[choice], usage=usage) + + +def test_openrouter_spec_is_gateway() -> None: + spec = find_by_name("openrouter") + assert spec is not None + assert spec.is_gateway is True + assert spec.default_api_base == "https://openrouter.ai/api/v1" + + +@pytest.mark.asyncio +async def test_openrouter_keeps_model_name_intact() -> None: + """OpenRouter gateway keeps the full model name (gateway does its own routing).""" + mock_create = AsyncMock(return_value=_fake_chat_response()) + spec = find_by_name("openrouter") + + with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI") as MockClient: + client_instance = MockClient.return_value + client_instance.chat.completions.create = mock_create + + provider = OpenAICompatProvider( + api_key="sk-or-test-key", + api_base="https://openrouter.ai/api/v1", + default_model="anthropic/claude-sonnet-4-5", + spec=spec, + ) + await provider.chat( + messages=[{"role": "user", "content": "hello"}], + model="anthropic/claude-sonnet-4-5", + ) + + call_kwargs = mock_create.call_args.kwargs + assert call_kwargs["model"] == "anthropic/claude-sonnet-4-5" + + +@pytest.mark.asyncio +async def test_aihubmix_strips_model_prefix() -> None: + """AiHubMix strips the provider prefix (strip_model_prefix=True).""" + mock_create = AsyncMock(return_value=_fake_chat_response()) + spec = find_by_name("aihubmix") + + with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI") as MockClient: + client_instance = MockClient.return_value + client_instance.chat.completions.create = mock_create + + provider = OpenAICompatProvider( + api_key="sk-aihub-test-key", + api_base="https://aihubmix.com/v1", + default_model="claude-sonnet-4-5", + spec=spec, + ) + await provider.chat( + messages=[{"role": "user", "content": "hello"}], + model="anthropic/claude-sonnet-4-5", + ) + + call_kwargs = mock_create.call_args.kwargs + assert call_kwargs["model"] == "claude-sonnet-4-5" + + +@pytest.mark.asyncio +async def test_standard_provider_passes_model_through() -> None: + """Standard provider (e.g. deepseek) passes model name through as-is.""" + mock_create = AsyncMock(return_value=_fake_chat_response()) + spec = find_by_name("deepseek") + + with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI") as MockClient: + client_instance = MockClient.return_value + client_instance.chat.completions.create = mock_create + + provider = OpenAICompatProvider( + api_key="sk-deepseek-test-key", + default_model="deepseek-chat", + spec=spec, + ) + await provider.chat( + messages=[{"role": "user", "content": "hello"}], + model="deepseek-chat", + ) + + call_kwargs = mock_create.call_args.kwargs + assert call_kwargs["model"] == "deepseek-chat" + + +@pytest.mark.asyncio +async def test_openai_compat_preserves_extra_content_on_tool_calls() -> None: + """Gemini extra_content (thought signatures) must survive parseβ†’serialize round-trip.""" + mock_create = AsyncMock(return_value=_fake_tool_call_response()) + spec = find_by_name("gemini") + + with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI") as MockClient: + client_instance = MockClient.return_value + client_instance.chat.completions.create = mock_create + + provider = OpenAICompatProvider( + api_key="test-key", + api_base="https://generativelanguage.googleapis.com/v1beta/openai/", + default_model="google/gemini-3.1-pro-preview", + spec=spec, + ) + result = await provider.chat( + messages=[{"role": "user", "content": "run exec"}], + model="google/gemini-3.1-pro-preview", + ) + + assert len(result.tool_calls) == 1 + tool_call = result.tool_calls[0] + assert tool_call.extra_content == {"google": {"thought_signature": "signed-token"}} + assert tool_call.function_provider_specific_fields == {"inner": "value"} + + serialized = tool_call.to_openai_tool_call() + assert serialized["extra_content"] == {"google": {"thought_signature": "signed-token"}} + assert serialized["function"]["provider_specific_fields"] == {"inner": "value"} + + +def test_openai_model_passthrough() -> None: + """OpenAI models pass through unchanged.""" + spec = find_by_name("openai") + with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"): + provider = OpenAICompatProvider( + api_key="sk-test-key", + default_model="gpt-4o", + spec=spec, + ) + assert provider.get_default_model() == "gpt-4o" diff --git a/tests/test_mistral_provider.py b/tests/providers/test_mistral_provider.py similarity index 87% rename from tests/test_mistral_provider.py rename to tests/providers/test_mistral_provider.py index 40112217823..30023afe74a 100644 --- a/tests/test_mistral_provider.py +++ b/tests/providers/test_mistral_provider.py @@ -17,6 +17,4 @@ def test_mistral_provider_in_registry(): mistral = specs["mistral"] assert mistral.env_key == "MISTRAL_API_KEY" - assert mistral.litellm_prefix == "mistral" assert mistral.default_api_base == "https://api.mistral.ai/v1" - assert "mistral/" in mistral.skip_prefixes diff --git a/tests/test_provider_retry.py b/tests/providers/test_provider_retry.py similarity index 100% rename from tests/test_provider_retry.py rename to tests/providers/test_provider_retry.py diff --git a/tests/test_providers_init.py b/tests/providers/test_providers_init.py similarity index 58% rename from tests/test_providers_init.py rename to tests/providers/test_providers_init.py index 02ab7c1efe8..32cbab47883 100644 --- a/tests/test_providers_init.py +++ b/tests/providers/test_providers_init.py @@ -8,19 +8,22 @@ def test_importing_providers_package_is_lazy(monkeypatch) -> None: monkeypatch.delitem(sys.modules, "nanobot.providers", raising=False) - monkeypatch.delitem(sys.modules, "nanobot.providers.litellm_provider", raising=False) + monkeypatch.delitem(sys.modules, "nanobot.providers.anthropic_provider", raising=False) + monkeypatch.delitem(sys.modules, "nanobot.providers.openai_compat_provider", raising=False) monkeypatch.delitem(sys.modules, "nanobot.providers.openai_codex_provider", raising=False) monkeypatch.delitem(sys.modules, "nanobot.providers.azure_openai_provider", raising=False) providers = importlib.import_module("nanobot.providers") - assert "nanobot.providers.litellm_provider" not in sys.modules + assert "nanobot.providers.anthropic_provider" not in sys.modules + assert "nanobot.providers.openai_compat_provider" not in sys.modules assert "nanobot.providers.openai_codex_provider" not in sys.modules assert "nanobot.providers.azure_openai_provider" not in sys.modules assert providers.__all__ == [ "LLMProvider", "LLMResponse", - "LiteLLMProvider", + "AnthropicProvider", + "OpenAICompatProvider", "OpenAICodexProvider", "AzureOpenAIProvider", ] @@ -28,10 +31,10 @@ def test_importing_providers_package_is_lazy(monkeypatch) -> None: def test_explicit_provider_import_still_works(monkeypatch) -> None: monkeypatch.delitem(sys.modules, "nanobot.providers", raising=False) - monkeypatch.delitem(sys.modules, "nanobot.providers.litellm_provider", raising=False) + monkeypatch.delitem(sys.modules, "nanobot.providers.anthropic_provider", raising=False) namespace: dict[str, object] = {} - exec("from nanobot.providers import LiteLLMProvider", namespace) + exec("from nanobot.providers import AnthropicProvider", namespace) - assert namespace["LiteLLMProvider"].__name__ == "LiteLLMProvider" - assert "nanobot.providers.litellm_provider" in sys.modules + assert namespace["AnthropicProvider"].__name__ == "AnthropicProvider" + assert "nanobot.providers.anthropic_provider" in sys.modules diff --git a/tests/test_security_network.py b/tests/security/test_security_network.py similarity index 100% rename from tests/test_security_network.py rename to tests/security/test_security_network.py diff --git a/tests/test_channel_plugins.py b/tests/test_channel_plugins.py deleted file mode 100644 index 3f34dc59885..00000000000 --- a/tests/test_channel_plugins.py +++ /dev/null @@ -1,264 +0,0 @@ -"""Tests for channel plugin discovery, merging, and config compatibility.""" - -from __future__ import annotations - -from types import SimpleNamespace -from unittest.mock import patch - -import pytest - -from nanobot.bus.events import OutboundMessage -from nanobot.bus.queue import MessageBus -from nanobot.channels.base import BaseChannel -from nanobot.channels.manager import ChannelManager -from nanobot.config.schema import ChannelsConfig - - -# --------------------------------------------------------------------------- -# Helpers -# --------------------------------------------------------------------------- - -class _FakePlugin(BaseChannel): - name = "fakeplugin" - display_name = "Fake Plugin" - - def __init__(self, config, bus): - super().__init__(config, bus) - self.login_calls: list[bool] = [] - - async def start(self) -> None: - pass - - async def stop(self) -> None: - pass - - async def send(self, msg: OutboundMessage) -> None: - pass - - async def login(self, force: bool = False) -> bool: - self.login_calls.append(force) - return True - - -class _FakeTelegram(BaseChannel): - """Plugin that tries to shadow built-in telegram.""" - name = "telegram" - display_name = "Fake Telegram" - - async def start(self) -> None: - pass - - async def stop(self) -> None: - pass - - async def send(self, msg: OutboundMessage) -> None: - pass - - -def _make_entry_point(name: str, cls: type): - """Create a mock entry point that returns *cls* on load().""" - ep = SimpleNamespace(name=name, load=lambda _cls=cls: _cls) - return ep - - -# --------------------------------------------------------------------------- -# ChannelsConfig extra="allow" -# --------------------------------------------------------------------------- - -def test_channels_config_accepts_unknown_keys(): - cfg = ChannelsConfig.model_validate({ - "myplugin": {"enabled": True, "token": "abc"}, - }) - extra = cfg.model_extra - assert extra is not None - assert extra["myplugin"]["enabled"] is True - assert extra["myplugin"]["token"] == "abc" - - -def test_channels_config_getattr_returns_extra(): - cfg = ChannelsConfig.model_validate({"myplugin": {"enabled": True}}) - section = getattr(cfg, "myplugin", None) - assert isinstance(section, dict) - assert section["enabled"] is True - - -def test_channels_config_builtin_fields_removed(): - """After decoupling, ChannelsConfig has no explicit channel fields.""" - cfg = ChannelsConfig() - assert not hasattr(cfg, "telegram") - assert cfg.send_progress is True - assert cfg.send_tool_hints is False - - -# --------------------------------------------------------------------------- -# discover_plugins -# --------------------------------------------------------------------------- - -_EP_TARGET = "importlib.metadata.entry_points" - - -def test_discover_plugins_loads_entry_points(): - from nanobot.channels.registry import discover_plugins - - ep = _make_entry_point("line", _FakePlugin) - with patch(_EP_TARGET, return_value=[ep]): - result = discover_plugins() - - assert "line" in result - assert result["line"] is _FakePlugin - - -def test_discover_plugins_handles_load_error(): - from nanobot.channels.registry import discover_plugins - - def _boom(): - raise RuntimeError("broken") - - ep = SimpleNamespace(name="broken", load=_boom) - with patch(_EP_TARGET, return_value=[ep]): - result = discover_plugins() - - assert "broken" not in result - - -# --------------------------------------------------------------------------- -# discover_all β€” merge & priority -# --------------------------------------------------------------------------- - -def test_discover_all_includes_builtins(): - from nanobot.channels.registry import discover_all, discover_channel_names - - with patch(_EP_TARGET, return_value=[]): - result = discover_all() - - # discover_all() only returns channels that are actually available (dependencies installed) - # discover_channel_names() returns all built-in channel names - # So we check that all actually loaded channels are in the result - for name in result: - assert name in discover_channel_names() - - -def test_discover_all_includes_external_plugin(): - from nanobot.channels.registry import discover_all - - ep = _make_entry_point("line", _FakePlugin) - with patch(_EP_TARGET, return_value=[ep]): - result = discover_all() - - assert "line" in result - assert result["line"] is _FakePlugin - - -def test_discover_all_builtin_shadows_plugin(): - from nanobot.channels.registry import discover_all - - ep = _make_entry_point("telegram", _FakeTelegram) - with patch(_EP_TARGET, return_value=[ep]): - result = discover_all() - - assert "telegram" in result - assert result["telegram"] is not _FakeTelegram - - -# --------------------------------------------------------------------------- -# Manager _init_channels with dict config (plugin scenario) -# --------------------------------------------------------------------------- - -@pytest.mark.asyncio -async def test_manager_loads_plugin_from_dict_config(): - """ChannelManager should instantiate a plugin channel from a raw dict config.""" - from nanobot.channels.manager import ChannelManager - - fake_config = SimpleNamespace( - channels=ChannelsConfig.model_validate({ - "fakeplugin": {"enabled": True, "allowFrom": ["*"]}, - }), - providers=SimpleNamespace(groq=SimpleNamespace(api_key="")), - ) - - with patch( - "nanobot.channels.registry.discover_all", - return_value={"fakeplugin": _FakePlugin}, - ): - mgr = ChannelManager.__new__(ChannelManager) - mgr.config = fake_config - mgr.bus = MessageBus() - mgr.channels = {} - mgr._dispatch_task = None - mgr._init_channels() - - assert "fakeplugin" in mgr.channels - assert isinstance(mgr.channels["fakeplugin"], _FakePlugin) - - -def test_channels_login_uses_discovered_plugin_class(monkeypatch): - from nanobot.cli.commands import app - from nanobot.config.schema import Config - from typer.testing import CliRunner - - runner = CliRunner() - seen: dict[str, object] = {} - - class _LoginPlugin(_FakePlugin): - display_name = "Login Plugin" - - async def login(self, force: bool = False) -> bool: - seen["force"] = force - seen["config"] = self.config - return True - - monkeypatch.setattr("nanobot.config.loader.load_config", lambda: Config()) - monkeypatch.setattr( - "nanobot.channels.registry.discover_all", - lambda: {"fakeplugin": _LoginPlugin}, - ) - - result = runner.invoke(app, ["channels", "login", "fakeplugin", "--force"]) - - assert result.exit_code == 0 - assert seen["force"] is True - - -@pytest.mark.asyncio -async def test_manager_skips_disabled_plugin(): - fake_config = SimpleNamespace( - channels=ChannelsConfig.model_validate({ - "fakeplugin": {"enabled": False}, - }), - providers=SimpleNamespace(groq=SimpleNamespace(api_key="")), - ) - - with patch( - "nanobot.channels.registry.discover_all", - return_value={"fakeplugin": _FakePlugin}, - ): - mgr = ChannelManager.__new__(ChannelManager) - mgr.config = fake_config - mgr.bus = MessageBus() - mgr.channels = {} - mgr._dispatch_task = None - mgr._init_channels() - - assert "fakeplugin" not in mgr.channels - - -# --------------------------------------------------------------------------- -# Built-in channel default_config() and dict->Pydantic conversion -# --------------------------------------------------------------------------- - -def test_builtin_channel_default_config(): - """Built-in channels expose default_config() returning a dict with 'enabled': False.""" - from nanobot.channels.telegram import TelegramChannel - cfg = TelegramChannel.default_config() - assert isinstance(cfg, dict) - assert cfg["enabled"] is False - assert "token" in cfg - - -def test_builtin_channel_init_from_dict(): - """Built-in channels accept a raw dict and convert to Pydantic internally.""" - from nanobot.channels.telegram import TelegramChannel - bus = MessageBus() - ch = TelegramChannel({"enabled": False, "token": "test-tok", "allowFrom": ["*"]}, bus) - assert ch.config.token == "test-tok" - assert ch.config.allow_from == ["*"] diff --git a/tests/test_custom_provider.py b/tests/test_custom_provider.py deleted file mode 100644 index 463affedc39..00000000000 --- a/tests/test_custom_provider.py +++ /dev/null @@ -1,13 +0,0 @@ -from types import SimpleNamespace - -from nanobot.providers.custom_provider import CustomProvider - - -def test_custom_provider_parse_handles_empty_choices() -> None: - provider = CustomProvider() - response = SimpleNamespace(choices=[]) - - result = provider._parse(response) - - assert result.finish_reason == "error" - assert "empty choices" in result.content diff --git a/tests/test_litellm_kwargs.py b/tests/test_litellm_kwargs.py deleted file mode 100644 index 437f8a55562..00000000000 --- a/tests/test_litellm_kwargs.py +++ /dev/null @@ -1,161 +0,0 @@ -"""Regression tests for PR #2026 β€” litellm_kwargs injection from ProviderSpec. - -Validates that: -- OpenRouter uses litellm_prefix (NOT custom_llm_provider) to avoid LiteLLM double-prefixing. -- The litellm_kwargs mechanism works correctly for providers that declare it. -- Non-gateway providers are unaffected. -""" - -from __future__ import annotations - -from types import SimpleNamespace -from typing import Any -from unittest.mock import AsyncMock, patch - -import pytest - -from nanobot.providers.litellm_provider import LiteLLMProvider -from nanobot.providers.registry import find_by_name - - -def _fake_response(content: str = "ok") -> SimpleNamespace: - """Build a minimal acompletion-shaped response object.""" - message = SimpleNamespace( - content=content, - tool_calls=None, - reasoning_content=None, - thinking_blocks=None, - ) - choice = SimpleNamespace(message=message, finish_reason="stop") - usage = SimpleNamespace(prompt_tokens=10, completion_tokens=5, total_tokens=15) - return SimpleNamespace(choices=[choice], usage=usage) - - -def test_openrouter_spec_uses_prefix_not_custom_llm_provider() -> None: - """OpenRouter must rely on litellm_prefix, not custom_llm_provider kwarg. - - LiteLLM internally adds a provider/ prefix when custom_llm_provider is set, - which double-prefixes models (openrouter/anthropic/model) and breaks the API. - """ - spec = find_by_name("openrouter") - assert spec is not None - assert spec.litellm_prefix == "openrouter" - assert "custom_llm_provider" not in spec.litellm_kwargs, ( - "custom_llm_provider causes LiteLLM to double-prefix the model name" - ) - - -@pytest.mark.asyncio -async def test_openrouter_prefixes_model_correctly() -> None: - """OpenRouter should prefix model as openrouter/vendor/model for LiteLLM routing.""" - mock_acompletion = AsyncMock(return_value=_fake_response()) - - with patch("nanobot.providers.litellm_provider.acompletion", mock_acompletion): - provider = LiteLLMProvider( - api_key="sk-or-test-key", - api_base="https://openrouter.ai/api/v1", - default_model="anthropic/claude-sonnet-4-5", - provider_name="openrouter", - ) - await provider.chat( - messages=[{"role": "user", "content": "hello"}], - model="anthropic/claude-sonnet-4-5", - ) - - call_kwargs = mock_acompletion.call_args.kwargs - assert call_kwargs["model"] == "openrouter/anthropic/claude-sonnet-4-5", ( - "LiteLLM needs openrouter/ prefix to detect the provider and strip it before API call" - ) - assert "custom_llm_provider" not in call_kwargs - - -@pytest.mark.asyncio -async def test_non_gateway_provider_no_extra_kwargs() -> None: - """Standard (non-gateway) providers must NOT inject any litellm_kwargs.""" - mock_acompletion = AsyncMock(return_value=_fake_response()) - - with patch("nanobot.providers.litellm_provider.acompletion", mock_acompletion): - provider = LiteLLMProvider( - api_key="sk-ant-test-key", - default_model="claude-sonnet-4-5", - ) - await provider.chat( - messages=[{"role": "user", "content": "hello"}], - model="claude-sonnet-4-5", - ) - - call_kwargs = mock_acompletion.call_args.kwargs - assert "custom_llm_provider" not in call_kwargs, ( - "Standard Anthropic provider should NOT inject custom_llm_provider" - ) - - -@pytest.mark.asyncio -async def test_gateway_without_litellm_kwargs_injects_nothing_extra() -> None: - """Gateways without litellm_kwargs (e.g. AiHubMix) must not add extra keys.""" - mock_acompletion = AsyncMock(return_value=_fake_response()) - - with patch("nanobot.providers.litellm_provider.acompletion", mock_acompletion): - provider = LiteLLMProvider( - api_key="sk-aihub-test-key", - api_base="https://aihubmix.com/v1", - default_model="claude-sonnet-4-5", - provider_name="aihubmix", - ) - await provider.chat( - messages=[{"role": "user", "content": "hello"}], - model="claude-sonnet-4-5", - ) - - call_kwargs = mock_acompletion.call_args.kwargs - assert "custom_llm_provider" not in call_kwargs - - -@pytest.mark.asyncio -async def test_openrouter_autodetect_by_key_prefix() -> None: - """OpenRouter should be auto-detected by sk-or- key prefix even without explicit provider_name.""" - mock_acompletion = AsyncMock(return_value=_fake_response()) - - with patch("nanobot.providers.litellm_provider.acompletion", mock_acompletion): - provider = LiteLLMProvider( - api_key="sk-or-auto-detect-key", - default_model="anthropic/claude-sonnet-4-5", - ) - await provider.chat( - messages=[{"role": "user", "content": "hello"}], - model="anthropic/claude-sonnet-4-5", - ) - - call_kwargs = mock_acompletion.call_args.kwargs - assert call_kwargs["model"] == "openrouter/anthropic/claude-sonnet-4-5", ( - "Auto-detected OpenRouter should prefix model for LiteLLM routing" - ) - - -@pytest.mark.asyncio -async def test_openrouter_native_model_id_gets_double_prefixed() -> None: - """Models like openrouter/free must be double-prefixed so LiteLLM strips one layer. - - openrouter/free is an actual OpenRouter model ID. LiteLLM strips the first - openrouter/ for routing, so we must send openrouter/openrouter/free to ensure - the API receives openrouter/free. - """ - mock_acompletion = AsyncMock(return_value=_fake_response()) - - with patch("nanobot.providers.litellm_provider.acompletion", mock_acompletion): - provider = LiteLLMProvider( - api_key="sk-or-test-key", - api_base="https://openrouter.ai/api/v1", - default_model="openrouter/free", - provider_name="openrouter", - ) - await provider.chat( - messages=[{"role": "user", "content": "hello"}], - model="openrouter/free", - ) - - call_kwargs = mock_acompletion.call_args.kwargs - assert call_kwargs["model"] == "openrouter/openrouter/free", ( - "openrouter/free must become openrouter/openrouter/free β€” " - "LiteLLM strips one layer so the API receives openrouter/free" - ) diff --git a/tests/test_weixin_channel.py b/tests/test_weixin_channel.py deleted file mode 100644 index a16c6b7509d..00000000000 --- a/tests/test_weixin_channel.py +++ /dev/null @@ -1,127 +0,0 @@ -import asyncio -from unittest.mock import AsyncMock - -import pytest - -from nanobot.bus.queue import MessageBus -from nanobot.channels.weixin import ( - ITEM_IMAGE, - ITEM_TEXT, - MESSAGE_TYPE_BOT, - WeixinChannel, - WeixinConfig, -) - - -def _make_channel() -> tuple[WeixinChannel, MessageBus]: - bus = MessageBus() - channel = WeixinChannel( - WeixinConfig(enabled=True, allow_from=["*"]), - bus, - ) - return channel, bus - - -@pytest.mark.asyncio -async def test_process_message_deduplicates_inbound_ids() -> None: - channel, bus = _make_channel() - msg = { - "message_type": 1, - "message_id": "m1", - "from_user_id": "wx-user", - "context_token": "ctx-1", - "item_list": [ - {"type": ITEM_TEXT, "text_item": {"text": "hello"}}, - ], - } - - await channel._process_message(msg) - first = await asyncio.wait_for(bus.consume_inbound(), timeout=1.0) - await channel._process_message(msg) - - assert first.sender_id == "wx-user" - assert first.chat_id == "wx-user" - assert first.content == "hello" - assert bus.inbound_size == 0 - - -@pytest.mark.asyncio -async def test_process_message_caches_context_token_and_send_uses_it() -> None: - channel, _bus = _make_channel() - channel._client = object() - channel._token = "token" - channel._send_text = AsyncMock() - - await channel._process_message( - { - "message_type": 1, - "message_id": "m2", - "from_user_id": "wx-user", - "context_token": "ctx-2", - "item_list": [ - {"type": ITEM_TEXT, "text_item": {"text": "ping"}}, - ], - } - ) - - await channel.send( - type("Msg", (), {"chat_id": "wx-user", "content": "pong", "media": [], "metadata": {}})() - ) - - channel._send_text.assert_awaited_once_with("wx-user", "pong", "ctx-2") - - -@pytest.mark.asyncio -async def test_process_message_extracts_media_and_preserves_paths() -> None: - channel, bus = _make_channel() - channel._download_media_item = AsyncMock(return_value="/tmp/test.jpg") - - await channel._process_message( - { - "message_type": 1, - "message_id": "m3", - "from_user_id": "wx-user", - "context_token": "ctx-3", - "item_list": [ - {"type": ITEM_IMAGE, "image_item": {"media": {"encrypt_query_param": "x"}}}, - ], - } - ) - - inbound = await asyncio.wait_for(bus.consume_inbound(), timeout=1.0) - - assert "[image]" in inbound.content - assert "/tmp/test.jpg" in inbound.content - assert inbound.media == ["/tmp/test.jpg"] - - -@pytest.mark.asyncio -async def test_send_without_context_token_does_not_send_text() -> None: - channel, _bus = _make_channel() - channel._client = object() - channel._token = "token" - channel._send_text = AsyncMock() - - await channel.send( - type("Msg", (), {"chat_id": "unknown-user", "content": "pong", "media": [], "metadata": {}})() - ) - - channel._send_text.assert_not_awaited() - - -@pytest.mark.asyncio -async def test_process_message_skips_bot_messages() -> None: - channel, bus = _make_channel() - - await channel._process_message( - { - "message_type": MESSAGE_TYPE_BOT, - "message_id": "m4", - "from_user_id": "wx-user", - "item_list": [ - {"type": ITEM_TEXT, "text_item": {"text": "hello"}}, - ], - } - ) - - assert bus.inbound_size == 0 diff --git a/tests/test_exec_security.py b/tests/tools/test_exec_security.py similarity index 100% rename from tests/test_exec_security.py rename to tests/tools/test_exec_security.py diff --git a/tests/test_filesystem_tools.py b/tests/tools/test_filesystem_tools.py similarity index 95% rename from tests/test_filesystem_tools.py rename to tests/tools/test_filesystem_tools.py index 76d0a512481..ca6629edbb5 100644 --- a/tests/test_filesystem_tools.py +++ b/tests/tools/test_filesystem_tools.py @@ -77,6 +77,11 @@ async def test_file_not_found(self, tool, tmp_path): assert "Error" in result assert "not found" in result + @pytest.mark.asyncio + async def test_missing_path_returns_clear_error(self, tool): + result = await tool.execute() + assert result == "Error reading file: Unknown path" + @pytest.mark.asyncio async def test_char_budget_trims(self, tool, tmp_path): """When the selected slice exceeds _MAX_CHARS the output is trimmed.""" @@ -200,6 +205,13 @@ async def test_not_found(self, tool, tmp_path): assert "Error" in result assert "not found" in result + @pytest.mark.asyncio + async def test_missing_new_text_returns_clear_error(self, tool, tmp_path): + f = tmp_path / "a.py" + f.write_text("hello", encoding="utf-8") + result = await tool.execute(path=str(f), old_text="hello") + assert result == "Error editing file: Unknown new_text" + # --------------------------------------------------------------------------- # ListDirTool @@ -265,6 +277,11 @@ async def test_not_found(self, tool, tmp_path): assert "Error" in result assert "not found" in result + @pytest.mark.asyncio + async def test_missing_path_returns_clear_error(self, tool): + result = await tool.execute() + assert result == "Error listing directory: Unknown path" + # --------------------------------------------------------------------------- # Workspace restriction + extra_allowed_dirs diff --git a/tests/test_mcp_tool.py b/tests/tools/test_mcp_tool.py similarity index 100% rename from tests/test_mcp_tool.py rename to tests/tools/test_mcp_tool.py diff --git a/tests/test_message_tool.py b/tests/tools/test_message_tool.py similarity index 100% rename from tests/test_message_tool.py rename to tests/tools/test_message_tool.py diff --git a/tests/test_message_tool_suppress.py b/tests/tools/test_message_tool_suppress.py similarity index 100% rename from tests/test_message_tool_suppress.py rename to tests/tools/test_message_tool_suppress.py diff --git a/tests/test_tool_validation.py b/tests/tools/test_tool_validation.py similarity index 100% rename from tests/test_tool_validation.py rename to tests/tools/test_tool_validation.py diff --git a/tests/test_web_fetch_security.py b/tests/tools/test_web_fetch_security.py similarity index 100% rename from tests/test_web_fetch_security.py rename to tests/tools/test_web_fetch_security.py diff --git a/tests/test_web_search_tool.py b/tests/tools/test_web_search_tool.py similarity index 100% rename from tests/test_web_search_tool.py rename to tests/tools/test_web_search_tool.py