Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions RELEASE-NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
## Deprecations

## New additions
* Added support for `private_key_file_pwd` in connection configuration to specify the private key passphrase. The `PRIVATE_KEY_PASSPHRASE` environment variable takes precedence for backward compatibility.

## Fixes and improvements

Expand Down
12 changes: 9 additions & 3 deletions src/snowflake/cli/_app/snow_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
"authenticator",
"workload_identity_provider",
"private_key_file",
"private_key_file_pwd",
"private_key_path",
"private_key_raw",
"database",
Expand Down Expand Up @@ -291,7 +292,8 @@ def _load_private_key(connection_parameters: Dict, private_key_var_name: str) ->
private_key_pem = _load_pem_from_file(
connection_parameters[private_key_var_name]
)
private_key = _load_pem_to_der(private_key_pem)
passphrase = connection_parameters.get("private_key_file_pwd")
private_key = _load_pem_to_der(private_key_pem, passphrase=passphrase)
connection_parameters["private_key"] = private_key.value
del connection_parameters[private_key_var_name]
else:
Expand Down Expand Up @@ -343,12 +345,16 @@ def _load_pem_from_parameters(private_key_raw: str) -> SecretType:
return SecretType(private_key_raw.encode("utf-8"))


def _load_pem_to_der(private_key_pem: SecretType) -> SecretType:
def _load_pem_to_der(
private_key_pem: SecretType, passphrase: Optional[str] = None
) -> SecretType:
"""
Given a private key file path (in PEM format), decode key data into DER
format
"""
private_key_passphrase = SecretType(os.getenv("PRIVATE_KEY_PASSPHRASE", None))
env_passphrase = os.getenv("PRIVATE_KEY_PASSPHRASE")
passphrase_value = env_passphrase if env_passphrase is None else passphrase
private_key_passphrase = SecretType(passphrase_value)
if (
private_key_pem.value.startswith(ENCRYPTED_PKCS8_PK_HEADER)
and private_key_passphrase.value is None
Expand Down
5 changes: 4 additions & 1 deletion src/snowflake/cli/_plugins/connection/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ def _mask_sensitive_parameters(connection_params: dict):
connection_params["password"] = "****"
if "oauth_client_secret" in connection_params:
connection_params["oauth_client_secret"] = "****"
if "private_key_file_pwd" in connection_params:
connection_params["private_key_file_pwd"] = "****"
return connection_params


Expand Down Expand Up @@ -412,7 +414,8 @@ def generate_jwt(
if not connection_details.private_key_file:
raise UsageError(msq_template.format("Private key file"))

passphrase = os.getenv("PRIVATE_KEY_PASSPHRASE", None)
env_passphrase = os.getenv("PRIVATE_KEY_PASSPHRASE")
passphrase = env_passphrase if env_passphrase is not None else connection_details.private_key_file_pwd

def _decrypt(passphrase: str | None):
return connector.auth.get_token_from_private_key(
Expand Down
1 change: 1 addition & 0 deletions src/snowflake/cli/api/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ class ConnectionConfig:
authenticator: Optional[str] = None
workload_identity_provider: Optional[str] = None
private_key_file: Optional[str] = None
private_key_file_pwd: Optional[str] = field(default=None, repr=False)
token_file_path: Optional[str] = None
oauth_client_id: Optional[str] = None
oauth_client_secret: Optional[str] = None
Expand Down
1 change: 1 addition & 0 deletions src/snowflake/cli/api/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ class ConnectionContext:
authenticator: Optional[str] = None
workload_identity_provider: Optional[str] = None
private_key_file: Optional[str] = None
private_key_file_pwd: Optional[str] = field(default=None, repr=False)
warehouse: Optional[str] = None
mfa_passcode: Optional[str] = None
token: Optional[str] = None
Expand Down
7 changes: 7 additions & 0 deletions tests/test.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,13 @@ account = "testing_account"
authenticator = "SNOWFLAKE_JWT"
private_key_file = "/private/key"

[connections.jwt_with_pwd]
user = "jdoe"
account = "testing_account"
authenticator = "SNOWFLAKE_JWT"
private_key_file = "/private/key"
private_key_file_pwd = "config_passphrase"

[cli.features]
dummy_flag = true
wrong_type_flag = "not_true"
71 changes: 71 additions & 0 deletions tests/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,17 @@ def test_lists_connection_information(mock_get_default_conn_name, runner):
"user": "jdoe",
},
},
{
"connection_name": "jwt_with_pwd",
"is_default": False,
"parameters": {
"account": "testing_account",
"authenticator": "SNOWFLAKE_JWT",
"private_key_file": "/private/key",
"private_key_file_pwd": "****", # masked
"user": "jdoe",
},
},
]


Expand Down Expand Up @@ -447,6 +458,17 @@ def test_connection_list_does_not_print_too_many_env_variables(
"user": "jdoe",
},
},
{
"connection_name": "jwt_with_pwd",
"is_default": False,
"parameters": {
"account": "testing_account",
"authenticator": "SNOWFLAKE_JWT",
"private_key_file": "/private/key",
"private_key_file_pwd": "****", # masked
"user": "jdoe",
},
},
]


Expand Down Expand Up @@ -1305,6 +1327,55 @@ def test_generate_jwt_uses_config(mocked_get_token, runner, named_temporary_file
)


@mock.patch(
"snowflake.cli._plugins.connection.commands.connector.auth.get_token_from_private_key"
)
@mock.patch.dict(os.environ, {}, clear=True)
def test_generate_jwt_uses_private_key_file_pwd_from_config(mocked_get_token, runner, named_temporary_file):
"""Test that private_key_file_pwd from config is used for generate_jwt."""
mocked_get_token.return_value = "funny token"

with named_temporary_file() as f:
f.write_text("secret from file")
result = runner.invoke(
["connection", "generate-jwt", "--connection", "jwt_with_pwd"],
)

assert result.exit_code == 0, result.output
assert result.output == "funny token\n"
mocked_get_token.assert_called_once_with(
user="jdoe",
account="testing_account",
privatekey_path="/private/key",
key_password="config_passphrase",
)


@mock.patch(
"snowflake.cli._plugins.connection.commands.connector.auth.get_token_from_private_key"
)
@mock.patch.dict(os.environ, {"PRIVATE_KEY_PASSPHRASE": "env_passphrase"}, clear=True)
def test_generate_jwt_env_passphrase_takes_precedence_over_config(mocked_get_token, runner, named_temporary_file):
"""Test that PRIVATE_KEY_PASSPHRASE env var takes precedence over private_key_file_pwd from config for backward compatibility."""
mocked_get_token.return_value = "funny token"

with named_temporary_file() as f:
f.write_text("secret from file")
result = runner.invoke(
["connection", "generate-jwt", "--connection", "jwt_with_pwd"],
)

assert result.exit_code == 0, result.output
assert result.output == "funny token\n"
# Env var should be used for backward compatibility, not the config passphrase
mocked_get_token.assert_called_once_with(
user="jdoe",
account="testing_account",
privatekey_path="/private/key",
key_password="env_passphrase",
)


@mock.patch(
"snowflake.cli._plugins.connection.commands.connector.auth.get_token_from_private_key"
)
Expand Down
64 changes: 63 additions & 1 deletion tests/test_snow_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,69 @@ def test_private_key_loading_and_aliases(
)
if expected_private_key_file_value is not None:
mock_load_pem_from_file.assert_called_with(expected_private_key_file_value)
mock_load_pem_to_der.assert_called_with(key)
mock_load_pem_to_der.assert_called_with(key, passphrase=None)


@mock.patch("snowflake.connector.connect")
@mock.patch("snowflake.cli._app.snow_connector.command_info")
@mock.patch("snowflake.cli._app.snow_connector._load_pem_to_der")
@mock.patch("snowflake.cli._app.snow_connector._load_pem_from_file")
def test_private_key_file_pwd_from_config(
mock_load_pem_from_file,
mock_load_pem_to_der,
mock_command_info,
mock_connect,
test_snowcli_config,
):
"""
Ensures that private_key_file_pwd from config is passed to _load_pem_to_der.
"""
from snowflake.cli._app.snow_connector import connect_to_snowflake
from snowflake.cli.api.config import config_init

config_init(test_snowcli_config)

key = SecretType(b"bytes")
mock_command_info.return_value = "SNOWCLI.SQL"
mock_load_pem_from_file.return_value = key
mock_load_pem_to_der.return_value = key

with mock.patch.dict(os.environ, {}, clear=True):
connect_to_snowflake(connection_name="jwt_with_pwd")
mock_load_pem_from_file.assert_called_with("/private/key")
mock_load_pem_to_der.assert_called_with(key, passphrase="config_passphrase")


@mock.patch("snowflake.connector.connect")
@mock.patch("snowflake.cli._app.snow_connector.command_info")
@mock.patch("snowflake.cli._app.snow_connector._load_pem_to_der")
@mock.patch("snowflake.cli._app.snow_connector._load_pem_from_file")
def test_private_key_file_pwd_config_fallback(
mock_load_pem_from_file,
mock_load_pem_to_der,
mock_command_info,
mock_connect,
test_snowcli_config,
):
"""
Ensures that private_key_file_pwd from config is used as fallback when PRIVATE_KEY_PASSPHRASE env var is not set.
"""
from snowflake.cli._app.snow_connector import connect_to_snowflake
from snowflake.cli.api.config import config_init

config_init(test_snowcli_config)

key = SecretType(b"bytes")
mock_command_info.return_value = "SNOWCLI.SQL"
mock_load_pem_from_file.return_value = key
mock_load_pem_to_der.return_value = key

# jwt connection does not have private_key_file_pwd, so passphrase=None is passed
with mock.patch.dict(os.environ, {}, clear=True):
connect_to_snowflake(connection_name="jwt")
mock_load_pem_from_file.assert_called_with("/private/key")
mock_load_pem_to_der.assert_called_with(key, passphrase=None)



@mock.patch.dict(os.environ, {}, clear=True)
Expand Down