diff --git a/databao/agent/databases/database_adapter.py b/databao/agent/databases/database_adapter.py index 2ffd303e..a3505884 100644 --- a/databao/agent/databases/database_adapter.py +++ b/databao/agent/databases/database_adapter.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Any +from typing import TYPE_CHECKING, Any from _duckdb import DuckDBPyConnection from databao_context_engine import DatasourceType @@ -11,6 +11,9 @@ DBConnectionRuntime, ) +if TYPE_CHECKING: + from sqlalchemy import Engine + class DatabaseAdapter(ABC): @classmethod @@ -36,3 +39,8 @@ def create_config_from_content(cls, content: dict[str, Any]) -> DBConnectionConf @classmethod @abstractmethod def register_in_duckdb(cls, shared_conn: DuckDBPyConnection, config: DBConnectionConfig, name: str) -> None: ... + + @classmethod + def create_sqlalchemy_engine(cls, config: DBConnectionConfig) -> "Engine | None": + """Create a SQLAlchemy engine from a connection config, or return None if not supported.""" + return None diff --git a/databao/agent/databases/databases.py b/databao/agent/databases/databases.py index 02477f0a..b0de7fbc 100644 --- a/databao/agent/databases/databases.py +++ b/databao/agent/databases/databases.py @@ -1,4 +1,4 @@ -from typing import Any +from typing import TYPE_CHECKING, Any from _duckdb import DuckDBPyConnection from databao_context_engine.pluginlib.build_plugin import AbstractConfigFile, DatasourceType @@ -15,6 +15,9 @@ from databao.agent.databases.snowflake_adapter import SnowflakeAdapter from databao.agent.databases.sqlite_adapter import SQLiteAdapter +if TYPE_CHECKING: + from sqlalchemy import Engine + DATABASE_ADAPTERS: list[DatabaseAdapter] = [ BigQueryAdapter(), DuckDBAdapter(), @@ -59,3 +62,10 @@ def register_db_in_duckdb(shared_conn: DuckDBPyConnection, config: DBConnectionC adapter.register_in_duckdb(shared_conn, config, name) return raise ValueError(f"Cannot register connection for config type {type(config)} in DuckDB.") + + +def try_create_sqlalchemy_engine(config: DBConnectionConfig) -> "Engine | None": + for adapter in DATABASE_ADAPTERS: + if adapter.accept(config): + return adapter.create_sqlalchemy_engine(config) + return None diff --git a/databao/agent/databases/snowflake_adapter.py b/databao/agent/databases/snowflake_adapter.py index b3d15c3c..7b50c52b 100644 --- a/databao/agent/databases/snowflake_adapter.py +++ b/databao/agent/databases/snowflake_adapter.py @@ -7,12 +7,13 @@ SnowflakeConfigFile, SnowflakeConnectionProperties, SnowflakeKeyPairAuth, + SnowflakeOAuthAuth, SnowflakePasswordAuth, SnowflakeSSOAuth, ) from databao_context_engine.pluginlib.build_plugin import AbstractConfigFile from snowflake.connector.network import SNOWFLAKE_HOST_SUFFIX -from sqlalchemy import Connection, Engine, make_url +from sqlalchemy import Connection, Engine, create_engine, make_url from databao.agent.databases.database_adapter import DatabaseAdapter from databao.agent.databases.database_connection import DBConnection, DBConnectionConfig, DBConnectionRuntime @@ -44,8 +45,13 @@ PRIVATE_KEY_FILE_KEY, PRIVATE_KEY_PASSPHRASE_KEY, OKTA_URL_KEY, + TOKEN_KEY, } +# Keys injected by SQLAlchemy's Snowflake dialect that are not valid Snowflake connection properties. +# Note: "host" is also dialect-internal but handled separately because its value is used to derive the account. +_SQLALCHEMY_INTERNAL_KEYS = {"port", "autocommit"} + EXCLUDED_QUERY_KEYS = {*MAIN_KEYS, *AUTH_KEYS} AUTH_TYPE_KEY = "auth_type" @@ -94,6 +100,8 @@ def create_config_from_runtime(cls, run_conn: DBConnectionRuntime) -> DBConnecti content[DATABASE_KEY] = content.pop("dbname") host: str | None = content.pop("host", None) + for key in _SQLALCHEMY_INTERNAL_KEYS: + content.pop(key, None) account: str = content.get(ACCOUNT_KEY, "") if host and host.endswith(SNOWFLAKE_HOST_SUFFIX): account = host[: -len(SNOWFLAKE_HOST_SUFFIX)] @@ -113,6 +121,64 @@ def create_config_from_content(cls, content: dict[str, Any]) -> DBConnectionConf config_file = SnowflakeConfigFile.model_validate({"name": "", **content}) return config_file.connection + @classmethod + def create_sqlalchemy_engine(cls, config: DBConnectionConfig) -> Engine | None: + if not isinstance(config, SnowflakeConnectionProperties): + return None + + from snowflake.sqlalchemy import URL # type: ignore[import-untyped] + + url_kwargs: dict[str, str] = {"account": config.account} + if config.user: + url_kwargs["user"] = config.user + if config.database: + url_kwargs["database"] = config.database + if config.warehouse: + url_kwargs["warehouse"] = config.warehouse + if config.role: + url_kwargs["role"] = config.role + + connect_args: dict[str, Any] = {k: v for k, v in config.additional_properties.items()} + auth = config.auth + if isinstance(auth, SnowflakePasswordAuth): + url_kwargs["password"] = auth.password + elif isinstance(auth, SnowflakeKeyPairAuth): + connect_args["private_key"] = cls._load_private_key_bytes(auth) + elif isinstance(auth, SnowflakeOAuthAuth): + connect_args["authenticator"] = "oauth" + connect_args["token"] = auth.token + elif isinstance(auth, SnowflakeSSOAuth): + url_kwargs["authenticator"] = auth.authenticator + else: + return None + + if connect_args: + return create_engine(URL(**url_kwargs), connect_args=connect_args) + return create_engine(URL(**url_kwargs)) + + @staticmethod + def _load_private_key_bytes(auth: SnowflakeKeyPairAuth) -> bytes: + from cryptography.hazmat.primitives import serialization + + if auth.private_key: + pem_data = auth.private_key.encode() + elif auth.private_key_file: + try: + pem_data = Path(auth.private_key_file).read_bytes() + except OSError as exc: + raise ValueError(f"Failed to read private key file at '{auth.private_key_file}'.") from exc + else: + raise ValueError("No private key provided.") + + passphrase = auth.private_key_file_pwd.encode() if auth.private_key_file_pwd else None + private_key = serialization.load_pem_private_key(pem_data, password=passphrase) + return private_key.private_bytes( + encoding=serialization.Encoding.DER, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ) + + # TODO: url and name should be escaped properly @classmethod def register_in_duckdb(cls, shared_conn: DuckDBPyConnection, config: DBConnectionConfig, name: str) -> None: if not isinstance(config, SnowflakeConnectionProperties): @@ -164,6 +230,9 @@ def _create_secret_params(config: SnowflakeConnectionProperties) -> dict[str, st raise ValueError("No private key provided.") if auth.private_key_file_pwd: params[PRIVATE_KEY_PASSPHRASE_KEY] = auth.private_key_file_pwd + elif isinstance(auth, SnowflakeOAuthAuth): + params[AUTH_TYPE_KEY] = AUTH_TYPE_OAUTH + params[TOKEN_KEY] = auth.token elif isinstance(auth, SnowflakeSSOAuth): authenticator = auth.authenticator if SnowflakeAdapter._is_okta_url(authenticator): @@ -177,7 +246,9 @@ def _create_secret_params(config: SnowflakeConnectionProperties) -> dict[str, st return params @staticmethod - def _create_auth(content: dict[str, Any]) -> SnowflakePasswordAuth | SnowflakeKeyPairAuth | SnowflakeSSOAuth: + def _create_auth( + content: dict[str, Any], + ) -> SnowflakePasswordAuth | SnowflakeKeyPairAuth | SnowflakeSSOAuth | SnowflakeOAuthAuth: if PASSWORD_KEY in content: return SnowflakePasswordAuth(password=content[PASSWORD_KEY]) if content.keys() & {PRIVATE_KEY_KEY, PRIVATE_KEY_FILE_KEY}: @@ -187,7 +258,7 @@ def _create_auth(content: dict[str, Any]) -> SnowflakePasswordAuth | SnowflakeKe private_key=content.get(PRIVATE_KEY_KEY), ) if TOKEN_KEY in content: - return SnowflakeSSOAuth(authenticator=AUTH_TYPE_OAUTH) + return SnowflakeOAuthAuth(token=content[TOKEN_KEY]) if OKTA_URL_KEY in content: return SnowflakeSSOAuth(authenticator=content[OKTA_URL_KEY]) raise ValueError("Unsupported Snowflake authentication type.") diff --git a/databao/agent/executors/separate/__init__.py b/databao/agent/executors/separate/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/databao/agent/executors/separate/graph.py b/databao/agent/executors/separate/graph.py new file mode 100644 index 00000000..c0819069 --- /dev/null +++ b/databao/agent/executors/separate/graph.py @@ -0,0 +1,349 @@ +import json +from typing import Annotated, Any, Literal + +import pandas as pd +from langchain_core.messages import AIMessage, BaseMessage, ToolMessage +from langchain_core.tools import BaseTool, tool +from langgraph.constants import END, START +from langgraph.graph import add_messages +from langgraph.graph.state import CompiledStateGraph, StateGraph +from langgraph.prebuilt import InjectedState +from sqlalchemy import Engine, text +from typing_extensions import TypedDict + +from databao.agent.configs import llm +from databao.agent.configs.agent import AgentConfig +from databao.agent.configs.llm import LLMConfig +from databao.agent.core import Domain, ExecutionResult +from databao.agent.executors.frontend.text_frontend import dataframe_to_markdown +from databao.agent.executors.langchain_tools import make_search_context_tool +from databao.agent.executors.llm import chat, model_bind_tools +from databao.agent.executors.utils import exception_to_string, trim_dataframe_values + + +class AgentState(TypedDict): + messages: Annotated[list[BaseMessage], add_messages] + query_ids: dict[str, ToolMessage] + sql: str | None + df: pd.DataFrame | None + visualization_prompt: str | None + ready_for_user: bool + limit_max_rows: int | None + + +def get_query_ids_mapping(messages: list[BaseMessage]) -> dict[str, ToolMessage]: + query_ids = {} + for message in messages: + if isinstance(message, ToolMessage) and isinstance(message.artifact, dict) and "query_id" in message.artifact: + query_ids[message.artifact["query_id"]] = message + return query_ids + + +_FORBIDDEN_SQL_PREFIXES = ( + "INSERT", + "UPDATE", + "DELETE", + "DROP", + "ALTER", + "CREATE", + "TRUNCATE", + "MERGE", + "GRANT", + "REVOKE", + "CALL", + "EXEC", + "EXECUTE", + "BEGIN", + "COMMIT", + "ROLLBACK", + "COPY", + "PUT", + "GET", + "REMOVE", +) + + +def _validate_read_only(sql: str) -> None: + """Raise ValueError if *sql* looks like a write / DDL statement.""" + stripped = sql.strip().lstrip("(").strip() + first_word = stripped.split(None, 1)[0].upper().rstrip(";") if stripped else "" + if first_word in _FORBIDDEN_SQL_PREFIXES: + raise ValueError(f"Only SELECT / read-only queries are allowed. Got statement starting with '{first_word}'.") + + +def _run_sql(engine: Engine, sql: str, limit: int | None) -> pd.DataFrame: + _validate_read_only(sql) + with engine.connect() as conn: + result = conn.execute(text(sql)) + columns = list(result.keys()) + rows = result.fetchmany(limit) if limit is not None else result.fetchall() + return pd.DataFrame(rows, columns=columns) + + +class SeparateGraph: + """Graph with two tools: run_sql_query and submit_result. + + SQL is routed to per-datasource SQLAlchemy engines via a ``datasource`` argument + on ``run_sql_query``. The ``connections`` dict is held by reference, so engines + added after construction are immediately visible. + """ + + MAX_TOOL_ROWS = 12 + MAX_DF_CELL_CHARS = 1024 + + def __init__(self, connections: dict[str, Engine]): + self._connections = connections + + def init_state(self, messages: list[BaseMessage], *, limit_max_rows: int | None = None) -> AgentState: + return AgentState( + messages=messages, + query_ids=get_query_ids_mapping(messages), + sql=None, + df=None, + visualization_prompt=None, + ready_for_user=False, + limit_max_rows=limit_max_rows, + ) + + def get_result(self, state: AgentState) -> ExecutionResult: + last_ai_message = None + for m in reversed(state["messages"]): + if isinstance(m, AIMessage): + last_ai_message = m + break + if last_ai_message is None: + raise RuntimeError("No AI message found in message log") + if len(last_ai_message.tool_calls) == 0: + result = ExecutionResult( + text=last_ai_message.text, + df=state.get("df"), + code=state.get("sql", ""), + meta={ + "visualization_prompt": state.get("visualization_prompt"), + ExecutionResult.META_MESSAGES_KEY: state["messages"], + "submit_called": False, + }, + ) + elif len(last_ai_message.tool_calls) > 1: + raise RuntimeError("Expected exactly one tool call in AI message") + elif last_ai_message.tool_calls[0]["name"] != "submit_result": + raise RuntimeError( + f"Expected submit_result tool call in AI message, got {last_ai_message.tool_calls[0]['name']}" + ) + else: + tool_call = last_ai_message.tool_calls[0] + result = ExecutionResult( + text=tool_call["args"]["result_description"], + df=state.get("df"), + code=state.get("sql", ""), + meta={ + "visualization_prompt": state.get("visualization_prompt", ""), + ExecutionResult.META_MESSAGES_KEY: state["messages"], + "submit_called": True, + }, + ) + return result + + def has_search_context_tool(self, domain: Domain) -> bool: + return make_search_context_tool(domain) is not None + + def make_tools(self, domain: Domain, extra_tools: list[BaseTool] | None = None) -> list[BaseTool]: + @tool(parse_docstring=True) + def run_sql_query( + sql: str, datasource: str, graph_state: Annotated[AgentState, InjectedState] + ) -> dict[str, Any]: + """ + Run a SELECT SQL query against a specific datasource. Returns the first 12 rows in csv format. + + Args: + sql: SQL query to execute + datasource: Name of the datasource to run the query against + """ + try: + if datasource not in self._connections: + available = sorted(self._connections.keys()) + return {"error": f"Unknown datasource '{datasource}'. Available: {available}"} + + limit = graph_state["limit_max_rows"] + df = _run_sql(self._connections[datasource], sql, limit) + + df_display = df.head(self.MAX_TOOL_ROWS) + df_display = trim_dataframe_values(df_display, max_cell_chars=self.MAX_DF_CELL_CHARS) + + df_csv = df_display.to_csv(index=False) + df_markdown = dataframe_to_markdown(df_display, index=False) + if len(df) > self.MAX_TOOL_ROWS: + df_csv += f"\nResult is truncated from {len(df)} to {self.MAX_TOOL_ROWS} rows." + df_markdown += f"\nResult is truncated from {len(df)} to {self.MAX_TOOL_ROWS} rows." + return {"df": df, "sql": sql, "csv": df_csv, "markdown": df_markdown} + except Exception as e: + return {"error": exception_to_string(e)} + + @tool(parse_docstring=True) + def submit_result( + query_id: str, + result_description: str, + visualization_prompt: str, + ) -> str: + """ + Call this tool with the ID of the query you want to submit to the user. + This will return control to the user and must always be the last tool call. + The user will see the query result up to the configured maximum row limit (which may be larger than the + 12-row preview shown in tool output). Returns a confirmation message. + + Args: + query_id: The ID of the query to submit (query_ids are automatically generated when you run queries). + result_description: A comment to a final result. This will be included in the final result. + visualization_prompt: Optional visualization prompt. If not empty, a Vega-Lite visualization agent + will be asked to plot the submitted query data according to instructions in the prompt. + The instructions should be short and simple. + """ + return f"Query {query_id} submitted successfully. Your response is now visible to the user." + + tools: list[BaseTool] = [run_sql_query, submit_result] + search_context_tool = make_search_context_tool(domain) + if search_context_tool is not None: + tools.append(search_context_tool) + if extra_tools: + tools.extend(extra_tools) + + return tools + + def compile( + self, + model_config: LLMConfig, + agent_config: AgentConfig, + domain: Domain, + extra_tools: list[BaseTool] | None = None, + ) -> CompiledStateGraph[Any]: + tools = self.make_tools(domain, extra_tools=extra_tools) + llm_model = model_config.new_chat_model() + + if llm.is_openai_model(model_config.name): + model_with_tools = model_bind_tools(llm_model, tools, parallel_tool_calls=agent_config.parallel_tool_calls) + else: + model_with_tools = model_bind_tools(llm_model, tools) + + def llm_node(state: AgentState) -> dict[str, Any]: + response = chat(state["messages"], model_config, model_with_tools) + return {"messages": [response[-1]]} + + def tool_executor_node(state: AgentState) -> dict[str, Any]: + last_message = state["messages"][-1] + assert isinstance(last_message, AIMessage) + tool_calls = last_message.tool_calls + tool_messages = [] + + is_ready_for_user = any(tc["name"] == "submit_result" for tc in tool_calls) + if is_ready_for_user: + if len(tool_calls) > 1: + return { + "messages": [ + ToolMessage("submit_result must be the only tool call.", tool_call_id=tc["id"]) + for tc in tool_calls + ], + "ready_for_user": False, + } + tool_call = tool_calls[0] + if "query_ids" not in state or len(state["query_ids"]) == 0: + return { + "messages": [ToolMessage("No queries have been executed yet.", tool_call_id=tool_call["id"])], + "ready_for_user": False, + } + query_id = tool_call["args"]["query_id"] + if query_id not in state["query_ids"]: + available_ids = ", ".join(state["query_ids"].keys()) + return { + "messages": [ + ToolMessage( + f"Query ID {query_id} not found. Available query IDs: {available_ids}", + tool_call_id=tool_call["id"], + ) + ], + "ready_for_user": False, + } + target = state["query_ids"][query_id] + if target.artifact is None or "df" not in target.artifact: + return { + "messages": [ + ToolMessage(f"Query {query_id} does not have a valid result.", tool_call_id=tool_call["id"]) + ], + "ready_for_user": False, + } + + query_ids = dict(state.get("query_ids", {})) + sql = state.get("sql") + df = state.get("df") + visualization_prompt = state.get("visualization_prompt", "") + message_index = len(state["messages"]) - 1 + + for idx, tool_call in enumerate(tool_calls): + name = tool_call["name"] + args = tool_call["args"] + tool_call_id = tool_call["id"] + t = next((t for t in tools if t.name == name), None) + if t is None: + tool_messages.append(ToolMessage(content=f"Tool {name} does not exist!", tool_call_id=tool_call_id)) + continue + + try: + result = t.invoke(args | {"graph_state": state}) + except Exception as e: + result = {"error": exception_to_string(e) + f"\nTool: {name}, Args: {args}"} + + content = "" + if name == "run_sql_query": + sql = result.get("sql") + df = result.get("df") + query_id = f"{message_index}-{idx}" + result["query_id"] = query_id + content = result.get("csv", result.get("error", "")) + if "csv" in result: + content = f"query_id='{query_id}'\n\n{content}" + query_ids[query_id] = ToolMessage(content=content, tool_call_id=tool_call_id, artifact=result) + elif name == "submit_result": + content = str(result) + query_id = tool_call["args"]["query_id"] + visualization_prompt = tool_call["args"].get("visualization_prompt", "") + sql = state["query_ids"][query_id].artifact["sql"] + df = state["query_ids"][query_id].artifact["df"] + else: + content = ( + json.dumps(result, ensure_ascii=False, default=str) if isinstance(result, dict) else str(result) + ) + + tool_messages.append(ToolMessage(content=content, tool_call_id=tool_call_id, artifact=result)) + if name == "submit_result": + return { + "messages": tool_messages, + "sql": sql, + "df": df, + "visualization_prompt": visualization_prompt, + "ready_for_user": True, + } + + return { + "messages": tool_messages, + "query_ids": query_ids, + "sql": sql, + "df": df, + "visualization_prompt": visualization_prompt, + "ready_for_user": False, + } + + def should_continue(state: AgentState) -> Literal["tool_executor", "end"]: + last_message = state["messages"][-1] + if isinstance(last_message, AIMessage) and last_message.tool_calls: + return "tool_executor" + return "end" + + def should_finish(state: AgentState) -> Literal["llm_node", "end"]: + return "end" if state.get("ready_for_user", False) else "llm_node" + + graph = StateGraph(AgentState) + graph.add_node("llm_node", llm_node) + graph.add_node("tool_executor", tool_executor_node) + graph.add_edge(START, "llm_node") + graph.add_conditional_edges("llm_node", should_continue, {"tool_executor": "tool_executor", "end": END}) + graph.add_conditional_edges("tool_executor", should_finish, {"llm_node": "llm_node", "end": END}) + return graph.compile() diff --git a/databao/agent/executors/separate/separate_executor.py b/databao/agent/executors/separate/separate_executor.py new file mode 100644 index 00000000..46d54102 --- /dev/null +++ b/databao/agent/executors/separate/separate_executor.py @@ -0,0 +1,170 @@ +import logging +from dataclasses import replace +from typing import Any, TextIO, cast + +from langchain_core.tools import BaseTool +from langgraph.graph.state import CompiledStateGraph +from sqlalchemy import Engine + +from databao.agent.configs import LLMConfig +from databao.agent.configs.agent import AgentConfig +from databao.agent.core import Cache, Domain, ExecutionResult, Opa +from databao.agent.core.domain import _Domain +from databao.agent.databases.databases import db_type as get_db_type +from databao.agent.databases.databases import try_create_sqlalchemy_engine +from databao.agent.duckdb.schema_inspection import ( + TableInfo, + summarize_duckdb_schema, + summarize_duckdb_schema_overview, +) +from databao.agent.executors.base import GraphExecutor +from databao.agent.executors.prompt import build_context_text, get_today_date_str, load_prompt_template +from databao.agent.executors.separate.graph import SeparateGraph +from databao.agent.sqlalchemy.schema_inspection import inspect_sqlalchemy_schema + +_LOGGER = logging.getLogger(__name__) + + +class SeparateExecutor(GraphExecutor): + """Executor that works directly with each database via its own SQLAlchemy connection. + + SQL queries are routed to the appropriate engine by the ``datasource`` argument of + the ``run_sql_query`` tool. + """ + + def __init__(self, writer: Any = None) -> None: + super().__init__(writer=writer) + self._sa_engines: dict[str, Engine] = {} + self._prompt_template = load_prompt_template("databao.agent.executors.separate", "system_prompt.jinja") + self._graph: SeparateGraph = SeparateGraph(self._sa_engines) + + self._max_columns_per_table: int | None = None + self._max_schema_summary_length: int | None = 250_000 # 1 token ~= 4 characters + + def _init_sources_from_domain(self, domain: Domain, *, register_in_duckdb: bool = True) -> None: + """Register domain sources. + + DB sources are connected via SQLAlchemy engines stored in ``_sa_engines``. + The ``register_in_duckdb`` parameter is accepted for API compatibility but ignored. + """ + if not isinstance(domain, _Domain): + return + sources = domain.sources + + for name, db_source in sources.dbs.items(): + if name not in self._registered_dbs: + engine = try_create_sqlalchemy_engine(db_source.config) + if engine is not None: + self._sa_engines[name] = engine + else: + db_type = get_db_type(db_source.config) + _LOGGER.warning( + "SQLAlchemy engine creation not implemented for database '%s' (type '%s'); " + "continuing without SQLAlchemy engine", + name, + db_type, + ) + self._registered_dbs[name] = db_source + + for name, df_source in sources.dfs.items(): + if name not in self._registered_dfs: + self._registered_dfs[name] = df_source + + for name, dbt_source in sources.dbts.items(): + if name not in self._registered_dbts: + self._registered_dbts[name] = dbt_source + + def _inspect_database_schema(self) -> str: + tables: list[TableInfo] = [] + + for name, _db_source in self._registered_dbs.items(): + engine = self._sa_engines.get(name) + if engine is None: + continue + try: + db_tables = inspect_sqlalchemy_schema(engine) + # Use the registered name as table_catalog so the LLM can derive + # the datasource argument directly from the schema prefix. + tables.extend(replace(t, table_catalog=name, columns_catalog=name) for t in db_tables) + except Exception as e: + _LOGGER.warning("Failed to inspect schema for '%s': %s", name, e) + + db_schema = _summarize(tables, self._max_columns_per_table) + if self._max_schema_summary_length is None: + return db_schema + + if len(db_schema) > self._max_schema_summary_length: + db_schema = _summarize(tables, 0) + + if len(db_schema) > self._max_schema_summary_length: + db_schema = _summarize_overview(tables) + + return db_schema + + def render_system_prompt(self, domain: Domain, recursion_limit: int = 50) -> str: + domain = cast(_Domain, domain) + + db_types = {name: get_db_type(src.config).full_type for name, src in domain.sources.dbs.items()} + db_schema = self._inspect_database_schema() + + sources = domain.sources + context_text = build_context_text(sources, df_label_fn=lambda name: f"DF {name}") + + dce_search_enabled = self._graph.has_search_context_tool(domain) + + prompt = self._prompt_template.render( + date=get_today_date_str(), + db_schema=db_schema, + context=context_text, + tool_limit=recursion_limit // 2, + db_types=db_types, + dce_search_enabled=dce_search_enabled, + ) + return prompt.strip() + + def _compile_graph( + self, llm_config: LLMConfig, agent_config: AgentConfig, domain: Domain, extra_tools: list[BaseTool] | None + ) -> CompiledStateGraph[Any]: + return self._graph.compile(llm_config, agent_config, domain, extra_tools=extra_tools) + + def execute( + self, + opas: list[Opa], + cache: Cache, + llm_config: LLMConfig, + agent_config: AgentConfig, + domain: Domain, + *, + rows_limit: int = 100, + stream: bool = True, + writer: TextIO | None = None, + ) -> ExecutionResult: + self._init_sources_from_domain(domain) + system_prompt = self.render_system_prompt(domain, agent_config.recursion_limit) + init_state = self._graph.init_state([], limit_max_rows=rows_limit) + + execution_result, _ = self._execute_core( + opas, + cache, + llm_config, + agent_config, + domain, + system_prompt=system_prompt, + init_state=init_state, + get_result=self._graph.get_result, + stream=stream, + writer=writer, + ) + return execution_result + + +def _summarize(tables: list[TableInfo], max_cols_per_table: int | None = None) -> str: + if not tables: + return "(no tables found)" + return summarize_duckdb_schema(tables, max_cols_per_table=max_cols_per_table, include_original_catalog_name=False) + + +def _summarize_overview(tables: list[TableInfo]) -> str: + if not tables: + return "(no tables found)" + return summarize_duckdb_schema_overview(tables, include_original_catalog_name=False) diff --git a/databao/agent/executors/separate/system_prompt.jinja b/databao/agent/executors/separate/system_prompt.jinja new file mode 100644 index 00000000..caaf1638 --- /dev/null +++ b/databao/agent/executors/separate/system_prompt.jinja @@ -0,0 +1,75 @@ +You are a "Databao" agent that has direct access to one or more databases. +You generate SQL requests, which are executed directly on each database's native connection with no changes. +The task is to request all necessary data and answer the user question. +You can answer with +- text (using plain text with no tool or result_description parameter of submit_result tool) +- a table (using SQL requests and query_id parameter of submit_result tool). It will be visible as a DataFrame. +- a plot (using visualization parameter of submit_result tool) +or a combination of these. + +Today's date is: {{ date }} (YYYY-MM-DD). + +# Instructions: +- Solve complex requests step by step + - Briefly describe each step before running the query and explain why you are doing it. + - If several similar tables or columns can be used, try both options, determine root cause of the difference in results and choose the best one. + - You can compare approaches by analyzing examples, which are filtered by one approach, but not by another. Probably some missing or corrupted data is causing the difference. It can help to find the most robust approach. +- Get DB schema in the 'Database schema' section. Don't waste tool call for it. +- Pay attention to SQL dialect specific commands +- Cross joins are allowed only for tables that are guaranteed small (< 5 rows), such as enums or static dictionaries. +- When calculating percentages like (a - b) / a * 100, you must make multiplication first to prevent number rounding. Use 100 * (a - b) / a. +- When comparing an unfinished period like the current year to a finished one like last year, use the same date range. Never compare unfinished periods to finished ones. +- Make sure the submitted result answers the user's question and it is not empty + - Result description of submitted result should contain definitions being used, important decisions and analysis of resulting data + - Leave visualization prompt empty if you don't want to visualize the result. Table with few values or table with heterogeneous data don't need visualization + - Time series require visualization +- The user will see only the submitted result - final SQL and DataFrame. The user will not see intermediate results +- Use less than {{ tool_limit }} tool calls before submitting the result +{% if dce_search_enabled -%} +- Remember to use the search_context tool to find relevant context (e.g., table and column descriptions) before writing your final SQL query. +{% endif %} + +# Database schema + +Each database is identified by a **datasource name** (e.g. `mydb`). Use this name as the `datasource` argument +when calling `run_sql_query`. Table names in the schema are prefixed with the datasource name: +`..`. Write SQL using `.
` (without the datasource prefix). + +{% if db_types -%} +## Available datasources + +{% for name, db_type in db_types.items() -%} + - `{{ name }}` ({{ db_type }}) +{% endfor %} + +{% if "snowflake" in db_types.values() -%} +### Snowflake identifier quoting rules + +**1. Unquoted Identifiers (Default Behavior)** +- **Resolution:** Snowflake converts all unquoted identifiers to **UPPERCASE**. +- *Example:* `select column_a from table_b` is resolved as `SELECT COLUMN_A FROM TABLE_B`. +- **Instruction:** Always quote identifiers exactly as they are provided to you in the Database schema section and according to the rules in the next section. + +**2. Double-Quoted Identifiers (`" "`)** +- **Resolution:** Snowflake treats content inside double quotes as **Case-Sensitive** and preserves it exactly as written. +- *Example:* `SELECT "Column_A" FROM "Table_b"` looks for exactly `Column_A` inside `Table_b`. +- **Mandatory Usage:** You **must** double-quote identifiers if they: + - Contain any **lowercase** characters (and the schema relies on mixed/lowercase). + - Contain **spaces** or **non-alphanumeric** characters (e.g., `.`, `-`, `@`, `%`). + - Start with a **digit**. + - Match a **reserved keyword** (e.g., `"GROUP"`, `"order"`). + +**3. Escaping Quotes** +- **Syntax:** If the identifier name itself contains a double quote, escape it by using **two double quotes**. +- *Example:* To query a table named `Client"Data`, write: `SELECT * FROM "Client""Data"` +{% endif %} +{% endif %} + +## Full Database schema +{{ db_schema }} + + +{% if context -%} +# Context +{{ context }} +{% endif %} diff --git a/databao/agent/sqlalchemy/__init__.py b/databao/agent/sqlalchemy/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/databao/agent/sqlalchemy/schema_inspection.py b/databao/agent/sqlalchemy/schema_inspection.py new file mode 100644 index 00000000..1b58090f --- /dev/null +++ b/databao/agent/sqlalchemy/schema_inspection.py @@ -0,0 +1,67 @@ +from collections import defaultdict + +from sqlalchemy import Connection, Engine, text + +from databao.agent.duckdb.schema_inspection import ColumnInfo, TableInfo + + +def inspect_sqlalchemy_schema(conn: Engine | Connection) -> list[TableInfo]: + """Inspect and return structured schema information from a SQLAlchemy connection.""" + if isinstance(conn, Engine): + with conn.connect() as connection: + return _inspect(connection) + return _inspect(conn) + + +def _inspect(conn: Connection) -> list[TableInfo]: + dialect = conn.engine.dialect.name + if dialect.startswith("snowflake"): + return _inspect_snowflake(conn) + raise NotImplementedError(f"SQLAlchemy schema inspection not supported for dialect: {dialect!r}") + + +def _inspect_snowflake(conn: Connection) -> list[TableInfo]: + table_rows = conn.execute( + text(""" + SELECT table_catalog, table_schema, table_name + FROM information_schema.tables + WHERE table_type IN ('BASE TABLE', 'VIEW') + AND table_schema != 'INFORMATION_SCHEMA' + ORDER BY table_catalog, table_schema, table_name + """) + ).fetchall() + + if not table_rows: + return [] + + valid_tables: set[tuple[str, str, str]] = {(r[0], r[1], r[2]) for r in table_rows} + + col_rows = conn.execute( + text(""" + SELECT table_catalog, table_schema, table_name, column_name, data_type + FROM information_schema.columns + WHERE table_schema != 'INFORMATION_SCHEMA' + ORDER BY table_catalog, table_schema, table_name, ordinal_position + """) + ).fetchall() + + col_map: dict[tuple[str, str, str], list[ColumnInfo]] = defaultdict(list) + for catalog, schema, table, col_name, data_type in col_rows: + key = (catalog, schema, table) + if key in valid_tables: + col_map[key].append(ColumnInfo(name=col_name, data_type=data_type)) + + result: list[TableInfo] = [] + for catalog, schema, table in table_rows: + key = (catalog, schema, table) + result.append( + TableInfo( + table_catalog=catalog, + columns_catalog=catalog, + schema=schema, + name=table, + columns=col_map[key], + ) + ) + + return result diff --git a/examples/snowflake-oauth.py b/examples/snowflake-oauth.py new file mode 100644 index 00000000..24f7412d --- /dev/null +++ b/examples/snowflake-oauth.py @@ -0,0 +1,40 @@ +import os +from typing import NoReturn + +from databao_context_engine import SnowflakeConnectionProperties, SnowflakeOAuthAuth + +import databao.agent as bao +from databao.agent.executors.separate.separate_executor import SeparateExecutor + + +def fail(message: str) -> NoReturn: + raise RuntimeError(message) + + +def from_env(key: str) -> str: + return os.getenv(key) or fail(f"{key} is not set") + + +def main() -> None: + domain = bao.domain() + domain.add_db( + SnowflakeConnectionProperties( + user=from_env("SNOWFLAKE_USER"), + account=from_env("SNOWFLAKE_ACCOUNT"), + database="CALIFORNIA_TRAFFIC_COLLISION", + auth=SnowflakeOAuthAuth(token=from_env("SNOWFLAKE_OAUTH_TOKEN")), + ) + ) + + agent = bao.agent( + domain=domain, + data_executor=SeparateExecutor(), + name="my_agent", + llm_config=bao.LLMConfig(name="gpt-5.1", temperature=0), + ) + + agent.thread().ask("How many accidents occurred in total?") + + +if __name__ == "__main__": + main() diff --git a/pyproject.toml b/pyproject.toml index f140c80b..3a6dd5fc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,7 +29,7 @@ dependencies = [ "sqlalchemy>=2.0.45", "snowflake-sqlalchemy>=1.8.2", "sqlalchemy-bigquery>=1.11.0", - "databao-context-engine[postgresql,mysql,snowflake]~=0.7.0", + "databao-context-engine[postgresql,mysql,snowflake]==0.7.1.dev2", "mcp>=1.0.0,<2", "watchdog>=6.0.0", "pandas-stubs~=2.3.3", diff --git a/tests/test_separate_executor.py b/tests/test_separate_executor.py new file mode 100644 index 00000000..ff44c9c9 --- /dev/null +++ b/tests/test_separate_executor.py @@ -0,0 +1,219 @@ +"""Tests for SeparateExecutor and related helpers.""" + +from typing import Any +from unittest.mock import MagicMock, patch + +import pytest +from databao_context_engine import SnowflakeConnectionProperties, SnowflakePasswordAuth + +from databao.agent.core.data_source import DBDataSource, Sources +from databao.agent.duckdb.schema_inspection import ColumnInfo, TableInfo +from databao.agent.executors.separate.graph import _validate_read_only +from databao.agent.executors.separate.separate_executor import SeparateExecutor, _summarize + +# --------------------------------------------------------------------------- +# _validate_read_only — allowed statements +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "sql", + [ + "SELECT 1", + "select * from t", + " SELECT * FROM t", + "WITH cte AS (SELECT 1) SELECT * FROM cte", + "SHOW TABLES", + "DESCRIBE TABLE t", + "EXPLAIN SELECT 1", + "(SELECT 1 UNION SELECT 2)", + ], +) +def test_validate_read_only_allows_select(sql: str) -> None: + _validate_read_only(sql) # should not raise + + +# --------------------------------------------------------------------------- +# _validate_read_only — forbidden statements +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "sql,keyword", + [ + ("INSERT INTO t VALUES (1)", "INSERT"), + (" insert into t values (1)", "INSERT"), + ("UPDATE t SET x=1", "UPDATE"), + ("DELETE FROM t", "DELETE"), + ("DROP TABLE t", "DROP"), + ("ALTER TABLE t ADD COLUMN x INT", "ALTER"), + ("CREATE TABLE t (x INT)", "CREATE"), + ("TRUNCATE TABLE t", "TRUNCATE"), + ("MERGE INTO t USING s ON t.id=s.id WHEN MATCHED THEN DELETE", "MERGE"), + ("GRANT SELECT ON t TO role_x", "GRANT"), + ("REVOKE SELECT ON t FROM role_x", "REVOKE"), + ("CALL my_procedure()", "CALL"), + ("EXEC my_procedure()", "EXEC"), + ("EXECUTE my_procedure()", "EXECUTE"), + ("BEGIN TRANSACTION", "BEGIN"), + ("COMMIT", "COMMIT"), + ("ROLLBACK", "ROLLBACK"), + ("COPY INTO t FROM @stage", "COPY"), + ], +) +def test_validate_read_only_rejects_write_statements(sql: str, keyword: str) -> None: + with pytest.raises(ValueError, match=f"Got statement starting with '{keyword}'"): + _validate_read_only(sql) + + +# --------------------------------------------------------------------------- +# SeparateExecutor._init_sources_from_domain — engine creation routing +# --------------------------------------------------------------------------- + + +def _make_snowflake_config(**kwargs: Any) -> SnowflakeConnectionProperties: + defaults = dict(account="acct", user="usr", database="db", warehouse="wh") + return SnowflakeConnectionProperties(**{**defaults, **kwargs}, auth=SnowflakePasswordAuth(password="pw")) + + +def _make_domain_mock(dbs: dict[str, Any]) -> MagicMock: + """Build a mock _Domain with the given db sources.""" + from databao.agent.core.domain import _Domain + + domain = MagicMock(spec=_Domain) + db_sources: dict[str, DBDataSource] = {} + for name, config in dbs.items(): + src = MagicMock(spec=DBDataSource) + src.config = config + db_sources[name] = src + domain.sources = Sources(dbs=db_sources, dfs={}, dbts={}, additional_description=[]) + return domain + + +def test_init_sources_stores_engine_when_created() -> None: + executor = SeparateExecutor() + config = _make_snowflake_config() + domain = _make_domain_mock({"mydb": config}) + fake_engine = MagicMock() + + with patch( + "databao.agent.executors.separate.separate_executor.try_create_sqlalchemy_engine", + return_value=fake_engine, + ): + executor._init_sources_from_domain(domain) + + assert "mydb" in executor._sa_engines + assert executor._sa_engines["mydb"] is fake_engine + + +def test_init_sources_logs_warning_when_engine_is_none() -> None: + executor = SeparateExecutor() + config = _make_snowflake_config() + domain = _make_domain_mock({"mydb": config}) + + with ( + patch( + "databao.agent.executors.separate.separate_executor.try_create_sqlalchemy_engine", + return_value=None, + ), + patch( + "databao.agent.executors.separate.separate_executor.get_db_type", + return_value=MagicMock(full_type="snowflake"), + ), + patch("databao.agent.executors.separate.separate_executor._LOGGER") as mock_logger, + ): + executor._init_sources_from_domain(domain) + + assert "mydb" not in executor._sa_engines + assert "mydb" in executor._registered_dbs + mock_logger.warning.assert_called_once() + + +def test_init_sources_skips_already_registered_dbs() -> None: + executor = SeparateExecutor() + config = _make_snowflake_config() + domain = _make_domain_mock({"mydb": config}) + fake_engine = MagicMock() + + with patch( + "databao.agent.executors.separate.separate_executor.try_create_sqlalchemy_engine", + return_value=fake_engine, + ) as mock_create: + executor._init_sources_from_domain(domain) + executor._init_sources_from_domain(domain) + + # Should only be called once — second call skips already-registered db + assert mock_create.call_count == 1 + + +# --------------------------------------------------------------------------- +# SeparateExecutor._inspect_database_schema — error handling +# --------------------------------------------------------------------------- + + +def test_inspect_schema_returns_no_tables_when_inspection_fails() -> None: + executor = SeparateExecutor() + config = _make_snowflake_config() + + # Manually register a db and engine + src = MagicMock(spec=DBDataSource) + src.config = config + executor._registered_dbs["mydb"] = src + executor._sa_engines["mydb"] = MagicMock() + + with patch( + "databao.agent.executors.separate.separate_executor.inspect_sqlalchemy_schema", + side_effect=RuntimeError("connection refused"), + ): + result = executor._inspect_database_schema() + + assert result == "(no tables found)" + + +def test_inspect_schema_prefixes_tables_with_datasource_name() -> None: + executor = SeparateExecutor() + + src = MagicMock(spec=DBDataSource) + executor._registered_dbs["sales_db"] = src + executor._sa_engines["sales_db"] = MagicMock() + + table = TableInfo( + table_catalog="ORIGINAL_CAT", + columns_catalog="ORIGINAL_CAT", + schema="PUBLIC", + name="ORDERS", + columns=[ColumnInfo(name="id", data_type="NUMBER")], + ) + + with patch( + "databao.agent.executors.separate.separate_executor.inspect_sqlalchemy_schema", + return_value=[table], + ): + result = executor._inspect_database_schema() + + # The schema summary should use the datasource name, not the original catalog + assert "sales_db" in result + assert "ORDERS" in result + + +# --------------------------------------------------------------------------- +# _summarize helpers +# --------------------------------------------------------------------------- + + +def test_summarize_empty_returns_placeholder() -> None: + assert _summarize([]) == "(no tables found)" + + +def test_summarize_non_empty_contains_table_name() -> None: + tables = [ + TableInfo( + table_catalog="cat", + columns_catalog="cat", + schema="PUBLIC", + name="MY_TABLE", + columns=[ColumnInfo(name="col1", data_type="TEXT")], + ) + ] + result = _summarize(tables) + assert "MY_TABLE" in result diff --git a/tests/test_snowflake_adapter.py b/tests/test_snowflake_adapter.py index ef8ff90d..550ce8e5 100644 --- a/tests/test_snowflake_adapter.py +++ b/tests/test_snowflake_adapter.py @@ -1,11 +1,12 @@ from pathlib import Path from typing import Any -from unittest.mock import MagicMock +from unittest.mock import MagicMock, patch import pytest from databao_context_engine import ( SnowflakeConnectionProperties, SnowflakeKeyPairAuth, + SnowflakeOAuthAuth, SnowflakePasswordAuth, SnowflakeSSOAuth, ) @@ -158,6 +159,62 @@ def test_secret_params_sso_oauth() -> None: assert params["auth_type"] == "oauth" +# --------------------------------------------------------------------------- +# _create_secret_params — OAuth token auth +# --------------------------------------------------------------------------- + + +def test_secret_params_oauth_token() -> None: + auth = SnowflakeOAuthAuth(token="eyJhbGciOi.test.token") + config = _make_config(auth) + params = SnowflakeAdapter._create_secret_params(config) + + assert params["auth_type"] == "oauth" + assert params["token"] == "eyJhbGciOi.test.token" + assert "password" not in params + + +def test_secret_params_oauth_token_with_special_chars() -> None: + auth = SnowflakeOAuthAuth(token="token'with'quotes") + config = _make_config(auth) + params = SnowflakeAdapter._create_secret_params(config) + + assert params["token"] == "token'with'quotes" + + +# --------------------------------------------------------------------------- +# _create_auth — OAuth token from content dict +# --------------------------------------------------------------------------- + + +def test_create_auth_recognizes_token() -> None: + content = {**BASE_CONFIG, "token": "my_oauth_token"} + auth = SnowflakeAdapter._create_auth(content) + + assert isinstance(auth, SnowflakeOAuthAuth) + assert auth.token == "my_oauth_token" + + +# --------------------------------------------------------------------------- +# create_config_from_content — OAuth round-trip +# --------------------------------------------------------------------------- + + +def test_create_config_from_content_oauth() -> None: + content = { + "type": "snowflake", + "connection": { + **BASE_CONFIG, + "auth": {"token": "my_oauth_token"}, + }, + } + config = SnowflakeAdapter.create_config_from_content(content) + + assert isinstance(config, SnowflakeConnectionProperties) + assert isinstance(config.auth, SnowflakeOAuthAuth) + assert config.auth.token == "my_oauth_token" + + # --------------------------------------------------------------------------- # _create_secret_params — values with special characters # --------------------------------------------------------------------------- @@ -292,6 +349,7 @@ def test_create_config_from_runtime_host_not_in_additional_properties() -> None: "account": "nameaccount", "host": "nameaccount.eu-central-1.snowflakecomputing.com", "port": "443", + "autocommit": False, "user": "user", "password": "secret", } @@ -299,3 +357,191 @@ def test_create_config_from_runtime_host_not_in_additional_properties() -> None: config = SnowflakeAdapter.create_config_from_runtime(engine) assert isinstance(config, SnowflakeConnectionProperties) assert "host" not in config.additional_properties + assert "port" not in config.additional_properties + assert "autocommit" not in config.additional_properties + + +# --------------------------------------------------------------------------- +# create_sqlalchemy_engine — helpers +# --------------------------------------------------------------------------- + + +def _call_create_engine(config: SnowflakeConnectionProperties) -> tuple[dict[str, str], dict[str, Any]]: + """Call create_sqlalchemy_engine with mocked URL and create_engine, returning (url_kwargs, connect_args).""" + captured_url_kwargs: dict[str, str] = {} + captured_connect_args: dict[str, Any] = {} + + def fake_url(**kwargs: str) -> str: + captured_url_kwargs.update(kwargs) + return "snowflake://fake" + + def fake_create_engine(url: Any, *, connect_args: dict[str, Any] | None = None) -> MagicMock: + if connect_args: + captured_connect_args.update(connect_args) + return MagicMock() + + with ( + patch("databao.agent.databases.snowflake_adapter.create_engine", side_effect=fake_create_engine), + patch.dict("sys.modules", {"snowflake": MagicMock(), "snowflake.sqlalchemy": MagicMock(URL=fake_url)}), + ): + result = SnowflakeAdapter.create_sqlalchemy_engine(config) + + assert result is not None + return captured_url_kwargs, captured_connect_args + + +# --------------------------------------------------------------------------- +# create_sqlalchemy_engine — password auth +# --------------------------------------------------------------------------- + + +def test_create_engine_password_auth() -> None: + config = _make_config(SnowflakePasswordAuth(password="s3cr3t")) + url_kwargs, connect_args = _call_create_engine(config) + + assert url_kwargs["account"] == "myaccount" + assert url_kwargs["user"] == "myuser" + assert url_kwargs["database"] == "mydb" + assert url_kwargs["warehouse"] == "mywh" + assert url_kwargs["password"] == "s3cr3t" + assert "private_key" not in connect_args + assert "token" not in connect_args + + +def test_create_engine_password_auth_with_role() -> None: + config = _make_config(SnowflakePasswordAuth(password="pw"), role="ANALYST") + url_kwargs, _ = _call_create_engine(config) + + assert url_kwargs["role"] == "ANALYST" + + +def test_create_engine_password_auth_omits_none_fields() -> None: + config = SnowflakeConnectionProperties( + account="acct", user=None, database=None, warehouse=None, auth=SnowflakePasswordAuth(password="pw") + ) + url_kwargs, _ = _call_create_engine(config) + + assert "user" not in url_kwargs + assert "database" not in url_kwargs + assert "warehouse" not in url_kwargs + + +# --------------------------------------------------------------------------- +# create_sqlalchemy_engine — key pair auth +# --------------------------------------------------------------------------- + + +def test_create_engine_key_pair_auth(tmp_path: Path) -> None: + from cryptography.hazmat.primitives import serialization + from cryptography.hazmat.primitives.asymmetric import rsa + + private_key = rsa.generate_private_key(public_exponent=65537, key_size=2048) + pem = private_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ) + key_file = tmp_path / "rsa_key.pem" + key_file.write_bytes(pem) + + auth = SnowflakeKeyPairAuth(private_key_file=str(key_file)) + config = _make_config(auth) + url_kwargs, connect_args = _call_create_engine(config) + + assert "password" not in url_kwargs + assert "private_key" in connect_args + assert isinstance(connect_args["private_key"], bytes) + + +def test_create_engine_key_pair_auth_bad_file_raises() -> None: + auth = SnowflakeKeyPairAuth(private_key_file="/nonexistent/key.pem") + config = _make_config(auth) + + with pytest.raises(ValueError, match="Failed to read private key file"): + _call_create_engine(config) + + +# --------------------------------------------------------------------------- +# create_sqlalchemy_engine — OAuth auth +# --------------------------------------------------------------------------- + + +def test_create_engine_oauth_auth() -> None: + auth = SnowflakeOAuthAuth(token="eyJhbGciOi.test.token") + config = _make_config(auth) + url_kwargs, connect_args = _call_create_engine(config) + + assert "password" not in url_kwargs + assert connect_args["authenticator"] == "oauth" + assert connect_args["token"] == "eyJhbGciOi.test.token" + + +# --------------------------------------------------------------------------- +# create_sqlalchemy_engine — SSO auth +# --------------------------------------------------------------------------- + + +def test_create_engine_sso_externalbrowser() -> None: + auth = SnowflakeSSOAuth(authenticator="externalbrowser") + config = _make_config(auth) + url_kwargs, connect_args = _call_create_engine(config) + + assert url_kwargs["authenticator"] == "externalbrowser" + assert "token" not in connect_args + + +def test_create_engine_sso_okta() -> None: + auth = SnowflakeSSOAuth(authenticator="https://myorg.okta.com") + config = _make_config(auth) + url_kwargs, _ = _call_create_engine(config) + + assert url_kwargs["authenticator"] == "https://myorg.okta.com" + + +# --------------------------------------------------------------------------- +# create_sqlalchemy_engine — additional_properties +# --------------------------------------------------------------------------- + + +def test_create_engine_includes_additional_properties() -> None: + config = _make_config( + SnowflakePasswordAuth(password="pw"), + additional_properties={"timeout": 30, "client_session_keep_alive": True}, + ) + _, connect_args = _call_create_engine(config) + + assert connect_args["timeout"] == 30 + assert connect_args["client_session_keep_alive"] is True + + +# --------------------------------------------------------------------------- +# create_sqlalchemy_engine — unsupported config +# --------------------------------------------------------------------------- + + +def test_create_engine_returns_none_for_non_snowflake_config() -> None: + result = SnowflakeAdapter.create_sqlalchemy_engine(MagicMock()) + assert result is None + + +# --------------------------------------------------------------------------- +# create_config_from_runtime — TOKEN_KEY excluded from additional_properties +# --------------------------------------------------------------------------- + + +def test_create_config_from_runtime_excludes_token_from_additional_properties() -> None: + """TOKEN_KEY must be in EXCLUDED_QUERY_KEYS so OAuth tokens don't leak into additional_properties.""" + engine = _make_snowflake_engine( + { + "account": "acct", + "host": "acct.snowflakecomputing.com", + "user": "user", + "token": "secret-oauth-token", + } + ) + config = SnowflakeAdapter.create_config_from_runtime(engine) + assert isinstance(config, SnowflakeConnectionProperties) + assert "token" not in config.additional_properties + # The token should be captured in the auth object + assert isinstance(config.auth, SnowflakeOAuthAuth) + assert config.auth.token == "secret-oauth-token" diff --git a/uv.lock b/uv.lock index 1519f686..1b15ff77 100644 --- a/uv.lock +++ b/uv.lock @@ -793,7 +793,7 @@ requires-dist = [ { name = "anywidget", marker = "extra == 'examples'", specifier = ">=0.9.0" }, { name = "anywidget", marker = "extra == 'jupyter'", specifier = ">=0.9.0" }, { name = "claude-agent-sdk", specifier = ">=0.1.48" }, - { name = "databao-context-engine", extras = ["mysql", "postgresql", "snowflake"], specifier = "~=0.7.0" }, + { name = "databao-context-engine", extras = ["mysql", "postgresql", "snowflake"], specifier = "==0.7.1.dev2" }, { name = "dbt-core", specifier = "~=1.9.0" }, { name = "dbt-duckdb", specifier = ">=1.10.0" }, { name = "diskcache", specifier = ">=5.6.3" }, @@ -838,10 +838,9 @@ dev = [ [[package]] name = "databao-context-engine" -version = "0.7.0" +version = "0.7.1.dev2" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "click" }, { name = "duckdb" }, { name = "jinja2" }, { name = "mcp" }, @@ -852,9 +851,9 @@ dependencies = [ { name = "sqlparse" }, { name = "xxhash" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/c4/72/040aa0c0c1fa2302bc350ae0dc18cd82dadad9bcab1b23d2bb2b0491b6a8/databao_context_engine-0.7.0.tar.gz", hash = "sha256:447e7c1cf6bbe899a125296e2362838cc7e5bd4770cb1d9bf682c9c12bac20aa", size = 124182, upload-time = "2026-03-18T09:54:43.204Z" } +sdist = { url = "https://files.pythonhosted.org/packages/ca/c9/3770c53321613b76f536c1a354f37c751d5468b659e00d8f371fda2e1177/databao_context_engine-0.7.1.dev2.tar.gz", hash = "sha256:039c6c3a1dfeb23cf92db626e5038e580528c5e0fdb77d24a0a844c95e55e66c", size = 133266, upload-time = "2026-03-30T16:35:59.884Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/27/fd/a271eac8b1486f41a4d40b89aa7212f8f1581b9c97b00e9b60b35ca5a32c/databao_context_engine-0.7.0-py3-none-any.whl", hash = "sha256:d4521c63980e87a9906e890424244e672a0bdc953d13c9e484152dd2a8612fd5", size = 195832, upload-time = "2026-03-18T09:54:42.059Z" }, + { url = "https://files.pythonhosted.org/packages/a3/f9/70a8c87682fc9c32dd099ab6e5e2bca8187003375f2273a06b7f86cdc364/databao_context_engine-0.7.1.dev2-py3-none-any.whl", hash = "sha256:64f272f63e2e3de06455007bc17b24878597585181bad6f41e41d53a2940f87e", size = 210998, upload-time = "2026-03-30T16:35:58.641Z" }, ] [package.optional-dependencies] @@ -1374,6 +1373,7 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/1f/cb/48e964c452ca2b92175a9b2dca037a553036cb053ba69e284650ce755f13/greenlet-3.3.0-cp311-cp311-macosx_11_0_universal2.whl", hash = "sha256:e29f3018580e8412d6aaf5641bb7745d38c85228dacf51a73bd4e26ddf2a6a8e", size = 274908, upload-time = "2025-12-04T14:23:26.435Z" }, { url = "https://files.pythonhosted.org/packages/28/da/38d7bff4d0277b594ec557f479d65272a893f1f2a716cad91efeb8680953/greenlet-3.3.0-cp311-cp311-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:a687205fb22794e838f947e2194c0566d3812966b41c78709554aa883183fb62", size = 577113, upload-time = "2025-12-04T14:50:05.493Z" }, { url = "https://files.pythonhosted.org/packages/3c/f2/89c5eb0faddc3ff014f1c04467d67dee0d1d334ab81fadbf3744847f8a8a/greenlet-3.3.0-cp311-cp311-manylinux_2_24_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:4243050a88ba61842186cb9e63c7dfa677ec146160b0efd73b855a3d9c7fcf32", size = 590338, upload-time = "2025-12-04T14:57:41.136Z" }, + { url = "https://files.pythonhosted.org/packages/80/d7/db0a5085035d05134f8c089643da2b44cc9b80647c39e93129c5ef170d8f/greenlet-3.3.0-cp311-cp311-manylinux_2_24_s390x.manylinux_2_28_s390x.whl", hash = "sha256:670d0f94cd302d81796e37299bcd04b95d62403883b24225c6b5271466612f45", size = 601098, upload-time = "2025-12-04T15:07:11.898Z" }, { url = "https://files.pythonhosted.org/packages/dc/a6/e959a127b630a58e23529972dbc868c107f9d583b5a9f878fb858c46bc1a/greenlet-3.3.0-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:6cb3a8ec3db4a3b0eb8a3c25436c2d49e3505821802074969db017b87bc6a948", size = 590206, upload-time = "2025-12-04T14:26:01.254Z" }, { url = "https://files.pythonhosted.org/packages/48/60/29035719feb91798693023608447283b266b12efc576ed013dd9442364bb/greenlet-3.3.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:2de5a0b09eab81fc6a382791b995b1ccf2b172a9fec934747a7a23d2ff291794", size = 1550668, upload-time = "2025-12-04T15:04:22.439Z" }, { url = "https://files.pythonhosted.org/packages/0a/5f/783a23754b691bfa86bd72c3033aa107490deac9b2ef190837b860996c9f/greenlet-3.3.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:4449a736606bd30f27f8e1ff4678ee193bc47f6ca810d705981cfffd6ce0d8c5", size = 1615483, upload-time = "2025-12-04T14:27:28.083Z" }, @@ -1381,6 +1381,7 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/f8/0a/a3871375c7b9727edaeeea994bfff7c63ff7804c9829c19309ba2e058807/greenlet-3.3.0-cp312-cp312-macosx_11_0_universal2.whl", hash = "sha256:b01548f6e0b9e9784a2c99c5651e5dc89ffcbe870bc5fb2e5ef864e9cc6b5dcb", size = 276379, upload-time = "2025-12-04T14:23:30.498Z" }, { url = "https://files.pythonhosted.org/packages/43/ab/7ebfe34dce8b87be0d11dae91acbf76f7b8246bf9d6b319c741f99fa59c6/greenlet-3.3.0-cp312-cp312-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:349345b770dc88f81506c6861d22a6ccd422207829d2c854ae2af8025af303e3", size = 597294, upload-time = "2025-12-04T14:50:06.847Z" }, { url = "https://files.pythonhosted.org/packages/a4/39/f1c8da50024feecd0793dbd5e08f526809b8ab5609224a2da40aad3a7641/greenlet-3.3.0-cp312-cp312-manylinux_2_24_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:e8e18ed6995e9e2c0b4ed264d2cf89260ab3ac7e13555b8032b25a74c6d18655", size = 607742, upload-time = "2025-12-04T14:57:42.349Z" }, + { url = "https://files.pythonhosted.org/packages/77/cb/43692bcd5f7a0da6ec0ec6d58ee7cddb606d055ce94a62ac9b1aa481e969/greenlet-3.3.0-cp312-cp312-manylinux_2_24_s390x.manylinux_2_28_s390x.whl", hash = "sha256:c024b1e5696626890038e34f76140ed1daf858e37496d33f2af57f06189e70d7", size = 622297, upload-time = "2025-12-04T15:07:13.552Z" }, { url = "https://files.pythonhosted.org/packages/75/b0/6bde0b1011a60782108c01de5913c588cf51a839174538d266de15e4bf4d/greenlet-3.3.0-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:047ab3df20ede6a57c35c14bf5200fcf04039d50f908270d3f9a7a82064f543b", size = 609885, upload-time = "2025-12-04T14:26:02.368Z" }, { url = "https://files.pythonhosted.org/packages/49/0e/49b46ac39f931f59f987b7cd9f34bfec8ef81d2a1e6e00682f55be5de9f4/greenlet-3.3.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:2d9ad37fc657b1102ec880e637cccf20191581f75c64087a549e66c57e1ceb53", size = 1567424, upload-time = "2025-12-04T15:04:23.757Z" }, { url = "https://files.pythonhosted.org/packages/05/f5/49a9ac2dff7f10091935def9165c90236d8f175afb27cbed38fb1d61ab6b/greenlet-3.3.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:83cd0e36932e0e7f36a64b732a6f60c2fc2df28c351bae79fbaf4f8092fe7614", size = 1636017, upload-time = "2025-12-04T14:27:29.688Z" }, @@ -1388,6 +1389,7 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/02/2f/28592176381b9ab2cafa12829ba7b472d177f3acc35d8fbcf3673d966fff/greenlet-3.3.0-cp313-cp313-macosx_11_0_universal2.whl", hash = "sha256:a1e41a81c7e2825822f4e068c48cb2196002362619e2d70b148f20a831c00739", size = 275140, upload-time = "2025-12-04T14:23:01.282Z" }, { url = "https://files.pythonhosted.org/packages/2c/80/fbe937bf81e9fca98c981fe499e59a3f45df2a04da0baa5c2be0dca0d329/greenlet-3.3.0-cp313-cp313-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:9f515a47d02da4d30caaa85b69474cec77b7929b2e936ff7fb853d42f4bf8808", size = 599219, upload-time = "2025-12-04T14:50:08.309Z" }, { url = "https://files.pythonhosted.org/packages/c2/ff/7c985128f0514271b8268476af89aee6866df5eec04ac17dcfbc676213df/greenlet-3.3.0-cp313-cp313-manylinux_2_24_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:7d2d9fd66bfadf230b385fdc90426fcd6eb64db54b40c495b72ac0feb5766c54", size = 610211, upload-time = "2025-12-04T14:57:43.968Z" }, + { url = "https://files.pythonhosted.org/packages/79/07/c47a82d881319ec18a4510bb30463ed6891f2ad2c1901ed5ec23d3de351f/greenlet-3.3.0-cp313-cp313-manylinux_2_24_s390x.manylinux_2_28_s390x.whl", hash = "sha256:30a6e28487a790417d036088b3bcb3f3ac7d8babaa7d0139edbaddebf3af9492", size = 624311, upload-time = "2025-12-04T15:07:14.697Z" }, { url = "https://files.pythonhosted.org/packages/fd/8e/424b8c6e78bd9837d14ff7df01a9829fc883ba2ab4ea787d4f848435f23f/greenlet-3.3.0-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:087ea5e004437321508a8d6f20efc4cfec5e3c30118e1417ea96ed1d93950527", size = 612833, upload-time = "2025-12-04T14:26:03.669Z" }, { url = "https://files.pythonhosted.org/packages/b5/ba/56699ff9b7c76ca12f1cdc27a886d0f81f2189c3455ff9f65246780f713d/greenlet-3.3.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:ab97cf74045343f6c60a39913fa59710e4bd26a536ce7ab2397adf8b27e67c39", size = 1567256, upload-time = "2025-12-04T15:04:25.276Z" }, { url = "https://files.pythonhosted.org/packages/1e/37/f31136132967982d698c71a281a8901daf1a8fbab935dce7c0cf15f942cc/greenlet-3.3.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:5375d2e23184629112ca1ea89a53389dddbffcf417dad40125713d88eb5f96e8", size = 1636483, upload-time = "2025-12-04T14:27:30.804Z" }, @@ -1395,6 +1397,7 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d7/7c/f0a6d0ede2c7bf092d00bc83ad5bafb7e6ec9b4aab2fbdfa6f134dc73327/greenlet-3.3.0-cp314-cp314-macosx_11_0_universal2.whl", hash = "sha256:60c2ef0f578afb3c8d92ea07ad327f9a062547137afe91f38408f08aacab667f", size = 275671, upload-time = "2025-12-04T14:23:05.267Z" }, { url = "https://files.pythonhosted.org/packages/44/06/dac639ae1a50f5969d82d2e3dd9767d30d6dbdbab0e1a54010c8fe90263c/greenlet-3.3.0-cp314-cp314-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:0a5d554d0712ba1de0a6c94c640f7aeba3f85b3a6e1f2899c11c2c0428da9365", size = 646360, upload-time = "2025-12-04T14:50:10.026Z" }, { url = "https://files.pythonhosted.org/packages/e0/94/0fb76fe6c5369fba9bf98529ada6f4c3a1adf19e406a47332245ef0eb357/greenlet-3.3.0-cp314-cp314-manylinux_2_24_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:3a898b1e9c5f7307ebbde4102908e6cbfcb9ea16284a3abe15cab996bee8b9b3", size = 658160, upload-time = "2025-12-04T14:57:45.41Z" }, + { url = "https://files.pythonhosted.org/packages/93/79/d2c70cae6e823fac36c3bbc9077962105052b7ef81db2f01ec3b9bf17e2b/greenlet-3.3.0-cp314-cp314-manylinux_2_24_s390x.manylinux_2_28_s390x.whl", hash = "sha256:dcd2bdbd444ff340e8d6bdf54d2f206ccddbb3ccfdcd3c25bf4afaa7b8f0cf45", size = 671388, upload-time = "2025-12-04T15:07:15.789Z" }, { url = "https://files.pythonhosted.org/packages/b8/14/bab308fc2c1b5228c3224ec2bf928ce2e4d21d8046c161e44a2012b5203e/greenlet-3.3.0-cp314-cp314-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:5773edda4dc00e173820722711d043799d3adb4f01731f40619e07ea2750b955", size = 660166, upload-time = "2025-12-04T14:26:05.099Z" }, { url = "https://files.pythonhosted.org/packages/4b/d2/91465d39164eaa0085177f61983d80ffe746c5a1860f009811d498e7259c/greenlet-3.3.0-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:ac0549373982b36d5fd5d30beb8a7a33ee541ff98d2b502714a09f1169f31b55", size = 1615193, upload-time = "2025-12-04T15:04:27.041Z" }, { url = "https://files.pythonhosted.org/packages/42/1b/83d110a37044b92423084d52d5d5a3b3a73cafb51b547e6d7366ff62eff1/greenlet-3.3.0-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:d198d2d977460358c3b3a4dc844f875d1adb33817f0613f663a656f463764ccc", size = 1683653, upload-time = "2025-12-04T14:27:32.366Z" }, @@ -1402,6 +1405,7 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a0/66/bd6317bc5932accf351fc19f177ffba53712a202f9df10587da8df257c7e/greenlet-3.3.0-cp314-cp314t-macosx_11_0_universal2.whl", hash = "sha256:d6ed6f85fae6cdfdb9ce04c9bf7a08d666cfcfb914e7d006f44f840b46741931", size = 282638, upload-time = "2025-12-04T14:25:20.941Z" }, { url = "https://files.pythonhosted.org/packages/30/cf/cc81cb030b40e738d6e69502ccbd0dd1bced0588e958f9e757945de24404/greenlet-3.3.0-cp314-cp314t-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:d9125050fcf24554e69c4cacb086b87b3b55dc395a8b3ebe6487b045b2614388", size = 651145, upload-time = "2025-12-04T14:50:11.039Z" }, { url = "https://files.pythonhosted.org/packages/9c/ea/1020037b5ecfe95ca7df8d8549959baceb8186031da83d5ecceff8b08cd2/greenlet-3.3.0-cp314-cp314t-manylinux_2_24_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:87e63ccfa13c0a0f6234ed0add552af24cc67dd886731f2261e46e241608bee3", size = 654236, upload-time = "2025-12-04T14:57:47.007Z" }, + { url = "https://files.pythonhosted.org/packages/69/cc/1e4bae2e45ca2fa55299f4e85854606a78ecc37fead20d69322f96000504/greenlet-3.3.0-cp314-cp314t-manylinux_2_24_s390x.manylinux_2_28_s390x.whl", hash = "sha256:2662433acbca297c9153a4023fe2161c8dcfdcc91f10433171cf7e7d94ba2221", size = 662506, upload-time = "2025-12-04T15:07:16.906Z" }, { url = "https://files.pythonhosted.org/packages/57/b9/f8025d71a6085c441a7eaff0fd928bbb275a6633773667023d19179fe815/greenlet-3.3.0-cp314-cp314t-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:3c6e9b9c1527a78520357de498b0e709fb9e2f49c3a513afd5a249007261911b", size = 653783, upload-time = "2025-12-04T14:26:06.225Z" }, { url = "https://files.pythonhosted.org/packages/f6/c7/876a8c7a7485d5d6b5c6821201d542ef28be645aa024cfe1145b35c120c1/greenlet-3.3.0-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:286d093f95ec98fdd92fcb955003b8a3d054b4e2cab3e2707a5039e7b50520fd", size = 1614857, upload-time = "2025-12-04T15:04:28.484Z" }, { url = "https://files.pythonhosted.org/packages/4f/dc/041be1dff9f23dac5f48a43323cd0789cb798342011c19a248d9c9335536/greenlet-3.3.0-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:6c10513330af5b8ae16f023e8ddbfb486ab355d04467c4679c5cfe4659975dd9", size = 1676034, upload-time = "2025-12-04T14:27:33.531Z" },