From e085f8307badf6a9786b84495a12fb2ca3b57a44 Mon Sep 17 00:00:00 2001 From: John Kew Date: Fri, 24 Oct 2025 11:02:12 -0700 Subject: [PATCH 1/9] Failed tests --- tests/integ/modin/frame/test_apply.py | 31 +++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/tests/integ/modin/frame/test_apply.py b/tests/integ/modin/frame/test_apply.py index 4938e695d9..e4243bae4a 100644 --- a/tests/integ/modin/frame/test_apply.py +++ b/tests/integ/modin/frame/test_apply.py @@ -147,6 +147,37 @@ def foo(row) -> str: with SqlCounter(query_count=4, join_count=0, udtf_count=0): eval_snowpark_pandas_result(snow_df, df, lambda x: x.apply(foo, axis=1)) +@pytest.mark.parametrize("index", [ + None, + ['a', 'b'], + [100, 200] +]) +@sql_count_checker(query_count=5, join_count=2, udtf_count=1) +def test_apply_axis_1_index_preservation(index): + """Test that apply(axis=1) preserves index values correctly.""" + # Test with default RangeIndex + native_df = native_pd.DataFrame([[1, 2], [3, 4]], index=index) + snow_df = pd.DataFrame(native_df) + + eval_snowpark_pandas_result( + snow_df, native_df, lambda x: x.apply(lambda row : row.name, axis=1) + ) + + +@sql_count_checker(query_count=4, join_count=0, udf_count=1) +def test_apply_axis_1_multiindex_preservation(): + """Test that apply(axis=1) preserves MultiIndex values correctly.""" + # Test with MultiIndex + multi_index = pd.MultiIndex.from_tuples([('A', 1), ('B', 2), ('C', 3)], names=['letter', 'number']) + native_df = native_pd.DataFrame([[1, 2], [3, 4], [5, 6]], index=multi_index) + snow_df = pd.DataFrame(native_df) + + + eval_snowpark_pandas_result( + snow_df, native_df, lambda x: x.apply(lambda row : row.name, axis=1) + ) + + @pytest.mark.xfail(strict=True, raises=NotImplementedError) @sql_count_checker(query_count=0) From ab4a197c58a57473d0acaaef481815ce34d9f142 Mon Sep 17 00:00:00 2001 From: John Kew Date: Fri, 24 Oct 2025 12:15:10 -0700 Subject: [PATCH 2/9] Tests partially pass test_apply --- .../modin/plugin/_internal/apply_utils.py | 48 +++++++++++++++---- .../compiler/snowflake_query_compiler.py | 35 +++++++++++--- tests/integ/modin/frame/test_apply.py | 2 +- 3 files changed, 70 insertions(+), 15 deletions(-) diff --git a/src/snowflake/snowpark/modin/plugin/_internal/apply_utils.py b/src/snowflake/snowpark/modin/plugin/_internal/apply_utils.py index ef8b4aa33e..7ab42c4a06 100644 --- a/src/snowflake/snowpark/modin/plugin/_internal/apply_utils.py +++ b/src/snowflake/snowpark/modin/plugin/_internal/apply_utils.py @@ -424,6 +424,8 @@ def create_udtf_for_apply_axis_1( column_index: native_pd.Index, input_types: list[DataType], session: Session, + index_column_pandas_labels: list[Hashable] | None = None, + num_index_columns: int = 0, **kwargs: Any, ) -> UserDefinedTableFunction: """ @@ -443,7 +445,9 @@ def create_udtf_for_apply_axis_1( result_type: pandas parameter controlling apply within the UDTF. args: pandas parameter controlling apply within the UDTF. column_index: The columns of the callee DataFrame, i.e. df.columns as pd.Index object. - input_types: Snowpark column types of the input data columns. + input_types: Snowpark column types of the input data columns (including index columns). + index_column_pandas_labels: The pandas labels for the index columns, if any. + num_index_columns: Number of index columns being passed into the UDTF. **kwargs: pandas parameter controlling apply within the UDTF. Returns: @@ -458,9 +462,33 @@ def create_udtf_for_apply_axis_1( class ApplyFunc: def end_partition(self, df): # type: ignore[no-untyped-def] # pragma: no cover - # First column is row position, set as index. - df = df.set_index(df.columns[0]) - + # First column is row position, extract it for later use + row_positions = df.iloc[:, 0] + + # If we have index columns, set them as the index + if num_index_columns > 0: + # Columns after row position are index columns, then data columns + index_cols = df.iloc[:, 1:1+num_index_columns] + data_cols = df.iloc[:, 1+num_index_columns:] + + # Set the index using the index columns + if num_index_columns == 1: + index = index_cols.iloc[:, 0] + if index_column_pandas_labels: + index.name = index_column_pandas_labels[0] + else: + # Multi-index case + index = native_pd.MultiIndex.from_arrays( + [index_cols.iloc[:, i] for i in range(num_index_columns)], + names=index_column_pandas_labels if index_column_pandas_labels else None + ) + data_cols.index = index + df = data_cols + else: + # No index columns, use row position as index (original behavior) + df = df.iloc[:, 1:] + df.index = row_positions + df.columns = column_index df = df.apply( func, axis=1, raw=raw, result_type=result_type, args=args, **kwargs @@ -495,8 +523,10 @@ def end_partition(self, df): # type: ignore[no-untyped-def] # pragma: no cover # - VALUE contains the result at this position. if isinstance(df, native_pd.DataFrame): result = [] - for row_position_index, series in df.iterrows(): - + for idx, (row_position_index, series) in enumerate(df.iterrows()): + # Use the actual row position from row_positions, not the index value + actual_row_position = row_positions.iloc[idx] + for i, (label, value) in enumerate(series.items()): # If this is a tuple then we store each component with a 0-based # lookup. For example, (a,b,c) is stored as (0:a, 1:b, 2:c). @@ -508,7 +538,7 @@ def end_partition(self, df): # type: ignore[no-untyped-def] # pragma: no cover obj_label["pos"] = i result.append( [ - row_position_index, + actual_row_position, json.dumps(obj_label), value, ] @@ -531,7 +561,9 @@ def end_partition(self, df): # type: ignore[no-untyped-def] # pragma: no cover elif isinstance(df, native_pd.Series): result = df.to_frame(name="value") result.insert(0, "label", json.dumps({"0": MODIN_UNNAMED_SERIES_LABEL})) - result.reset_index(names="__row__", inplace=True) + # Use the row_positions (integer positions) rather than the index values + result.insert(0, "__row__", row_positions.values) + result = result[["__row__", "label", "value"]] else: raise TypeError(f"Unsupported data type {df} from df.apply") diff --git a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py index 8c54e56e1a..20541ca262 100644 --- a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py +++ b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py @@ -9656,6 +9656,19 @@ def _apply_with_udtf_and_dynamic_pivot_along_axis_1( ) # The apply function is encapsulated in a UDTF and run as a stored procedure on the pandas dataframe. + # Determine if we should pass index columns to the UDTF + # We pass index columns when the index is not the row position itself + index_columns_for_udtf = new_internal_df.index_column_snowflake_quoted_identifiers + if row_position_snowflake_quoted_identifier in index_columns_for_udtf: + # The row position IS the index (e.g., RangeIndex), don't pass index columns + index_columns_for_udtf = [] + num_index_columns = 0 + index_column_pandas_labels_for_udtf = None + else: + # Pass the actual index columns to the UDTF + num_index_columns = len(index_columns_for_udtf) + index_column_pandas_labels_for_udtf = new_internal_df.index_column_pandas_labels + func_udtf = create_udtf_for_apply_axis_1( row_position_snowflake_quoted_identifier, func, @@ -9665,6 +9678,8 @@ def _apply_with_udtf_and_dynamic_pivot_along_axis_1( column_index, input_types, self._modin_frame.ordered_dataframe.session, + index_column_pandas_labels=index_column_pandas_labels_for_udtf, + num_index_columns=num_index_columns, **kwargs, ) @@ -9704,13 +9719,16 @@ def _apply_with_udtf_and_dynamic_pivot_along_axis_1( _emit_ast=self._modin_frame.ordered_dataframe.session.ast_enabled, ) ).as_(partition_identifier) + # Select columns to pass to UDTF: partition, row_position, index columns (if any), data columns udtf_dataframe = new_internal_df.ordered_dataframe.select( partition_expression, row_position_snowflake_quoted_identifier, + *index_columns_for_udtf, *new_internal_df.data_column_snowflake_quoted_identifiers, ).select( func_udtf( row_position_snowflake_quoted_identifier, + *index_columns_for_udtf, *new_internal_df.data_column_snowflake_quoted_identifiers, ).over(partition_by=[partition_identifier]), ) @@ -10428,11 +10446,6 @@ def wrapped_func(*args, **kwargs): # type: ignore[no-untyped-def] # pragma: no return qc_result else: - # get input types of all data columns from the dataframe directly - input_types = self._modin_frame.get_snowflake_type( - self._modin_frame.data_column_snowflake_quoted_identifiers - ) - from snowflake.snowpark.modin.plugin.extensions.utils import ( try_convert_index_to_native, ) @@ -10441,6 +10454,16 @@ def wrapped_func(*args, **kwargs): # type: ignore[no-untyped-def] # pragma: no column_index = try_convert_index_to_native( self._modin_frame.data_columns_index ) + + # get input types of index and data columns from the dataframe + data_input_types = self._modin_frame.get_snowflake_type( + self._modin_frame.data_column_snowflake_quoted_identifiers + ) + index_input_types = self._modin_frame.get_snowflake_type( + self._modin_frame.index_column_snowflake_quoted_identifiers + ) + # Combine index types + data types for UDTF input + input_types = index_input_types + data_input_types # Extract return type from annotations (or lookup for known pandas functions) for func object, # if no return type could be extracted the variable will hold None. @@ -10456,7 +10479,7 @@ def wrapped_func(*args, **kwargs): # type: ignore[no-untyped-def] # pragma: no return self._apply_udf_row_wise_and_reduce_to_series_along_axis_1( func, column_index, - input_types, + data_input_types, return_type, udf_args=args, udf_kwargs=kwargs, diff --git a/tests/integ/modin/frame/test_apply.py b/tests/integ/modin/frame/test_apply.py index e4243bae4a..5894a6fe3a 100644 --- a/tests/integ/modin/frame/test_apply.py +++ b/tests/integ/modin/frame/test_apply.py @@ -164,7 +164,7 @@ def test_apply_axis_1_index_preservation(index): ) -@sql_count_checker(query_count=4, join_count=0, udf_count=1) +@sql_count_checker(query_count=5, join_count=2, udtf_count=1) def test_apply_axis_1_multiindex_preservation(): """Test that apply(axis=1) preserves MultiIndex values correctly.""" # Test with MultiIndex From 47ebc94609a567cecdfd1d911be0873befc7cae1 Mon Sep 17 00:00:00 2001 From: John Kew Date: Fri, 24 Oct 2025 12:20:12 -0700 Subject: [PATCH 3/9] lint --- .../modin/plugin/_internal/apply_utils.py | 18 +++++++++------- .../compiler/snowflake_query_compiler.py | 12 +++++++---- tests/integ/modin/frame/test_apply.py | 21 ++++++++----------- 3 files changed, 27 insertions(+), 24 deletions(-) diff --git a/src/snowflake/snowpark/modin/plugin/_internal/apply_utils.py b/src/snowflake/snowpark/modin/plugin/_internal/apply_utils.py index 7ab42c4a06..3495702350 100644 --- a/src/snowflake/snowpark/modin/plugin/_internal/apply_utils.py +++ b/src/snowflake/snowpark/modin/plugin/_internal/apply_utils.py @@ -464,13 +464,13 @@ class ApplyFunc: def end_partition(self, df): # type: ignore[no-untyped-def] # pragma: no cover # First column is row position, extract it for later use row_positions = df.iloc[:, 0] - + # If we have index columns, set them as the index if num_index_columns > 0: # Columns after row position are index columns, then data columns - index_cols = df.iloc[:, 1:1+num_index_columns] - data_cols = df.iloc[:, 1+num_index_columns:] - + index_cols = df.iloc[:, 1 : 1 + num_index_columns] + data_cols = df.iloc[:, 1 + num_index_columns :] + # Set the index using the index columns if num_index_columns == 1: index = index_cols.iloc[:, 0] @@ -480,7 +480,9 @@ def end_partition(self, df): # type: ignore[no-untyped-def] # pragma: no cover # Multi-index case index = native_pd.MultiIndex.from_arrays( [index_cols.iloc[:, i] for i in range(num_index_columns)], - names=index_column_pandas_labels if index_column_pandas_labels else None + names=index_column_pandas_labels + if index_column_pandas_labels + else None, ) data_cols.index = index df = data_cols @@ -488,7 +490,7 @@ def end_partition(self, df): # type: ignore[no-untyped-def] # pragma: no cover # No index columns, use row position as index (original behavior) df = df.iloc[:, 1:] df.index = row_positions - + df.columns = column_index df = df.apply( func, axis=1, raw=raw, result_type=result_type, args=args, **kwargs @@ -523,10 +525,10 @@ def end_partition(self, df): # type: ignore[no-untyped-def] # pragma: no cover # - VALUE contains the result at this position. if isinstance(df, native_pd.DataFrame): result = [] - for idx, (row_position_index, series) in enumerate(df.iterrows()): + for idx, (_row_position_index, series) in enumerate(df.iterrows()): # Use the actual row position from row_positions, not the index value actual_row_position = row_positions.iloc[idx] - + for i, (label, value) in enumerate(series.items()): # If this is a tuple then we store each component with a 0-based # lookup. For example, (a,b,c) is stored as (0:a, 1:b, 2:c). diff --git a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py index 7f68a21a74..ff923401a1 100644 --- a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py +++ b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py @@ -9602,7 +9602,9 @@ def _apply_with_udtf_and_dynamic_pivot_along_axis_1( # The apply function is encapsulated in a UDTF and run as a stored procedure on the pandas dataframe. # Determine if we should pass index columns to the UDTF # We pass index columns when the index is not the row position itself - index_columns_for_udtf = new_internal_df.index_column_snowflake_quoted_identifiers + index_columns_for_udtf = ( + new_internal_df.index_column_snowflake_quoted_identifiers + ) if row_position_snowflake_quoted_identifier in index_columns_for_udtf: # The row position IS the index (e.g., RangeIndex), don't pass index columns index_columns_for_udtf = [] @@ -9611,8 +9613,10 @@ def _apply_with_udtf_and_dynamic_pivot_along_axis_1( else: # Pass the actual index columns to the UDTF num_index_columns = len(index_columns_for_udtf) - index_column_pandas_labels_for_udtf = new_internal_df.index_column_pandas_labels - + index_column_pandas_labels_for_udtf = ( + new_internal_df.index_column_pandas_labels + ) + func_udtf = create_udtf_for_apply_axis_1( row_position_snowflake_quoted_identifier, func, @@ -10386,7 +10390,7 @@ def wrapped_func(*args, **kwargs): # type: ignore[no-untyped-def] # pragma: no column_index = try_convert_index_to_native( self._modin_frame.data_columns_index ) - + # get input types of index and data columns from the dataframe data_input_types = self._modin_frame.get_snowflake_type( self._modin_frame.data_column_snowflake_quoted_identifiers diff --git a/tests/integ/modin/frame/test_apply.py b/tests/integ/modin/frame/test_apply.py index 5894a6fe3a..e6a1d327c4 100644 --- a/tests/integ/modin/frame/test_apply.py +++ b/tests/integ/modin/frame/test_apply.py @@ -147,20 +147,17 @@ def foo(row) -> str: with SqlCounter(query_count=4, join_count=0, udtf_count=0): eval_snowpark_pandas_result(snow_df, df, lambda x: x.apply(foo, axis=1)) -@pytest.mark.parametrize("index", [ - None, - ['a', 'b'], - [100, 200] -]) + +@pytest.mark.parametrize("index", [None, ["a", "b"], [100, 200]]) @sql_count_checker(query_count=5, join_count=2, udtf_count=1) def test_apply_axis_1_index_preservation(index): """Test that apply(axis=1) preserves index values correctly.""" # Test with default RangeIndex native_df = native_pd.DataFrame([[1, 2], [3, 4]], index=index) snow_df = pd.DataFrame(native_df) - + eval_snowpark_pandas_result( - snow_df, native_df, lambda x: x.apply(lambda row : row.name, axis=1) + snow_df, native_df, lambda x: x.apply(lambda row: row.name, axis=1) ) @@ -168,17 +165,17 @@ def test_apply_axis_1_index_preservation(index): def test_apply_axis_1_multiindex_preservation(): """Test that apply(axis=1) preserves MultiIndex values correctly.""" # Test with MultiIndex - multi_index = pd.MultiIndex.from_tuples([('A', 1), ('B', 2), ('C', 3)], names=['letter', 'number']) + multi_index = pd.MultiIndex.from_tuples( + [("A", 1), ("B", 2), ("C", 3)], names=["letter", "number"] + ) native_df = native_pd.DataFrame([[1, 2], [3, 4], [5, 6]], index=multi_index) snow_df = pd.DataFrame(native_df) - - + eval_snowpark_pandas_result( - snow_df, native_df, lambda x: x.apply(lambda row : row.name, axis=1) + snow_df, native_df, lambda x: x.apply(lambda row: row.name, axis=1) ) - @pytest.mark.xfail(strict=True, raises=NotImplementedError) @sql_count_checker(query_count=0) def test_frame_with_timedelta_index(): From a6c70abda4f6c2f20febbc2cba6c0c0996380ab7 Mon Sep 17 00:00:00 2001 From: John Kew Date: Fri, 24 Oct 2025 13:06:49 -0700 Subject: [PATCH 4/9] Changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 21e53ee372..2648d7697e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -28,6 +28,7 @@ #### Bug Fixes - Fixed a bug in `DataFrameGroupBy.agg` where func is a list of tuples used to set the names of the output columns. +- Preserve index values when using `df.apply(axis=1)`. #### Improvements From 8fda9275f3088e92cdba23fb1351b44d9b1a952f Mon Sep 17 00:00:00 2001 From: John Kew Date: Fri, 24 Oct 2025 13:17:12 -0700 Subject: [PATCH 5/9] Add test for an index from column --- tests/integ/modin/frame/test_apply.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/tests/integ/modin/frame/test_apply.py b/tests/integ/modin/frame/test_apply.py index e6a1d327c4..c1517cf775 100644 --- a/tests/integ/modin/frame/test_apply.py +++ b/tests/integ/modin/frame/test_apply.py @@ -152,7 +152,6 @@ def foo(row) -> str: @sql_count_checker(query_count=5, join_count=2, udtf_count=1) def test_apply_axis_1_index_preservation(index): """Test that apply(axis=1) preserves index values correctly.""" - # Test with default RangeIndex native_df = native_pd.DataFrame([[1, 2], [3, 4]], index=index) snow_df = pd.DataFrame(native_df) @@ -161,6 +160,21 @@ def test_apply_axis_1_index_preservation(index): ) +@sql_count_checker(query_count=5, join_count=2, udtf_count=1) +def test_apply_axis_1_index_from_col(): + """Test that apply(axis=1) preserves an index when set from a column""" + native_df = native_pd.DataFrame( + [[1, 2, 3], [4, 5, 6], [7, 8, 9]], columns=["a", "b", "c"] + ) + snow_df = pd.DataFrame(native_df) + snow_df = snow_df.set_index("a") + native_df = native_df.set_index("a") + + eval_snowpark_pandas_result( + snow_df, native_df, lambda x: x.apply(lambda row: row.name, axis=1) + ) + + @sql_count_checker(query_count=5, join_count=2, udtf_count=1) def test_apply_axis_1_multiindex_preservation(): """Test that apply(axis=1) preserves MultiIndex values correctly.""" From 93b282b26a635f0b128b2cc7b8bea0d1abb7e095 Mon Sep 17 00:00:00 2001 From: John Kew Date: Fri, 24 Oct 2025 15:30:27 -0700 Subject: [PATCH 6/9] Remove stale branch --- .../modin/plugin/compiler/snowflake_query_compiler.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py index ff923401a1..ad10a0ea3a 100644 --- a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py +++ b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py @@ -9605,12 +9605,7 @@ def _apply_with_udtf_and_dynamic_pivot_along_axis_1( index_columns_for_udtf = ( new_internal_df.index_column_snowflake_quoted_identifiers ) - if row_position_snowflake_quoted_identifier in index_columns_for_udtf: - # The row position IS the index (e.g., RangeIndex), don't pass index columns - index_columns_for_udtf = [] - num_index_columns = 0 - index_column_pandas_labels_for_udtf = None - else: + if row_position_snowflake_quoted_identifier not in index_columns_for_udtf: # Pass the actual index columns to the UDTF num_index_columns = len(index_columns_for_udtf) index_column_pandas_labels_for_udtf = ( From 056e54d3df632204384032755f52618e1181559f Mon Sep 17 00:00:00 2001 From: John Kew Date: Mon, 27 Oct 2025 14:05:48 -0700 Subject: [PATCH 7/9] Clean up some AI stuff --- .../snowpark/modin/plugin/_internal/apply_utils.py | 2 +- .../modin/plugin/compiler/snowflake_query_compiler.py | 8 +------- 2 files changed, 2 insertions(+), 8 deletions(-) diff --git a/src/snowflake/snowpark/modin/plugin/_internal/apply_utils.py b/src/snowflake/snowpark/modin/plugin/_internal/apply_utils.py index 3495702350..ec9424db91 100644 --- a/src/snowflake/snowpark/modin/plugin/_internal/apply_utils.py +++ b/src/snowflake/snowpark/modin/plugin/_internal/apply_utils.py @@ -425,7 +425,6 @@ def create_udtf_for_apply_axis_1( input_types: list[DataType], session: Session, index_column_pandas_labels: list[Hashable] | None = None, - num_index_columns: int = 0, **kwargs: Any, ) -> UserDefinedTableFunction: """ @@ -466,6 +465,7 @@ def end_partition(self, df): # type: ignore[no-untyped-def] # pragma: no cover row_positions = df.iloc[:, 0] # If we have index columns, set them as the index + num_index_columns = len(index_column_pandas_labels) if num_index_columns > 0: # Columns after row position are index columns, then data columns index_cols = df.iloc[:, 1 : 1 + num_index_columns] diff --git a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py index 05abe51d8b..07fe046383 100644 --- a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py +++ b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py @@ -9623,12 +9623,7 @@ def _apply_with_udtf_and_dynamic_pivot_along_axis_1( index_columns_for_udtf = ( new_internal_df.index_column_snowflake_quoted_identifiers ) - if row_position_snowflake_quoted_identifier not in index_columns_for_udtf: - # Pass the actual index columns to the UDTF - num_index_columns = len(index_columns_for_udtf) - index_column_pandas_labels_for_udtf = ( - new_internal_df.index_column_pandas_labels - ) + index_column_pandas_labels_for_udtf = new_internal_df.index_column_pandas_labels func_udtf = create_udtf_for_apply_axis_1( row_position_snowflake_quoted_identifier, @@ -9640,7 +9635,6 @@ def _apply_with_udtf_and_dynamic_pivot_along_axis_1( input_types, self._modin_frame.ordered_dataframe.session, index_column_pandas_labels=index_column_pandas_labels_for_udtf, - num_index_columns=num_index_columns, **kwargs, ) From e5b352cc4e2eaf1d9d1595b9412a5e8ab5bcb49d Mon Sep 17 00:00:00 2001 From: John Kew Date: Mon, 27 Oct 2025 14:23:49 -0700 Subject: [PATCH 8/9] Use set_index --- src/snowflake/snowpark/modin/plugin/_internal/apply_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/snowflake/snowpark/modin/plugin/_internal/apply_utils.py b/src/snowflake/snowpark/modin/plugin/_internal/apply_utils.py index ec9424db91..7588007265 100644 --- a/src/snowflake/snowpark/modin/plugin/_internal/apply_utils.py +++ b/src/snowflake/snowpark/modin/plugin/_internal/apply_utils.py @@ -484,12 +484,12 @@ def end_partition(self, df): # type: ignore[no-untyped-def] # pragma: no cover if index_column_pandas_labels else None, ) - data_cols.index = index + data_cols.set_index(index, inplace=True) df = data_cols else: # No index columns, use row position as index (original behavior) df = df.iloc[:, 1:] - df.index = row_positions + df.set_index(row_positions, inplace=True) df.columns = column_index df = df.apply( From 73961bdb11323ce08b6cfe4adda125d8f3fe5aa5 Mon Sep 17 00:00:00 2001 From: John Kew Date: Tue, 28 Oct 2025 09:12:29 -0700 Subject: [PATCH 9/9] More cleanup --- .../modin/plugin/_internal/apply_utils.py | 17 ++++++++--------- .../plugin/compiler/snowflake_query_compiler.py | 2 +- 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/src/snowflake/snowpark/modin/plugin/_internal/apply_utils.py b/src/snowflake/snowpark/modin/plugin/_internal/apply_utils.py index 7588007265..8e04b7c313 100644 --- a/src/snowflake/snowpark/modin/plugin/_internal/apply_utils.py +++ b/src/snowflake/snowpark/modin/plugin/_internal/apply_utils.py @@ -424,7 +424,7 @@ def create_udtf_for_apply_axis_1( column_index: native_pd.Index, input_types: list[DataType], session: Session, - index_column_pandas_labels: list[Hashable] | None = None, + index_column_labels: list[Hashable] | None = None, **kwargs: Any, ) -> UserDefinedTableFunction: """ @@ -445,8 +445,7 @@ def create_udtf_for_apply_axis_1( args: pandas parameter controlling apply within the UDTF. column_index: The columns of the callee DataFrame, i.e. df.columns as pd.Index object. input_types: Snowpark column types of the input data columns (including index columns). - index_column_pandas_labels: The pandas labels for the index columns, if any. - num_index_columns: Number of index columns being passed into the UDTF. + index_column_labels: index column labels, assuming this is not a RangeIndex **kwargs: pandas parameter controlling apply within the UDTF. Returns: @@ -465,7 +464,9 @@ def end_partition(self, df): # type: ignore[no-untyped-def] # pragma: no cover row_positions = df.iloc[:, 0] # If we have index columns, set them as the index - num_index_columns = len(index_column_pandas_labels) + num_index_columns = ( + 0 if index_column_labels is None else len(index_column_labels) + ) if num_index_columns > 0: # Columns after row position are index columns, then data columns index_cols = df.iloc[:, 1 : 1 + num_index_columns] @@ -474,15 +475,13 @@ def end_partition(self, df): # type: ignore[no-untyped-def] # pragma: no cover # Set the index using the index columns if num_index_columns == 1: index = index_cols.iloc[:, 0] - if index_column_pandas_labels: - index.name = index_column_pandas_labels[0] + if index_column_labels: + index.name = index_column_labels[0] else: # Multi-index case index = native_pd.MultiIndex.from_arrays( [index_cols.iloc[:, i] for i in range(num_index_columns)], - names=index_column_pandas_labels - if index_column_pandas_labels - else None, + names=index_column_labels if index_column_labels else None, ) data_cols.set_index(index, inplace=True) df = data_cols diff --git a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py index d007c68944..8d8fddcc4b 100644 --- a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py +++ b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py @@ -9644,7 +9644,7 @@ def _apply_with_udtf_and_dynamic_pivot_along_axis_1( column_index, input_types, self._modin_frame.ordered_dataframe.session, - index_column_pandas_labels=index_column_pandas_labels_for_udtf, + index_column_labels=index_column_pandas_labels_for_udtf, **kwargs, )