diff --git a/CHANGELOG.md b/CHANGELOG.md index 580d582337..3f154646a6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,8 @@ #### New Features +- Added support for the `INCLUDE_METADATA` copy option in `DataFrame.copy_into_table`, allowing users to include file metadata columns in the target table. + #### Bug Fixes - Fixed a bug in `Session.client_telemetry` that trace does not have snowflake style trace id. diff --git a/src/snowflake/snowpark/_internal/analyzer/analyzer_utils.py b/src/snowflake/snowpark/_internal/analyzer/analyzer_utils.py index 4466d5ec5d..5353036b01 100644 --- a/src/snowflake/snowpark/_internal/analyzer/analyzer_utils.py +++ b/src/snowflake/snowpark/_internal/analyzer/analyzer_utils.py @@ -1357,11 +1357,13 @@ def file_operation_statement( def convert_value_to_sql_option( - value: Optional[Union[str, bool, int, float, list, tuple]], + value: Optional[Union[str, bool, int, float, list, tuple, dict]], parse_none_as_string: bool = False, ) -> str: if value is None and parse_none_as_string: value = str(value) + if isinstance(value, dict): + return f"({', '.join(f'{k} = {v}' for k, v in value.items())})" if isinstance(value, str): if len(value) > 1 and is_single_quoted(value): return value diff --git a/src/snowflake/snowpark/_internal/utils.py b/src/snowflake/snowpark/_internal/utils.py index 12b2b13794..71b35881f5 100644 --- a/src/snowflake/snowpark/_internal/utils.py +++ b/src/snowflake/snowpark/_internal/utils.py @@ -180,6 +180,7 @@ "TRUNCATECOLUMNS", "FORCE", "LOAD_UNCERTAIN_FILES", + "INCLUDE_METADATA", } COPY_INTO_LOCATION_COPY_OPTIONS = { diff --git a/src/snowflake/snowpark/dataframe.py b/src/snowflake/snowpark/dataframe.py index bc9ec55426..20fb9beeb3 100644 --- a/src/snowflake/snowpark/dataframe.py +++ b/src/snowflake/snowpark/dataframe.py @@ -179,7 +179,12 @@ track_data_source_statement_params, ) from snowflake.snowpark.async_job import AsyncJob, _AsyncResultType -from snowflake.snowpark.column import Column, _to_col_if_sql_expr, _to_col_if_str +from snowflake.snowpark.column import ( + METADATA_COLUMN_TYPES, + Column, + _to_col_if_sql_expr, + _to_col_if_str, +) from snowflake.snowpark.dataframe_ai_functions import DataFrameAIFunctions from snowflake.snowpark.dataframe_analytics_functions import DataFrameAnalyticsFunctions from snowflake.snowpark.dataframe_na_functions import DataFrameNaFunctions @@ -4865,6 +4870,27 @@ def copy_into_table( else None ) copy_options = copy_options or reader_copy_options + + if copy_options.get("INCLUDE_METADATA", None) is not None: + for metadata_col in copy_options["INCLUDE_METADATA"].values(): + if quote_name(metadata_col.upper()) not in METADATA_COLUMN_TYPES: + raise ValueError( + f"Metadata column {metadata_col} is not supported. Supported columns: {list(METADATA_COLUMN_TYPES.keys())}" + ) + if "MATCH_BY_COLUMN_NAME" not in copy_options: + raise ValueError( + "INCLUDE_METADATA can only be used with the MATCH_BY_COLUMN_NAME copy option." + ) + if self._reader._file_type and self._reader._file_type.upper() == "CSV": + format_type_options = ( + format_type_options.copy() if format_type_options else {} + ) + if format_type_options.get("ERROR_ON_COLUMN_COUNT_MISMATCH", False): + raise ValueError( + "ERROR_ON_COLUMN_COUNT_MISMATCH must be False when INCLUDE_METADATA is used with CSV files." + ) + format_type_options["ERROR_ON_COLUMN_COUNT_MISMATCH"] = False + validation_mode = validation_mode or self._reader._cur_options.get( "VALIDATION_MODE" ) diff --git a/tests/ast/data/DataFrame.create_or_replace.test b/tests/ast/data/DataFrame.create_or_replace.test index 2021527290..5c27573575 100644 --- a/tests/ast/data/DataFrame.create_or_replace.test +++ b/tests/ast/data/DataFrame.create_or_replace.test @@ -27,6 +27,7 @@ df.copy_into_table( statement_params={"foo": "bar"}, iceberg_config={"external_volume": "example_volume", "partition_by": [bucket(10, "n"), truncate(2, col("str"))], "target_file_size": "128MB", "catalog": "my_catalog"}, force=True, + INCLUDE_METADATA={"filename_col": "METADATA$FILENAME", "row_num_col": "METADATA$FILE_ROW_NUMBER"}, ) df3 = df.cache_result() @@ -47,7 +48,7 @@ res3 = df.create_or_replace_temp_view(["test_db", "test_schema", "test_view"], c res4 = df.create_or_replace_temp_view("test_view", statement_params={"foo": "bar"}, copy_grants=True) -df.copy_into_table(["test_db", "test_schema", "table2"], files=["file1", "file2"], pattern="[A-Z]+", validation_mode="RETURN_ERRORS", target_columns=["n", "str"], transformations=[col("n") * 10, col("str")], format_type_options={"COMPRESSION": "GZIP", "RECORD_DELIMITER": "|"}, statement_params={"foo": "bar"}, force=True, iceberg_config={"external_volume": "example_volume", "partition_by": [bucket(10, "n"), truncate(2, col("str"))], "target_file_size": "128MB", "catalog": "my_catalog"}) +df.copy_into_table(["test_db", "test_schema", "table2"], files=["file1", "file2"], pattern="[A-Z]+", validation_mode="RETURN_ERRORS", target_columns=["n", "str"], transformations=[col("n") * 10, col("str")], format_type_options={"COMPRESSION": "GZIP", "RECORD_DELIMITER": "|"}, statement_params={"foo": "bar"}, force=True, INCLUDE_METADATA={"filename_col": "METADATA$FILENAME", "row_num_col": "METADATA$FILE_ROW_NUMBER"}, iceberg_config={"external_volume": "example_volume", "partition_by": [bucket(10, "n"), truncate(2, col("str"))], "target_file_size": "128MB", "catalog": "my_catalog"}) df = session.table("table1") @@ -254,7 +255,7 @@ body { bool_val { src { end_column: 9 - end_line: 52 + end_line: 53 file: 2 start_column: 8 start_line: 41 @@ -263,6 +264,72 @@ body { } } } + copy_options { + _1: "INCLUDE_METADATA" + _2 { + seq_map_val { + kvs { + vs { + string_val { + src { + end_column: 9 + end_line: 53 + file: 2 + start_column: 8 + start_line: 41 + } + v: "filename_col" + } + } + vs { + string_val { + src { + end_column: 9 + end_line: 53 + file: 2 + start_column: 8 + start_line: 41 + } + v: "METADATA$FILENAME" + } + } + } + kvs { + vs { + string_val { + src { + end_column: 9 + end_line: 53 + file: 2 + start_column: 8 + start_line: 41 + } + v: "row_num_col" + } + } + vs { + string_val { + src { + end_column: 9 + end_line: 53 + file: 2 + start_column: 8 + start_line: 41 + } + v: "METADATA$FILE_ROW_NUMBER" + } + } + } + src { + end_column: 9 + end_line: 53 + file: 2 + start_column: 8 + start_line: 41 + } + } + } + } df { dataframe_ref { id: 1 @@ -276,7 +343,7 @@ body { string_val { src { end_column: 9 - end_line: 52 + end_line: 53 file: 2 start_column: 8 start_line: 41 @@ -291,7 +358,7 @@ body { string_val { src { end_column: 9 - end_line: 52 + end_line: 53 file: 2 start_column: 8 start_line: 41 @@ -306,7 +373,7 @@ body { string_val { src { end_column: 9 - end_line: 52 + end_line: 53 file: 2 start_column: 8 start_line: 41 @@ -321,7 +388,7 @@ body { list_val { src { end_column: 9 - end_line: 52 + end_line: 53 file: 2 start_column: 8 start_line: 41 @@ -449,7 +516,7 @@ body { string_val { src { end_column: 9 - end_line: 52 + end_line: 53 file: 2 start_column: 8 start_line: 41 @@ -464,7 +531,7 @@ body { string_val { src { end_column: 9 - end_line: 52 + end_line: 53 file: 2 start_column: 8 start_line: 41 @@ -478,7 +545,7 @@ body { } src { end_column: 9 - end_line: 52 + end_line: 53 file: 2 start_column: 8 start_line: 41 @@ -653,10 +720,10 @@ body { } src { end_column: 31 - end_line: 54 + end_line: 55 file: 2 start_column: 14 - start_line: 54 + start_line: 55 } } } @@ -685,10 +752,10 @@ body { } src { end_column: 62 - end_line: 56 + end_line: 57 file: 2 start_column: 14 - start_line: 56 + start_line: 57 } statement_params { _1: "foo" @@ -729,10 +796,10 @@ body { } src { end_column: 128 - end_line: 58 + end_line: 59 file: 2 start_column: 8 - start_line: 58 + start_line: 59 } warehouse: "test_wh" } diff --git a/tests/integ/scala/test_dataframe_copy_into.py b/tests/integ/scala/test_dataframe_copy_into.py index 29b11ad2a2..4d896f8604 100644 --- a/tests/integ/scala/test_dataframe_copy_into.py +++ b/tests/integ/scala/test_dataframe_copy_into.py @@ -528,6 +528,16 @@ def test_copy_json_write_with_column_names(session, tmp_stage_name1): Utils.drop_table(session, table_name) +special_format_schema = StructType( + [ + StructField("ID", IntegerType()), + StructField("USERNAME", StringType()), + StructField("FIRSTNAME", StringType()), + StructField("LASTNAME", StringType()), + ] +) + + def test_csv_read_format_name(session, tmp_stage_name1): temp_file_fmt_name = Utils.random_name_for_temp_object(TempObjectType.FILE_FORMAT) session.sql( @@ -535,16 +545,7 @@ def test_csv_read_format_name(session, tmp_stage_name1): "null_if = ('none','NA');" ).collect() df = ( - session.read.schema( - StructType( - [ - StructField("ID", IntegerType()), - StructField("USERNAME", StringType()), - StructField("FIRSTNAME", StringType()), - StructField("LASTNAME", StringType()), - ] - ) - ) + session.read.schema(special_format_schema) .option("format_name", temp_file_fmt_name) .csv( f"@{tmp_stage_name1}/{test_file_csv_special_format}", @@ -1442,3 +1443,87 @@ def create_and_append_check_answer(table_name_input): # drop schema Utils.drop_schema(session, schema) Utils.drop_schema(session, double_quoted_schema) + + +def test_copy_into_table_include_metadata_requires_match_by_column_name( + session, tmp_stage_name1 +): + test_file_on_stage = f"@{tmp_stage_name1}/{test_file_csv}" + table_name = Utils.random_name_for_temp_object(TempObjectType.TABLE) + df = session.read.schema(user_schema).csv(test_file_on_stage) + with pytest.raises( + ValueError, + match="INCLUDE_METADATA can only be used with the MATCH_BY_COLUMN_NAME copy option.", + ): + df.copy_into_table( + table_name, + INCLUDE_METADATA={"filename_col": "METADATA$FILENAME"}, + ) + + +def test_copy_into_table_include_metadata_with_error_on_column_count_mismatch( + session, tmp_stage_name1 +): + test_file_on_stage = f"@{tmp_stage_name1}/{test_file_csv}" + table_name = Utils.random_name_for_temp_object(TempObjectType.TABLE) + df = session.read.schema(user_schema).csv(test_file_on_stage) + with pytest.raises( + ValueError, + match="ERROR_ON_COLUMN_COUNT_MISMATCH must be False when INCLUDE_METADATA is used with CSV files", + ): + df.copy_into_table( + table_name, + format_type_options={"ERROR_ON_COLUMN_COUNT_MISMATCH": True}, + INCLUDE_METADATA={"filename_col": "METADATA$FILENAME"}, + MATCH_BY_COLUMN_NAME="CASE_INSENSITIVE", + ) + + +def test_copy_into_table_include_metadata_requires_supported_metadata_column( + session, tmp_stage_name1 +): + test_file_on_stage = f"@{tmp_stage_name1}/{test_file_csv}" + table_name = Utils.random_name_for_temp_object(TempObjectType.TABLE) + df = session.read.schema(user_schema).csv(test_file_on_stage) + with pytest.raises( + ValueError, + match="Metadata column NON_EXISTING_COLUMN is not supported", + ): + df.copy_into_table( + table_name, + INCLUDE_METADATA={"filename_col": "NON_EXISTING_COLUMN"}, + ) + + +def test_copy_into_table_include_metadata_csv(session, tmp_stage_name1): + test_file_on_stage = f"@{tmp_stage_name1}/{test_file_csv_special_format}" + table_name = Utils.random_name_for_temp_object(TempObjectType.TABLE) + Utils.create_table( + session, + table_name, + "id String, username String, firstname String, lastname String, filename_col String, row_num_col Int", + ) + try: + df = ( + session.read.schema(special_format_schema) + .option("PARSE_HEADER", True) + .csv(test_file_on_stage) + ) + df.copy_into_table( + table_name, + MATCH_BY_COLUMN_NAME="CASE_INSENSITIVE", + INCLUDE_METADATA={ + "filename_col": '"metadata$filename"', + "ROW_NUM_COL": "METADATA$FILE_ROW_NUMBER", + }, + force=True, + ) + result = session.table(table_name).sort("ID").collect() + assert len(result) > 0 + for i, row in enumerate(result): + assert all(v is not None for v in row) + assert len(row) == 6 + assert row["FILENAME_COL"] == test_file_csv_special_format + assert row["ROW_NUM_COL"] == i + 1 + finally: + Utils.drop_table(session, table_name) diff --git a/tests/integ/scala/test_udf_suite.py b/tests/integ/scala/test_udf_suite.py index 4b56ca2b3f..6c96d52469 100644 --- a/tests/integ/scala/test_udf_suite.py +++ b/tests/integ/scala/test_udf_suite.py @@ -567,7 +567,7 @@ def test_geometry_type(session): table_name = Utils.random_name_for_temp_object(TempObjectType.TABLE) Utils.create_table(session, table_name, "g geometry", is_temporary=True) session._run_query( - f"insert into {table_name} values ('POINT(30 10)'), ('POINT(50 60)'), (null)" + f"insert into {table_name} values ('POINT(30.5 10.5)'), ('POINT(50.5 60.5)'), (null)" ) df = session.table(table_name) @@ -576,18 +576,18 @@ def geometry(g): return None else: g_str = str(g) - if "[50, 60]" in g_str and "Point" in g_str: + if "[50.5, 60.5]" in g_str and "Point" in g_str: return g_str else: - return g_str.replace("0", "") + return g_str.replace(".5", ".2") geometry_udf = udf(geometry, return_type=StringType(), input_types=[GeometryType()]) Utils.check_answer( df.select(geometry_udf(col("g"))), [ - Row("{'coordinates': [3, 1], 'type': 'Point'}"), - Row("{'coordinates': [50, 60], 'type': 'Point'}"), + Row("{'coordinates': [30.2, 10.2], 'type': 'Point'}"), + Row("{'coordinates': [50.5, 60.5], 'type': 'Point'}"), Row(None), ], ) diff --git a/tests/unit/test_analyzer_util_suite.py b/tests/unit/test_analyzer_util_suite.py index 3cd5c3e60b..e1ad1e28dd 100644 --- a/tests/unit/test_analyzer_util_suite.py +++ b/tests/unit/test_analyzer_util_suite.py @@ -417,6 +417,19 @@ def test_convert_value_to_sql_option(): assert convert_value_to_sql_option(None, parse_none_as_string=True) == "'None'" assert convert_value_to_sql_option((1,)) == "(1)" assert convert_value_to_sql_option((1, 2)) == "(1, 2)" + assert ( + convert_value_to_sql_option({"col1": "METADATA$FILENAME"}) + == "(col1 = METADATA$FILENAME)" + ) + assert ( + convert_value_to_sql_option( + { + "col1": "METADATA$FILENAME", + "col2": "METADATA$FILE_ROW_NUMBER", + } + ) + == "(col1 = METADATA$FILENAME, col2 = METADATA$FILE_ROW_NUMBER)" + ) def test_file_operation_negative(): @@ -835,3 +848,15 @@ def test_get_options_statement(): ) # already single-quoted string should pass through unchanged assert get_options_statement({"P": "'abc'"}) == " P = 'abc' " + # dict handling (e.g. INCLUDE_METADATA) + assert ( + get_options_statement( + { + "INCLUDE_METADATA": { + "col1": "METADATA$FILENAME", + "col2": "METADATA$FILE_ROW_NUMBER", + } + } + ) + == " INCLUDE_METADATA = (col1 = METADATA$FILENAME, col2 = METADATA$FILE_ROW_NUMBER) " + )