-
Notifications
You must be signed in to change notification settings - Fork 146
SNOW-3192256: Support XML infer schema #4123
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
4acbeb7
87505c8
ca48c9a
49e3078
9520970
97047c9
62c55ba
292318c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,6 +2,7 @@ | |
| # Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved. | ||
| # | ||
|
|
||
| import datetime | ||
| import os | ||
| import re | ||
| import html.entities | ||
|
|
@@ -12,7 +13,18 @@ | |
| from snowflake.snowpark._internal.analyzer.analyzer_utils import unquote_if_quoted | ||
| from snowflake.snowpark._internal.type_utils import type_string_to_type_object | ||
| from snowflake.snowpark.files import SnowflakeFile | ||
| from snowflake.snowpark.types import StructType, ArrayType, DataType, MapType | ||
| from snowflake.snowpark.types import ( | ||
| ArrayType, | ||
| BooleanType, | ||
| DataType, | ||
| DateType, | ||
| DoubleType, | ||
| LongType, | ||
| MapType, | ||
| StringType, | ||
| StructType, | ||
| TimestampType, | ||
| ) | ||
|
|
||
| # lxml is only a dev dependency so use try/except to import it if available | ||
| try: | ||
|
|
@@ -102,13 +114,97 @@ def _restore_colons_in_template(template: Optional[dict]) -> Optional[dict]: | |
| return restored | ||
|
|
||
|
|
||
| def schema_string_to_result_dict_and_struct_type(schema_string: str) -> Optional[dict]: | ||
| def schema_string_to_result_dict_and_struct_type( | ||
| schema_string: str, | ||
| ) -> Tuple[Optional[dict], Optional[StructType]]: | ||
| if schema_string == "": | ||
| return None | ||
| return None, None | ||
| safe_string = _escape_colons_in_quotes(schema_string) | ||
| schema = type_string_to_type_object(safe_string) | ||
| template = struct_type_to_result_template(schema) | ||
| return _restore_colons_in_template(template) | ||
| return _restore_colons_in_template(template), schema | ||
|
|
||
|
|
||
| def _can_cast_to_type(value: str, target_type: DataType) -> bool: | ||
| if isinstance(target_type, StringType): | ||
| return True | ||
| if isinstance(target_type, LongType): | ||
| try: | ||
| int(value) | ||
| return True | ||
| except (ValueError, OverflowError): | ||
| return False | ||
| if isinstance(target_type, DoubleType): | ||
| try: | ||
| float(value) | ||
| return True | ||
| except ValueError: | ||
| return False | ||
| if isinstance(target_type, BooleanType): | ||
| return value.lower() in ("true", "false", "1", "0") | ||
| if isinstance(target_type, DateType): | ||
| try: | ||
| datetime.date.fromisoformat(value) | ||
sfc-gh-aling marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| return True | ||
| except (ValueError, TypeError): | ||
| return False | ||
| if isinstance(target_type, TimestampType): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. currently, this function will return True when the target type is not in the check(not StringType, LongType, DoubleType ....)
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good point, when the target type is not checked, we fall back to the original behavior to throw errors upon encountering corrupted data. Snce we will guard this type mismatch permissive mode behavior with In SCOS, since the current defined scope of |
||
| try: | ||
| datetime.datetime.fromisoformat(value) | ||
| return True | ||
| except (ValueError, TypeError): | ||
| return False | ||
| return True | ||
|
|
||
|
|
||
| def _validate_row_for_type_mismatch( | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this function being called per row after each call of
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Discussed offline, From contextual point of view, |
||
| row: dict, | ||
| schema: StructType, | ||
| mode: str, | ||
| record_str: str = "", | ||
| column_name_of_corrupt_record: str = "_corrupt_record", | ||
| ) -> Optional[dict]: | ||
| """Validate a parsed row dict against the expected schema types for mode handling: | ||
| - PERMISSIVE: set mismatched fields to ``None`` and store the raw XML | ||
| record in *column_name_of_corrupt_record* (Spark compatible). | ||
| - FAILFAST: raise immediately on the first mismatch. | ||
| - DROPMALFORMED: return ``None`` so the caller skips the row. | ||
|
|
||
| Only top-level primitive fields are validated because complex types | ||
| are kept as VARIANT and never cast downstream. | ||
| """ | ||
| had_error = False | ||
| for field in schema.fields: | ||
| field_name = unquote_if_quoted(field.name) | ||
| if field_name not in row: | ||
| continue | ||
|
|
||
| value = row[field_name] | ||
| if value is None: | ||
| continue | ||
|
|
||
| # Skip complex types as these are kept as VARIANT | ||
| if isinstance(field.datatype, (StructType, ArrayType, MapType)): | ||
| continue | ||
|
|
||
| castable = isinstance(value, str) and _can_cast_to_type(value, field.datatype) | ||
| if not castable: | ||
| if mode == "FAILFAST": | ||
| raise RuntimeError( | ||
| f"Failed to cast value '{value}' to " | ||
| f"{field.datatype.simple_string()} for field " | ||
| f"'{field_name}'.\nXML record: {record_str}" | ||
| ) | ||
| if mode == "DROPMALFORMED": | ||
| return None | ||
| # PERMISSIVE: null the bad field, continue checking remaining fields | ||
| row[field_name] = None | ||
| had_error = True | ||
|
|
||
| if had_error and mode == "PERMISSIVE": | ||
| row[column_name_of_corrupt_record] = record_str | ||
|
|
||
| return row | ||
|
|
||
|
|
||
| def struct_type_to_result_template(dt: DataType) -> Optional[dict]: | ||
|
|
@@ -367,7 +463,18 @@ def get_text(element: ET.Element) -> Optional[str]: | |
|
|
||
| children = list(element) | ||
| if not children and (not element.attrib or exclude_attributes): | ||
| # it's a value element with no attributes or excluded attributes, so return the text | ||
| # When the schema (result_template) expects a struct for this element, | ||
| # wrap the text in the template so the output shape matches the schema. | ||
| # e.g. <publisher>Some Publisher</publisher> with template | ||
| # {"_VALUE": None, "_country": None, "_language": None} becomes | ||
| # {"_VALUE": "Some Publisher", "_country": None, "_language": None} | ||
| # instead of the raw string "Some Publisher". | ||
| if result_template is not None and isinstance(result_template, dict): | ||
| result = copy.deepcopy(result_template) | ||
| text = get_text(element) | ||
| if text is not None: | ||
| result[value_tag] = text | ||
| return result | ||
| return get_text(element) | ||
|
|
||
| result = copy.deepcopy(result_template) if result_template is not None else {} | ||
|
|
@@ -445,6 +552,8 @@ def process_xml_range( | |
| row_validation_xsd_path: str, | ||
| chunk_size: int = DEFAULT_CHUNK_SIZE, | ||
| result_template: Optional[dict] = None, | ||
| schema_type: Optional[StructType] = None, | ||
| is_snowpark_connect_compatible: bool = False, | ||
| ) -> Iterator[Optional[Dict[str, Any]]]: | ||
| """ | ||
| Processes an XML file within a given approximate byte range. | ||
|
|
@@ -475,6 +584,8 @@ def process_xml_range( | |
| row_validation_xsd_path (str): Path to XSD file for row validation. | ||
| chunk_size (int): Size of chunks to read. | ||
| result_template(dict): a result template generate from user input schema | ||
| schema_type(StructType): the parsed StructType for row validation | ||
| is_snowpark_connect_compatible(bool): context._is_snowpark_connect_compatible_mode | ||
|
|
||
| Yields: | ||
| Optional[Dict[str, Any]]: Dictionary representation of the parsed XML element. | ||
|
|
@@ -607,10 +718,22 @@ def process_xml_range( | |
| ignore_surrounding_whitespace=ignore_surrounding_whitespace, | ||
| result_template=copy.deepcopy(result_template), | ||
| ) | ||
| if isinstance(result, dict): | ||
| yield result | ||
| else: | ||
| yield {value_tag: result} | ||
| row = result if isinstance(result, dict) else {value_tag: result} | ||
|
|
||
| # Validate primitive field values against schema types in Snowpark Connect mode only | ||
| if schema_type is not None and is_snowpark_connect_compatible: | ||
| # Mode handling for type mismatch errors. | ||
| row = _validate_row_for_type_mismatch( | ||
| row, | ||
| schema_type, | ||
| mode, | ||
| record_str, | ||
| column_name_of_corrupt_record, | ||
| ) | ||
|
|
||
| if row is not None: | ||
sfc-gh-aling marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| yield row | ||
| # Mode handling for malformed XML records that fail to parse. | ||
| except ET.ParseError as e: | ||
| if mode == "PERMISSIVE": | ||
| yield {column_name_of_corrupt_record: record_str} | ||
|
|
@@ -645,6 +768,7 @@ def process( | |
| ignore_surrounding_whitespace: bool, | ||
| row_validation_xsd_path: str, | ||
| custom_schema: str, | ||
| is_snowpark_connect_compatible: bool, | ||
| ): | ||
| """ | ||
| Splits the file into byte ranges—one per worker—by starting with an even | ||
|
|
@@ -668,12 +792,15 @@ def process( | |
| ignore_surrounding_whitespace (bool): Whether or not whitespaces surrounding values should be skipped. | ||
| row_validation_xsd_path (str): Path to XSD file for row validation. | ||
| custom_schema: User input schema for xml, must be used together with row tag. | ||
| is_snowpark_connect_compatible (bool): context._is_snowpark_connect_compatible_mode | ||
| """ | ||
| file_size = get_file_size(filename) | ||
| approx_chunk_size = file_size // num_workers | ||
| approx_start = approx_chunk_size * i | ||
| approx_end = approx_chunk_size * (i + 1) if i < num_workers - 1 else file_size | ||
| result_template = schema_string_to_result_dict_and_struct_type(custom_schema) | ||
| result_template, schema_type = schema_string_to_result_dict_and_struct_type( | ||
| custom_schema | ||
| ) | ||
| for element in process_xml_range( | ||
| filename, | ||
| row_tag, | ||
|
|
@@ -690,5 +817,7 @@ def process( | |
| ignore_surrounding_whitespace, | ||
| row_validation_xsd_path=row_validation_xsd_path, | ||
| result_template=result_template, | ||
| schema_type=schema_type, | ||
sfc-gh-aling marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| is_snowpark_connect_compatible=is_snowpark_connect_compatible, | ||
| ): | ||
| yield (element,) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does this cover all types -- I don't see decimal type here, so does it mean we don't need to handle the decimal type here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right, Spark by default doesn't infer decimal for XML. Spark's DecimalType inference is gated behind
options.prefersDecimal, which defaults to false.case v if options.prefersDecimal && decimalTry.isDefined => decimalTry.getThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
did we confirm that
options.prefersDecimalis out of the scope for implementation at this stage?