Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion databao/agent/databases/database_adapter.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Any
from typing import TYPE_CHECKING, Any

from _duckdb import DuckDBPyConnection
from databao_context_engine import DatasourceType
Expand All @@ -11,6 +11,9 @@
DBConnectionRuntime,
)

if TYPE_CHECKING:
from sqlalchemy import Engine


class DatabaseAdapter(ABC):
@classmethod
Expand All @@ -36,3 +39,8 @@ def create_config_from_content(cls, content: dict[str, Any]) -> DBConnectionConf
@classmethod
@abstractmethod
def register_in_duckdb(cls, shared_conn: DuckDBPyConnection, config: DBConnectionConfig, name: str) -> None: ...

@classmethod
def create_sqlalchemy_engine(cls, config: DBConnectionConfig) -> "Engine | None":
"""Create a SQLAlchemy engine from a connection config, or return None if not supported."""
return None
12 changes: 11 additions & 1 deletion databao/agent/databases/databases.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any
from typing import TYPE_CHECKING, Any

from _duckdb import DuckDBPyConnection
from databao_context_engine.pluginlib.build_plugin import AbstractConfigFile, DatasourceType
Expand All @@ -15,6 +15,9 @@
from databao.agent.databases.snowflake_adapter import SnowflakeAdapter
from databao.agent.databases.sqlite_adapter import SQLiteAdapter

if TYPE_CHECKING:
from sqlalchemy import Engine

DATABASE_ADAPTERS: list[DatabaseAdapter] = [
BigQueryAdapter(),
DuckDBAdapter(),
Expand Down Expand Up @@ -59,3 +62,10 @@ def register_db_in_duckdb(shared_conn: DuckDBPyConnection, config: DBConnectionC
adapter.register_in_duckdb(shared_conn, config, name)
return
raise ValueError(f"Cannot register connection for config type {type(config)} in DuckDB.")


def try_create_sqlalchemy_engine(config: DBConnectionConfig) -> "Engine | None":
for adapter in DATABASE_ADAPTERS:
if adapter.accept(config):
return adapter.create_sqlalchemy_engine(config)
return None
77 changes: 74 additions & 3 deletions databao/agent/databases/snowflake_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,13 @@
SnowflakeConfigFile,
SnowflakeConnectionProperties,
SnowflakeKeyPairAuth,
SnowflakeOAuthAuth,
SnowflakePasswordAuth,
SnowflakeSSOAuth,
)
from databao_context_engine.pluginlib.build_plugin import AbstractConfigFile
from snowflake.connector.network import SNOWFLAKE_HOST_SUFFIX
from sqlalchemy import Connection, Engine, make_url
from sqlalchemy import Connection, Engine, create_engine, make_url

from databao.agent.databases.database_adapter import DatabaseAdapter
from databao.agent.databases.database_connection import DBConnection, DBConnectionConfig, DBConnectionRuntime
Expand Down Expand Up @@ -44,8 +45,13 @@
PRIVATE_KEY_FILE_KEY,
PRIVATE_KEY_PASSPHRASE_KEY,
OKTA_URL_KEY,
TOKEN_KEY,
}

# Keys injected by SQLAlchemy's Snowflake dialect that are not valid Snowflake connection properties.
# Note: "host" is also dialect-internal but handled separately because its value is used to derive the account.
_SQLALCHEMY_INTERNAL_KEYS = {"port", "autocommit"}

EXCLUDED_QUERY_KEYS = {*MAIN_KEYS, *AUTH_KEYS}
Comment thread
SimonKaran13 marked this conversation as resolved.

AUTH_TYPE_KEY = "auth_type"
Expand Down Expand Up @@ -94,6 +100,8 @@ def create_config_from_runtime(cls, run_conn: DBConnectionRuntime) -> DBConnecti
content[DATABASE_KEY] = content.pop("dbname")

host: str | None = content.pop("host", None)
for key in _SQLALCHEMY_INTERNAL_KEYS:
content.pop(key, None)
account: str = content.get(ACCOUNT_KEY, "")
if host and host.endswith(SNOWFLAKE_HOST_SUFFIX):
account = host[: -len(SNOWFLAKE_HOST_SUFFIX)]
Expand All @@ -113,6 +121,64 @@ def create_config_from_content(cls, content: dict[str, Any]) -> DBConnectionConf
config_file = SnowflakeConfigFile.model_validate({"name": "", **content})
return config_file.connection

@classmethod
def create_sqlalchemy_engine(cls, config: DBConnectionConfig) -> Engine | None:
if not isinstance(config, SnowflakeConnectionProperties):
return None

from snowflake.sqlalchemy import URL # type: ignore[import-untyped]

url_kwargs: dict[str, str] = {"account": config.account}
if config.user:
url_kwargs["user"] = config.user
if config.database:
url_kwargs["database"] = config.database
if config.warehouse:
url_kwargs["warehouse"] = config.warehouse
if config.role:
url_kwargs["role"] = config.role

connect_args: dict[str, Any] = {k: v for k, v in config.additional_properties.items()}
auth = config.auth
if isinstance(auth, SnowflakePasswordAuth):
url_kwargs["password"] = auth.password
elif isinstance(auth, SnowflakeKeyPairAuth):
connect_args["private_key"] = cls._load_private_key_bytes(auth)
elif isinstance(auth, SnowflakeOAuthAuth):
connect_args["authenticator"] = "oauth"
connect_args["token"] = auth.token
elif isinstance(auth, SnowflakeSSOAuth):
url_kwargs["authenticator"] = auth.authenticator
else:
return None

if connect_args:
return create_engine(URL(**url_kwargs), connect_args=connect_args)
return create_engine(URL(**url_kwargs))
Comment thread
SimonKaran13 marked this conversation as resolved.

Comment thread
catstrike marked this conversation as resolved.
@staticmethod
def _load_private_key_bytes(auth: SnowflakeKeyPairAuth) -> bytes:
from cryptography.hazmat.primitives import serialization

if auth.private_key:
pem_data = auth.private_key.encode()
elif auth.private_key_file:
try:
pem_data = Path(auth.private_key_file).read_bytes()
except OSError as exc:
raise ValueError(f"Failed to read private key file at '{auth.private_key_file}'.") from exc
else:
raise ValueError("No private key provided.")

passphrase = auth.private_key_file_pwd.encode() if auth.private_key_file_pwd else None
private_key = serialization.load_pem_private_key(pem_data, password=passphrase)
return private_key.private_bytes(
encoding=serialization.Encoding.DER,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption(),
)

# TODO: url and name should be escaped properly
@classmethod
def register_in_duckdb(cls, shared_conn: DuckDBPyConnection, config: DBConnectionConfig, name: str) -> None:
if not isinstance(config, SnowflakeConnectionProperties):
Expand Down Expand Up @@ -164,6 +230,9 @@ def _create_secret_params(config: SnowflakeConnectionProperties) -> dict[str, st
raise ValueError("No private key provided.")
if auth.private_key_file_pwd:
params[PRIVATE_KEY_PASSPHRASE_KEY] = auth.private_key_file_pwd
elif isinstance(auth, SnowflakeOAuthAuth):
params[AUTH_TYPE_KEY] = AUTH_TYPE_OAUTH
params[TOKEN_KEY] = auth.token
elif isinstance(auth, SnowflakeSSOAuth):
authenticator = auth.authenticator
if SnowflakeAdapter._is_okta_url(authenticator):
Expand All @@ -177,7 +246,9 @@ def _create_secret_params(config: SnowflakeConnectionProperties) -> dict[str, st
return params

@staticmethod
def _create_auth(content: dict[str, Any]) -> SnowflakePasswordAuth | SnowflakeKeyPairAuth | SnowflakeSSOAuth:
def _create_auth(
content: dict[str, Any],
) -> SnowflakePasswordAuth | SnowflakeKeyPairAuth | SnowflakeSSOAuth | SnowflakeOAuthAuth:
if PASSWORD_KEY in content:
return SnowflakePasswordAuth(password=content[PASSWORD_KEY])
if content.keys() & {PRIVATE_KEY_KEY, PRIVATE_KEY_FILE_KEY}:
Expand All @@ -187,7 +258,7 @@ def _create_auth(content: dict[str, Any]) -> SnowflakePasswordAuth | SnowflakeKe
private_key=content.get(PRIVATE_KEY_KEY),
)
if TOKEN_KEY in content:
return SnowflakeSSOAuth(authenticator=AUTH_TYPE_OAUTH)
return SnowflakeOAuthAuth(token=content[TOKEN_KEY])
if OKTA_URL_KEY in content:
return SnowflakeSSOAuth(authenticator=content[OKTA_URL_KEY])
raise ValueError("Unsupported Snowflake authentication type.")
Expand Down
Empty file.
Loading
Loading