diff --git a/databao/agent/databases/snowflake_adapter.py b/databao/agent/databases/snowflake_adapter.py index 14c3d898..b3d15c3c 100644 --- a/databao/agent/databases/snowflake_adapter.py +++ b/databao/agent/databases/snowflake_adapter.py @@ -113,26 +113,28 @@ def create_config_from_content(cls, content: dict[str, Any]) -> DBConnectionConf config_file = SnowflakeConfigFile.model_validate({"name": "", **content}) return config_file.connection - # TODO: url and name should be escaped properly @classmethod def register_in_duckdb(cls, shared_conn: DuckDBPyConnection, config: DBConnectionConfig, name: str) -> None: if not isinstance(config, SnowflakeConnectionProperties): raise ValueError( f"Invalid connection config type: expected SnowflakeConnectionProperties, got {type(config)}." ) - connection_string = cls._create_connection_string(config) - # Build the secret SQL before ATTACH so that preparation errors (e.g. unreadable key file) + # Build the secret params before ATTACH so that preparation errors (e.g. unreadable key file) # don't leave the connection in a partially-registered state. - secret_sql = cls._create_secret_sql(config, name) + secret_params = cls._create_secret_params(config) + formatted_secret_params = cls._format_sql_params(secret_params) + safe_name = cls._escape(name, '"') + shared_conn.execute("INSTALL snowflake FROM community;") shared_conn.execute("LOAD snowflake;") - shared_conn.execute(f"ATTACH '{connection_string}' AS \"{name}\" (TYPE snowflake, READ_ONLY);") - shared_conn.execute(secret_sql) + shared_conn.execute(f'CREATE OR REPLACE SECRET "{safe_name}" (TYPE snowflake, {formatted_secret_params});') + shared_conn.execute(f'ATTACH \'\' AS "{safe_name}" (TYPE snowflake, SECRET "{safe_name}", READ_ONLY);') @staticmethod - def _create_secret_sql(config: SnowflakeConnectionProperties, name: str) -> str: + def _create_secret_params(config: SnowflakeConnectionProperties) -> dict[str, str]: params: dict[str, str] = { ACCOUNT_KEY: config.account, + **{k: str(v) for k, v in config.additional_properties.items()}, } if config.user: params[USER_KEY] = config.user @@ -172,11 +174,7 @@ def _create_secret_sql(config: SnowflakeConnectionProperties, name: str) -> str: else: raise ValueError("Unsupported Snowflake authentication type.") - def _escape(v: str) -> str: - return v.replace("'", "''") - - kv = ", ".join(f"{k} '{_escape(v)}'" for k, v in params.items()) - return f'CREATE OR REPLACE SECRET "{name}" (TYPE snowflake, {kv});' + return params @staticmethod def _create_auth(content: dict[str, Any]) -> SnowflakePasswordAuth | SnowflakeKeyPairAuth | SnowflakeSSOAuth: @@ -195,38 +193,12 @@ def _create_auth(content: dict[str, Any]) -> SnowflakePasswordAuth | SnowflakeKe raise ValueError("Unsupported Snowflake authentication type.") @staticmethod - def _create_connection_string(config: SnowflakeConnectionProperties) -> str: - connection_parameters: dict[str, Any] = { - ACCOUNT_KEY: config.account, - WAREHOUSE_KEY: config.warehouse, - DATABASE_KEY: config.database, - USER_KEY: config.user, - **config.additional_properties, - } + def _escape(value: str, quote: str) -> str: + return value.replace(quote, quote + quote) - auth = config.auth - if isinstance(auth, SnowflakePasswordAuth): - connection_parameters[PASSWORD_KEY] = auth.password - elif isinstance(auth, SnowflakeKeyPairAuth): - connection_parameters[AUTH_TYPE_KEY] = AUTH_TYPE_KEY_PAIR - connection_parameters[PRIVATE_KEY_PASSPHRASE_KEY] = auth.private_key_file_pwd - if auth.private_key: - connection_parameters[PRIVATE_KEY_KEY] = auth.private_key - elif auth.private_key_file: - connection_parameters[PRIVATE_KEY_KEY] = Path(auth.private_key_file).absolute() - else: - raise ValueError("No private key provided.") - elif isinstance(auth, SnowflakeSSOAuth): - authenticator = auth.authenticator - if SnowflakeAdapter._is_okta_url(authenticator): - connection_parameters[AUTH_TYPE_KEY] = AUTH_TYPE_OKTA - connection_parameters[OKTA_URL_KEY] = authenticator - else: - connection_parameters[AUTH_TYPE_KEY] = authenticator - else: - raise ValueError("Unsupported Snowflake authentication type.") - - return ";".join(f"{k}={v!s}" for k, v in connection_parameters.items() if v is not None) + @classmethod + def _format_sql_params(cls, params: dict[str, str]) -> str: + return ", ".join(f"""{k} '{cls._escape(v, "'")}'""" for k, v in params.items()) @staticmethod def _is_okta_url(authenticator: str) -> bool: diff --git a/examples/snowflake-example.py b/examples/snowflake-example.py index 35d078fc..eaeb1045 100644 --- a/examples/snowflake-example.py +++ b/examples/snowflake-example.py @@ -1,13 +1,10 @@ import os -from pathlib import Path from typing import NoReturn -from sqlalchemy import create_engine, text +from sqlalchemy import create_engine import databao.agent as bao -FILE_DIR = Path(__file__).parent - def fail(message: str) -> NoReturn: raise RuntimeError(message) @@ -19,29 +16,17 @@ def from_env(key: str) -> str: def main() -> None: engine = create_engine( - "snowflake://{user}@{account_identifier}/{database}?private_key_file={private_key_file}".format( + "snowflake://{user}@{account_identifier}/{database}?private_key_file={private_key_file}&warehouse={warehouse}".format( user=from_env("SNOWFLAKE_USER"), # password=from_env("SNOWFLAKE_PASSWORD"), account_identifier=from_env("SNOWFLAKE_ACCOUNT"), - database="CALIFORNIA_TRAFFIC_COLLISION", + database=from_env("SNOWFLAKE_DATABASE"), private_key_file=from_env("SNOWFLAKE_PRIVATE_KEY_FILE"), + warehouse=from_env("SNOWFLAKE_WAREHOUSE"), ) ) - with engine.connect() as db_connection: - result = db_connection.execute(text("select current_version();")).fetchone() - - if result is None: - fail("Failed to execute query") - - print(f"Snowflake version: {result[0]}") - - project_dir = Path(FILE_DIR, "example-dce-project") - - if not project_dir.is_dir(): - project_dir.mkdir(parents=True) - - domain = bao.domain(project_dir) + domain = bao.domain() domain.add_db(engine) agent = bao.agent(domain=domain, name="my_agent", llm_config=bao.LLMConfig(name="gpt-5.1", temperature=0)) diff --git a/tests/test_snowflake_adapter.py b/tests/test_snowflake_adapter.py index b6c579a3..ef8ff90d 100644 --- a/tests/test_snowflake_adapter.py +++ b/tests/test_snowflake_adapter.py @@ -23,35 +23,14 @@ def _make_config(auth: Any, **kwargs: Any) -> SnowflakeConnectionProperties: return SnowflakeConnectionProperties(**{**BASE_CONFIG, **kwargs}, auth=auth) -def _parse_secret_sql(sql: str) -> dict[str, str]: - """Parse 'CREATE OR REPLACE SECRET "name" (TYPE snowflake, k 'v', ...)' into a dict.""" - import re - - inner = sql[sql.index("(") + 1 : sql.rindex(")")] - # Drop the leading "TYPE snowflake, " prefix - inner = inner.split(", ", 1)[1] - result: dict[str, str] = {} - for m in re.finditer(r"(\w+) '((?:[^']|'')*)'", inner): - result[m.group(1)] = m.group(2).replace("''", "'") - return result - - # --------------------------------------------------------------------------- -# _create_secret_sql — password auth +# _create_secret_params — password auth # --------------------------------------------------------------------------- -def test_secret_sql_password_auth_structure() -> None: - config = _make_config(SnowflakePasswordAuth(password="s3cr3t")) - sql = SnowflakeAdapter._create_secret_sql(config, "mydb") - - assert sql.startswith('CREATE OR REPLACE SECRET "mydb" (TYPE snowflake,') - assert sql.endswith(");") - - -def test_secret_sql_password_auth_params() -> None: +def test_secret_params_password_auth() -> None: config = _make_config(SnowflakePasswordAuth(password="s3cr3t")) - params = _parse_secret_sql(SnowflakeAdapter._create_secret_sql(config, "mydb")) + params = SnowflakeAdapter._create_secret_params(config) assert params["account"] == "myaccount" assert params["user"] == "myuser" @@ -61,43 +40,43 @@ def test_secret_sql_password_auth_params() -> None: assert "auth_type" not in params -def test_secret_sql_password_auth_no_role_by_default() -> None: +def test_secret_params_password_auth_no_role_by_default() -> None: config = _make_config(SnowflakePasswordAuth(password="s3cr3t")) - params = _parse_secret_sql(SnowflakeAdapter._create_secret_sql(config, "mydb")) + params = SnowflakeAdapter._create_secret_params(config) assert "role" not in params -def test_secret_sql_password_auth_includes_role_when_set() -> None: +def test_secret_params_password_auth_includes_role_when_set() -> None: config = _make_config(SnowflakePasswordAuth(password="s3cr3t"), role="ANALYST") - params = _parse_secret_sql(SnowflakeAdapter._create_secret_sql(config, "mydb")) + params = SnowflakeAdapter._create_secret_params(config) assert params["role"] == "ANALYST" -def test_secret_sql_omits_database_when_none() -> None: +def test_secret_params_omits_database_when_none() -> None: config = SnowflakeConnectionProperties( account="acct", user="usr", database=None, warehouse="wh", auth=SnowflakePasswordAuth(password="pw") ) - params = _parse_secret_sql(SnowflakeAdapter._create_secret_sql(config, "s")) + params = SnowflakeAdapter._create_secret_params(config) assert "database" not in params -def test_secret_sql_omits_warehouse_when_none() -> None: +def test_secret_params_omits_warehouse_when_none() -> None: config = SnowflakeConnectionProperties( account="acct", user="usr", database="db", warehouse=None, auth=SnowflakePasswordAuth(password="pw") ) - params = _parse_secret_sql(SnowflakeAdapter._create_secret_sql(config, "s")) + params = SnowflakeAdapter._create_secret_params(config) assert "warehouse" not in params # --------------------------------------------------------------------------- -# _create_secret_sql — key pair auth (inline key) +# _create_secret_params — key pair auth (inline key) # --------------------------------------------------------------------------- -def test_secret_sql_key_pair_inline_key() -> None: +def test_secret_params_key_pair_inline_key() -> None: auth = SnowflakeKeyPairAuth(private_key="-----BEGIN PRIVATE KEY-----\nABC\n-----END PRIVATE KEY-----\n") config = _make_config(auth) - params = _parse_secret_sql(SnowflakeAdapter._create_secret_sql(config, "mydb")) + params = SnowflakeAdapter._create_secret_params(config) assert params["auth_type"] == "key_pair" assert "BEGIN PRIVATE KEY" in params["private_key"] @@ -105,121 +84,147 @@ def test_secret_sql_key_pair_inline_key() -> None: assert "private_key_passphrase" not in params -def test_secret_sql_key_pair_inline_key_with_passphrase() -> None: +def test_secret_params_key_pair_inline_key_with_passphrase() -> None: auth = SnowflakeKeyPairAuth( private_key="-----BEGIN ENCRYPTED PRIVATE KEY-----\nXYZ\n-----END ENCRYPTED PRIVATE KEY-----\n", private_key_file_pwd="mypassphrase", ) config = _make_config(auth) - params = _parse_secret_sql(SnowflakeAdapter._create_secret_sql(config, "mydb")) + params = SnowflakeAdapter._create_secret_params(config) assert params["auth_type"] == "key_pair" assert params["private_key_passphrase"] == "mypassphrase" # --------------------------------------------------------------------------- -# _create_secret_sql — key pair auth (file path) +# _create_secret_params — key pair auth (file path) # --------------------------------------------------------------------------- -def test_secret_sql_key_pair_file_reads_content(tmp_path: Path) -> None: +def test_secret_params_key_pair_file_reads_content(tmp_path: Path) -> None: key_content = "-----BEGIN PRIVATE KEY-----\nFILE_KEY\n-----END PRIVATE KEY-----\n" key_file = tmp_path / "rsa_key.p8" key_file.write_text(key_content) auth = SnowflakeKeyPairAuth(private_key_file=str(key_file)) config = _make_config(auth) - params = _parse_secret_sql(SnowflakeAdapter._create_secret_sql(config, "mydb")) + params = SnowflakeAdapter._create_secret_params(config) assert params["auth_type"] == "key_pair" assert params["private_key"] == key_content -def test_secret_sql_key_pair_file_with_passphrase(tmp_path: Path) -> None: +def test_secret_params_key_pair_file_with_passphrase(tmp_path: Path) -> None: key_file = tmp_path / "rsa_key.p8" key_file.write_text("key") auth = SnowflakeKeyPairAuth(private_key_file=str(key_file), private_key_file_pwd="phrase") config = _make_config(auth) - params = _parse_secret_sql(SnowflakeAdapter._create_secret_sql(config, "mydb")) + params = SnowflakeAdapter._create_secret_params(config) assert params["private_key_passphrase"] == "phrase" # --------------------------------------------------------------------------- -# _create_secret_sql — SSO auth +# _create_secret_params — SSO auth # --------------------------------------------------------------------------- -def test_secret_sql_sso_externalbrowser() -> None: +def test_secret_params_sso_externalbrowser() -> None: auth = SnowflakeSSOAuth(authenticator="externalbrowser") config = _make_config(auth) - params = _parse_secret_sql(SnowflakeAdapter._create_secret_sql(config, "mydb")) + params = SnowflakeAdapter._create_secret_params(config) assert params["auth_type"] == "externalbrowser" assert "okta_url" not in params assert "password" not in params -def test_secret_sql_sso_okta_url() -> None: +def test_secret_params_sso_okta_url() -> None: okta_url = "https://myorg.okta.com" auth = SnowflakeSSOAuth(authenticator=okta_url) config = _make_config(auth) - params = _parse_secret_sql(SnowflakeAdapter._create_secret_sql(config, "mydb")) + params = SnowflakeAdapter._create_secret_params(config) assert params["auth_type"] == "okta" assert params["okta_url"] == okta_url -def test_secret_sql_sso_oauth() -> None: +def test_secret_params_sso_oauth() -> None: auth = SnowflakeSSOAuth(authenticator="oauth") config = _make_config(auth) - params = _parse_secret_sql(SnowflakeAdapter._create_secret_sql(config, "mydb")) + params = SnowflakeAdapter._create_secret_params(config) assert params["auth_type"] == "oauth" # --------------------------------------------------------------------------- -# _create_secret_sql — secret name quoting +# _create_secret_params — values with special characters # --------------------------------------------------------------------------- -def test_secret_sql_name_used_correctly() -> None: - config = _make_config(SnowflakePasswordAuth(password="pw")) - sql = SnowflakeAdapter._create_secret_sql(config, "my_secret_name") - assert 'SECRET "my_secret_name"' in sql +def test_secret_params_preserves_single_quotes_in_password() -> None: + config = _make_config(SnowflakePasswordAuth(password="my'password")) + params = SnowflakeAdapter._create_secret_params(config) + assert params["password"] == "my'password" + + +def test_secret_params_includes_additional_properties() -> None: + config = _make_config(SnowflakePasswordAuth(password="pw"), additional_properties={"timeout": 30, "custom": "val"}) + params = SnowflakeAdapter._create_secret_params(config) + assert params["timeout"] == "30" + assert params["custom"] == "val" # --------------------------------------------------------------------------- -# _create_secret_sql — single-quote escaping +# _format_sql_params — SQL formatting and escaping # --------------------------------------------------------------------------- -def test_secret_sql_escapes_single_quotes_in_password() -> None: - config = _make_config(SnowflakePasswordAuth(password="my'password")) - sql = SnowflakeAdapter._create_secret_sql(config, "s") - assert "my''password" in sql - params = _parse_secret_sql(sql) - assert params["password"] == "my'password" +def test_format_sql_params_basic() -> None: + assert SnowflakeAdapter._format_sql_params({"account": "acct", "user": "me"}) == "account 'acct', user 'me'" + + +def test_format_sql_params_escapes_single_quotes() -> None: + assert SnowflakeAdapter._format_sql_params({"password": "my'pass"}) == "password 'my''pass'" # --------------------------------------------------------------------------- -# _create_secret_sql — error handling +# _create_secret_params — error handling # --------------------------------------------------------------------------- -def test_secret_sql_key_pair_no_key_raises() -> None: +def test_secret_params_key_pair_no_key_raises() -> None: auth = SnowflakeKeyPairAuth(private_key=None, private_key_file=None) config = _make_config(auth) with pytest.raises(ValueError, match="No private key provided"): - SnowflakeAdapter._create_secret_sql(config, "s") + SnowflakeAdapter._create_secret_params(config) -def test_secret_sql_key_pair_file_not_found_raises() -> None: +def test_secret_params_key_pair_file_not_found_raises() -> None: auth = SnowflakeKeyPairAuth(private_key_file="/nonexistent/path/key.p8") config = _make_config(auth) with pytest.raises(ValueError, match="Unable to read Snowflake private key file"): - SnowflakeAdapter._create_secret_sql(config, "s") + SnowflakeAdapter._create_secret_params(config) + + +# --------------------------------------------------------------------------- +# register_in_duckdb — statement ordering +# --------------------------------------------------------------------------- + + +def test_register_in_duckdb_executes_statements_in_order() -> None: + config = _make_config(SnowflakePasswordAuth(password="s3cr3t")) + conn = MagicMock() + + SnowflakeAdapter.register_in_duckdb(conn, config, "mydb") + + calls = [c.args[0] for c in conn.execute.call_args_list] + assert len(calls) == 4 + assert calls[0] == "INSTALL snowflake FROM community;" + assert calls[1] == "LOAD snowflake;" + assert calls[2].startswith('CREATE OR REPLACE SECRET "mydb" (TYPE snowflake,') + assert calls[3] == """ATTACH '' AS "mydb" (TYPE snowflake, SECRET "mydb", READ_ONLY);""" # ---------------------------------------------------------------------------