diff --git a/CHANGELOG.md b/CHANGELOG.md index 6024959c69..6241f7838f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -41,6 +41,7 @@ #### Bug Fixes - Fixed a bug in `DataFrameGroupBy.agg` where func is a list of tuples used to set the names of the output columns. +- Preserve index values when using `df.apply(axis=1)`. - Fixed a bug where converting a modin datetime index with a timezone to a numpy array with `np.asarray` would cause a `TypeError`. #### Improvements diff --git a/src/snowflake/snowpark/modin/plugin/_internal/apply_utils.py b/src/snowflake/snowpark/modin/plugin/_internal/apply_utils.py index ef8b4aa33e..8e04b7c313 100644 --- a/src/snowflake/snowpark/modin/plugin/_internal/apply_utils.py +++ b/src/snowflake/snowpark/modin/plugin/_internal/apply_utils.py @@ -424,6 +424,7 @@ def create_udtf_for_apply_axis_1( column_index: native_pd.Index, input_types: list[DataType], session: Session, + index_column_labels: list[Hashable] | None = None, **kwargs: Any, ) -> UserDefinedTableFunction: """ @@ -443,7 +444,8 @@ def create_udtf_for_apply_axis_1( result_type: pandas parameter controlling apply within the UDTF. args: pandas parameter controlling apply within the UDTF. column_index: The columns of the callee DataFrame, i.e. df.columns as pd.Index object. - input_types: Snowpark column types of the input data columns. + input_types: Snowpark column types of the input data columns (including index columns). + index_column_labels: index column labels, assuming this is not a RangeIndex **kwargs: pandas parameter controlling apply within the UDTF. Returns: @@ -458,8 +460,35 @@ def create_udtf_for_apply_axis_1( class ApplyFunc: def end_partition(self, df): # type: ignore[no-untyped-def] # pragma: no cover - # First column is row position, set as index. - df = df.set_index(df.columns[0]) + # First column is row position, extract it for later use + row_positions = df.iloc[:, 0] + + # If we have index columns, set them as the index + num_index_columns = ( + 0 if index_column_labels is None else len(index_column_labels) + ) + if num_index_columns > 0: + # Columns after row position are index columns, then data columns + index_cols = df.iloc[:, 1 : 1 + num_index_columns] + data_cols = df.iloc[:, 1 + num_index_columns :] + + # Set the index using the index columns + if num_index_columns == 1: + index = index_cols.iloc[:, 0] + if index_column_labels: + index.name = index_column_labels[0] + else: + # Multi-index case + index = native_pd.MultiIndex.from_arrays( + [index_cols.iloc[:, i] for i in range(num_index_columns)], + names=index_column_labels if index_column_labels else None, + ) + data_cols.set_index(index, inplace=True) + df = data_cols + else: + # No index columns, use row position as index (original behavior) + df = df.iloc[:, 1:] + df.set_index(row_positions, inplace=True) df.columns = column_index df = df.apply( @@ -495,7 +524,9 @@ def end_partition(self, df): # type: ignore[no-untyped-def] # pragma: no cover # - VALUE contains the result at this position. if isinstance(df, native_pd.DataFrame): result = [] - for row_position_index, series in df.iterrows(): + for idx, (_row_position_index, series) in enumerate(df.iterrows()): + # Use the actual row position from row_positions, not the index value + actual_row_position = row_positions.iloc[idx] for i, (label, value) in enumerate(series.items()): # If this is a tuple then we store each component with a 0-based @@ -508,7 +539,7 @@ def end_partition(self, df): # type: ignore[no-untyped-def] # pragma: no cover obj_label["pos"] = i result.append( [ - row_position_index, + actual_row_position, json.dumps(obj_label), value, ] @@ -531,7 +562,9 @@ def end_partition(self, df): # type: ignore[no-untyped-def] # pragma: no cover elif isinstance(df, native_pd.Series): result = df.to_frame(name="value") result.insert(0, "label", json.dumps({"0": MODIN_UNNAMED_SERIES_LABEL})) - result.reset_index(names="__row__", inplace=True) + # Use the row_positions (integer positions) rather than the index values + result.insert(0, "__row__", row_positions.values) + result = result[["__row__", "label", "value"]] else: raise TypeError(f"Unsupported data type {df} from df.apply") diff --git a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py index ea22aacc5d..7c18ed8ed4 100644 --- a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py +++ b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py @@ -9809,6 +9809,13 @@ def _apply_with_udtf_and_dynamic_pivot_along_axis_1( ) # The apply function is encapsulated in a UDTF and run as a stored procedure on the pandas dataframe. + # Determine if we should pass index columns to the UDTF + # We pass index columns when the index is not the row position itself + index_columns_for_udtf = ( + new_internal_df.index_column_snowflake_quoted_identifiers + ) + index_column_pandas_labels_for_udtf = new_internal_df.index_column_pandas_labels + func_udtf = create_udtf_for_apply_axis_1( row_position_snowflake_quoted_identifier, func, @@ -9818,6 +9825,7 @@ def _apply_with_udtf_and_dynamic_pivot_along_axis_1( column_index, input_types, self._modin_frame.ordered_dataframe.session, + index_column_labels=index_column_pandas_labels_for_udtf, **kwargs, ) @@ -9857,13 +9865,16 @@ def _apply_with_udtf_and_dynamic_pivot_along_axis_1( _emit_ast=self._modin_frame.ordered_dataframe.session.ast_enabled, ) ).as_(partition_identifier) + # Select columns to pass to UDTF: partition, row_position, index columns (if any), data columns udtf_dataframe = new_internal_df.ordered_dataframe.select( partition_expression, row_position_snowflake_quoted_identifier, + *index_columns_for_udtf, *new_internal_df.data_column_snowflake_quoted_identifiers, ).select( func_udtf( row_position_snowflake_quoted_identifier, + *index_columns_for_udtf, *new_internal_df.data_column_snowflake_quoted_identifiers, ).over(partition_by=[partition_identifier]), ) @@ -10590,11 +10601,6 @@ def wrapped_func(*args, **kwargs): # type: ignore[no-untyped-def] # pragma: no return qc_result else: - # get input types of all data columns from the dataframe directly - input_types = self._modin_frame.get_snowflake_type( - self._modin_frame.data_column_snowflake_quoted_identifiers - ) - from snowflake.snowpark.modin.plugin.extensions.utils import ( try_convert_index_to_native, ) @@ -10604,6 +10610,16 @@ def wrapped_func(*args, **kwargs): # type: ignore[no-untyped-def] # pragma: no self._modin_frame.data_columns_index ) + # get input types of index and data columns from the dataframe + data_input_types = self._modin_frame.get_snowflake_type( + self._modin_frame.data_column_snowflake_quoted_identifiers + ) + index_input_types = self._modin_frame.get_snowflake_type( + self._modin_frame.index_column_snowflake_quoted_identifiers + ) + # Combine index types + data types for UDTF input + input_types = index_input_types + data_input_types + # Extract return type from annotations (or lookup for known pandas functions) for func object, # if no return type could be extracted the variable will hold None. return_type = deduce_return_type_from_function(func, None) @@ -10618,7 +10634,7 @@ def wrapped_func(*args, **kwargs): # type: ignore[no-untyped-def] # pragma: no return self._apply_udf_row_wise_and_reduce_to_series_along_axis_1( func, column_index, - input_types, + data_input_types, return_type, udf_args=args, udf_kwargs=kwargs, diff --git a/tests/integ/modin/frame/test_apply.py b/tests/integ/modin/frame/test_apply.py index 73afef8f60..0420b20d24 100644 --- a/tests/integ/modin/frame/test_apply.py +++ b/tests/integ/modin/frame/test_apply.py @@ -148,6 +148,48 @@ def foo(row) -> str: eval_snowpark_pandas_result(snow_df, df, lambda x: x.apply(foo, axis=1)) +@pytest.mark.parametrize("index", [None, ["a", "b"], [100, 200]]) +@sql_count_checker(query_count=5, join_count=2, udtf_count=1) +def test_apply_axis_1_index_preservation(index): + """Test that apply(axis=1) preserves index values correctly.""" + native_df = native_pd.DataFrame([[1, 2], [3, 4]], index=index) + snow_df = pd.DataFrame(native_df) + + eval_snowpark_pandas_result( + snow_df, native_df, lambda x: x.apply(lambda row: row.name, axis=1) + ) + + +@sql_count_checker(query_count=5, join_count=2, udtf_count=1) +def test_apply_axis_1_index_from_col(): + """Test that apply(axis=1) preserves an index when set from a column""" + native_df = native_pd.DataFrame( + [[1, 2, 3], [4, 5, 6], [7, 8, 9]], columns=["a", "b", "c"] + ) + snow_df = pd.DataFrame(native_df) + snow_df = snow_df.set_index("a") + native_df = native_df.set_index("a") + + eval_snowpark_pandas_result( + snow_df, native_df, lambda x: x.apply(lambda row: row.name, axis=1) + ) + + +@sql_count_checker(query_count=5, join_count=2, udtf_count=1) +def test_apply_axis_1_multiindex_preservation(): + """Test that apply(axis=1) preserves MultiIndex values correctly.""" + # Test with MultiIndex + multi_index = pd.MultiIndex.from_tuples( + [("A", 1), ("B", 2), ("C", 3)], names=["letter", "number"] + ) + native_df = native_pd.DataFrame([[1, 2], [3, 4], [5, 6]], index=multi_index) + snow_df = pd.DataFrame(native_df) + + eval_snowpark_pandas_result( + snow_df, native_df, lambda x: x.apply(lambda row: row.name, axis=1) + ) + + @pytest.mark.xfail(strict=True, raises=NotImplementedError) @sql_count_checker(query_count=0) def test_frame_with_timedelta_index():