From fdf45c8a8e86cf99d5bacbef9e8e2d1d8c2af8d4 Mon Sep 17 00:00:00 2001 From: protosphinx <133899485+protosphinx@users.noreply.github.com> Date: Sat, 2 May 2026 12:57:45 -0700 Subject: [PATCH 1/6] chore: format sweep + Python 3.8-3.11 f-string compat fix Two separate hygiene fixes bundled because they both block the new test suite from running green in CI: 1. **black + isort sweep across 17 source files.** The repo was never run through `black` after the formatter became a CI gate in `.github/workflows/test.yml`. Pure whitespace / import-order changes; no logic touched. 2. **`tools/data_tools._generate_create_table_sql` syntax fix.** Line 188 used `{',\n '.join(column_definitions)}` inside an f-string. Backslashes in f-string expression parts are a SyntaxError on Python 3.8 - 3.11 (only legal from 3.12 onward), and `pyproject.toml` declares `requires-python = ">=3.8"` plus the CI matrix runs 3.8-3.12. The module is unimportable on every pre-3.12 Python; CI never caught it because nothing imported the module. Fixed by binding the separator to a name (`sep = ",\n "`) so the f-string expression is backslash-free. Behaviour identical. After this commit, `black --check`, `isort --check-only`, and `flake8 . --select=E9,F63,F7,F82` all pass cleanly. The next commit adds the test suite that surfaced both gaps. --- agent/__init__.py | 2 +- agent/clickhouse_agent.py | 279 ++++++++++++---------- config/__init__.py | 2 +- config/settings.py | 100 +++++--- main.py | 167 ++++++------- providers/__init__.py | 2 +- providers/local_llm.py | 482 +++++++++++++++++++++++++------------- setup.py | 5 +- tools/__init__.py | 2 +- tools/clickhouse_tools.py | 479 +++++++++++++++++++------------------ tools/data_tools.py | 236 ++++++++++--------- ui/__init__.py | 2 +- ui/beautiful_interface.py | 215 ++++++++++------- ui/minimal_interface.py | 257 ++++++++++++-------- ui/onboarding.py | 313 +++++++++++++++---------- ui/settings_manager.py | 193 ++++++++------- utils/__init__.py | 2 +- utils/logging.py | 54 +++-- 18 files changed, 1615 insertions(+), 1177 deletions(-) diff --git a/agent/__init__.py b/agent/__init__.py index 8403c75..7c1bd03 100644 --- a/agent/__init__.py +++ b/agent/__init__.py @@ -1,3 +1,3 @@ from .clickhouse_agent import ClickHouseAgent -__all__ = ['ClickHouseAgent'] \ No newline at end of file +__all__ = ["ClickHouseAgent"] diff --git a/agent/clickhouse_agent.py b/agent/clickhouse_agent.py index 910644d..9384e71 100644 --- a/agent/clickhouse_agent.py +++ b/agent/clickhouse_agent.py @@ -5,33 +5,36 @@ import asyncio import json -from typing import Dict, List, Any, Optional from pathlib import Path +from typing import Any, Dict, List, Optional +# Use OpenAI client directly +from openai import AsyncOpenAI from rich.console import Console -from rich.prompt import Prompt, Confirm -from rich.panel import Panel -from rich.text import Text from rich.live import Live +from rich.panel import Panel +from rich.prompt import Confirm, Prompt from rich.spinner import Spinner - -from ui.minimal_interface import ui +from rich.text import Text from config.settings import ClickHouseConfig from providers.local_llm import LocalLLMProvider -from tools.clickhouse_tools import ClickHouseConnection, ClickHouseToolExecutor, OPENAI_TOOLS -from tools.data_tools import DataLoader, DataVisualizer, DataExporter +from tools.clickhouse_tools import ( + OPENAI_TOOLS, + ClickHouseConnection, + ClickHouseToolExecutor, +) +from tools.data_tools import DataExporter, DataLoader, DataVisualizer +from ui.minimal_interface import ui from utils.logging import get_logger -# Use OpenAI client directly -from openai import AsyncOpenAI - logger = get_logger(__name__) console = Console() + class ClickHouseAgent: """Main ClickHouse AI Agent class""" - + def __init__(self, config: ClickHouseConfig): self.config = config self.connection = ClickHouseConnection(config) @@ -44,110 +47,105 @@ def __init__(self, config: ClickHouseConfig): self.conversation_history = [] self.max_tool_calls = config.max_tool_calls self.current_tool_calls = 0 - + async def initialize(self): """Initialize all components""" - + # Initialize ClickHouse connection await self.connection.connect() - + # Initialize tool executor self.tool_executor = ClickHouseToolExecutor(self.connection) await self.tool_executor.initialize() - + # Initialize data utilities self.data_loader = DataLoader(self.tool_executor.client) self.data_visualizer = DataVisualizer(self.tool_executor.client) self.data_exporter = DataExporter(self.tool_executor.client) - + # Initialize LLM provider and OpenAI client (local only) self.llm_provider = LocalLLMProvider( - base_url=self.config.local_llm_base_url, - model=self.config.local_llm_model + base_url=self.config.local_llm_base_url, model=self.config.local_llm_model ) # Create OpenAI client for local server openai_config = self.llm_provider.get_openai_config() self.openai_client = AsyncOpenAI( - api_key=openai_config["api_key"], - base_url=openai_config["base_url"] + api_key=openai_config["api_key"], base_url=openai_config["base_url"] ) - + logger.info("ClickHouse AI Agent initialized successfully") - + async def start_interactive_session(self): """Start interactive chat session""" - + await self.initialize() - - console.print("[dim bright_cyan]โ—[/dim bright_cyan] [bright_white]Ready to help with your data![/bright_white]") - console.print("[dim]Type your questions or commands. Type 'exit' to quit.[/dim]\n") - + + console.print( + "[dim bright_cyan]โ—[/dim bright_cyan] [bright_white]Ready to help with your data![/bright_white]" + ) + console.print( + "[dim]Type your questions or commands. Type 'exit' to quit.[/dim]\n" + ) + # Force reset conversation history completely self.conversation_history = [] self.current_tool_calls = 0 - + # Initialize conversation history cleanly system_message = self._build_system_prompt() - self.conversation_history = [ - {"role": "system", "content": system_message} - ] - + self.conversation_history = [{"role": "system", "content": system_message}] + try: while True: # Get user input with beautiful prompt ui.show_user_input_prompt() user_input = input().strip() - - if user_input.lower() in ['exit', 'quit', 'bye']: + + if user_input.lower() in ["exit", "quit", "bye"]: ui.show_goodbye() break - - if user_input.lower() in ['clear', 'reset']: + + if user_input.lower() in ["clear", "reset"]: self.conversation_history = [ {"role": "system", "content": system_message} ] self.current_tool_calls = 0 ui.show_success("Conversation reset") continue - + # Add user message to conversation - self.conversation_history.append({ - "role": "user", - "content": user_input - }) - + self.conversation_history.append( + {"role": "user", "content": user_input} + ) + # Process the conversation await self._process_conversation() - + except KeyboardInterrupt: console.print("\n[yellow]๐Ÿ‘‹ Goodbye![/yellow]") finally: await self._cleanup() - + async def _process_conversation(self): """Process the conversation with the AI agent""" - + # Simple validation - ensure clean conversation history if not self.conversation_history: logger.warning("Empty conversation history detected, initializing") system_message = self._build_system_prompt() - self.conversation_history = [ - {"role": "system", "content": system_message} - ] + self.conversation_history = [{"role": "system", "content": system_message}] elif self.conversation_history[0].get("role") != "system": logger.warning("Invalid conversation history, reinitializing") system_message = self._build_system_prompt() - self.conversation_history = [ - {"role": "system", "content": system_message} - ] - + self.conversation_history = [{"role": "system", "content": system_message}] + stop_requested = False last_tool_calls = [] # Track recent tool calls to prevent loops - + # Create the live animation outside the context manager so we can control it manually live = ui.show_thinking_animation() live.start() - + try: while self.current_tool_calls < self.max_tool_calls and not stop_requested: print(f"\n") @@ -162,14 +160,14 @@ async def _process_conversation(self): # Make LLM call with OpenAI client directly # Use the actual model path that llama-server provides model_name = "vishprometa/clickhouse-qwen3-1.7b-gguf" - + response = await self.openai_client.chat.completions.create( model=model_name, messages=self.conversation_history, tools=OPENAI_TOOLS, tool_choice="auto", temperature=self.config.temperature, - max_tokens=self.config.max_tokens + max_tokens=self.config.max_tokens, ) # Extract message from OpenAI response # Debug: print response (remove this in production) @@ -177,14 +175,14 @@ async def _process_conversation(self): message = response.choices[0].message # Extract reasoning content if available - reasoning_content = getattr(message, 'reasoning_content', None) + reasoning_content = getattr(message, "reasoning_content", None) # Convert OpenAI message to dict format for conversation history assistant_msg = { "role": "assistant", "content": message.content, } - + # Add tool calls if present if message.tool_calls: assistant_msg["tool_calls"] = [ @@ -193,12 +191,12 @@ async def _process_conversation(self): "type": tc.type, "function": { "name": tc.function.name, - "arguments": tc.function.arguments - } + "arguments": tc.function.arguments, + }, } for tc in message.tool_calls ] - + self.conversation_history.append(assistant_msg) # Process the response @@ -207,11 +205,11 @@ async def _process_conversation(self): # Display reasoning and text content live.stop() - + # Always show reasoning if available if reasoning_content: ui.show_reasoning(reasoning_content) - + # Show text content if available if text_content: # Show the response with markdown rendering @@ -227,30 +225,40 @@ async def _process_conversation(self): tool_id = tool_call.id # Switch to tool execution animation with arguments display live.stop() - tool_live = ui.show_tool_execution(tool_name, f"Running {tool_name.replace('_', ' ')}", tool_input) + tool_live = ui.show_tool_execution( + tool_name, + f"Running {tool_name.replace('_', ' ')}", + tool_input, + ) tool_live.start() try: # Check for stop agent if tool_name == "stop_agent": - summary = tool_input.get("summary", "Task completed") + summary = tool_input.get( + "summary", "Task completed" + ) result = f"Agent stopped: {summary}" # Add tool result BEFORE breaking (required for OpenAI format) - tool_results.append({ - "tool_call_id": tool_id, - "name": tool_name, - "content": result - }) + tool_results.append( + { + "tool_call_id": tool_id, + "name": tool_name, + "content": result, + } + ) ui.show_success(summary) stop_requested = True break # Execute the tool result = await self._execute_tool(tool_name, tool_input) # Store tool result for conversation history - tool_results.append({ - "tool_call_id": tool_id, - "name": tool_name, - "content": result - }) + tool_results.append( + { + "tool_call_id": tool_id, + "name": tool_name, + "content": result, + } + ) self.current_tool_calls += 1 finally: tool_live.stop() @@ -259,19 +267,28 @@ async def _process_conversation(self): # Truncate very long tool results to prevent context overflow content = tool_result["content"] if len(content) > 2000: # Limit tool result content - content = content[:1500] + f"\n... (truncated, full results shown above)" - - self.conversation_history.append({ - "role": "tool", - "tool_call_id": tool_result["tool_call_id"], - "name": tool_result["name"], - "content": content - }) - + content = ( + content[:1500] + + f"\n... (truncated, full results shown above)" + ) + + self.conversation_history.append( + { + "role": "tool", + "tool_call_id": tool_result["tool_call_id"], + "name": tool_result["name"], + "content": content, + } + ) + # Manage conversation history length to prevent context overflow - if len(self.conversation_history) > 10: # Keep only recent messages + if ( + len(self.conversation_history) > 10 + ): # Keep only recent messages # Keep system message + last 8 messages - self.conversation_history = [self.conversation_history[0]] + self.conversation_history[-8:] + self.conversation_history = [ + self.conversation_history[0] + ] + self.conversation_history[-8:] # If stop was requested, we're done - don't continue the loop if stop_requested: # print(f"Stop requested: {stop_requested}") @@ -296,61 +313,61 @@ async def _process_conversation(self): # Check if we hit the tool call limit if self.current_tool_calls >= self.max_tool_calls: ui.show_warning(f"Reached maximum tool calls limit ({self.max_tool_calls})") - + async def _execute_tool(self, tool_name: str, tool_input: Dict[str, Any]) -> str: """Execute a tool and return the result""" - + try: if tool_name == "list_databases": return await self.tool_executor.list_databases() - + elif tool_name == "switch_database": return await self.tool_executor.switch_database( database_name=tool_input["database_name"] ) - + elif tool_name == "execute_clickhouse_query": return await self.tool_executor.execute_clickhouse_query( query=tool_input["query"] ) - + elif tool_name == "list_tables": return await self.tool_executor.list_tables() - + elif tool_name == "get_table_schema": return await self.tool_executor.get_table_schema( table_name=tool_input["table_name"] ) - + elif tool_name == "search_table": return await self.tool_executor.search_table( table_name=tool_input["table_name"], limit=tool_input.get("limit", 100), - where_clause=tool_input.get("where_clause") + where_clause=tool_input.get("where_clause"), ) - + elif tool_name == "export_data_to_csv": return await self.tool_executor.export_data_to_csv( query=tool_input["query"], filename=tool_input.get("filename"), - analysis_limit=tool_input.get("analysis_limit", 50) + analysis_limit=tool_input.get("analysis_limit", 50), ) - + elif tool_name == "stop_agent": # This is handled in the main conversation loop return f"Agent stopped: {tool_input.get('summary', 'Task completed')}" - + else: return f"Unknown tool: {tool_name}" - + except Exception as e: error_msg = f"Tool execution failed: {str(e)}" logger.error(f"Tool {tool_name} failed: {e}") return error_msg - + def _build_system_prompt(self) -> str: """Build the system prompt for the AI agent""" - + return f"""You are Proto, a ClickHouse AI agent for data analysis and querying. CORE MINDSET: Be proactive and smart. Handle vague queries by exploring what's available. @@ -381,82 +398,83 @@ def _build_system_prompt(self) -> str: Be curious and analytical. Show interesting patterns, outliers, and insights without being asked.""" - - async def execute_single_query(self, query: str, output_format: str = "table", save_to: Optional[Path] = None): + async def execute_single_query( + self, query: str, output_format: str = "table", save_to: Optional[Path] = None + ): """Execute a single query (for CLI query command)""" - + await self.initialize() - + try: result = await self.tool_executor.execute_query( - query=query, - format=output_format + query=query, format=output_format ) - + if save_to and output_format != "table": - with open(save_to, 'w') as f: + with open(save_to, "w") as f: f.write(result) console.print(f"[green]โœ“ Results saved to {save_to}[/green]") - + except Exception as e: console.print(f"[red]โŒ Error: {e}[/red]") finally: await self._cleanup() - + async def analyze_table(self, table_name: str, deep: bool = False): """Analyze a specific table (for CLI analyze command)""" - + await self.initialize() - + try: result = await self.tool_executor.analyze_table( - table_name=table_name, - sample_size=50000 if deep else 10000 + table_name=table_name, sample_size=50000 if deep else 10000 ) console.print(result) - + except Exception as e: console.print(f"[red]โŒ Error: {e}[/red]") finally: await self._cleanup() - + async def load_data_from_file( self, file_path: Path, table_name: str, create_table: bool = True, - batch_size: int = 10000 + batch_size: int = 10000, ): """Load data from file (for CLI load-data command)""" - + await self.initialize() - + try: - if file_path.suffix.lower() == '.csv': + if file_path.suffix.lower() == ".csv": result = await self.data_loader.load_from_csv( file_path=str(file_path), table_name=table_name, create_table=create_table, - batch_size=batch_size + batch_size=batch_size, ) - elif file_path.suffix.lower() == '.json': + elif file_path.suffix.lower() == ".json": result = await self.data_loader.load_from_json( file_path=str(file_path), table_name=table_name, create_table=create_table, - batch_size=batch_size + batch_size=batch_size, ) else: - console.print(f"[red]โŒ Unsupported file format: {file_path.suffix}[/red]") + console.print( + f"[red]โŒ Unsupported file format: {file_path.suffix}[/red]" + ) return - + console.print(f"[green]โœ“ {result}[/green]") - + except Exception as e: console.print(f"[red]โŒ Error: {e}[/red]") finally: await self._cleanup() - + async def _cleanup(self): """Cleanup resources""" try: @@ -469,5 +487,6 @@ async def _cleanup(self): except Exception as e: logger.error(f"Error during cleanup: {e}") + # Export the main class -__all__ = ['ClickHouseAgent'] \ No newline at end of file +__all__ = ["ClickHouseAgent"] diff --git a/config/__init__.py b/config/__init__.py index be6a4fc..2726292 100644 --- a/config/__init__.py +++ b/config/__init__.py @@ -1 +1 @@ -# Config package \ No newline at end of file +# Config package diff --git a/config/settings.py b/config/settings.py index 9cf79b6..e1536aa 100644 --- a/config/settings.py +++ b/config/settings.py @@ -7,16 +7,17 @@ from typing import Optional import typer +from dotenv import load_dotenv from pydantic import BaseModel, Field from rich.console import Console -from rich.prompt import Prompt, Confirm -from dotenv import load_dotenv +from rich.prompt import Confirm, Prompt console = Console() + class ClickHouseConfig(BaseModel): """ClickHouse connection configuration""" - + host: str = Field(default="localhost", description="ClickHouse host") port: int = Field(default=8123, description="ClickHouse HTTP port") username: str = Field(default="default", description="ClickHouse username") @@ -26,37 +27,50 @@ class ClickHouseConfig(BaseModel): # Provider selection provider: str = Field(default="local", description="LLM provider: local only") - + # Local LLM (llama.cpp / llamafile / llama-cpp-python server) configuration - local_llm_base_url: str = Field(default="http://127.0.0.1:8000/v1", description="Local LLM server base URL") - local_llm_model: str = Field(default="qwen3-1.7b", description="Local LLM model name") - + local_llm_base_url: str = Field( + default="http://127.0.0.1:8000/v1", description="Local LLM server base URL" + ) + local_llm_model: str = Field( + default="qwen3-1.7b", description="Local LLM model name" + ) + # Legacy OpenRouter configuration (kept for backward compatibility) openrouter_api_key: str = Field(default="", description="Legacy OpenRouter API key") - openrouter_model: str = Field(default="openai/gpt-4o-mini", description="Legacy OpenRouter model") - openrouter_provider_only: str = Field(default="openai", description="Legacy OpenRouter provider preference") - openrouter_data_collection: str = Field(default="deny", description="Legacy OpenRouter data collection setting") - + openrouter_model: str = Field( + default="openai/gpt-4o-mini", description="Legacy OpenRouter model" + ) + openrouter_provider_only: str = Field( + default="openai", description="Legacy OpenRouter provider preference" + ) + openrouter_data_collection: str = Field( + default="deny", description="Legacy OpenRouter data collection setting" + ) + # Agent configuration - max_tool_calls: int = Field(default=35, description="Maximum tool calls per conversation") + max_tool_calls: int = Field( + default=35, description="Maximum tool calls per conversation" + ) temperature: float = Field(default=0.1, description="LLM temperature") max_tokens: int = Field(default=4000, description="Maximum tokens per response") + def load_config( config_file: Optional[Path] = None, host: Optional[str] = None, port: Optional[int] = None, username: Optional[str] = None, password: Optional[str] = None, - database: Optional[str] = None + database: Optional[str] = None, ) -> ClickHouseConfig: """Load configuration from file and command line arguments""" - + # Load environment variables from .env file load_dotenv() - + config_data = {} - + # Determine default config file if not provided default_config_file = Path.home() / ".config" / "proto" / "proto-config.json" legacy_config_file = Path("proto-config.json") @@ -71,6 +85,7 @@ def load_config( # Load from config file if found if candidate_config and candidate_config.exists(): import json + with open(candidate_config) as f: config_data = json.load(f) @@ -86,27 +101,31 @@ def load_config( for old_key, new_key in key_mapping.items(): if old_key in config_data and new_key not in config_data: config_data[new_key] = config_data[old_key] - + # Load from environment variables - config_data.update({ - k.lower().replace("clickhouse_", ""): v - for k, v in os.environ.items() - if k.startswith("CLICKHOUSE_") - }) - + config_data.update( + { + k.lower().replace("clickhouse_", ""): v + for k, v in os.environ.items() + if k.startswith("CLICKHOUSE_") + } + ) + # OpenRouter configuration from environment if "OPENROUTER_API_KEY" in os.environ: config_data["openrouter_api_key"] = os.environ["OPENROUTER_API_KEY"] - + if "OPENROUTER_MODEL" in os.environ: config_data["openrouter_model"] = os.environ["OPENROUTER_MODEL"] - + if "OPENROUTER_PROVIDER_ONLY" in os.environ: config_data["openrouter_provider_only"] = os.environ["OPENROUTER_PROVIDER_ONLY"] - + if "OPENROUTER_DATA_COLLECTION" in os.environ: - config_data["openrouter_data_collection"] = os.environ["OPENROUTER_DATA_COLLECTION"] - + config_data["openrouter_data_collection"] = os.environ[ + "OPENROUTER_DATA_COLLECTION" + ] + # Override with command line arguments if host: config_data["host"] = host @@ -118,21 +137,22 @@ def load_config( config_data["password"] = password if database: config_data["database"] = database - + config = ClickHouseConfig(**config_data) - + # No interactive configuration needed for local provider if config.provider != "local": console.print("[yellow]โš ๏ธ Only local provider is supported[/yellow]") console.print("[blue]โ„น๏ธ Switching to local provider automatically[/blue]") config.provider = "local" - + return config + def save_env_config(config: ClickHouseConfig): """Save configuration to .env file""" env_path = Path(".env") - + env_content = f"""# ClickHouse Configuration CLICKHOUSE_HOST={config.host} CLICKHOUSE_PORT={config.port} @@ -147,12 +167,13 @@ def save_env_config(config: ClickHouseConfig): OPENROUTER_PROVIDER_ONLY={config.openrouter_provider_only} OPENROUTER_DATA_COLLECTION={config.openrouter_data_collection} """ - + with open(env_path, "w") as f: f.write(env_content) - + console.print(f"[green]โœ“[/green] Configuration saved to {env_path}") + def create_sample_config(): """Create a sample configuration file""" config_data = { @@ -168,13 +189,14 @@ def create_sample_config(): "openrouter_data_collection": "deny", "max_tool_calls": 35, "temperature": 0.1, - "max_tokens": 4000 + "max_tokens": 4000, } - + config_path = Path("proto-config.json") - + import json + with open(config_path, "w") as f: json.dump(config_data, f, indent=2) - - console.print(f"[green]โœ“[/green] Sample configuration created at {config_path}") \ No newline at end of file + + console.print(f"[green]โœ“[/green] Sample configuration created at {config_path}") diff --git a/main.py b/main.py index 364beb8..784c9a7 100644 --- a/main.py +++ b/main.py @@ -17,10 +17,10 @@ from agent import ClickHouseAgent from config.settings import load_config -from utils.logging import setup_logging from ui.minimal_interface import ui -from ui.onboarding import needs_onboarding, OnboardingFlow +from ui.onboarding import OnboardingFlow, needs_onboarding from ui.settings_manager import SettingsManager +from utils.logging import setup_logging app = typer.Typer( name="proto", @@ -32,6 +32,7 @@ console = Console() + @app.callback(invoke_without_command=True) def main(ctx: typer.Context): """Default entry: start interactive chat (with onboarding on first run).""" @@ -53,59 +54,35 @@ def main(ctx: typer.Context): @app.command() def chat( config_file: Optional[Path] = typer.Option( - None, - "--config", - "-c", - help="Path to configuration file" - ), - host: Optional[str] = typer.Option( - None, - "--host", - "-h", - help="ClickHouse host" - ), - port: Optional[int] = typer.Option( - None, - "--port", - "-p", - help="ClickHouse port" + None, "--config", "-c", help="Path to configuration file" ), + host: Optional[str] = typer.Option(None, "--host", "-h", help="ClickHouse host"), + port: Optional[int] = typer.Option(None, "--port", "-p", help="ClickHouse port"), username: Optional[str] = typer.Option( - None, - "--username", - "-u", - help="ClickHouse username" + None, "--username", "-u", help="ClickHouse username" ), password: Optional[str] = typer.Option( - None, - "--password", - help="ClickHouse password" + None, "--password", help="ClickHouse password" ), database: Optional[str] = typer.Option( - None, - "--database", - "-d", - help="ClickHouse database" + None, "--database", "-d", help="ClickHouse database" ), verbose: bool = typer.Option( - False, - "--verbose", - "-v", - help="Enable verbose logging" - ) + False, "--verbose", "-v", help="Enable verbose logging" + ), ): """Start interactive chat with ClickHouse AI Agent""" - + # Setup logging (quiet mode for beautiful UI unless verbose) setup_logging(verbose=verbose, quiet_mode=not verbose) - + # Check if onboarding is needed if needs_onboarding(): onboarding = OnboardingFlow() onboarding.run_onboarding() # After onboarding, rely on default config discovery config_file = None - + # Load configuration config = load_config( config_file=config_file, @@ -113,20 +90,17 @@ def chat( port=port, username=username, password=password, - database=database + database=database, ) - + # Display beautiful welcome screen ui.show_welcome_screen() - + # Show connection status ui.show_connection_status( - host=config.host, - port=config.port, - database=config.database, - connected=True + host=config.host, port=config.port, database=config.database, connected=True ) - + # Start the agent try: agent = ClickHouseAgent(config) @@ -138,17 +112,22 @@ def chat( ui.show_error(str(e)) sys.exit(1) + @app.command() def query( sql: str = typer.Argument(..., help="SQL query to execute"), config_file: Optional[Path] = typer.Option(None, "--config", "-c"), - output_format: str = typer.Option("table", "--format", "-f", help="Output format: table, json, csv"), - save_to: Optional[Path] = typer.Option(None, "--save", "-s", help="Save results to file") + output_format: str = typer.Option( + "table", "--format", "-f", help="Output format: table, json, csv" + ), + save_to: Optional[Path] = typer.Option( + None, "--save", "-s", help="Save results to file" + ), ): """Execute a single SQL query""" - + config = load_config(config_file=config_file) - + try: agent = ClickHouseAgent(config) asyncio.run(agent.execute_single_query(sql, output_format, save_to)) @@ -156,16 +135,17 @@ def query( console.print(f"[red]โŒ Error: {e}[/red]") sys.exit(1) + @app.command() def analyze( table: str = typer.Argument(..., help="Table name to analyze"), config_file: Optional[Path] = typer.Option(None, "--config", "-c"), - deep: bool = typer.Option(False, "--deep", help="Perform deep analysis") + deep: bool = typer.Option(False, "--deep", help="Perform deep analysis"), ): """Analyze a specific table""" - + config = load_config(config_file=config_file) - + try: agent = ClickHouseAgent(config) asyncio.run(agent.analyze_table(table, deep=deep)) @@ -173,77 +153,92 @@ def analyze( console.print(f"[red]โŒ Error: {e}[/red]") sys.exit(1) + @app.command() def load_data( file_path: Path = typer.Argument(..., help="Path to data file"), table: str = typer.Argument(..., help="Target table name"), config_file: Optional[Path] = typer.Option(None, "--config", "-c"), - create_table: bool = typer.Option(True, "--create-table", help="Create table if it doesn't exist"), - batch_size: int = typer.Option(10000, "--batch-size", help="Batch size for data loading") + create_table: bool = typer.Option( + True, "--create-table", help="Create table if it doesn't exist" + ), + batch_size: int = typer.Option( + 10000, "--batch-size", help="Batch size for data loading" + ), ): """Load data from file into ClickHouse""" - + config = load_config(config_file=config_file) - + try: agent = ClickHouseAgent(config) - asyncio.run(agent.load_data_from_file(file_path, table, create_table, batch_size)) + asyncio.run( + agent.load_data_from_file(file_path, table, create_table, batch_size) + ) except Exception as e: console.print(f"[red]โŒ Error: {e}[/red]") sys.exit(1) + @app.command() def settings(): """Manage Proto settings and configuration""" settings_manager = SettingsManager() settings_manager.run_settings_menu() + @app.command() def clear(): """Clear all configuration and start fresh""" from pathlib import Path + from rich.prompt import Confirm - + # Config file locations config_files = [ Path.home() / ".config" / "proto" / "proto-config.json", Path("proto-config.json"), - Path(".env") + Path(".env"), ] - + existing_files = [f for f in config_files if f.exists()] - + if not existing_files: console.print("[yellow]No configuration files found to clear.[/yellow]") return - + console.print("[bold red]โš ๏ธ This will delete all Proto configuration:[/bold red]") for file in existing_files: console.print(f" โ€ข {file}") console.print() - - if Confirm.ask("[bold red]Are you sure you want to clear all configuration?[/bold red]"): + + if Confirm.ask( + "[bold red]Are you sure you want to clear all configuration?[/bold red]" + ): for file in existing_files: try: file.unlink() console.print(f"[green]โœ“[/green] Deleted {file}") except Exception as e: console.print(f"[red]โœ—[/red] Failed to delete {file}: {e}") - + console.print() - console.print("[green]๐ŸŽ‰ Configuration cleared! Run 'proto' to start fresh onboarding.[/green]") + console.print( + "[green]๐ŸŽ‰ Configuration cleared! Run 'proto' to start fresh onboarding.[/green]" + ) else: console.print("[blue]Configuration clearing cancelled.[/blue]") + @app.command() def refresh_template(): """Refresh chat template from Hugging Face repository""" from providers.local_llm import LocalLLMProvider - + console.print("[bold cyan]๐Ÿ”„ Refreshing Chat Template[/bold cyan]") console.print("Fetching latest chat template from Hugging Face repository...") console.print() - + try: # Create a minimal provider instance just for template refresh provider = LocalLLMProvider.__new__(LocalLLMProvider) @@ -251,27 +246,35 @@ def refresh_template(): provider.chat_template_file = "chat_template.jinja" provider.chat_template_url = f"https://huggingface.co/{provider.model_name}/resolve/main/{provider.chat_template_file}" provider.cache_dir = os.path.expanduser("~/.cache/llama.cpp") - provider.chat_template_path = os.path.join(provider.cache_dir, f"{provider.model_name.replace('/', '_')}_{provider.chat_template_file}") - + provider.chat_template_path = os.path.join( + provider.cache_dir, + f"{provider.model_name.replace('/', '_')}_{provider.chat_template_file}", + ) + # Add the announce method for user feedback def _announce(message: str): console.print(f"[cyan]{message}[/cyan]") + provider._announce = _announce - + # Refresh the template success = provider.refresh_chat_template() - + if success: console.print() console.print("[green]โœ… Chat template refreshed successfully![/green]") - console.print(f"[dim]Template location: {provider.chat_template_path}[/dim]") - + console.print( + f"[dim]Template location: {provider.chat_template_path}[/dim]" + ) + # Show template info if os.path.exists(provider.chat_template_path): - with open(provider.chat_template_path, 'r') as f: + with open(provider.chat_template_path, "r") as f: content = f.read() - console.print(f"[dim]Template size: {len(content)} characters[/dim]") - + console.print( + f"[dim]Template size: {len(content)} characters[/dim]" + ) + # Show a preview of the template preview = content[:200] + "..." if len(content) > 200 else content console.print() @@ -280,14 +283,17 @@ def _announce(message: str): else: console.print() console.print("[red]โŒ Failed to refresh chat template[/red]") - console.print("[yellow]The agent will use the default template if available[/yellow]") + console.print( + "[yellow]The agent will use the default template if available[/yellow]" + ) sys.exit(1) - + except Exception as e: console.print() console.print(f"[red]โŒ Error refreshing template: {e}[/red]") sys.exit(1) + @app.command() def version(): """Show version information""" @@ -295,5 +301,6 @@ def version(): console.print("Version: 1.0.0") console.print("Built with โค๏ธ for ClickHouse analysis") + if __name__ == "__main__": - app() \ No newline at end of file + app() diff --git a/providers/__init__.py b/providers/__init__.py index e7073c9..22fb90b 100644 --- a/providers/__init__.py +++ b/providers/__init__.py @@ -1,4 +1,4 @@ # Providers package from .local_llm import LocalLLMProvider -__all__ = ["LocalLLMProvider"] \ No newline at end of file +__all__ = ["LocalLLMProvider"] diff --git a/providers/local_llm.py b/providers/local_llm.py index 23dc3a1..f9f39e2 100644 --- a/providers/local_llm.py +++ b/providers/local_llm.py @@ -2,19 +2,22 @@ Simplified Local LLM provider that manages llama-server process """ -import subprocess -import time -import signal import os -import psutil import shutil +import signal +import subprocess import sys +import time +from typing import Any, Dict, List, Optional + +import psutil import requests -from typing import Dict, List, Any, Optional + from utils.logging import get_logger try: from rich.console import Console + _console: Optional[Console] = Console() except Exception: _console = None @@ -25,26 +28,37 @@ class LocalLLMProvider: """Simple provider that manages llama-server process and provides connection info""" - def __init__(self, base_url: str = "http://127.0.0.1:8000/v1", model: str = "vishprometa/clickhouse-qwen3-1.7b-gguf"): + def __init__( + self, + base_url: str = "http://127.0.0.1:8000/v1", + model: str = "vishprometa/clickhouse-qwen3-1.7b-gguf", + ): self.base_url = base_url.rstrip("/") self.model = model self.host = "127.0.0.1" self.port = 8000 self.server_process = None - + # Model download configuration self.model_name = "vishprometa/clickhouse-qwen3-1.7b-gguf" self.display_name = "clickhouse-qwen3-1.7b-gguf" self.model_file = "unsloth.F16.gguf" - self.model_url = f"https://huggingface.co/{self.model_name}/resolve/main/{self.model_file}" + self.model_url = ( + f"https://huggingface.co/{self.model_name}/resolve/main/{self.model_file}" + ) self.cache_dir = os.path.expanduser("~/.cache/llama.cpp") - self.model_path = os.path.join(self.cache_dir, f"{self.model_name.replace('/', '_')}_{self.model_file}") - + self.model_path = os.path.join( + self.cache_dir, f"{self.model_name.replace('/', '_')}_{self.model_file}" + ) + # Chat template configuration self.chat_template_file = "chat_template.jinja" self.chat_template_url = f"https://huggingface.co/{self.model_name}/resolve/main/{self.chat_template_file}" - self.chat_template_path = os.path.join(self.cache_dir, f"{self.model_name.replace('/', '_')}_{self.chat_template_file}") - + self.chat_template_path = os.path.join( + self.cache_dir, + f"{self.model_name.replace('/', '_')}_{self.chat_template_file}", + ) + # Extract host and port from base_url try: if "://" in base_url: @@ -53,7 +67,7 @@ def __init__(self, base_url: str = "http://127.0.0.1:8000/v1", model: str = "vis host_port = url_part.split("/")[0] else: host_port = url_part - + if ":" in host_port: self.host, port_str = host_port.split(":") self.port = int(port_str) @@ -68,14 +82,19 @@ def _auto_setup(self): try: # Ensure llama-server is installed self._ensure_llama_server_installed() - + # Download model and chat template together - model_exists = os.path.exists(self.model_path) and os.path.getsize(self.model_path) > 3_000_000_000 + model_exists = ( + os.path.exists(self.model_path) + and os.path.getsize(self.model_path) > 3_000_000_000 + ) template_exists = os.path.exists(self.chat_template_path) - + logger.info(f"Model status: exists={model_exists}, path={self.model_path}") - logger.info(f"Template status: exists={template_exists}, path={self.chat_template_path}") - + logger.info( + f"Template status: exists={template_exists}, path={self.chat_template_path}" + ) + # If model doesn't exist, download both model and template if not model_exists: logger.info("Model not found, downloading model and template...") @@ -83,28 +102,38 @@ def _auto_setup(self): raise RuntimeError("Model download failed") # Download template right after model download if not self._download_chat_template(): - logger.warning("Chat template download failed, will try to start server without custom template") + logger.warning( + "Chat template download failed, will try to start server without custom template" + ) self.chat_template_path = None - + # If model exists but template doesn't (existing users), download just the template elif not template_exists: - logger.info("Model exists but template missing, downloading template...") - self._announce("๐Ÿ“ฅ Downloading missing chat template for existing model...") + logger.info( + "Model exists but template missing, downloading template..." + ) + self._announce( + "๐Ÿ“ฅ Downloading missing chat template for existing model..." + ) if not self._download_chat_template(): - logger.warning("Chat template download failed, will try to start server without custom template") + logger.warning( + "Chat template download failed, will try to start server without custom template" + ) self.chat_template_path = None else: logger.info("Both model and template exist, proceeding with startup...") - + # Check if server is already running if self._is_server_running(): - self._announce(f"โœ… ClickHouse AI Agent ready on {self.host}:{self.port}") + self._announce( + f"โœ… ClickHouse AI Agent ready on {self.host}:{self.port}" + ) return - + # Start the server with local model file self._start_server() self._wait_for_server() # No timeout for model download - + except Exception as e: logger.error(f"Auto-setup failed: {e}") raise RuntimeError(f"Failed to setup ClickHouse AI Agent: {e}") @@ -112,35 +141,44 @@ def _auto_setup(self): def _download_model_with_progress(self): """Download model with real progress tracking using curl""" try: - from rich.progress import Progress, SpinnerColumn, TextColumn, TimeElapsedColumn, BarColumn, TaskProgressColumn, DownloadColumn, TransferSpeedColumn - + from rich.progress import ( + BarColumn, + DownloadColumn, + Progress, + SpinnerColumn, + TaskProgressColumn, + TextColumn, + TimeElapsedColumn, + TransferSpeedColumn, + ) + # Ensure cache directory exists os.makedirs(self.cache_dir, exist_ok=True) - + # Check if model already exists if os.path.exists(self.model_path): file_size = os.path.getsize(self.model_path) if file_size > 3_000_000_000: # > 3GB, assume complete self._announce(f"โœ… Model already downloaded: {self.model_path}") return True - + self._announce(f"๐Ÿ“ฅ Downloading {self.display_name}...") - + # Get file size for progress tracking try: response = requests.head(self.model_url, allow_redirects=True) - total_size = int(response.headers.get('content-length', 0)) + total_size = int(response.headers.get("content-length", 0)) except Exception as e: logger.warning(f"Could not get file size: {e}") total_size = 3_447_349_440 # Known size ~3.2GB - + # Try simple download first (more reliable) if self._simple_download_with_progress(total_size): return True - + # Fallback to curl if simple download fails return self._curl_download_with_progress(total_size) - + except Exception as e: logger.error(f"Failed to download model: {e}") raise RuntimeError(f"Model download failed: {e}") @@ -153,7 +191,9 @@ def _download_chat_template(self, quiet: bool = False) -> bool: file_age = time.time() - os.path.getmtime(self.chat_template_path) if file_age < 86400: # 24 hours in seconds if not quiet: - logger.info(f"Chat template already exists and is recent: {self.chat_template_path}") + logger.info( + f"Chat template already exists and is recent: {self.chat_template_path}" + ) return True else: if not quiet: @@ -161,29 +201,35 @@ def _download_chat_template(self, quiet: bool = False) -> bool: else: if not quiet: self._announce("๐Ÿ“ฅ Downloading chat template...") - + # Ensure cache directory exists os.makedirs(self.cache_dir, exist_ok=True) - + # Download the chat template logger.info(f"Downloading chat template from: {self.chat_template_url}") response = requests.get(self.chat_template_url, timeout=30) response.raise_for_status() - + # Write to file - with open(self.chat_template_path, 'w', encoding='utf-8') as f: + with open(self.chat_template_path, "w", encoding="utf-8") as f: f.write(response.text) - + # Verify the file was written correctly file_size = os.path.getsize(self.chat_template_path) - logger.info(f"Chat template saved: {self.chat_template_path} ({file_size} bytes)") - + logger.info( + f"Chat template saved: {self.chat_template_path} ({file_size} bytes)" + ) + if not quiet: - self._announce(f"โœ… Chat template downloaded: {self.chat_template_path}") + self._announce( + f"โœ… Chat template downloaded: {self.chat_template_path}" + ) return True - + except requests.RequestException as e: - logger.error(f"Failed to download chat template from {self.chat_template_url}: {e}") + logger.error( + f"Failed to download chat template from {self.chat_template_url}: {e}" + ) return False except Exception as e: logger.error(f"Failed to save chat template: {e}") @@ -196,10 +242,10 @@ def refresh_chat_template(self) -> bool: if os.path.exists(self.chat_template_path): os.remove(self.chat_template_path) self._announce("๐Ÿ—‘๏ธ Removed old chat template") - + # Download fresh template return self._download_chat_template() - + except Exception as e: logger.error(f"Failed to refresh chat template: {e}") return False @@ -207,8 +253,17 @@ def refresh_chat_template(self) -> bool: def _simple_download_with_progress(self, total_size: int) -> bool: """Simple download using requests with progress tracking""" try: - from rich.progress import Progress, SpinnerColumn, TextColumn, TimeElapsedColumn, BarColumn, TaskProgressColumn, DownloadColumn, TransferSpeedColumn - + from rich.progress import ( + BarColumn, + DownloadColumn, + Progress, + SpinnerColumn, + TaskProgressColumn, + TextColumn, + TimeElapsedColumn, + TransferSpeedColumn, + ) + with Progress( SpinnerColumn(), TextColumn("[progress.description]{task.description}"), @@ -221,38 +276,50 @@ def _simple_download_with_progress(self, total_size: int) -> bool: transient=False, ) as progress: task = progress.add_task( - f"๐Ÿ”„ Downloading {self.display_name}...", - total=total_size + f"๐Ÿ”„ Downloading {self.display_name}...", total=total_size ) - + # Download with requests and stream - response = requests.get(self.model_url, stream=True, allow_redirects=True) + response = requests.get( + self.model_url, stream=True, allow_redirects=True + ) response.raise_for_status() - + downloaded_size = 0 - with open(self.model_path, 'wb') as f: + with open(self.model_path, "wb") as f: for chunk in response.iter_content(chunk_size=8192): if chunk: f.write(chunk) downloaded_size += len(chunk) progress.update(task, completed=downloaded_size) - + if downloaded_size > 3_000_000_000: # > 3GB - progress.update(task, description="โœ… Model download completed!", completed=total_size) - self._announce(f"โœ… Model downloaded successfully: {self.model_path}") - + progress.update( + task, + description="โœ… Model download completed!", + completed=total_size, + ) + self._announce( + f"โœ… Model downloaded successfully: {self.model_path}" + ) + # Download chat template immediately after successful model download progress.update(task, description="๐Ÿ“ฅ Downloading chat template...") if self._download_chat_template(quiet=True): - progress.update(task, description="โœ… Model and template ready!") + progress.update( + task, description="โœ… Model and template ready!" + ) else: - progress.update(task, description="โœ… Model ready (template download failed)") - + progress.update( + task, + description="โœ… Model ready (template download failed)", + ) + return True else: progress.update(task, description="โŒ Download incomplete") return False - + except Exception as e: logger.warning(f"Simple download failed, trying curl: {e}") return False @@ -260,8 +327,17 @@ def _simple_download_with_progress(self, total_size: int) -> bool: def _curl_download_with_progress(self, total_size: int) -> bool: """Download using curl with progress tracking""" try: - from rich.progress import Progress, SpinnerColumn, TextColumn, TimeElapsedColumn, BarColumn, TaskProgressColumn, DownloadColumn, TransferSpeedColumn - + from rich.progress import ( + BarColumn, + DownloadColumn, + Progress, + SpinnerColumn, + TaskProgressColumn, + TextColumn, + TimeElapsedColumn, + TransferSpeedColumn, + ) + with Progress( SpinnerColumn(), TextColumn("[progress.description]{task.description}"), @@ -274,28 +350,30 @@ def _curl_download_with_progress(self, total_size: int) -> bool: transient=False, ) as progress: task = progress.add_task( - f"๐Ÿ”„ Downloading {self.display_name}...", - total=total_size + f"๐Ÿ”„ Downloading {self.display_name}...", total=total_size ) - + # Use curl for reliable download with progress cmd = [ - "curl", "-L", "-o", self.model_path, + "curl", + "-L", + "-o", + self.model_path, "--progress-bar", - self.model_url + self.model_url, ] - + process = subprocess.Popen( cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=True, - bufsize=1 + bufsize=1, ) - + downloaded_size = 0 last_update_time = time.time() - + # Simple progress tracking based on file size start_time = time.time() while process.poll() is None: @@ -306,7 +384,7 @@ def _curl_download_with_progress(self, total_size: int) -> bool: downloaded_size = current_size progress.update(task, completed=downloaded_size) last_update_time = time.time() - + # Update description periodically current_time = time.time() if current_time - last_update_time > 2.0: # Update every 2 seconds @@ -315,35 +393,59 @@ def _curl_download_with_progress(self, total_size: int) -> bool: if elapsed > 0 and downloaded_size > 0: # Calculate estimated speed and remaining time speed = downloaded_size / elapsed - remaining = (total_size - downloaded_size) / speed if speed > 0 else 0 - progress.update(task, description=f"๐Ÿ”„ Downloading {self.display_name}... ({downloaded_size}/{total_size} bytes, {speed/1024/1024:.1f} MB/s)") + remaining = ( + (total_size - downloaded_size) / speed + if speed > 0 + else 0 + ) + progress.update( + task, + description=f"๐Ÿ”„ Downloading {self.display_name}... ({downloaded_size}/{total_size} bytes, {speed/1024/1024:.1f} MB/s)", + ) last_update_time = current_time - + time.sleep(0.5) # Check every 0.5 seconds - + process.wait() - + if process.returncode == 0 and os.path.exists(self.model_path): file_size = os.path.getsize(self.model_path) if file_size > 3_000_000_000: # > 3GB - progress.update(task, description="โœ… Model download completed!", completed=total_size) - self._announce(f"โœ… Model downloaded successfully: {self.model_path}") - + progress.update( + task, + description="โœ… Model download completed!", + completed=total_size, + ) + self._announce( + f"โœ… Model downloaded successfully: {self.model_path}" + ) + # Download chat template immediately after successful model download - progress.update(task, description="๐Ÿ“ฅ Downloading chat template...") + progress.update( + task, description="๐Ÿ“ฅ Downloading chat template..." + ) if self._download_chat_template(quiet=True): - progress.update(task, description="โœ… Model and template ready!") + progress.update( + task, description="โœ… Model and template ready!" + ) else: - progress.update(task, description="โœ… Model ready (template download failed)") - + progress.update( + task, + description="โœ… Model ready (template download failed)", + ) + return True else: progress.update(task, description="โŒ Download incomplete") - raise RuntimeError(f"Downloaded file too small: {file_size} bytes") + raise RuntimeError( + f"Downloaded file too small: {file_size} bytes" + ) else: progress.update(task, description="โŒ Download failed") - raise RuntimeError(f"Download failed with return code: {process.returncode}") - + raise RuntimeError( + f"Download failed with return code: {process.returncode}" + ) + except Exception as e: logger.error(f"Failed to download model: {e}") raise RuntimeError(f"Model download failed: {e}") @@ -353,9 +455,9 @@ def _ensure_llama_server_installed(self): if shutil.which("llama-server"): logger.info("llama-server is already installed") return - + self._announce("๐Ÿ“ฆ Installing llama.cpp for ClickHouse AI Agent...") - + try: # Detect platform and install accordingly if sys.platform == "darwin": # macOS @@ -364,28 +466,30 @@ def _ensure_llama_server_installed(self): self._install_llama_server_linux() else: raise RuntimeError(f"Unsupported platform: {sys.platform}") - + except Exception as e: logger.error(f"Failed to install llama-server: {e}") raise RuntimeError( f"Could not install llama-server automatically. " f"Please install llama.cpp manually: https://github.com/ggerganov/llama.cpp" ) - + def _install_llama_server_macos(self): """Install llama-server on macOS using Homebrew""" try: # Check if Homebrew is available if not shutil.which("brew"): - raise RuntimeError("Homebrew is required but not installed. Please install Homebrew first.") - + raise RuntimeError( + "Homebrew is required but not installed. Please install Homebrew first." + ) + # Install llama.cpp subprocess.run(["brew", "install", "llama.cpp"], check=True) self._announce("โœ… llama.cpp installed successfully via Homebrew") - + except subprocess.CalledProcessError as e: raise RuntimeError(f"Failed to install llama.cpp via Homebrew: {e}") - + def _install_llama_server_linux(self): """Install llama-server on Linux""" try: @@ -399,41 +503,51 @@ def _install_llama_server_linux(self): else: # Fallback to building from source self._build_llama_cpp_from_source() - + except Exception as e: raise RuntimeError(f"Failed to install llama.cpp on Linux: {e}") - + def _build_llama_cpp_from_source(self): """Build llama.cpp from source (Linux fallback)""" import tempfile - + with tempfile.TemporaryDirectory() as temp_dir: self._announce("Building llama.cpp from source...") - + # Clone the repository - subprocess.run([ - "git", "clone", "https://github.com/ggerganov/llama.cpp.git", - os.path.join(temp_dir, "llama.cpp") - ], check=True) - + subprocess.run( + [ + "git", + "clone", + "https://github.com/ggerganov/llama.cpp.git", + os.path.join(temp_dir, "llama.cpp"), + ], + check=True, + ) + # Build build_dir = os.path.join(temp_dir, "llama.cpp") subprocess.run(["make", "-j", "4"], cwd=build_dir, check=True) - + # Install to /usr/local/bin llama_server_src = os.path.join(build_dir, "llama-server") llama_server_dst = "/usr/local/bin/llama-server" - - subprocess.run(["sudo", "cp", llama_server_src, llama_server_dst], check=True) + + subprocess.run( + ["sudo", "cp", llama_server_src, llama_server_dst], check=True + ) subprocess.run(["sudo", "chmod", "+x", llama_server_dst], check=True) - + self._announce("โœ… llama.cpp built and installed successfully") def _is_server_running(self) -> bool: """Check if llama-server is already running on our port""" try: import requests - response = requests.get(f"http://{self.host}:{self.port}/v1/models", timeout=2) + + response = requests.get( + f"http://{self.host}:{self.port}/v1/models", timeout=2 + ) return response.status_code == 200 except Exception: return False @@ -441,63 +555,84 @@ def _is_server_running(self) -> bool: def _start_server(self): """Start llama-server process""" try: - self._announce(f"๐Ÿš€ Starting ClickHouse AI Agent on {self.host}:{self.port}") - + self._announce( + f"๐Ÿš€ Starting ClickHouse AI Agent on {self.host}:{self.port}" + ) + # Kill any existing server on this port self._kill_existing_server() - + # Use the downloaded chat template if available cmd = [ "llama-server", - "-m", self.model_path, # Use local model file instead of Hugging Face URL + "-m", + self.model_path, # Use local model file instead of Hugging Face URL "--jinja", ] - + # Add chat template file if available if self.chat_template_path and os.path.exists(self.chat_template_path): cmd.extend(["--chat-template-file", self.chat_template_path]) logger.info(f"โœ… Using custom chat template: {self.chat_template_path}") # Removed the announcement message as requested else: - logger.warning("โŒ No custom chat template found, using llama-server default") + logger.warning( + "โŒ No custom chat template found, using llama-server default" + ) # Removed the announcement message as requested - + # Continue with other parameters - cmd.extend([ - "--reasoning-format", "deepseek", - "-ngl", "99", - "-fa", - "-sm", "row", - "--temp", "0.6", - "--top-k", "20", - "--top-p", "0.95", - "--min-p", "0", - "-c", "40960", - "-n", "32768", - "--no-context-shift", - "--host", self.host, - "--port", str(self.port) - ]) - + cmd.extend( + [ + "--reasoning-format", + "deepseek", + "-ngl", + "99", + "-fa", + "-sm", + "row", + "--temp", + "0.6", + "--top-k", + "20", + "--top-p", + "0.95", + "--min-p", + "0", + "-c", + "40960", + "-n", + "32768", + "--no-context-shift", + "--host", + self.host, + "--port", + str(self.port), + ] + ) + # Log the full command for debugging logger.info(f"Starting llama-server with command: {' '.join(cmd)}") - + # Start in background logger.info("Attempting to start llama-server process...") self.server_process = subprocess.Popen( cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, - preexec_fn=os.setsid if os.name != 'nt' else None, - text=True + preexec_fn=os.setsid if os.name != "nt" else None, + text=True, ) - - logger.info(f"Started ClickHouse AI Agent with PID {self.server_process.pid}") - + + logger.info( + f"Started ClickHouse AI Agent with PID {self.server_process.pid}" + ) + # Give the process a moment to start and check if it's still running import time + time.sleep(2) - + if self.server_process.poll() is not None: # Process has already terminated stdout, stderr = self.server_process.communicate() @@ -507,10 +642,14 @@ def _start_server(self): logger.error(f"STDOUT: {stdout}") if stderr: logger.error(f"STDERR: {stderr}") - raise RuntimeError(f"llama-server failed to start. Exit code: {self.server_process.returncode}") + raise RuntimeError( + f"llama-server failed to start. Exit code: {self.server_process.returncode}" + ) else: - logger.info("llama-server process is running, proceeding with startup...") - + logger.info( + "llama-server process is running, proceeding with startup..." + ) + except Exception as e: logger.error(f"Failed to start ClickHouse AI Agent: {e}") raise @@ -518,13 +657,20 @@ def _start_server(self): def _kill_existing_server(self): """Kill any existing llama-server processes on our port""" try: - for proc in psutil.process_iter(['pid', 'name', 'cmdline']): + for proc in psutil.process_iter(["pid", "name", "cmdline"]): try: - if 'llama-server' in proc.info['name'] or 'llama-server' in ' '.join(proc.info['cmdline'] or []): + if "llama-server" in proc.info[ + "name" + ] or "llama-server" in " ".join(proc.info["cmdline"] or []): # Check if it's using our port - cmdline = ' '.join(proc.info['cmdline'] or []) - if f"--port {self.port}" in cmdline or f":{self.port}" in cmdline: - logger.info(f"Stopping existing ClickHouse AI Agent process {proc.info['pid']}") + cmdline = " ".join(proc.info["cmdline"] or []) + if ( + f"--port {self.port}" in cmdline + or f":{self.port}" in cmdline + ): + logger.info( + f"Stopping existing ClickHouse AI Agent process {proc.info['pid']}" + ) proc.terminate() try: proc.wait(timeout=5) @@ -538,7 +684,7 @@ def _kill_existing_server(self): def _wait_for_server(self, timeout: int = 60): """Wait for server to be ready with simple progress""" from rich.progress import Progress, SpinnerColumn, TextColumn, TimeElapsedColumn - + with Progress( SpinnerColumn(), TextColumn("[progress.description]{task.description}"), @@ -547,18 +693,22 @@ def _wait_for_server(self, timeout: int = 60): transient=False, ) as progress: task = progress.add_task("๐Ÿ”„ Starting ClickHouse AI Agent...", total=None) - + start_time = time.time() - + while time.time() - start_time < timeout: if self._is_server_running(): progress.update(task, description="โœ… ClickHouse AI Agent ready!") return - + time.sleep(1) # Check every second - - progress.update(task, description="โŒ Timeout waiting for ClickHouse AI Agent") - raise TimeoutError(f"ClickHouse AI Agent did not start within {timeout} seconds") + + progress.update( + task, description="โŒ Timeout waiting for ClickHouse AI Agent" + ) + raise TimeoutError( + f"ClickHouse AI Agent did not start within {timeout} seconds" + ) def _announce(self, message: str): """Announce status message""" @@ -572,7 +722,7 @@ def get_openai_config(self) -> Dict[str, Any]: """Get configuration for OpenAI client""" return { "base_url": self.base_url, - "api_key": "not-needed" # llama-server doesn't require API key + "api_key": "not-needed", # llama-server doesn't require API key } async def chat_completion(self, *args, **kwargs): @@ -587,30 +737,30 @@ async def close(self): if self.server_process: try: # Terminate the process group - if os.name != 'nt': + if os.name != "nt": os.killpg(os.getpgid(self.server_process.pid), signal.SIGTERM) else: self.server_process.terminate() - + # Wait for graceful shutdown try: self.server_process.wait(timeout=5) except subprocess.TimeoutExpired: # Force kill if needed - if os.name != 'nt': + if os.name != "nt": os.killpg(os.getpgid(self.server_process.pid), signal.SIGKILL) else: self.server_process.kill() - + logger.info("ClickHouse AI Agent process terminated") except Exception as e: logger.error(f"Error terminating ClickHouse AI Agent: {e}") def __del__(self): """Cleanup on deletion""" - if hasattr(self, 'server_process') and self.server_process: + if hasattr(self, "server_process") and self.server_process: try: if self.server_process.poll() is None: # Still running self.server_process.terminate() except Exception: - pass \ No newline at end of file + pass diff --git a/setup.py b/setup.py index fc760aa..ffc0fdd 100644 --- a/setup.py +++ b/setup.py @@ -3,9 +3,10 @@ Setup script for Proto ClickHouse AI Agent """ -from setuptools import setup, find_packages from pathlib import Path +from setuptools import find_packages, setup + # Read the README file this_directory = Path(__file__).parent long_description = (this_directory / "README.md").read_text() @@ -76,4 +77,4 @@ }, include_package_data=True, zip_safe=False, -) \ No newline at end of file +) diff --git a/tools/__init__.py b/tools/__init__.py index f3c0b28..4d868e2 100644 --- a/tools/__init__.py +++ b/tools/__init__.py @@ -1 +1 @@ -# Tools package \ No newline at end of file +# Tools package diff --git a/tools/clickhouse_tools.py b/tools/clickhouse_tools.py index dc08b57..67e3fa6 100644 --- a/tools/clickhouse_tools.py +++ b/tools/clickhouse_tools.py @@ -3,32 +3,34 @@ """ import json -import pandas as pd -import plotly.graph_objects as go -import plotly.express as px from pathlib import Path -from typing import Dict, List, Any, Optional, Union +from typing import Any, Dict, List, Optional, Union + import clickhouse_connect +import pandas as pd +import plotly.express as px +import plotly.graph_objects as go from clickhouse_connect.driver import Client from rich.console import Console -from rich.table import Table from rich.panel import Panel from rich.syntax import Syntax +from rich.table import Table from config.settings import ClickHouseConfig -from utils.logging import get_logger, ToolExecutionLogger +from utils.logging import ToolExecutionLogger, get_logger logger = get_logger(__name__) console = Console() tool_logger = ToolExecutionLogger(console) + class ClickHouseConnection: """ClickHouse database connection manager""" - + def __init__(self, config: ClickHouseConfig): self.config = config self.client: Optional[Client] = None - + async def connect(self) -> Client: """Establish connection to ClickHouse""" if not self.client: @@ -39,36 +41,33 @@ async def connect(self) -> Client: username=self.config.username, password=self.config.password, database=self.config.database, - secure=self.config.secure + secure=self.config.secure, ) - + # Test connection result = self.client.query("SELECT 1 as test") logger.info("ClickHouse connection established successfully") - + return self.client except Exception as e: logger.error(f"Failed to connect to ClickHouse: {e}") raise - + return self.client - + def close(self): """Close ClickHouse connection""" if self.client: self.client.close() self.client = None + # Simplified tool definitions - just the essentials for data analysis CLICKHOUSE_TOOLS = [ { "name": "list_databases", "description": "List all available databases in ClickHouse, returns JSON with database names", - "input_schema": { - "type": "object", - "properties": {}, - "required": [] - } + "input_schema": {"type": "object", "properties": {}, "required": []}, }, { "name": "switch_database", @@ -78,11 +77,11 @@ def close(self): "properties": { "database_name": { "type": "string", - "description": "Name of the database to switch to" + "description": "Name of the database to switch to", } }, - "required": ["database_name"] - } + "required": ["database_name"], + }, }, { "name": "execute_clickhouse_query", @@ -90,22 +89,15 @@ def close(self): "input_schema": { "type": "object", "properties": { - "query": { - "type": "string", - "description": "SQL query to execute" - } + "query": {"type": "string", "description": "SQL query to execute"} }, - "required": ["query"] - } + "required": ["query"], + }, }, { "name": "list_tables", "description": "List all tables in the current database, returns JSON with table names and count", - "input_schema": { - "type": "object", - "properties": {}, - "required": [] - } + "input_schema": {"type": "object", "properties": {}, "required": []}, }, { "name": "get_table_schema", @@ -115,11 +107,11 @@ def close(self): "properties": { "table_name": { "type": "string", - "description": "Name of the table to get schema for" + "description": "Name of the table to get schema for", } }, - "required": ["table_name"] - } + "required": ["table_name"], + }, }, { "name": "search_table", @@ -129,20 +121,20 @@ def close(self): "properties": { "table_name": { "type": "string", - "description": "Name of the table to search" + "description": "Name of the table to search", }, "limit": { "type": "integer", "description": "Number of rows to return (default: 100)", - "default": 100 + "default": 100, }, "where_clause": { "type": "string", - "description": "Optional WHERE clause for filtering" - } + "description": "Optional WHERE clause for filtering", + }, }, - "required": ["table_name"] - } + "required": ["table_name"], + }, }, { "name": "export_data_to_csv", @@ -152,20 +144,20 @@ def close(self): "properties": { "query": { "type": "string", - "description": "SQL query to execute and export" + "description": "SQL query to execute and export", }, "filename": { "type": "string", - "description": "Optional filename for the CSV export (without extension)" + "description": "Optional filename for the CSV export (without extension)", }, "analysis_limit": { "type": "integer", "description": "Number of rows to show for analysis (default: 50)", - "default": 50 - } + "default": 50, + }, }, - "required": ["query"] - } + "required": ["query"], + }, }, { "name": "stop_agent", @@ -175,64 +167,68 @@ def close(self): "properties": { "summary": { "type": "string", - "description": "Summary of what was accomplished" + "description": "Summary of what was accomplished", } }, - "required": ["summary"] - } - } + "required": ["summary"], + }, + }, ] + class ClickHouseToolExecutor: """Executes ClickHouse-specific tools""" - + def __init__(self, connection: ClickHouseConnection): self.connection = connection self.client = None - + async def initialize(self): """Initialize the connection""" self.client = await self.connection.connect() - + async def execute_clickhouse_query(self, query: str) -> str: """Execute any SQL query and return formatted results""" - + tool_logger.log_tool_start("execute_clickhouse_query", {"query": query[:100]}) - + try: # Show the actual SQL query being executed from ui.minimal_interface import ui + ui.show_query_execution(query) - + # Execute query import time + start_time = time.time() result = self.client.query(query) duration = time.time() - start_time - + # Convert to DataFrame for easier manipulation import pandas as pd + df = pd.DataFrame(result.result_rows, columns=result.column_names) - + tool_logger.log_query_execution(query, duration, len(df)) - + # Always show results in table format for better analysis from ui.minimal_interface import ui - + if len(df) > 0: # Convert to list of dicts for the UI - data_list = df.to_dict('records') - + data_list = df.to_dict("records") + # Show with smart large dataset handling display_limit = 100 if len(df) > 100 else len(df) - + ui.show_data_table( - data_list[:display_limit], - title="Query Results", + data_list[:display_limit], + title="Query Results", max_rows=display_limit, - total_rows=len(df) + total_rows=len(df), ) - + # For queries that return structured data, provide JSON # AGGRESSIVE data limiting to prevent context overflow if len(df) <= 10: # Only return full JSON for very small results @@ -241,7 +237,7 @@ async def execute_clickhouse_query(self, query: str) -> str: "rows": data_list, "row_count": len(df), "columns": list(df.columns), - "summary": f"Query returned {len(df)} rows" + "summary": f"Query returned {len(df)} rows", } output = json.dumps(result_data, indent=2, default=str) elif len(df) <= 50: # Return very limited rows for small results @@ -250,13 +246,13 @@ async def execute_clickhouse_query(self, query: str) -> str: "rows": data_list[:5], # Only first 5 rows for context "row_count": len(df), "columns": list(df.columns), - "summary": f"Query returned {len(df)} rows (showing first 5 for analysis). Full results displayed above." + "summary": f"Query returned {len(df)} rows (showing first 5 for analysis). Full results displayed above.", } output = json.dumps(limited_data, indent=2, default=str) else: # For any larger results, just return summary to avoid context overflow output = f"SUCCESS: Query returned {len(df):,} rows with {len(df.columns)} columns: {', '.join(df.columns[:5])}{'...' if len(df.columns) > 5 else ''}. Results displayed in table above. Use export_data_to_csv for large dataset analysis." - + else: ui.console.print("[dim yellow]No results found[/dim yellow]") result_data = { @@ -264,153 +260,162 @@ async def execute_clickhouse_query(self, query: str) -> str: "rows": [], "row_count": 0, "columns": [], - "summary": "No results found" + "summary": "No results found", } output = json.dumps(result_data, indent=2) - - tool_logger.log_tool_success("execute_clickhouse_query", f"Returned {len(df)} rows") + + tool_logger.log_tool_success( + "execute_clickhouse_query", f"Returned {len(df)} rows" + ) return output - + except Exception as e: error_msg = f"Query execution failed: {str(e)}" tool_logger.log_tool_error("execute_clickhouse_query", error_msg) return error_msg - + async def list_tables(self) -> str: """List all tables in the database""" - + tool_logger.log_tool_start("list_tables", {}) - + try: query = "SHOW TABLES" result = self.client.query(query) df = pd.DataFrame(result.result_rows, columns=result.column_names) - + if len(df) > 0: # Get table names as a list table_names = [str(row.iloc[0]) for _, row in df.iterrows()] - + # Display the table nicely table = Table(title=f"Tables in database ({len(df)} tables)") table.add_column("Table Name") - + for table_name in table_names: table.add_row(table_name) - + console.print(table) - + # Return JSON data - result_data = { - "tables": table_names, - "count": len(table_names) - } + result_data = {"tables": table_names, "count": len(table_names)} output = json.dumps(result_data, indent=2) else: - result_data = { - "tables": [], - "count": 0 - } + result_data = {"tables": [], "count": 0} output = json.dumps(result_data, indent=2) - - tool_logger.log_tool_success("list_tables", f"Found {len(table_names) if len(df) > 0 else 0} tables") + + tool_logger.log_tool_success( + "list_tables", f"Found {len(table_names) if len(df) > 0 else 0} tables" + ) return output - + except Exception as e: error_msg = f"Failed to list tables: {str(e)}" tool_logger.log_tool_error("list_tables", error_msg) return json.dumps({"error": error_msg}) - + async def get_table_schema(self, table_name: str) -> str: """Get the schema/structure of a table""" - + tool_logger.log_tool_start("get_table_schema", {"table_name": table_name}) - + try: # Get table structure describe_query = f"DESCRIBE TABLE {table_name}" result = self.client.query(describe_query) df = pd.DataFrame(result.result_rows, columns=result.column_names) - + # Display table structure table = Table(title=f"Schema: {table_name}") table.add_column("Column") table.add_column("Type") table.add_column("Default Type") table.add_column("Default Expression") - + # Build JSON schema data columns = [] for _, row in df.iterrows(): column_info = { - "name": str(row['name']), - "type": str(row['type']), - "default_type": str(row.get('default_type', '')), - "default_expression": str(row.get('default_expression', '')) + "name": str(row["name"]), + "type": str(row["type"]), + "default_type": str(row.get("default_type", "")), + "default_expression": str(row.get("default_expression", "")), } columns.append(column_info) - + # Add to display table table.add_row( column_info["name"], column_info["type"], column_info["default_type"], - column_info["default_expression"] + column_info["default_expression"], ) - + console.print(table) - + # Return JSON data result_data = { "table_name": table_name, "columns": columns, - "column_count": len(columns) + "column_count": len(columns), } output = json.dumps(result_data, indent=2) - - tool_logger.log_tool_success("get_table_schema", f"Retrieved schema for {table_name} with {len(columns)} columns") + + tool_logger.log_tool_success( + "get_table_schema", + f"Retrieved schema for {table_name} with {len(columns)} columns", + ) return output - + except Exception as e: error_msg = f"Failed to get schema for table {table_name}: {str(e)}" tool_logger.log_tool_error("get_table_schema", error_msg) return json.dumps({"error": error_msg, "table_name": table_name}) - - async def search_table(self, table_name: str, limit: int = 100, where_clause: str = None) -> str: + + async def search_table( + self, table_name: str, limit: int = 100, where_clause: str = None + ) -> str: """Search and preview data in a table""" - - tool_logger.log_tool_start("search_table", {"table_name": table_name, "limit": limit, "where_clause": where_clause}) - + + tool_logger.log_tool_start( + "search_table", + {"table_name": table_name, "limit": limit, "where_clause": where_clause}, + ) + try: # Build query query = f"SELECT * FROM {table_name}" if where_clause: query += f" WHERE {where_clause}" query += f" LIMIT {limit}" - + # Execute query directly to get structured data result = self.client.query(query) df = pd.DataFrame(result.result_rows, columns=result.column_names) - + # Display the results from ui.minimal_interface import ui + if len(df) > 0: - data_list = df.to_dict('records') + data_list = df.to_dict("records") ui.show_data_table( - data_list, - title=f"Search Results: {table_name}", + data_list, + title=f"Search Results: {table_name}", max_rows=min(50, len(df)), - total_rows=len(df) + total_rows=len(df), ) - + # Return JSON data result_data = { "table_name": table_name, "query": query, "rows": data_list, "row_count": len(df), - "columns": list(df.columns) + "columns": list(df.columns), } - output = json.dumps(result_data, indent=2, default=str) # default=str handles dates/timestamps + output = json.dumps( + result_data, indent=2, default=str + ) # default=str handles dates/timestamps else: ui.console.print("[dim yellow]No results found[/dim yellow]") result_data = { @@ -418,158 +423,167 @@ async def search_table(self, table_name: str, limit: int = 100, where_clause: st "query": query, "rows": [], "row_count": 0, - "columns": [] + "columns": [], } output = json.dumps(result_data, indent=2) - - tool_logger.log_tool_success("search_table", f"Searched table {table_name}, found {len(df)} rows") + + tool_logger.log_tool_success( + "search_table", f"Searched table {table_name}, found {len(df)} rows" + ) return output - + except Exception as e: error_msg = f"Failed to search table {table_name}: {str(e)}" tool_logger.log_tool_error("search_table", error_msg) return json.dumps({"error": error_msg, "table_name": table_name}) - + async def export_data_to_csv( - self, - query: str, - filename: str = None, - analysis_limit: int = 50 + self, query: str, filename: str = None, analysis_limit: int = 50 ) -> str: """Export query results to CSV with truncated analysis display""" - - tool_logger.log_tool_start("export_data_to_csv", { - "query": query[:100], - "filename": filename, - "analysis_limit": analysis_limit - }) - + + tool_logger.log_tool_start( + "export_data_to_csv", + { + "query": query[:100], + "filename": filename, + "analysis_limit": analysis_limit, + }, + ) + try: - from ui.minimal_interface import ui import os from datetime import datetime - + + from ui.minimal_interface import ui + # Show the query being executed ui.show_query_execution(query) - + # Execute query import time + start_time = time.time() result = self.client.query(query) duration = time.time() - start_time - + # Convert to DataFrame import pandas as pd + df = pd.DataFrame(result.result_rows, columns=result.column_names) - + if len(df) == 0: return "No data returned from query" - + # Generate filename if not provided if not filename: timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") filename = f"data_export_{timestamp}" - + # Ensure exports directory exists os.makedirs("exports", exist_ok=True) filepath = os.path.join("exports", f"{filename}.csv") - + # Export full data to CSV df.to_csv(filepath, index=False) - + file_size = os.path.getsize(filepath) file_size_mb = file_size / (1024 * 1024) - + # Show truncated data for analysis display_rows = min(analysis_limit, len(df)) truncated_df = df.head(display_rows) - data_list = truncated_df.to_dict('records') - + data_list = truncated_df.to_dict("records") + # Display the truncated results ui.show_data_table( - data_list, - title=f"Analysis Preview ({display_rows} of {len(df):,} rows)", + data_list, + title=f"Analysis Preview ({display_rows} of {len(df):,} rows)", max_rows=display_rows, - total_rows=len(df) + total_rows=len(df), ) - + # Show export information ui.console.print() - ui.console.print(f"[dim bright_green]๐Ÿ“[/dim bright_green] [bright_white]Data Export Complete![/bright_white]") + ui.console.print( + f"[dim bright_green]๐Ÿ“[/dim bright_green] [bright_white]Data Export Complete![/bright_white]" + ) ui.console.print(f"[dim]โ€ข Full dataset: {filepath}[/dim]") ui.console.print(f"[dim]โ€ข File size: {file_size_mb:.2f} MB[/dim]") ui.console.print(f"[dim]โ€ข Total rows: {len(df):,}[/dim]") ui.console.print(f"[dim]โ€ข Analysis shown: {display_rows} rows[/dim]") - ui.console.print(f"[yellow]๐Ÿ’ก Open the CSV file for complete data analysis[/yellow]") + ui.console.print( + f"[yellow]๐Ÿ’ก Open the CSV file for complete data analysis[/yellow]" + ) ui.console.print() - - tool_logger.log_tool_success("export_data_to_csv", f"Exported {len(df)} rows to {filepath}") - + + tool_logger.log_tool_success( + "export_data_to_csv", f"Exported {len(df)} rows to {filepath}" + ) + # Return structured data for the LLM result_data = { "query": query, "total_rows": len(df), - "analysis_rows": truncated_df.to_dict('records'), + "analysis_rows": truncated_df.to_dict("records"), "export_file": filepath, "file_size_mb": round(file_size_mb, 2), "columns": list(df.columns), "summary": f"Exported {len(df):,} rows to CSV. Analysis shows {display_rows} sample rows.", - "message": f"Full dataset with {len(df):,} rows exported to {filepath}. Analysis limited to {display_rows} rows for display. Open CSV file for complete data." + "message": f"Full dataset with {len(df):,} rows exported to {filepath}. Analysis limited to {display_rows} rows for display. Open CSV file for complete data.", } - + return json.dumps(result_data, indent=2, default=str) - + except Exception as e: error_msg = f"CSV export failed: {str(e)}" tool_logger.log_tool_error("export_data_to_csv", error_msg) return error_msg - + async def export_query_results( - self, - query: str, - filename: str = None, - format: str = "csv", - limit: int = None + self, query: str, filename: str = None, format: str = "csv", limit: int = None ) -> str: """Export large query results to file for datasets too big to display""" - - tool_logger.log_tool_start("export_query_results", { - "query": query[:100], - "format": format, - "limit": limit - }) - + + tool_logger.log_tool_start( + "export_query_results", + {"query": query[:100], "format": format, "limit": limit}, + ) + try: - from ui.minimal_interface import ui import os from datetime import datetime - + + from ui.minimal_interface import ui + # Add limit if specified if limit and "LIMIT" not in query.upper(): query = f"{query.rstrip(';')} LIMIT {limit}" - + # Show the query being executed ui.show_query_execution(query) - + # Execute query import time + start_time = time.time() result = self.client.query(query) duration = time.time() - start_time - + # Convert to DataFrame import pandas as pd + df = pd.DataFrame(result.result_rows, columns=result.column_names) - + # Generate filename if not provided if not filename: timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") filename = f"clickhouse_export_{timestamp}.{format}" - + # Ensure exports directory exists os.makedirs("exports", exist_ok=True) filepath = os.path.join("exports", filename) - + # Export based on format if format.lower() == "csv": df.to_csv(filepath, index=False) @@ -579,26 +593,30 @@ async def export_query_results( df.to_excel(filepath, index=False) else: return f"Unsupported export format: {format}. Use csv, json, or excel." - + file_size = os.path.getsize(filepath) file_size_mb = file_size / (1024 * 1024) - + ui.console.print() - ui.console.print(f"[dim bright_green]๐Ÿ“[/dim bright_green] [bright_white]Export completed![/bright_white]") + ui.console.print( + f"[dim bright_green]๐Ÿ“[/dim bright_green] [bright_white]Export completed![/bright_white]" + ) ui.console.print(f"[dim]File: {filepath}[/dim]") - ui.console.print(f"[dim]Size: {file_size_mb:.2f} MB ({len(df):,} rows)[/dim]") + ui.console.print( + f"[dim]Size: {file_size_mb:.2f} MB ({len(df):,} rows)[/dim]" + ) ui.console.print() - - tool_logger.log_tool_success("export_query_results", f"Exported {len(df)} rows to {filepath}") - + + tool_logger.log_tool_success( + "export_query_results", f"Exported {len(df)} rows to {filepath}" + ) + return f"Successfully exported {len(df):,} rows to {filepath} ({file_size_mb:.2f} MB) in {duration:.2f}s" - + except Exception as e: error_msg = f"Export failed: {str(e)}" tool_logger.log_tool_error("export_query_results", error_msg) return error_msg - - # Simplified OpenAI format tools @@ -611,26 +629,19 @@ async def export_query_results( "parameters": { "type": "object", "properties": { - "query": { - "type": "string", - "description": "SQL query to execute" - } + "query": {"type": "string", "description": "SQL query to execute"} }, - "required": ["query"] - } - } + "required": ["query"], + }, + }, }, { "type": "function", "function": { "name": "list_tables", "description": "List all tables in the current database, returns JSON with table names and count", - "parameters": { - "type": "object", - "properties": {}, - "required": [] - } - } + "parameters": {"type": "object", "properties": {}, "required": []}, + }, }, { "type": "function", @@ -642,12 +653,12 @@ async def export_query_results( "properties": { "table_name": { "type": "string", - "description": "Name of the table to get schema for" + "description": "Name of the table to get schema for", } }, - "required": ["table_name"] - } - } + "required": ["table_name"], + }, + }, }, { "type": "function", @@ -659,21 +670,21 @@ async def export_query_results( "properties": { "table_name": { "type": "string", - "description": "Name of the table to search" + "description": "Name of the table to search", }, "limit": { "type": "integer", "description": "Number of rows to return (default: 100)", - "default": 100 + "default": 100, }, "where_clause": { "type": "string", - "description": "Optional WHERE clause for filtering" - } + "description": "Optional WHERE clause for filtering", + }, }, - "required": ["table_name"] - } - } + "required": ["table_name"], + }, + }, }, { "type": "function", @@ -685,21 +696,21 @@ async def export_query_results( "properties": { "query": { "type": "string", - "description": "SQL query to execute and export" + "description": "SQL query to execute and export", }, "filename": { "type": "string", - "description": "Optional filename for the CSV export (without extension)" + "description": "Optional filename for the CSV export (without extension)", }, "analysis_limit": { "type": "integer", "description": "Number of rows to show for analysis (default: 50)", - "default": 50 - } + "default": 50, + }, }, - "required": ["query"] - } - } + "required": ["query"], + }, + }, }, { "type": "function", @@ -711,11 +722,11 @@ async def export_query_results( "properties": { "summary": { "type": "string", - "description": "Summary of what was accomplished" + "description": "Summary of what was accomplished", } }, - "required": ["summary"] - } - } - } -] \ No newline at end of file + "required": ["summary"], + }, + }, + }, +] diff --git a/tools/data_tools.py b/tools/data_tools.py index 1297998..11c7b53 100644 --- a/tools/data_tools.py +++ b/tools/data_tools.py @@ -3,165 +3,183 @@ """ import json -import pandas as pd -import numpy as np from pathlib import Path -from typing import Dict, Any, Optional, List -import plotly.graph_objects as go +from typing import Any, Dict, List, Optional + +import numpy as np +import pandas as pd import plotly.express as px +import plotly.graph_objects as go from rich.console import Console from rich.progress import Progress, TaskID -from utils.logging import get_logger, ToolExecutionLogger +from utils.logging import ToolExecutionLogger, get_logger logger = get_logger(__name__) console = Console() tool_logger = ToolExecutionLogger(console) + class DataLoader: """Data loading utilities for ClickHouse""" - + def __init__(self, client): self.client = client - + async def load_from_csv( self, file_path: str, table_name: str, create_table: bool = True, - batch_size: int = 10000 + batch_size: int = 10000, ) -> str: """Load data from CSV file into ClickHouse table""" - - tool_logger.log_tool_start("load_from_csv", { - "file_path": file_path, - "table_name": table_name, - "create_table": create_table, - "batch_size": batch_size - }) - + + tool_logger.log_tool_start( + "load_from_csv", + { + "file_path": file_path, + "table_name": table_name, + "create_table": create_table, + "batch_size": batch_size, + }, + ) + try: file_path = Path(file_path) if not file_path.exists(): raise FileNotFoundError(f"File not found: {file_path}") - + # Read CSV file df = pd.read_csv(file_path) - + if len(df) == 0: return "CSV file is empty" - + # Create table if requested if create_table: create_sql = self._generate_create_table_sql(df, table_name) console.print(f"[cyan]Creating table with SQL:[/cyan]") console.print(create_sql) - + self.client.command(create_sql) - console.print(f"[green]โœ“ Table {table_name} created successfully[/green]") - + console.print( + f"[green]โœ“ Table {table_name} created successfully[/green]" + ) + # Insert data in batches total_rows = len(df) inserted_rows = 0 - + with Progress() as progress: - task = progress.add_task(f"Loading data into {table_name}", total=total_rows) - + task = progress.add_task( + f"Loading data into {table_name}", total=total_rows + ) + for start_idx in range(0, total_rows, batch_size): end_idx = min(start_idx + batch_size, total_rows) batch_df = df.iloc[start_idx:end_idx] - + # Insert batch self.client.insert_df(table_name, batch_df) - + inserted_rows += len(batch_df) progress.update(task, completed=inserted_rows) - + result = f"Successfully loaded {inserted_rows:,} rows from {file_path.name} into table {table_name}" tool_logger.log_tool_success("load_from_csv", result) return result - + except Exception as e: error_msg = f"Failed to load CSV data: {str(e)}" tool_logger.log_tool_error("load_from_csv", error_msg) return error_msg - + async def load_from_json( self, file_path: str, table_name: str, create_table: bool = True, - batch_size: int = 10000 + batch_size: int = 10000, ) -> str: """Load data from JSON file into ClickHouse table""" - - tool_logger.log_tool_start("load_from_json", { - "file_path": file_path, - "table_name": table_name, - "create_table": create_table, - "batch_size": batch_size - }) - + + tool_logger.log_tool_start( + "load_from_json", + { + "file_path": file_path, + "table_name": table_name, + "create_table": create_table, + "batch_size": batch_size, + }, + ) + try: file_path = Path(file_path) if not file_path.exists(): raise FileNotFoundError(f"File not found: {file_path}") - + # Read JSON file - with open(file_path, 'r') as f: + with open(file_path, "r") as f: data = json.load(f) - + # Convert to DataFrame if isinstance(data, list): df = pd.DataFrame(data) elif isinstance(data, dict): df = pd.DataFrame([data]) else: - raise ValueError("JSON file must contain a list of objects or a single object") - + raise ValueError( + "JSON file must contain a list of objects or a single object" + ) + if len(df) == 0: return "JSON file contains no data" - + # Create table if requested if create_table: create_sql = self._generate_create_table_sql(df, table_name) console.print(f"[cyan]Creating table with SQL:[/cyan]") console.print(create_sql) - + self.client.command(create_sql) - console.print(f"[green]โœ“ Table {table_name} created successfully[/green]") - + console.print( + f"[green]โœ“ Table {table_name} created successfully[/green]" + ) + # Insert data in batches total_rows = len(df) inserted_rows = 0 - + with Progress() as progress: - task = progress.add_task(f"Loading data into {table_name}", total=total_rows) - + task = progress.add_task( + f"Loading data into {table_name}", total=total_rows + ) + for start_idx in range(0, total_rows, batch_size): end_idx = min(start_idx + batch_size, total_rows) batch_df = df.iloc[start_idx:end_idx] - + # Insert batch self.client.insert_df(table_name, batch_df) - + inserted_rows += len(batch_df) progress.update(task, completed=inserted_rows) - + result = f"Successfully loaded {inserted_rows:,} rows from {file_path.name} into table {table_name}" tool_logger.log_tool_success("load_from_json", result) return result - + except Exception as e: error_msg = f"Failed to load JSON data: {str(e)}" tool_logger.log_tool_error("load_from_json", error_msg) return error_msg - + def _generate_create_table_sql(self, df: pd.DataFrame, table_name: str) -> str: """Generate CREATE TABLE SQL from DataFrame schema""" - + column_definitions = [] - + for column_name, dtype in df.dtypes.items(): # Map pandas dtypes to ClickHouse types if pd.api.types.is_integer_dtype(dtype): @@ -177,24 +195,30 @@ def _generate_create_table_sql(self, df: pd.DataFrame, table_name: str) -> str: ch_type = "DateTime" else: ch_type = "String" - + column_definitions.append(f"`{column_name}` {ch_type}") - + + # Bind the separator to a name so the f-string expression + # part stays free of backslashes โ€” that's a SyntaxError on + # Python 3.8 - 3.11 (only legal from 3.12 onward) and clickr + # supports the older versions per pyproject.toml's matrix. + sep = ",\n " create_sql = f""" CREATE TABLE IF NOT EXISTS {table_name} ( - {',\n '.join(column_definitions)} + {sep.join(column_definitions)} ) ENGINE = MergeTree() ORDER BY tuple() """ - + return create_sql + class DataVisualizer: """Data visualization utilities""" - + def __init__(self, client): self.client = client - + async def create_chart( self, query: str, @@ -202,32 +226,35 @@ async def create_chart( x_column: str, y_column: str, title: str, - save_path: Optional[str] = None + save_path: Optional[str] = None, ) -> str: """Create a chart from query results""" - - tool_logger.log_tool_start("create_chart", { - "query": query[:100], - "chart_type": chart_type, - "x_column": x_column, - "y_column": y_column, - "title": title - }) - + + tool_logger.log_tool_start( + "create_chart", + { + "query": query[:100], + "chart_type": chart_type, + "x_column": x_column, + "y_column": y_column, + "title": title, + }, + ) + try: # Execute query result = self.client.query(query) df = result.result_as_dataframe() - + if len(df) == 0: return "No data returned from query" - + if x_column not in df.columns: return f"Column '{x_column}' not found in query results" - + if y_column not in df.columns: return f"Column '{y_column}' not found in query results" - + # Create chart based on type if chart_type == "line": fig = px.line(df, x=x_column, y=y_column, title=title) @@ -243,68 +270,65 @@ async def create_chart( fig = px.box(df, y=y_column, title=title) else: return f"Unsupported chart type: {chart_type}" - + # Save chart if path specified if save_path: save_path = Path(save_path) save_path.parent.mkdir(parents=True, exist_ok=True) - - if save_path.suffix.lower() == '.html': + + if save_path.suffix.lower() == ".html": fig.write_html(str(save_path)) - elif save_path.suffix.lower() in ['.png', '.jpg', '.jpeg']: + elif save_path.suffix.lower() in [".png", ".jpg", ".jpeg"]: fig.write_image(str(save_path)) else: # Default to HTML - save_path = save_path.with_suffix('.html') + save_path = save_path.with_suffix(".html") fig.write_html(str(save_path)) - + console.print(f"[green]โœ“ Chart saved to {save_path}[/green]") result = f"Created {chart_type} chart with {len(df)} data points and saved to {save_path}" else: # Show chart in browser fig.show() result = f"Created {chart_type} chart with {len(df)} data points and displayed in browser" - + tool_logger.log_tool_success("create_chart", result) return result - + except Exception as e: error_msg = f"Failed to create chart: {str(e)}" tool_logger.log_tool_error("create_chart", error_msg) return error_msg + class DataExporter: """Data export utilities""" - + def __init__(self, client): self.client = client - + async def export_to_file( - self, - query: str, - output_path: str, - format: str = "csv" + self, query: str, output_path: str, format: str = "csv" ) -> str: """Export query results to file""" - - tool_logger.log_tool_start("export_to_file", { - "query": query[:100], - "output_path": output_path, - "format": format - }) - + + tool_logger.log_tool_start( + "export_to_file", + {"query": query[:100], "output_path": output_path, "format": format}, + ) + try: # Execute query result = self.client.query(query) df = result.result_as_dataframe() - + if len(df) == 0: return "No data to export" - + # Ensure output directory exists output_path = Path(output_path) output_path.parent.mkdir(parents=True, exist_ok=True) - + # Export based on format if format == "csv": df.to_csv(output_path, index=False) @@ -316,14 +340,14 @@ async def export_to_file( df.to_excel(output_path, index=False) else: return f"Unsupported export format: {format}" - + result = f"Exported {len(df):,} rows to {output_path} in {format} format" console.print(f"[green]โœ“ {result}[/green]") - + tool_logger.log_tool_success("export_to_file", result) return result - + except Exception as e: error_msg = f"Failed to export data: {str(e)}" tool_logger.log_tool_error("export_to_file", error_msg) - return error_msg \ No newline at end of file + return error_msg diff --git a/ui/__init__.py b/ui/__init__.py index f4f10bc..8e61f2c 100644 --- a/ui/__init__.py +++ b/ui/__init__.py @@ -1 +1 @@ -# UI package \ No newline at end of file +# UI package diff --git a/ui/beautiful_interface.py b/ui/beautiful_interface.py index 89b09c5..6acf525 100644 --- a/ui/beautiful_interface.py +++ b/ui/beautiful_interface.py @@ -3,70 +3,84 @@ Designed to feel like a modern app with smooth animations and clean design """ -from rich.console import Console -from rich.panel import Panel -from rich.text import Text +import asyncio +import time +from typing import Any, Dict, List, Optional + from rich.align import Align +from rich.box import DOUBLE, MINIMAL, ROUNDED, SIMPLE +from rich.columns import Columns +from rich.console import Console +from rich.emoji import Emoji from rich.layout import Layout from rich.live import Live -from rich.spinner import Spinner -from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn, TaskProgressColumn -from rich.table import Table -from rich.columns import Columns -from rich.box import ROUNDED, SIMPLE, MINIMAL, DOUBLE +from rich.markdown import Markdown from rich.padding import Padding +from rich.panel import Panel +from rich.progress import ( + BarColumn, + Progress, + SpinnerColumn, + TaskProgressColumn, + TextColumn, +) from rich.rule import Rule -from rich.emoji import Emoji -from rich.markdown import Markdown -import time -from typing import Optional, List, Dict, Any -import asyncio +from rich.spinner import Spinner +from rich.table import Table +from rich.text import Text console = Console() + class BeautifulInterface: """Beautiful, animated CLI interface for ClickHouse AI Agent""" - + def __init__(self): self.console = Console() self.current_step = None self.progress = None - + def show_welcome_screen(self): """Display beautiful welcome screen with animations""" self.console.clear() - + # Create gradient-like title title = Text() title.append("โœจ ", style="bold bright_yellow") title.append("Proto", style="bold bright_blue") title.append(" ClickHouse AI Agent", style="bold bright_cyan") title.append(" โœจ", style="bold bright_yellow") - - subtitle = Text("Intelligent Database Analysis & Operations", style="italic bright_white") - + + subtitle = Text( + "Intelligent Database Analysis & Operations", style="italic bright_white" + ) + welcome_panel = Panel( Align.center( Text.assemble( - title, "\n\n", - subtitle, "\n\n", - Text("๐Ÿš€ Ready to explore your data!", style="bold green") + title, + "\n\n", + subtitle, + "\n\n", + Text("๐Ÿš€ Ready to explore your data!", style="bold green"), ) ), box=ROUNDED, border_style="bright_cyan", padding=(2, 4), title="[bold bright_yellow]Welcome[/bold bright_yellow]", - title_align="center" + title_align="center", ) - + self.console.print() self.console.print(welcome_panel) self.console.print() - - def show_connection_status(self, host: str, port: int, database: str, connected: bool = True): + + def show_connection_status( + self, host: str, port: int, database: str, connected: bool = True + ): """Show database connection status with beautiful formatting""" - + if connected: status_icon = "โœ…" status_text = "Connected" @@ -77,171 +91,193 @@ def show_connection_status(self, host: str, port: int, database: str, connected: status_text = "Disconnected" status_style = "bold red" border_style = "red" - + connection_info = Table(show_header=False, box=None, padding=(0, 1)) connection_info.add_column("Field", style="bright_white") connection_info.add_column("Value", style="bright_cyan") - + connection_info.add_row("๐Ÿ  Host:", f"{host}:{port}") connection_info.add_row("๐Ÿ—„๏ธ Database:", database) connection_info.add_row("๐Ÿ“Š Status:", f"{status_icon} {status_text}") - + connection_panel = Panel( connection_info, title=f"[{status_style}]Database Connection[/{status_style}]", border_style=border_style, box=ROUNDED, - padding=(1, 2) + padding=(1, 2), ) - + self.console.print(connection_panel) - + def show_thinking_animation(self, message: str = "Thinking..."): """Show minimal, elegant thinking animation""" return Live( Align.center( - Text.from_markup(f"[dim bright_cyan]โ ‹[/dim bright_cyan] [bright_white]{message}[/bright_white]") + Text.from_markup( + f"[dim bright_cyan]โ ‹[/dim bright_cyan] [bright_white]{message}[/bright_white]" + ) ), - refresh_per_second=8 + refresh_per_second=8, ) - + def show_tool_execution(self, tool_name: str, description: str = ""): """Show minimal tool execution animation""" - clean_name = tool_name.replace('_', ' ').title() + clean_name = tool_name.replace("_", " ").title() return Live( Align.center( - Text.from_markup(f"[dim bright_yellow]โ ธ[/dim bright_yellow] [bright_white]{clean_name}[/bright_white]") + Text.from_markup( + f"[dim bright_yellow]โ ธ[/dim bright_yellow] [bright_white]{clean_name}[/bright_white]" + ) ), - refresh_per_second=8 + refresh_per_second=8, ) - + def show_success(self, message: str): """Show minimal success message""" - self.console.print(f"[dim bright_green]โœ“[/dim bright_green] [bright_white]{message}[/bright_white]") - + self.console.print( + f"[dim bright_green]โœ“[/dim bright_green] [bright_white]{message}[/bright_white]" + ) + def show_error(self, message: str): """Show minimal error message""" - self.console.print(f"[dim red]โœ—[/dim red] [bright_white]{message}[/bright_white]") - + self.console.print( + f"[dim red]โœ—[/dim red] [bright_white]{message}[/bright_white]" + ) + def show_warning(self, message: str): """Show minimal warning message""" - self.console.print(f"[dim yellow]โš [/dim yellow] [bright_white]{message}[/bright_white]") - - def show_data_table(self, data: List[Dict], title: str = "Results", max_rows: int = 10): + self.console.print( + f"[dim yellow]โš [/dim yellow] [bright_white]{message}[/bright_white]" + ) + + def show_data_table( + self, data: List[Dict], title: str = "Results", max_rows: int = 10 + ): """Display data in a beautiful table format""" if not data: - self.console.print(Panel( - Text("No data to display", style="italic bright_white"), - title=title, - border_style="bright_blue", - box=ROUNDED - )) + self.console.print( + Panel( + Text("No data to display", style="italic bright_white"), + title=title, + border_style="bright_blue", + box=ROUNDED, + ) + ) return - + # Create table table = Table(box=ROUNDED, show_lines=True, header_style="bold bright_cyan") - + # Add columns if data: for key in data[0].keys(): - table.add_column(str(key).replace('_', ' ').title(), style="bright_white") - + table.add_column( + str(key).replace("_", " ").title(), style="bright_white" + ) + # Add rows (limit to max_rows) for i, row in enumerate(data[:max_rows]): table.add_row(*[str(value) for value in row.values()]) - + if len(data) > max_rows: table.add_row(*["..." for _ in data[0].keys()], style="dim") - + # Show in panel data_panel = Panel( table, title=f"[bold bright_cyan]{title}[/bold bright_cyan]", border_style="bright_cyan", box=ROUNDED, - padding=(1, 2) + padding=(1, 2), ) - + self.console.print(data_panel) - + if len(data) > max_rows: self.console.print(f"[dim]Showing {max_rows} of {len(data)} rows[/dim]") - - def show_query_result(self, query: str, result_data: List[Dict], execution_time: float = None): + + def show_query_result( + self, query: str, result_data: List[Dict], execution_time: float = None + ): """Show query and its results in a beautiful format""" - + # Show the query query_panel = Panel( Markdown(f"```sql\n{query}\n```"), title="[bold bright_magenta]Query[/bold bright_magenta]", border_style="bright_magenta", box=ROUNDED, - padding=(1, 2) + padding=(1, 2), ) self.console.print(query_panel) - + # Show execution time if provided if execution_time is not None: self.console.print(f"[dim]โฑ๏ธ Executed in {execution_time:.2f}s[/dim]") - + # Show results self.show_data_table(result_data, "Query Results") - + def show_agent_response(self, message: str): """Show minimal agent response""" - self.console.print(f"[dim bright_blue]โ†’[/dim bright_blue] [bright_white]{message}[/bright_white]") - + self.console.print( + f"[dim bright_blue]โ†’[/dim bright_blue] [bright_white]{message}[/bright_white]" + ) + def show_user_input_prompt(self): """Show beautiful input prompt""" self.console.print() rule = Rule(style="dim") self.console.print(rule) - + prompt_text = Text.assemble( - ("๐Ÿ’ฌ ", "bright_cyan"), - ("You: ", "bold bright_cyan") + ("๐Ÿ’ฌ ", "bright_cyan"), ("You: ", "bold bright_cyan") ) self.console.print(prompt_text, end="") - + def show_goodbye(self): """Show beautiful goodbye message""" goodbye_panel = Panel( Align.center( Text.assemble( ("๐Ÿ‘‹ ", "bold bright_yellow"), - ("Thank you for using Proto ClickHouse AI Agent!", "bold bright_white"), + ( + "Thank you for using Proto ClickHouse AI Agent!", + "bold bright_white", + ), ("\n\n", ""), - ("Hope to see you again soon! โœจ", "italic bright_cyan") + ("Hope to see you again soon! โœจ", "italic bright_cyan"), ) ), border_style="bright_yellow", box=ROUNDED, - padding=(2, 4) + padding=(2, 4), ) self.console.print() self.console.print(goodbye_panel) self.console.print() - + def show_statistics(self, stats: Dict[str, Any]): """Show session statistics in a beautiful format""" stats_table = Table(show_header=False, box=None, padding=(0, 2)) stats_table.add_column("Metric", style="bright_white") stats_table.add_column("Value", style="bright_cyan") - + for key, value in stats.items(): - display_key = key.replace('_', ' ').title() + display_key = key.replace("_", " ").title() stats_table.add_row(f"๐Ÿ“Š {display_key}:", str(value)) - + stats_panel = Panel( stats_table, title="[bold bright_green]Session Statistics[/bold bright_green]", border_style="bright_green", box=ROUNDED, - padding=(1, 2) + padding=(1, 2), ) - + self.console.print(stats_panel) - + def create_progress_bar(self, description: str, total: int = 100): """Create a beautiful progress bar""" progress = Progress( @@ -249,12 +285,12 @@ def create_progress_bar(self, description: str, total: int = 100): TextColumn("[progress.description]{task.description}"), BarColumn(bar_width=40), TaskProgressColumn(), - console=self.console + console=self.console, ) - + task = progress.add_task(description, total=total) return progress, task - + def animate_typing(self, text: str, delay: float = 0.03): """Animate typing effect for text""" for char in text: @@ -262,5 +298,6 @@ def animate_typing(self, text: str, delay: float = 0.03): time.sleep(delay) self.console.print() # New line at the end + # Global instance -ui = BeautifulInterface() \ No newline at end of file +ui = BeautifulInterface() diff --git a/ui/minimal_interface.py b/ui/minimal_interface.py index a3d47fe..24c993c 100644 --- a/ui/minimal_interface.py +++ b/ui/minimal_interface.py @@ -3,19 +3,20 @@ Clean animations, creative loading messages, no boxy layouts """ +import asyncio +import random +import time +from typing import Any, Dict, List, Optional + +from rich.align import Align +from rich.box import SIMPLE from rich.console import Console from rich.live import Live -from rich.text import Text -from rich.align import Align -from rich.table import Table +from rich.markdown import Markdown from rich.rule import Rule -from rich.box import SIMPLE from rich.spinner import Spinner -from rich.markdown import Markdown -import time -import random -from typing import Optional, List, Dict, Any -import asyncio +from rich.table import Table +from rich.text import Text console = Console() @@ -48,35 +49,38 @@ # Spinning characters for smooth animation SPINNERS = ["โ ‹", "โ ™", "โ น", "โ ธ", "โ ผ", "โ ด", "โ ฆ", "โ ง", "โ ‡", "โ "] + class MinimalInterface: """Claude Code inspired minimal interface""" - + def __init__(self): self.console = Console() self.spinner_index = 0 - + def _get_spinner(self) -> str: """Get next spinner character""" char = SPINNERS[self.spinner_index % len(SPINNERS)] self.spinner_index += 1 return char - + def show_welcome_screen(self): """Show minimal welcome - no boxes""" self.console.clear() self.console.print() self.console.print() - + # Simple centered title with better spacing title = Text("โœจ Proto", style="bold bright_cyan") subtitle = Text("ClickHouse AI Agent", style="dim bright_white") - + self.console.print(Align.center(title)) self.console.print(Align.center(subtitle)) self.console.print() self.console.print() - - def show_connection_status(self, host: str, port: int, database: str, connected: bool = True): + + def show_connection_status( + self, host: str, port: int, database: str, connected: bool = True + ): """Minimal connection status with better spacing""" if connected: status = f"[dim bright_green]โ—[/dim bright_green] [bright_white]Connected to {host}:{port}[/bright_white]" @@ -84,60 +88,76 @@ def show_connection_status(self, host: str, port: int, database: str, connected: else: status = f"[dim red]โ—[/dim red] [bright_white]Disconnected[/bright_white]" db_info = "" - + self.console.print(status) if db_info: self.console.print(db_info) self.console.print() self.console.print() - + def show_thinking_animation(self, message: str = None): """Smooth thinking animation like Claude Code""" if not message: message = random.choice(THINKING_MESSAGES) - + # Create simple spinner with message using Rich's built-in spinner spinner_with_text = Spinner("dots", text=message, style="dim bright_cyan") - + return Live(spinner_with_text, refresh_per_second=10, transient=True) - - def show_tool_execution(self, tool_name: str, description: str = "", arguments: dict = None): + + def show_tool_execution( + self, tool_name: str, description: str = "", arguments: dict = None + ): """Smooth tool execution animation with arguments display""" - message = TOOL_MESSAGES.get(tool_name, f"Running {tool_name.replace('_', ' ')}...") - + message = TOOL_MESSAGES.get( + tool_name, f"Running {tool_name.replace('_', ' ')}..." + ) + # Show tool call with arguments before starting animation - self.console.print(f"[dim bright_yellow]๐Ÿ”ง[/dim bright_yellow] [bright_white]Executing tool: {tool_name}[/bright_white]") + self.console.print( + f"[dim bright_yellow]๐Ÿ”ง[/dim bright_yellow] [bright_white]Executing tool: {tool_name}[/bright_white]" + ) if arguments: # Show key arguments (truncate if too long) for key, value in arguments.items(): display_value = str(value) if len(display_value) > 100: display_value = display_value[:97] + "..." - self.console.print(f"[dim]{key}>[/dim] [dim bright_white]{display_value}[/dim bright_white]") - + self.console.print( + f"[dim]{key}>[/dim] [dim bright_white]{display_value}[/dim bright_white]" + ) + # Create simple spinner with message using Rich's built-in spinner spinner_with_text = Spinner("dots", text=message, style="dim bright_yellow") - + return Live(spinner_with_text, refresh_per_second=10, transient=True) - + def show_success(self, message: str): """Minimal success message with spacing""" - self.console.print(f"[dim bright_green]โœ“[/dim bright_green] [bright_white]{message}[/bright_white]") - + self.console.print( + f"[dim bright_green]โœ“[/dim bright_green] [bright_white]{message}[/bright_white]" + ) + def show_error(self, message: str): """Minimal error message with spacing""" - self.console.print(f"[dim red]โœ—[/dim red] [bright_white]{message}[/bright_white]") - + self.console.print( + f"[dim red]โœ—[/dim red] [bright_white]{message}[/bright_white]" + ) + def show_warning(self, message: str): """Minimal warning message with spacing""" - self.console.print(f"[dim yellow]โš [/dim yellow] [bright_white]{message}[/bright_white]") - + self.console.print( + f"[dim yellow]โš [/dim yellow] [bright_white]{message}[/bright_white]" + ) + def show_agent_response(self, message: str): """Clean agent response with spacing""" # Add a subtle bullet point like Claude Code - self.console.print(f"[dim bright_blue]โ—[/dim bright_blue] [bright_white]{message}[/bright_white]") + self.console.print( + f"[dim bright_blue]โ—[/dim bright_blue] [bright_white]{message}[/bright_white]" + ) self.console.print() - + def show_markdown(self, markdown_text: str): """Render markdown text beautifully in the terminal""" try: @@ -148,88 +168,126 @@ def show_markdown(self, markdown_text: str): # Fallback to plain text if markdown parsing fails self.console.print(f"[bright_white]{markdown_text}[/bright_white]") self.console.print() - + def show_agent_response_markdown(self, message: str): - """Show agent response with markdown rendering if it contains markdown""" + """Show agent response with markdown rendering if it contains markdown""" # Check if the message contains markdown-like syntax - if any(marker in message for marker in ['##', '**', '|', '```', '- ', '1. ', '*']): + if any( + marker in message for marker in ["##", "**", "|", "```", "- ", "1. ", "*"] + ): # Add a subtle bullet point and render as markdown - self.console.print(f"[dim bright_blue]โ—[/dim bright_blue] [dim]Response:[/dim]") + self.console.print( + f"[dim bright_blue]โ—[/dim bright_blue] [dim]Response:[/dim]" + ) self.console.print() self.show_markdown(message) else: # Fall back to regular response self.show_agent_response(message) - + def show_reasoning(self, reasoning_text: str): """Show reasoning in muted quote style like Notion""" if not reasoning_text: return - + # Split reasoning into lines for proper quote formatting - lines = reasoning_text.split('\n') - - self.console.print(f"[dim bright_yellow]๐Ÿ’ญ[/dim bright_yellow] [dim italic]Thinking:[/dim italic]") - + lines = reasoning_text.split("\n") + + self.console.print( + f"[dim bright_yellow]๐Ÿ’ญ[/dim bright_yellow] [dim italic]Thinking:[/dim italic]" + ) + for line in lines: if line.strip(): # Only show non-empty lines # Quote-style formatting with muted colors and smaller appearance - self.console.print(f"[dim]โ”‚[/dim] [dim italic]{line.strip()}[/dim italic]") - + self.console.print( + f"[dim]โ”‚[/dim] [dim italic]{line.strip()}[/dim italic]" + ) + self.console.print() # Add spacing after reasoning - + def show_query_execution(self, query: str): """Show the actual SQL query being executed""" - self.console.print(f"[dim bright_magenta]SQL>[/dim bright_magenta] [dim bright_white]{query}[/dim bright_white]") + self.console.print( + f"[dim bright_magenta]SQL>[/dim bright_magenta] [dim bright_white]{query}[/dim bright_white]" + ) self.console.print() - - def show_data_table(self, data: List[Dict], title: str = "Results", max_rows: int = 10, total_rows: int = None): + + def show_data_table( + self, + data: List[Dict], + title: str = "Results", + max_rows: int = 10, + total_rows: int = None, + ): """Smart data table display with large dataset handling""" if not data: self.console.print(f"[dim]No data to display[/dim]") return - + total_rows = total_rows or len(data) - + # Handle large datasets with warnings and tips if total_rows > 1000: self.console.print() - self.console.print(f"[dim bright_yellow]๐Ÿ“Š[/dim bright_yellow] [bright_white]Large dataset: {total_rows:,} total rows[/bright_white]") - self.console.print(f"[dim]Showing first {min(max_rows, len(data))} rows. Use LIMIT for specific ranges.[/dim]") - + self.console.print( + f"[dim bright_yellow]๐Ÿ“Š[/dim bright_yellow] [bright_white]Large dataset: {total_rows:,} total rows[/bright_white]" + ) + self.console.print( + f"[dim]Showing first {min(max_rows, len(data))} rows. Use LIMIT for specific ranges.[/dim]" + ) + # Handle wide tables (many columns) columns = list(data[0].keys()) if data else [] display_columns = columns - + if len(columns) > 8: self.console.print() - self.console.print(f"[dim bright_blue]๐Ÿ“‹[/dim bright_blue] [bright_white]Wide table: {len(columns)} columns[/bright_white]") - self.console.print(f"[dim]Showing first 8 columns. Use SELECT col1, col2... for specific columns.[/dim]") + self.console.print( + f"[dim bright_blue]๐Ÿ“‹[/dim bright_blue] [bright_white]Wide table: {len(columns)} columns[/bright_white]" + ) + self.console.print( + f"[dim]Showing first 8 columns. Use SELECT col1, col2... for specific columns.[/dim]" + ) display_columns = columns[:8] - + # Create minimal table with smart title - rows_text = f"{min(len(data), max_rows):,} of {total_rows:,}" if total_rows != len(data) else f"{len(data):,}" + rows_text = ( + f"{min(len(data), max_rows):,} of {total_rows:,}" + if total_rows != len(data) + else f"{len(data):,}" + ) table_title = f"{title} ({rows_text} rows)" - - table = Table(title=table_title, box=SIMPLE, show_header=True, header_style="dim bright_cyan") - + + table = Table( + title=table_title, + box=SIMPLE, + show_header=True, + header_style="dim bright_cyan", + ) + # Add columns with smart truncation for key in display_columns: - display_key = str(key).replace('_', ' ').title() + display_key = str(key).replace("_", " ").title() if len(display_key) > 15: display_key = display_key[:12] + "..." table.add_column(display_key, style="bright_white") - + # Add rows with smart content truncation displayed_rows = min(max_rows, len(data)) for i, row in enumerate(data[:displayed_rows]): row_values = [] for key in display_columns: - value = str(row.get(key, '')) - + value = str(row.get(key, "")) + # Smart truncation based on content if len(value) > 50: - if value.replace('.', '').replace('-', '').replace(',', '').isdigit(): + if ( + value.replace(".", "") + .replace("-", "") + .replace(",", "") + .isdigit() + ): # Numeric - show with ellipsis row_values.append(value[:20] + "...") elif len(value) > 100: @@ -240,39 +298,49 @@ def show_data_table(self, data: List[Dict], title: str = "Results", max_rows: in row_values.append(value[:45] + "...") else: row_values.append(value) - + table.add_row(*row_values) - + # Show pagination indicator if len(data) > displayed_rows or total_rows > len(data): remaining = total_rows - displayed_rows if remaining > 0: pagination_text = f"... and {remaining:,} more rows" - table.add_row(*[f"[dim]{pagination_text}[/dim]" if i == 0 else "[dim]...[/dim]" - for i in range(len(display_columns))]) - + table.add_row( + *[ + f"[dim]{pagination_text}[/dim]" if i == 0 else "[dim]...[/dim]" + for i in range(len(display_columns)) + ] + ) + self.console.print() self.console.print(table) - + # Show helpful tips for large datasets if total_rows > 100: self.console.print() - self.console.print(f"[dim bright_cyan]๐Ÿ’ก[/dim bright_cyan] [dim]Tips: Use LIMIT, WHERE clauses, or ask to export results[/dim]") + self.console.print( + f"[dim bright_cyan]๐Ÿ’ก[/dim bright_cyan] [dim]Tips: Use LIMIT, WHERE clauses, or ask to export results[/dim]" + ) elif len(columns) > 8: self.console.print() - self.console.print(f"[dim bright_cyan]๐Ÿ’ก[/dim bright_cyan] [dim]Tip: Use SELECT specific_columns FROM table for focused results[/dim]") - + self.console.print( + f"[dim bright_cyan]๐Ÿ’ก[/dim bright_cyan] [dim]Tip: Use SELECT specific_columns FROM table for focused results[/dim]" + ) + self.console.print() - - def show_query_result(self, query: str, result_data: List[Dict], execution_time: float = None): + + def show_query_result( + self, query: str, result_data: List[Dict], execution_time: float = None + ): """Show query and results cleanly""" # Show execution time if provided if execution_time is not None: self.console.print(f"[dim]Query executed in {execution_time:.2f}s[/dim]") - + # Show results self.show_data_table(result_data, "Results") - + def show_user_input_prompt(self): """Minimal input prompt like Claude Code""" self.console.print() @@ -280,26 +348,30 @@ def show_user_input_prompt(self): rule = Rule(style="dim") self.console.print(rule) self.console.print() - + # Simple prompt prompt_text = Text.assemble( ("> ", "dim bright_cyan"), ) self.console.print(prompt_text, end="") - + def show_goodbye(self): """Minimal goodbye with spacing""" self.console.print() self.console.print() - self.console.print("[dim bright_cyan]โœจ[/dim bright_cyan] [bright_white]Thanks for using Proto![/bright_white]") + self.console.print( + "[dim bright_cyan]โœจ[/dim bright_cyan] [bright_white]Thanks for using Proto![/bright_white]" + ) self.console.print() - + def show_statistics(self, stats: Dict[str, Any]): """Minimal stats display""" for key, value in stats.items(): - display_key = key.replace('_', ' ').title() - self.console.print(f"[dim]{display_key}:[/dim] [bright_white]{value}[/bright_white]") - + display_key = key.replace("_", " ").title() + self.console.print( + f"[dim]{display_key}:[/dim] [bright_white]{value}[/bright_white]" + ) + def animate_typing(self, text: str, delay: float = 0.02): """Smooth typing animation""" for char in text: @@ -307,5 +379,6 @@ def animate_typing(self, text: str, delay: float = 0.02): time.sleep(delay) self.console.print() + # Global instance -ui = MinimalInterface() \ No newline at end of file +ui = MinimalInterface() diff --git a/ui/onboarding.py b/ui/onboarding.py index 1a76be3..a1eb36e 100644 --- a/ui/onboarding.py +++ b/ui/onboarding.py @@ -3,312 +3,359 @@ Handles provider selection, API key setup, and initial configuration """ -from rich.console import Console -from rich.panel import Panel -from rich.prompt import Prompt, Confirm -from rich.text import Text +import json +from pathlib import Path +from typing import Any, Dict, Optional + from rich.align import Align +from rich.box import ROUNDED from rich.columns import Columns +from rich.console import Console +from rich.panel import Panel +from rich.prompt import Confirm, Prompt from rich.table import Table -from rich.box import ROUNDED -from pathlib import Path -import json -from typing import Dict, Any, Optional +from rich.text import Text + from ui.minimal_interface import ui console = Console() + class OnboardingFlow: """Beautiful onboarding flow for new users""" - + def __init__(self): self.config = {} # Store config in user config directory (e.g., ~/.config/proto/proto-config.json) config_dir = Path.home() / ".config" / "proto" config_dir.mkdir(parents=True, exist_ok=True) self.config_file = config_dir / "proto-config.json" - + def show_welcome(self): """Show welcome message for first-time users""" welcome_text = Text() welcome_text.append("๐ŸŽ‰ ", style="bold bright_yellow") welcome_text.append("Welcome to Proto!", style="bold bright_cyan") welcome_text.append(" ๐ŸŽ‰", style="bold bright_yellow") - - subtitle = Text("Let's get you set up with your ClickHouse AI Agent", style="italic bright_white") - + + subtitle = Text( + "Let's get you set up with your ClickHouse AI Agent", + style="italic bright_white", + ) + welcome_panel = Panel( Align.center( Text.assemble( - welcome_text, "\n\n", - subtitle, "\n\n", - ("This will only take a minute! โšก", "bold green") + welcome_text, + "\n\n", + subtitle, + "\n\n", + ("This will only take a minute! โšก", "bold green"), ) ), box=ROUNDED, border_style="bright_cyan", padding=(2, 4), title="[bold bright_yellow]First Time Setup[/bold bright_yellow]", - title_align="center" + title_align="center", ) - + console.print() console.print(welcome_panel) console.print() - + # Removed choose_ai_provider method - no longer needed - + def setup_openrouter(self) -> Dict[str, Any]: """Setup OpenRouter configuration""" console.print() info_panel = Panel( Text.assemble( - ("๐ŸŒ OpenRouter Setup", "bold bright_cyan"), "\n\n", - ("OpenRouter gives you access to multiple AI models including:", "bright_white"), "\n", - ("โ€ข GPT-4o, GPT-4o-mini", "green"), "\n", - ("โ€ข Claude 3.5 Sonnet", "green"), "\n", - ("โ€ข Llama, Mistral, and more", "green"), "\n\n", + ("๐ŸŒ OpenRouter Setup", "bold bright_cyan"), + "\n\n", + ( + "OpenRouter gives you access to multiple AI models including:", + "bright_white", + ), + "\n", + ("โ€ข GPT-4o, GPT-4o-mini", "green"), + "\n", + ("โ€ข Claude 3.5 Sonnet", "green"), + "\n", + ("โ€ข Llama, Mistral, and more", "green"), + "\n\n", ("You'll need an API key from: ", "bright_white"), - ("https://openrouter.ai", "bright_blue underline") + ("https://openrouter.ai", "bright_blue underline"), ), border_style="bright_cyan", - padding=(1, 2) + padding=(1, 2), ) console.print(info_panel) console.print() - + # Get API key api_key = Prompt.ask( "[bold bright_cyan]Enter your OpenRouter API key[/bold bright_cyan]", - password=True + password=True, ) - + # Choose model model_table = Table(show_header=False, box=ROUNDED, padding=(0, 1)) model_table.add_column("Option", style="bold bright_cyan") model_table.add_column("Model", style="bright_white") model_table.add_column("Cost", style="bright_green") - + model_table.add_row("1", "GPT-4o-mini", "Cheapest, fast") model_table.add_row("2", "GPT-4o", "Most capable") model_table.add_row("3", "Claude 3.5 Sonnet", "Great for analysis") - + model_panel = Panel( model_table, title="[bold bright_green]Choose Model[/bold bright_green]", - border_style="bright_green" + border_style="bright_green", ) console.print(model_panel) - + model_choice = Prompt.ask( "[bold bright_green]Select model[/bold bright_green]", choices=["1", "2", "3"], - default="1" + default="1", ) - + models = { "1": "openai/gpt-4o-mini", - "2": "openai/gpt-4o", - "3": "anthropic/claude-3.5-sonnet" + "2": "openai/gpt-4o", + "3": "anthropic/claude-3.5-sonnet", } - + return { "provider": "openrouter", "openrouter_api_key": api_key, "openrouter_model": models[model_choice], - "openrouter_provider_only": "openai" if model_choice in ["1", "2"] else "anthropic", - "openrouter_data_collection": "deny" + "openrouter_provider_only": ( + "openai" if model_choice in ["1", "2"] else "anthropic" + ), + "openrouter_data_collection": "deny", } - + def setup_local(self) -> Dict[str, Any]: """Setup local ClickHouse AI Agent configuration""" console.print() info_panel = Panel( Text.assemble( - ("๐Ÿค– ClickHouse AI Agent Setup", "bold bright_green"), "\n\n", - ("Your ClickHouse AI Agent will be automatically downloaded and started.", "bright_white"), "\n", - ("No API keys needed - everything is handled automatically.", "green"), "\n\n", - ("The model will be downloaded on first run (~3.5GB).", "dim") + ("๐Ÿค– ClickHouse AI Agent Setup", "bold bright_green"), + "\n\n", + ( + "Your ClickHouse AI Agent will be automatically downloaded and started.", + "bright_white", + ), + "\n", + ("No API keys needed - everything is handled automatically.", "green"), + "\n\n", + ("The model will be downloaded on first run (~3.5GB).", "dim"), ), border_style="bright_green", - padding=(1, 2) + padding=(1, 2), ) console.print(info_panel) console.print() - + return { "provider": "local", "local_llm_base_url": "http://127.0.0.1:8000/v1", - "local_llm_model": "vishprometa/clickhouse-qwen3-1.7b-gguf" + "local_llm_model": "vishprometa/clickhouse-qwen3-1.7b-gguf", } - + def setup_lmstudio(self) -> Dict[str, Any]: """Setup LM Studio configuration""" console.print() info_panel = Panel( Text.assemble( - ("๐Ÿ  LM Studio Setup", "bold bright_blue"), "\n\n", - ("LM Studio provides a local server for AI models.", "bright_white"), "\n", - ("Benefits: Free, private, easy GUI", "green"), "\n\n", - ("Make sure LM Studio is running with local server enabled:", "bright_white"), "\n", - ("โ€ข Download from: ", "bright_white"), ("https://lmstudio.ai", "bright_blue underline"), "\n", - ("โ€ข Start local server on port 1234", "bright_yellow") + ("๐Ÿ  LM Studio Setup", "bold bright_blue"), + "\n\n", + ("LM Studio provides a local server for AI models.", "bright_white"), + "\n", + ("Benefits: Free, private, easy GUI", "green"), + "\n\n", + ( + "Make sure LM Studio is running with local server enabled:", + "bright_white", + ), + "\n", + ("โ€ข Download from: ", "bright_white"), + ("https://lmstudio.ai", "bright_blue underline"), + "\n", + ("โ€ข Start local server on port 1234", "bright_yellow"), ), border_style="bright_blue", - padding=(1, 2) + padding=(1, 2), ) console.print(info_panel) console.print() - + base_url = Prompt.ask( "[bold bright_blue]LM Studio base URL[/bold bright_blue]", - default="http://localhost:1234" + default="http://localhost:1234", ) - - return { - "provider": "lmstudio", - "lmstudio_base_url": base_url - } - + + return {"provider": "lmstudio", "lmstudio_base_url": base_url} + def setup_openai(self) -> Dict[str, Any]: """Setup direct OpenAI configuration""" console.print() info_panel = Panel( Text.assemble( - ("โ˜๏ธ OpenAI Direct Setup", "bold bright_red"), "\n\n", - ("Connect directly to OpenAI's API.", "bright_white"), "\n", + ("โ˜๏ธ OpenAI Direct Setup", "bold bright_red"), + "\n\n", + ("Connect directly to OpenAI's API.", "bright_white"), + "\n", ("You'll need an API key from: ", "bright_white"), - ("https://platform.openai.com", "bright_blue underline") + ("https://platform.openai.com", "bright_blue underline"), ), border_style="bright_red", - padding=(1, 2) + padding=(1, 2), ) console.print(info_panel) console.print() - + api_key = Prompt.ask( "[bold bright_red]Enter your OpenAI API key[/bold bright_red]", - password=True + password=True, ) - + model_choice = Prompt.ask( "[bold bright_red]Choose model[/bold bright_red]", choices=["gpt-4o-mini", "gpt-4o", "gpt-4"], - default="gpt-4o-mini" + default="gpt-4o-mini", ) - + return { "provider": "openai", "openai_api_key": api_key, - "openai_model": model_choice + "openai_model": model_choice, } - + def setup_clickhouse(self) -> Dict[str, Any]: """Setup ClickHouse connection""" console.print() - + # Ask if they want local or cloud connection_type = Prompt.ask( "[bold bright_cyan]ClickHouse connection type[/bold bright_cyan]", choices=["local", "cloud"], - default="local" + default="local", ) - + if connection_type == "local": return { "clickhouse_host": "localhost", "clickhouse_port": 8123, - "clickhouse_username": "default", + "clickhouse_username": "default", "clickhouse_password": "", "clickhouse_database": "default", - "clickhouse_secure": False + "clickhouse_secure": False, } else: console.print() cloud_panel = Panel( Text.assemble( - ("โ˜๏ธ ClickHouse Cloud Setup", "bold bright_cyan"), "\n\n", - ("Enter your ClickHouse Cloud connection details:", "bright_white"), "\n\n", - ("๐Ÿ’ก Common ports:", "blue"), "\n", - ("โ€ข 8123 (HTTP)", "dim"), "\n", - ("โ€ข 8443 (HTTPS)", "dim") + ("โ˜๏ธ ClickHouse Cloud Setup", "bold bright_cyan"), + "\n\n", + ("Enter your ClickHouse Cloud connection details:", "bright_white"), + "\n\n", + ("๐Ÿ’ก Common ports:", "blue"), + "\n", + ("โ€ข 8123 (HTTP)", "dim"), + "\n", + ("โ€ข 8443 (HTTPS)", "dim"), ), border_style="bright_cyan", - padding=(1, 2) + padding=(1, 2), ) console.print(cloud_panel) - + host = Prompt.ask("[bold bright_cyan]Host[/bold bright_cyan]") - port = int(Prompt.ask("[bold bright_cyan]Port[/bold bright_cyan]", default="8123")) - username = Prompt.ask("[bold bright_cyan]Username[/bold bright_cyan]", default="default") - password = Prompt.ask("[bold bright_cyan]Password[/bold bright_cyan]", password=True) - database = Prompt.ask("[bold bright_cyan]Database[/bold bright_cyan]", default="default") - + port = int( + Prompt.ask("[bold bright_cyan]Port[/bold bright_cyan]", default="8123") + ) + username = Prompt.ask( + "[bold bright_cyan]Username[/bold bright_cyan]", default="default" + ) + password = Prompt.ask( + "[bold bright_cyan]Password[/bold bright_cyan]", password=True + ) + database = Prompt.ask( + "[bold bright_cyan]Database[/bold bright_cyan]", default="default" + ) + # Ask about secure connection secure = Confirm.ask( "[bold bright_cyan]Use secure connection (HTTPS)?[/bold bright_cyan]", - default=True if port == 8443 else False + default=True if port == 8443 else False, ) - + return { "clickhouse_host": host, "clickhouse_port": port, "clickhouse_username": username, "clickhouse_password": password, "clickhouse_database": database, - "clickhouse_secure": secure + "clickhouse_secure": secure, } - + def save_config(self, config: Dict[str, Any]): """Save configuration to file""" self.config_file.parent.mkdir(parents=True, exist_ok=True) - with open(self.config_file, 'w') as f: + with open(self.config_file, "w") as f: json.dump(config, f, indent=2) - + ui.show_success(f"Configuration saved to {self.config_file}") - + def run_onboarding(self) -> Dict[str, Any]: """Run the complete onboarding flow""" self.show_welcome() - + # Automatically use local ClickHouse AI Agent (no provider choice) ai_config = self.setup_local() - + # Setup ClickHouse clickhouse_config = self.setup_clickhouse() - + # Combine configs final_config = { **ai_config, **clickhouse_config, "temperature": 0.1, "max_tokens": 4000, - "max_tool_calls": 35 + "max_tool_calls": 35, } - + # Save config self.save_config(final_config) - + # Show completion completion_panel = Panel( Align.center( Text.assemble( - ("๐ŸŽ‰ Setup Complete! ๐ŸŽ‰", "bold bright_green"), "\n\n", - ("Your Proto ClickHouse AI Agent is ready to use!", "bright_white"), "\n", + ("๐ŸŽ‰ Setup Complete! ๐ŸŽ‰", "bold bright_green"), + "\n\n", + ("Your Proto ClickHouse AI Agent is ready to use!", "bright_white"), + "\n", ("You can change these settings anytime with: ", "dim"), - ("proto settings", "bright_cyan") + ("proto settings", "bright_cyan"), ) ), border_style="bright_green", box=ROUNDED, - padding=(2, 4) + padding=(2, 4), ) console.print() console.print(completion_panel) console.print() - + return final_config + def needs_onboarding() -> bool: """Check if user needs onboarding""" # Preferred config location @@ -316,32 +363,40 @@ def needs_onboarding() -> bool: # Legacy location fallback (cwd) legacy_config_file = Path("proto-config.json") env_file = Path(".env") - + # If neither config file exists, needs onboarding if not (config_file.exists() or legacy_config_file.exists() or env_file.exists()): return True - + # If config exists but is empty or invalid - active_config_file = config_file if config_file.exists() else legacy_config_file if legacy_config_file.exists() else None + active_config_file = ( + config_file + if config_file.exists() + else legacy_config_file if legacy_config_file.exists() else None + ) if active_config_file and active_config_file.exists(): try: - with open(active_config_file, 'r') as f: + with open(active_config_file, "r") as f: config = json.load(f) # Check if essential keys exist - has_provider = any( - k in config for k in [ - 'local_llm_base_url', - 'ollama_base_url', - 'lmstudio_base_url', - 'openai_api_key', - # local provider keys - 'local_llm_base_url', - 'local_llm_model', - ] - ) or config.get('provider') == 'local' # Also check if provider is explicitly set to local - has_clickhouse = 'clickhouse_host' in config + has_provider = ( + any( + k in config + for k in [ + "local_llm_base_url", + "ollama_base_url", + "lmstudio_base_url", + "openai_api_key", + # local provider keys + "local_llm_base_url", + "local_llm_model", + ] + ) + or config.get("provider") == "local" + ) # Also check if provider is explicitly set to local + has_clickhouse = "clickhouse_host" in config return not (has_provider and has_clickhouse) except (json.JSONDecodeError, KeyError): return True - - return False \ No newline at end of file + + return False diff --git a/ui/settings_manager.py b/ui/settings_manager.py index f144914..fb71777 100644 --- a/ui/settings_manager.py +++ b/ui/settings_manager.py @@ -3,55 +3,60 @@ Allows users to view, edit, and switch between configurations """ +import json +from pathlib import Path +from typing import Any, Dict, Optional + +from rich.box import ROUNDED from rich.console import Console from rich.panel import Panel -from rich.prompt import Prompt, Confirm -from rich.text import Text +from rich.prompt import Confirm, Prompt from rich.table import Table -from rich.box import ROUNDED -from pathlib import Path -import json -from typing import Dict, Any, Optional +from rich.text import Text + from ui.minimal_interface import ui from ui.onboarding import OnboardingFlow console = Console() + class SettingsManager: """Manage user settings and configurations""" - + def __init__(self): self.config_file = Path("proto-config.json") self.env_file = Path(".env") - + def load_current_config(self) -> Dict[str, Any]: """Load current configuration""" if self.config_file.exists(): try: - with open(self.config_file, 'r') as f: + with open(self.config_file, "r") as f: return json.load(f) except (json.JSONDecodeError, FileNotFoundError): return {} return {} - + def show_current_settings(self): """Display current settings in a beautiful format""" config = self.load_current_config() - + if not config: ui.show_warning("No configuration found. Run onboarding first!") return - + # AI Provider Section ai_table = Table(show_header=False, box=None, padding=(0, 2)) ai_table.add_column("Setting", style="bright_white") ai_table.add_column("Value", style="bright_cyan") - + provider = config.get("provider", "unknown") ai_table.add_row("๐Ÿค– Provider:", provider.title()) - + if provider == "local": - ai_table.add_row("๐ŸŒ Base URL:", config.get("local_llm_base_url", "Not set")) + ai_table.add_row( + "๐ŸŒ Base URL:", config.get("local_llm_base_url", "Not set") + ) ai_table.add_row("๐Ÿง  Model:", config.get("local_llm_model", "Not set")) elif provider == "ollama": ai_table.add_row("๐ŸŒ Base URL:", config.get("ollama_base_url", "Not set")) @@ -59,175 +64,200 @@ def show_current_settings(self): elif provider == "lmstudio": ai_table.add_row("๐ŸŒ Base URL:", config.get("lmstudio_base_url", "Not set")) elif provider == "openai": - ai_table.add_row("๐Ÿ”‘ API Key:", "โ€ขโ€ขโ€ขโ€ขโ€ขโ€ขโ€ขโ€ข" if config.get("openai_api_key") else "Not set") + ai_table.add_row( + "๐Ÿ”‘ API Key:", "โ€ขโ€ขโ€ขโ€ขโ€ขโ€ขโ€ขโ€ข" if config.get("openai_api_key") else "Not set" + ) ai_table.add_row("๐Ÿง  Model:", config.get("openai_model", "Not set")) - + ai_panel = Panel( ai_table, title="[bold bright_magenta]AI Provider Settings[/bold bright_magenta]", border_style="bright_magenta", - padding=(1, 2) + padding=(1, 2), ) - + # ClickHouse Section ch_table = Table(show_header=False, box=None, padding=(0, 2)) ch_table.add_column("Setting", style="bright_white") ch_table.add_column("Value", style="bright_cyan") - + ch_table.add_row("๐Ÿ  Host:", config.get("clickhouse_host", "Not set")) ch_table.add_row("๐Ÿ”Œ Port:", str(config.get("clickhouse_port", "Not set"))) ch_table.add_row("๐Ÿ‘ค Username:", config.get("clickhouse_username", "Not set")) - ch_table.add_row("๐Ÿ”’ Password:", "โ€ขโ€ขโ€ขโ€ขโ€ขโ€ขโ€ขโ€ข" if config.get("clickhouse_password") else "Not set") + ch_table.add_row( + "๐Ÿ”’ Password:", + "โ€ขโ€ขโ€ขโ€ขโ€ขโ€ขโ€ขโ€ข" if config.get("clickhouse_password") else "Not set", + ) ch_table.add_row("๐Ÿ—„๏ธ Database:", config.get("clickhouse_database", "Not set")) - ch_table.add_row("๐Ÿ” Secure:", "Yes" if config.get("clickhouse_secure") else "No") - + ch_table.add_row( + "๐Ÿ” Secure:", "Yes" if config.get("clickhouse_secure") else "No" + ) + ch_panel = Panel( ch_table, title="[bold bright_cyan]ClickHouse Settings[/bold bright_cyan]", border_style="bright_cyan", - padding=(1, 2) + padding=(1, 2), ) - + # Agent Settings agent_table = Table(show_header=False, box=None, padding=(0, 2)) agent_table.add_column("Setting", style="bright_white") agent_table.add_column("Value", style="bright_cyan") - - agent_table.add_row("๐ŸŒก๏ธ Temperature:", str(config.get("temperature", "Not set"))) + + agent_table.add_row( + "๐ŸŒก๏ธ Temperature:", str(config.get("temperature", "Not set")) + ) agent_table.add_row("๐Ÿ“ Max Tokens:", str(config.get("max_tokens", "Not set"))) - agent_table.add_row("๐Ÿ”ง Max Tool Calls:", str(config.get("max_tool_calls", "Not set"))) - + agent_table.add_row( + "๐Ÿ”ง Max Tool Calls:", str(config.get("max_tool_calls", "Not set")) + ) + agent_panel = Panel( agent_table, title="[bold bright_green]Agent Settings[/bold bright_green]", border_style="bright_green", - padding=(1, 2) + padding=(1, 2), ) - + console.print() console.print(ai_panel) - console.print(ch_panel) + console.print(ch_panel) console.print(agent_panel) console.print() - + def edit_ai_provider(self): """Edit AI provider settings""" config = self.load_current_config() - + console.print() - ui.show_info("ClickHouse AI Agent uses your own local model. No provider changes needed.") + ui.show_info( + "ClickHouse AI Agent uses your own local model. No provider changes needed." + ) console.print() - + def edit_clickhouse(self): """Edit ClickHouse connection settings""" config = self.load_current_config() - + console.print() - if Confirm.ask("[bold bright_cyan]Update ClickHouse settings?[/bold bright_cyan]"): + if Confirm.ask( + "[bold bright_cyan]Update ClickHouse settings?[/bold bright_cyan]" + ): onboarding = OnboardingFlow() ch_config = onboarding.setup_clickhouse() config.update(ch_config) self.save_config(config) ui.show_success("ClickHouse settings updated!") - + def edit_agent_settings(self): """Edit agent behavior settings""" config = self.load_current_config() - + console.print() if Confirm.ask("[bold bright_green]Update agent settings?[/bold bright_green]"): - temp = float(Prompt.ask( - "[bold bright_green]Temperature (0.0-2.0)[/bold bright_green]", - default=str(config.get("temperature", 0.1)) - )) - - max_tokens = int(Prompt.ask( - "[bold bright_green]Max tokens per response[/bold bright_green]", - default=str(config.get("max_tokens", 4000)) - )) - - max_tool_calls = int(Prompt.ask( - "[bold bright_green]Max tool calls per conversation[/bold bright_green]", - default=str(config.get("max_tool_calls", 35)) - )) - - config.update({ - "temperature": temp, - "max_tokens": max_tokens, - "max_tool_calls": max_tool_calls - }) - + temp = float( + Prompt.ask( + "[bold bright_green]Temperature (0.0-2.0)[/bold bright_green]", + default=str(config.get("temperature", 0.1)), + ) + ) + + max_tokens = int( + Prompt.ask( + "[bold bright_green]Max tokens per response[/bold bright_green]", + default=str(config.get("max_tokens", 4000)), + ) + ) + + max_tool_calls = int( + Prompt.ask( + "[bold bright_green]Max tool calls per conversation[/bold bright_green]", + default=str(config.get("max_tool_calls", 35)), + ) + ) + + config.update( + { + "temperature": temp, + "max_tokens": max_tokens, + "max_tool_calls": max_tool_calls, + } + ) + self.save_config(config) ui.show_success("Agent settings updated!") - + def reset_config(self): """Reset configuration and run onboarding""" console.print() - if Confirm.ask("[bold red]โš ๏ธ Reset all settings and run setup again?[/bold red]"): + if Confirm.ask( + "[bold red]โš ๏ธ Reset all settings and run setup again?[/bold red]" + ): if self.config_file.exists(): self.config_file.unlink() if self.env_file.exists(): self.env_file.unlink() - + ui.show_success("Configuration reset! Running onboarding...") onboarding = OnboardingFlow() return onboarding.run_onboarding() return None - + def save_config(self, config: Dict[str, Any]): """Save configuration to file""" - with open(self.config_file, 'w') as f: + with open(self.config_file, "w") as f: json.dump(config, f, indent=2) - + def run_settings_menu(self): """Run interactive settings menu""" while True: console.clear() - + # Show title title_panel = Panel( Align.center( Text.assemble( ("โš™๏ธ ", "bold bright_yellow"), ("Proto Settings", "bold bright_cyan"), - (" โš™๏ธ", "bold bright_yellow") + (" โš™๏ธ", "bold bright_yellow"), ) ), border_style="bright_cyan", box=ROUNDED, - padding=(1, 2) + padding=(1, 2), ) console.print() console.print(title_panel) - + # Show current settings self.show_current_settings() - + # Show menu options menu_table = Table(show_header=False, box=ROUNDED, padding=(1, 2)) menu_table.add_column("Option", style="bold bright_cyan") menu_table.add_column("Action", style="bright_white") - + menu_table.add_row("1", "๐Ÿค– AI Provider Info") - menu_table.add_row("2", "๐Ÿ—„๏ธ Update ClickHouse Connection") + menu_table.add_row("2", "๐Ÿ—„๏ธ Update ClickHouse Connection") menu_table.add_row("3", "โš™๏ธ Adjust Agent Settings") menu_table.add_row("4", "๐Ÿ”„ Reset All Settings") menu_table.add_row("5", "โŒ Exit Settings") - + menu_panel = Panel( menu_table, title="[bold bright_yellow]Settings Menu[/bold bright_yellow]", - border_style="bright_yellow" + border_style="bright_yellow", ) console.print(menu_panel) console.print() - + choice = Prompt.ask( "[bold bright_cyan]What would you like to do?[/bold bright_cyan]", choices=["1", "2", "3", "4", "5"], - default="5" + default="5", ) - + if choice == "1": self.edit_ai_provider() elif choice == "2": @@ -241,9 +271,10 @@ def run_settings_menu(self): elif choice == "5": ui.show_success("Settings saved! ๐Ÿ‘‹") break - + if choice != "5": console.print() Prompt.ask("[dim]Press Enter to continue...[/dim]", default="") -from rich.align import Align \ No newline at end of file + +from rich.align import Align diff --git a/utils/__init__.py b/utils/__init__.py index 67b9db6..dd7ee44 100644 --- a/utils/__init__.py +++ b/utils/__init__.py @@ -1 +1 @@ -# Utils package \ No newline at end of file +# Utils package diff --git a/utils/logging.py b/utils/logging.py index 915e0e0..44d8703 100644 --- a/utils/logging.py +++ b/utils/logging.py @@ -12,13 +12,14 @@ console = Console() + def setup_logging(verbose: bool = False, quiet_mode: bool = False) -> None: """Setup structured logging with rich formatting""" - + # Create logs directory logs_dir = Path("logs") logs_dir.mkdir(exist_ok=True) - + # Configure structlog structlog.configure( processors=[ @@ -30,78 +31,85 @@ def setup_logging(verbose: bool = False, quiet_mode: bool = False) -> None: structlog.processors.StackInfoRenderer(), structlog.processors.format_exc_info, structlog.processors.UnicodeDecoder(), - structlog.processors.JSONRenderer() + structlog.processors.JSONRenderer(), ], context_class=dict, logger_factory=structlog.stdlib.LoggerFactory(), wrapper_class=structlog.stdlib.BoundLogger, cache_logger_on_first_use=True, ) - + # Setup Python logging import logging - + # File handler for detailed logs file_handler = logging.FileHandler(logs_dir / "proto.log") file_handler.setLevel(logging.DEBUG) file_formatter = logging.Formatter( - '%(asctime)s - %(name)s - %(levelname)s - %(message)s' + "%(asctime)s - %(name)s - %(levelname)s - %(message)s" ) file_handler.setFormatter(file_formatter) - + # Rich handler for console output (only if not in quiet mode) rich_handler = RichHandler( - console=console, - show_time=False, - show_path=False, - rich_tracebacks=True + console=console, show_time=False, show_path=False, rich_tracebacks=True ) rich_handler.setLevel(logging.DEBUG if verbose else logging.INFO) - + # Root logger configuration root_logger = logging.getLogger() root_logger.setLevel(logging.DEBUG) root_logger.addHandler(file_handler) - + # Only add console handler if not in quiet mode if not quiet_mode: root_logger.addHandler(rich_handler) else: # In quiet mode, set console logging to ERROR level only root_logger.setLevel(logging.ERROR) - + # Suppress noisy third-party loggers logging.getLogger("httpx").setLevel(logging.WARNING) logging.getLogger("urllib3").setLevel(logging.WARNING) logging.getLogger("clickhouse_connect").setLevel(logging.INFO) + def get_logger(name: str) -> Any: """Get a structured logger instance""" return structlog.get_logger(name) + class ToolExecutionLogger: """Logger for tool execution with rich formatting""" - + def __init__(self, console: Console): self.console = console self.logger = get_logger("tool_execution") - + def log_tool_start(self, tool_name: str, arguments: Dict[str, Any]): """Log tool execution start""" self.console.print(f"[blue]๐Ÿ”ง Executing tool:[/blue] [bold]{tool_name}[/bold]") self.logger.info("tool_execution_start", tool=tool_name, arguments=arguments) - + def log_tool_success(self, tool_name: str, result: str): """Log successful tool execution""" self.console.print(f"[green]โœ“ Tool completed:[/green] [bold]{tool_name}[/bold]") - self.logger.info("tool_execution_success", tool=tool_name, result_length=len(result)) - + self.logger.info( + "tool_execution_success", tool=tool_name, result_length=len(result) + ) + def log_tool_error(self, tool_name: str, error: str): """Log tool execution error""" - self.console.print(f"[red]โŒ Tool failed:[/red] [bold]{tool_name}[/bold] - {error}") + self.console.print( + f"[red]โŒ Tool failed:[/red] [bold]{tool_name}[/bold] - {error}" + ) self.logger.error("tool_execution_error", tool=tool_name, error=error) - + def log_query_execution(self, query: str, duration: float, rows: int): """Log ClickHouse query execution""" - self.console.print(f"[cyan]๐Ÿ“Š Query executed:[/cyan] {duration:.2f}s, {rows} rows") - self.logger.info("query_execution", query=query[:100], duration=duration, rows=rows) \ No newline at end of file + self.console.print( + f"[cyan]๐Ÿ“Š Query executed:[/cyan] {duration:.2f}s, {rows} rows" + ) + self.logger.info( + "query_execution", query=query[:100], duration=duration, rows=rows + ) From 30bd32dc8ffa141c68a29f56fd1d2137a7a5fc64 Mon Sep 17 00:00:00 2001 From: protosphinx <133899485+protosphinx@users.noreply.github.com> Date: Sat, 2 May 2026 12:57:46 -0700 Subject: [PATCH 2/6] test: add pytest suite for config, data_tools, and CLI Bootstraps clickr's first real test suite. The CI workflow at `.github/workflows/test.yml` already invokes `pytest tests/`, but the directory did not exist - so every "passing" green check on main was a no-op. This commit makes the green meaningful: 52 tests across three files, all running in <1s on Python 3.9. `tests/conftest.py` - shared fixtures. The autouse `_isolate_env` fixture strips `CLICKHOUSE_*` and `OPENROUTER_*` from `os.environ` and points `$HOME` at a `tmp_path` so the loader's probe for `~/.config/proto/proto-config.json` never finds the developer's real config. Without this, tests pass locally but fail in CI (or vice versa) depending on whose machine they run on. `tests/test_settings.py` (28 tests) - covers `ClickHouseConfig` defaults and validation, plus every input source `load_config` merges (file, env vars, CLI args, the legacy `clickhouse_*` key mapping from the onboarding flow), and the precedence between them (CLI > env > file). Also covers the silent `provider != "local"` coercion (so an old config doesn't break boot) and the `create_sample_config` round-trip. `tests/test_data_tools.py` (15 tests) - exercises every dtype branch of `_generate_create_table_sql` (UInt64 / Int64 / Float64 / Bool / DateTime / String fallback for mixed-object columns) and the structural invariants (IF NOT EXISTS, MergeTree engine, ORDER BY tuple(), backtick quoting for reserved-word columns, multi-column composition, empty-DataFrame degenerate case). The function does not touch `self.client`, so the fixture passes `None`. `tests/test_cli.py` (9 tests) - Typer `CliRunner` smoke tests: top-level `--help`, that documented subcommands appear in help (catches silent renames), every subcommand has a working `--help`, `version` returns 0, and unknown commands return non-zero. Does not exercise the LLM provider or ClickHouse client - those need real services. Methodology choices worth flagging: - No mocks for ClickHouse or the LLM. Mocked tests for an HTTP client tend to lock in the mock's view of reality and rot when the real service moves. The pure-data-validation tests cover what is actually testable without a server. - Tests use the public API (`load_config`, `ClickHouseConfig`, `_generate_create_table_sql`) rather than asserting against internal helpers. A refactor of internals should not need to touch this file. - Each test class groups behaviour, not function-under-test, so a reader can scan the test names for the contract instead of reading the code. --- tests/__init__.py | 0 tests/conftest.py | 47 +++++++ tests/test_cli.py | 56 ++++++++ tests/test_data_tools.py | 131 ++++++++++++++++++ tests/test_settings.py | 286 +++++++++++++++++++++++++++++++++++++++ 5 files changed, 520 insertions(+) create mode 100644 tests/__init__.py create mode 100644 tests/conftest.py create mode 100644 tests/test_cli.py create mode 100644 tests/test_data_tools.py create mode 100644 tests/test_settings.py diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..1395206 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,47 @@ +"""Shared pytest fixtures. + +The biggest hazard for clickr's test suite is the ambient environment: +``load_config`` reads ``CLICKHOUSE_*`` and ``OPENROUTER_*`` env vars and +also probes for a default config file under ``~/.config/proto/``. Both +make tests order-dependent and machine-dependent. The fixtures here +sandbox each test from both. +""" + +from __future__ import annotations + +import os +from pathlib import Path +from typing import Iterator + +import pytest + +_RELEVANT_PREFIXES = ("CLICKHOUSE_", "OPENROUTER_") + + +@pytest.fixture(autouse=True) +def _isolate_env(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> Iterator[None]: + """Strip clickr-relevant env vars and point ``$HOME`` at a temp dir. + + Every test runs against an empty environment for the prefixes that + ``load_config`` reads, and against a fresh ``$HOME`` so the + ``~/.config/proto/proto-config.json`` probe never finds the real + user's config. Without this, a developer who has clickr configured + locally would see different test results than CI. + """ + for key in list(os.environ): + if key.startswith(_RELEVANT_PREFIXES): + monkeypatch.delenv(key, raising=False) + monkeypatch.setenv("HOME", str(tmp_path)) + yield + + +@pytest.fixture +def cwd(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> Path: + """Run the test from a clean temp working directory. + + ``load_config`` also probes for a legacy ``proto-config.json`` in the + current directory. Tests that exercise the file-discovery logic want + a known empty cwd. + """ + monkeypatch.chdir(tmp_path) + return tmp_path diff --git a/tests/test_cli.py b/tests/test_cli.py new file mode 100644 index 0000000..c3b25c7 --- /dev/null +++ b/tests/test_cli.py @@ -0,0 +1,56 @@ +"""CLI smoke tests via Typer's ``CliRunner``. + +These do not exercise the LLM provider or the ClickHouse client โ€” that +would require a real model server and a real database. They check that +the Typer app is wired up correctly: subcommands resolve, ``--help`` +works on each, and the version flag returns a sane string. The goal +is to catch import-time breakage and command-registration regressions +that would otherwise only surface in the user's first session. +""" + +from __future__ import annotations + +import pytest +from typer.testing import CliRunner + +from main import app + + +@pytest.fixture +def runner() -> CliRunner: + return CliRunner() + + +class TestCliSurface: + def test_help_returns_zero(self, runner: CliRunner) -> None: + result = runner.invoke(app, ["--help"]) + assert result.exit_code == 0 + assert "Usage:" in result.output + + def test_help_lists_documented_subcommands(self, runner: CliRunner) -> None: + # Spot-check the subcommands that are part of the documented + # surface โ€” a rename on any of these is a user-visible + # breaking change and should fail this test. + result = runner.invoke(app, ["--help"]) + for cmd in ("chat", "query", "analyze", "load-data", "settings", "version"): + assert cmd in result.output, f"subcommand '{cmd}' missing from help" + + @pytest.mark.parametrize( + "cmd", + ["chat", "query", "analyze", "load-data", "settings", "clear", "version"], + ) + def test_subcommand_help_returns_zero(self, runner: CliRunner, cmd: str) -> None: + result = runner.invoke(app, [cmd, "--help"]) + assert result.exit_code == 0 + assert "Usage:" in result.output + + def test_version_subcommand_returns_zero(self, runner: CliRunner) -> None: + result = runner.invoke(app, ["version"]) + assert result.exit_code == 0 + + def test_unknown_subcommand_returns_nonzero(self, runner: CliRunner) -> None: + # Belt-and-braces: Typer normally returns 2 for unknown + # commands. If a future change swallows the error and exits 0, + # this test catches it. + result = runner.invoke(app, ["definitely-not-a-real-command"]) + assert result.exit_code != 0 diff --git a/tests/test_data_tools.py b/tests/test_data_tools.py new file mode 100644 index 0000000..16e3be4 --- /dev/null +++ b/tests/test_data_tools.py @@ -0,0 +1,131 @@ +"""Tests for ``tools.data_tools._generate_create_table_sql``. + +This is the only pure function in ``data_tools`` โ€” every other entry +point either calls a real ClickHouse client or runs a real Plotly / +matplotlib pipeline. Keeping the test surface tight here avoids +slow, flaky tests; the CREATE TABLE generator is the bug-prone piece +because pandas dtype inference changes between releases. +""" + +from __future__ import annotations + +from datetime import datetime + +import pandas as pd +import pytest + +from tools.data_tools import DataLoader + + +@pytest.fixture +def loader() -> DataLoader: + """A DataLoader whose client is None. + + ``_generate_create_table_sql`` does not touch ``self.client``, so + passing None is safe and keeps the test independent of the + ClickHouse client library. + """ + return DataLoader(client=None) + + +class TestColumnTypeMapping: + def test_unsigned_integer_column_maps_to_uint64(self, loader: DataLoader) -> None: + df = pd.DataFrame({"x": [1, 2, 3]}) + sql = loader._generate_create_table_sql(df, "t") + assert "`x` UInt64" in sql + + def test_signed_integer_column_maps_to_int64(self, loader: DataLoader) -> None: + df = pd.DataFrame({"x": [-1, 2, 3]}) + sql = loader._generate_create_table_sql(df, "t") + assert "`x` Int64" in sql + + def test_float_column_maps_to_float64(self, loader: DataLoader) -> None: + df = pd.DataFrame({"x": [1.5, 2.5, 3.5]}) + sql = loader._generate_create_table_sql(df, "t") + assert "`x` Float64" in sql + + def test_bool_column_maps_to_bool(self, loader: DataLoader) -> None: + df = pd.DataFrame({"x": [True, False, True]}) + sql = loader._generate_create_table_sql(df, "t") + assert "`x` Bool" in sql + + def test_datetime_column_maps_to_datetime(self, loader: DataLoader) -> None: + df = pd.DataFrame( + {"x": pd.to_datetime(["2024-01-01", "2024-01-02", "2024-01-03"])} + ) + sql = loader._generate_create_table_sql(df, "t") + assert "`x` DateTime" in sql + + def test_object_column_maps_to_string(self, loader: DataLoader) -> None: + df = pd.DataFrame({"x": ["a", "b", "c"]}) + sql = loader._generate_create_table_sql(df, "t") + assert "`x` String" in sql + + def test_mixed_object_column_falls_back_to_string(self, loader: DataLoader) -> None: + # Mixed-type object columns are common in CSV ingestion; they + # must not crash the generator. + df = pd.DataFrame({"x": ["a", 1, datetime(2024, 1, 1)]}, dtype=object) + sql = loader._generate_create_table_sql(df, "t") + assert "`x` String" in sql + + +class TestSqlStructure: + def test_includes_if_not_exists_clause(self, loader: DataLoader) -> None: + # The generator is called repeatedly during CSV ingestion; + # IF NOT EXISTS is what makes that idempotent. + df = pd.DataFrame({"x": [1]}) + sql = loader._generate_create_table_sql(df, "t") + assert "CREATE TABLE IF NOT EXISTS" in sql + + def test_uses_mergetree_engine(self, loader: DataLoader) -> None: + df = pd.DataFrame({"x": [1]}) + sql = loader._generate_create_table_sql(df, "t") + assert "ENGINE = MergeTree()" in sql + + def test_orders_by_tuple_when_no_key_supplied(self, loader: DataLoader) -> None: + # ORDER BY tuple() is ClickHouse's "I have no opinion about + # ordering" idiom; the generator must emit it (an empty + # ORDER BY would be a syntax error). + df = pd.DataFrame({"x": [1]}) + sql = loader._generate_create_table_sql(df, "t") + assert "ORDER BY tuple()" in sql + + def test_table_name_is_substituted(self, loader: DataLoader) -> None: + df = pd.DataFrame({"x": [1]}) + sql = loader._generate_create_table_sql(df, "my_events") + assert "my_events" in sql + + def test_columns_are_backtick_quoted(self, loader: DataLoader) -> None: + # Backticks let column names contain reserved words and + # punctuation. The generator must use them unconditionally. + df = pd.DataFrame({"order": [1], "select": ["x"]}) + sql = loader._generate_create_table_sql(df, "t") + assert "`order`" in sql + assert "`select`" in sql + + +class TestMixedSchemas: + def test_multi_column_schema_emits_each_column(self, loader: DataLoader) -> None: + df = pd.DataFrame( + { + "id": [1, 2, 3], + "amount": [1.5, 2.5, 3.5], + "active": [True, False, True], + "name": ["a", "b", "c"], + } + ) + sql = loader._generate_create_table_sql(df, "t") + assert "`id` UInt64" in sql + assert "`amount` Float64" in sql + assert "`active` Bool" in sql + assert "`name` String" in sql + + def test_empty_dataframe_emits_a_table_with_no_columns( + self, loader: DataLoader + ) -> None: + # An empty DataFrame is a degenerate case but must not crash; + # ClickHouse will reject the resulting CREATE TABLE itself + # with a clearer error than a Python traceback. + df = pd.DataFrame() + sql = loader._generate_create_table_sql(df, "t") + assert "CREATE TABLE IF NOT EXISTS t" in sql diff --git a/tests/test_settings.py b/tests/test_settings.py new file mode 100644 index 0000000..a054913 --- /dev/null +++ b/tests/test_settings.py @@ -0,0 +1,286 @@ +"""Tests for ``config.settings.ClickHouseConfig`` and ``load_config``. + +The Pydantic model itself is mostly validated by Pydantic, so the model +tests focus on defaults and the small amount of custom behaviour +clickr layers on top. The ``load_config`` tests exercise the four +input sources it merges โ€” file, environment, CLI args, and the +"clickhouse_*" key normalisation that comes from the onboarding flow โ€” +and the precedence between them. +""" + +from __future__ import annotations + +import json +from pathlib import Path + +import pytest + +from config.settings import ClickHouseConfig, create_sample_config, load_config + +# --------------------------------------------------------------------------- +# ClickHouseConfig defaults +# --------------------------------------------------------------------------- + + +class TestClickHouseConfigDefaults: + def test_host_defaults_to_localhost(self) -> None: + assert ClickHouseConfig().host == "localhost" + + def test_port_defaults_to_8123(self) -> None: + assert ClickHouseConfig().port == 8123 + + def test_username_defaults_to_default(self) -> None: + assert ClickHouseConfig().username == "default" + + def test_password_defaults_to_empty_string(self) -> None: + assert ClickHouseConfig().password == "" + + def test_database_defaults_to_default(self) -> None: + assert ClickHouseConfig().database == "default" + + def test_secure_defaults_to_false(self) -> None: + assert ClickHouseConfig().secure is False + + def test_provider_defaults_to_local(self) -> None: + assert ClickHouseConfig().provider == "local" + + def test_local_llm_url_defaults_to_loopback(self) -> None: + # Loopback so that out-of-the-box clickr never accidentally + # reaches a non-local LLM. + cfg = ClickHouseConfig() + assert cfg.local_llm_base_url.startswith("http://127.0.0.1") + + def test_temperature_defaults_low(self) -> None: + # Low temperature is intentional for SQL generation determinism. + assert ClickHouseConfig().temperature <= 0.2 + + +class TestClickHouseConfigValidation: + def test_explicit_values_are_kept(self) -> None: + cfg = ClickHouseConfig( + host="ch.example.com", + port=9000, + username="reader", + password="secret", + database="analytics", + secure=True, + ) + assert cfg.host == "ch.example.com" + assert cfg.port == 9000 + assert cfg.username == "reader" + assert cfg.password == "secret" + assert cfg.database == "analytics" + assert cfg.secure is True + + def test_port_must_be_an_integer(self) -> None: + from pydantic import ValidationError + + with pytest.raises(ValidationError): + ClickHouseConfig(port="not-a-port") # type: ignore[arg-type] + + def test_secure_string_coerces_via_pydantic(self) -> None: + # Pydantic v2 coerces "true"/"false" strings to bool by default; + # this asserts that contract so an upstream Pydantic change + # that breaks env-var parsing surfaces here. + assert ClickHouseConfig(secure="true").secure is True # type: ignore[arg-type] + assert ClickHouseConfig(secure="false").secure is False # type: ignore[arg-type] + + +# --------------------------------------------------------------------------- +# load_config: file source +# --------------------------------------------------------------------------- + + +class TestLoadConfigFromFile: + def test_explicit_config_file_overrides_defaults( + self, tmp_path: Path, cwd: Path + ) -> None: + cfg_path = tmp_path / "clickr.json" + cfg_path.write_text( + json.dumps( + { + "host": "from-file.example.com", + "port": 9001, + "database": "events", + } + ) + ) + cfg = load_config(config_file=cfg_path) + assert cfg.host == "from-file.example.com" + assert cfg.port == 9001 + assert cfg.database == "events" + + def test_legacy_proto_config_in_cwd_is_picked_up(self, cwd: Path) -> None: + (cwd / "proto-config.json").write_text( + json.dumps({"host": "legacy.example.com", "port": 9002}) + ) + cfg = load_config() + assert cfg.host == "legacy.example.com" + assert cfg.port == 9002 + + def test_default_config_in_xdg_home_is_picked_up( + self, tmp_path: Path, cwd: Path + ) -> None: + # ``$HOME`` is patched by the autouse fixture to ``tmp_path``. + target = tmp_path / ".config" / "proto" + target.mkdir(parents=True) + (target / "proto-config.json").write_text( + json.dumps({"host": "xdg.example.com"}) + ) + cfg = load_config() + assert cfg.host == "xdg.example.com" + + def test_missing_explicit_file_falls_through_to_default(self, cwd: Path) -> None: + # Pointing at a non-existent path must not raise; the loader + # silently falls through to the next source. + cfg = load_config(config_file=Path("/nonexistent/clickr.json")) + assert cfg.host == "localhost" + + +# --------------------------------------------------------------------------- +# load_config: environment variables +# --------------------------------------------------------------------------- + + +class TestLoadConfigFromEnv: + def test_clickhouse_env_vars_are_read( + self, monkeypatch: pytest.MonkeyPatch, cwd: Path + ) -> None: + monkeypatch.setenv("CLICKHOUSE_HOST", "env.example.com") + monkeypatch.setenv("CLICKHOUSE_PORT", "9003") + monkeypatch.setenv("CLICKHOUSE_DATABASE", "metrics") + cfg = load_config() + assert cfg.host == "env.example.com" + assert cfg.port == 9003 + assert cfg.database == "metrics" + + def test_openrouter_api_key_env_var_is_read( + self, monkeypatch: pytest.MonkeyPatch, cwd: Path + ) -> None: + monkeypatch.setenv("OPENROUTER_API_KEY", "sk-test-12345") + cfg = load_config() + assert cfg.openrouter_api_key == "sk-test-12345" + + def test_openrouter_model_env_var_is_read( + self, monkeypatch: pytest.MonkeyPatch, cwd: Path + ) -> None: + monkeypatch.setenv("OPENROUTER_MODEL", "anthropic/claude-haiku") + cfg = load_config() + assert cfg.openrouter_model == "anthropic/claude-haiku" + + +# --------------------------------------------------------------------------- +# load_config: precedence +# --------------------------------------------------------------------------- + + +class TestLoadConfigPrecedence: + def test_cli_args_override_env( + self, monkeypatch: pytest.MonkeyPatch, cwd: Path + ) -> None: + monkeypatch.setenv("CLICKHOUSE_HOST", "from-env") + cfg = load_config(host="from-cli") + assert cfg.host == "from-cli" + + def test_cli_args_override_file(self, tmp_path: Path, cwd: Path) -> None: + cfg_path = tmp_path / "clickr.json" + cfg_path.write_text(json.dumps({"host": "from-file", "port": 9000})) + cfg = load_config(config_file=cfg_path, port=9999) + assert cfg.host == "from-file" # not overridden + assert cfg.port == 9999 # CLI wins + + def test_env_overrides_file( + self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch, cwd: Path + ) -> None: + cfg_path = tmp_path / "clickr.json" + cfg_path.write_text(json.dumps({"host": "from-file"})) + monkeypatch.setenv("CLICKHOUSE_HOST", "from-env") + cfg = load_config(config_file=cfg_path) + assert cfg.host == "from-env" + + +# --------------------------------------------------------------------------- +# load_config: onboarding key normalisation +# --------------------------------------------------------------------------- + + +class TestKeyNormalization: + def test_clickhouse_prefixed_keys_in_file_are_normalised( + self, tmp_path: Path, cwd: Path + ) -> None: + # The onboarding flow writes "clickhouse_host" etc. into the + # config file; the loader must normalise these to bare keys + # (matching the ``ClickHouseConfig`` field names). + cfg_path = tmp_path / "clickr.json" + cfg_path.write_text( + json.dumps( + { + "clickhouse_host": "norm.example.com", + "clickhouse_port": 9004, + "clickhouse_database": "raw", + "clickhouse_secure": True, + } + ) + ) + cfg = load_config(config_file=cfg_path) + assert cfg.host == "norm.example.com" + assert cfg.port == 9004 + assert cfg.database == "raw" + assert cfg.secure is True + + def test_normalisation_does_not_clobber_already_normalised_key( + self, tmp_path: Path, cwd: Path + ) -> None: + # If both the prefixed and the bare key are present, the bare + # key wins (loader treats the prefixed one as a fallback). + cfg_path = tmp_path / "clickr.json" + cfg_path.write_text( + json.dumps({"clickhouse_host": "fallback", "host": "primary"}) + ) + cfg = load_config(config_file=cfg_path) + assert cfg.host == "primary" + + +# --------------------------------------------------------------------------- +# load_config: provider enforcement +# --------------------------------------------------------------------------- + + +class TestProviderEnforcement: + def test_explicit_local_provider_is_kept(self, tmp_path: Path, cwd: Path) -> None: + cfg_path = tmp_path / "clickr.json" + cfg_path.write_text(json.dumps({"provider": "local"})) + cfg = load_config(config_file=cfg_path) + assert cfg.provider == "local" + + def test_non_local_provider_is_force_switched_to_local( + self, tmp_path: Path, cwd: Path + ) -> None: + # Only the local provider is supported today; the loader must + # silently coerce other values rather than raise (so an old + # config file with provider=openrouter still boots). + cfg_path = tmp_path / "clickr.json" + cfg_path.write_text(json.dumps({"provider": "openrouter"})) + cfg = load_config(config_file=cfg_path) + assert cfg.provider == "local" + + +# --------------------------------------------------------------------------- +# create_sample_config +# --------------------------------------------------------------------------- + + +class TestCreateSampleConfig: + def test_writes_a_loadable_json_file(self, cwd: Path) -> None: + # The sample config should round-trip through load_config + # without errors โ€” it is the on-ramp for new users. + create_sample_config() + sample_path = cwd / "proto-config.json" + assert sample_path.exists() + data = json.loads(sample_path.read_text()) + # Spot-check a few fields the sample is documented to set. + assert data["host"] == "localhost" + assert data["port"] == 8123 + # The sample should load cleanly into a ClickHouseConfig. + cfg = load_config(config_file=sample_path) + assert cfg.host == "localhost" From 99590ce563f6720f4d1cdc66a1e9a0b965930dc3 Mon Sep 17 00:00:00 2001 From: protosphinx <133899485+protosphinx@users.noreply.github.com> Date: Sat, 2 May 2026 12:59:53 -0700 Subject: [PATCH 3/6] fix(packaging): use legacy license-table syntax so install works on Python 3.8 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit `license = "MIT"` is the new PEP 639 SPDX-string form, only understood by setuptools >= 77 (released Feb 2025). The CI matrix includes Python 3.8, whose bundled setuptools predates that โ€” so `pip install -e ".[dev]"` fails on the ubuntu-latest 3.8 job with: ValueError: invalid pyproject.toml config: `project.license`. configuration error: `project.license` must be valid exactly by one definition (2 matches found): - keys: 'file': {type: string} - keys: 'text': {type: string} `license = { text = "MIT" }` is the pre-PEP-639 table form, understood by every setuptools version that supports pyproject.toml, so it works across the full 3.8 - 3.12 matrix without forcing a setuptools-upgrade step into the CI install. Same metadata reaches PyPI either way (`License: MIT` classifier). --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index da74774..e0dc994 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,7 @@ name = "clickr" version = "1.0.3" description = "Natural-language CLI for ClickHouseยฎ โ€” text-to-SQL with local or cloud LLMs. Not affiliated with ClickHouse, Inc." readme = "README.md" -license = "MIT" +license = { text = "MIT" } authors = [ {name = "ERPโ€ขAI", email = "team@erp.ai"} ] From acd2a9d68e545253c5db588b9845dd89602c70b7 Mon Sep 17 00:00:00 2001 From: protosphinx <133899485+protosphinx@users.noreply.github.com> Date: Sat, 2 May 2026 13:02:17 -0700 Subject: [PATCH 4/6] chore: drop Python 3.8 (EOL Oct 2024) from CI matrix and metadata MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The latest `clickhouse-connect` (a transitive dep) uses PEP 585 generic-subscript syntax (`list[IPv6Address]`) without bumping its own `requires-python` past 3.8 โ€” so on Python 3.8 pip resolves to a version of clickhouse-connect that fails to import with `TypeError: 'type' object is not subscriptable`. The Python 3.8 test job was failing before this PR for that reason; the matrix ran for years on luck (no test ever imported clickhouse-connect). Python 3.8 reached end-of-life on 2024-10-07. Dropping it across: - `.github/workflows/test.yml` matrix: 3.8 โ†’ 3.9 floor - `pyproject.toml`: `requires-python`, classifier, mypy `python_version`, black `target-version` - `setup.py`: `python_requires`, classifier Cuts the matrix from 15 jobs to 12 and matches the practical reality of the dep tree. Anyone genuinely on 3.8 can still install older clickr versions from PyPI; new 3.9+ becomes the supported floor. --- .github/workflows/test.yml | 2 +- pyproject.toml | 7 +++---- setup.py | 3 +-- 3 files changed, 5 insertions(+), 7 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index a611992..be72715 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -12,7 +12,7 @@ jobs: strategy: matrix: os: [ubuntu-latest, macos-latest, windows-latest] - python-version: [3.8, 3.9, '3.10', '3.11', '3.12'] + python-version: ['3.9', '3.10', '3.11', '3.12'] steps: - uses: actions/checkout@v4 diff --git a/pyproject.toml b/pyproject.toml index e0dc994..b2872e0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,7 +22,6 @@ classifiers = [ "Operating System :: OS Independent", "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", @@ -31,7 +30,7 @@ classifiers = [ "Topic :: Scientific/Engineering :: Information Analysis", "Topic :: Software Development :: Libraries :: Python Modules", ] -requires-python = ">=3.8" +requires-python = ">=3.9" dependencies = [ "typer[all]>=0.9.0", "rich>=13.0.0", @@ -82,7 +81,7 @@ py-modules = ["main"] [tool.black] line-length = 88 -target-version = ['py38'] +target-version = ['py39'] include = '\.pyi?$' extend-exclude = ''' /( @@ -105,7 +104,7 @@ line_length = 88 known_first_party = ["agent", "config", "providers", "tools", "ui", "utils"] [tool.mypy] -python_version = "3.8" +python_version = "3.9" warn_return_any = true warn_unused_configs = true disallow_untyped_defs = true diff --git a/setup.py b/setup.py index ffc0fdd..f3b783f 100644 --- a/setup.py +++ b/setup.py @@ -30,7 +30,6 @@ "License :: OSI Approved :: MIT License", "Operating System :: OS Independent", "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", @@ -38,7 +37,7 @@ "Topic :: Scientific/Engineering :: Information Analysis", "Topic :: Software Development :: Libraries :: Python Modules", ], - python_requires=">=3.8", + python_requires=">=3.9", install_requires=[ "typer[all]>=0.9.0", "rich>=13.0.0", From d941066d9974259046d24c251a3dc60c513d6dd0 Mon Sep 17 00:00:00 2001 From: protosphinx <133899485+protosphinx@users.noreply.github.com> Date: Sat, 2 May 2026 13:04:09 -0700 Subject: [PATCH 5/6] fix(packaging): read README.md with explicit utf-8 encoding for Windows MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit `Path.read_text()` without `encoding=` falls back to `locale.getpreferredencoding()` โ€” UTF-8 on macOS/Linux, but cp1252 on Windows. README.md contains the ERPโ€ขAI brand mark (`โ€ข`) and em-dashes, which fail to decode as cp1252 and crash the install: UnicodeDecodeError: 'charmap' codec can't decode byte 0x9d in position 480: character maps to Pinning `encoding="utf-8"` makes the install reliable across the matrix. The same bug pattern bites every Python project that does `open("README.md").read()` without encoding; would not surface until someone runs CI (or installs) on Windows. --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index f3b783f..fe7eb5c 100644 --- a/setup.py +++ b/setup.py @@ -9,7 +9,7 @@ # Read the README file this_directory = Path(__file__).parent -long_description = (this_directory / "README.md").read_text() +long_description = (this_directory / "README.md").read_text(encoding="utf-8") setup( name="clickr", From f7ead4241c6400301aafd0327a5a3d81c82e5164 Mon Sep 17 00:00:00 2001 From: protosphinx <133899485+protosphinx@users.noreply.github.com> Date: Sat, 2 May 2026 13:06:42 -0700 Subject: [PATCH 6/6] test(conftest): point USERPROFILE + HOMEDRIVE/HOMEPATH at tmp_path for Windows MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit `Path.home()` reads HOME on POSIX but USERPROFILE on Windows (with HOMEDRIVE + HOMEPATH as a fallback). The autouse fixture only set HOME, so on Windows runners the loader's `~/.config/proto/...` probe escaped the sandbox and resolved to the real runner's home โ€” where there is no proto-config.json โ€” and test_default_config_in_xdg_home_is_picked_up failed with `assert 'localhost' == 'xdg.example.com'`. Setting all three env vars (HOME, USERPROFILE, HOMEDRIVE+HOMEPATH) makes the fixture portable across the full ubuntu/macos/windows matrix without forking the test or skipping on Windows. --- tests/conftest.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/tests/conftest.py b/tests/conftest.py index 1395206..912b575 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -31,7 +31,14 @@ def _isolate_env(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> Iterator[No for key in list(os.environ): if key.startswith(_RELEVANT_PREFIXES): monkeypatch.delenv(key, raising=False) + # `Path.home()` reads HOME on POSIX but USERPROFILE on Windows + # (with HOMEDRIVE + HOMEPATH as a fallback). Set all three to + # tmp_path so the loader's `~/.config/proto/...` probe lands in + # the sandbox on every OS the CI matrix runs. monkeypatch.setenv("HOME", str(tmp_path)) + monkeypatch.setenv("USERPROFILE", str(tmp_path)) + monkeypatch.setenv("HOMEDRIVE", str(tmp_path.drive) if tmp_path.drive else "") + monkeypatch.setenv("HOMEPATH", str(tmp_path)[len(tmp_path.drive) :]) yield