Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
- Fixed a bug in `DataFrameReader.xml()` where reading XML with a custom schema whose field names contain colons (e.g., `px:name`) raised a `SnowparkColumnException`.
- Fixed a bug in that caused SQL compilation errors in `Session.read.json` when `INFER_SCHEMA` was set to True, and the `USE_RELAXED_TYPES` field of `INFER_SCHEMA_OPTIONS` was also set to True.
- Fixed a bug where passing a DataFrame created from a SQL `SET` command to Streamlit's `st.write` method would raise an exception.
- Fixed a bug where the account-level default artifact repository setting was not reflected in creation of stored procedures/UDFs.

#### Improvements

Expand Down
9 changes: 7 additions & 2 deletions src/snowflake/snowpark/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,9 @@
GeneratorTableFunction,
TableFunctionRelation,
)
from snowflake.snowpark._internal.analyzer.analyzer_utils import (
quote_name_without_upper_casing,
)
from snowflake.snowpark._internal.analyzer.unary_expression import Cast
from snowflake.snowpark._internal.ast.batch import AstBatch
from snowflake.snowpark._internal.ast.utils import (
Expand Down Expand Up @@ -2424,7 +2427,6 @@ def _get_default_artifact_repository(self) -> str:
if isinstance(self._conn, MockServerConnection):
return _DEFAULT_ARTIFACT_REPOSITORY

account = self.get_current_account()
database = self.get_current_database()
schema = self.get_current_schema()
cache_key = (database, schema)
Expand All @@ -2442,7 +2444,10 @@ def _get_default_artifact_repository(self) -> str:
if schema
else f"'database', '{database}'"
if database
else f"'account', '{account}'"
# self.get_current_account uses a cached connector field that may not be properly cased, so we need to
# explicitly issue a query for it.
# Since this issues a query, we should compute it only if database/schema are unset.
else f"""'account', '{quote_name_without_upper_casing(self._conn._get_string_datum("SELECT CURRENT_ACCOUNT()"))}'"""
)
result = self._run_query(
f"SELECT SYSTEM$GET_DEFAULT_PYTHON_ARTIFACT_REPOSITORY('{python_version}', {entity_selector_args})"
Expand Down
26 changes: 26 additions & 0 deletions tests/integ/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import io
from functools import partial
from unittest.mock import patch
import logging

import pytest

Expand Down Expand Up @@ -41,6 +42,7 @@
_get_active_session,
_get_active_sessions,
)
from snowflake.snowpark.context import _ANACONDA_SHARED_REPOSITORY
from tests.utils import IS_IN_STORED_PROC, IS_IN_STORED_PROC_LOCALFS, TestFiles, Utils


Expand Down Expand Up @@ -1091,3 +1093,27 @@ def test_get_active_sessions_empty():

with patch.object(session_module, "_active_sessions", return_value=set()):
assert session_module._get_active_sessions(require_at_least_one=False) == set()


def test_default_artifact_repository_with_no_db_schema(session, caplog):
# The reported customer issue covered by this test (SNOW-3230493) occurs when no schema/database
# is set and an account locator is used, so we mock schema/db to be empty for this test.
# Oddly, getting schema and database appear to be fine (for example, schema comes back with
# double-quoted all-caps '"PUBLIC"') while the connector's cached _account field produces a
# double-quoted lowercase value ('"sfctest0') that causes an error when passed to
# SYSTEM$GET_DEFAULT_ARTIFACT_REPOSITORY.
original_conn_implementation = session._conn._get_current_parameter

def mock_session_parameters(param: str, quoted: bool = True):
if param == "schema" or param == "database":
return None
return original_conn_implementation(param, quoted)

with patch.object(
session._conn,
"_get_current_parameter",
mock_session_parameters,
), caplog.at_level(logging.WARNING):
result = session._get_default_artifact_repository()
assert result == _ANACONDA_SHARED_REPOSITORY
assert caplog.text.count("Error getting default artifact repository") == 0
Loading