diff --git a/cli/README.md b/cli/README.md index 8135904ee..265250071 100644 --- a/cli/README.md +++ b/cli/README.md @@ -26,6 +26,7 @@ pip install 'nao-core[athena]' pip install 'nao-core[trino]' pip install 'nao-core[redshift]' pip install 'nao-core[fabric]' +pip install 'nao-core[starrocks]' # LLM providers pip install 'nao-core[openai]' @@ -83,7 +84,7 @@ nao init This will create a new nao project in the current directory. It will prompt you for a project name and ask you to configure: -- **Database connections** (BigQuery, DuckDB, Databricks, Snowflake, PostgreSQL, Redshift, MSSQL, Trino) +- **Database connections** (BigQuery, DuckDB, Databricks, Snowflake, PostgreSQL, Redshift, MSSQL, Trino, StarRocks) - **Git repositories** to sync - **LLM provider** (OpenAI, Anthropic, Mistral, Gemini, OpenRouter, Ollama) - **`ai_summary` template + model** (prompted only when you enable `ai_summary` for databases) diff --git a/cli/nao_core/config/__init__.py b/cli/nao_core/config/__init__.py index 37900b931..7bdfe865b 100644 --- a/cli/nao_core/config/__init__.py +++ b/cli/nao_core/config/__init__.py @@ -10,6 +10,7 @@ PostgresConfig, RedshiftConfig, SnowflakeConfig, + StarRocksConfig, TrinoConfig, ) from .exceptions import InitError @@ -28,6 +29,7 @@ "PostgresConfig", "MssqlConfig", "RedshiftConfig", + "StarRocksConfig", "TrinoConfig", "DatabaseType", "LLMConfig", diff --git a/cli/nao_core/config/databases/__init__.py b/cli/nao_core/config/databases/__init__.py index cab52c7a7..0e3c32c8a 100644 --- a/cli/nao_core/config/databases/__init__.py +++ b/cli/nao_core/config/databases/__init__.py @@ -14,6 +14,7 @@ from .postgres import PostgresConfig from .redshift import RedshiftConfig from .snowflake import SnowflakeConfig +from .starrocks import StarRocksConfig from .trino import TrinoConfig # ============================================================================= @@ -33,6 +34,7 @@ Annotated[MssqlConfig, Tag("mssql")], Annotated[PostgresConfig, Tag("postgres")], Annotated[RedshiftConfig, Tag("redshift")], + Annotated[StarRocksConfig, Tag("starrocks")], Annotated[TrinoConfig, Tag("trino")], ], Discriminator("type"), @@ -52,6 +54,7 @@ DatabaseType.SNOWFLAKE: SnowflakeConfig, DatabaseType.POSTGRES: PostgresConfig, DatabaseType.REDSHIFT: RedshiftConfig, + DatabaseType.STARROCKS: StarRocksConfig, DatabaseType.TRINO: TrinoConfig, } @@ -86,6 +89,7 @@ def parse_database_config(data: dict) -> AnyDatabaseConfig: "MssqlConfig", "MysqlConfig", "SnowflakeConfig", + "StarRocksConfig", "PostgresConfig", "RedshiftConfig", "TrinoConfig", diff --git a/cli/nao_core/config/databases/base.py b/cli/nao_core/config/databases/base.py index bf4e8377a..4b3e690a6 100644 --- a/cli/nao_core/config/databases/base.py +++ b/cli/nao_core/config/databases/base.py @@ -28,6 +28,7 @@ class DatabaseType(str, Enum): MYSQL = "mysql" POSTGRES = "postgres" REDSHIFT = "redshift" + STARROCKS = "starrocks" TRINO = "trino" @classmethod diff --git a/cli/nao_core/config/databases/starrocks.py b/cli/nao_core/config/databases/starrocks.py new file mode 100644 index 000000000..f5c2d1722 --- /dev/null +++ b/cli/nao_core/config/databases/starrocks.py @@ -0,0 +1,345 @@ +from __future__ import annotations + +import fnmatch +import re +from typing import TYPE_CHECKING, Any, Literal + +from pydantic import Field + +from nao_core.config.exceptions import InitError +from nao_core.deps import require_dependency +from nao_core.ui import ask_text + +if TYPE_CHECKING: + from mysql.connector import MySQLConnection + +from .base import DatabaseConfig +from .context import DatabaseContext + +DEFAULT_CATALOG = "default_catalog" +SYSTEM_SCHEMAS = ("information_schema", "sys", "_statistics_") + + +def _quote_identifier(value: str) -> str: + escaped = value.replace("`", "``") + return f"`{escaped}`" + + +def _quote_literal(value: str) -> str: + escaped = value.replace("\\", "\\\\").replace("'", "''") + return f"'{escaped}'" + + +def _extract_table_comment_from_ddl(ddl: str) -> str | None: + match = re.search( + r'^\s*COMMENT\s*(?:\(\s*)?"((?:\\.|[^"\\])*)"(?:\s*\))?\s*,?\s*$', + ddl, + re.IGNORECASE | re.MULTILINE, + ) + if not match: + return None + unescaped = match.group(1).replace('\\"', '"').replace("\\\\", "\\") + return unescaped.strip() or None + + +def _split_schema_identifier(identifier: str, default_catalog: str = DEFAULT_CATALOG) -> tuple[str, str]: + if "." in identifier: + catalog, schema = identifier.split(".", 1) + return catalog, schema + return default_catalog, identifier + + +class StarRocksBackend: + """Lightweight backend adapter over mysql-connector for StarRocks.""" + + def __init__(self, conn: MySQLConnection, default_catalog: str) -> None: + self._conn = conn + self._default_catalog = default_catalog + + def raw_sql(self, sql: str): + cursor = self._conn.cursor() + cursor.execute(sql) + return cursor + + def list_catalogs(self) -> list[str]: + rows = self.raw_sql("SHOW CATALOGS").fetchall() + catalogs = [str(row[0]) for row in rows if row and row[0]] + return sorted(set(catalogs)) + + def list_databases(self, catalog: str) -> list[str]: + rows = self.raw_sql(f"SHOW DATABASES FROM {_quote_identifier(catalog)}").fetchall() + return [str(row[0]) for row in rows if row and row[0]] + + def list_tables(self, database: str) -> list[str]: + catalog, schema = _split_schema_identifier(database, default_catalog=self._default_catalog) + rows = self.raw_sql(f"SHOW TABLES FROM {_quote_identifier(catalog)}.{_quote_identifier(schema)}").fetchall() + return [str(row[0]) for row in rows if row and row[0]] + + def disconnect(self) -> None: + self._conn.close() + + +class StarRocksDatabaseContext(DatabaseContext): + """StarRocks context using information_schema metadata queries.""" + + def __init__(self, conn, schema: str, table_name: str, default_catalog: str = DEFAULT_CATALOG): + catalog, db_schema = _split_schema_identifier(schema, default_catalog=default_catalog) + self._catalog = catalog + super().__init__(conn, db_schema, table_name) + + def _quote(self, name: str) -> str: + return _quote_identifier(name) + + def _qualified_table_sql(self) -> str: + return ".".join((self._quote(self._catalog), self._quote(self._schema), self._quote(self._table_name))) + + def _show_create_table_sql(self) -> str: + return f"SHOW CREATE TABLE {self._qualified_table_sql()}" + + def _description_from_information_schema(self) -> str | None: + query = f""" + SELECT TABLE_COMMENT + FROM information_schema.TABLES + WHERE TABLE_CATALOG = {_quote_literal(self._catalog)} + AND TABLE_SCHEMA = {_quote_literal(self._schema)} + AND TABLE_NAME = {_quote_literal(self._table_name)} + LIMIT 1 + """ + row = self._conn.raw_sql(query).fetchone() # type: ignore[union-attr] + if row and row[0]: + return str(row[0]).strip() or None + return None + + def _description_from_show_create(self) -> str | None: + row = self._conn.raw_sql(self._show_create_table_sql()).fetchone() # type: ignore[union-attr] + if not row: + return None + ddl = str(row[-1]).strip() if row[-1] else "" + if not ddl: + return None + return _extract_table_comment_from_ddl(ddl) + + def description(self) -> str | None: + try: + if desc := self._description_from_information_schema(): + return desc + except Exception: + pass + try: + if desc := self._description_from_show_create(): + return desc + except Exception: + pass + return None + + def _columns_from_information_schema(self) -> list[dict[str, Any]]: + query = f""" + SELECT COLUMN_NAME, DATA_TYPE, IS_NULLABLE, COLUMN_COMMENT + FROM information_schema.COLUMNS + WHERE TABLE_CATALOG = {_quote_literal(self._catalog)} + AND TABLE_SCHEMA = {_quote_literal(self._schema)} + AND TABLE_NAME = {_quote_literal(self._table_name)} + ORDER BY ORDINAL_POSITION + """ + rows = self._conn.raw_sql(query).fetchall() # type: ignore[union-attr] + return [ + { + "name": str(row[0]), + "type": str(row[1]), + "nullable": str(row[2]).upper() == "YES", + "description": str(row[3]).strip() if row[3] else None, + } + for row in rows + ] + + def _columns_from_show_full_columns(self) -> list[dict[str, Any]]: + cursor = self._conn.raw_sql(f"SHOW FULL COLUMNS FROM {self._qualified_table_sql()}") # type: ignore[union-attr] + rows = cursor.fetchall() + description = getattr(cursor, "description", None) or [] + columns = {str(desc[0]).lower(): idx for idx, desc in enumerate(description) if desc and desc[0]} + + nullable_idx = columns.get("null", 3) + comment_idx = columns.get("comment", len(rows[0]) - 1 if rows else -1) + + return [ + { + "name": str(row[0]), + "type": str(row[1]), + "nullable": str(row[nullable_idx]).upper() == "YES", + "description": str(row[comment_idx]).strip() if comment_idx >= 0 and row[comment_idx] else None, + } + for row in rows + if row and row[0] + ] + + def columns(self) -> list[dict[str, Any]]: + try: + columns = self._columns_from_information_schema() + if columns: + return columns + except Exception: + pass + try: + return self._columns_from_show_full_columns() + except Exception: + return [] + + def row_count(self) -> int: + try: + row = self._conn.raw_sql(f"SELECT COUNT(*) FROM {self._qualified_table_sql()}").fetchone() # type: ignore[union-attr] + return int(row[0]) if row and row[0] is not None else 0 + except Exception: + return 0 + + def preview(self, limit: int = 10) -> list[dict[str, Any]]: + safe_limit = max(0, int(limit)) + cursor = self._conn.raw_sql(f"SELECT * FROM {self._qualified_table_sql()} LIMIT {safe_limit}") # type: ignore[union-attr] + rows = cursor.fetchall() + columns = [desc[0] for desc in cursor.description] if cursor.description else [] + out: list[dict[str, Any]] = [] + for row in rows: + record = dict(zip(columns, row, strict=False)) + for key, value in record.items(): + if value is not None and not isinstance(value, (str, int, float, bool, list, dict)): + record[key] = str(value) + out.append(record) + return out + + def _build_profiling_query(self, col: dict) -> str: + col_sql = self._quote(col["name"]) + table_sql = self._qualified_table_sql() + partition_filter = self._partition_filter() + where_clause = f"WHERE {partition_filter}" if partition_filter else "" + frags = self._numeric_agg_fragments(col_sql, col) + extra_aggs = "".join(f"\n , {expr} AS {alias}" for alias, expr in frags) + return f""" + SELECT + {self._null_count_sql(col_sql)} AS null_count, + {self._distinct_count_sql(col_sql)} AS distinct_count{extra_aggs} + FROM {table_sql} + {where_clause} + """.strip() + + def _build_top_values_query(self, col: dict) -> str: + col_sql = self._quote(col["name"]) + table_sql = self._qualified_table_sql() + partition_filter = self._partition_filter() + where_clause = f"WHERE {partition_filter}" if partition_filter else "" + return f""" + SELECT {col_sql} AS value, COUNT(*) AS cnt + FROM {table_sql} + {where_clause} + GROUP BY {col_sql} + ORDER BY cnt DESC, {col_sql} ASC + LIMIT 10 + """.strip() + + def _cast_complex_to_string(self, col_sql: str) -> str: + return f"CAST({col_sql} AS STRING)" + + +class StarRocksConfig(DatabaseConfig): + """StarRocks-specific configuration using mysql-connector-python.""" + + type: Literal["starrocks"] = "starrocks" + host: str = Field(description="StarRocks FE host") + port: int = Field(default=9030, description="StarRocks MySQL protocol port") + user: str = Field(description="Username") + password: str = Field(default="", description="Password") + catalog: str | None = Field(default=None, description="Catalog name (optional, defaults to all catalogs)") + database: str | None = Field(default=None, description="Default database name (optional)") + schema_name: str | None = Field(default=None, description="Specific schema to sync (optional)") + + @classmethod + def promptConfig(cls) -> "StarRocksConfig": + name = ask_text("Connection name:", default="starrocks-prod") or "starrocks-prod" + host = ask_text("Host:", default="localhost") or "localhost" + port_str = ask_text("Port:", default="9030") or "9030" + if not port_str.isdigit(): + raise InitError("Port must be a valid integer.") + + user = ask_text("Username:", required_field=True) + password = ask_text("Password:", password=True) or "" + catalog = ask_text("Catalog (optional, e.g. default_catalog):") or None + database = ask_text("Default database (optional):") or None + schema_name = ask_text("Schema to sync (optional):") or None + + return StarRocksConfig( + name=name, + host=host, + port=int(port_str), + user=user, # type: ignore[arg-type] + password=password, + catalog=catalog, + database=database, + schema_name=schema_name, + ) + + def connect(self): + require_dependency("mysql.connector", "starrocks", "to connect to StarRocks databases") + import mysql.connector + + conn_kwargs: dict[str, Any] = { + "host": self.host, + "port": self.port, + "user": self.user, + "password": self.password, + "autocommit": True, + } + if self.database: + catalog = self.catalog or DEFAULT_CATALOG + conn_kwargs["database"] = f"{catalog}.{self.database}" + + conn = mysql.connector.connect(**conn_kwargs) + return StarRocksBackend(conn, default_catalog=self.catalog or DEFAULT_CATALOG) # type: ignore[invalid-argument-type] + + def get_database_name(self) -> str: + if self.catalog and self.database: + return f"{self.catalog}.{self.database}" + if not self.catalog and self.database: + return self.database + if self.catalog and not self.database: + return self.catalog + return "starrocks" + + def get_schemas(self, conn) -> list[str]: + if self.schema_name: + catalog = self.catalog or DEFAULT_CATALOG + return [f"{catalog}.{self.schema_name}"] + + catalogs = [self.catalog] if self.catalog else conn.list_catalogs() + schemas: list[str] = [] + for catalog in catalogs: + try: + for schema in conn.list_databases(catalog): + if schema.lower() in SYSTEM_SCHEMAS: + continue + schemas.append(f"{catalog}.{schema}") + except Exception: + continue + return sorted(set(schemas)) + + def create_context(self, conn, schema: str, table_name: str): + return StarRocksDatabaseContext(conn, schema, table_name, default_catalog=self.catalog or DEFAULT_CATALOG) + + def check_connection(self) -> tuple[bool, str]: + conn = None + try: + conn = self.connect() + schemas = self.get_schemas(conn) + return True, f"Connected successfully ({len(schemas)} schemas found)" + except Exception as e: + return False, str(e) + finally: + if conn is not None: + conn.disconnect() + + def matches_pattern(self, schema: str, table: str) -> bool: + catalog, schema_name = _split_schema_identifier(schema, default_catalog=self.catalog or DEFAULT_CATALOG) + full_names = (f"{catalog}.{schema_name}.{table}", f"{schema_name}.{table}") + + if self.include and not any(fnmatch.fnmatch(name, pattern) for pattern in self.include for name in full_names): + return False + if self.exclude and any(fnmatch.fnmatch(name, pattern) for pattern in self.exclude for name in full_names): + return False + return True diff --git a/cli/nao_core/deps.py b/cli/nao_core/deps.py index 8c3ebe0a8..d308f2983 100644 --- a/cli/nao_core/deps.py +++ b/cli/nao_core/deps.py @@ -32,6 +32,7 @@ "trino": ["ibis.backends.trino"], "redshift": ["ibis.backends.postgres", "sshtunnel"], "fabric": ["ibis.backends.mssql", "azure.identity"], + "starrocks": ["mysql.connector"], # LLM providers "openai": ["openai"], "anthropic": ["anthropic"], diff --git a/cli/pyproject.toml b/cli/pyproject.toml index ede517e5c..db3d92cd4 100644 --- a/cli/pyproject.toml +++ b/cli/pyproject.toml @@ -59,6 +59,7 @@ athena = ["ibis-framework[athena]>=9.0.0"] trino = ["ibis-framework[trino]>=9.0.0"] redshift = ["ibis-framework[postgres]>=9.0.0", "sshtunnel>=0.4.0", "redshift-connector>=2.1.13"] fabric = ["ibis-framework[mssql]>=9.0.0", "azure-identity>=1.19.0"] +starrocks = ["mysql-connector-python>=9.0.0"] # LLM providers openai = ["openai>=1.0.0"] @@ -72,7 +73,7 @@ notion = ["notion-client>=2.7.0", "notion2md>=2.9.0"] # Convenience groups all-databases = [ - "nao-core[postgres,bigquery,snowflake,duckdb,clickhouse,databricks,mysql,mssql,athena,trino,redshift,fabric]", + "nao-core[postgres,bigquery,snowflake,duckdb,clickhouse,databricks,mysql,mssql,athena,trino,redshift,fabric,starrocks]", ] all-llms = ["nao-core[openai,anthropic,mistral,gemini,ollama]"] all = ["nao-core[all-databases,all-llms,notion]"] diff --git a/cli/tests/nao_core/commands/sync/integration/.env.example b/cli/tests/nao_core/commands/sync/integration/.env.example index 9f50c87d1..dc690ccd5 100644 --- a/cli/tests/nao_core/commands/sync/integration/.env.example +++ b/cli/tests/nao_core/commands/sync/integration/.env.example @@ -43,3 +43,9 @@ TRINO_PORT=8080 TRINO_CATALOG=hive TRINO_USER=nao TRINO_PASSWORD= + +# StarRocks (tests skipped when STARROCKS_HOST is not set) +STARROCKS_HOST=localhost +STARROCKS_PORT=9030 +STARROCKS_USER=root +STARROCKS_PASSWORD= diff --git a/cli/tests/nao_core/commands/sync/integration/dml/starrocks.sql b/cli/tests/nao_core/commands/sync/integration/dml/starrocks.sql new file mode 100644 index 000000000..32ad8fa12 --- /dev/null +++ b/cli/tests/nao_core/commands/sync/integration/dml/starrocks.sql @@ -0,0 +1,21 @@ +CREATE TABLE users ( + id INT NOT NULL, + name VARCHAR(255) NOT NULL, + email VARCHAR(255), + active BOOLEAN NOT NULL +); + +INSERT INTO users (id, name, email, active) VALUES +(1, 'Alice', 'alice@example.com', true), +(2, 'Bob', NULL, false), +(3, 'Charlie', 'charlie@example.com', true); + +CREATE TABLE orders ( + id INT NOT NULL, + user_id INT NOT NULL, + amount DOUBLE NOT NULL +); + +INSERT INTO orders (id, user_id, amount) VALUES +(1, 1, 99.99), +(2, 1, 24.50); diff --git a/cli/tests/nao_core/commands/sync/integration/test_starrocks.py b/cli/tests/nao_core/commands/sync/integration/test_starrocks.py new file mode 100644 index 000000000..85abd0d11 --- /dev/null +++ b/cli/tests/nao_core/commands/sync/integration/test_starrocks.py @@ -0,0 +1,123 @@ +"""Integration tests for StarRocks sync using mysql-connector backend. + +Required environment variables: + STARROCKS_HOST, STARROCKS_USER +Optional: + STARROCKS_PORT (default 9030), STARROCKS_PASSWORD + +The test suite is skipped when STARROCKS_HOST is not set. +""" + +import os +import uuid +from pathlib import Path + +import pytest +from rich.progress import Progress + +from nao_core.commands.sync.providers.databases.provider import sync_database +from nao_core.config.databases.starrocks import DEFAULT_CATALOG, StarRocksConfig + +STARROCKS_HOST = os.environ.get("STARROCKS_HOST") + +pytestmark = pytest.mark.skipif(STARROCKS_HOST is None, reason="STARROCKS_HOST not set - skipping StarRocks tests") + + +def _connect_starrocks(): + import mysql.connector + + return mysql.connector.connect( + host=os.environ["STARROCKS_HOST"], + port=int(os.environ.get("STARROCKS_PORT", "9030")), + user=os.environ["STARROCKS_USER"], + password=os.environ.get("STARROCKS_PASSWORD", ""), + autocommit=True, + ) + + +@pytest.fixture(scope="module") +def temp_databases(): + """Create temporary StarRocks databases and test tables.""" + database = f"nao_sr_{uuid.uuid4().hex[:8]}" + another_database = f"{database}_alt" + conn = _connect_starrocks() + cursor = conn.cursor() + + try: + cursor.execute(f"CREATE DATABASE {database}") + cursor.execute(f"CREATE DATABASE {another_database}") + + sql_file = Path(__file__).parent / "dml" / "starrocks.sql" + sql_content = sql_file.read_text() + + cursor.execute(f"USE {DEFAULT_CATALOG}.{database}") + for statement in sql_content.split(";"): + statement = statement.strip() + if statement: + cursor.execute(statement) + + cursor.execute(f"USE {DEFAULT_CATALOG}.{another_database}") + cursor.execute("CREATE TABLE whatever (id INT NOT NULL, price DOUBLE NOT NULL)") + + yield {"primary": database, "another": another_database} + finally: + cleanup_conn = _connect_starrocks() + cleanup_cursor = cleanup_conn.cursor() + try: + cleanup_cursor.execute(f"DROP DATABASE IF EXISTS {database} FORCE") + cleanup_cursor.execute(f"DROP DATABASE IF EXISTS {another_database} FORCE") + except Exception: + pass + cleanup_cursor.close() + cleanup_conn.close() + cursor.close() + conn.close() + + +@pytest.fixture +def db_config(temp_databases): + return StarRocksConfig( + name="test-starrocks", + host=os.environ["STARROCKS_HOST"], + port=int(os.environ.get("STARROCKS_PORT", "9030")), + user=os.environ["STARROCKS_USER"], + password=os.environ.get("STARROCKS_PASSWORD", ""), + catalog=DEFAULT_CATALOG, + schema_name=temp_databases["primary"], + ) + + +def test_sync_with_explicit_schema(tmp_path, db_config, temp_databases): + output = tmp_path / "sync" + with Progress(transient=True) as progress: + state = sync_database(db_config, output, progress) + + primary_schema = f"{DEFAULT_CATALOG}.{temp_databases['primary']}" + another_schema = f"{DEFAULT_CATALOG}.{temp_databases['another']}" + base = output / "type=starrocks" / f"database={db_config.get_database_name()}" + + assert state.schemas_synced == 1 + assert state.tables_synced == 2 + assert primary_schema in state.synced_schemas + assert (base / f"schema={primary_schema}" / "table=users" / "columns.md").exists() + assert (base / f"schema={primary_schema}" / "table=orders" / "preview.md").exists() + assert not (base / f"schema={another_schema}").exists() + + +def test_get_schemas_supports_catalog_prefix(db_config, temp_databases): + config = db_config.model_copy(update={"schema_name": None}) + conn = config.connect() + try: + schemas = config.get_schemas(conn) + finally: + conn.disconnect() + + assert f"{DEFAULT_CATALOG}.{temp_databases['primary']}" in schemas + assert f"{DEFAULT_CATALOG}.{temp_databases['another']}" in schemas + + +def test_execute_sql_works_with_three_part_names(db_config, temp_databases): + schema = f"{DEFAULT_CATALOG}.{temp_databases['primary']}" + df = db_config.execute_sql(f"SELECT COUNT(*) AS cnt FROM {schema}.users") + assert len(df) == 1 + assert int(df.iloc[0, 0]) == 3 diff --git a/cli/tests/nao_core/config/test_starrocks.py b/cli/tests/nao_core/config/test_starrocks.py new file mode 100644 index 000000000..74dec1eb0 --- /dev/null +++ b/cli/tests/nao_core/config/test_starrocks.py @@ -0,0 +1,42 @@ +from nao_core.config.databases.starrocks import StarRocksConfig + + +class DummyConn: + def list_catalogs(self): + return ["default_catalog", "hive1"] + + def list_databases(self, catalog: str): + return { + "default_catalog": ["information_schema", "sales"], + "hive1": ["analytics"], + }[catalog] + + +def test_starrocks_get_schemas_without_explicit_schema(): + cfg = StarRocksConfig(name="sr", host="localhost", user="root", catalog=None) + schemas = cfg.get_schemas(DummyConn()) + assert schemas == ["default_catalog.sales", "hive1.analytics"] + + +def test_starrocks_matches_pattern_accepts_catalog_and_schema_forms(): + cfg = StarRocksConfig( + name="sr", + host="localhost", + user="root", + catalog="default_catalog", + include=["default_catalog.sales.*"], + exclude=["sales.orders"], + ) + + assert cfg.matches_pattern("default_catalog.sales", "users") is True + assert cfg.matches_pattern("default_catalog.sales", "orders") is False + + +def test_starrocks_get_database_name_variants(): + both = StarRocksConfig(name="sr", host="localhost", user="root", catalog="hive1", database="analytics") + catalog_only = StarRocksConfig(name="sr", host="localhost", user="root", catalog="hive1") + fallback = StarRocksConfig(name="sr", host="localhost", user="root") + + assert both.get_database_name() == "hive1.analytics" + assert catalog_only.get_database_name() == "hive1" + assert fallback.get_database_name() == "starrocks" diff --git a/cli/uv.lock b/cli/uv.lock index 199c34e55..8bb856b3a 100644 --- a/cli/uv.lock +++ b/cli/uv.lock @@ -2382,6 +2382,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/81/08/7036c080d7117f28a4af526d794aab6a84463126db031b007717c1a6676e/multidict-6.7.1-py3-none-any.whl", hash = "sha256:55d97cc6dae627efa6a6e548885712d4864b81110ac76fa4e534c03819fa4a56", size = 12319, upload-time = "2026-01-26T02:46:44.004Z" }, ] +[[package]] +name = "mysql-connector-python" +version = "9.7.0" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7a/7b/bfbe1732bdc413fa29d4431e04f257bed32b0f3efe775ca2e70e9d347008/mysql_connector_python-9.7.0-cp310-cp310-macosx_14_0_arm64.whl", hash = "sha256:ee90c5f44f706f012be17f03f6ad158ff96e7f2dcc077896fe4537d3d28b3cf4", size = 20265583, upload-time = "2026-04-23T07:15:43.703Z" }, + { url = "https://files.pythonhosted.org/packages/43/40/cba971fdc54522742955f12d4b019e9f3325d9a5c734abf5f012fde7cfff/mysql_connector_python-9.7.0-cp310-cp310-macosx_14_0_x86_64.whl", hash = "sha256:a2f371ab69d65c61136c51ad7026017400166cef3c959cab7a9fb668c7acbfba", size = 19826949, upload-time = "2026-04-23T07:15:46.443Z" }, + { url = "https://files.pythonhosted.org/packages/83/5c/724577da77cd33d056ad48d1e29149f6c123371d651c0d824f6bfd2af28f/mysql_connector_python-9.7.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:9bdfc2d4c4444cd1cc79cc6487c047b28fe2b26d0327b27eb9f5737bb553cb5c", size = 21917561, upload-time = "2026-04-23T07:15:49.077Z" }, + { url = "https://files.pythonhosted.org/packages/f3/40/f0184970f6483a4e5ffcb99028f8402f3789b885872a5779edd3fa53da44/mysql_connector_python-9.7.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:6546e0b60c275409a5add9e3308c3897fcf478d1338cd845b1664c1a8946f72f", size = 21687512, upload-time = "2026-04-23T07:15:51.614Z" }, + { url = "https://files.pythonhosted.org/packages/50/ee/0be8e060376e518897f4b3433e768ccd05bc8bb3d08c436cc2441b44ac0b/mysql_connector_python-9.7.0-cp310-cp310-win_amd64.whl", hash = "sha256:c51be697bfdfdf63bb71c5ecc51f7c6faf4aaa3d14a0136fa16e97cc37df1185", size = 17678391, upload-time = "2026-04-23T07:15:54.626Z" }, + { url = "https://files.pythonhosted.org/packages/70/fa/babe981ec8c24eece7f47dc52c5e3fe3f126bc99cc80d637b49ac2fe50a4/mysql_connector_python-9.7.0-cp311-cp311-macosx_14_0_arm64.whl", hash = "sha256:b5cb8a3ba42b539f79cd13e4c8376d28506f3180f7079c9b04ea7bfd0424fb03", size = 20265659, upload-time = "2026-04-23T07:15:57.375Z" }, + { url = "https://files.pythonhosted.org/packages/a5/4b/c45b8b601b0270faf1d4384e4c7270af9abb8d95ea39425253217c3c236c/mysql_connector_python-9.7.0-cp311-cp311-macosx_14_0_x86_64.whl", hash = "sha256:5492d57a6a0e5127a928290737fbb91b66b46d31dac8de3e7604e550bf3b3a6e", size = 19826940, upload-time = "2026-04-23T07:16:00.156Z" }, +] + [[package]] name = "mysqlclient" version = "2.2.8" @@ -2434,6 +2448,7 @@ all = [ { name = "google-genai" }, { name = "ibis-framework", extra = ["athena", "bigquery", "clickhouse", "databricks", "duckdb", "mssql", "mysql", "postgres", "snowflake", "trino"] }, { name = "mistralai" }, + { name = "mysql-connector-python" }, { name = "notion-client" }, { name = "notion2md" }, { name = "ollama" }, @@ -2448,6 +2463,7 @@ all-databases = [ { name = "google-cloud" }, { name = "ibis-framework", extra = ["athena", "bigquery", "clickhouse", "databricks", "duckdb", "mssql", "mysql", "postgres", "snowflake", "trino"] }, { name = "redshift-connector" }, + { name = "mysql-connector-python" }, { name = "snowflake-connector-python", extra = ["secure-local-storage"] }, { name = "sshtunnel" }, ] @@ -2520,6 +2536,9 @@ snowflake = [ { name = "ibis-framework", extra = ["snowflake"] }, { name = "snowflake-connector-python", extra = ["secure-local-storage"] }, ] +starrocks = [ + { name = "mysql-connector-python" }, +] trino = [ { name = "ibis-framework", extra = ["trino"] }, ] @@ -2559,9 +2578,10 @@ requires-dist = [ { name = "ibis-framework", extras = ["trino"], marker = "extra == 'trino'", specifier = ">=9.0.0" }, { name = "jinja2", specifier = ">=3.1.0" }, { name = "mistralai", marker = "extra == 'mistral'", specifier = ">=1.11.1,<2.0.0" }, + { name = "mysql-connector-python", marker = "extra == 'starrocks'", specifier = ">=9.0.0" }, { name = "nao-core", extras = ["all-databases", "all-llms", "notion"], marker = "extra == 'all'" }, { name = "nao-core", extras = ["openai", "anthropic", "mistral", "gemini", "ollama"], marker = "extra == 'all-llms'" }, - { name = "nao-core", extras = ["postgres", "bigquery", "snowflake", "duckdb", "clickhouse", "databricks", "mysql", "mssql", "athena", "trino", "redshift", "fabric"], marker = "extra == 'all-databases'" }, + { name = "nao-core", extras = ["postgres", "bigquery", "snowflake", "duckdb", "clickhouse", "databricks", "mysql", "mssql", "athena", "trino", "redshift", "fabric", "starrocks"], marker = "extra == 'all-databases'" }, { name = "notion-client", marker = "extra == 'notion'", specifier = ">=2.7.0" }, { name = "notion2md", marker = "extra == 'notion'", specifier = ">=2.9.0" }, { name = "numpy", specifier = ">=1.26.0" }, @@ -2585,7 +2605,7 @@ requires-dist = [ { name = "sshtunnel", marker = "extra == 'redshift'", specifier = ">=0.4.0" }, { name = "uvicorn", specifier = ">=0.40.0" }, ] -provides-extras = ["postgres", "bigquery", "snowflake", "duckdb", "clickhouse", "databricks", "mysql", "mssql", "athena", "trino", "redshift", "fabric", "openai", "anthropic", "mistral", "gemini", "ollama", "notion", "all-databases", "all-llms", "all", "dev"] +provides-extras = ["postgres", "bigquery", "snowflake", "duckdb", "clickhouse", "databricks", "mysql", "mssql", "athena", "trino", "redshift", "fabric", "starrocks", "openai", "anthropic", "mistral", "gemini", "ollama", "notion", "all-databases", "all-llms", "all", "dev"] [package.metadata.requires-dev] dev = [