Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 15 additions & 43 deletions databao/agent/databases/snowflake_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,26 +113,28 @@ def create_config_from_content(cls, content: dict[str, Any]) -> DBConnectionConf
config_file = SnowflakeConfigFile.model_validate({"name": "", **content})
return config_file.connection

# TODO: url and name should be escaped properly
@classmethod
def register_in_duckdb(cls, shared_conn: DuckDBPyConnection, config: DBConnectionConfig, name: str) -> None:
if not isinstance(config, SnowflakeConnectionProperties):
raise ValueError(
f"Invalid connection config type: expected SnowflakeConnectionProperties, got {type(config)}."
)
connection_string = cls._create_connection_string(config)
# Build the secret SQL before ATTACH so that preparation errors (e.g. unreadable key file)
# Build the secret params before ATTACH so that preparation errors (e.g. unreadable key file)
# don't leave the connection in a partially-registered state.
secret_sql = cls._create_secret_sql(config, name)
secret_params = cls._create_secret_params(config)
formatted_secret_params = cls._format_sql_params(secret_params)
safe_name = cls._escape(name, '"')

shared_conn.execute("INSTALL snowflake FROM community;")
shared_conn.execute("LOAD snowflake;")
shared_conn.execute(f"ATTACH '{connection_string}' AS \"{name}\" (TYPE snowflake, READ_ONLY);")
shared_conn.execute(secret_sql)
shared_conn.execute(f'CREATE OR REPLACE SECRET "{safe_name}" (TYPE snowflake, {formatted_secret_params});')
shared_conn.execute(f'ATTACH \'\' AS "{safe_name}" (TYPE snowflake, SECRET "{safe_name}", READ_ONLY);')

@staticmethod
def _create_secret_sql(config: SnowflakeConnectionProperties, name: str) -> str:
def _create_secret_params(config: SnowflakeConnectionProperties) -> dict[str, str]:
params: dict[str, str] = {
ACCOUNT_KEY: config.account,
**{k: str(v) for k, v in config.additional_properties.items()},
}
Comment thread
catstrike marked this conversation as resolved.
if config.user:
params[USER_KEY] = config.user
Expand Down Expand Up @@ -172,11 +174,7 @@ def _create_secret_sql(config: SnowflakeConnectionProperties, name: str) -> str:
else:
raise ValueError("Unsupported Snowflake authentication type.")

def _escape(v: str) -> str:
return v.replace("'", "''")

kv = ", ".join(f"{k} '{_escape(v)}'" for k, v in params.items())
return f'CREATE OR REPLACE SECRET "{name}" (TYPE snowflake, {kv});'
return params

@staticmethod
def _create_auth(content: dict[str, Any]) -> SnowflakePasswordAuth | SnowflakeKeyPairAuth | SnowflakeSSOAuth:
Expand All @@ -195,38 +193,12 @@ def _create_auth(content: dict[str, Any]) -> SnowflakePasswordAuth | SnowflakeKe
raise ValueError("Unsupported Snowflake authentication type.")

@staticmethod
def _create_connection_string(config: SnowflakeConnectionProperties) -> str:
connection_parameters: dict[str, Any] = {
ACCOUNT_KEY: config.account,
WAREHOUSE_KEY: config.warehouse,
DATABASE_KEY: config.database,
USER_KEY: config.user,
**config.additional_properties,
}
def _escape(value: str, quote: str) -> str:
return value.replace(quote, quote + quote)

auth = config.auth
if isinstance(auth, SnowflakePasswordAuth):
connection_parameters[PASSWORD_KEY] = auth.password
elif isinstance(auth, SnowflakeKeyPairAuth):
connection_parameters[AUTH_TYPE_KEY] = AUTH_TYPE_KEY_PAIR
connection_parameters[PRIVATE_KEY_PASSPHRASE_KEY] = auth.private_key_file_pwd
if auth.private_key:
connection_parameters[PRIVATE_KEY_KEY] = auth.private_key
elif auth.private_key_file:
connection_parameters[PRIVATE_KEY_KEY] = Path(auth.private_key_file).absolute()
else:
raise ValueError("No private key provided.")
elif isinstance(auth, SnowflakeSSOAuth):
authenticator = auth.authenticator
if SnowflakeAdapter._is_okta_url(authenticator):
connection_parameters[AUTH_TYPE_KEY] = AUTH_TYPE_OKTA
connection_parameters[OKTA_URL_KEY] = authenticator
else:
connection_parameters[AUTH_TYPE_KEY] = authenticator
else:
raise ValueError("Unsupported Snowflake authentication type.")

return ";".join(f"{k}={v!s}" for k, v in connection_parameters.items() if v is not None)
@classmethod
def _format_sql_params(cls, params: dict[str, str]) -> str:
return ", ".join(f"""{k} '{cls._escape(v, "'")}'""" for k, v in params.items())

@staticmethod
def _is_okta_url(authenticator: str) -> bool:
Expand Down
25 changes: 5 additions & 20 deletions examples/snowflake-example.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
import os
from pathlib import Path
from typing import NoReturn

from sqlalchemy import create_engine, text
from sqlalchemy import create_engine

import databao.agent as bao

FILE_DIR = Path(__file__).parent


def fail(message: str) -> NoReturn:
raise RuntimeError(message)
Expand All @@ -19,29 +16,17 @@ def from_env(key: str) -> str:

def main() -> None:
engine = create_engine(
"snowflake://{user}@{account_identifier}/{database}?private_key_file={private_key_file}".format(
"snowflake://{user}@{account_identifier}/{database}?private_key_file={private_key_file}&warehouse={warehouse}".format(
user=from_env("SNOWFLAKE_USER"),
# password=from_env("SNOWFLAKE_PASSWORD"),
account_identifier=from_env("SNOWFLAKE_ACCOUNT"),
database="CALIFORNIA_TRAFFIC_COLLISION",
database=from_env("SNOWFLAKE_DATABASE"),
private_key_file=from_env("SNOWFLAKE_PRIVATE_KEY_FILE"),
warehouse=from_env("SNOWFLAKE_WAREHOUSE"),
)
)

with engine.connect() as db_connection:
result = db_connection.execute(text("select current_version();")).fetchone()

if result is None:
fail("Failed to execute query")

print(f"Snowflake version: {result[0]}")

project_dir = Path(FILE_DIR, "example-dce-project")

if not project_dir.is_dir():
project_dir.mkdir(parents=True)

domain = bao.domain(project_dir)
domain = bao.domain()
domain.add_db(engine)

agent = bao.agent(domain=domain, name="my_agent", llm_config=bao.LLMConfig(name="gpt-5.1", temperature=0))
Expand Down
Loading
Loading