Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1882,6 +1882,7 @@ def _create_xml_query(
lit(ignore_surrounding_whitespace),
lit(row_validation_xsd_path),
lit(schema_string),
lit(context._is_snowpark_connect_compatible_mode),
),
)

Expand Down
4 changes: 4 additions & 0 deletions src/snowflake/snowpark/_internal/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@
# The following are not copy into SQL command options but client side options.
"INFER_SCHEMA",
"INFER_SCHEMA_OPTIONS",
"SAMPLING_RATIO",
"FORMAT_TYPE_OPTIONS",
"TARGET_COLUMNS",
"TRANSFORMATIONS",
Expand All @@ -210,6 +211,9 @@
XML_ROW_TAG_STRING = "ROWTAG"
XML_ROW_DATA_COLUMN_NAME = "ROW_DATA"
XML_READER_FILE_PATH = os.path.join(os.path.dirname(__file__), "xml_reader.py")
XML_SCHEMA_INFERENCE_FILE_PATH = os.path.join(
os.path.dirname(__file__), "xml_schema_inference.py"
)
XML_READER_API_SIGNATURE = "DataFrameReader.xml[rowTag]"
XML_READER_SQL_COMMENT = f"/* Python:snowflake.snowpark.{XML_READER_API_SIGNATURE} */"

Expand Down
149 changes: 139 additions & 10 deletions src/snowflake/snowpark/_internal/xml_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
#

import datetime
import os
import re
import html.entities
Expand All @@ -12,7 +13,18 @@
from snowflake.snowpark._internal.analyzer.analyzer_utils import unquote_if_quoted
from snowflake.snowpark._internal.type_utils import type_string_to_type_object
from snowflake.snowpark.files import SnowflakeFile
from snowflake.snowpark.types import StructType, ArrayType, DataType, MapType
from snowflake.snowpark.types import (
ArrayType,
BooleanType,
DataType,
DateType,
DoubleType,
LongType,
MapType,
StringType,
StructType,
TimestampType,
)

# lxml is only a dev dependency so use try/except to import it if available
try:
Expand Down Expand Up @@ -102,13 +114,97 @@ def _restore_colons_in_template(template: Optional[dict]) -> Optional[dict]:
return restored


def schema_string_to_result_dict_and_struct_type(schema_string: str) -> Optional[dict]:
def schema_string_to_result_dict_and_struct_type(
schema_string: str,
) -> Tuple[Optional[dict], Optional[StructType]]:
if schema_string == "":
return None
return None, None
safe_string = _escape_colons_in_quotes(schema_string)
schema = type_string_to_type_object(safe_string)
template = struct_type_to_result_template(schema)
return _restore_colons_in_template(template)
return _restore_colons_in_template(template), schema


def _can_cast_to_type(value: str, target_type: DataType) -> bool:
if isinstance(target_type, StringType):
return True
if isinstance(target_type, LongType):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this cover all types -- I don't see decimal type here, so does it mean we don't need to handle the decimal type here

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, Spark by default doesn't infer decimal for XML. Spark's DecimalType inference is gated behind options.prefersDecimal, which defaults to false.
case v if options.prefersDecimal && decimalTry.isDefined => decimalTry.get

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

did we confirm that options.prefersDecimal is out of the scope for implementation at this stage?

try:
int(value)
return True
except (ValueError, OverflowError):
return False
if isinstance(target_type, DoubleType):
try:
float(value)
return True
except ValueError:
return False
if isinstance(target_type, BooleanType):
return value.lower() in ("true", "false", "1", "0")
if isinstance(target_type, DateType):
try:
datetime.date.fromisoformat(value)
return True
except (ValueError, TypeError):
return False
if isinstance(target_type, TimestampType):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

currently, this function will return True when the target type is not in the check(not StringType, LongType, DoubleType ....)
For example, if the target type is TimeType, this function always return True, is this expected?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, when the target type is not checked, we fall back to the original behavior to throw errors upon encountering corrupted data. Snce we will guard this type mismatch permissive mode behavior with context._is_snowpark_connect_compatible, this is a no-op to Snowpark users.

In SCOS, since the current defined scope of PERMISSIVE mode is to unblock infer_schema usage, this function checks all possible infer_schema types returned by Spark, and thus would be sufficient for SCOS infer_schema usage.

try:
datetime.datetime.fromisoformat(value)
return True
except (ValueError, TypeError):
return False
return True


def _validate_row_for_type_mismatch(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this function being called per row after each call of element_to_dict_or_str ? I am a little concerning that whether this is going to impact the performance. Do you think it is possible to put this logic into element_to_dict_or_str so that we don't have to traverse each element of the row again?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Discussed offline, _validate_row_for_type_mismatch is O(n) where n is the number of schema fields/columns; whereas element_to_dict_or_str traverses every element recursively. Thus, _validate_row_for_type_mismatch's performance impact is minimal compared it.

From contextual point of view, _validate_row_for_type_mismatch needs to validate against the resulting transformed dict, so it'd be better to semantically and sequentially separate these functions.

row: dict,
schema: StructType,
mode: str,
record_str: str = "",
column_name_of_corrupt_record: str = "_corrupt_record",
) -> Optional[dict]:
"""Validate a parsed row dict against the expected schema types for mode handling:
- PERMISSIVE: set mismatched fields to ``None`` and store the raw XML
record in *column_name_of_corrupt_record* (Spark compatible).
- FAILFAST: raise immediately on the first mismatch.
- DROPMALFORMED: return ``None`` so the caller skips the row.

Only top-level primitive fields are validated because complex types
are kept as VARIANT and never cast downstream.
"""
had_error = False
for field in schema.fields:
field_name = unquote_if_quoted(field.name)
if field_name not in row:
continue

value = row[field_name]
if value is None:
continue

# Skip complex types as these are kept as VARIANT
if isinstance(field.datatype, (StructType, ArrayType, MapType)):
continue

castable = isinstance(value, str) and _can_cast_to_type(value, field.datatype)
if not castable:
if mode == "FAILFAST":
raise RuntimeError(
f"Failed to cast value '{value}' to "
f"{field.datatype.simple_string()} for field "
f"'{field_name}'.\nXML record: {record_str}"
)
if mode == "DROPMALFORMED":
return None
# PERMISSIVE: null the bad field, continue checking remaining fields
row[field_name] = None
had_error = True

if had_error and mode == "PERMISSIVE":
row[column_name_of_corrupt_record] = record_str

return row


def struct_type_to_result_template(dt: DataType) -> Optional[dict]:
Expand Down Expand Up @@ -367,7 +463,18 @@ def get_text(element: ET.Element) -> Optional[str]:

children = list(element)
if not children and (not element.attrib or exclude_attributes):
# it's a value element with no attributes or excluded attributes, so return the text
# When the schema (result_template) expects a struct for this element,
# wrap the text in the template so the output shape matches the schema.
# e.g. <publisher>Some Publisher</publisher> with template
# {"_VALUE": None, "_country": None, "_language": None} becomes
# {"_VALUE": "Some Publisher", "_country": None, "_language": None}
# instead of the raw string "Some Publisher".
if result_template is not None and isinstance(result_template, dict):
result = copy.deepcopy(result_template)
text = get_text(element)
if text is not None:
result[value_tag] = text
return result
return get_text(element)

result = copy.deepcopy(result_template) if result_template is not None else {}
Expand Down Expand Up @@ -445,6 +552,8 @@ def process_xml_range(
row_validation_xsd_path: str,
chunk_size: int = DEFAULT_CHUNK_SIZE,
result_template: Optional[dict] = None,
schema_type: Optional[StructType] = None,
is_snowpark_connect_compatible: bool = False,
) -> Iterator[Optional[Dict[str, Any]]]:
"""
Processes an XML file within a given approximate byte range.
Expand Down Expand Up @@ -475,6 +584,8 @@ def process_xml_range(
row_validation_xsd_path (str): Path to XSD file for row validation.
chunk_size (int): Size of chunks to read.
result_template(dict): a result template generate from user input schema
schema_type(StructType): the parsed StructType for row validation
is_snowpark_connect_compatible(bool): context._is_snowpark_connect_compatible_mode

Yields:
Optional[Dict[str, Any]]: Dictionary representation of the parsed XML element.
Expand Down Expand Up @@ -607,10 +718,22 @@ def process_xml_range(
ignore_surrounding_whitespace=ignore_surrounding_whitespace,
result_template=copy.deepcopy(result_template),
)
if isinstance(result, dict):
yield result
else:
yield {value_tag: result}
row = result if isinstance(result, dict) else {value_tag: result}

# Validate primitive field values against schema types in Snowpark Connect mode only
if schema_type is not None and is_snowpark_connect_compatible:
# Mode handling for type mismatch errors.
row = _validate_row_for_type_mismatch(
row,
schema_type,
mode,
record_str,
column_name_of_corrupt_record,
)

if row is not None:
yield row
# Mode handling for malformed XML records that fail to parse.
except ET.ParseError as e:
if mode == "PERMISSIVE":
yield {column_name_of_corrupt_record: record_str}
Expand Down Expand Up @@ -645,6 +768,7 @@ def process(
ignore_surrounding_whitespace: bool,
row_validation_xsd_path: str,
custom_schema: str,
is_snowpark_connect_compatible: bool,
):
"""
Splits the file into byte ranges—one per worker—by starting with an even
Expand All @@ -668,12 +792,15 @@ def process(
ignore_surrounding_whitespace (bool): Whether or not whitespaces surrounding values should be skipped.
row_validation_xsd_path (str): Path to XSD file for row validation.
custom_schema: User input schema for xml, must be used together with row tag.
is_snowpark_connect_compatible (bool): context._is_snowpark_connect_compatible_mode
"""
file_size = get_file_size(filename)
approx_chunk_size = file_size // num_workers
approx_start = approx_chunk_size * i
approx_end = approx_chunk_size * (i + 1) if i < num_workers - 1 else file_size
result_template = schema_string_to_result_dict_and_struct_type(custom_schema)
result_template, schema_type = schema_string_to_result_dict_and_struct_type(
custom_schema
)
for element in process_xml_range(
filename,
row_tag,
Expand All @@ -690,5 +817,7 @@ def process(
ignore_surrounding_whitespace,
row_validation_xsd_path=row_validation_xsd_path,
result_template=result_template,
schema_type=schema_type,
is_snowpark_connect_compatible=is_snowpark_connect_compatible,
):
yield (element,)
Loading
Loading