From b694e2fb48b206026d71da878da1f75e1670ebf5 Mon Sep 17 00:00:00 2001 From: Yijun Xie Date: Wed, 22 Oct 2025 21:23:29 -0700 Subject: [PATCH 1/8] Loosen flattening rules for sort and filter --- .../_internal/analyzer/select_statement.py | 52 +++++++++----- tests/integ/test_simplifier_suite.py | 71 ++++++++++++++----- 2 files changed, 89 insertions(+), 34 deletions(-) diff --git a/src/snowflake/snowpark/_internal/analyzer/select_statement.py b/src/snowflake/snowpark/_internal/analyzer/select_statement.py index 723277a31d..b6d92411d7 100644 --- a/src/snowflake/snowpark/_internal/analyzer/select_statement.py +++ b/src/snowflake/snowpark/_internal/analyzer/select_statement.py @@ -20,6 +20,7 @@ Sequence, Set, Union, + Literal, ) import snowflake.snowpark._internal.utils @@ -1362,9 +1363,9 @@ def select(self, cols: List[Expression]) -> "SelectStatement": ): # TODO: Clean up, this entire if case is parameter protection can_be_flattened = False - elif (self.where or self.order_by or self.limit_) and has_data_generator_exp( - cols - ): + elif ( + self.where or self.order_by or self.limit_ + ) and has_data_generator_or_window_function_exp(cols): can_be_flattened = False elif self.where and ( (subquery_dependent_columns := derive_dependent_columns(self.where)) @@ -1453,9 +1454,9 @@ def filter(self, col: Expression) -> "SelectStatement": can_be_flattened = ( (not self.flatten_disabled) and can_clause_dependent_columns_flatten( - derive_dependent_columns(col), self.column_states + derive_dependent_columns(col), self.column_states, "filter" ) - and not has_data_generator_exp(self.projection) + and not has_data_generator_or_window_function_exp(self.projection) and not (self.order_by and self.limit_ is not None) ) if can_be_flattened: @@ -1490,7 +1491,7 @@ def sort(self, cols: List[Expression]) -> "SelectStatement": and (not self.limit_) and (not self.offset) and can_clause_dependent_columns_flatten( - derive_dependent_columns(*cols), self.column_states + derive_dependent_columns(*cols), self.column_states, "sort" ) and not has_data_generator_exp(self.projection) ) @@ -1529,7 +1530,7 @@ def distinct(self) -> "SelectStatement": # .order_by(col1).select(col2).distinct() cannot be flattened because # SELECT DISTINCT B FROM TABLE ORDER BY A is not valid SQL and (not (self.order_by and self.has_projection)) - and not has_data_generator_exp(self.projection) + and not has_data_generator_or_window_function_exp(self.projection) ) if can_be_flattened: new = copy(self) @@ -2020,7 +2021,12 @@ def can_projection_dependent_columns_be_flattened( def can_clause_dependent_columns_flatten( dependent_columns: Optional[AbstractSet[str]], subquery_column_states: ColumnStateDict, + clause: Literal["filter", "sort"], ) -> bool: + if clause not in ["filter", "sort"]: + raise ValueError( + f"Invalid clause called in can_clause_dependent_columns_flatten: {clause}" + ) if dependent_columns == COLUMN_DEPENDENCY_DOLLAR: return False elif ( @@ -2034,15 +2040,10 @@ def can_clause_dependent_columns_flatten( for dc in dependent_columns: dc_state = subquery_column_states.get(dc) if dc_state: - if dc_state.change_state == ColumnChangeState.CHANGED_EXP: - return False - elif dc_state.change_state == ColumnChangeState.NEW: - # Most of the time this can be flattened. But if a new column uses window function and this column - # is used in a clause, the sql doesn't work in Snowflake. - # For instance `select a, rank() over(order by b) as d from test_table where d = 1` doesn't work. - # But `select a, b as d from test_table where d = 1` works - # We can inspect whether the referenced new column uses window function. Here we are being - # conservative for now to not flatten the SQL. + if ( + dc_state.change_state == ColumnChangeState.CHANGED_EXP + and clause == "filter" + ): return False return True @@ -2264,8 +2265,6 @@ def has_data_generator_exp(expressions: Optional[List["Expression"]]) -> bool: if expressions is None: return False for exp in expressions: - if isinstance(exp, WindowExpression): - return True if isinstance(exp, FunctionExpression) and ( exp.is_data_generator or exp.name.lower() in SEQUENCE_DEPENDENT_DATA_GENERATION @@ -2275,3 +2274,20 @@ def has_data_generator_exp(expressions: Optional[List["Expression"]]) -> bool: if exp is not None and has_data_generator_exp(exp.children): return True return False + + +def has_window_function_exp(expressions: Optional[List["Expression"]]) -> bool: + if expressions is None: + return False + for exp in expressions: + if isinstance(exp, WindowExpression): + return True + if exp is not None and has_window_function_exp(exp.children): + return True + return False + + +def has_data_generator_or_window_function_exp( + expressions: Optional[List["Expression"]], +) -> bool: + return has_data_generator_exp(expressions) or has_window_function_exp(expressions) diff --git a/tests/integ/test_simplifier_suite.py b/tests/integ/test_simplifier_suite.py index f42cb176cd..55942d31ed 100644 --- a/tests/integ/test_simplifier_suite.py +++ b/tests/integ/test_simplifier_suite.py @@ -9,7 +9,7 @@ import pytest -from snowflake.snowpark import Row +from snowflake.snowpark import Row, Window from snowflake.snowpark._internal.analyzer.select_statement import ( SET_EXCEPT, SET_INTERSECT, @@ -30,6 +30,7 @@ sum as sum_, table_function, udtf, + rank, ) from tests.utils import TestData, Utils @@ -754,21 +755,35 @@ def test_order_by(setup_reduce_cast, session, simplifier_table): f'SELECT "A", "B" FROM {simplifier_table} ORDER BY "A" ASC NULLS FIRST, "B" ASC NULLS FIRST' ) - # no flatten because c is a new column + # flatten if a new column is used in the order by clause df3 = df.select("a", "b", (col("a") - col("b")).as_("c")).sort("a", "b", "c") assert Utils.normalize_sql(df3.queries["queries"][-1]) == Utils.normalize_sql( - f'SELECT * FROM ( SELECT "A", "B", ("A" - "B") AS "C" FROM {simplifier_table} ) ORDER BY "A" ASC NULLS FIRST, "B" ASC NULLS FIRST, "C" ASC NULLS FIRST' + f'SELECT "A", "B", ("A" - "B") AS "C" FROM {simplifier_table} ORDER BY "A" ASC NULLS FIRST, "B" ASC NULLS FIRST, "C" ASC NULLS FIRST' ) - # no flatten because a and be are changed + # still flatten even if a is changed because it's used in the order by clause df4 = df.select((col("a") + 1).as_("a"), ((col("b") + 1).as_("b"))).sort("a", "b") assert Utils.normalize_sql(df4.queries["queries"][-1]) == Utils.normalize_sql( - f'SELECT * FROM ( SELECT ("A" + 1{integer_literal_postfix}) AS "A", ("B" + 1{integer_literal_postfix}) AS "B" FROM {simplifier_table} ) ORDER BY "A" ASC NULLS FIRST, "B" ASC NULLS FIRST' + f'SELECT ("A" + 1{integer_literal_postfix}) AS "A", ("B" + 1{integer_literal_postfix}) AS "B" FROM {simplifier_table} ORDER BY "A" ASC NULLS FIRST, "B" ASC NULLS FIRST' ) - # subquery has sql text so unable to figure out same-level dependency, so assuming d depends on c. No flatten. - df5 = df.select("a", "b", lit(3).as_("c"), sql_expr("1 + 1 as d")).sort("a", "b") + # still flatten if a window function is used in the projection + df5 = df.select("a", "b", rank().over(Window.order_by("b")).alias("c")).sort( + "a", "b" + ) assert Utils.normalize_sql(df5.queries["queries"][-1]) == Utils.normalize_sql( + f'SELECT "A", "B", rank() OVER (ORDER BY "B" ASC NULLS FIRST) AS "C" FROM {simplifier_table} ORDER BY "A" ASC NULLS FIRST, "B" ASC NULLS FIRST' + ) + + # No flatten if a data generator is used in the projection + df6 = df.select("a", "b", seq1().alias("c")).sort("a", "b") + assert Utils.normalize_sql(df6.queries["queries"][-1]) == Utils.normalize_sql( + f'SELECT * FROM ( SELECT "A", "B", seq1(0) AS "C" FROM {simplifier_table}) ORDER BY "A" ASC NULLS FIRST, "B" ASC NULLS FIRST' + ) + + # subquery has sql text so unable to figure out if a data generator is used in the projection. No flatten. + df7 = df.select("a", "b", lit(3).as_("c"), sql_expr("1 + 1 as d")).sort("a", "b") + assert Utils.normalize_sql(df7.queries["queries"][-1]) == Utils.normalize_sql( f'SELECT * FROM ( SELECT "A", "B", 3 :: INT AS "C", 1 + 1 as d FROM ( SELECT * FROM {simplifier_table} ) ) ORDER BY "A" ASC NULLS FIRST, "B" ASC NULLS FIRST' ) @@ -791,32 +806,56 @@ def test_filter(setup_reduce_cast, session, simplifier_table): f'SELECT "A", "B" FROM {simplifier_table} WHERE (("A" > 1{integer_literal_postfix}) AND ("B" > 2{integer_literal_postfix}))' ) - # no flatten because c is a new column + # flatten if a regular new column is in the projection df3 = df.select("a", "b", (col("a") - col("b")).as_("c")).filter( - (col("a") > 1) & (col("b") > 2) & (col("c") < 1) + (col("a") > 1) & (col("b") > 2) ) assert Utils.normalize_sql(df3.queries["queries"][-1]) == Utils.normalize_sql( - f'SELECT * FROM ( SELECT "A", "B", ("A" - "B") AS "C" FROM {simplifier_table} ) WHERE ((("A" > 1{integer_literal_postfix}) AND ("B" > 2{integer_literal_postfix})) AND ("C" < 1{integer_literal_postfix}))' + f'SELECT "A", "B", ("A" - "B") AS "C" FROM {simplifier_table} WHERE (("A" > 1{integer_literal_postfix}) AND ("B" > 2{integer_literal_postfix}))' + ) + + # flatten if a regular new column is used in the filter clause + df4 = df.select("a", "b", (col("a") - col("b")).as_("c")).filter( + (col("a") > 1) & (col("b") > 2) & (col("c") < 1) + ) + assert Utils.normalize_sql(df4.queries["queries"][-1]) == Utils.normalize_sql( + f'SELECT "A", "B", ("A" - "B") AS "C" FROM {simplifier_table} WHERE ((("A" > 1{integer_literal_postfix}) AND ("B" > 2{integer_literal_postfix})) AND ("C" < 1{integer_literal_postfix}))' + ) + + # no flatten if a window function is used in the projection + df5 = df.select("a", "b", rank().over(Window.order_by("b")).alias("c")).filter( + (col("a") > 1) & (col("b") > 2) & (col("c") < 1) + ) + assert Utils.normalize_sql(df5.queries["queries"][-1]) == Utils.normalize_sql( + f'SELECT * FROM ( SELECT "A", "B", rank() OVER (ORDER BY "B" ASC NULLS FIRST) AS "C" FROM {simplifier_table} ) WHERE ((("A" > 1{integer_literal_postfix}) AND ("B" > 2{integer_literal_postfix})) AND ("C" < 1{integer_literal_postfix}))' + ) + + # no flatten if a data generator is used in the projection + df6 = df.select("a", "b", seq1().alias("c")).filter( + (col("a") > 1) & (col("b") > 2) & (col("c") < 1) + ) + assert Utils.normalize_sql(df6.queries["queries"][-1]) == Utils.normalize_sql( + f'SELECT * FROM ( SELECT "A", "B", seq1(0) AS "C" FROM {simplifier_table} ) WHERE ((("A" > 1{integer_literal_postfix}) AND ("B" > 2{integer_literal_postfix})) AND ("C" < 1{integer_literal_postfix}))' ) # no flatten because a and be are changed - df4 = df.select((col("a") + 1).as_("a"), (col("b") + 1).as_("b")).filter( + df7 = df.select((col("a") + 1).as_("a"), (col("b") + 1).as_("b")).filter( (col("a") > 1) & (col("b") > 2) ) - assert Utils.normalize_sql(df4.queries["queries"][-1]) == Utils.normalize_sql( + assert Utils.normalize_sql(df7.queries["queries"][-1]) == Utils.normalize_sql( f'SELECT * FROM ( SELECT ("A" + 1{integer_literal_postfix}) AS "A", ("B" + 1{integer_literal_postfix}) AS "B" FROM {simplifier_table} ) WHERE (("A" > 1{integer_literal_postfix}) AND ("B" > 2{integer_literal_postfix}))' ) - df5 = df4.select("a") - assert Utils.normalize_sql(df5.queries["queries"][-1]) == Utils.normalize_sql( + df8 = df7.select("a") + assert Utils.normalize_sql(df8.queries["queries"][-1]) == Utils.normalize_sql( f'SELECT "A" FROM ( SELECT ("A" + 1{integer_literal_postfix}) AS "A", ("B" + 1{integer_literal_postfix}) AS "B" FROM {simplifier_table} ) WHERE (("A" > 1{integer_literal_postfix}) AND ("B" > 2{integer_literal_postfix}))' ) # subquery has sql text so unable to figure out same-level dependency, so assuming d depends on c. No flatten. - df6 = df.select("a", "b", lit(3).as_("c"), sql_expr("1 + 1 as d")).filter( + df9 = df.select("a", "b", lit(3).as_("c"), sql_expr("1 + 1 as d")).filter( col("a") > 1 ) - assert Utils.normalize_sql(df6.queries["queries"][-1]) == Utils.normalize_sql( + assert Utils.normalize_sql(df9.queries["queries"][-1]) == Utils.normalize_sql( f'SELECT * FROM ( SELECT "A", "B", 3 :: INT AS "C", 1 + 1 as d FROM ( SELECT * FROM {simplifier_table} ) ) WHERE ("A" > 1{integer_literal_postfix})' ) From a25942b13dac2303da4d423e10d563dc9f38a9c4 Mon Sep 17 00:00:00 2001 From: Yijun Xie Date: Mon, 27 Oct 2025 14:02:16 -0700 Subject: [PATCH 2/8] No flatten if dropped columns after order by / sort --- .../_internal/analyzer/select_statement.py | 21 +++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/src/snowflake/snowpark/_internal/analyzer/select_statement.py b/src/snowflake/snowpark/_internal/analyzer/select_statement.py index b6d92411d7..40467d8fba 100644 --- a/src/snowflake/snowpark/_internal/analyzer/select_statement.py +++ b/src/snowflake/snowpark/_internal/analyzer/select_statement.py @@ -1376,7 +1376,18 @@ def select(self, cols: List[Expression]) -> "SelectStatement": subquery_dependent_columns & new_column_states.active_columns ) ) + or ( + new_column_states.dropped_columns + and any( + new_column_states[_col].change_state == ColumnChangeState.DROPPED + and self.column_states[_col].change_state + in (ColumnChangeState.NEW, ColumnChangeState.CHANGED_EXP) + and _col in subquery_dependent_columns + for _col in (new_column_states.dropped_columns) + ) + ) ): + # or (new_column_states[_col].change_state == ColumnChangeState.DROPPED and self.column_states[_col].change_state in (ColumnChangeState.NEW, ColumnChangeState.CHANGED_EXP)) can_be_flattened = False elif self.order_by and ( (subquery_dependent_columns := derive_dependent_columns(*self.order_by)) @@ -1388,6 +1399,16 @@ def select(self, cols: List[Expression]) -> "SelectStatement": subquery_dependent_columns & new_column_states.active_columns ) ) + or ( + new_column_states.dropped_columns + and any( + new_column_states[_col].change_state == ColumnChangeState.DROPPED + and self.column_states[_col].change_state + in (ColumnChangeState.NEW, ColumnChangeState.CHANGED_EXP) + and _col in subquery_dependent_columns + for _col in (new_column_states.dropped_columns) + ) + ) ): can_be_flattened = False elif self.distinct_: From 1df1e96fa7c30979e5feacc6c91d7b3f176289c4 Mon Sep 17 00:00:00 2001 From: Yijun Xie Date: Mon, 27 Oct 2025 23:46:30 -0700 Subject: [PATCH 3/8] Add tests in simplifier --- tests/integ/test_simplifier_suite.py | 79 ++++++++++++++++++++++++++++ 1 file changed, 79 insertions(+) diff --git a/tests/integ/test_simplifier_suite.py b/tests/integ/test_simplifier_suite.py index 55942d31ed..bc9d08b068 100644 --- a/tests/integ/test_simplifier_suite.py +++ b/tests/integ/test_simplifier_suite.py @@ -1650,6 +1650,22 @@ def test_chained_sort(session): .filter(col("A") > 2), 'SELECT "A", "B", 12 :: INT AS "TWELVE" FROM ( SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT) ) WHERE (("A" > 1{POSTFIX}) AND ("A" > 2{POSTFIX}))', ), + # Flattened if the dropped columns are not used in filter + ( + lambda df: df.filter(col("A") >= 1) + .select(col("A").alias("C"), col("B").alias("D")) + .filter(col("C") > 2) + .select(col("C")), + 'SELECT "A" AS "C" FROM (SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT)) WHERE (("A" >= 1{POSTFIX}) AND ("C" > 2{POSTFIX}))', + ), + # Flattened if the dropped columns are not in the filter clause's dependent columns + ( + lambda df: df.filter(col("A") >= 1) + .select(col("A").alias("C"), col("B").alias("D")) + .filter((col("C") + 1) > 2) + .select(col("C")), + 'SELECT "A" AS "C" FROM (SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT)) WHERE (("A" >= 1{POSTFIX}) AND (("C" + 1{POSTFIX}) > 2{POSTFIX}))', + ), # Not fully flattened, since col("A") > 1 and col("A") > 2 are referring to different columns ( lambda df: df.filter(col("A") > 1) @@ -1672,6 +1688,29 @@ def test_chained_sort(session): lambda df: df.filter(col("$1") > 1).select(col("B"), col("A")), 'SELECT "B", "A" FROM ( SELECT * FROM ( SELECT "A", "B" FROM ( SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT) ) ) WHERE ("$1" > 1{POSTFIX}) )', ), + # Not flattened if a dropped column is used in the filter clause + ( + lambda df: df.filter(col("A") >= 1) + .select(col("A"), col("B").alias("D")) + .filter(col("D") > -3) + .select(col("A").alias("E")), + 'SELECT "A" AS "E" FROM (SELECT "A", "B" AS "D" FROM (SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT)) WHERE (("A" >= 1{POSTFIX}) AND ("D" > -3{POSTFIX})))', + ), + # Not flattened if a dropped column is used in the select clause's dependent columns + ( + lambda df: df.filter(col("A") >= 1) + .select(col("A"), col("B").alias("D")) + .filter((col("D") - 1) > -4) + .select((col("A") + 1).alias("E")), + 'SELECT ("A" + 1{POSTFIX}) AS "E" FROM (SELECT "A", "B" AS "D" FROM (SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT)) WHERE (("A" >= 1{POSTFIX}) AND (("D" - 1{POSTFIX}) > -4{POSTFIX})))', + ), + # Not flattened if a dropped column that was changed expression is used in the select clause's dependent columns + ( + lambda df: df.select(col("A"), (col("B") + 1).alias("B")) + .filter((col("B") - 1) > -4) + .select((col("A") + 1).alias("E")), + 'SELECT ("A" + 1{POSTFIX}) AS "E" FROM (SELECT "A", ("B" + 1{POSTFIX}) AS "B" FROM (SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT))) WHERE (("B" - 1{POSTFIX}) > -4{POSTFIX})', + ), ], ) def test_select_after_filter(setup_reduce_cast, session, operation, simplified_query): @@ -1742,6 +1781,46 @@ def test_select_after_filter(setup_reduce_cast, session, operation, simplified_q 'SELECT "A" FROM ( SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT) ) ORDER BY "B" ASC NULLS FIRST, "A" ASC NULLS FIRST', True, ), + # Flattened if the dropped columns are not used in filter + ( + lambda df: df.select(col("A").alias("C"), col("B").alias("D")) + .order_by(col("C")) + .select(col("C")), + 'SELECT "A" AS "C" FROM (SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT)) ORDER BY "C" ASC NULLS FIRST', + True, + ), + # Flattened if the dropped columns are not in the order by clause's dependent columns + ( + lambda df: df.select(col("A").alias("C"), col("B").alias("D")) + .order_by(col("C") + 1) + .select(col("C")), + 'SELECT "A" AS "C" FROM (SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT)) ORDER BY ("C" + 1{POSTFIX}) ASC NULLS FIRST', + True, + ), + # Not flattened if a dropped new column is used in the order by clause + ( + lambda df: df.select(col("A"), col("B").alias("D")) + .order_by(col("D")) + .select(col("A").alias("E")), + 'SELECT "A" AS "E" FROM (SELECT "A", "B" AS "D" FROM (SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT)) ORDER BY "D" ASC NULLS FIRST)', + True, + ), + # Not flattened if a dropped new column is used in the order by clause's dependent columns + ( + lambda df: df.select(col("A"), col("B").alias("D")) + .order_by(col("D") - 1) + .select((col("A") + 1).alias("E")), + 'SELECT ("A" + 1{POSTFIX}) AS "E" FROM (SELECT "A", "B" AS "D" FROM (SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT)) ORDER BY ("D" - 1{POSTFIX}) ASC NULLS FIRST)', + True, + ), + # Not flattened if a dropped column that was changed expression is used in the select clause's dependent columns + ( + lambda df: df.select(col("A"), (col("B") + 1).alias("B")) + .filter((col("B") - 1) > -4) + .select((col("A") + 1).alias("E")), + 'SELECT ("A" + 1{POSTFIX}) AS "E" FROM (SELECT "A", ("B" + 1{POSTFIX}) AS "B" FROM (SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT))) WHERE (("B" - 1{POSTFIX}) > -4{POSTFIX})', + True, + ), ], ) def test_select_after_orderby( From 11043611e01874f80f8084a7edb0561853813ae9 Mon Sep 17 00:00:00 2001 From: Yijun Xie Date: Tue, 28 Oct 2025 10:53:52 -0700 Subject: [PATCH 4/8] Update test after flattening --- tests/integ/test_query_line_intervals.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/integ/test_query_line_intervals.py b/tests/integ/test_query_line_intervals.py index 2852a7c9e3..d60f951bf6 100644 --- a/tests/integ/test_query_line_intervals.py +++ b/tests/integ/test_query_line_intervals.py @@ -73,7 +73,7 @@ def generate_test_data(session, sql_simplifier_enabled): lambda data: data["df1"].filter(data["df1"].value > 150), True, { - 8: 'SELECT $1 AS "_1", $2 AS "_2", $3 AS "_3" FROM VALUES (1 :: INT, \'A\' :: STRING, 100 :: INT), (2 :: INT, \'B\' :: STRING, 200 :: INT)', + 8: """SELECT "_1" AS "ID", "_2" AS "NAME", "_3" AS "VALUE" FROM (SELECT $1 AS "_1", $2 AS "_2", $3 AS "_3" FROM VALUES (1 :: INT, 'A' :: STRING, 100 :: INT), (2 :: INT, 'B' :: STRING, 200 :: INT)) WHERE ("VALUE" > 150)""", }, ), ( From d67f5df9039fc6997c1853f67667449be0a23ce4 Mon Sep 17 00:00:00 2001 From: Adam Ling Date: Wed, 29 Oct 2025 23:49:29 -0700 Subject: [PATCH 5/8] parameter protection and agg function check for fitler --- .../_internal/analyzer/select_statement.py | 79 ++++++++++++++----- src/snowflake/snowpark/context.py | 1 + src/snowflake/snowpark/session.py | 14 ++++ 3 files changed, 74 insertions(+), 20 deletions(-) diff --git a/src/snowflake/snowpark/_internal/analyzer/select_statement.py b/src/snowflake/snowpark/_internal/analyzer/select_statement.py index 40467d8fba..36bb6f125d 100644 --- a/src/snowflake/snowpark/_internal/analyzer/select_statement.py +++ b/src/snowflake/snowpark/_internal/analyzer/select_statement.py @@ -87,6 +87,7 @@ is_sql_select_statement, ExprAliasUpdateDict, ) +import snowflake.snowpark.context as context # Python 3.8 needs to use typing.Iterable because collections.abc.Iterable is not subscriptable # Python 3.9 can use both @@ -1377,17 +1378,20 @@ def select(self, cols: List[Expression]) -> "SelectStatement": ) ) or ( - new_column_states.dropped_columns + # unflattenable condition: dropped column is used in subquery WHERE clause and dropped column status is NEW or CHANGED in the subquery + # reason: we should not flatten because the dropped column is not available in the new query, leading to WHERE clause error + # sample query: 'select "b" from (select "a" as "c", "b" from table where "c" > 1)' can not be flatten to 'select "b" from table where "c" > 1' + context._is_snowpark_connect_compatible_mode + and new_column_states.dropped_columns and any( - new_column_states[_col].change_state == ColumnChangeState.DROPPED - and self.column_states[_col].change_state + self.column_states[_col].change_state in (ColumnChangeState.NEW, ColumnChangeState.CHANGED_EXP) - and _col in subquery_dependent_columns - for _col in (new_column_states.dropped_columns) + for _col in ( + subquery_dependent_columns & new_column_states.dropped_columns + ) ) ) ): - # or (new_column_states[_col].change_state == ColumnChangeState.DROPPED and self.column_states[_col].change_state in (ColumnChangeState.NEW, ColumnChangeState.CHANGED_EXP)) can_be_flattened = False elif self.order_by and ( (subquery_dependent_columns := derive_dependent_columns(*self.order_by)) @@ -1400,13 +1404,17 @@ def select(self, cols: List[Expression]) -> "SelectStatement": ) ) or ( - new_column_states.dropped_columns + # unflattenable condition: dropped column is used in subquery ORDER BY clause and dropped column status is NEW or CHANGED in the subquery + # reason: we should not flatten because the dropped column is not available in the new query, leading to ORDER BY clause error + # sample query: 'select "b" from (select "a" as "c", "b" order by "c")' can not be flatten to 'select "b" from table order by "c"' + context._is_snowpark_connect_compatible_mode + and new_column_states.dropped_columns and any( - new_column_states[_col].change_state == ColumnChangeState.DROPPED - and self.column_states[_col].change_state + self.column_states[_col].change_state in (ColumnChangeState.NEW, ColumnChangeState.CHANGED_EXP) - and _col in subquery_dependent_columns - for _col in (new_column_states.dropped_columns) + for _col in ( + subquery_dependent_columns & new_column_states.dropped_columns + ) ) ) ): @@ -1478,6 +1486,10 @@ def filter(self, col: Expression) -> "SelectStatement": derive_dependent_columns(col), self.column_states, "filter" ) and not has_data_generator_or_window_function_exp(self.projection) + and not ( + context._is_snowpark_connect_compatible_mode + and has_aggregation_function_exp(self.projection) + ) # sum(col) as new_col, new_col can not be flattened in where clause and not (self.order_by and self.limit_ is not None) ) if can_be_flattened: @@ -2044,10 +2056,10 @@ def can_clause_dependent_columns_flatten( subquery_column_states: ColumnStateDict, clause: Literal["filter", "sort"], ) -> bool: - if clause not in ["filter", "sort"]: - raise ValueError( - f"Invalid clause called in can_clause_dependent_columns_flatten: {clause}" - ) + assert clause in ( + "filter", + "sort", + ), f"Invalid clause called in can_clause_dependent_columns_flatten: {clause}" if dependent_columns == COLUMN_DEPENDENCY_DOLLAR: return False elif ( @@ -2061,11 +2073,19 @@ def can_clause_dependent_columns_flatten( for dc in dependent_columns: dc_state = subquery_column_states.get(dc) if dc_state: - if ( - dc_state.change_state == ColumnChangeState.CHANGED_EXP - and clause == "filter" - ): - return False + if dc_state.change_state == ColumnChangeState.CHANGED_EXP: + if ( + clause == "filter" + ): # where can not be flattened because 'where' is evaluated before projection, flattening leads to wrong result + # df.select((col('a') + 1).alias('a')).filter(col('a') > 5) -- this should be applied to the new 'a', flattening will use the old 'a' to evaluated + return False + else: # clause == 'sort' + # df.select((col('a') + 1).alias('a')).sort(col('a')) -- this is valid to flatten because 'order by' is evaluated after projection + # however, if the order by is a data generator, it should not be flattened because generator is evaluated dynamically according to the order. + return context._is_snowpark_connect_compatible_mode + elif dc_state.change_state == ColumnChangeState.NEW: + return context._is_snowpark_connect_compatible_mode + return True @@ -2286,6 +2306,10 @@ def has_data_generator_exp(expressions: Optional[List["Expression"]]) -> bool: if expressions is None: return False for exp in expressions: + if not context._is_snowpark_connect_compatible_mode and isinstance( + exp, WindowExpression + ): + return True if isinstance(exp, FunctionExpression) and ( exp.is_data_generator or exp.name.lower() in SEQUENCE_DEPENDENT_DATA_GENERATION @@ -2311,4 +2335,19 @@ def has_window_function_exp(expressions: Optional[List["Expression"]]) -> bool: def has_data_generator_or_window_function_exp( expressions: Optional[List["Expression"]], ) -> bool: + if not context._is_snowpark_connect_compatible_mode: + return has_data_generator_exp(expressions) return has_data_generator_exp(expressions) or has_window_function_exp(expressions) + + +def has_aggregation_function_exp(expressions: Optional[List["Expression"]]) -> bool: + if expressions is None: + return False + for exp in expressions: + if isinstance(exp, FunctionExpression) and ( + exp.name.lower() in context._aggregation_function_set + ): + return True + if exp is not None and has_aggregation_function_exp(exp.children): + return True + return False diff --git a/src/snowflake/snowpark/context.py b/src/snowflake/snowpark/context.py index 86e92b6aa4..cffd79fc52 100644 --- a/src/snowflake/snowpark/context.py +++ b/src/snowflake/snowpark/context.py @@ -31,6 +31,7 @@ # This is an internal-only global flag, used to determine whether the api code which will be executed is compatible with snowflake.snowpark_connect _is_snowpark_connect_compatible_mode = False +_aggregation_function_set = set() # Following are internal-only global flags, used to enable development features. _enable_dataframe_trace_on_error = False diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index 1afe626720..1067fd4471 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -518,6 +518,20 @@ def create(self) -> "Session": _add_session(session) else: session = self._create_internal(self._options.get("connection")) + if context._is_snowpark_connect_compatible_mode: + for sql in [ + """select function_name from information_schema.functions where is_aggregate = 'YES'""", + """show functions ->> select "name" from $1 where "is_aggregate" = 'Y'""", + ]: + try: + context._aggregation_function_set.update( + {r[0] for r in session.sql(sql).collect()} + ) + except BaseException as e: + _logger.debug( + "Unable to get aggregation functions from the database: %s", + e, + ) if self._app_name: if self._format_json: From 43c3316fa4f6189bbd8caee271a3bce6514d704f Mon Sep 17 00:00:00 2001 From: Adam Ling Date: Fri, 31 Oct 2025 12:12:59 -0700 Subject: [PATCH 6/8] fix tests --- .../_internal/analyzer/select_statement.py | 4 + src/snowflake/snowpark/context.py | 5 +- src/snowflake/snowpark/session.py | 39 ++-- tests/integ/test_query_line_intervals.py | 24 ++- tests/integ/test_simplifier_suite.py | 176 +++++++++++++++--- 5 files changed, 207 insertions(+), 41 deletions(-) diff --git a/src/snowflake/snowpark/_internal/analyzer/select_statement.py b/src/snowflake/snowpark/_internal/analyzer/select_statement.py index 36bb6f125d..c0bf4ce207 100644 --- a/src/snowflake/snowpark/_internal/analyzer/select_statement.py +++ b/src/snowflake/snowpark/_internal/analyzer/select_statement.py @@ -1480,6 +1480,7 @@ def select(self, cols: List[Expression]) -> "SelectStatement": return new def filter(self, col: Expression) -> "SelectStatement": + self._session._retrieve_aggregation_function_list() can_be_flattened = ( (not self.flatten_disabled) and can_clause_dependent_columns_flatten( @@ -1527,6 +1528,9 @@ def sort(self, cols: List[Expression]) -> "SelectStatement": derive_dependent_columns(*cols), self.column_states, "sort" ) and not has_data_generator_exp(self.projection) + # we do not check aggregation function here like filter + # in the case when aggregation function is in the projection + # order by is evaluated after aggregation, row info are not taken in the calculation ) if can_be_flattened: new = copy(self) diff --git a/src/snowflake/snowpark/context.py b/src/snowflake/snowpark/context.py index cffd79fc52..ed1d15c5f2 100644 --- a/src/snowflake/snowpark/context.py +++ b/src/snowflake/snowpark/context.py @@ -31,7 +31,10 @@ # This is an internal-only global flag, used to determine whether the api code which will be executed is compatible with snowflake.snowpark_connect _is_snowpark_connect_compatible_mode = False -_aggregation_function_set = set() +_aggregation_function_set = ( + set() +) # lower cased names of aggregation functions, used in sql simplification +_aggregation_function_set_lock = threading.RLock() # Following are internal-only global flags, used to enable development features. _enable_dataframe_trace_on_error = False diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index 1067fd4471..6fea3308ae 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -518,20 +518,6 @@ def create(self) -> "Session": _add_session(session) else: session = self._create_internal(self._options.get("connection")) - if context._is_snowpark_connect_compatible_mode: - for sql in [ - """select function_name from information_schema.functions where is_aggregate = 'YES'""", - """show functions ->> select "name" from $1 where "is_aggregate" = 'Y'""", - ]: - try: - context._aggregation_function_set.update( - {r[0] for r in session.sql(sql).collect()} - ) - except BaseException as e: - _logger.debug( - "Unable to get aggregation functions from the database: %s", - e, - ) if self._app_name: if self._format_json: @@ -4874,6 +4860,31 @@ def _execute_sproc_internal( # Note the collect is implicit within the stored procedure call, so should not emit_ast here. return df.collect(statement_params=statement_params, _emit_ast=False)[0][0] + def _retrieve_aggregation_function_list(self) -> None: + """Retrieve the list of aggregation functions which will later be used in sql simplifier.""" + if ( + not context._is_snowpark_connect_compatible_mode + or context._aggregation_function_set + ): + return + + retrieved_set = set() + + for sql in [ + """select function_name from information_schema.functions where is_aggregate = 'YES'""", + """show functions ->> select "name" from $1 where "is_aggregate" = 'Y'""", + ]: + try: + retrieved_set.update({r[0].lower() for r in self.sql(sql).collect()}) + except BaseException as e: + _logger.debug( + "Unable to get aggregation functions from the database: %s", + e, + ) + + with context._aggregation_function_set_lock: + context._aggregation_function_set.update(retrieved_set) + def directory(self, stage_name: str, _emit_ast: bool = True) -> DataFrame: """ Returns a DataFrame representing the results of a directory table query on the specified stage. diff --git a/tests/integ/test_query_line_intervals.py b/tests/integ/test_query_line_intervals.py index d60f951bf6..a95fec4855 100644 --- a/tests/integ/test_query_line_intervals.py +++ b/tests/integ/test_query_line_intervals.py @@ -57,8 +57,9 @@ def generate_test_data(session, sql_simplifier_enabled): } +@pytest.mark.parametrize("snowpark_connect_compatible_mode", [True, False]) @pytest.mark.parametrize( - "op,sql_simplifier,line_to_expected_sql", + "op,sql_simplifier,line_to_expected_sql,snowpark_connect_compatible_mode_sql", [ ( lambda data: data["df1"].union(data["df2"]), @@ -68,10 +69,14 @@ def generate_test_data(session, sql_simplifier_enabled): 6: 'SELECT $1 AS "_1", $2 AS "_2", $3 AS "_3" FROM VALUES (1 :: INT, \'A\' :: STRING, 100 :: INT), (2 :: INT, \'B\' :: STRING, 200 :: INT)', 10: 'SELECT "_1" AS "ID", "_2" AS "NAME", "_3" AS "VALUE" FROM ( SELECT $1 AS "_1", $2 AS "_2", $3 AS "_3" FROM VALUES (3 :: INT, \'C\' :: STRING, 300 :: INT), (4 :: INT, \'D\' :: STRING, 400 :: INT) )', }, + None, ), ( lambda data: data["df1"].filter(data["df1"].value > 150), True, + { + 8: 'SELECT $1 AS "_1", $2 AS "_2", $3 AS "_3" FROM VALUES (1 :: INT, \'A\' :: STRING, 100 :: INT), (2 :: INT, \'B\' :: STRING, 200 :: INT)' + }, { 8: """SELECT "_1" AS "ID", "_2" AS "NAME", "_3" AS "VALUE" FROM (SELECT $1 AS "_1", $2 AS "_2", $3 AS "_3" FROM VALUES (1 :: INT, 'A' :: STRING, 100 :: INT), (2 :: INT, 'B' :: STRING, 200 :: INT)) WHERE ("VALUE" > 150)""", }, @@ -83,6 +88,7 @@ def generate_test_data(session, sql_simplifier_enabled): 1: 'SELECT "_1" AS "ID", "_2" AS "NAME" FROM ( SELECT $1 AS "_1", $2 AS "_2", $3 AS "_3" FROM VALUES (1 :: INT, \'A\' :: STRING, 100 :: INT), (2 :: INT, \'B\' :: STRING, 200 :: INT) )', 4: 'SELECT $1 AS "_1", $2 AS "_2", $3 AS "_3" FROM VALUES (1 :: INT, \'A\' :: STRING, 100 :: INT), (2 :: INT, \'B\' :: STRING, 200 :: INT)', }, + None, ), ( lambda data: data["df1"].pivot(F.col("name")).sum(F.col("value")), @@ -92,12 +98,26 @@ def generate_test_data(session, sql_simplifier_enabled): 6: 'SELECT $1 AS "_1", $2 AS "_2", $3 AS "_3" FROM VALUES (1 :: INT, \'A\' :: STRING, 100 :: INT), (2 :: INT, \'B\' :: STRING, 200 :: INT)', 9: 'SELECT * FROM ( SELECT "_1" AS "ID", "_2" AS "NAME", "_3" AS "VALUE" FROM ( SELECT $1 AS "_1", $2 AS "_2", $3 AS "_3" FROM VALUES (1 :: INT, \'A\' :: STRING, 100 :: INT), (2 :: INT, \'B\' :: STRING, 200 :: INT) ) ) PIVOT ( sum("VALUE") FOR "NAME" IN ( ANY ) )', }, + None, ), ], ) def test_get_plan_from_line_numbers_sql_content( - session, op, sql_simplifier, line_to_expected_sql + session, + op, + sql_simplifier, + line_to_expected_sql, + snowpark_connect_compatible_mode_sql, + snowpark_connect_compatible_mode, + monkeypatch, ): + if snowpark_connect_compatible_mode: + import snowflake.snowpark.context as context + + monkeypatch.setattr(context, "_is_snowpark_connect_compatible_mode", True) + line_to_expected_sql = ( + snowpark_connect_compatible_mode_sql or line_to_expected_sql + ) session.sql_simplifier_enabled = sql_simplifier df = op(generate_test_data(session, sql_simplifier)) diff --git a/tests/integ/test_simplifier_suite.py b/tests/integ/test_simplifier_suite.py index bc9d08b068..044e16f4f8 100644 --- a/tests/integ/test_simplifier_suite.py +++ b/tests/integ/test_simplifier_suite.py @@ -737,7 +737,19 @@ def test_reference_non_exist_columns(session, simplifier_table): df.select(col("c") + 1).collect() -def test_order_by(setup_reduce_cast, session, simplifier_table): +@pytest.mark.parametrize("is_snowpark_connect_compatible_mode", [True, False]) +def test_order_by( + setup_reduce_cast, + session, + simplifier_table, + is_snowpark_connect_compatible_mode, + monkeypatch, +): + if is_snowpark_connect_compatible_mode: + import snowflake.snowpark.context as context + + monkeypatch.setattr(context, "_is_snowpark_connect_compatible_mode", True) + df = session.table(simplifier_table) # flatten @@ -755,24 +767,42 @@ def test_order_by(setup_reduce_cast, session, simplifier_table): f'SELECT "A", "B" FROM {simplifier_table} ORDER BY "A" ASC NULLS FIRST, "B" ASC NULLS FIRST' ) - # flatten if a new column is used in the order by clause + # snowpark connect compatible mode: flatten if a new column is used in the order by clause + # snowflake mode: no flatten because c is a new column df3 = df.select("a", "b", (col("a") - col("b")).as_("c")).sort("a", "b", "c") - assert Utils.normalize_sql(df3.queries["queries"][-1]) == Utils.normalize_sql( + compare_sql = ( f'SELECT "A", "B", ("A" - "B") AS "C" FROM {simplifier_table} ORDER BY "A" ASC NULLS FIRST, "B" ASC NULLS FIRST, "C" ASC NULLS FIRST' + if is_snowpark_connect_compatible_mode + else f'SELECT * FROM ( SELECT "A", "B", ("A" - "B") AS "C" FROM {simplifier_table} ) ORDER BY "A" ASC NULLS FIRST, "B" ASC NULLS FIRST, "C" ASC NULLS FIRST' + ) + assert Utils.normalize_sql(df3.queries["queries"][-1]) == Utils.normalize_sql( + compare_sql ) - # still flatten even if a is changed because it's used in the order by clause + # snowpark connect compatible mode: flatten even if a is changed because it's used in the order by clause + # snowflake mode: no flatten because a and be are changed df4 = df.select((col("a") + 1).as_("a"), ((col("b") + 1).as_("b"))).sort("a", "b") - assert Utils.normalize_sql(df4.queries["queries"][-1]) == Utils.normalize_sql( + compare_sql = ( f'SELECT ("A" + 1{integer_literal_postfix}) AS "A", ("B" + 1{integer_literal_postfix}) AS "B" FROM {simplifier_table} ORDER BY "A" ASC NULLS FIRST, "B" ASC NULLS FIRST' + if is_snowpark_connect_compatible_mode + else f'SELECT * FROM ( SELECT ("A" + 1{integer_literal_postfix}) AS "A", ("B" + 1{integer_literal_postfix}) AS "B" FROM {simplifier_table} ) ORDER BY "A" ASC NULLS FIRST, "B" ASC NULLS FIRST' + ) + assert Utils.normalize_sql(df4.queries["queries"][-1]) == Utils.normalize_sql( + compare_sql ) - # still flatten if a window function is used in the projection + # snowpark connect compatible mode: flatten if a window function is used in the projection + # snowflake mode: no flatten because c is a new column df5 = df.select("a", "b", rank().over(Window.order_by("b")).alias("c")).sort( "a", "b" ) - assert Utils.normalize_sql(df5.queries["queries"][-1]) == Utils.normalize_sql( + compare_sql = ( f'SELECT "A", "B", rank() OVER (ORDER BY "B" ASC NULLS FIRST) AS "C" FROM {simplifier_table} ORDER BY "A" ASC NULLS FIRST, "B" ASC NULLS FIRST' + if is_snowpark_connect_compatible_mode + else f'SELECT * FROM ( SELECT "A", "B", rank() OVER (ORDER BY "B" ASC NULLS FIRST) AS "C" FROM {simplifier_table} ) ORDER BY "A" ASC NULLS FIRST, "B" ASC NULLS FIRST' + ) + assert Utils.normalize_sql(df5.queries["queries"][-1]) == Utils.normalize_sql( + compare_sql ) # No flatten if a data generator is used in the projection @@ -787,8 +817,30 @@ def test_order_by(setup_reduce_cast, session, simplifier_table): f'SELECT * FROM ( SELECT "A", "B", 3 :: INT AS "C", 1 + 1 as d FROM ( SELECT * FROM {simplifier_table} ) ) ORDER BY "A" ASC NULLS FIRST, "B" ASC NULLS FIRST' ) + df8 = df.select("a", sum_(col("b")).alias("c")).sort("c") + compare_sql = ( + f'SELECT "A", sum("B") AS "C" FROM {simplifier_table} ORDER BY "C" ASC NULLS FIRST' + if is_snowpark_connect_compatible_mode + else f'SELECT * FROM ( SELECT "A", sum("B") AS "C" FROM {simplifier_table} ) ORDER BY "C" ASC NULLS FIRST' + ) + assert Utils.normalize_sql(df8.queries["queries"][-1]) == Utils.normalize_sql( + compare_sql + ) + + +@pytest.mark.parametrize("is_snowpark_connect_compatible_mode", [True, False]) +def test_filter( + setup_reduce_cast, + session, + simplifier_table, + is_snowpark_connect_compatible_mode, + monkeypatch, +): + if is_snowpark_connect_compatible_mode: + import snowflake.snowpark.context as context + + monkeypatch.setattr(context, "_is_snowpark_connect_compatible_mode", True) -def test_filter(setup_reduce_cast, session, simplifier_table): df = session.table(simplifier_table) integer_literal_postfix = ( "" if session.eliminate_numeric_sql_value_cast_enabled else " :: INT" @@ -808,18 +860,15 @@ def test_filter(setup_reduce_cast, session, simplifier_table): # flatten if a regular new column is in the projection df3 = df.select("a", "b", (col("a") - col("b")).as_("c")).filter( - (col("a") > 1) & (col("b") > 2) - ) - assert Utils.normalize_sql(df3.queries["queries"][-1]) == Utils.normalize_sql( - f'SELECT "A", "B", ("A" - "B") AS "C" FROM {simplifier_table} WHERE (("A" > 1{integer_literal_postfix}) AND ("B" > 2{integer_literal_postfix}))' - ) - - # flatten if a regular new column is used in the filter clause - df4 = df.select("a", "b", (col("a") - col("b")).as_("c")).filter( (col("a") > 1) & (col("b") > 2) & (col("c") < 1) ) - assert Utils.normalize_sql(df4.queries["queries"][-1]) == Utils.normalize_sql( + compare_sql = ( f'SELECT "A", "B", ("A" - "B") AS "C" FROM {simplifier_table} WHERE ((("A" > 1{integer_literal_postfix}) AND ("B" > 2{integer_literal_postfix})) AND ("C" < 1{integer_literal_postfix}))' + if is_snowpark_connect_compatible_mode + else f'SELECT * FROM ( SELECT "A", "B", ("A" - "B") AS "C" FROM {simplifier_table} ) WHERE ((("A" > 1{integer_literal_postfix}) AND ("B" > 2{integer_literal_postfix})) AND ("C" < 1{integer_literal_postfix}))' + ) + assert Utils.normalize_sql(df3.queries["queries"][-1]) == Utils.normalize_sql( + compare_sql ) # no flatten if a window function is used in the projection @@ -859,6 +908,12 @@ def test_filter(setup_reduce_cast, session, simplifier_table): f'SELECT * FROM ( SELECT "A", "B", 3 :: INT AS "C", 1 + 1 as d FROM ( SELECT * FROM {simplifier_table} ) ) WHERE ("A" > 1{integer_literal_postfix})' ) + # no flatten if a aggregation function is used in the projection + df10 = df.select("a", sum_(col("b")).alias("c")).filter(col("c") < 1) + assert Utils.normalize_sql(df10.queries["queries"][-1]) == Utils.normalize_sql( + f'SELECT * FROM ( SELECT "A", sum("B") AS "C" FROM {simplifier_table} ) WHERE ("C" < 1{integer_literal_postfix})' + ) + def test_limit(setup_reduce_cast, session, simplifier_table): df = session.table(simplifier_table) @@ -1630,18 +1685,21 @@ def test_chained_sort(session): ) +@pytest.mark.parametrize("snowpark_connect_compatible_mode", [True, False]) @pytest.mark.parametrize( - "operation,simplified_query", + "operation,simplified_query,snowpark_connect_simplified_query", [ # Flattened ( lambda df: df.filter(col("A") > 1).select(col("B") + 1), 'SELECT ("B" + 1{POSTFIX}) FROM ( SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT) ) WHERE ("A" > 1{POSTFIX})', + None, ), # Flattened, if there are duplicate column names across the parent/child, WHERE is evaluated on subquery first, so we could flatten in this case ( lambda df: df.filter(col("A") > 1).select((col("B") + 1).alias("A")), 'SELECT ("B" + 1{POSTFIX}) AS "A" FROM ( SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT) ) WHERE ("A" > 1{POSTFIX})', + None, ), # Flattened ( @@ -1649,21 +1707,26 @@ def test_chained_sort(session): .select(col("A"), col("B"), lit(12).alias("TWELVE")) .filter(col("A") > 2), 'SELECT "A", "B", 12 :: INT AS "TWELVE" FROM ( SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT) ) WHERE (("A" > 1{POSTFIX}) AND ("A" > 2{POSTFIX}))', + None, ), - # Flattened if the dropped columns are not used in filter + # Flattened if the dropped columns are not used in filter in snowpark connect compatible mode + # Notice the local inner flattening happening to WHERE clauses because in snowpark connect compatible mode, NEW column "D" can be flattened into the new query ( lambda df: df.filter(col("A") >= 1) .select(col("A").alias("C"), col("B").alias("D")) .filter(col("C") > 2) .select(col("C")), + 'SELECT "C" FROM (SELECT "A" AS "C", "B" AS "D" FROM (SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT)) WHERE ("A" >= 1{POSTFIX})) WHERE ("C" > 2{POSTFIX})', 'SELECT "A" AS "C" FROM (SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT)) WHERE (("A" >= 1{POSTFIX}) AND ("C" > 2{POSTFIX}))', ), - # Flattened if the dropped columns are not in the filter clause's dependent columns + # Flattened if the dropped columns are not in the filter clause's dependent columns in snowpark connect compatible mode + # Notice the local inner flattening happening to WHERE clauses because in snowpark connect compatible mode, NEW column "D" can be flattened into the new query ( lambda df: df.filter(col("A") >= 1) .select(col("A").alias("C"), col("B").alias("D")) .filter((col("C") + 1) > 2) .select(col("C")), + 'SELECT "C" FROM (SELECT "A" AS "C", "B" AS "D" FROM (SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT)) WHERE ("A" >= 1{POSTFIX})) WHERE (("C" + 1{POSTFIX}) > 2{POSTFIX})', 'SELECT "A" AS "C" FROM (SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT)) WHERE (("A" >= 1{POSTFIX}) AND (("C" + 1{POSTFIX}) > 2{POSTFIX}))', ), # Not fully flattened, since col("A") > 1 and col("A") > 2 are referring to different columns @@ -1672,36 +1735,44 @@ def test_chained_sort(session): .select((col("B") + 1).alias("A")) .filter(col("A") > 2), 'SELECT * FROM ( SELECT ("B" + 1{POSTFIX}) AS "A" FROM ( SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT) ) WHERE ("A" > 1{POSTFIX}) ) WHERE ("A" > 2{POSTFIX})', + None, ), # Not flattened, since A is updated in the select after filter. ( lambda df: df.filter(col("A") > 1).select("A", seq1(0)), 'SELECT "A", seq1(0) FROM ( SELECT "A", "B" FROM ( SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT) ) WHERE ("A" > 1{POSTFIX}) )', + None, ), # Not flattened, since we cannot detect dependent columns from sql_expr ( lambda df: df.filter(sql_expr("A > 1")).select(col("B"), col("A")), 'SELECT "B", "A" FROM ( SELECT "A", "B" FROM ( SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT) ) WHERE A > 1 )', + None, ), # Not flattened, since we cannot flatten when the subquery uses positional parameter ($1) ( lambda df: df.filter(col("$1") > 1).select(col("B"), col("A")), 'SELECT "B", "A" FROM ( SELECT * FROM ( SELECT "A", "B" FROM ( SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT) ) ) WHERE ("$1" > 1{POSTFIX}) )', + None, ), # Not flattened if a dropped column is used in the filter clause + # Notice the local inner flattening happening to WHERE clauses because in snowpark connect compatible mode, NEW column "D" can be flattened into the new query ( lambda df: df.filter(col("A") >= 1) .select(col("A"), col("B").alias("D")) .filter(col("D") > -3) .select(col("A").alias("E")), + 'SELECT "A" AS "E" FROM (SELECT "A", "B" AS "D" FROM (SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT)) WHERE ("A" >= 1{POSTFIX})) WHERE ("D" > -3{POSTFIX})', 'SELECT "A" AS "E" FROM (SELECT "A", "B" AS "D" FROM (SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT)) WHERE (("A" >= 1{POSTFIX}) AND ("D" > -3{POSTFIX})))', ), # Not flattened if a dropped column is used in the select clause's dependent columns + # Notice the local inner flattening happening to WHERE clauses because in snowpark connect compatible mode, NEW column "D" can be flattened into the new query ( lambda df: df.filter(col("A") >= 1) .select(col("A"), col("B").alias("D")) .filter((col("D") - 1) > -4) .select((col("A") + 1).alias("E")), + 'SELECT ("A" + 1{POSTFIX}) AS "E" FROM (SELECT "A", "B" AS "D" FROM (SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT)) WHERE ("A" >= 1{POSTFIX})) WHERE (("D" - 1{POSTFIX}) > -4{POSTFIX})', 'SELECT ("A" + 1{POSTFIX}) AS "E" FROM (SELECT "A", "B" AS "D" FROM (SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT)) WHERE (("A" >= 1{POSTFIX}) AND (("D" - 1{POSTFIX}) > -4{POSTFIX})))', ), # Not flattened if a dropped column that was changed expression is used in the select clause's dependent columns @@ -1710,10 +1781,25 @@ def test_chained_sort(session): .filter((col("B") - 1) > -4) .select((col("A") + 1).alias("E")), 'SELECT ("A" + 1{POSTFIX}) AS "E" FROM (SELECT "A", ("B" + 1{POSTFIX}) AS "B" FROM (SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT))) WHERE (("B" - 1{POSTFIX}) > -4{POSTFIX})', + None, ), ], ) -def test_select_after_filter(setup_reduce_cast, session, operation, simplified_query): +def test_select_after_filter( + setup_reduce_cast, + session, + operation, + simplified_query, + snowpark_connect_compatible_mode, + monkeypatch, + snowpark_connect_simplified_query, +): + if snowpark_connect_compatible_mode: + import snowflake.snowpark.context as context + + monkeypatch.setattr(context, "_is_snowpark_connect_compatible_mode", True) + simplified_query = snowpark_connect_simplified_query or simplified_query + session.sql_simplifier_enabled = False df1 = session.create_dataframe([[1, -2], [3, -4]], schema=["a", "b"]) @@ -1733,43 +1819,50 @@ def test_select_after_filter(setup_reduce_cast, session, operation, simplified_q ) == Utils.normalize_sql(simplified_query) +@pytest.mark.parametrize("snowpark_connect_compatible_mode", [True, False]) @pytest.mark.parametrize( - "operation,simplified_query,execute_sql", + "operation,simplified_query,snowpark_connect_simplified_query,execute_sql", [ # Flattened ( lambda df: df.order_by(col("A")).select(col("B") + 1), 'SELECT ("B" + 1{POSTFIX}) FROM ( SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT) ) ORDER BY "A" ASC NULLS FIRST', + None, True, ), # Not flattened because SEQ1() is a data generator. ( lambda df: df.order_by(col("A")).select(seq1(0)), 'SELECT seq1(0) FROM ( SELECT "A", "B" FROM ( SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT) ) ORDER BY "A" ASC NULLS FIRST )', + None, True, ), # Not flattened, unlike filter, current query takes precendence when there are duplicate column names from a ORDERBY clause ( lambda df: df.order_by(col("A")).select((col("B") + 1).alias("A")), 'SELECT ("B" + 1{POSTFIX}) AS "A" FROM ( SELECT "A", "B" FROM ( SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT) ) ORDER BY "A" ASC NULLS FIRST )', + None, True, ), # Not flattened, since we cannot detect dependent columns from sql_expr ( lambda df: df.order_by(sql_expr("A")).select(col("B")), 'SELECT "B" FROM ( SELECT "A", "B" FROM ( SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT) ) ORDER BY A ASC NULLS FIRST )', + None, True, ), # Not flattened, since we cannot flatten when the subquery uses positional parameter ($1) ( lambda df: df.order_by(col("$1")).select(col("B")), 'SELECT "B" FROM ( SELECT * FROM ( SELECT "A", "B" FROM ( SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT) ) ) ORDER BY "$1" ASC NULLS FIRST )', + None, True, ), # Not flattened, skip execution since this would result in SnowparkSQLException ( lambda df: df.order_by(col("C")).select((col("A") + col("B")).alias("C")), 'SELECT ("A" + "B") AS "C" FROM ( SELECT "A", "B" FROM ( SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT) ) ORDER BY "C" ASC NULLS FIRST )', + None, False, ), # Flattened @@ -1779,13 +1872,15 @@ def test_select_after_filter(setup_reduce_cast, session, operation, simplified_q .order_by(col("B")) .select(col("A")), 'SELECT "A" FROM ( SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT) ) ORDER BY "B" ASC NULLS FIRST, "A" ASC NULLS FIRST', + None, True, ), - # Flattened if the dropped columns are not used in filter + # Flattened if the dropped columns are not used in filter in the snowpark connect compatible mode ( lambda df: df.select(col("A").alias("C"), col("B").alias("D")) .order_by(col("C")) .select(col("C")), + 'SELECT "C" FROM (SELECT "A" AS "C", "B" AS "D" FROM (SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT))) ORDER BY "C" ASC NULLS FIRST', 'SELECT "A" AS "C" FROM (SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT)) ORDER BY "C" ASC NULLS FIRST', True, ), @@ -1794,6 +1889,7 @@ def test_select_after_filter(setup_reduce_cast, session, operation, simplified_q lambda df: df.select(col("A").alias("C"), col("B").alias("D")) .order_by(col("C") + 1) .select(col("C")), + 'SELECT "C" FROM (SELECT "A" AS "C", "B" AS "D" FROM (SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT))) ORDER BY ("C" + 1{POSTFIX}) ASC NULLS FIRST', 'SELECT "A" AS "C" FROM (SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT)) ORDER BY ("C" + 1{POSTFIX}) ASC NULLS FIRST', True, ), @@ -1802,6 +1898,7 @@ def test_select_after_filter(setup_reduce_cast, session, operation, simplified_q lambda df: df.select(col("A"), col("B").alias("D")) .order_by(col("D")) .select(col("A").alias("E")), + 'SELECT "A" AS "E" FROM (SELECT "A", "B" AS "D" FROM (SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT))) ORDER BY "D" ASC NULLS FIRST', 'SELECT "A" AS "E" FROM (SELECT "A", "B" AS "D" FROM (SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT)) ORDER BY "D" ASC NULLS FIRST)', True, ), @@ -1810,6 +1907,7 @@ def test_select_after_filter(setup_reduce_cast, session, operation, simplified_q lambda df: df.select(col("A"), col("B").alias("D")) .order_by(col("D") - 1) .select((col("A") + 1).alias("E")), + 'SELECT ("A" + 1{POSTFIX}) AS "E" FROM (SELECT "A", "B" AS "D" FROM (SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT))) ORDER BY ("D" - 1{POSTFIX}) ASC NULLS FIRST', 'SELECT ("A" + 1{POSTFIX}) AS "E" FROM (SELECT "A", "B" AS "D" FROM (SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT)) ORDER BY ("D" - 1{POSTFIX}) ASC NULLS FIRST)', True, ), @@ -1819,13 +1917,27 @@ def test_select_after_filter(setup_reduce_cast, session, operation, simplified_q .filter((col("B") - 1) > -4) .select((col("A") + 1).alias("E")), 'SELECT ("A" + 1{POSTFIX}) AS "E" FROM (SELECT "A", ("B" + 1{POSTFIX}) AS "B" FROM (SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT))) WHERE (("B" - 1{POSTFIX}) > -4{POSTFIX})', + None, True, ), ], ) def test_select_after_orderby( - setup_reduce_cast, session, operation, simplified_query, execute_sql + setup_reduce_cast, + session, + operation, + simplified_query, + execute_sql, + snowpark_connect_compatible_mode, + monkeypatch, + snowpark_connect_simplified_query, ): + if snowpark_connect_compatible_mode: + import snowflake.snowpark.context as context + + monkeypatch.setattr(context, "_is_snowpark_connect_compatible_mode", True) + simplified_query = snowpark_connect_simplified_query or simplified_query + session.sql_simplifier_enabled = False df1 = session.create_dataframe([[1, -2], [3, -4]], schema=["a", "b"]) @@ -2012,3 +2124,19 @@ def test_select_distinct( ) finally: session.conf.set("use_simplified_query_generation", original) + + +def test_retrieving_aggregation_funcs(session, monkeypatch): + import snowflake.snowpark.context as context + + monkeypatch.setattr(context, "_is_snowpark_connect_compatible_mode", True) + monkeypatch.setattr(context, "_aggregation_function_set", set()) + assert not context._aggregation_function_set + session._retrieve_aggregation_function_list() + assert context._aggregation_function_set + + monkeypatch.setattr(context, "_is_snowpark_connect_compatible_mode", False) + monkeypatch.setattr(context, "_aggregation_function_set", set()) + assert not context._aggregation_function_set + session._retrieve_aggregation_function_list() + assert not context._aggregation_function_set From 3f7a98e08e2985888fcdf60203d3f6e038e80013 Mon Sep 17 00:00:00 2001 From: Adam Ling Date: Fri, 31 Oct 2025 13:18:06 -0700 Subject: [PATCH 7/8] fix local testing --- src/snowflake/snowpark/mock/_select_statement.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/snowflake/snowpark/mock/_select_statement.py b/src/snowflake/snowpark/mock/_select_statement.py index d21a86aeda..2a149db59e 100644 --- a/src/snowflake/snowpark/mock/_select_statement.py +++ b/src/snowflake/snowpark/mock/_select_statement.py @@ -412,7 +412,7 @@ def filter(self, col: Expression) -> "MockSelectStatement": else: dependent_columns = derive_dependent_columns(col) can_be_flattened = can_clause_dependent_columns_flatten( - dependent_columns, self.column_states + dependent_columns, self.column_states, "filter" ) if can_be_flattened: new = copy(self) @@ -433,7 +433,7 @@ def sort(self, cols: List[Expression]) -> "MockSelectStatement": else: dependent_columns = derive_dependent_columns(*cols) can_be_flattened = can_clause_dependent_columns_flatten( - dependent_columns, self.column_states + dependent_columns, self.column_states, "sort" ) if can_be_flattened: new = copy(self) From 8035696a19c40011d34813dd41be9491b1851328 Mon Sep 17 00:00:00 2001 From: Adam Ling Date: Fri, 31 Oct 2025 16:31:22 -0700 Subject: [PATCH 8/8] update --- .../_internal/analyzer/select_statement.py | 99 +++++++++++++------ src/snowflake/snowpark/session.py | 4 + 2 files changed, 72 insertions(+), 31 deletions(-) diff --git a/src/snowflake/snowpark/_internal/analyzer/select_statement.py b/src/snowflake/snowpark/_internal/analyzer/select_statement.py index c0bf4ce207..ca0b677a78 100644 --- a/src/snowflake/snowpark/_internal/analyzer/select_statement.py +++ b/src/snowflake/snowpark/_internal/analyzer/select_statement.py @@ -2306,52 +2306,89 @@ def derive_column_states_from_subquery( return column_states -def has_data_generator_exp(expressions: Optional[List["Expression"]]) -> bool: +def _check_expressions_for_types( + expressions: Optional[List["Expression"]], + check_data_gen: bool = False, + check_window: bool = False, + check_aggregation: bool = False, +) -> bool: + """Efficiently check if expressions contain specific types in a single pass. + + Args: + expressions: List of expressions to check + check_data_gen: Check for data generator functions + check_window: Check for window functions + check_aggregation: Check for aggregation functions + + Returns: + True if any requested type is found + """ if expressions is None: return False + for exp in expressions: - if not context._is_snowpark_connect_compatible_mode and isinstance( - exp, WindowExpression - ): + if exp is None: + continue + + # Check window functions + if check_window and isinstance(exp, WindowExpression): return True - if isinstance(exp, FunctionExpression) and ( - exp.is_data_generator - or exp.name.lower() in SEQUENCE_DEPENDENT_DATA_GENERATION + + # Check data generators (including window in non-connect mode) + if check_data_gen: + # In non-connect mode, windows are treated as data generators + if not context._is_snowpark_connect_compatible_mode and isinstance( + exp, WindowExpression + ): + return True + # Check actual data generator functions + if isinstance(exp, FunctionExpression) and ( + exp.is_data_generator + or exp.name.lower() in SEQUENCE_DEPENDENT_DATA_GENERATION + ): + # https://docs.snowflake.com/en/sql-reference/functions-data-generation + return True + + # Check aggregation functions + if check_aggregation and isinstance(exp, FunctionExpression): + if exp.name.lower() in context._aggregation_function_set: + return True + + # Recursively check children + if _check_expressions_for_types( + exp.children, check_data_gen, check_window, check_aggregation ): - # https://docs.snowflake.com/en/sql-reference/functions-data-generation - return True - if exp is not None and has_data_generator_exp(exp.children): return True + return False -def has_window_function_exp(expressions: Optional[List["Expression"]]) -> bool: - if expressions is None: - return False - for exp in expressions: - if isinstance(exp, WindowExpression): - return True - if exp is not None and has_window_function_exp(exp.children): - return True - return False +def has_data_generator_exp(expressions: Optional[List["Expression"]]) -> bool: + """Check if expressions contain data generator functions. + + Note: + In non-connect mode, check_data_gen check both data generator and window expressions for backward compatibility. + In connect mode, check_data_gen only checks data generator expressions. + """ + return _check_expressions_for_types(expressions, check_data_gen=True) def has_data_generator_or_window_function_exp( expressions: Optional[List["Expression"]], ) -> bool: + """Check if expressions contain data generators or window functions. + + Optimized to do a single pass checking both types simultaneously. + """ if not context._is_snowpark_connect_compatible_mode: - return has_data_generator_exp(expressions) - return has_data_generator_exp(expressions) or has_window_function_exp(expressions) + # In non-connect mode, windows are already treated as data generators + return _check_expressions_for_types(expressions, check_data_gen=True) + # In connect mode, check both in a single pass + return _check_expressions_for_types( + expressions, check_data_gen=True, check_window=True + ) def has_aggregation_function_exp(expressions: Optional[List["Expression"]]) -> bool: - if expressions is None: - return False - for exp in expressions: - if isinstance(exp, FunctionExpression) and ( - exp.name.lower() in context._aggregation_function_set - ): - return True - if exp is not None and has_aggregation_function_exp(exp.children): - return True - return False + """Check if expressions contain aggregation functions.""" + return _check_expressions_for_types(expressions, check_aggregation=True) diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index 6fea3308ae..85915c082b 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -4881,6 +4881,10 @@ def _retrieve_aggregation_function_list(self) -> None: "Unable to get aggregation functions from the database: %s", e, ) + # we raise error here as a pessimistic tactics + # the reason is that if we fail to retrieve the aggregation function list, we have empty set + # the simplifier will flatten the query which contains aggregation functions leading to incorrect results + raise with context._aggregation_function_set_lock: context._aggregation_function_set.update(retrieved_set)