From 36807b1b0f75bd6c199ad85d5dcd50885d68b0b6 Mon Sep 17 00:00:00 2001 From: Frost Ming Date: Fri, 24 Apr 2026 12:18:30 +0800 Subject: [PATCH 1/3] feat: new bub command: onboard Signed-off-by: Frost Ming --- README.md | 20 +- env.example | 6 +- pyproject.toml | 1 + src/bub/builtin/cli.py | 31 ++ src/bub/builtin/hook_impl.py | 134 ++++++++- src/bub/builtin/settings.py | 2 +- src/bub/configure.py | 51 ++++ src/bub/framework.py | 24 +- src/bub/hookspecs.py | 14 +- tests/test_builtin_cli.py | 276 +++++++++++++++++- tests/test_configure.py | 77 +++++ tests/test_framework.py | 81 ++++- tests/test_settings.py | 4 +- uv.lock | 14 + .../src/content/docs/docs/extending/hooks.mdx | 36 +++ .../content/docs/docs/guides/deployment.mdx | 2 +- .../docs/zh-cn/docs/extending/hooks.mdx | 36 +++ .../docs/zh-cn/docs/guides/deployment.mdx | 2 +- 18 files changed, 768 insertions(+), 43 deletions(-) create mode 100644 tests/test_configure.py diff --git a/README.md b/README.md index 30d186cd..cec14a14 100644 --- a/README.md +++ b/README.md @@ -106,16 +106,16 @@ Lines starting with `,` enter internal command mode (`,help`, `,skill name=my-sk ## Configuration -| Variable | Default | Description | -|----------|---------|-------------| -| `BUB_MODEL` | `openrouter:qwen/qwen3-coder-next` | Model identifier | -| `BUB_API_KEY` | — | Provider key (optional with `bub login openai`) | -| `BUB_API_BASE` | — | Custom provider endpoint | -| `BUB_API_FORMAT` | `completion` | `completion`, `responses`, or `messages` | -| `BUB_CLIENT_ARGS` | — | JSON object forwarded to the underlying model client | -| `BUB_MAX_STEPS` | `50` | Max tool-use loop iterations | -| `BUB_MAX_TOKENS` | `1024` | Max tokens per model call | -| `BUB_MODEL_TIMEOUT_SECONDS` | — | Model call timeout (seconds) | +| Variable | Default | Description | +| --------------------------- | ---------------------------- | ---------------------------------------------------- | +| `BUB_MODEL` | `openrouter:openrouter/free` | Model identifier | +| `BUB_API_KEY` | — | Provider key (optional with `bub login openai`) | +| `BUB_API_BASE` | — | Custom provider endpoint | +| `BUB_API_FORMAT` | `completion` | `completion`, `responses`, or `messages` | +| `BUB_CLIENT_ARGS` | — | JSON object forwarded to the underlying model client | +| `BUB_MAX_STEPS` | `50` | Max tool-use loop iterations | +| `BUB_MAX_TOKENS` | `1024` | Max tokens per model call | +| `BUB_MODEL_TIMEOUT_SECONDS` | — | Model call timeout (seconds) | ## Background diff --git a/env.example b/env.example index 54ad7b23..32366cc1 100644 --- a/env.example +++ b/env.example @@ -5,8 +5,8 @@ # Agent runtime # --------------------------------------------------------------------------- # Republic model format: provider:model_id -# Default in code is `openrouter:qwen/qwen3-coder-next`. -# BUB_MODEL=openrouter:qwen/qwen3-coder-next +# Default in code is `openrouter:openrouter/free`. +# BUB_MODEL=openrouter:openrouter/free # BUB_MAX_STEPS=50 # BUB_MAX_TOKENS=1024 # BUB_MODEL_TIMEOUT_SECONDS=300 @@ -58,6 +58,6 @@ # --------------------------------------------------------------------------- # Example minimal OpenRouter setup # --------------------------------------------------------------------------- -# BUB_MODEL=openrouter:qwen/qwen3-coder-next +# BUB_MODEL=openrouter:openrouter/free # BUB_API_KEY=sk-or-... # BUB_CLIENT_ARGS={"extra_headers":{"HTTP-Referer":"https://openclaw.ai","X-Title":"OpenClaw"}} diff --git a/pyproject.toml b/pyproject.toml index 425c3deb..9323216f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,6 +24,7 @@ dependencies = [ "pydantic-settings>=2.0.0", "pyyaml>=6.0.0", "pluggy>=1.6.0", + "questionary>=2.1.0", "typer>=0.9.0", "republic>=0.5.4", "any-llm-sdk[anthropic]", diff --git a/src/bub/builtin/cli.py b/src/bub/builtin/cli.py index df6c405d..65b42813 100644 --- a/src/bub/builtin/cli.py +++ b/src/bub/builtin/cli.py @@ -16,11 +16,23 @@ import typer +from bub import __version__, configure from bub.builtin.auth import app as login_app # noqa: F401 from bub.channels.message import ChannelMessage from bub.envelope import field_of from bub.framework import BubFramework +ONBOARD_BANNER = r""" + ███████████ █████ +▒▒███▒▒▒▒▒███ ▒▒███ + ▒███ ▒███ █████ ████ ▒███████ + ▒██████████ ▒▒███ ▒███ ▒███▒▒███ + ▒███▒▒▒▒▒███ ▒███ ▒███ ▒███ ▒███ + ▒███ ▒███ ▒███ ▒███ ▒███ ▒███ + ███████████ ▒▒████████ ████████ +▒▒▒▒▒▒▒▒▒▒▒ ▒▒▒▒▒▒▒▒ ▒▒▒▒▒▒▒▒ v{version} +""".strip("\n") + def run( ctx: typer.Context, @@ -92,6 +104,25 @@ def chat( asyncio.run(manager.listen_and_run()) +def onboard(ctx: typer.Context) -> None: + """Interactively collect plugin configuration and write it to Bub's config file.""" + + framework = ctx.ensure_object(BubFramework) + typer.echo(ONBOARD_BANNER.format(version=__version__)) + typer.echo("\nWelcome to Bub! Let's get you set up.\n") + + try: + config_data = framework.collect_onboard_config() + configure.save(framework.config_file, config_data) + except (typer.Abort, typer.Exit): + raise + except Exception as exc: + typer.secho(f"Onboarding failed: {exc}", err=True, fg="red") + raise typer.Exit(1) from exc + + typer.echo(f"Saved config to {framework.config_file}") + + @lru_cache(maxsize=1) def _find_uv() -> str: import shutil diff --git a/src/bub/builtin/hook_impl.py b/src/bub/builtin/hook_impl.py index 6c717d6d..5c0deb54 100644 --- a/src/bub/builtin/hook_impl.py +++ b/src/bub/builtin/hook_impl.py @@ -1,8 +1,9 @@ import sys from datetime import UTC, datetime from pathlib import Path -from typing import cast +from typing import Any, cast +import questionary import typer from loguru import logger from republic import AsyncStreamEvents, TapeContext @@ -10,6 +11,7 @@ from bub.builtin.agent import Agent from bub.builtin.context import default_tape_context +from bub.builtin.settings import DEFAULT_MODEL from bub.channels.base import Channel from bub.channels.message import ChannelMessage, MediaItem from bub.envelope import content_of, field_of @@ -18,6 +20,19 @@ from bub.types import Envelope, MessageHandler, State AGENTS_FILE_NAME = "AGENTS.md" +MODEL_PROVIDER_CHOICES: tuple[str, ...] = ( + "openrouter", + "openai", + "anthropic", + "gemini", + "azure", + "bedrock", + "ollama", + "groq", + "mistral", + "deepseek", +) +API_FORMAT_CHOICES: tuple[str, ...] = ("completion", "responses", "messages") DEFAULT_SYSTEM_PROMPT = """\ Call tools or skills to finish the task. @@ -55,6 +70,49 @@ def _get_agent(self) -> Agent: self._agent = Agent(self.framework) return self._agent + @staticmethod + def _ask_onboard_value(question: Any) -> Any: + answer = question.ask() + if answer is None: + raise typer.Abort() + return answer + + @classmethod + def _ask_onboard(cls, question: Any) -> str: + answer = cls._ask_onboard_value(question) + return str(answer).strip() + + @staticmethod + async def _discard_message(_: ChannelMessage) -> None: + return + + @staticmethod + def _split_model_identifier(model: str) -> tuple[str, str]: + provider, separator, model_name = model.partition(":") + if separator and provider and model_name: + return provider.strip(), model_name.strip() + default_provider, _, default_model_name = DEFAULT_MODEL.partition(":") + fallback_model_name = model.strip() or default_model_name + return default_provider, fallback_model_name + + @staticmethod + def _provider_choices(current_provider: str) -> list[str]: + choices = list(MODEL_PROVIDER_CHOICES) + if current_provider and current_provider not in choices: + choices.append(current_provider) + choices.append("custom") + return choices + + def _channel_choices(self) -> list[str]: + return [c for c in self.framework.get_channels(self._discard_message) if c != "cli"] + + @staticmethod + def _default_enabled_channels(current_value: object, available_channels: list[str]) -> list[str]: + if isinstance(current_value, str) and current_value.strip() and current_value.strip().lower() != "all": + selected = [name.strip() for name in current_value.split(",") if name.strip() in available_channels] + return selected + return available_channels + @hookimpl def resolve_session(self, message: ChannelMessage) -> str: session_id = field_of(message, "session_id") @@ -124,6 +182,7 @@ def register_cli_commands(self, app: typer.Typer) -> None: app.command("run")(cli.run) app.command("chat")(cli.chat) + app.command("onboard")(cli.onboard) app.add_typer(cli.login_app) app.command("hooks", hidden=True)(cli.list_hooks) app.command("gateway")(cli.gateway) @@ -131,6 +190,79 @@ def register_cli_commands(self, app: typer.Typer) -> None: app.command("uninstall")(cli.uninstall) app.command("update")(cli.update) + @hookimpl + def onboard_config(self, current_config: dict[str, object]) -> dict[str, object] | None: + current_model = current_config.get("model") + model_default = str(current_model) if isinstance(current_model, str) and current_model else DEFAULT_MODEL + provider_default, model_name_default = self._split_model_identifier(model_default) + + provider = self._ask_onboard( + questionary.autocomplete( + "LLM provider", + choices=self._provider_choices(provider_default), + match_middle=True, + default=provider_default, + ) + ) + if provider == "custom": + provider = ( + self._ask_onboard(questionary.text("Custom provider", default=provider_default)) or provider_default + ) + + model_name = self._ask_onboard(questionary.text("LLM model", default=model_name_default)) + if not model_name: + model_name = model_name_default + model = f"{provider}:{model_name}" + + current_api_key = current_config.get("api_key") + api_key_default = str(current_api_key) if isinstance(current_api_key, str) else "" + api_key = self._ask_onboard(questionary.password("API key (optional)", default=api_key_default)) + + current_api_base = current_config.get("api_base") + api_base_default = str(current_api_base) if isinstance(current_api_base, str) else "" + api_base = self._ask_onboard(questionary.text("API base (optional)", default=api_base_default)) + + current_api_format = current_config.get("api_format") + api_format_default = ( + str(current_api_format) + if isinstance(current_api_format, str) and current_api_format in API_FORMAT_CHOICES + else API_FORMAT_CHOICES[0] + ) + api_format = self._ask_onboard( + questionary.select("API format", choices=list(API_FORMAT_CHOICES), default=api_format_default) + ) + + available_channels = self._channel_choices() + default_channels = self._default_enabled_channels(current_config.get("enabled_channels"), available_channels) + enabled_channels = cast( + "list[str]", + self._ask_onboard_value( + questionary.checkbox( + "Channels", + choices=[questionary.Choice(name, checked=name in default_channels) for name in available_channels], + validate=lambda values: True if values else "Select at least one channel.", + ) + ), + ) + + stream_output = cast( + "bool", + self._ask_onboard_value( + questionary.confirm("Stream output", default=bool(current_config.get("stream_output"))) + ), + ) + config: dict[str, object] = { + "model": model, + "api_format": api_format, + "enabled_channels": ",".join(enabled_channels), + "stream_output": stream_output, + } + if api_key: + config["api_key"] = api_key + if api_base: + config["api_base"] = api_base + return config + def _read_agents_file(self, state: State) -> str: workspace = state.get("_runtime_workspace", str(Path.cwd())) prompt_path = Path(workspace) / AGENTS_FILE_NAME diff --git a/src/bub/builtin/settings.py b/src/bub/builtin/settings.py index 878b4bac..8eefca93 100644 --- a/src/bub/builtin/settings.py +++ b/src/bub/builtin/settings.py @@ -11,7 +11,7 @@ from bub import Settings, config, ensure_config -DEFAULT_MODEL = "openrouter:qwen/qwen3-coder-next" +DEFAULT_MODEL = "openrouter:openrouter/free" DEFAULT_MAX_TOKENS = 1024 diff --git a/src/bub/configure.py b/src/bub/configure.py index 65994c33..0d0d6b17 100644 --- a/src/bub/configure.py +++ b/src/bub/configure.py @@ -42,6 +42,7 @@ def load(config_file: Path) -> dict[str, Any]: """Load config from a file.""" import yaml + _global_config.clear() _config_data.clear() if config_file.exists(): with config_file.open() as f: @@ -49,6 +50,34 @@ def load(config_file: Path) -> dict[str, Any]: return _config_data +def merge(base: dict[str, Any], *updates: dict[str, Any]) -> dict[str, Any]: + """Update base in place with config updates, preferring incoming values on conflict.""" + + for update in updates: + _merge_into(base, update, path=()) + return base + + +def validate(config_data: dict[str, Any]) -> dict[str, Any]: + """Validate config data against all registered config classes.""" + + for section, config_classes in CONFIG_MAP.items(): + section_data = config_data if section == ROOT else config_data.get(section, {}) + for config_cls in config_classes: + config_cls.model_validate(section_data) + return config_data + + +def save(config_file: Path, config_data: dict[str, Any]) -> None: + """Validate and persist config data to a YAML file.""" + import yaml + + validated = validate(config_data) + config_file.parent.mkdir(parents=True, exist_ok=True) + with config_file.open("w", encoding="utf-8") as f: + yaml.safe_dump(validated, f, sort_keys=False) + + def ensure_config[C: BaseSettings](config_cls: type[C]) -> C: """No-op function to ensure a config class is registered and can be imported.""" section = getattr(config_cls, "__config_name__", ROOT) @@ -64,3 +93,25 @@ def ensure_config[C: BaseSettings](config_cls: type[C]) -> C: instance = config_cls.model_validate(section_data) instances.append(instance) return instance + + +def _copy_dict(data: dict[str, Any]) -> dict[str, Any]: + copied: dict[str, Any] = {} + for key, value in data.items(): + if isinstance(value, dict): + copied[key] = _copy_dict(value) + else: + copied[key] = value + return copied + + +def _merge_into(target: dict[str, Any], incoming: dict[str, Any], path: tuple[str, ...]) -> None: + for key, value in incoming.items(): + existing = target.get(key) + if key not in target: + target[key] = _copy_dict(value) if isinstance(value, dict) else value + continue + if isinstance(existing, dict) and isinstance(value, dict): + _merge_into(existing, value, path=(*path, key)) + continue + target[key] = _copy_dict(value) if isinstance(value, dict) else value diff --git a/src/bub/framework.py b/src/bub/framework.py index a89fdd1c..8573fea3 100644 --- a/src/bub/framework.py +++ b/src/bub/framework.py @@ -16,7 +16,7 @@ from bub import configure from bub.envelope import content_of, field_of, unpack_batch -from bub.hook_runtime import HookRuntime +from bub.hook_runtime import _SKIP_VALUE, HookRuntime from bub.hookspecs import BUB_HOOK_NAMESPACE, BubHookSpecs from bub.types import Envelope, MessageHandler, OutboundChannelRouter, TurnResult @@ -40,12 +40,13 @@ class BubFramework: def __init__(self, config_file: Path = DEFAULT_CONFIG_FILE) -> None: self.workspace = Path.cwd().resolve() + self.config_file = config_file.resolve() self._plugin_manager = pluggy.PluginManager(BUB_HOOK_NAMESPACE) self._plugin_manager.add_hookspecs(BubHookSpecs) self._hook_runtime = HookRuntime(self._plugin_manager) self._plugin_status: dict[str, PluginStatus] = {} self._outbound_router: OutboundChannelRouter | None = None - configure.load(config_file) + configure.load(self.config_file) def _load_builtin_hooks(self) -> None: from bub.builtin.hook_impl import BuiltinImpl @@ -264,3 +265,22 @@ def get_system_prompt(self, prompt: str | list[dict], state: dict[str, Any]) -> def build_tape_context(self) -> TapeContext: return self._hook_runtime.call_first_sync("build_tape_context") + + def collect_onboard_config(self) -> dict[str, Any]: + current_config: dict[str, Any] = {} + + for impl in self._hook_runtime._iter_hookimpls("onboard_config"): + result = self._hook_runtime._invoke_impl_sync( + hook_name="onboard_config", + impl=impl, + call_kwargs={"current_config": current_config}, + kwargs={"current_config": current_config}, + ) + if result is _SKIP_VALUE: + continue + if result is None: + continue + if not isinstance(result, dict): + raise TypeError("hook.onboard_config must return dict or None") + configure.merge(current_config, result) + return configure.validate(current_config) diff --git a/src/bub/hookspecs.py b/src/bub/hookspecs.py index 47da2b66..000aa9c5 100644 --- a/src/bub/hookspecs.py +++ b/src/bub/hookspecs.py @@ -26,11 +26,6 @@ def resolve_session(self, message: Envelope) -> str: """Resolve session id for one inbound message.""" raise NotImplementedError - @hookspec(firstresult=True) - def load_state(self, message: Envelope, session_id: str) -> State: - """Load state snapshot for one session.""" - raise NotImplementedError - @hookspec(firstresult=True) def build_prompt(self, message: Envelope, session_id: str, state: State) -> str | list[dict]: """Build model prompt for this turn. @@ -50,6 +45,11 @@ def run_model_stream(self, prompt: str | list[dict], session_id: str, state: Sta """Run model for one turn and return a stream of events. Should not be implemented if `run_model` is implemented.""" raise NotImplementedError + @hookspec + def load_state(self, message: Envelope, session_id: str) -> State: + """Load state snapshot for one session.""" + raise NotImplementedError + @hookspec def save_state( self, @@ -80,6 +80,10 @@ def dispatch_outbound(self, message: Envelope) -> bool: def register_cli_commands(self, app: Any) -> None: """Register CLI commands onto the root Typer application.""" + @hookspec + def onboard_config(self, current_config: dict[str, Any]) -> dict[str, Any] | None: + """Collect a plugin config fragment for the interactive onboarding command.""" + @hookspec def on_error(self, stage: str, error: Exception, message: Envelope | None) -> None: """Observe framework errors from any stage.""" diff --git a/tests/test_builtin_cli.py b/tests/test_builtin_cli.py index 4f0cdd1f..386ceef6 100644 --- a/tests/test_builtin_cli.py +++ b/tests/test_builtin_cli.py @@ -1,21 +1,295 @@ from __future__ import annotations import json +import os from pathlib import Path +from typing import Any +from unittest.mock import patch +import typer from typer.testing import CliRunner import bub.builtin.auth as auth import bub.builtin.cli as cli +import bub.builtin.hook_impl as builtin_hook_impl +import bub.configure as configure from bub.framework import BubFramework +from bub.hookspecs import hookimpl -def _create_app() -> object: +class _FakeQuestion: + def __init__(self, answer: Any) -> None: + self._answer = answer + + def ask(self) -> Any: + return self._answer + + +def _create_app() -> typer.Typer: framework = BubFramework() framework.load_hooks() return framework.create_cli_app() +def _rendered_onboard_banner() -> str: + return cli.ONBOARD_BANNER.format(version=cli.__version__) + + +def test_onboard_collects_plugin_config_and_writes_file(tmp_path: Path, monkeypatch) -> None: + config_file = tmp_path / "config.yml" + + with patch.dict(os.environ, {}, clear=True): + monkeypatch.chdir(tmp_path) + framework = BubFramework(config_file=config_file) + framework.load_hooks() + + class OnboardPlugin: + @hookimpl + def onboard_config(self, current_config): + assert current_config == {} + return { + "model": cli.typer.prompt("Model", default="openai:gpt-5"), + "telegram": {"token": cli.typer.prompt("Telegram token", hide_input=True)}, + } + + framework._plugin_manager.register(OnboardPlugin(), name="onboard-plugin") + app = framework.create_cli_app() + + answers = iter([ + "openai:gpt-5", + "123:abc", + "openai:gpt-5", + "", + "", + ]) + monkeypatch.setattr( + cli.typer, + "prompt", + lambda message, default=None, hide_input=False, show_default=True: next(answers), + ) + monkeypatch.setattr( + builtin_hook_impl.questionary, + "text", + lambda message, default="": _FakeQuestion(default), + ) + monkeypatch.setattr( + builtin_hook_impl.questionary, + "autocomplete", + lambda message, choices, default="", match_middle=False: _FakeQuestion(default), + ) + monkeypatch.setattr( + builtin_hook_impl.questionary, + "select", + lambda message, choices, default="": _FakeQuestion(default), + ) + monkeypatch.setattr( + builtin_hook_impl.questionary, + "checkbox", + lambda message, choices, validate=None: _FakeQuestion(["telegram"]), + ) + monkeypatch.setattr( + builtin_hook_impl.questionary, + "confirm", + lambda message, default=False: _FakeQuestion(default), + ) + monkeypatch.setattr( + builtin_hook_impl.questionary, + "password", + lambda message, default="": _FakeQuestion(default), + ) + + result = CliRunner().invoke(app, ["onboard"]) + + loaded = configure.load(config_file) + + assert result.exit_code == 0 + assert _rendered_onboard_banner() in result.stdout + assert f"Saved config to {config_file.resolve()}" in result.stdout + assert loaded == { + "model": "openai:gpt-5", + "api_format": "completion", + "enabled_channels": "telegram", + "stream_output": False, + "telegram": {"token": "123:abc"}, + } + + +def test_onboard_collects_builtin_runtime_config(tmp_path: Path, monkeypatch) -> None: + config_file = tmp_path / "config.yml" + + with patch.dict(os.environ, {}, clear=True): + monkeypatch.chdir(tmp_path) + framework = BubFramework(config_file=config_file) + framework.load_hooks() + app = framework.create_cli_app() + + monkeypatch.setattr( + builtin_hook_impl.questionary, + "text", + lambda message, default="": _FakeQuestion( + { + "LLM model": "openrouter/free", + "API base (optional)": "https://openrouter.ai/api/v1", + }.get(message, default) + ), + ) + monkeypatch.setattr( + builtin_hook_impl.questionary, + "autocomplete", + lambda message, choices, default="", match_middle=False: _FakeQuestion("openrouter"), + ) + monkeypatch.setattr( + builtin_hook_impl.questionary, + "select", + lambda message, choices, default="": _FakeQuestion("responses"), + ) + monkeypatch.setattr( + builtin_hook_impl.questionary, + "checkbox", + lambda message, choices, validate=None: _FakeQuestion(["telegram", "cli"]), + ) + monkeypatch.setattr( + builtin_hook_impl.questionary, + "confirm", + lambda message, default=False: _FakeQuestion(True), + ) + monkeypatch.setattr( + builtin_hook_impl.questionary, + "password", + lambda message, default="": _FakeQuestion("sk-test"), + ) + + result = CliRunner().invoke(app, ["onboard"]) + + loaded = configure.load(config_file) + + assert result.exit_code == 0 + assert loaded == { + "model": "openrouter:openrouter/free", + "api_format": "responses", + "enabled_channels": "telegram,cli", + "stream_output": True, + "api_key": "sk-test", + "api_base": "https://openrouter.ai/api/v1", + } + + +def test_onboard_aborts_immediately_when_builtin_prompt_is_interrupted(tmp_path: Path, monkeypatch) -> None: + config_file = tmp_path / "config.yml" + asked_messages: list[str] = [] + + with patch.dict(os.environ, {}, clear=True): + monkeypatch.chdir(tmp_path) + framework = BubFramework(config_file=config_file) + framework.load_hooks() + app = framework.create_cli_app() + + def fake_autocomplete( + message: str, choices: list[str], default: str = "", match_middle: bool = False + ) -> _FakeQuestion: + asked_messages.append(message) + return _FakeQuestion(default) + + def fake_select(message: str, choices: list[str], default: str = "") -> _FakeQuestion: + asked_messages.append(message) + return _FakeQuestion(default) + + def fake_checkbox(message: str, choices: list[object], validate=None) -> _FakeQuestion: + asked_messages.append(message) + return _FakeQuestion(["telegram"]) + + def fake_confirm(message: str, default: bool = False) -> _FakeQuestion: + asked_messages.append(message) + return _FakeQuestion(default) + + def fake_text(message: str, default: str = "") -> _FakeQuestion: + asked_messages.append(message) + if message == "API base (optional)": + raise AssertionError("Onboarding should stop after interruption") + return _FakeQuestion("openrouter:openrouter/free") + + def fake_password(message: str, default: str = "") -> _FakeQuestion: + asked_messages.append(message) + return _FakeQuestion(None) + + monkeypatch.setattr(builtin_hook_impl.questionary, "autocomplete", fake_autocomplete) + monkeypatch.setattr(builtin_hook_impl.questionary, "select", fake_select) + monkeypatch.setattr(builtin_hook_impl.questionary, "checkbox", fake_checkbox) + monkeypatch.setattr(builtin_hook_impl.questionary, "confirm", fake_confirm) + monkeypatch.setattr(builtin_hook_impl.questionary, "text", fake_text) + monkeypatch.setattr(builtin_hook_impl.questionary, "password", fake_password) + + result = CliRunner().invoke(app, ["onboard"]) + + assert result.exit_code == 1 + assert _rendered_onboard_banner() in result.stdout + assert asked_messages == [ + "LLM provider", + "LLM model", + "API key (optional)", + ] + assert not config_file.exists() + + +def test_onboard_collects_builtin_runtime_config_with_custom_provider(tmp_path: Path, monkeypatch) -> None: + config_file = tmp_path / "config.yml" + + with patch.dict(os.environ, {}, clear=True): + monkeypatch.chdir(tmp_path) + framework = BubFramework(config_file=config_file) + framework.load_hooks() + app = framework.create_cli_app() + + monkeypatch.setattr( + builtin_hook_impl.questionary, + "autocomplete", + lambda message, choices, default="", match_middle=False: _FakeQuestion("custom"), + ) + monkeypatch.setattr( + builtin_hook_impl.questionary, + "select", + lambda message, choices, default="": _FakeQuestion("messages"), + ) + monkeypatch.setattr( + builtin_hook_impl.questionary, + "checkbox", + lambda message, choices, validate=None: _FakeQuestion(["telegram"]), + ) + monkeypatch.setattr( + builtin_hook_impl.questionary, + "confirm", + lambda message, default=False: _FakeQuestion(False), + ) + monkeypatch.setattr( + builtin_hook_impl.questionary, + "text", + lambda message, default="": _FakeQuestion( + { + "Custom provider": "acme", + "LLM model": "ultra-1", + }.get(message, default) + ), + ) + monkeypatch.setattr( + builtin_hook_impl.questionary, + "password", + lambda message, default="": _FakeQuestion(""), + ) + + result = CliRunner().invoke(app, ["onboard"]) + + loaded = configure.load(config_file) + + assert result.exit_code == 0 + assert _rendered_onboard_banner() in result.stdout + assert loaded == { + "model": "acme:ultra-1", + "api_format": "messages", + "enabled_channels": "telegram", + "stream_output": False, + } + + def test_login_openai_runs_oauth_flow_and_prints_usage_hint( tmp_path: Path, monkeypatch, diff --git a/tests/test_configure.py b/tests/test_configure.py new file mode 100644 index 00000000..8f854a6a --- /dev/null +++ b/tests/test_configure.py @@ -0,0 +1,77 @@ +from __future__ import annotations + +import os +from pathlib import Path +from unittest.mock import patch + +import pytest +from pydantic import ValidationError + +import bub.configure as configure +from bub.builtin.settings import AgentSettings +from bub.channels.telegram import TelegramSettings + + +def test_merge_recursively_combines_non_conflicting_dicts() -> None: + base = {"model": "openai:gpt-5", "telegram": {"token": "token"}} + + result = configure.merge( + base, + {"telegram": {"allow_users": "1,2"}}, + ) + + assert result is base + assert result == { + "model": "openai:gpt-5", + "telegram": { + "token": "token", + "allow_users": "1,2", + }, + } + + +def test_merge_overrides_conflicting_scalar_values() -> None: + base = {"model": "openai:gpt-5"} + + result = configure.merge(base, {"model": "anthropic:claude-3-7-sonnet"}) + + assert result is base + assert base == {"model": "anthropic:claude-3-7-sonnet"} + + +def test_validate_checks_registered_config_sections() -> None: + valid_data = { + "model": "openai:gpt-5", + "telegram": {"token": "123:abc"}, + } + + assert configure.validate(valid_data) == valid_data + + with pytest.raises(ValidationError): + configure.validate({"max_steps": "not-an-int"}) + + +def test_save_writes_yaml_and_refreshes_loaded_config(tmp_path: Path) -> None: + config_file = tmp_path / "config.yml" + expected_token = "123:abc" # noqa: S105 + + with patch.dict(os.environ, {}, clear=True): + previous_cwd = Path.cwd() + os.chdir(tmp_path) + configure.save( + config_file, + { + "model": "openai:gpt-5", + "telegram": {"token": expected_token}, + }, + ) + + try: + loaded = configure.load(config_file) + + assert loaded["model"] == "openai:gpt-5" + assert loaded["telegram"]["token"] == expected_token + assert configure.ensure_config(AgentSettings).model == "openai:gpt-5" + assert configure.ensure_config(TelegramSettings).token == expected_token + finally: + os.chdir(previous_cwd) diff --git a/tests/test_framework.py b/tests/test_framework.py index e306bb86..ed925498 100644 --- a/tests/test_framework.py +++ b/tests/test_framework.py @@ -4,13 +4,15 @@ import os from pathlib import Path from types import SimpleNamespace +from typing import Any, cast from unittest.mock import patch import pytest import typer -from republic import AsyncStreamEvents, StreamEvent +from republic import AsyncStreamEvents, StreamEvent, StreamState from typer.testing import CliRunner +from bub import configure from bub.builtin.settings import load_settings from bub.channels.base import Channel from bub.channels.message import ChannelMessage @@ -20,16 +22,23 @@ from bub.hookspecs import hookimpl -class NamedChannel(Channel): - def __init__(self, name: str, label: str) -> None: - self.name = name - self.label = label +def make_named_channel(name: str, label: str) -> Channel: + channel_name = name + channel_label = label - async def start(self, stop_event) -> None: - return None + class NamedChannelImpl(Channel): + name = channel_name - async def stop(self) -> None: - return None + def __init__(self) -> None: + self.label = channel_label + + async def start(self, stop_event) -> None: + return None + + async def stop(self) -> None: + return None + + return NamedChannelImpl() def test_create_cli_app_sets_workspace_and_context(tmp_path: Path) -> None: @@ -56,25 +65,28 @@ def workspace_command(ctx: typer.Context) -> None: def test_get_channels_prefers_high_priority_plugin_for_duplicate_names() -> None: framework = BubFramework() + async def message_handler(message) -> None: + return None + class LowPriorityPlugin: @hookimpl def provide_channels(self, message_handler): - return [NamedChannel("shared", "low"), NamedChannel("low-only", "low")] + return [make_named_channel("shared", "low"), make_named_channel("low-only", "low")] class HighPriorityPlugin: @hookimpl def provide_channels(self, message_handler): - return [NamedChannel("shared", "high"), NamedChannel("high-only", "high")] + return [make_named_channel("shared", "high"), make_named_channel("high-only", "high")] framework._plugin_manager.register(LowPriorityPlugin(), name="low") framework._plugin_manager.register(HighPriorityPlugin(), name="high") - channels = framework.get_channels(lambda message: None) + channels = framework.get_channels(message_handler) assert set(channels) == {"shared", "low-only", "high-only"} - assert channels["shared"].label == "high" - assert channels["low-only"].label == "low" - assert channels["high-only"].label == "high" + assert cast(Any, channels["shared"]).label == "high" + assert cast(Any, channels["low-only"]).label == "low" + assert cast(Any, channels["high-only"]).label == "high" def test_get_system_prompt_uses_priority_order_and_skips_empty_results() -> None: @@ -117,6 +129,7 @@ def test_builtin_cli_exposes_login_and_gateway_command(write_config) -> None: assert help_result.exit_code == 0 assert "login" in help_result.stdout assert "gateway" in help_result.stdout + assert "onboard" in help_result.stdout assert "│ message" not in help_result.stdout assert gateway_result.exit_code == 0 assert "bub gateway" in gateway_result.stdout @@ -165,6 +178,36 @@ def register_cli_commands(self, app: typer.Typer) -> None: assert framework._plugin_status["config-plugin"].is_success is True +def test_collect_onboard_config_passes_accumulated_updates_to_later_hooks(write_config) -> None: + with patch.dict(os.environ, {}, clear=True): + framework = BubFramework(config_file=write_config("model: openai:gpt-5")) + observed_configs: list[tuple[str, dict[str, Any]]] = [] + + class FirstPlugin: + @hookimpl + def onboard_config(self, current_config): + observed_configs.append(("first", configure.merge({}, current_config))) + return {"first": {"enabled": True}} + + class SecondPlugin: + @hookimpl + def onboard_config(self, current_config): + observed_configs.append(("second", configure.merge({}, current_config))) + return {"second": {"enabled": True}} + + framework._plugin_manager.register(FirstPlugin(), name="first") + framework._plugin_manager.register(SecondPlugin(), name="second") + + result = framework.collect_onboard_config() + + assert observed_configs[0][1] == {} + assert observed_configs[1][1] == {observed_configs[0][0]: {"enabled": True}} + assert result == { + "first": {"enabled": True}, + "second": {"enabled": True}, + } + + @pytest.mark.asyncio async def test_process_inbound_defaults_to_non_streaming_run_model() -> None: framework = BubFramework() @@ -237,7 +280,7 @@ async def iterator(): yield StreamEvent("text", {"delta": "ed"}) yield StreamEvent("final", {"text": "streamed", "ok": True}) - return AsyncStreamEvents(iterator(), state=SimpleNamespace(error=None, usage=None)) + return AsyncStreamEvents(iterator(), state=StreamState()) @hookimpl async def save_state(self, session_id, state, message, model_output) -> None: @@ -260,6 +303,12 @@ async def iterator(): return iterator() + async def dispatch_output(self, message) -> bool: + return True + + async def quit(self, session_id: str) -> None: + return None + framework._plugin_manager.register(StreamingPlugin(), name="streaming") framework.bind_outbound_router(RecordingRouter()) diff --git a/tests/test_settings.py b/tests/test_settings.py index 96e663da..81399d89 100644 --- a/tests/test_settings.py +++ b/tests/test_settings.py @@ -139,12 +139,12 @@ def test_load_settings_returns_loaded_config(load_config) -> None: with patch.dict(os.environ, {}, clear=True): load_config( """ -model: openrouter:qwen/qwen3-coder-next +model: openrouter:openrouter/free api_format: responses """.strip(), ) settings = load_settings() - assert settings.model == "openrouter:qwen/qwen3-coder-next" + assert settings.model == "openrouter:openrouter/free" assert settings.api_format == "responses" diff --git a/uv.lock b/uv.lock index 5ab2c14a..7d60a6f2 100644 --- a/uv.lock +++ b/uv.lock @@ -216,6 +216,7 @@ dependencies = [ { name = "pydantic-settings" }, { name = "python-telegram-bot" }, { name = "pyyaml" }, + { name = "questionary" }, { name = "rapidfuzz" }, { name = "republic" }, { name = "rich" }, @@ -251,6 +252,7 @@ requires-dist = [ { name = "pydantic-settings", specifier = ">=2.0.0" }, { name = "python-telegram-bot", specifier = ">=21.0" }, { name = "pyyaml", specifier = ">=6.0.0" }, + { name = "questionary", specifier = ">=2.1.0" }, { name = "rapidfuzz", specifier = ">=3.14.3" }, { name = "republic", specifier = ">=0.5.4" }, { name = "rich", specifier = ">=13.0.0" }, @@ -1583,6 +1585,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/f1/12/de94a39c2ef588c7e6455cfbe7343d3b2dc9d6b6b2f40c4c6565744c873d/pyyaml-6.0.3-cp314-cp314t-win_arm64.whl", hash = "sha256:ebc55a14a21cb14062aa4162f906cd962b28e2e9ea38f9b4391244cd8de4ae0b", size = 149341, upload-time = "2025-09-25T21:32:56.828Z" }, ] +[[package]] +name = "questionary" +version = "2.1.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "prompt-toolkit" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/f6/45/eafb0bba0f9988f6a2520f9ca2df2c82ddfa8d67c95d6625452e97b204a5/questionary-2.1.1.tar.gz", hash = "sha256:3d7e980292bb0107abaa79c68dd3eee3c561b83a0f89ae482860b181c8bd412d", size = 25845, upload-time = "2025-08-28T19:00:20.851Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/3c/26/1062c7ec1b053db9e499b4d2d5bc231743201b74051c973dadeac80a8f43/questionary-2.1.1-py3-none-any.whl", hash = "sha256:a51af13f345f1cdea62347589fbb6df3b290306ab8930713bfae4d475a7d4a59", size = 36753, upload-time = "2025-08-28T19:00:19.56Z" }, +] + [[package]] name = "rapidfuzz" version = "3.14.5" diff --git a/website/src/content/docs/docs/extending/hooks.mdx b/website/src/content/docs/docs/extending/hooks.mdx index 39da89fa..36d4de53 100644 --- a/website/src/content/docs/docs/extending/hooks.mdx +++ b/website/src/content/docs/docs/extending/hooks.mdx @@ -31,6 +31,7 @@ Compatibility note: Other hook consumers: - `register_cli_commands`: called by `call_many_sync` +- `onboard_config`: called by `BubFramework.collect_onboard_config()` during `bub onboard` - `provide_channels`: called by `call_many_sync` in `BubFramework.get_channels()` - `system_prompt`, `provide_tape_store`, `build_tape_context`: consumed by `BubFramework` and the builtin `Agent` @@ -47,10 +48,45 @@ Other hook consumers: - Sync hook calls skip awaitable return values and log a warning. - Therefore, keep bootstrap hooks synchronous: - `register_cli_commands` + - `onboard_config` - `provide_channels` - `provide_tape_store` - `build_tape_context` +## Interactive Onboarding + +`onboard_config(current_config)` lets a plugin participate in the interactive `bub onboard` flow. + +- Bub calls implementations by priority order, the same way it does for other sync bootstrap hooks. +- Each hook receives the accumulated `current_config` built by earlier hooks. +- Return a config fragment as `dict[str, Any]` to merge into the onboarding result. +- Return `None` to skip without changing the accumulated config. +- Returning any non-dict value raises `TypeError` and aborts onboarding. + +This makes `onboard_config` a good fit for provider-specific questions, plugin setup steps, or defaults that depend on answers collected earlier in the same session. + +```python +from bub import hookimpl + + +class WeatherPlugin: + @hookimpl + def onboard_config(self, current_config): + weather_config = dict(current_config.get("weather", {})) + api_key = weather_config.get("api_key", "") + + return { + "weather": { + "api_key": api_key or "demo-key", + "enabled": True, + } + } +``` + +Use `current_config` when your prompt or default value depends on answers from Bub itself or higher-priority plugins. For example, a plugin can branch on the selected model provider, pre-fill an API base, or avoid asking for settings that another plugin already supplied. + +To verify the hook, run `uv run bub onboard` and confirm that the saved config file contains the merged fragment from your plugin. + ## Signature Matching `HookRuntime` passes only parameters declared in your function signature. diff --git a/website/src/content/docs/docs/guides/deployment.mdx b/website/src/content/docs/docs/guides/deployment.mdx index 6538d416..7db1bfa6 100644 --- a/website/src/content/docs/docs/guides/deployment.mdx +++ b/website/src/content/docs/docs/guides/deployment.mdx @@ -25,7 +25,7 @@ Use `uv sync` here on purpose: deployment hosts only need the Python runtime. `m Minimum `.env` example: ```bash -BUB_MODEL=openrouter:qwen/qwen3-coder-next +BUB_MODEL=openrouter:openrouter/free OPENROUTER_API_KEY=sk-or-... ``` diff --git a/website/src/content/docs/zh-cn/docs/extending/hooks.mdx b/website/src/content/docs/zh-cn/docs/extending/hooks.mdx index 62318262..ea15f55a 100644 --- a/website/src/content/docs/zh-cn/docs/extending/hooks.mdx +++ b/website/src/content/docs/zh-cn/docs/extending/hooks.mdx @@ -31,6 +31,7 @@ description: Hook 执行语义、优先级、同步/异步规则、签名匹配 其他 hook 消费者: - `register_cli_commands`:由 `call_many_sync` 调用 +- `onboard_config`:在执行 `bub onboard` 时由 `BubFramework.collect_onboard_config()` 调用 - `provide_channels`:在 `BubFramework.get_channels()` 中由 `call_many_sync` 调用 - `system_prompt`、`provide_tape_store`、`build_tape_context`:由 `BubFramework` 和内置 `Agent` 消费 @@ -47,10 +48,45 @@ description: Hook 执行语义、优先级、同步/异步规则、签名匹配 - 同步 hook 调用会跳过可等待的返回值并记录警告。 - 因此,引导阶段的 hook 应保持同步: - `register_cli_commands` + - `onboard_config` - `provide_channels` - `provide_tape_store` - `build_tape_context` +## 交互式引导 + +`onboard_config(current_config)` 允许插件参与交互式 `bub onboard` 流程。 + +- Bub 会按优先级顺序调用实现,这一点与其他同步引导 hook 一致。 +- 每个 hook 都会收到由更早执行的 hook 累积出来的 `current_config`。 +- 返回 `dict[str, Any]` 配置片段时,结果会被合并进最终的引导配置。 +- 返回 `None` 表示跳过,不修改当前累积配置。 +- 返回任何非字典值都会触发 `TypeError`,并中止引导流程。 + +因此,`onboard_config` 很适合承载 provider 相关提问、插件初始化步骤,或者依赖前序回答的默认值计算。 + +```python +from bub import hookimpl + + +class WeatherPlugin: + @hookimpl + def onboard_config(self, current_config): + weather_config = dict(current_config.get("weather", {})) + api_key = weather_config.get("api_key", "") + + return { + "weather": { + "api_key": api_key or "demo-key", + "enabled": True, + } + } +``` + +当你的提问或默认值依赖 Bub 自身、或更高优先级插件先前收集到的答案时,请读取 `current_config`。例如,插件可以根据已选择的模型 provider 分支、预填 `api_base`,或避免重复询问其他插件已经提供过的设置。 + +验证这个 hook 时,可以运行 `uv run bub onboard`,确认最终写入的配置文件包含了插件返回并合并后的配置片段。 + ## 签名匹配 `HookRuntime` 只传递你函数签名中声明的参数。 diff --git a/website/src/content/docs/zh-cn/docs/guides/deployment.mdx b/website/src/content/docs/zh-cn/docs/guides/deployment.mdx index 48d7cb9e..1356630f 100644 --- a/website/src/content/docs/zh-cn/docs/guides/deployment.mdx +++ b/website/src/content/docs/zh-cn/docs/guides/deployment.mdx @@ -25,7 +25,7 @@ cp env.example .env 最小 `.env` 示例: ```bash -BUB_MODEL=openrouter:qwen/qwen3-coder-next +BUB_MODEL=openrouter:openrouter/free OPENROUTER_API_KEY=sk-or-... ``` From 53d105ad45da512dc818fb59610bd8563f6c1d1d Mon Sep 17 00:00:00 2001 From: Frost Ming Date: Fri, 24 Apr 2026 16:26:18 +0800 Subject: [PATCH 2/3] feat: replace questionary with inquirer-textual for improved CLI prompts Signed-off-by: Frost Ming --- pyproject.toml | 2 +- src/bub/builtin/hook_impl.py | 94 ++++++----- tests/test_builtin_cli.py | 157 ++++++++++-------- uv.lock | 83 +++++++-- .../src/content/docs/docs/extending/hooks.mdx | 24 +-- .../docs/zh-cn/docs/extending/hooks.mdx | 24 +-- 6 files changed, 237 insertions(+), 147 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 9323216f..cc05424f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,7 +24,7 @@ dependencies = [ "pydantic-settings>=2.0.0", "pyyaml>=6.0.0", "pluggy>=1.6.0", - "questionary>=2.1.0", + "inquirer-textual>=0.5.1", "typer>=0.9.0", "republic>=0.5.4", "any-llm-sdk[anthropic]", diff --git a/src/bub/builtin/hook_impl.py b/src/bub/builtin/hook_impl.py index 5c0deb54..f5191912 100644 --- a/src/bub/builtin/hook_impl.py +++ b/src/bub/builtin/hook_impl.py @@ -3,8 +3,12 @@ from pathlib import Path from typing import Any, cast -import questionary import typer +from inquirer_textual import prompts +from inquirer_textual.common.Choice import Choice +from inquirer_textual.common.InquirerResult import InquirerResult +from inquirer_textual.common.PromptSettings import PromptSettings +from inquirer_textual.common.Shortcut import Shortcut from loguru import logger from republic import AsyncStreamEvents, TapeContext from republic.tape import TapeStore @@ -20,6 +24,7 @@ from bub.types import Envelope, MessageHandler, State AGENTS_FILE_NAME = "AGENTS.md" +CHECKBOX_HINT_SETTINGS = PromptSettings(shortcuts=[Shortcut("space", "toggle", "Space check/uncheck")]) MODEL_PROVIDER_CHOICES: tuple[str, ...] = ( "openrouter", "openai", @@ -56,6 +61,15 @@ """ +def ask_prompt(question: InquirerResult[Any]) -> Any: + if question.command in {"ctrl+c", "quit"}: + raise typer.Abort() + answer = question.value + if answer is None: + raise typer.Abort() + return answer + + class BuiltinImpl: """Default hook implementations for basic runtime operations.""" @@ -70,18 +84,6 @@ def _get_agent(self) -> Agent: self._agent = Agent(self.framework) return self._agent - @staticmethod - def _ask_onboard_value(question: Any) -> Any: - answer = question.ask() - if answer is None: - raise typer.Abort() - return answer - - @classmethod - def _ask_onboard(cls, question: Any) -> str: - answer = cls._ask_onboard_value(question) - return str(answer).strip() - @staticmethod async def _discard_message(_: ChannelMessage) -> None: return @@ -113,6 +115,31 @@ def _default_enabled_channels(current_value: object, available_channels: list[st return selected return available_channels + @classmethod + def _ask_onboard_checkbox( + cls, + message: str, + choices: list[str], + enabled: list[str] | None = None, + validate: Any = None, + ) -> list[str]: + while True: + answer: list[str | Choice] = ask_prompt( + prompts.checkbox( + message, + choices=cast("list[str | Choice]", choices), + enabled=cast("list[str | Choice] | None", enabled), + settings=CHECKBOX_HINT_SETTINGS, + ) + ) + values = list(cast("list[str]", answer or [])) + if validate is None: + return values + validation_result = validate(values) + if validation_result is True: + return values + typer.secho(str(validation_result), err=True, fg="red") + @hookimpl def resolve_session(self, message: ChannelMessage) -> str: session_id = field_of(message, "session_id") @@ -196,31 +223,26 @@ def onboard_config(self, current_config: dict[str, object]) -> dict[str, object] model_default = str(current_model) if isinstance(current_model, str) and current_model else DEFAULT_MODEL provider_default, model_name_default = self._split_model_identifier(model_default) - provider = self._ask_onboard( - questionary.autocomplete( + provider: str = ask_prompt( + prompts.fuzzy( "LLM provider", - choices=self._provider_choices(provider_default), - match_middle=True, + choices=cast("list[str | Choice]", self._provider_choices(provider_default)), default=provider_default, ) ) if provider == "custom": - provider = ( - self._ask_onboard(questionary.text("Custom provider", default=provider_default)) or provider_default - ) + provider = ask_prompt(prompts.text("Custom provider", default=provider_default)) or provider_default - model_name = self._ask_onboard(questionary.text("LLM model", default=model_name_default)) + model_name: str = ask_prompt(prompts.text("LLM model", default=model_name_default)) if not model_name: model_name = model_name_default model = f"{provider}:{model_name}" - current_api_key = current_config.get("api_key") - api_key_default = str(current_api_key) if isinstance(current_api_key, str) else "" - api_key = self._ask_onboard(questionary.password("API key (optional)", default=api_key_default)) + api_key: str = ask_prompt(prompts.secret("API key (optional)")) current_api_base = current_config.get("api_base") api_base_default = str(current_api_base) if isinstance(current_api_base, str) else "" - api_base = self._ask_onboard(questionary.text("API base (optional)", default=api_base_default)) + api_base: str = ask_prompt(prompts.text("API base (optional)", default=api_base_default)) current_api_format = current_config.get("api_format") api_format_default = ( @@ -228,28 +250,22 @@ def onboard_config(self, current_config: dict[str, object]) -> dict[str, object] if isinstance(current_api_format, str) and current_api_format in API_FORMAT_CHOICES else API_FORMAT_CHOICES[0] ) - api_format = self._ask_onboard( - questionary.select("API format", choices=list(API_FORMAT_CHOICES), default=api_format_default) + api_format: str = ask_prompt( + prompts.select("API format", choices=list(API_FORMAT_CHOICES), default=api_format_default) ) available_channels = self._channel_choices() default_channels = self._default_enabled_channels(current_config.get("enabled_channels"), available_channels) - enabled_channels = cast( - "list[str]", - self._ask_onboard_value( - questionary.checkbox( - "Channels", - choices=[questionary.Choice(name, checked=name in default_channels) for name in available_channels], - validate=lambda values: True if values else "Select at least one channel.", - ) - ), + enabled_channels = self._ask_onboard_checkbox( + "Channels", + choices=available_channels, + enabled=default_channels, + validate=lambda values: True if values else "Select at least one channel.", ) stream_output = cast( "bool", - self._ask_onboard_value( - questionary.confirm("Stream output", default=bool(current_config.get("stream_output"))) - ), + ask_prompt(prompts.confirm("Stream output", default=bool(current_config.get("stream_output")))), ) config: dict[str, object] = { "model": model, diff --git a/tests/test_builtin_cli.py b/tests/test_builtin_cli.py index 386ceef6..cb5c8f78 100644 --- a/tests/test_builtin_cli.py +++ b/tests/test_builtin_cli.py @@ -7,6 +7,8 @@ from unittest.mock import patch import typer +from inquirer_textual.common.InquirerResult import InquirerResult +from inquirer_textual.common.PromptSettings import PromptSettings from typer.testing import CliRunner import bub.builtin.auth as auth @@ -17,12 +19,16 @@ from bub.hookspecs import hookimpl -class _FakeQuestion: - def __init__(self, answer: Any) -> None: - self._answer = answer +def _fake_result(answer: Any, command: str | None = "enter") -> InquirerResult[Any]: + return InquirerResult(None, answer, command) - def ask(self) -> Any: - return self._answer + +def _assert_checkbox_hint(settings: PromptSettings | None) -> None: + assert settings is not None + assert settings.shortcuts is not None + assert [(shortcut.key, shortcut.command, shortcut.description) for shortcut in settings.shortcuts] == [ + ("space", "toggle", "Space check/uncheck") + ] def _create_app() -> typer.Typer: @@ -68,34 +74,37 @@ def onboard_config(self, current_config): lambda message, default=None, hide_input=False, show_default=True: next(answers), ) monkeypatch.setattr( - builtin_hook_impl.questionary, + builtin_hook_impl.prompts, "text", - lambda message, default="": _FakeQuestion(default), + lambda message, default="": _fake_result(default), ) monkeypatch.setattr( - builtin_hook_impl.questionary, - "autocomplete", - lambda message, choices, default="", match_middle=False: _FakeQuestion(default), + builtin_hook_impl.prompts, + "fuzzy", + lambda message, choices, default=None: _fake_result(default), ) monkeypatch.setattr( - builtin_hook_impl.questionary, + builtin_hook_impl.prompts, "select", - lambda message, choices, default="": _FakeQuestion(default), + lambda message, choices, default="": _fake_result(default), ) monkeypatch.setattr( - builtin_hook_impl.questionary, + builtin_hook_impl.prompts, "checkbox", - lambda message, choices, validate=None: _FakeQuestion(["telegram"]), + lambda message, choices, enabled=None, settings=None: ( + _assert_checkbox_hint(settings), + _fake_result(["telegram"]), + )[1], ) monkeypatch.setattr( - builtin_hook_impl.questionary, + builtin_hook_impl.prompts, "confirm", - lambda message, default=False: _FakeQuestion(default), + lambda message, default=False: _fake_result(default), ) monkeypatch.setattr( - builtin_hook_impl.questionary, - "password", - lambda message, default="": _FakeQuestion(default), + builtin_hook_impl.prompts, + "secret", + lambda message: _fake_result(""), ) result = CliRunner().invoke(app, ["onboard"]) @@ -124,9 +133,9 @@ def test_onboard_collects_builtin_runtime_config(tmp_path: Path, monkeypatch) -> app = framework.create_cli_app() monkeypatch.setattr( - builtin_hook_impl.questionary, + builtin_hook_impl.prompts, "text", - lambda message, default="": _FakeQuestion( + lambda message, default="": _fake_result( { "LLM model": "openrouter/free", "API base (optional)": "https://openrouter.ai/api/v1", @@ -134,29 +143,32 @@ def test_onboard_collects_builtin_runtime_config(tmp_path: Path, monkeypatch) -> ), ) monkeypatch.setattr( - builtin_hook_impl.questionary, - "autocomplete", - lambda message, choices, default="", match_middle=False: _FakeQuestion("openrouter"), + builtin_hook_impl.prompts, + "fuzzy", + lambda message, choices, default=None: _fake_result("openrouter"), ) monkeypatch.setattr( - builtin_hook_impl.questionary, + builtin_hook_impl.prompts, "select", - lambda message, choices, default="": _FakeQuestion("responses"), + lambda message, choices, default="": _fake_result("responses"), ) monkeypatch.setattr( - builtin_hook_impl.questionary, + builtin_hook_impl.prompts, "checkbox", - lambda message, choices, validate=None: _FakeQuestion(["telegram", "cli"]), + lambda message, choices, enabled=None, settings=None: ( + _assert_checkbox_hint(settings), + _fake_result(["telegram", "cli"]), + )[1], ) monkeypatch.setattr( - builtin_hook_impl.questionary, + builtin_hook_impl.prompts, "confirm", - lambda message, default=False: _FakeQuestion(True), + lambda message, default=False: _fake_result(True), ) monkeypatch.setattr( - builtin_hook_impl.questionary, - "password", - lambda message, default="": _FakeQuestion("sk-test"), + builtin_hook_impl.prompts, + "secret", + lambda message: _fake_result("sk-test"), ) result = CliRunner().invoke(app, ["onboard"]) @@ -184,40 +196,44 @@ def test_onboard_aborts_immediately_when_builtin_prompt_is_interrupted(tmp_path: framework.load_hooks() app = framework.create_cli_app() - def fake_autocomplete( - message: str, choices: list[str], default: str = "", match_middle: bool = False - ) -> _FakeQuestion: + def fake_fuzzy(message: str, choices: list[str], default: str | None = None) -> InquirerResult[Any]: asked_messages.append(message) - return _FakeQuestion(default) + return _fake_result(default) - def fake_select(message: str, choices: list[str], default: str = "") -> _FakeQuestion: + def fake_select(message: str, choices: list[str], default: str = "") -> InquirerResult[Any]: asked_messages.append(message) - return _FakeQuestion(default) - - def fake_checkbox(message: str, choices: list[object], validate=None) -> _FakeQuestion: + return _fake_result(default) + + def fake_checkbox( + message: str, + choices: list[object], + enabled=None, + settings: PromptSettings | None = None, + ) -> InquirerResult[Any]: asked_messages.append(message) - return _FakeQuestion(["telegram"]) + _assert_checkbox_hint(settings) + return _fake_result(["telegram"]) - def fake_confirm(message: str, default: bool = False) -> _FakeQuestion: + def fake_confirm(message: str, default: bool = False) -> InquirerResult[Any]: asked_messages.append(message) - return _FakeQuestion(default) + return _fake_result(default) - def fake_text(message: str, default: str = "") -> _FakeQuestion: + def fake_text(message: str, default: str = "") -> InquirerResult[Any]: asked_messages.append(message) if message == "API base (optional)": raise AssertionError("Onboarding should stop after interruption") - return _FakeQuestion("openrouter:openrouter/free") + return _fake_result("openrouter:openrouter/free") - def fake_password(message: str, default: str = "") -> _FakeQuestion: - asked_messages.append(message) - return _FakeQuestion(None) + def fake_secret(message: str) -> InquirerResult[Any]: + asked_messages.append("API key (optional)") + return _fake_result(None) - monkeypatch.setattr(builtin_hook_impl.questionary, "autocomplete", fake_autocomplete) - monkeypatch.setattr(builtin_hook_impl.questionary, "select", fake_select) - monkeypatch.setattr(builtin_hook_impl.questionary, "checkbox", fake_checkbox) - monkeypatch.setattr(builtin_hook_impl.questionary, "confirm", fake_confirm) - monkeypatch.setattr(builtin_hook_impl.questionary, "text", fake_text) - monkeypatch.setattr(builtin_hook_impl.questionary, "password", fake_password) + monkeypatch.setattr(builtin_hook_impl.prompts, "fuzzy", fake_fuzzy) + monkeypatch.setattr(builtin_hook_impl.prompts, "select", fake_select) + monkeypatch.setattr(builtin_hook_impl.prompts, "checkbox", fake_checkbox) + monkeypatch.setattr(builtin_hook_impl.prompts, "confirm", fake_confirm) + monkeypatch.setattr(builtin_hook_impl.prompts, "text", fake_text) + monkeypatch.setattr(builtin_hook_impl.prompts, "secret", fake_secret) result = CliRunner().invoke(app, ["onboard"]) @@ -241,29 +257,32 @@ def test_onboard_collects_builtin_runtime_config_with_custom_provider(tmp_path: app = framework.create_cli_app() monkeypatch.setattr( - builtin_hook_impl.questionary, - "autocomplete", - lambda message, choices, default="", match_middle=False: _FakeQuestion("custom"), + builtin_hook_impl.prompts, + "fuzzy", + lambda message, choices, default=None: _fake_result("custom"), ) monkeypatch.setattr( - builtin_hook_impl.questionary, + builtin_hook_impl.prompts, "select", - lambda message, choices, default="": _FakeQuestion("messages"), + lambda message, choices, default="": _fake_result("messages"), ) monkeypatch.setattr( - builtin_hook_impl.questionary, + builtin_hook_impl.prompts, "checkbox", - lambda message, choices, validate=None: _FakeQuestion(["telegram"]), + lambda message, choices, enabled=None, settings=None: ( + _assert_checkbox_hint(settings), + _fake_result(["telegram"]), + )[1], ) monkeypatch.setattr( - builtin_hook_impl.questionary, + builtin_hook_impl.prompts, "confirm", - lambda message, default=False: _FakeQuestion(False), + lambda message, default=False: _fake_result(False), ) monkeypatch.setattr( - builtin_hook_impl.questionary, + builtin_hook_impl.prompts, "text", - lambda message, default="": _FakeQuestion( + lambda message, default="": _fake_result( { "Custom provider": "acme", "LLM model": "ultra-1", @@ -271,9 +290,9 @@ def test_onboard_collects_builtin_runtime_config_with_custom_provider(tmp_path: ), ) monkeypatch.setattr( - builtin_hook_impl.questionary, - "password", - lambda message, default="": _FakeQuestion(""), + builtin_hook_impl.prompts, + "secret", + lambda message: _fake_result(""), ) result = CliRunner().invoke(app, ["onboard"]) diff --git a/uv.lock b/uv.lock index 7d60a6f2..739c4e6a 100644 --- a/uv.lock +++ b/uv.lock @@ -209,6 +209,7 @@ dependencies = [ { name = "aiohttp" }, { name = "any-llm-sdk" }, { name = "httpx", extra = ["socks"] }, + { name = "inquirer-textual" }, { name = "loguru" }, { name = "pluggy" }, { name = "prompt-toolkit" }, @@ -216,7 +217,6 @@ dependencies = [ { name = "pydantic-settings" }, { name = "python-telegram-bot" }, { name = "pyyaml" }, - { name = "questionary" }, { name = "rapidfuzz" }, { name = "republic" }, { name = "rich" }, @@ -244,6 +244,7 @@ requires-dist = [ { name = "aiohttp", specifier = ">=3.13.3" }, { name = "any-llm-sdk", extras = ["anthropic"] }, { name = "httpx", extras = ["socks"], specifier = ">=0.28.1" }, + { name = "inquirer-textual", specifier = ">=0.5.1" }, { name = "logfire", marker = "extra == 'logfire'", specifier = ">=4.31.0" }, { name = "loguru", specifier = ">=0.7.2" }, { name = "pluggy", specifier = ">=1.6.0" }, @@ -252,7 +253,6 @@ requires-dist = [ { name = "pydantic-settings", specifier = ">=2.0.0" }, { name = "python-telegram-bot", specifier = ">=21.0" }, { name = "pyyaml", specifier = ">=6.0.0" }, - { name = "questionary", specifier = ">=2.1.0" }, { name = "rapidfuzz", specifier = ">=3.14.3" }, { name = "republic", specifier = ">=0.5.4" }, { name = "rich", specifier = ">=13.0.0" }, @@ -711,6 +711,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/cb/b1/3846dd7f199d53cb17f49cba7e651e9ce294d8497c8c150530ed11865bb8/iniconfig-2.3.0-py3-none-any.whl", hash = "sha256:f631c04d2c48c52b84d0d0549c99ff3859c98df65b3101406327ecc7d53fbf12", size = 7484, upload-time = "2025-10-18T21:55:41.639Z" }, ] +[[package]] +name = "inquirer-textual" +version = "0.5.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "textual" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/ba/6c/544e4216e2dc66f931f7bd65c90b91fc9a19a4701e0cf3bee207eaeecb54/inquirer_textual-0.5.1.tar.gz", hash = "sha256:4e4604a303ba58e7321d96044ba235b3332f2f330b081282914be504e1f59c71", size = 3666912, upload-time = "2026-03-24T21:41:46.598Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a2/4f/cf9ee5811c2e74e4be29e5500c8121710dd8e35d241907a0e768cbdc06cb/inquirer_textual-0.5.1-py3-none-any.whl", hash = "sha256:7e5aca4ea112d947d9bb3e096477bcb3555a9a44fd3a3daae6f28a2487da7a5a", size = 31612, upload-time = "2026-03-24T21:41:44.294Z" }, +] + [[package]] name = "jiter" version = "0.14.0" @@ -843,6 +855,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/aa/47/7d70414bcdbb3bc1f458a8d10558f00bbfdb24e5a11740fc8197e12c3255/librt-0.9.0-cp314-cp314t-win_arm64.whl", hash = "sha256:a4b25c6c25cac5d0d9d6d6da855195b254e0021e513e0249f0e3b444dc6e0e61", size = 50009, upload-time = "2026-04-09T16:06:07.995Z" }, ] +[[package]] +name = "linkify-it-py" +version = "2.1.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "uc-micro-py" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/2e/c9/06ea13676ef354f0af6169587ae292d3e2406e212876a413bf9eece4eb23/linkify_it_py-2.1.0.tar.gz", hash = "sha256:43360231720999c10e9328dc3691160e27a718e280673d444c38d7d3aaa3b98b", size = 29158, upload-time = "2026-03-01T07:48:47.683Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b4/de/88b3be5c31b22333b3ca2f6ff1de4e863d8fe45aaea7485f591970ec1d3e/linkify_it_py-2.1.0-py3-none-any.whl", hash = "sha256:0d252c1594ecba2ecedc444053db5d3a9b7ec1b0dd929c8f1d74dce89f86c05e", size = 19878, upload-time = "2026-03-01T07:48:46.098Z" }, +] + [[package]] name = "logfire" version = "4.32.0" @@ -886,6 +910,23 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/94/54/e7d793b573f298e1c9013b8c4dade17d481164aa517d1d7148619c2cedbf/markdown_it_py-4.0.0-py3-none-any.whl", hash = "sha256:87327c59b172c5011896038353a81343b6754500a08cd7a4973bb48c6d578147", size = 87321, upload-time = "2025-08-11T12:57:51.923Z" }, ] +[package.optional-dependencies] +linkify = [ + { name = "linkify-it-py" }, +] + +[[package]] +name = "mdit-py-plugins" +version = "0.5.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "markdown-it-py" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/b2/fd/a756d36c0bfba5f6e39a1cdbdbfdd448dc02692467d83816dff4592a1ebc/mdit_py_plugins-0.5.0.tar.gz", hash = "sha256:f4918cb50119f50446560513a8e311d574ff6aaed72606ddae6d35716fe809c6", size = 44655, upload-time = "2025-08-11T07:25:49.083Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/fb/86/dd6e5db36df29e76c7a7699123569a4a18c1623ce68d826ed96c62643cae/mdit_py_plugins-0.5.0-py3-none-any.whl", hash = "sha256:07a08422fc1936a5d26d146759e9155ea466e842f5ab2f7d2266dd084c8dab1f", size = 57205, upload-time = "2025-08-11T07:25:47.597Z" }, +] + [[package]] name = "mdurl" version = "0.1.2" @@ -1585,18 +1626,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/f1/12/de94a39c2ef588c7e6455cfbe7343d3b2dc9d6b6b2f40c4c6565744c873d/pyyaml-6.0.3-cp314-cp314t-win_arm64.whl", hash = "sha256:ebc55a14a21cb14062aa4162f906cd962b28e2e9ea38f9b4391244cd8de4ae0b", size = 149341, upload-time = "2025-09-25T21:32:56.828Z" }, ] -[[package]] -name = "questionary" -version = "2.1.1" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "prompt-toolkit" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/f6/45/eafb0bba0f9988f6a2520f9ca2df2c82ddfa8d67c95d6625452e97b204a5/questionary-2.1.1.tar.gz", hash = "sha256:3d7e980292bb0107abaa79c68dd3eee3c561b83a0f89ae482860b181c8bd412d", size = 25845, upload-time = "2025-08-28T19:00:20.851Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/3c/26/1062c7ec1b053db9e499b4d2d5bc231743201b74051c973dadeac80a8f43/questionary-2.1.1-py3-none-any.whl", hash = "sha256:a51af13f345f1cdea62347589fbb6df3b290306ab8930713bfae4d475a7d4a59", size = 36753, upload-time = "2025-08-28T19:00:19.56Z" }, -] - [[package]] name = "rapidfuzz" version = "3.14.5" @@ -1755,6 +1784,23 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/37/c3/6eeb6034408dac0fa653d126c9204ade96b819c936e136c5e8a6897eee9c/socksio-1.0.0-py3-none-any.whl", hash = "sha256:95dc1f15f9b34e8d7b16f06d74b8ccf48f609af32ab33c608d08761c5dcbb1f3", size = 12763, upload-time = "2020-04-17T15:50:31.878Z" }, ] +[[package]] +name = "textual" +version = "8.2.4" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "markdown-it-py", extra = ["linkify"] }, + { name = "mdit-py-plugins" }, + { name = "platformdirs" }, + { name = "pygments" }, + { name = "rich" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/19/89/bec5709fb759f9c784bbcb30b2e3497df3f901691d13c2b864dbf6694a17/textual-8.2.4.tar.gz", hash = "sha256:d4e2b2ddd7157191d00b228592b7c739ea080b7d792fd410f23ca75f05ea76c4", size = 1848933, upload-time = "2026-04-19T04:20:45.845Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/5c/32/02932f0d597cdbb34e34bf24266ff0f2cf292ccb3aafc37dd9efcb0cc416/textual-8.2.4-py3-none-any.whl", hash = "sha256:a83bd3f0cc7125ca203845af753f9d6b6be030025ecd1b05cc75ebe645b9c4ba", size = 724390, upload-time = "2026-04-19T04:20:49.968Z" }, +] + [[package]] name = "tomli-w" version = "1.2.0" @@ -1867,6 +1913,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/dc/9b/47798a6c91d8bdb567fe2698fe81e0c6b7cb7ef4d13da4114b41d239f65d/typing_inspection-0.4.2-py3-none-any.whl", hash = "sha256:4ed1cacbdc298c220f1bd249ed5287caa16f34d44ef4e9c3d0cbad5b521545e7", size = 14611, upload-time = "2025-10-01T02:14:40.154Z" }, ] +[[package]] +name = "uc-micro-py" +version = "2.0.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/78/67/9a363818028526e2d4579334460df777115bdec1bb77c08f9db88f6389f2/uc_micro_py-2.0.0.tar.gz", hash = "sha256:c53691e495c8db60e16ffc4861a35469b0ba0821fe409a8a7a0a71864d33a811", size = 6611, upload-time = "2026-03-01T06:31:27.526Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/61/73/d21edf5b204d1467e06500080a50f79d49ef2b997c79123a536d4a17d97c/uc_micro_py-2.0.0-py3-none-any.whl", hash = "sha256:3603a3859af53e5a39bc7677713c78ea6589ff188d70f4fee165db88e22b242c", size = 6383, upload-time = "2026-03-01T06:31:26.257Z" }, +] + [[package]] name = "urllib3" version = "2.6.3" diff --git a/website/src/content/docs/docs/extending/hooks.mdx b/website/src/content/docs/docs/extending/hooks.mdx index 36d4de53..b8bf006b 100644 --- a/website/src/content/docs/docs/extending/hooks.mdx +++ b/website/src/content/docs/docs/extending/hooks.mdx @@ -48,7 +48,7 @@ Other hook consumers: - Sync hook calls skip awaitable return values and log a warning. - Therefore, keep bootstrap hooks synchronous: - `register_cli_commands` - - `onboard_config` + - `onboard_config` - `provide_channels` - `provide_tape_store` - `build_tape_context` @@ -70,17 +70,17 @@ from bub import hookimpl class WeatherPlugin: - @hookimpl - def onboard_config(self, current_config): - weather_config = dict(current_config.get("weather", {})) - api_key = weather_config.get("api_key", "") - - return { - "weather": { - "api_key": api_key or "demo-key", - "enabled": True, - } - } + @hookimpl + def onboard_config(self, current_config): + weather_config = dict(current_config.get("weather", {})) + api_key = weather_config.get("api_key", "") + + return { + "weather": { + "api_key": api_key or "demo-key", + "enabled": True, + } + } ``` Use `current_config` when your prompt or default value depends on answers from Bub itself or higher-priority plugins. For example, a plugin can branch on the selected model provider, pre-fill an API base, or avoid asking for settings that another plugin already supplied. diff --git a/website/src/content/docs/zh-cn/docs/extending/hooks.mdx b/website/src/content/docs/zh-cn/docs/extending/hooks.mdx index ea15f55a..82560783 100644 --- a/website/src/content/docs/zh-cn/docs/extending/hooks.mdx +++ b/website/src/content/docs/zh-cn/docs/extending/hooks.mdx @@ -48,7 +48,7 @@ description: Hook 执行语义、优先级、同步/异步规则、签名匹配 - 同步 hook 调用会跳过可等待的返回值并记录警告。 - 因此,引导阶段的 hook 应保持同步: - `register_cli_commands` - - `onboard_config` + - `onboard_config` - `provide_channels` - `provide_tape_store` - `build_tape_context` @@ -70,17 +70,17 @@ from bub import hookimpl class WeatherPlugin: - @hookimpl - def onboard_config(self, current_config): - weather_config = dict(current_config.get("weather", {})) - api_key = weather_config.get("api_key", "") - - return { - "weather": { - "api_key": api_key or "demo-key", - "enabled": True, - } - } + @hookimpl + def onboard_config(self, current_config): + weather_config = dict(current_config.get("weather", {})) + api_key = weather_config.get("api_key", "") + + return { + "weather": { + "api_key": api_key or "demo-key", + "enabled": True, + } + } ``` 当你的提问或默认值依赖 Bub 自身、或更高优先级插件先前收集到的答案时,请读取 `current_config`。例如,插件可以根据已选择的模型 provider 分支、预填 `api_base`,或避免重复询问其他插件已经提供过的设置。 From ff9ee1d5b5b4434695001bf746af4973520ebd28 Mon Sep 17 00:00:00 2001 From: Frost Ming Date: Sat, 25 Apr 2026 09:06:52 +0800 Subject: [PATCH 3/3] feat: implement inquirer utility functions for improved CLI interaction Signed-off-by: Frost Ming --- src/bub/builtin/hook_impl.py | 72 +++----------- src/bub/inquirer.py | 86 +++++++++++++++++ tests/test_builtin_cli.py | 175 ++++++++++++++++------------------- 3 files changed, 180 insertions(+), 153 deletions(-) create mode 100644 src/bub/inquirer.py diff --git a/src/bub/builtin/hook_impl.py b/src/bub/builtin/hook_impl.py index f5191912..5fb2a728 100644 --- a/src/bub/builtin/hook_impl.py +++ b/src/bub/builtin/hook_impl.py @@ -1,18 +1,14 @@ import sys from datetime import UTC, datetime from pathlib import Path -from typing import Any, cast +from typing import cast import typer -from inquirer_textual import prompts -from inquirer_textual.common.Choice import Choice -from inquirer_textual.common.InquirerResult import InquirerResult -from inquirer_textual.common.PromptSettings import PromptSettings -from inquirer_textual.common.Shortcut import Shortcut from loguru import logger from republic import AsyncStreamEvents, TapeContext from republic.tape import TapeStore +from bub import inquirer as bub_inquirer from bub.builtin.agent import Agent from bub.builtin.context import default_tape_context from bub.builtin.settings import DEFAULT_MODEL @@ -24,7 +20,6 @@ from bub.types import Envelope, MessageHandler, State AGENTS_FILE_NAME = "AGENTS.md" -CHECKBOX_HINT_SETTINGS = PromptSettings(shortcuts=[Shortcut("space", "toggle", "Space check/uncheck")]) MODEL_PROVIDER_CHOICES: tuple[str, ...] = ( "openrouter", "openai", @@ -61,15 +56,6 @@ """ -def ask_prompt(question: InquirerResult[Any]) -> Any: - if question.command in {"ctrl+c", "quit"}: - raise typer.Abort() - answer = question.value - if answer is None: - raise typer.Abort() - return answer - - class BuiltinImpl: """Default hook implementations for basic runtime operations.""" @@ -115,31 +101,6 @@ def _default_enabled_channels(current_value: object, available_channels: list[st return selected return available_channels - @classmethod - def _ask_onboard_checkbox( - cls, - message: str, - choices: list[str], - enabled: list[str] | None = None, - validate: Any = None, - ) -> list[str]: - while True: - answer: list[str | Choice] = ask_prompt( - prompts.checkbox( - message, - choices=cast("list[str | Choice]", choices), - enabled=cast("list[str | Choice] | None", enabled), - settings=CHECKBOX_HINT_SETTINGS, - ) - ) - values = list(cast("list[str]", answer or [])) - if validate is None: - return values - validation_result = validate(values) - if validation_result is True: - return values - typer.secho(str(validation_result), err=True, fg="red") - @hookimpl def resolve_session(self, message: ChannelMessage) -> str: session_id = field_of(message, "session_id") @@ -223,26 +184,24 @@ def onboard_config(self, current_config: dict[str, object]) -> dict[str, object] model_default = str(current_model) if isinstance(current_model, str) and current_model else DEFAULT_MODEL provider_default, model_name_default = self._split_model_identifier(model_default) - provider: str = ask_prompt( - prompts.fuzzy( - "LLM provider", - choices=cast("list[str | Choice]", self._provider_choices(provider_default)), - default=provider_default, - ) + provider = bub_inquirer.ask_fuzzy( + "LLM provider", + choices=self._provider_choices(provider_default), + default=provider_default, ) if provider == "custom": - provider = ask_prompt(prompts.text("Custom provider", default=provider_default)) or provider_default + provider = bub_inquirer.ask_text("Custom provider", default=provider_default) or provider_default - model_name: str = ask_prompt(prompts.text("LLM model", default=model_name_default)) + model_name = bub_inquirer.ask_text("LLM model", default=model_name_default) if not model_name: model_name = model_name_default model = f"{provider}:{model_name}" - api_key: str = ask_prompt(prompts.secret("API key (optional)")) + api_key = bub_inquirer.ask_secret("API key (optional)") current_api_base = current_config.get("api_base") api_base_default = str(current_api_base) if isinstance(current_api_base, str) else "" - api_base: str = ask_prompt(prompts.text("API base (optional)", default=api_base_default)) + api_base = bub_inquirer.ask_text("API base (optional)", default=api_base_default) current_api_format = current_config.get("api_format") api_format_default = ( @@ -250,23 +209,18 @@ def onboard_config(self, current_config: dict[str, object]) -> dict[str, object] if isinstance(current_api_format, str) and current_api_format in API_FORMAT_CHOICES else API_FORMAT_CHOICES[0] ) - api_format: str = ask_prompt( - prompts.select("API format", choices=list(API_FORMAT_CHOICES), default=api_format_default) - ) + api_format = bub_inquirer.ask_select("API format", choices=list(API_FORMAT_CHOICES), default=api_format_default) available_channels = self._channel_choices() default_channels = self._default_enabled_channels(current_config.get("enabled_channels"), available_channels) - enabled_channels = self._ask_onboard_checkbox( + enabled_channels = bub_inquirer.ask_checkbox( "Channels", choices=available_channels, enabled=default_channels, validate=lambda values: True if values else "Select at least one channel.", ) - stream_output = cast( - "bool", - ask_prompt(prompts.confirm("Stream output", default=bool(current_config.get("stream_output")))), - ) + stream_output = bub_inquirer.ask_confirm("Stream output", default=bool(current_config.get("stream_output"))) config: dict[str, object] = { "model": model, "api_format": api_format, diff --git a/src/bub/inquirer.py b/src/bub/inquirer.py new file mode 100644 index 00000000..63a96c5e --- /dev/null +++ b/src/bub/inquirer.py @@ -0,0 +1,86 @@ +from __future__ import annotations + +from collections.abc import Callable +from typing import Any, cast + +import typer +from inquirer_textual import prompts +from inquirer_textual.common.Choice import Choice +from inquirer_textual.common.InquirerResult import InquirerResult +from inquirer_textual.common.PromptSettings import PromptSettings +from inquirer_textual.common.Shortcut import Shortcut + +CheckboxValidator = Callable[[list[str]], bool | str] + +CHECKBOX_HINT_SETTINGS = PromptSettings(shortcuts=[Shortcut("space", "toggle", "Space check/uncheck")]) + + +def ask_prompt(question: InquirerResult[Any]) -> Any: + if question.command in {"ctrl+c", "quit"}: + raise typer.Abort() + answer = question.value + if answer is None: + raise typer.Abort() + return answer + + +def ask_text(message: str, default: str = "") -> str: + return cast("str", ask_prompt(prompts.text(message, default=default))) + + +def ask_secret(message: str) -> str: + return cast("str", ask_prompt(prompts.secret(message))) + + +def ask_confirm(message: str, default: bool = False) -> bool: + return cast("bool", ask_prompt(prompts.confirm(message, default=default))) + + +def ask_select(message: str, choices: list[str], default: str = "") -> str: + return cast( + "str", + ask_prompt( + prompts.select( + message, + choices=cast("list[str | Choice]", choices), + default=default, + ) + ), + ) + + +def ask_fuzzy(message: str, choices: list[str], default: str | None = None) -> str: + return cast( + "str", + ask_prompt( + prompts.fuzzy( + message, + choices=cast("list[str | Choice]", choices), + default=default, + ) + ), + ) + + +def ask_checkbox( + message: str, + choices: list[str], + enabled: list[str] | None = None, + validate: CheckboxValidator | None = None, +) -> list[str]: + while True: + answer: list[str | Choice] = ask_prompt( + prompts.checkbox( + message, + choices=cast("list[str | Choice]", choices), + enabled=cast("list[str | Choice] | None", enabled), + settings=CHECKBOX_HINT_SETTINGS, + ) + ) + values = list(cast("list[str]", answer or [])) + if validate is None: + return values + validation_result = validate(values) + if validation_result is True: + return values + typer.secho(str(validation_result), err=True, fg="red") diff --git a/tests/test_builtin_cli.py b/tests/test_builtin_cli.py index cb5c8f78..9d69686e 100644 --- a/tests/test_builtin_cli.py +++ b/tests/test_builtin_cli.py @@ -13,8 +13,8 @@ import bub.builtin.auth as auth import bub.builtin.cli as cli -import bub.builtin.hook_impl as builtin_hook_impl import bub.configure as configure +import bub.inquirer as bub_inquirer from bub.framework import BubFramework from bub.hookspecs import hookimpl @@ -74,37 +74,34 @@ def onboard_config(self, current_config): lambda message, default=None, hide_input=False, show_default=True: next(answers), ) monkeypatch.setattr( - builtin_hook_impl.prompts, - "text", - lambda message, default="": _fake_result(default), + bub_inquirer, + "ask_text", + lambda message, default="": default, ) monkeypatch.setattr( - builtin_hook_impl.prompts, - "fuzzy", - lambda message, choices, default=None: _fake_result(default), + bub_inquirer, + "ask_fuzzy", + lambda message, choices, default=None: default, ) monkeypatch.setattr( - builtin_hook_impl.prompts, - "select", - lambda message, choices, default="": _fake_result(default), + bub_inquirer, + "ask_select", + lambda message, choices, default="": default, ) monkeypatch.setattr( - builtin_hook_impl.prompts, - "checkbox", - lambda message, choices, enabled=None, settings=None: ( - _assert_checkbox_hint(settings), - _fake_result(["telegram"]), - )[1], + bub_inquirer, + "ask_checkbox", + lambda message, choices, enabled=None, validate=None: ["telegram"], ) monkeypatch.setattr( - builtin_hook_impl.prompts, - "confirm", - lambda message, default=False: _fake_result(default), + bub_inquirer, + "ask_confirm", + lambda message, default=False: default, ) monkeypatch.setattr( - builtin_hook_impl.prompts, - "secret", - lambda message: _fake_result(""), + bub_inquirer, + "ask_secret", + lambda message: "", ) result = CliRunner().invoke(app, ["onboard"]) @@ -133,42 +130,37 @@ def test_onboard_collects_builtin_runtime_config(tmp_path: Path, monkeypatch) -> app = framework.create_cli_app() monkeypatch.setattr( - builtin_hook_impl.prompts, - "text", - lambda message, default="": _fake_result( - { - "LLM model": "openrouter/free", - "API base (optional)": "https://openrouter.ai/api/v1", - }.get(message, default) - ), + bub_inquirer, + "ask_text", + lambda message, default="": { + "LLM model": "openrouter/free", + "API base (optional)": "https://openrouter.ai/api/v1", + }.get(message, default), ) monkeypatch.setattr( - builtin_hook_impl.prompts, - "fuzzy", - lambda message, choices, default=None: _fake_result("openrouter"), + bub_inquirer, + "ask_fuzzy", + lambda message, choices, default=None: "openrouter", ) monkeypatch.setattr( - builtin_hook_impl.prompts, - "select", - lambda message, choices, default="": _fake_result("responses"), + bub_inquirer, + "ask_select", + lambda message, choices, default="": "responses", ) monkeypatch.setattr( - builtin_hook_impl.prompts, - "checkbox", - lambda message, choices, enabled=None, settings=None: ( - _assert_checkbox_hint(settings), - _fake_result(["telegram", "cli"]), - )[1], + bub_inquirer, + "ask_checkbox", + lambda message, choices, enabled=None, validate=None: ["telegram", "cli"], ) monkeypatch.setattr( - builtin_hook_impl.prompts, - "confirm", - lambda message, default=False: _fake_result(True), + bub_inquirer, + "ask_confirm", + lambda message, default=False: True, ) monkeypatch.setattr( - builtin_hook_impl.prompts, - "secret", - lambda message: _fake_result("sk-test"), + bub_inquirer, + "ask_secret", + lambda message: "sk-test", ) result = CliRunner().invoke(app, ["onboard"]) @@ -196,44 +188,44 @@ def test_onboard_aborts_immediately_when_builtin_prompt_is_interrupted(tmp_path: framework.load_hooks() app = framework.create_cli_app() - def fake_fuzzy(message: str, choices: list[str], default: str | None = None) -> InquirerResult[Any]: + def fake_fuzzy(message: str, choices: list[str], default: str | None = None) -> str: asked_messages.append(message) - return _fake_result(default) + assert default is not None + return default - def fake_select(message: str, choices: list[str], default: str = "") -> InquirerResult[Any]: + def fake_select(message: str, choices: list[str], default: str = "") -> str: asked_messages.append(message) - return _fake_result(default) + return default def fake_checkbox( message: str, choices: list[object], enabled=None, - settings: PromptSettings | None = None, - ) -> InquirerResult[Any]: + validate=None, + ) -> list[str]: asked_messages.append(message) - _assert_checkbox_hint(settings) - return _fake_result(["telegram"]) + return ["telegram"] - def fake_confirm(message: str, default: bool = False) -> InquirerResult[Any]: + def fake_confirm(message: str, default: bool = False) -> bool: asked_messages.append(message) - return _fake_result(default) + return default - def fake_text(message: str, default: str = "") -> InquirerResult[Any]: + def fake_text(message: str, default: str = "") -> str: asked_messages.append(message) if message == "API base (optional)": raise AssertionError("Onboarding should stop after interruption") - return _fake_result("openrouter:openrouter/free") + return "openrouter:openrouter/free" - def fake_secret(message: str) -> InquirerResult[Any]: + def fake_secret(message: str) -> str: asked_messages.append("API key (optional)") - return _fake_result(None) + raise typer.Abort() - monkeypatch.setattr(builtin_hook_impl.prompts, "fuzzy", fake_fuzzy) - monkeypatch.setattr(builtin_hook_impl.prompts, "select", fake_select) - monkeypatch.setattr(builtin_hook_impl.prompts, "checkbox", fake_checkbox) - monkeypatch.setattr(builtin_hook_impl.prompts, "confirm", fake_confirm) - monkeypatch.setattr(builtin_hook_impl.prompts, "text", fake_text) - monkeypatch.setattr(builtin_hook_impl.prompts, "secret", fake_secret) + monkeypatch.setattr(bub_inquirer, "ask_fuzzy", fake_fuzzy) + monkeypatch.setattr(bub_inquirer, "ask_select", fake_select) + monkeypatch.setattr(bub_inquirer, "ask_checkbox", fake_checkbox) + monkeypatch.setattr(bub_inquirer, "ask_confirm", fake_confirm) + monkeypatch.setattr(bub_inquirer, "ask_text", fake_text) + monkeypatch.setattr(bub_inquirer, "ask_secret", fake_secret) result = CliRunner().invoke(app, ["onboard"]) @@ -257,42 +249,37 @@ def test_onboard_collects_builtin_runtime_config_with_custom_provider(tmp_path: app = framework.create_cli_app() monkeypatch.setattr( - builtin_hook_impl.prompts, - "fuzzy", - lambda message, choices, default=None: _fake_result("custom"), + bub_inquirer, + "ask_fuzzy", + lambda message, choices, default=None: "custom", ) monkeypatch.setattr( - builtin_hook_impl.prompts, - "select", - lambda message, choices, default="": _fake_result("messages"), + bub_inquirer, + "ask_select", + lambda message, choices, default="": "messages", ) monkeypatch.setattr( - builtin_hook_impl.prompts, - "checkbox", - lambda message, choices, enabled=None, settings=None: ( - _assert_checkbox_hint(settings), - _fake_result(["telegram"]), - )[1], + bub_inquirer, + "ask_checkbox", + lambda message, choices, enabled=None, validate=None: ["telegram"], ) monkeypatch.setattr( - builtin_hook_impl.prompts, - "confirm", - lambda message, default=False: _fake_result(False), + bub_inquirer, + "ask_confirm", + lambda message, default=False: False, ) monkeypatch.setattr( - builtin_hook_impl.prompts, - "text", - lambda message, default="": _fake_result( - { - "Custom provider": "acme", - "LLM model": "ultra-1", - }.get(message, default) - ), + bub_inquirer, + "ask_text", + lambda message, default="": { + "Custom provider": "acme", + "LLM model": "ultra-1", + }.get(message, default), ) monkeypatch.setattr( - builtin_hook_impl.prompts, - "secret", - lambda message: _fake_result(""), + bub_inquirer, + "ask_secret", + lambda message: "", ) result = CliRunner().invoke(app, ["onboard"])