diff --git a/CHANGELOG.md b/CHANGELOG.md index 580d582337..57befa90c3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,7 @@ - Fixed a bug in `DataFrameReader.xml()` where reading XML with a custom schema whose field names contain colons (e.g., `px:name`) raised a `SnowparkColumnException`. - Fixed a bug in that caused SQL compilation errors in `Session.read.json` when `INFER_SCHEMA` was set to True, and the `USE_RELAXED_TYPES` field of `INFER_SCHEMA_OPTIONS` was also set to True. - Fixed a bug where passing a DataFrame created from a SQL `SET` command to Streamlit's `st.write` method would raise an exception. +- Fixed a bug where the account-level default artifact repository setting was not reflected in creation of stored procedures/UDFs. #### Improvements diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index db2b90688c..f80255d00d 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -69,6 +69,9 @@ GeneratorTableFunction, TableFunctionRelation, ) +from snowflake.snowpark._internal.analyzer.analyzer_utils import ( + quote_name_without_upper_casing, +) from snowflake.snowpark._internal.analyzer.unary_expression import Cast from snowflake.snowpark._internal.ast.batch import AstBatch from snowflake.snowpark._internal.ast.utils import ( @@ -2424,7 +2427,6 @@ def _get_default_artifact_repository(self) -> str: if isinstance(self._conn, MockServerConnection): return _DEFAULT_ARTIFACT_REPOSITORY - account = self.get_current_account() database = self.get_current_database() schema = self.get_current_schema() cache_key = (database, schema) @@ -2442,7 +2444,10 @@ def _get_default_artifact_repository(self) -> str: if schema else f"'database', '{database}'" if database - else f"'account', '{account}'" + # self.get_current_account uses a cached connector field that may not be properly cased, so we need to + # explicitly issue a query for it. + # Since this issues a query, we should compute it only if database/schema are unset. + else f"""'account', '{quote_name_without_upper_casing(self._conn._get_string_datum("SELECT CURRENT_ACCOUNT()"))}'""" ) result = self._run_query( f"SELECT SYSTEM$GET_DEFAULT_PYTHON_ARTIFACT_REPOSITORY('{python_version}', {entity_selector_args})" diff --git a/tests/integ/test_session.py b/tests/integ/test_session.py index 7cd4017985..128f8013f5 100644 --- a/tests/integ/test_session.py +++ b/tests/integ/test_session.py @@ -8,6 +8,7 @@ import io from functools import partial from unittest.mock import patch +import logging import pytest @@ -41,6 +42,7 @@ _get_active_session, _get_active_sessions, ) +from snowflake.snowpark.context import _ANACONDA_SHARED_REPOSITORY from tests.utils import IS_IN_STORED_PROC, IS_IN_STORED_PROC_LOCALFS, TestFiles, Utils @@ -1091,3 +1093,27 @@ def test_get_active_sessions_empty(): with patch.object(session_module, "_active_sessions", return_value=set()): assert session_module._get_active_sessions(require_at_least_one=False) == set() + + +def test_default_artifact_repository_with_no_db_schema(session, caplog): + # The reported customer issue covered by this test (SNOW-3230493) occurs when no schema/database + # is set and an account locator is used, so we mock schema/db to be empty for this test. + # Oddly, getting schema and database appear to be fine (for example, schema comes back with + # double-quoted all-caps '"PUBLIC"') while the connector's cached _account field produces a + # double-quoted lowercase value ('"sfctest0') that causes an error when passed to + # SYSTEM$GET_DEFAULT_ARTIFACT_REPOSITORY. + original_conn_implementation = session._conn._get_current_parameter + + def mock_session_parameters(param: str, quoted: bool = True): + if param == "schema" or param == "database": + return None + return original_conn_implementation(param, quoted) + + with patch.object( + session._conn, + "_get_current_parameter", + mock_session_parameters, + ), caplog.at_level(logging.WARNING): + result = session._get_default_artifact_repository() + assert result == _ANACONDA_SHARED_REPOSITORY + assert caplog.text.count("Error getting default artifact repository") == 0