diff --git a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py index 06296eaea2..4918652dc3 100644 --- a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py +++ b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py @@ -1882,6 +1882,7 @@ def _create_xml_query( lit(ignore_surrounding_whitespace), lit(row_validation_xsd_path), lit(schema_string), + lit(context._is_snowpark_connect_compatible_mode), ), ) diff --git a/src/snowflake/snowpark/_internal/utils.py b/src/snowflake/snowpark/_internal/utils.py index 12b2b13794..453b838a7e 100644 --- a/src/snowflake/snowpark/_internal/utils.py +++ b/src/snowflake/snowpark/_internal/utils.py @@ -199,6 +199,7 @@ # The following are not copy into SQL command options but client side options. "INFER_SCHEMA", "INFER_SCHEMA_OPTIONS", + "SAMPLING_RATIO", "FORMAT_TYPE_OPTIONS", "TARGET_COLUMNS", "TRANSFORMATIONS", @@ -210,6 +211,9 @@ XML_ROW_TAG_STRING = "ROWTAG" XML_ROW_DATA_COLUMN_NAME = "ROW_DATA" XML_READER_FILE_PATH = os.path.join(os.path.dirname(__file__), "xml_reader.py") +XML_SCHEMA_INFERENCE_FILE_PATH = os.path.join( + os.path.dirname(__file__), "xml_schema_inference.py" +) XML_READER_API_SIGNATURE = "DataFrameReader.xml[rowTag]" XML_READER_SQL_COMMENT = f"/* Python:snowflake.snowpark.{XML_READER_API_SIGNATURE} */" diff --git a/src/snowflake/snowpark/_internal/xml_reader.py b/src/snowflake/snowpark/_internal/xml_reader.py index 96cf5b723e..4a5300ff0b 100644 --- a/src/snowflake/snowpark/_internal/xml_reader.py +++ b/src/snowflake/snowpark/_internal/xml_reader.py @@ -2,6 +2,7 @@ # Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved. # +import datetime import os import re import html.entities @@ -12,7 +13,18 @@ from snowflake.snowpark._internal.analyzer.analyzer_utils import unquote_if_quoted from snowflake.snowpark._internal.type_utils import type_string_to_type_object from snowflake.snowpark.files import SnowflakeFile -from snowflake.snowpark.types import StructType, ArrayType, DataType, MapType +from snowflake.snowpark.types import ( + ArrayType, + BooleanType, + DataType, + DateType, + DoubleType, + LongType, + MapType, + StringType, + StructType, + TimestampType, +) # lxml is only a dev dependency so use try/except to import it if available try: @@ -102,13 +114,97 @@ def _restore_colons_in_template(template: Optional[dict]) -> Optional[dict]: return restored -def schema_string_to_result_dict_and_struct_type(schema_string: str) -> Optional[dict]: +def schema_string_to_result_dict_and_struct_type( + schema_string: str, +) -> Tuple[Optional[dict], Optional[StructType]]: if schema_string == "": - return None + return None, None safe_string = _escape_colons_in_quotes(schema_string) schema = type_string_to_type_object(safe_string) template = struct_type_to_result_template(schema) - return _restore_colons_in_template(template) + return _restore_colons_in_template(template), schema + + +def _can_cast_to_type(value: str, target_type: DataType) -> bool: + if isinstance(target_type, StringType): + return True + if isinstance(target_type, LongType): + try: + int(value) + return True + except (ValueError, OverflowError): + return False + if isinstance(target_type, DoubleType): + try: + float(value) + return True + except ValueError: + return False + if isinstance(target_type, BooleanType): + return value.lower() in ("true", "false", "1", "0") + if isinstance(target_type, DateType): + try: + datetime.date.fromisoformat(value) + return True + except (ValueError, TypeError): + return False + if isinstance(target_type, TimestampType): + try: + datetime.datetime.fromisoformat(value) + return True + except (ValueError, TypeError): + return False + return True + + +def _validate_row_for_type_mismatch( + row: dict, + schema: StructType, + mode: str, + record_str: str = "", + column_name_of_corrupt_record: str = "_corrupt_record", +) -> Optional[dict]: + """Validate a parsed row dict against the expected schema types for mode handling: + - PERMISSIVE: set mismatched fields to ``None`` and store the raw XML + record in *column_name_of_corrupt_record* (Spark compatible). + - FAILFAST: raise immediately on the first mismatch. + - DROPMALFORMED: return ``None`` so the caller skips the row. + + Only top-level primitive fields are validated because complex types + are kept as VARIANT and never cast downstream. + """ + had_error = False + for field in schema.fields: + field_name = unquote_if_quoted(field.name) + if field_name not in row: + continue + + value = row[field_name] + if value is None: + continue + + # Skip complex types as these are kept as VARIANT + if isinstance(field.datatype, (StructType, ArrayType, MapType)): + continue + + castable = isinstance(value, str) and _can_cast_to_type(value, field.datatype) + if not castable: + if mode == "FAILFAST": + raise RuntimeError( + f"Failed to cast value '{value}' to " + f"{field.datatype.simple_string()} for field " + f"'{field_name}'.\nXML record: {record_str}" + ) + if mode == "DROPMALFORMED": + return None + # PERMISSIVE: null the bad field, continue checking remaining fields + row[field_name] = None + had_error = True + + if had_error and mode == "PERMISSIVE": + row[column_name_of_corrupt_record] = record_str + + return row def struct_type_to_result_template(dt: DataType) -> Optional[dict]: @@ -367,7 +463,18 @@ def get_text(element: ET.Element) -> Optional[str]: children = list(element) if not children and (not element.attrib or exclude_attributes): - # it's a value element with no attributes or excluded attributes, so return the text + # When the schema (result_template) expects a struct for this element, + # wrap the text in the template so the output shape matches the schema. + # e.g. Some Publisher with template + # {"_VALUE": None, "_country": None, "_language": None} becomes + # {"_VALUE": "Some Publisher", "_country": None, "_language": None} + # instead of the raw string "Some Publisher". + if result_template is not None and isinstance(result_template, dict): + result = copy.deepcopy(result_template) + text = get_text(element) + if text is not None: + result[value_tag] = text + return result return get_text(element) result = copy.deepcopy(result_template) if result_template is not None else {} @@ -445,6 +552,8 @@ def process_xml_range( row_validation_xsd_path: str, chunk_size: int = DEFAULT_CHUNK_SIZE, result_template: Optional[dict] = None, + schema_type: Optional[StructType] = None, + is_snowpark_connect_compatible: bool = False, ) -> Iterator[Optional[Dict[str, Any]]]: """ Processes an XML file within a given approximate byte range. @@ -475,6 +584,8 @@ def process_xml_range( row_validation_xsd_path (str): Path to XSD file for row validation. chunk_size (int): Size of chunks to read. result_template(dict): a result template generate from user input schema + schema_type(StructType): the parsed StructType for row validation + is_snowpark_connect_compatible(bool): context._is_snowpark_connect_compatible_mode Yields: Optional[Dict[str, Any]]: Dictionary representation of the parsed XML element. @@ -607,10 +718,22 @@ def process_xml_range( ignore_surrounding_whitespace=ignore_surrounding_whitespace, result_template=copy.deepcopy(result_template), ) - if isinstance(result, dict): - yield result - else: - yield {value_tag: result} + row = result if isinstance(result, dict) else {value_tag: result} + + # Validate primitive field values against schema types in Snowpark Connect mode only + if schema_type is not None and is_snowpark_connect_compatible: + # Mode handling for type mismatch errors. + row = _validate_row_for_type_mismatch( + row, + schema_type, + mode, + record_str, + column_name_of_corrupt_record, + ) + + if row is not None: + yield row + # Mode handling for malformed XML records that fail to parse. except ET.ParseError as e: if mode == "PERMISSIVE": yield {column_name_of_corrupt_record: record_str} @@ -645,6 +768,7 @@ def process( ignore_surrounding_whitespace: bool, row_validation_xsd_path: str, custom_schema: str, + is_snowpark_connect_compatible: bool, ): """ Splits the file into byte ranges—one per worker—by starting with an even @@ -668,12 +792,15 @@ def process( ignore_surrounding_whitespace (bool): Whether or not whitespaces surrounding values should be skipped. row_validation_xsd_path (str): Path to XSD file for row validation. custom_schema: User input schema for xml, must be used together with row tag. + is_snowpark_connect_compatible (bool): context._is_snowpark_connect_compatible_mode """ file_size = get_file_size(filename) approx_chunk_size = file_size // num_workers approx_start = approx_chunk_size * i approx_end = approx_chunk_size * (i + 1) if i < num_workers - 1 else file_size - result_template = schema_string_to_result_dict_and_struct_type(custom_schema) + result_template, schema_type = schema_string_to_result_dict_and_struct_type( + custom_schema + ) for element in process_xml_range( filename, row_tag, @@ -690,5 +817,7 @@ def process( ignore_surrounding_whitespace, row_validation_xsd_path=row_validation_xsd_path, result_template=result_template, + schema_type=schema_type, + is_snowpark_connect_compatible=is_snowpark_connect_compatible, ): yield (element,) diff --git a/src/snowflake/snowpark/_internal/xml_schema_inference.py b/src/snowflake/snowpark/_internal/xml_schema_inference.py new file mode 100644 index 0000000000..15ee78d6b8 --- /dev/null +++ b/src/snowflake/snowpark/_internal/xml_schema_inference.py @@ -0,0 +1,711 @@ +# +# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved. +# + +import re +import random +from datetime import datetime, date +from typing import Optional, Dict, List + +from snowflake.snowpark._internal.xml_reader import ( + DEFAULT_CHUNK_SIZE, + get_file_size, + find_next_opening_tag_pos, + tag_is_self_closing, + find_next_closing_tag_pos, + strip_xml_namespaces, + replace_entity, +) +from snowflake.snowpark.files import SnowflakeFile +from snowflake.snowpark.types import ( + StructType, + ArrayType, + DataType, + NullType, + StringType, + BooleanType, + LongType, + DoubleType, + DecimalType, + DateType, + TimestampType, + StructField, +) + +# lxml is only a dev dependency so use try/except to import it if available +try: + import lxml.etree as ET + + lxml_installed = True +except ImportError: + import xml.etree.ElementTree as ET + + lxml_installed = False + + +# --------------------------------------------------------------------------- +# Stage 1 – Per-record type inference +# --------------------------------------------------------------------------- + + +def _normalize_text( + text: Optional[str], ignore_surrounding_whitespace: bool +) -> Optional[str]: + """Normalize text by stripping whitespace if configured.""" + if text is None: + return None + return text.strip() if ignore_surrounding_whitespace else text + + +def _infer_primitive_type(text: str) -> DataType: + """ + Infer the DataType from a single string value with below priority order: + - null/empty -> NullType + - parseable as Long -> LongType + - parseable as Double -> DoubleType + - "true"/"false" -> BooleanType + - parseable as Date -> DateType (ISO format: yyyy-mm-dd) + - parseable as Timestamp -> TimestampType (ISO format) + - anything else -> StringType + """ + # Long (matching Spark: infers integers as LongType directly) + try: + sign_safe = text.lstrip("+-") + if sign_safe and sign_safe[0].isdigit() and "." not in text: + val = int(text) + # Spark's Long is 64-bit; Python's int is unbounded + if -(2**63) <= val <= 2**63 - 1: + return LongType() + # Numbers outside Long range fall through to Double + except (ValueError, OverflowError): + pass + + # Double + try: + sign_safe = text.lstrip("+-") + is_numeric_start = sign_safe and (sign_safe[0].isdigit() or sign_safe[0] == ".") + is_special_float = text.lower() in ( + "nan", + "infinity", + "+infinity", + "-infinity", + "inf", + "+inf", + "-inf", + ) + if is_numeric_start or is_special_float: + # Reject strings ending in d/D/f/F + if text[-1] not in ("d", "D", "f", "F"): + float(text) + return DoubleType() + except (ValueError, OverflowError): + pass + + # Boolean + if text.lower() in ("true", "false"): + return BooleanType() + + # Date (Spark default pattern: yyyy-MM-dd via DateFormatter) + try: + date.fromisoformat(text) + return DateType() + except (ValueError, TypeError): + pass + + # Timestamp (Spark default: TimestampFormatter with CAST logic) + # Backward compatibility: Python 3.9 fromisoformat doesn't support 'Z' suffix; replace with +00:00 + try: + ts_text = text.replace("Z", "+00:00") if text.endswith("Z") else text + datetime.fromisoformat(ts_text) + return TimestampType() + except (ValueError, TypeError): + pass + + return StringType() + + +def infer_type( + value: Optional[str], + ignore_surrounding_whitespace: bool = False, + null_value: str = "", +) -> DataType: + """ + Infer the DataType from a single string value. + Normalizes *value*, checks for null / empty / null_value, then delegates + to :func:`_infer_primitive_type`. + """ + text = _normalize_text(value, ignore_surrounding_whitespace) + if text is None or text == null_value or text == "": + return NullType() + return _infer_primitive_type(text) + + +def infer_element_schema( + element: ET.Element, + attribute_prefix: str = "_", + exclude_attributes: bool = False, + value_tag: str = "_VALUE", + null_value: str = "", + ignore_surrounding_whitespace: bool = False, + ignore_namespace: bool = False, + is_root: bool = True, +) -> DataType: + """ + Infer the schema (DataType) from a parsed XML Element. + - Elements with no children and no attributes (or excluded) -> infer from text (primitive/NullType) + - Elements with children -> StructType with child field types + - Elements with attributes -> StructType with attribute fields + - Mixed content (text + children/attributes) -> StructType with _VALUE field + - Repeated child tags -> ArrayType detection via add_or_update_type + + is_root: True for the row-tag element (top-level), False for child elements. + Spark treats these differently: at root level, self-closing attribute-only + elements do NOT get _VALUE. At child level, they always get _VALUE + (NullType -> StringType after canonicalization). + """ + children = list(element) + has_attributes = bool(element.attrib) and not exclude_attributes + + # Case: leaf element with no attributes -> infer from text content + if not children and not has_attributes: + return infer_type(element.text, ignore_surrounding_whitespace, null_value) + + # This element will become a StructType + # Use a dict to track field names and types (for array detection) + name_to_type: Dict[str, DataType] = {} + field_order: List[str] = [] + + # Process attributes first + if has_attributes: + for attr_name, attr_value in element.attrib.items(): + prefixed_name = f"{attribute_prefix}{attr_name}" + attr_type = infer_type( + attr_value, ignore_surrounding_whitespace, null_value + ) + if prefixed_name not in name_to_type: + field_order.append(prefixed_name) + add_or_update_type(name_to_type, prefixed_name, attr_type, value_tag) + + if children: + for child in children: + child_tag = child.tag + # Ignore namespace in tag if configured (both Clark and prefix notation) + if ignore_namespace: + if "}" in child_tag: + child_tag = child_tag.split("}", 1)[1] + elif ":" in child_tag: + child_tag = child_tag.split(":", 1)[1] + + # Check if child has attributes + child_has_attrs = bool(child.attrib) and not exclude_attributes + + # inferField dispatch + child_children = list(child) + if not child_children and not child_has_attrs: + # Leaf element + child_type = infer_type( + child.text, ignore_surrounding_whitespace, null_value + ) + else: + # Non-leaf element: recurse into inferObject + child_type = infer_element_schema( + child, + attribute_prefix=attribute_prefix, + exclude_attributes=exclude_attributes, + value_tag=value_tag, + null_value=null_value, + ignore_surrounding_whitespace=ignore_surrounding_whitespace, + ignore_namespace=ignore_namespace, + is_root=False, + ) + + # When child_has_attrs is True, the recursive infer_element_schema call + # above already processes child attributes and includes them in the + # returned StructType. No additional attribute processing needed here. + if child_tag not in name_to_type: + field_order.append(child_tag) + add_or_update_type(name_to_type, child_tag, child_type, value_tag) + + # Handle mixed content: text + child elements + text = _normalize_text(element.text, ignore_surrounding_whitespace) + if text is not None and text != null_value and text.strip() != "": + text_type = _infer_primitive_type(text) + if value_tag not in name_to_type: + field_order.append(value_tag) + add_or_update_type(name_to_type, value_tag, text_type, value_tag) + else: + # No children but has attributes -> conditionally include _VALUE. + # [SPARK PARITY] Spark's behavior differs by element level: + # - Root/row-tag level: _VALUE only added when actual text exists + # - Child level: _VALUE always added (NullType if no text, canonicalized + # to StringType later). This covers cases like self-closing + # . + text = _normalize_text(element.text, ignore_surrounding_whitespace) + if text is not None and text != null_value and text != "": + text_type = _infer_primitive_type(text) + if value_tag not in name_to_type: + field_order.append(value_tag) + add_or_update_type(name_to_type, value_tag, text_type, value_tag) + elif not is_root: + # Child-level attribute-only element: add _VALUE as NullType + if value_tag not in name_to_type: + field_order.append(value_tag) + add_or_update_type(name_to_type, value_tag, NullType(), value_tag) + + # Build the StructType with sorted fields to match Spark behavior. + result_fields = sorted( + (StructField(name, name_to_type[name], nullable=True) for name in field_order), + key=lambda f: f.name, + ) + return StructType(result_fields) + + +# --------------------------------------------------------------------------- +# Stage 2 – Cross-partition merge +# --------------------------------------------------------------------------- + + +def compatible_type(t1: DataType, t2: DataType, value_tag: str = "_VALUE") -> DataType: + """ + Returns the most general data type for two given data types. + """ + # Same type + if type(t1) == type(t2): + if isinstance(t1, StructType): + return merge_struct_types(t1, t2, value_tag) + if isinstance(t1, ArrayType): + return ArrayType( + compatible_type(t1.element_type, t2.element_type, value_tag), + ) + if isinstance(t1, DecimalType): + # Widen decimal + scale = max(t1.scale, t2.scale) + range_ = max(t1.precision - t1.scale, t2.precision - t2.scale) + if range_ + scale > 38: + return DoubleType() + return DecimalType(range_ + scale, scale) + return t1 + + # NullType + T -> T + if isinstance(t1, NullType): + return t2 + if isinstance(t2, NullType): + return t1 + + # Numeric widening: Long < Double + if (isinstance(t1, LongType) and isinstance(t2, DoubleType)) or ( + isinstance(t1, DoubleType) and isinstance(t2, LongType) + ): + return DoubleType() + + # Double + Decimal -> Double + if (isinstance(t1, DoubleType) and isinstance(t2, DecimalType)) or ( + isinstance(t1, DecimalType) and isinstance(t2, DoubleType) + ): + return DoubleType() + + # Long + Decimal -> Decimal (widened) + if isinstance(t1, LongType) and isinstance(t2, DecimalType): + # DecimalType.forType(LongType) in Spark is Decimal(20, 0) + return compatible_type(DecimalType(20, 0), t2, value_tag) + if isinstance(t1, DecimalType) and isinstance(t2, LongType): + return compatible_type(t1, DecimalType(20, 0), value_tag) + + # Timestamp + Date -> Timestamp + if (isinstance(t1, TimestampType) and isinstance(t2, DateType)) or ( + isinstance(t1, DateType) and isinstance(t2, TimestampType) + ): + return TimestampType() + + # Array + non-Array -> ArrayType(compatible) + if isinstance(t1, ArrayType): + return ArrayType(compatible_type(t1.element_type, t2, value_tag)) + if isinstance(t2, ArrayType): + return ArrayType(compatible_type(t1, t2.element_type, value_tag)) + + # Struct with _VALUE tag + Primitive -> widen _VALUE field + if isinstance(t1, StructType) and _struct_has_value_tag(t1, value_tag): + return _merge_struct_with_primitive(t1, t2, value_tag) + if isinstance(t2, StructType) and _struct_has_value_tag(t2, value_tag): + return _merge_struct_with_primitive(t2, t1, value_tag) + + # Fallback: anything else -> StringType + return StringType() + + +def _struct_has_value_tag(st: StructType, value_tag: str) -> bool: + """Check if a StructType has a field with the given value_tag name.""" + return any(f.name == value_tag for f in st.fields) + + +def _merge_struct_with_primitive( + st: StructType, primitive: DataType, value_tag: str +) -> StructType: + """ + Merge a StructType containing a value_tag field with a primitive type. + The value_tag field's type is widened to be compatible with the primitive. + """ + new_fields = [] + for f in st.fields: + if f.name == value_tag: + new_type = compatible_type(f.datatype, primitive, value_tag) + new_fields.append(StructField(f.name, new_type, nullable=True)) + else: + new_fields.append(f) + return StructType(new_fields) + + +def merge_struct_types( + a: StructType, b: StructType, value_tag: str = "_VALUE" +) -> StructType: + """ + Merge two StructTypes field-by-field (case-sensitive). + Fields present in both are merged via compatible_type. + Fields present in only one are included as-is (nullable). + + Uses f._name (original case from XML) rather than f.name (uppercased by ColumnIdentifier). + """ + field_map: Dict[str, DataType] = {} + field_order: List[str] = [] + + for f in a.fields: + field_map[f._name] = f.datatype + field_order.append(f._name) + + for f in b.fields: + if f._name in field_map: + field_map[f._name] = compatible_type( + field_map[f._name], f.datatype, value_tag + ) + else: + field_map[f._name] = f.datatype + field_order.append(f._name) + + # Sort fields by name to match Spark behavior + return StructType( + sorted( + (StructField(name, field_map[name], nullable=True) for name in field_order), + key=lambda f: f._name, + ), + ) + + +def add_or_update_type( + name_to_type: Dict[str, DataType], + field_name: str, + new_type: DataType, + value_tag: str = "_VALUE", +) -> None: + """ + Array detection logic: + - 1st occurrence of field_name -> store the type as-is + - 2nd occurrence -> wrap into ArrayType(compatible(old, new)) + - Nth occurrence -> the existing type is already ArrayType, merge via compatible_type + """ + if field_name in name_to_type: + old_type = name_to_type[field_name] + if not isinstance(old_type, ArrayType): + # 2nd occurrence: promote to ArrayType + name_to_type[field_name] = ArrayType( + compatible_type(old_type, new_type, value_tag), + ) + else: + # Already an ArrayType: merge element types + name_to_type[field_name] = compatible_type(old_type, new_type, value_tag) + else: + name_to_type[field_name] = new_type + + +# --------------------------------------------------------------------------- +# Stage 3 – Canonicalization +# --------------------------------------------------------------------------- + + +def canonicalize_type(dt: DataType) -> Optional[DataType]: + """ + Convert NullType to StringType and remove StructTypes with no fields: + - NullType -> StringType + - Empty StructType -> None + - ArrayType -> recurse on element type + - StructType -> recurse, remove empty-name fields + - Other -> kept as-is + """ + if isinstance(dt, NullType): + return StringType() + + if isinstance(dt, ArrayType): + canonical_element = canonicalize_type(dt.element_type) + if canonical_element is not None: + return ArrayType(canonical_element) + return None + + if isinstance(dt, StructType): + canonical_fields = [] + for f in dt.fields: + if f._name == "": + continue + canonical_child = canonicalize_type(f.datatype) + if canonical_child is not None: + canonical_fields.append( + StructField(f._name, canonical_child, nullable=True) + ) + if canonical_fields: + return StructType(canonical_fields) + # empty structs should be deleted + return None + + return dt + + +# --------------------------------------------------------------------------- +# Schema string serialization +# --------------------------------------------------------------------------- + + +def _case_preserving_simple_string(dt: DataType) -> str: + """ + Serialize a DataType to a simple string, preserving original field name case. + Wrap field names with colons in double quotes so that the schema string parser + can correctly find the top-level colon separating name from type. + The consumer of these strings must strip the outer quotes after parsing. + """ + if isinstance(dt, StructType): + parts = [] + for f in dt.fields: + name = f._name + # Wrap names with colons in double quotes so the parser can + # distinguish the name:type separator from colons in the name. + if ":" in name: + name = f'"{name}"' + parts.append(f"{name}:{_case_preserving_simple_string(f.datatype)}") + return f"struct<{','.join(parts)}>" + elif isinstance(dt, ArrayType): + return f"array<{_case_preserving_simple_string(dt.element_type)}>" + else: + return dt.simple_string() + + +# --------------------------------------------------------------------------- +# XMLSchemaInference UDTF +# --------------------------------------------------------------------------- + + +def infer_schema_for_xml_range( + file_path: str, + row_tag: str, + approx_start: int, + approx_end: int, + sampling_ratio: float, + ignore_namespace: bool, + attribute_prefix: str, + exclude_attributes: bool, + value_tag: str, + null_value: str, + charset: str, + ignore_surrounding_whitespace: bool, + chunk_size: int = DEFAULT_CHUNK_SIZE, +) -> Optional[StructType]: + """ + Infer the merged XML schema for all records within a byte range. + + Scans the file from *approx_start* to *approx_end*, parses each + XML record delimited by *row_tag*, infers per-record schemas, and + merges them into a single StructType. + + Returns: + The merged StructType, or None if no records were found. + """ + tag_start_1 = f"<{row_tag}>".encode() + tag_start_2 = f"<{row_tag} ".encode() + closing_tag = f"".encode() + + merged_schema: Optional[StructType] = None + + with SnowflakeFile.open(file_path, "rb", require_scoped_url=False) as f: + f.seek(approx_start) + + while True: + try: + open_pos = find_next_opening_tag_pos( + f, tag_start_1, tag_start_2, approx_end, chunk_size + ) + except EOFError: + break + + if open_pos >= approx_end: + break + + record_start = open_pos + f.seek(record_start) + + try: + is_self_close, tag_end = tag_is_self_closing(f, chunk_size) + if is_self_close: + record_end = tag_end + else: + f.seek(tag_end) + record_end = find_next_closing_tag_pos(f, closing_tag, chunk_size) + except Exception: + try: + f.seek(min(record_start + 1, approx_end)) + except Exception: + break + continue + + if sampling_ratio < 1.0 and random.random() > sampling_ratio: + if record_end > approx_end: + break + try: + f.seek(min(record_end, approx_end)) + except Exception: + break + continue + + try: + f.seek(record_start) + record_bytes = f.read(record_end - record_start) + record_str = record_bytes.decode(charset, errors="replace") + record_str = re.sub(r"&(\w+);", replace_entity, record_str) + + if lxml_installed: + recover = bool(":" in row_tag) + parser = ET.XMLParser(recover=recover, ns_clean=True) + try: + element = ET.fromstring(record_str, parser) + except ET.XMLSyntaxError: + if ignore_namespace: + cleaned = re.sub(r"\s+(\w+):(\w+)=", r" \2=", record_str) + element = ET.fromstring(cleaned, parser) + else: + raise + else: + element = ET.fromstring(record_str) + + if ignore_namespace: + element = strip_xml_namespaces(element) + except Exception: + if record_end > approx_end: + break + try: + f.seek(min(record_end, approx_end)) + except Exception: + break + continue + + record_schema = infer_element_schema( + element, + attribute_prefix=attribute_prefix, + exclude_attributes=exclude_attributes, + value_tag=value_tag, + null_value=null_value, + ignore_surrounding_whitespace=ignore_surrounding_whitespace, + ignore_namespace=ignore_namespace, + ) + + if not isinstance(record_schema, StructType): + record_schema = StructType( + [StructField(value_tag, record_schema, nullable=True)], + ) + + if merged_schema is None: + merged_schema = record_schema + else: + merged_schema = merge_struct_types( + merged_schema, record_schema, value_tag + ) + + if record_end > approx_end: + break + try: + f.seek(min(record_end, approx_end)) + except Exception: + break + + return merged_schema + + +class XMLSchemaInference: + """ + UDTF handler for parallelized XML schema inference. + + Each worker reads its assigned byte range of the XML file, parses + XML records, infers a per-record schema, and merges all schemas + within its partition. The merged schema is yielded as a serialized + string (StructType.simple_string()). + + The parallelization pattern mirrors XMLReader. + """ + + def process( + self, + filename: str, + num_workers: int, + row_tag: str, + i: int, + sampling_ratio: float, + ignore_namespace: bool, + attribute_prefix: str, + exclude_attributes: bool, + value_tag: str, + null_value: str, + charset: str, + ignore_surrounding_whitespace: bool, + ): + """ + Infer XML schema for a byte-range partition of the file. + + Args: + filename: Path to the XML file. + num_workers: Total number of workers. + row_tag: The tag that delimits records. + i: This worker's ID (0-based). + sampling_ratio: Fraction of records to sample (0.0-1.0). + ignore_namespace: Whether to strip namespaces. + attribute_prefix: Prefix for attribute names. + exclude_attributes: Whether to exclude attributes. + value_tag: Tag name for the value column. + null_value: Value to treat as null. + charset: Character encoding of the XML file. + ignore_surrounding_whitespace: Whether to strip whitespace from values. + """ + file_size = get_file_size(filename) + if not file_size or file_size <= 0: + yield ("",) + return + + if num_workers is None or num_workers <= 0: + num_workers = 1 + if i is None or i < 0: + i = 0 + if i >= num_workers: + yield ("",) + return + + approx_chunk_size = file_size // num_workers + approx_start = approx_chunk_size * i + approx_end = approx_chunk_size * (i + 1) if i < num_workers - 1 else file_size + + # Deterministic per-worker seed for Bernoulli sampling: + if sampling_ratio < 1.0: + random.seed(1 + i) + + merged_schema = infer_schema_for_xml_range( + file_path=filename, + row_tag=row_tag, + approx_start=approx_start, + approx_end=approx_end, + sampling_ratio=sampling_ratio, + ignore_namespace=ignore_namespace, + attribute_prefix=attribute_prefix, + exclude_attributes=exclude_attributes, + value_tag=value_tag, + null_value=null_value, + charset=charset, + ignore_surrounding_whitespace=ignore_surrounding_whitespace, + ) + + yield ( + _case_preserving_simple_string(merged_schema) + if merged_schema is not None + else "", + ) diff --git a/src/snowflake/snowpark/dataframe_reader.py b/src/snowflake/snowpark/dataframe_reader.py index 218c086b5e..9c8d7824b4 100644 --- a/src/snowflake/snowpark/dataframe_reader.py +++ b/src/snowflake/snowpark/dataframe_reader.py @@ -12,6 +12,7 @@ from datetime import datetime import snowflake.snowpark +import snowflake.snowpark.context as context import snowflake.snowpark._internal.proto.generated.ast_pb2 as proto from snowflake.snowpark._internal.analyzer.analyzer_utils import ( convert_value_to_sql_option, @@ -51,16 +52,21 @@ convert_sf_to_sp_type, convert_sp_to_sf_type, most_permissive_type, + type_string_to_type_object, ) from snowflake.snowpark._internal.udf_utils import get_types_from_type_hints +from snowflake.snowpark._internal.xml_reader import DEFAULT_CHUNK_SIZE +from snowflake.snowpark._internal.xml_schema_inference import ( + merge_struct_types, + canonicalize_type, +) from snowflake.snowpark._internal.utils import ( SNOWURL_PREFIX, - STAGE_PREFIX, XML_ROW_TAG_STRING, XML_ROW_DATA_COLUMN_NAME, XML_READER_FILE_PATH, + XML_SCHEMA_INFERENCE_FILE_PATH, XML_READER_API_SIGNATURE, - XML_READER_SQL_COMMENT, INFER_SCHEMA_FORMAT_TYPES, SNOWFLAKE_PATH_PREFIXES, TempObjectType, @@ -85,6 +91,9 @@ from snowflake.snowpark.mock._connection import MockServerConnection from snowflake.snowpark.table import Table from snowflake.snowpark.types import ( + ArrayType, + MapType, + StringType, StructType, TimestampTimeZone, VariantType, @@ -118,6 +127,7 @@ "PATHGLOBFILTER": "PATTERN", "FILENAMEPATTERN": "PATTERN", "INFERSCHEMA": "INFER_SCHEMA", + "SAMPLINGRATIO": "SAMPLING_RATIO", "SEP": "FIELD_DELIMITER", "LINESEP": "RECORD_DELIMITER", "QUOTE": "FIELD_OPTIONALLY_ENCLOSED_BY", @@ -473,6 +483,7 @@ def __init__( self._infer_schema_target_columns: Optional[List[str]] = None self.__format: Optional[str] = None self._data_source_format = ["jdbc", "dbapi"] + self._xml_inferred_schema: Optional[StructType] = None self._ast = None if _emit_ast: @@ -1070,19 +1081,53 @@ def xml(self, path: str, _emit_ast: bool = True) -> DataFrame: ast.reader.CopyFrom(self._ast) df._ast_id = stmt.uid - # cast to input custom schema type - # TODO: SNOW-2923003: remove single quote after server side BCR is done - if self._user_schema: - cols = [ - df[single_quote(field._name)] - .cast(field.datatype) - .alias(quote_name_without_upper_casing(field._name)) - for field in self._user_schema.fields - ] - return df.select(cols) + # xml_reader returns VARIANT DataFrame, cast to custom or inferred schema type + effective_schema = self._user_schema or self._xml_inferred_schema + if effective_schema is not None: + return self._apply_xml_schema(df, effective_schema) else: return df + def _apply_xml_schema(self, df: DataFrame, schema: StructType) -> DataFrame: + """Apply an XML schema to a VARIANT DataFrame""" + cols = [] + for field in schema.fields: + # TODO: SNOW-2923003: remove single quote after server side BCR is done + col = df[single_quote(field._name)] + if isinstance(field.datatype, (StructType, ArrayType, MapType)): + # Complex types: keep as VARIANT to prevent structured cast errors + cols.append(col.alias(quote_name_without_upper_casing(field._name))) + else: + # Primitive types: cast to the inferred datatype + cols.append( + col.cast(field.datatype).alias( + quote_name_without_upper_casing(field._name) + ) + ) + + # In PERMISSIVE mode, append the StringType column for the corrupt record if it + # exists in the DataFrame emitted by the UDTF on malformed or corrupted records. + mode = self._cur_options.get("MODE", "PERMISSIVE").upper() + corrupt_col_name = self._cur_options.get( + "COLUMNNAMEOFCORRUPTRECORD", "_corrupt_record" + ) + if mode == "PERMISSIVE": + df_columns = {c.strip('"').strip("'") for c in df.columns} + if corrupt_col_name in df_columns: + corrupt_ref = df[single_quote(corrupt_col_name)] + cols.append( + corrupt_ref.cast(StringType()).alias( + quote_name_without_upper_casing(corrupt_col_name) + ) + ) + if self._xml_inferred_schema is not None: + self._xml_inferred_schema.fields.append( + StructField(corrupt_col_name, StringType()) + ) + + result = df.select(cols) + return result + @publicapi def option(self, key: str, value: Any, _emit_ast: bool = True) -> "DataFrameReader": """Sets the specified option in the DataFrameReader. @@ -1368,6 +1413,163 @@ def _get_schema_from_user_input( read_file_transformations = [t._expression.sql for t in transformations] return new_schema, schema_to_cast, read_file_transformations + def _resolve_xml_file_for_udtf(self, local_file_path: str) -> str: + """Return the UDTF file path, uploading to a temp stage in stored procedures.""" + if is_in_stored_procedure(): # pragma: no cover + session_stage = self._session.get_session_stage() + self._session._conn.upload_file( + local_file_path, + session_stage, + compress_data=False, + overwrite=True, + skip_upload_on_content_match=True, + ) + return f"{session_stage}/{os.path.basename(local_file_path)}" + return local_file_path + + def _infer_schema_for_xml(self, path: str) -> Optional[StructType]: + # Register the XMLSchemaInference UDTF + handler_name = "XMLSchemaInference" + _, input_types = get_types_from_type_hints( + (XML_SCHEMA_INFERENCE_FILE_PATH, handler_name), + TempObjectType.TABLE_FUNCTION, + ) + + inference_file_path = self._resolve_xml_file_for_udtf( + XML_SCHEMA_INFERENCE_FILE_PATH + ) + + output_schema = StructType([StructField("SCHEMA_VALUE", StringType(), True)]) + schema_udtf = self._session.udtf.register_from_file( + inference_file_path, + handler_name, + output_schema=output_schema, + input_types=input_types, + packages=["snowflake-snowpark-python", "lxml<6"], + replace=True, + _suppress_local_package_warnings=True, + ) + + # Determine number of workers + try: + file_size = int( + self._session.sql(f"ls {path}", _emit_ast=False).collect( + _emit_ast=False + )[0]["size"] + ) + except IndexError: + raise ValueError(f"{path} does not exist") + + num_workers = min(16, file_size // DEFAULT_CHUNK_SIZE + 1) + + row_tag = self._cur_options[XML_ROW_TAG_STRING] + sampling_ratio = float(self._cur_options.get("SAMPLING_RATIO", 1.0)) + if sampling_ratio <= 0: + raise ValueError( + f"samplingRatio ({sampling_ratio}) should be greater than 0" + ) + ignore_namespace = self._cur_options.get("IGNORENAMESPACE", True) + attribute_prefix = self._cur_options.get("ATTRIBUTEPREFIX", "_") + exclude_attributes = self._cur_options.get("EXCLUDEATTRIBUTES", False) + value_tag = self._cur_options.get("VALUETAG", "_VALUE") + null_value = self._cur_options.get("NULL_IF", "") + charset = self._cur_options.get("CHARSET", "utf-8") + ignore_surrounding_whitespace = self._cur_options.get( + "IGNORESURROUNDINGWHITESPACE", False + ) + + # Create range DataFrame and apply UDTF + worker_col = "WORKER" + df = self._session.range(num_workers).to_df(worker_col) + df = df.select( + schema_udtf( + lit(path), + lit(num_workers), + lit(row_tag), + col(worker_col), + lit(sampling_ratio), + lit(ignore_namespace), + lit(attribute_prefix), + lit(exclude_attributes), + lit(value_tag), + lit(null_value), + lit(charset), + lit(ignore_surrounding_whitespace), + ) + ) + + # Collect and merge schema results from all workers + results = df.collect(_emit_ast=False) + merged_schema: Optional[StructType] = None + + for row in results: + schema_str = row[0] + if schema_str is None or schema_str == "": + continue + try: + partial_schema = type_string_to_type_object(schema_str) + except Exception: + continue + if not isinstance(partial_schema, StructType): + partial_schema = StructType( + [StructField(value_tag, partial_schema, nullable=True)] + ) + if merged_schema is None: + merged_schema = partial_schema + else: + merged_schema = merge_struct_types( + merged_schema, partial_schema, value_tag + ) + + if merged_schema is None: + return None + + # Canonicalize: NullType -> StringType, remove empty structs + canonical = canonicalize_type(merged_schema) + if canonical is None or not isinstance(canonical, StructType): + return None + + def _strip_outer_quotes(name: str) -> str: + """Strip one layer of double quotes added by _case_preserving_simple_string()""" + if len(name) >= 2 and name[0] == '"' and name[-1] == '"': + return name[1:-1] + return name + + def _clean_schema_field_names(dt): + """Strip outer double quotes from field names. + + _case_preserving_simple_string wraps colon-containing names in + double quotes so the parser can split name:type correctly. + type_string_to_type_object keeps those quotes as part of the name. + Strip them here so downstream quote_name_without_upper_casing + doesn't produce triple-quoted identifiers. + """ + if isinstance(dt, StructType): + return StructType( + [ + StructField( + _strip_outer_quotes(f._name), + _clean_schema_field_names(f.datatype), + f.nullable, + ) + for f in dt.fields + ] + ) + if isinstance(dt, ArrayType): + return ArrayType( + _clean_schema_field_names(dt.element_type), dt.contains_null + ) + if isinstance(dt, MapType): + return MapType( + _clean_schema_field_names(dt.key_type), + _clean_schema_field_names(dt.value_type), + dt.value_contains_null, + ) + return dt + + inferred_schema = _clean_schema_field_names(canonical) + return inferred_schema + def _read_semi_structured_file(self, path: str, format: str) -> DataFrame: if isinstance(self._session._conn, MockServerConnection): if self._session._conn.is_closed(): @@ -1422,26 +1624,26 @@ def _read_semi_structured_file(self, path: str, format: str) -> DataFrame: metadata_project, metadata_schema = self._get_metadata_project_and_schema() + xml_inferred_schema = None if format == "XML" and XML_ROW_TAG_STRING in self._cur_options: - if is_in_stored_procedure(): # pragma: no cover - # create a temp stage for udtf import files - # we have to use "temp" object instead of "scoped temp" object in stored procedure - # so we need to upload the file to the temp stage first to use register_from_file - temp_stage = random_name_for_temp_object(TempObjectType.STAGE) - sql_create_temp_stage = f"create temp stage if not exists {temp_stage} {XML_READER_SQL_COMMENT}" - self._session.sql(sql_create_temp_stage, _emit_ast=False).collect( - _emit_ast=False - ) - self._session._conn.upload_file( - XML_READER_FILE_PATH, - temp_stage, - compress_data=False, - overwrite=True, - skip_upload_on_content_match=True, - ) - python_file_path = f"{STAGE_PREFIX}{temp_stage}/{os.path.basename(XML_READER_FILE_PATH)}" - else: - python_file_path = XML_READER_FILE_PATH + python_file_path = self._resolve_xml_file_for_udtf(XML_READER_FILE_PATH) + if ( + context._is_snowpark_connect_compatible_mode + and not self._user_schema + and self._cur_options.get("INFER_SCHEMA", True) + ): + xml_inferred_schema = self._infer_schema_for_xml(path) + if xml_inferred_schema is not None: + self._xml_inferred_schema = xml_inferred_schema + schema = [ + Attribute( + quote_name_without_upper_casing(f._name), + f.datatype, + f.nullable, + ) + for f in xml_inferred_schema.fields + ] + use_user_schema = True # create udtf handler_name = "XMLReader" @@ -1508,7 +1710,10 @@ def _read_semi_structured_file(self, path: str, format: str) -> DataFrame: set_api_call_source(df, XML_READER_API_SIGNATURE) if self._cur_options.get("CACHERESULT", True): df = df.cache_result() - df._all_variant_cols = True + # When schema is inferred or user-provided, columns are typed + # (not all VARIANT), so don't enable the all-variant dot-notation mode. + if xml_inferred_schema is None and not self._user_schema: + df._all_variant_cols = True else: set_api_call_source(df, f"DataFrameReader.{format.lower()}") return df diff --git a/tests/integ/test_xml_infer_schema.py b/tests/integ/test_xml_infer_schema.py new file mode 100644 index 0000000000..1c6f9e0fc5 --- /dev/null +++ b/tests/integ/test_xml_infer_schema.py @@ -0,0 +1,1189 @@ +# +# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved. +# + +import datetime +import json +import os +import pytest + +from snowflake.snowpark import Row +from snowflake.snowpark.functions import col +from snowflake.snowpark.types import ( + StructType, + StructField, + StringType, + LongType, + DoubleType, + DateType, + BooleanType, + TimestampType, + ArrayType, + VariantType, +) +import snowflake.snowpark.context as context +from tests.utils import TestFiles, Utils + + +pytestmark = [ + pytest.mark.skipif( + "config.getoption('local_testing_mode', default=False)", + reason="xml not supported in local testing mode", + ), + pytest.mark.udf, +] + +tmp_stage_name = Utils.random_stage_name() + +# Resource XML file names (uploaded from tests/resources/) +RES_BOOKS_XML = "books.xml" +RES_BOOKS2_XML = "books2.xml" +RES_DK_TRACE_XML = "dk_trace_sample.xml" +RES_DBLP_XML = "dblp_6kb.xml" +RES_BOOKS_ATTR_VAL_XML = "books_attribute_value.xml" + +# Inline XML strings uploaded to stage as files for testing +# Each covers specific inference scenarios without needing separate resource files. + +# Primitive type inference: bool, double, long, string, timestamp +PRIMITIVES_XML = """\ + + + + true + +10.1 + -10 + 10 + 8E9D + 2015-01-01 00:00:00 + + +""" + +# Date and timestamp inference +DATE_TIME_XML = """\ + + + + John Smith + 2021-02-01 + 02-01-2021 + + +""" + +TIMESTAMP_XML = """\ + + + + John Smith + + not-a-timestamp + + +""" + +# Root-level _VALUE: attribute-only row element with text content +ROOT_VALUE_XML = """\ + + + value1 + value2 + value3 + +""" + +# Root-level _VALUE with child elements +ROOT_VALUE_MIXED_XML = """\ + + + value1 + value2 + 45 + 67 + + +""" + +# Nested object: struct child element +NESTED_OBJECT_XML = """\ + + + + Book A + 44.95 + + Acme + 2020 + + + + Book B + 29.99 + + Beta + 2021 + + + +""" + +# Nested array: repeated sibling elements +NESTED_ARRAY_XML = """\ + + + + Book A + fiction + classic + + + Book B + science + + +""" + +# Element with attribute on leaf: 5.95 → struct(_VALUE, _unit) +ATTR_ON_LEAF_XML = """\ + + + + Book A + 44.95 + + + Book B + 29.99 + + + Book C + 15.00 + + +""" + +# Missing nested struct: some rows have nested struct, some don't +MISSING_NESTED_XML = """\ + + + + Item A +
+ red +
+
+ + Item B + +
+""" + +# Unbalanced types: same field has different types across rows +UNBALANCED_TYPES_XML = """\ + + + + 123 + hello + + + 45.6 + world + + +""" + +# Mixed content: text + child elements +MIXED_CONTENT_XML = """\ + + + + Simple text + + 1 + + + + Has mixed content + + 2 + + + +""" + +# ExcludeAttributes with inferSchema +EXCLUDE_ATTRS_XML = """\ + + + + Widget + 9.99 + + + Gadget + 19.99 + + +""" + +# Big integer inference +BIG_INT_XML = """\ + + + + 42 + 92233720368547758070 + 3.14 + + +""" + +# Nested element same name as parent +PARENT_NAME_COLLISION_XML = """\ + + + + + Child 1.1 + + Child 1.2 + + + + Child 2.1 + + Child 2.2 + + +""" + +# Complicated nested: struct containing array of structs with attributes +COMPLICATED_NESTED_XML = """\ + + + + Author A + + 1 + Fiction + + + + 2020 + 1 + + + 2020 + 6 + + + + + Author B + + 2 + Science + + + + 2021 + 3 + + + + +""" + +# Sampling heterogeneous: 20 rows where the first 15 have value as integer, +# but rows 16-20 have value as a float string. With a low sampling ratio +# deterministic seed may only see the first chunk → LongType instead of DoubleType. +SAMPLING_HETERO_XML = """\ + + + r011 + r022 + r033 + r044 + r055 + r066 + r077 + r088 + r099 + r1010 + r1111 + r1212 + r1313 + r1414 + r1515 + r1616.1 + r1717.2 + r1818.3 + r1919.4 + r2020.5 + +""" + +# Processing instruction XML +PROCESSING_INSTRUCTION_XML = """\ + + + + + 1 + hello + + +""" + + +def _upload_xml_string(session, stage, filename, xml_content): + """Write XML string to a temp file and upload to stage.""" + import tempfile + + with tempfile.NamedTemporaryFile( + mode="w", suffix=".xml", delete=False, prefix=filename.replace(".xml", "_") + ) as f: + f.write(xml_content) + tmp_path = f.name + try: + Utils.upload_to_stage(session, stage, tmp_path, compress=False) + finally: + os.unlink(tmp_path) + return os.path.basename(tmp_path) + + +# Map of logical name -> (xml_content, staged_filename) populated during setup +_staged_files = {} + + +@pytest.fixture(autouse=True) +def enable_scos_compatible_mode(): + """XML inferSchema is gated behind _is_snowpark_connect_compatible_mode. + Enable it for every test in this module.""" + original = context._is_snowpark_connect_compatible_mode + context._is_snowpark_connect_compatible_mode = True + yield + context._is_snowpark_connect_compatible_mode = original + + +@pytest.fixture(scope="module", autouse=True) +def setup(session, resources_path, local_testing_mode): + test_files = TestFiles(resources_path) + if not local_testing_mode: + Utils.create_stage(session, tmp_stage_name, is_temporary=True) + + # Upload resource XML files + Utils.upload_to_stage( + session, "@" + tmp_stage_name, test_files.test_xml_infer_types, compress=False + ) + Utils.upload_to_stage( + session, "@" + tmp_stage_name, test_files.test_xml_infer_mixed, compress=False + ) + Utils.upload_to_stage( + session, "@" + tmp_stage_name, test_files.test_books_xml, compress=False + ) + Utils.upload_to_stage( + session, "@" + tmp_stage_name, test_files.test_books2_xml, compress=False + ) + Utils.upload_to_stage( + session, + "@" + tmp_stage_name, + test_files.test_dk_trace_sample_xml, + compress=False, + ) + Utils.upload_to_stage( + session, "@" + tmp_stage_name, test_files.test_dblp_6kb_xml, compress=False + ) + Utils.upload_to_stage( + session, + "@" + tmp_stage_name, + test_files.test_books_attribute_value_xml, + compress=False, + ) + + # Upload inline XML strings as files + inline_xmls = { + "primitives": PRIMITIVES_XML, + "date_time": DATE_TIME_XML, + "timestamp": TIMESTAMP_XML, + "root_value": ROOT_VALUE_XML, + "root_value_mixed": ROOT_VALUE_MIXED_XML, + "nested_object": NESTED_OBJECT_XML, + "nested_array": NESTED_ARRAY_XML, + "attr_on_leaf": ATTR_ON_LEAF_XML, + "missing_nested": MISSING_NESTED_XML, + "unbalanced_types": UNBALANCED_TYPES_XML, + "mixed_content": MIXED_CONTENT_XML, + "exclude_attrs": EXCLUDE_ATTRS_XML, + "big_int": BIG_INT_XML, + "parent_collision": PARENT_NAME_COLLISION_XML, + "complicated_nested": COMPLICATED_NESTED_XML, + "processing_instr": PROCESSING_INSTRUCTION_XML, + "sampling_hetero": SAMPLING_HETERO_XML, + } + for name, xml_str in inline_xmls.items(): + staged = _upload_xml_string( + session, "@" + tmp_stage_name, f"{name}.xml", xml_str + ) + _staged_files[name] = staged + + yield + if not local_testing_mode: + session.sql(f"DROP STAGE IF EXISTS {tmp_stage_name}").collect() + + +def _schema_types(df): + """Return {lowercase_field_name: datatype_class} from df.schema.""" + return {f.name.strip('"').lower(): type(f.datatype) for f in df.schema.fields} + + +# ─── Primitive type inference ─────────────────────────────────────────────── + + +def test_infer_primitives(session): + """Bool, double, long, string, timestamp inferred from inline XML.""" + df = session.read.option("rowTag", "ROW").xml( + f"@{tmp_stage_name}/{_staged_files['primitives']}" + ) + assert df._all_variant_cols is False + types = _schema_types(df) + assert types["bool1"] == BooleanType + assert types["double1"] == DoubleType + assert types["long1"] == LongType + assert types["long2"] == LongType + assert types["string1"] == StringType + assert types["ts1"] == TimestampType + result = df.collect() + assert len(result) == 1 + assert result[0]["bool1"] is True + assert result[0]["double1"] == 10.1 + assert result[0]["long1"] == -10 + assert result[0]["string1"] == "8E9D" + + +def test_infer_date(session): + """ISO date string inferred as DateType, non-ISO stays StringType.""" + df = session.read.option("rowTag", "book").xml( + f"@{tmp_stage_name}/{_staged_files['date_time']}" + ) + assert df._all_variant_cols is False + types = _schema_types(df) + assert types["date"] == DateType + assert types["date2"] == StringType # non-ISO format stays string + assert types["author"] == StringType + result = df.collect() + assert result[0]["date"] == datetime.date(2021, 2, 1) + + +def test_infer_timestamp(session): + """ISO timestamp inferred as TimestampType, non-timestamp stays StringType.""" + df = session.read.option("rowTag", "book").xml( + f"@{tmp_stage_name}/{_staged_files['timestamp']}" + ) + assert df._all_variant_cols is False + types = _schema_types(df) + assert types["time"] == TimestampType + assert types["time2"] == StringType + + +# ─── Root-level _VALUE tag ────────────────────────────────────────────────── + + +def test_infer_root_value_attrs_only(session): + """value → _VALUE + _attr columns.""" + df = session.read.option("rowTag", "ROW").xml( + f"@{tmp_stage_name}/{_staged_files['root_value']}" + ) + assert df._all_variant_cols is False + types = _schema_types(df) + assert "_value" in types or "_VALUE" in types.keys() + result = df.collect() + assert len(result) == 3 + + +def test_infer_root_value_with_child_elements(session): + """45 → _VALUE + tag columns.""" + df = session.read.option("rowTag", "ROW").xml( + f"@{tmp_stage_name}/{_staged_files['root_value_mixed']}" + ) + assert df._all_variant_cols is False + types = _schema_types(df) + assert "tag" in types + result = df.collect() + assert len(result) == 5 + + +# ─── Nested structures ───────────────────────────────────────────────────── + + +def test_infer_nested_object(session): + """Child element becomes nested struct: info.publisher, info.year.""" + df = session.read.option("rowTag", "book").xml( + f"@{tmp_stage_name}/{_staged_files['nested_object']}" + ) + assert df._all_variant_cols is False + types = _schema_types(df) + # info should be VariantType (Snowpark represents nested structs as Variant) + assert types["info"] == VariantType + assert types["price"] == DoubleType + result = df.collect() + assert len(result) == 2 + info = json.loads(result[0]["info"]) + assert info["publisher"] in ["Acme", "Beta"] + + +def test_infer_nested_array(session): + """Repeated sibling elements → ArrayType or VariantType.""" + df = session.read.option("rowTag", "book").xml( + f"@{tmp_stage_name}/{_staged_files['nested_array']}" + ) + assert df._all_variant_cols is False + types = _schema_types(df) + # Repeated tag elements: Variant wrapping an array + assert types["tag"] in (VariantType, ArrayType) + result = df.collect() + assert len(result) == 2 + # Book with 2 tags should be an array; _id may be Long or String + book1 = [r for r in result if str(r["_id"]).strip('"') == "1"][0] + tag_val = ( + json.loads(book1["tag"]) if isinstance(book1["tag"], str) else book1["tag"] + ) + assert isinstance(tag_val, list) + assert len(tag_val) == 2 + + +def test_infer_complicated_nested(session): + """Struct with sub-struct and array of structs with attributes.""" + df = session.read.option("rowTag", "book").xml( + f"@{tmp_stage_name}/{_staged_files['complicated_nested']}" + ) + assert df._all_variant_cols is False + types = _schema_types(df) + assert types["genre"] == VariantType + assert types["dates"] == VariantType + result = df.collect() + assert len(result) == 2 + # Verify genre nested content + genre1 = json.loads(result[0]["genre"]) + assert "genreid" in genre1 or "name" in genre1 + # Verify dates nested array + dates1 = json.loads(result[0]["dates"]) + assert "date" in dates1 + + +# ─── Attribute on leaf element (_VALUE pattern) ──────────────────────────── + + +def test_infer_attribute_on_leaf(session): + """44.95 → struct with _VALUE + _unit.""" + df = session.read.option("rowTag", "book").xml( + f"@{tmp_stage_name}/{_staged_files['attr_on_leaf']}" + ) + assert df._all_variant_cols is False + types = _schema_types(df) + # price has attr on some rows → becomes VariantType (struct with _VALUE/_unit) + assert types["price"] == VariantType + result = df.collect() + assert len(result) == 3 + # Book 1 has unit="$"; _id may be Long or String + book1 = [r for r in result if str(r["_id"]).strip('"') == "1"][0] + price_data = json.loads(book1["price"]) + assert "_VALUE" in price_data + assert price_data["_unit"] == "$" + + +# ─── Missing nested struct ───────────────────────────────────────────────── + + +def test_infer_missing_nested_struct(session): + """Row missing a nested struct field gets null/empty, not crash.""" + df = session.read.option("rowTag", "item").xml( + f"@{tmp_stage_name}/{_staged_files['missing_nested']}" + ) + assert df._all_variant_cols is False + types = _schema_types(df) + assert "details" in types + result = df.collect() + assert len(result) == 2 + # Item B has no details + item_b = [r for r in result if r["name"] == "Item B"][0] + details_val = item_b["details"] + # Either null or struct with null fields + if details_val is not None: + parsed = json.loads(details_val) + assert parsed.get("color") is None + + +# ─── Unbalanced types (type widening across rows) ────────────────────────── + + +def test_infer_unbalanced_types(session): + """Field1 is 123 in one row and 45.6 in another → widened to DoubleType.""" + df = session.read.option("rowTag", "ROW").xml( + f"@{tmp_stage_name}/{_staged_files['unbalanced_types']}" + ) + assert df._all_variant_cols is False + types = _schema_types(df) + assert types["field1"] == DoubleType # Long widened to Double + assert types["field2"] == StringType + result = df.collect() + assert len(result) == 2 + + +# ─── excludeAttributes with inferSchema ───────────────────────────────────── + + +def test_infer_exclude_attributes(session): + """excludeAttributes=true removes id/category from inferred schema.""" + df = ( + session.read.option("rowTag", "item") + .option("excludeAttributes", True) + .xml(f"@{tmp_stage_name}/{_staged_files['exclude_attrs']}") + ) + assert df._all_variant_cols is False + types = _schema_types(df) + assert "_id" not in types + assert "_category" not in types + assert "name" in types + assert "price" in types + result = df.collect() + assert len(result) == 2 + + # Without excludeAttributes, attributes should appear + df2 = session.read.option("rowTag", "item").xml( + f"@{tmp_stage_name}/{_staged_files['exclude_attrs']}" + ) + types2 = _schema_types(df2) + assert "_id" in types2 + assert "_category" in types2 + + +# ─── Big integer handling ─────────────────────────────────────────────────── + + +def test_infer_big_integer(session): + """Very large integer doesn't fit in Long → inferred as Double or String.""" + df = session.read.option("rowTag", "ROW").xml( + f"@{tmp_stage_name}/{_staged_files['big_int']}" + ) + assert df._all_variant_cols is False + types = _schema_types(df) + assert types["small_int"] == LongType + # Big int overflows Long → should become Double or String + assert types["big_int"] in (DoubleType, StringType) + assert types["normal_double"] == DoubleType + + +# ─── Nested element same name as parent ───────────────────────────────────── + + +def test_infer_parent_name_collision(session): + """...... + Known limitation: when rowTag matches a child element name, byte-scanning + picks up inner tags as row boundaries, causing schema inference to fall back + to all-variant or produce extra rows. Verify it doesn't crash.""" + df = session.read.option("rowTag", "parent").xml( + f"@{tmp_stage_name}/{_staged_files['parent_collision']}" + ) + result = df.collect() + # May produce 2 or 4 rows depending on inner matching + assert len(result) >= 2 + + +# ─── Processing instruction ──────────────────────────────────────────────── + + +def test_infer_with_processing_instruction(session): + """XML with should parse without error.""" + df = session.read.option("rowTag", "foo").xml( + f"@{tmp_stage_name}/{_staged_files['processing_instr']}" + ) + assert df._all_variant_cols is False + result = df.collect() + assert len(result) == 1 + + +# ─── Mixed content (text + child elements) ────────────────────────────────── + + +def test_infer_mixed_content(session): + """Text mixed with child elements: desc field has both text and .""" + df = session.read.option("rowTag", "item").xml( + f"@{tmp_stage_name}/{_staged_files['mixed_content']}" + ) + assert df._all_variant_cols is False + result = df.collect() + assert len(result) == 2 + + +# ─── inferSchema=false keeps all strings ──────────────────────────────────── + + +@pytest.mark.parametrize( + "staged_key, row_tag, expected_count", + [ + ("primitives", "ROW", 1), + ("nested_object", "book", 2), + ("unbalanced_types", "ROW", 2), + ], +) +def test_infer_schema_false_all_strings(session, staged_key, row_tag, expected_count): + """inferSchema=false → all fields are variant/string columns.""" + df = ( + session.read.option("rowTag", row_tag) + .option("inferSchema", False) + .xml(f"@{tmp_stage_name}/{_staged_files[staged_key]}") + ) + assert df._all_variant_cols is True + result = df.collect() + assert len(result) == expected_count + + +# ─── Resource file: xml_infer_types.xml ───────────────────────────────────── + + +def test_infer_types_resource_file(session): + """Comprehensive resource file: primitives, nested, array, attributes, mixed.""" + df = session.read.option("rowTag", "item").xml( + f"@{tmp_stage_name}/xml_infer_types.xml" + ) + assert df._all_variant_cols is False + types = _schema_types(df) + + # Flat fields + assert types["_id"] in (LongType, StringType) + assert types["name"] == StringType + assert types["quantity"] == LongType + assert types["in_stock"] == BooleanType + assert types["release_date"] == DateType + assert types["last_updated"] == TimestampType + + # price has attribute on some rows → complex type + assert types["price"] == VariantType + + # tags → nested with array → Variant + assert types["tags"] == VariantType + + # specs → nested struct → Variant + assert types["specs"] == VariantType + + # rating has attribute on some rows → complex + assert types["rating"] == VariantType + + result = df.collect() + assert len(result) == 3 + + # Verify nested data + item1 = [r for r in result if str(r["_id"]) in ("1", '"1"')][0] + specs = json.loads(item1["specs"]) + assert "weight" in specs or "dimensions" in specs + tags = json.loads(item1["tags"]) + assert "tag" in tags + + +def test_infer_mixed_resource_file(session): + """Resource file with name collisions, sparse fields, mixed content.""" + df = session.read.option("rowTag", "record").xml( + f"@{tmp_stage_name}/xml_infer_mixed.xml" + ) + assert df._all_variant_cols is False + types = _schema_types(df) + assert "_type" in types + assert "parent" in types + assert "age" in types + assert "value" in types + result = df.collect() + assert len(result) == 3 + + # Verify opt_struct is null for sparse record + sparse = [r for r in result if r["_type"] == "sparse"][0] + try: + opt_val = sparse["opt_struct"] + except (KeyError, IndexError): + opt_val = None + if opt_val is not None: + parsed = json.loads(opt_val) + assert parsed["a"] is None + + +def test_read_xml_infer_schema_books_flat(session): + """Infer schema on books.xml: all flat primitives, 12 rows, 7 columns.""" + expected_schema = StructType( + [ + StructField("_id", StringType()), + StructField("author", StringType()), + StructField("description", StringType()), + StructField("genre", StringType()), + StructField("price", DoubleType()), + StructField("publish_date", DateType()), + StructField("title", StringType()), + ] + ) + df = session.read.option("rowTag", "book").xml(f"@{tmp_stage_name}/{RES_BOOKS_XML}") + assert df._all_variant_cols is False + actual = {f.name.strip('"').lower(): type(f.datatype) for f in df.schema.fields} + expected = {f.name.lower(): type(f.datatype) for f in expected_schema.fields} + assert actual == expected + result = df.collect() + assert len(result) == 12 + assert len(result[0]) == 7 + Utils.check_answer( + df.filter(col('"_id"') == "bk101").select( + col('"price"'), col('"publish_date"'), col('"author"') + ), + [Row(44.95, datetime.date(2000, 10, 1), "Gambardella, Matthew")], + ) + + +def test_read_xml_infer_schema_books2_nested(session): + """Infer schema on books2.xml: complex types become VariantType, verify nested data.""" + expected_schema = StructType( + [ + StructField("_id", LongType()), + StructField("author", StringType()), + StructField("editions", VariantType()), + StructField("price", DoubleType()), + StructField("reviews", VariantType()), + StructField("title", StringType()), + ] + ) + df = session.read.option("rowTag", "book").xml( + f"@{tmp_stage_name}/{RES_BOOKS2_XML}" + ) + assert df._all_variant_cols is False + actual = {f.name.strip('"').lower(): type(f.datatype) for f in df.schema.fields} + expected = {f.name.lower(): type(f.datatype) for f in expected_schema.fields} + assert actual == expected + result = df.collect() + assert len(result) == 2 + assert len(result[0]) == 6 + + book1 = df.filter(col('"_id"') == 1).collect() + assert len(book1) == 1 + reviews = json.loads(book1[0]["reviews"]) + assert isinstance(reviews["review"], list) + assert len(reviews["review"]) == 2 + assert reviews["review"][0]["user"] == "tech_guru_87" + editions = json.loads(book1[0]["editions"]) + assert isinstance(editions["edition"], list) + assert len(editions["edition"]) == 2 + + book2 = df.filter(col('"_id"') == 2).collect() + assert len(book2) == 1 + review_data = json.loads(book2[0]["reviews"])["review"] + if isinstance(review_data, dict): + assert review_data["user"] == "xml_master" + else: + assert review_data[0]["user"] == "xml_master" + assert book1[0]["price"] == 29.99 + assert book2[0]["price"] == 35.50 + + +def test_read_xml_namespace_infer_schema(session): + """Infer schema on dk_trace_sample.xml with namespace-prefixed eqTrace:event rowTag.""" + expected_schema = StructType( + [ + StructField("eqTrace:date-time", TimestampType()), + StructField("eqTrace:equipment-cycle-status-changed", VariantType()), + StructField("eqTrace:event-descriptor-list", VariantType()), + StructField("eqTrace:event-id", StringType()), + StructField("eqTrace:event-name", StringType()), + StructField("eqTrace:event-version-number", DoubleType()), + StructField("eqTrace:interchanged", VariantType()), + StructField("eqTrace:is-planned", BooleanType()), + StructField("eqTrace:is-synthetic", BooleanType()), + StructField("eqTrace:location-id", VariantType()), + StructField("eqTrace:placement-at-industry-planned", VariantType()), + StructField("eqTrace:publisher-identification", StringType()), + StructField("eqTrace:reporting-detail", VariantType()), + StructField("eqTrace:tag-name-list", VariantType()), + StructField("eqTrace:waybill-applied", VariantType()), + ] + ) + df = ( + session.read.option("rowTag", "eqTrace:event") + .option("ignoreNamespace", False) + .xml(f"@{tmp_stage_name}/{RES_DK_TRACE_XML}") + ) + assert df._all_variant_cols is False + actual = {f.name.strip('"'): type(f.datatype) for f in df.schema.fields} + expected = {f.name.strip('"'): type(f.datatype) for f in expected_schema.fields} + assert actual == expected + result = df.collect() + assert len(result) == 5 + + first_event = df.filter( + col('"eqTrace:event-id"') == "f0e765d9-599b-46bf-9aef-bd33e0c2183f" + ).collect() + assert len(first_event) == 1 + Utils.check_answer( + df.filter( + col('"eqTrace:event-id"') == "f0e765d9-599b-46bf-9aef-bd33e0c2183f" + ).select(col('"eqTrace:event-name"'), col('"eqTrace:is-planned"')), + [Row("equipment/equipment-placement-at-industry-planned", True)], + ) + + second_event = df.filter( + col('"eqTrace:event-id"') == "dd9a4616-e41c-4571-9bda-9a5506a2b78d" + ).collect() + assert len(second_event) == 1 + assert second_event[0]["eqTrace:is-planned"] is False + + +def test_read_xml_dblp_mastersthesis_infer_schema(session): + """Infer schema on dblp_6kb.xml mastersthesis: verify schema, nested ee data, printSchema.""" + expected_schema = StructType( + [ + StructField("_key", StringType()), + StructField("_mdate", DateType()), + StructField("author", StringType()), + StructField("ee", VariantType()), + StructField("note", StringType()), + StructField("school", StringType()), + StructField("title", StringType()), + StructField("year", LongType()), + ] + ) + df = session.read.option("rowTag", "mastersthesis").xml( + f"@{tmp_stage_name}/{RES_DBLP_XML}" + ) + assert df._all_variant_cols is False + actual = {f.name.strip('"').lower(): type(f.datatype) for f in df.schema.fields} + expected = {f.name.lower(): type(f.datatype) for f in expected_schema.fields} + assert actual == expected + result = df.collect() + assert len(result) == 6 + + hoffmann = df.filter(col('"_key"') == "ms/Hoffmann2008").collect() + assert len(hoffmann) == 1 + ee = json.loads(hoffmann[0]["ee"]) + assert ee["_type"] == "oa" + assert "dblp.uni-trier.de" in ee["_VALUE"] + + vollmer = df.filter(col('"_key"') == "ms/Vollmer2006").collect() + assert len(vollmer) == 1 + ee_v = json.loads(vollmer[0]["ee"]) + assert ee_v["_type"] is None + assert ee_v["_VALUE"] is not None + + brown = df.filter(col('"_key"') == "ms/Brown92").collect() + assert len(brown) == 1 + assert json.loads(brown[0]["ee"]) == {"_VALUE": None, "_type": None} + + schema_str = df._format_schema() + assert "LongType" in schema_str + assert "StringType" in schema_str + assert "VariantType" in schema_str + + +def test_read_xml_dblp_incollection_infer_schema(session): + """Infer schema on dblp_6kb.xml incollection: author array, ee struct, verify data.""" + expected_schema = StructType( + [ + StructField("_corrupt_record", StringType()), + StructField("_key", StringType()), + StructField("_mdate", DateType()), + StructField("author", VariantType()), + StructField("booktitle", StringType()), + StructField("crossref", StringType()), + StructField("ee", VariantType()), + StructField("pages", StringType()), + StructField("title", StringType()), + StructField("url", StringType()), + StructField("year", LongType()), + ] + ) + df = session.read.option("rowTag", "incollection").xml( + f"@{tmp_stage_name}/{RES_DBLP_XML}" + ) + assert df._all_variant_cols is False + actual = {f.name.strip('"').lower(): type(f.datatype) for f in df.schema.fields} + expected = {f.name.lower(): type(f.datatype) for f in expected_schema.fields} + assert actual == expected + result = df.collect() + assert len(result) == 6 + + parker = df.filter(col('"_key"') == "series/ifip/ParkerD14").collect() + assert len(parker) == 1 + authors = json.loads(parker[0]["author"]) + assert isinstance(authors, list) + assert len(authors) == 2 + assert authors[0]["_VALUE"] == "Kevin R. Parker" + assert authors[0]["_orcid"] == "0000-0003-0549-3687" + ee = json.loads(parker[0]["ee"]) + assert ee["_type"] == "oa" + + lecomber = df.filter(col('"_key"') == "series/ifip/Lecomber14").collect() + assert len(lecomber) == 1 + author_val = json.loads(lecomber[0]["author"]) + if isinstance(author_val, dict): + assert author_val["_VALUE"] == "Angela Lecomber" + assert author_val["_orcid"] is None + elif isinstance(author_val, list): + assert len(author_val) == 1 + assert author_val[0]["_VALUE"] == "Angela Lecomber" + + +def test_read_xml_attribute_value_infer_schema(session): + """Infer schema on books_attribute_value.xml: publisher mixed struct + plain text.""" + expected_schema = StructType( + [ + StructField("_id", LongType()), + StructField("author", StringType()), + StructField("pages", LongType()), + StructField("price", DoubleType()), + StructField("publisher", VariantType()), + StructField("title", StringType()), + ] + ) + df = session.read.option("rowTag", "book").xml( + f"@{tmp_stage_name}/{RES_BOOKS_ATTR_VAL_XML}" + ) + assert df._all_variant_cols is False + actual = {f.name.strip('"').lower(): type(f.datatype) for f in df.schema.fields} + expected = {f.name.lower(): type(f.datatype) for f in expected_schema.fields} + assert actual == expected + result = df.collect() + assert len(result) == 5 + assert len(result[0]) == 6 + + book1 = df.filter(col('"_id"') == 1).collect() + assert len(book1) == 1 + pub1 = json.loads(book1[0]["publisher"]) + assert pub1 == {"_VALUE": "O'Reilly Media", "_country": "USA", "_language": None} + + book3 = df.filter(col('"_id"') == 3).collect() + assert len(book3) == 1 + pub3 = json.loads(book3[0]["publisher"]) + assert pub3 == {"_VALUE": "Springer", "_country": "Canada", "_language": "English"} + + book4 = df.filter(col('"_id"') == 4).collect() + assert len(book4) == 1 + pub4 = json.loads(book4[0]["publisher"]) + assert pub4 == {"_VALUE": "Some Publisher", "_country": None, "_language": None} + + book5 = df.filter(col('"_id"') == 5).collect() + assert len(book5) == 1 + pub5 = json.loads(book5[0]["publisher"]) + assert pub5 == {"_VALUE": None, "_country": None, "_language": None} + + +# ─── inferSchema vs inferSchema=false comparison ──────────────────────────── + + +@pytest.mark.parametrize( + "staged_file, row_tag", + [ + ("xml_infer_types.xml", "item"), + (RES_BOOKS_XML, "book"), + (RES_BOOKS2_XML, "book"), + ], +) +def test_infer_vs_no_infer_column_count(session, staged_file, row_tag): + """inferSchema produces typed columns; no inferSchema produces variant columns.""" + df_infer = session.read.option("rowTag", row_tag).xml( + f"@{tmp_stage_name}/{staged_file}" + ) + df_no_infer = ( + session.read.option("rowTag", row_tag) + .option("inferSchema", False) + .xml(f"@{tmp_stage_name}/{staged_file}") + ) + assert df_infer._all_variant_cols is False + assert df_no_infer._all_variant_cols is True + assert df_infer.count() == df_no_infer.count() + + +# ─── samplingRatio tests ───────────────────────────────────────────────────── + + +def test_sampling_ratio_schema_books_flat(session): + """samplingRatio=0.5 on homogeneous books.xml: correct schema and deterministic across runs.""" + path = f"@{tmp_stage_name}/{RES_BOOKS_XML}" + schemas = [] + for _ in range(3): + df = ( + session.read.option("rowTag", "book").option("samplingRatio", 0.5).xml(path) + ) + schemas.append(df.schema) + + assert schemas[0] == schemas[1] == schemas[2] + assert df._all_variant_cols is False + assert _schema_types(df) == { + "_id": StringType, + "author": StringType, + "description": StringType, + "genre": StringType, + "price": DoubleType, + "publish_date": DateType, + "title": StringType, + } + assert df.count() == 12 + + +@pytest.mark.parametrize("invalid_ratio", [0, -0.5]) +def test_sampling_ratio_invalid(session, invalid_ratio): + with pytest.raises(ValueError, match="should be greater than 0"): + session.read.option("rowTag", "ROW").option("samplingRatio", invalid_ratio).xml( + f"@{tmp_stage_name}/{_staged_files['primitives']}" + ) + + +def test_sampling_ratio_hetero_may_narrow_schema(session): + """Low samplingRatio on heterogeneous data may infer a narrower schema.""" + path = f"@{tmp_stage_name}/{_staged_files['sampling_hetero']}" + + df_full = session.read.option("rowTag", "ROW").xml(path) + assert _schema_types(df_full)["value"] == DoubleType + + df_sampled = ( + session.read.option("rowTag", "ROW").option("samplingRatio", 0.3).xml(path) + ) + types_sampled = _schema_types(df_sampled) + assert types_sampled["value"] in (LongType, DoubleType) + assert types_sampled["name"] == StringType + assert df_sampled.count() == 20 + + +def test_sampling_ratio_nested_schema_preserved(session): + """samplingRatio < 1.0 on nested data still infers nested structure.""" + df = ( + session.read.option("rowTag", "book") + .option("samplingRatio", 0.5) + .xml(f"@{tmp_stage_name}/{RES_BOOKS2_XML}") + ) + assert df._all_variant_cols is False + assert _schema_types(df) == { + "_id": LongType, + "author": StringType, + "editions": VariantType, + "price": DoubleType, + "reviews": VariantType, + "title": StringType, + } + assert df.count() == 2 + book1 = df.filter(col('"_id"') == 1).collect()[0] + reviews = json.loads(book1["reviews"]) + assert isinstance(reviews["review"], list) + assert len(reviews["review"]) == 2 + + +def test_infer_schema_non_existing_file(session): + with pytest.raises(ValueError, match="does not exist"): + session.read.option("rowTag", "row").xml( + f"@{tmp_stage_name}/non_existing_file.xml" + ) + + +def test_infer_schema_use_leaf_row_tag(session): + xml_content = "helloworld" + actual_filename = _upload_xml_string( + session, tmp_stage_name, "leaf_only_infer.xml", xml_content + ) + df = session.read.option("rowTag", "item").xml( + f"@{tmp_stage_name}/{actual_filename}" + ) + assert df.count() == 2 + assert "_VALUE" in [f.name for f in df.schema.fields] or len(df.schema.fields) == 1 diff --git a/tests/integ/test_xml_reader_row_tag.py b/tests/integ/test_xml_reader_row_tag.py index b63123ae10..4eda2408c4 100644 --- a/tests/integ/test_xml_reader_row_tag.py +++ b/tests/integ/test_xml_reader_row_tag.py @@ -4,6 +4,7 @@ import datetime import logging import json +import os import pytest from snowflake.snowpark import Row @@ -16,10 +17,14 @@ StructType, StructField, StringType, + LongType, DoubleType, DateType, + BooleanType, + TimestampType, ArrayType, ) +import snowflake.snowpark.context as context from tests.utils import TestFiles, Utils @@ -32,6 +37,15 @@ ] +@pytest.fixture() +def enable_scos_compatible_mode(): + """Enable SCOS compatible mode so that type validation runs inside the UDTF.""" + original = context._is_snowpark_connect_compatible_mode + context._is_snowpark_connect_compatible_mode = True + yield + context._is_snowpark_connect_compatible_mode = original + + # XML test file constants test_file_books_xml = "books.xml" test_file_books2_xml = "books2.xml" @@ -46,10 +60,54 @@ test_file_xml_undeclared_namespace = "undeclared_namespace.xml" test_file_null_value_xml = "null_value.xml" test_file_books_xsd = "books.xsd" +test_file_dk_trace_xml = "dk_trace_sample.xml" +test_file_dblp_xml = "dblp_6kb.xml" +test_file_books_attr_val_xml = "books_attribute_value.xml" # Global stage name for uploading test files tmp_stage_name = Utils.random_stage_name() +# Inline XML strings for permissive/failfast/dropmalformed mode tests +SAMPLING_MISMATCH_XML = """\ + + + Alice100 + Bob200 + Carol300 + Dave400 + Eve500 + Frankhello + +""" + +MULTIFIELD_MISMATCH_XML = """\ + + + 42true3.14 + not_a_nummaybe2.72 + 99falsenot_a_dbl + +""" + + +def _upload_xml_string(session, stage, filename, xml_content): + """Write XML string to a temp file and upload to stage.""" + import tempfile + + with tempfile.NamedTemporaryFile( + mode="w", suffix=".xml", delete=False, prefix=filename.replace(".xml", "_") + ) as f: + f.write(xml_content) + tmp_path = f.name + try: + Utils.upload_to_stage(session, stage, tmp_path, compress=False) + finally: + os.unlink(tmp_path) + return os.path.basename(tmp_path) + + +_staged_files = {} + @pytest.fixture(scope="module", autouse=True) def setup(session, resources_path, local_testing_mode): @@ -124,6 +182,34 @@ def setup(session, resources_path, local_testing_mode): test_files.test_books_xsd, compress=False, ) + Utils.upload_to_stage( + session, + "@" + tmp_stage_name, + test_files.test_dk_trace_sample_xml, + compress=False, + ) + Utils.upload_to_stage( + session, + "@" + tmp_stage_name, + test_files.test_dblp_6kb_xml, + compress=False, + ) + Utils.upload_to_stage( + session, + "@" + tmp_stage_name, + test_files.test_books_attribute_value_xml, + compress=False, + ) + + # Upload inline XML strings for mode tests + for name, xml_str in { + "sampling_mismatch": SAMPLING_MISMATCH_XML, + "multifield_mismatch": MULTIFIELD_MISMATCH_XML, + }.items(): + staged = _upload_xml_string( + session, "@" + tmp_stage_name, f"{name}.xml", xml_str + ) + _staged_files[name] = staged yield # Clean up resources @@ -138,6 +224,8 @@ def setup(session, resources_path, local_testing_mode): [test_file_books2_xml, "book", 2, 6], [test_file_house_xml, "House", 37, 22], [test_file_house_large_xml, "House", 740, 22], + [test_file_dblp_xml, "mastersthesis", 6, 8], + [test_file_books_attr_val_xml, "book", 5, 6], ], ) def test_read_xml_row_tag( @@ -769,3 +857,292 @@ def test_value_tag_custom_schema(session): .xml(f"@{tmp_stage_name}/{test_file_null_value_xml}") ) Utils.check_answer(df, [Row(num="1", str1="NULL", str2=None, str3="xxx")]) + + +def test_read_xml_namespace_user_schema(session): + """User-provided schema with namespace-prefixed fields (ignoreNamespace=false).""" + user_schema = StructType( + [ + StructField("eqTrace:event-id", StringType(), True), + StructField("eqTrace:event-name", StringType(), True), + StructField("eqTrace:event-version-number", DoubleType(), True), + StructField("eqTrace:is-planned", BooleanType(), True), + StructField("eqTrace:is-synthetic", BooleanType(), True), + StructField("eqTrace:date-time", TimestampType(), True), + ] + ) + df = ( + session.read.option("rowTag", "eqTrace:event") + .option("ignoreNamespace", False) + .schema(user_schema) + .xml(f"@{tmp_stage_name}/{test_file_dk_trace_xml}") + ) + result = df.collect() + assert len(result) == 5 + assert len(result[0]) == 6 + col_names = [f.name.strip('"') for f in df.schema.fields] + assert "eqTrace:event-id" in col_names + assert "eqTrace:event-name" in col_names + event_ids = {r[0] for r in result} + assert "f0e765d9-599b-46bf-9aef-bd33e0c2183f" in event_ids + assert "dd9a4616-e41c-4571-9bda-9a5506a2b78d" in event_ids + + +def test_read_xml_dblp_user_schema(session): + """User-provided schema on dblp_6kb.xml mastersthesis with nested ee StructType.""" + user_schema = StructType( + [ + StructField("_mdate", DateType(), True), + StructField("_key", StringType(), True), + StructField("author", StringType(), True), + StructField("title", StringType(), True), + StructField("year", LongType(), True), + StructField("school", StringType(), True), + StructField( + "ee", + StructType( + [ + StructField("_VALUE", StringType(), True), + StructField("_type", StringType(), True), + ] + ), + True, + ), + ] + ) + df = ( + session.read.option("rowTag", "mastersthesis") + .schema(user_schema) + .xml(f"@{tmp_stage_name}/{test_file_dblp_xml}") + ) + result = df.collect() + assert len(result) == 6 + assert len(result[0]) == 7 + brown = df.filter(col('"_key"') == "ms/Brown92").collect() + assert len(brown) == 1 + assert json.loads(brown[0]["ee"]) == {"_VALUE": None, "_type": None} + assert brown[0]["year"] == 1992 + + +def test_read_xml_dblp_incollection_user_schema(session): + """User-provided ArrayType(StructType) for author, StructType for ee on dblp incollection.""" + author_element_schema = StructType( + [ + StructField("_VALUE", StringType(), True), + StructField("_orcid", StringType(), True), + ] + ) + ee_schema = StructType( + [ + StructField("_VALUE", StringType(), True), + StructField("_type", StringType(), True), + ] + ) + user_schema = StructType( + [ + StructField("_key", StringType(), True), + StructField("_mdate", DateType(), True), + StructField("author", ArrayType(author_element_schema, True), True), + StructField("booktitle", StringType(), True), + StructField("crossref", StringType(), True), + StructField("ee", ee_schema, True), + StructField("pages", StringType(), True), + StructField("title", StringType(), True), + StructField("url", StringType(), True), + StructField("year", LongType(), True), + ] + ) + df = ( + session.read.option("rowTag", "incollection") + .schema(user_schema) + .xml(f"@{tmp_stage_name}/{test_file_dblp_xml}") + ) + result = df.collect() + assert len(result) == 6 + assert len(result[0]) == 11 + + parker = df.filter(col('"_key"') == "series/ifip/ParkerD14").collect() + assert len(parker) == 1 + authors = json.loads(parker[0]["author"]) + assert len(authors) == 2 + assert authors[0]["_VALUE"] == "Kevin R. Parker" + assert authors[0]["_orcid"] == "0000-0003-0549-3687" + assert authors[1]["_VALUE"] == "Bill Davey" + assert authors[1]["_orcid"] is None + ee = json.loads(parker[0]["ee"]) + assert ee["_VALUE"] == "https://doi.org/10.1007/978-3-642-55119-2_14" + assert ee["_type"] == "oa" + + scheid = df.filter(col('"_key"') == "series/ifip/ScheidRKFRS21").collect() + assert len(scheid) == 1 + scheid_authors = json.loads(scheid[0]["author"]) + assert len(scheid_authors) == 6 + assert all(a["_orcid"] is not None for a in scheid_authors) + + rheingans = df.filter(col('"_key"') == "series/ifip/RheingansL95").collect() + assert len(rheingans) == 1 + ee_no_type = json.loads(rheingans[0]["ee"]) + assert ee_no_type["_type"] is None + assert "doi.org" in ee_no_type["_VALUE"] + assert parker[0]["year"] == 2014 + + +def test_read_xml_attribute_value_user_schema_struct_publisher(session): + """User StructType schema for publisher on books_attribute_value.xml.""" + publisher_schema = StructType( + [ + StructField("_VALUE", StringType(), True), + StructField("_country", StringType(), True), + StructField("_language", StringType(), True), + ] + ) + user_schema = StructType( + [ + StructField("_id", LongType(), True), + StructField("title", StringType(), True), + StructField("author", StringType(), True), + StructField("price", DoubleType(), True), + StructField("publisher", publisher_schema, True), + ] + ) + df = ( + session.read.option("rowTag", "book") + .schema(user_schema) + .xml(f"@{tmp_stage_name}/{test_file_books_attr_val_xml}") + ) + result = df.collect() + assert len(result) == 5 + assert len(result[0]) == 5 + + book1 = df.filter(col('"_id"') == 1).collect() + assert len(book1) == 1 + pub1 = json.loads(book1[0]["publisher"]) + assert pub1 == {"_VALUE": "O'Reilly Media", "_country": "USA", "_language": None} + + book3 = df.filter(col('"_id"') == 3).collect() + assert len(book3) == 1 + pub3 = json.loads(book3[0]["publisher"]) + assert pub3 == {"_VALUE": "Springer", "_country": "Canada", "_language": "English"} + + book4 = df.filter(col('"_id"') == 4).collect() + assert len(book4) == 1 + pub4 = json.loads(book4[0]["publisher"]) + assert pub4 == {"_VALUE": "Some Publisher", "_country": None, "_language": None} + assert book1[0]["price"] == 29.99 + assert book1[0]["_id"] == 1 + + +def test_permissive_type_mismatch_user_schema(session, enable_scos_compatible_mode): + schema = StructType( + [ + StructField("name", StringType()), + StructField("value", LongType()), + ] + ) + df = ( + session.read.option("rowTag", "ROW") + .schema(schema) + .xml(f"@{tmp_stage_name}/{_staged_files['sampling_mismatch']}") + ) + result = df.order_by('"name"').collect() + assert len(result) == 6 + + alice = [r for r in result if r["name"] == "Alice"][0] + assert alice["value"] == 100 + assert alice["_corrupt_record"] is None + + frank = [r for r in result if r["name"] == "Frank"][0] + assert frank["value"] is None + assert frank["_corrupt_record"] is not None + assert "hello" in frank["_corrupt_record"] + + col_names = [c.strip('"') for c in df.columns] + assert "_corrupt_record" in col_names + + +def test_permissive_multifield_per_field_granularity( + session, enable_scos_compatible_mode +): + schema = StructType( + [ + StructField("int_col", LongType()), + StructField("bool_col", BooleanType()), + StructField("dbl_col", DoubleType()), + ] + ) + df = ( + session.read.option("rowTag", "ROW") + .schema(schema) + .xml(f"@{tmp_stage_name}/{_staged_files['multifield_mismatch']}") + ) + result = df.order_by('"int_col"').collect() + assert len(result) == 3 + + assert result[0]["int_col"] is None + assert result[0]["bool_col"] is None + assert abs(result[0]["dbl_col"] - 2.72) < 0.001 + + assert result[1]["int_col"] == 42 + assert result[1]["bool_col"] is True + assert abs(result[1]["dbl_col"] - 3.14) < 0.001 + + assert result[2]["int_col"] == 99 + assert result[2]["bool_col"] is False + assert result[2]["dbl_col"] is None + + +def test_failfast_type_mismatch_raises(session, enable_scos_compatible_mode): + narrow_schema = StructType( + [ + StructField("name", StringType()), + StructField("value", LongType()), + ] + ) + with pytest.raises(SnowparkSQLException): + session.read.option("rowTag", "ROW").option("mode", "FAILFAST").schema( + narrow_schema + ).xml(f"@{tmp_stage_name}/{_staged_files['sampling_mismatch']}") + + +def test_dropmalformed_type_mismatch_drops_rows(session, enable_scos_compatible_mode): + narrow_schema = StructType( + [ + StructField("name", StringType()), + StructField("value", LongType()), + ] + ) + df = ( + session.read.option("rowTag", "ROW") + .option("mode", "DROPMALFORMED") + .schema(narrow_schema) + .xml(f"@{tmp_stage_name}/{_staged_files['sampling_mismatch']}") + ) + result = df.order_by('"name"').collect() + assert len(result) == 5 + names = [r["name"] for r in result] + assert "Frank" not in names + for r in result: + assert r["value"] is not None + + +def test_dropmalformed_multifield_drops_any_bad_field( + session, enable_scos_compatible_mode +): + schema = StructType( + [ + StructField("int_col", LongType()), + StructField("bool_col", BooleanType()), + StructField("dbl_col", DoubleType()), + ] + ) + df = ( + session.read.option("rowTag", "ROW") + .option("mode", "DROPMALFORMED") + .schema(schema) + .xml(f"@{tmp_stage_name}/{_staged_files['multifield_mismatch']}") + ) + result = df.collect() + assert len(result) == 1 + assert result[0]["int_col"] == 42 + assert result[0]["bool_col"] is True + assert abs(result[0]["dbl_col"] - 3.14) < 0.001 diff --git a/tests/resources/books_attribute_value.xml b/tests/resources/books_attribute_value.xml new file mode 100644 index 0000000000..107be4bedc --- /dev/null +++ b/tests/resources/books_attribute_value.xml @@ -0,0 +1,33 @@ + + + The Art of Snowflake + Jane Doe + 29.99 + O'Reilly Media + + + XML for Data Engineers + John Smith + 35.50 + Springer + + + Book 3 + John Smith + 35.50 + Springer + + + Book 4 + John Smith + 35.50 + Some Publisher + + + Book 5 + author5 + 35 + + 5 + + diff --git a/tests/resources/dblp_6kb.xml b/tests/resources/dblp_6kb.xml new file mode 100644 index 0000000000..2fe62000eb --- /dev/null +++ b/tests/resources/dblp_6kb.xml @@ -0,0 +1,126 @@ + + + + +Oliver Hoffmann 0002 +Regelbasierte Extraktion und asymmetrische Fusion bibliographischer Informationen. +2009 +University of Trier +Diplomarbeit, Universität Trier, FB IV, DBIS/DBLP +http://dblp.uni-trier.de/papers/DiplomarbeitOliverHoffmann.pdf + + +Rita Ley +Der Einfluss kleiner naturnaher Retentionsmaßnahmen in der Fläche auf den Hochwasserabfluss - Kleinrückhaltebecken -. +2006 +Diplomarbeit, Universität Trier, FB VI, Physische Geographie +http://dblp.uni-trier.de/papers/DiplomarbeitRitaLey.pdf + + +Kurt P. Brown +PRPL: A Database Workload Specification Language, v1.3. +1992 +University of Wisconsin-Madison + + + + +Stephan Vollmer +Portierung des DBLP-Systems auf ein relationales Datenbanksystem und Evaluation der Performance. +2006 +Diplomarbeit, Universität Trier, FB IV, Informatik +http://dbis.uni-trier.de/Diplomanden/Vollmer/vollmer.shtml + + + + +Vanessa C. Klaas +Who's Who in the World Wide Web: Approaches to Name Disambiguation +2007 +Diplomarbeit, LMU München, Informatik +http://www.pms.ifi.lmu.de/publikationen/diplomarbeiten/Vanessa.Klaas/thesis.pdf + + +Tolga Yurek +Efficient View Maintenance at Data Warehouses. +1997 +University of California at Santa Barbara, Department of Computer Science, CA, USA + + + + +Ulrich Briefs +John Kjaer +Jean-Louis Rigal +Computerization and Work, A Reader on Social Aspects of Computerization +Computerization and Work +Springer +1985 +978-3-540-15367-2 +978-3-642-70453-6 +https://doi.org/10.1007/978-3-642-70453-6 +IFIP State-of-the-Art Reports +db/series/ifip/computerization1985.html + +Kevin R. Parker +Bill Davey +Computers in Schools in the USA: A Social History. +203-211 +2014 +Reflections on the History of Computers in Education +https://doi.org/10.1007/978-3-642-55119-2_14 +series/ifip/hedu2014 +db/series/ifip/hedu2014.html#ParkerD14 + +Penny Rheingans +Chris Landreth +Perceptual Principles for Effective Visualizations. +59-73 +1995 +Perceptual Issues in Visualization +https://doi.org/10.1007/978-3-642-79057-7_6 +series/ifip/piv1995 +db/series/ifip/piv1995.html#RheingansL95 + +Rory Butler +Adrie J. Visscher +The Hopes and Realities of the Computer as a School Administration and School Management Tool. +197-202 +2014 +Reflections on the History of Computers in Education +https://doi.org/10.1007/978-3-642-55119-2_13 +series/ifip/hedu2014 +db/series/ifip/hedu2014.html#ButlerV14 + +Eder J. Scheid +Bruno Bastos Rodrigues +Christian Killer +Muriel Figueredo Franco +Sina Rafati 0001 +Burkhard Stiller +Blockchains and Distributed Ledgers Uncovered: Clarifications, Achievements, and Open Issues. +289-317 +2021 +IFIP's Exciting First 60+ Years +https://doi.org/10.1007/978-3-030-81701-5_12 +series/ifip/600 +db/series/ifip/sixty2021.html#ScheidRKFRS21 + +Angela Lecomber +Pioneering the Internet in the Nineties - An Innovative Project Involving UK and Australian Schools. +384-393 +2014 +Reflections on the History of Computers in Education +https://doi.org/10.1007/978-3-642-55119-2_27 +series/ifip/hedu2014 +db/series/ifip/hedu2014.html#Lecomber14 + +Hartmut Ehrig +Hans-Jörg Kreowski +Refinement and Implementation. +201-242 +1999 +Algebraic Foundations of Systems Specification +https://doi.org/10.1007/978-3-642-59851-7_7 +series/ifip/afss1999 +db/series/ifip/afss1999.ht \ No newline at end of file diff --git a/tests/resources/dk_trace_sample.xml b/tests/resources/dk_trace_sample.xml new file mode 100644 index 0000000000..466bdcc769 --- /dev/null +++ b/tests/resources/dk_trace_sample.xml @@ -0,0 +1,1015 @@ + + + 124922747 + + CAAU + 573469 + + + + 124922747 + + CAAU + 573469 + + 0 + 000000000000000000000000001402748340 + + Opened + + 2023-11-27T21:48:00Z + 2023-11-27T21:48:00Z + + Load + InYard + + UP + + + 2023-11-27T21:48:00Z + + + true + + + EVER + + 66ef31dc-f71f-45e1-8b1a-6877da033c66 + Industry + 502238 + + false + + 502238 + + + 053f9df9-932f-4795-ab21-cb66546e9314 + equipment/equipment-interchanged + 2023-11-27T21:48:00Z + 2023-11-27T21:49:56Z + + 7609 + + + + EqStopReason + EventCode + InterchangeReceipt + + + + + 7609 + + + InterchangeReceipt + + 47264fad-5172-4de8-b513-4f2f4ea89a9e + InterchangeReceipt + 7609 + + ETS + 2023-11-27T21:48:00Z + + + + true + ETS + OriginSwitch + UP + + 142054 + + + 7609 + TICTF + + + + false + UP + RoadHaul + ETS + + 7609 + TICTF + + + 502238 + + + + + + 66ef31dc-f71f-45e1-8b1a-6877da033c66 + Industry + + 677a8cbc-6de2-4bb5-b778-9b8abbdf935e + equipment/equipment-cycle-status-changed + + 502238 + + ActualPlacement + + CompleteUnload + + + + FinalSystemDest + NextPlannedStop + NextScheduledStop + + + + EqStopReason + EventCode + ActualPlacement + + + EqStopSubReason + EventQualifierCode + CompleteUnload + + + + + + 4611110 + STCC + MIXFRT + OTHER + FREIGHT ALL KINDS, (FAK) OR ALL FREIGHT RATE SHIPMENTS, NEC, OR TRAILER-ON FLATCAR SHIPMENTS, COMMERCIAL (EXCEPT IDENTIFIED BY COMMODITIES, THEN CODE BY COMMODITY) + + false + false + false + + + false + false + + + + + CofcShipment + Reporting + + + ImportShipment + Reporting + + + InBond + Reporting + + + ShprCertfdScaleWgts + Reporting + + + + + Tare + Umler + 8000 + false + + + Net + Waybill + 9619 + false + + + Gross + NCEqmtActRptg + 17800 + false + + + + + 869277338 + + 114035 + 2023-11-17 + + + + + true + BillOfLadingNumber + 100123111771505 + + + false + OceanBillOfLading + 003302074967 + + + false + EssCommFlags + ABEKEY + 231117103606 + + + false + EssGrpId + PABOCIRC7 + CS526 + + + false + EssCommFlags + YNNNNNNYYNNYNNNNYNNNYNNNNNNNNN + NNNNNNNNNN + + + false + EssGrpId + PABDCIRC7 + HL011 + + + false + EssGrpId + ESSID=880553922,0 + + + Local + Load + false + false + false + 1 + false + ShipperCert + + 4611110 + STCC + MIXFRT + OTHER + FREIGHT ALL KINDS, (FAK) OR ALL FREIGHT RATE SHIPMENTS, NEC, OR TRAILER-ON FLATCAR SHIPMENTS, COMMERCIAL (EXCEPT IDENTIFIED BY COMMODITIES, THEN CODE BY COMMODITY) + + false + false + false + + + false + false + + + + 7609 + + + 502238 + + + + UP + RoadHaul + + 7609 + + + 502238 + + + + 3 + + + InCareOf1 + + 633493 + 144973435 + EVERGREEN SHIPPING AGENCY + 16000 DALLAS PKWY STE 400 + DALLAS + TX + US + + + Phone + 9722465520 + + + + + + IMMEDIATE + + 831777 + 144973435 + EVERGREEN SHIPPING AGENCY + DALLAS + TX + US + + + + + + Consignee + + 633493 + 144973435 + EVERGREEN SHIPPING AGENCY + 16000 DALLAS PKWY STE 400 + DALLAS + TX + US + + + Phone + 9722465520 + + + + + + IMMEDIATE + + 831777 + 144973435 + EVERGREEN SHIPPING AGENCY + DALLAS + TX + US + + + + + + Notify1 + + 831777 + 144973435 + EVERGREEN SHIPPING AGENCY + 16000 DALLAS PKWY STE 400 + DALLAS + TX + US + + + Phone + 9722465520 + + + Phone + 8665718765 + + + + + + IMMEDIATE + + 993341 + 174299743 + EVERGREEN SHIPPING AGENCY + JERSEY CITY + NJ + US + + + + ULTIMATE + + 993341 + 174299743 + EVERGREEN SHIPPING AGENCY + JERSEY CITY + NJ + US + + + + + + ToReceiveFreight + + 27313 + EVERGREEN SHIPPING AGENCY + CYPRESS + CA + US + + + Phone + 7148226800 + + + + + + Shipper + + 371960 + 174299743 + EVERGREEN SHIPPING AGENCY + ONE EVERTRUST PLZ + JERSEY CITY + NJ + US + + + Phone + 2017613000 + + + + + + IMMEDIATE + + 993341 + 174299743 + EVERGREEN SHIPPING AGENCY + JERSEY CITY + NJ + US + + + + + + + EMCJFR5623 + + + + ShippersBondNumber + VVS10930776 + + + 9619 + Net + + + IN SHIPPER BOND + TCS + + + WB + AUTOBILL + + 003302074967 + 1322-018E + LOS ANGELES + D + 2023-11-20 + EVER FOREVER + + + + CT + 52072 + UP + UPC + + + + + + Intermodal + Container + false + false + K4E + U + 903 + 8000 + U + P + 480 + ACT + INSV + CAAU + CAAU + CAAU + N + + + + f0e765d9-599b-46bf-9aef-bd33e0c2183f + equipment/equipment-placement-at-industry-planned + 1.2 + cn=dneq999,ou=uprr,o=up + + 01f14f9f-2464-448b-9051-2da30b5f3fcd + InterchangeReceipt + 2023-11-27T21:49:50Z + 2023-11-27T21:49:50Z + 2023-11-27T21:49:56Z + OIPT223 + UP + InterchangeReceiptUI + + InterchangeReceipt + + + + Stop + + true + false + + 502238 + + + + EqStopReason + EventCode + ActualPlacement + + + EqStopSubReason + EventQualifierCode + CompleteUnload + + + + NEW + + CompleteUnload + + 502238 + + + 66ef31dc-f71f-45e1-8b1a-6877da033c66 + + FinalSystemDest + NextPlannedStop + NextScheduledStop + + + + + dd9a4616-e41c-4571-9bda-9a5506a2b78d + equipment/equipment-waybill-applied + 1.2 + cn=dneq999,ou=uprr,o=up + + 01f14f9f-2464-448b-9051-2da30b5f3fcd + InterchangeReceipt + 2023-11-27T21:49:55Z + 2023-11-27T21:49:50Z + 2023-11-27T21:49:56Z + OIPT223 + UP + InterchangeReceiptUI + + InterchangeReceipt + + + false + false + 2023-11-27T21:49:55Z + + + EqTraceUIEventCode + EventCode + WaybillApplied + + + + NEW + + + 000000000000000000000000001402748340 + Opened + Current + + + + 869277338 + + 114035 + 2023-11-17 + + + + + true + BillOfLadingNumber + 100123111771505 + DESCRIPTION + + + false + OceanBillOfLading + 003302074967 + DESCRIPTION + + + false + EssCommFlags + ABEKEY + DESCRIPTION + + + false + EssGrpId + PABOCIRC7 + DESCRIPTION + + + false + EssCommFlags + YNNNNNNYYNNYNNNNYNNNYNNNNNNNNN + DESCRIPTION + + + false + EssGrpId + PABDCIRC7 + DESCRIPTION + + + false + EssGrpId + ESSID=880553922,0 + DESCRIPTION + + + Local + Load + false + false + false + false + + 4611110 + MIXFRT + OTHER + FREIGHT ALL KINDS, (FAK) OR ALL FREIGHT RATE SHIPMENTS, NEC, OR TRAILER-ON FLATCAR SHIPMENTS, COMMERCIAL (EXCEPT IDENTIFIED BY COMMODITIES, THEN CODE BY COMMODITY) + + false + false + + + + 7609 + + + 502238 + + + + UP + RoadHaul + + 7609 + + + 502238 + + + + + + Shipper + + 371960 + 174299743 + EVERGREEN SHIPPING AGENCY + ONE EVERTRUST PLZ + JERSEY CITY + NJ + US + + + + IMMEDIATE + + 993341 + 174299743 + EVERGREEN SHIPPING AGENCY + JERSEY CITY + NJ + US + + + + + + Consignee + + 633493 + 144973435 + EVERGREEN SHIPPING AGENCY + 16000 DALLAS PKWY STE 400 + DALLAS + TX + US + + + + IMMEDIATE + + 831777 + 144973435 + EVERGREEN SHIPPING AGENCY + DALLAS + TX + US + + + + + + InCareOf + + 633493 + 144973435 + EVERGREEN SHIPPING AGENCY + 16000 DALLAS PKWY STE 400 + DALLAS + TX + US + + + + IMMEDIATE + + 831777 + 144973435 + EVERGREEN SHIPPING AGENCY + DALLAS + TX + US + + + + + + 9619 + Net + + + IN SHIPPER BOND + TCS + + + + 2023-11-27T21:49:55Z + + + + + 677a8cbc-6de2-4bb5-b778-9b8abbdf935e + equipment/equipment-cycle-status-changed + 1.3 + cn=dneq999,ou=uprr,o=up + + 01f14f9f-2464-448b-9051-2da30b5f3fcd + InterchangeReceipt + 2023-11-27T21:48:00Z + 2023-11-27T21:49:50Z + 2023-11-27T21:49:56Z + OIPT223 + UP + InterchangeReceiptUI + + InterchangeReceipt + + + false + false + 2023-11-27T21:48:00Z + + + Opened + + + Current + + + + + 053f9df9-932f-4795-ab21-cb66546e9314 + equipment/equipment-interchanged + 1.3 + cn=dneq999,ou=uprr,o=up + + 01f14f9f-2464-448b-9051-2da30b5f3fcd + InterchangeReceipt + 2023-11-27T21:48:00Z + 2023-11-27T21:49:50Z + 2023-11-27T21:49:56Z + OIPT223 + UP + InterchangeReceiptUI + + InterchangeReceipt + + + + Stop + + false + false + 2023-11-27T21:48:00Z + + 7609 + + + + EqStopReason + EventCode + InterchangeReceipt + + + + false + NEW + false + InterchangeReceipt + false + false + 47264fad-5172-4de8-b513-4f2f4ea89a9e + + OpeningCycleStop + + + + ETS + 7609 + 2023-11-27T21:48:00Z + + + + + + 4e530f8a-9ccd-445f-a807-6b9ab09e625d + equipment/equipment-cycle-status-changed + 1.3 + cn=dneq999,ou=uprr,o=up + + 01f14f9f-2464-448b-9051-2da30b5f3fcd + InterchangeReceipt + 2023-11-27T21:48:00Z + 2023-11-27T21:49:50Z + 2023-11-27T21:49:56Z + OIPT223 + UP + InterchangeReceiptUI + + InterchangeReceipt + + + false + false + 2023-11-17T16:36:06Z + + + Unopened + Deleted + + + Future + Deleted + + + + + + + + + 502238 + + 502238 + HL011 + 435 + DIT + DIT + DIT + TX + US + America/Chicago + 514167 + 667285000 + 57810 + 046 + + + PRPTYARD + 11609 + + + DRN + 336 + + + BKP + 046 + + + + + + 7609 + + 7609 + CS526 + 17793 + TICTF + TICTF + TCTF + CA + US + America/Los_Angeles + 142054 + 883213000 + 15004 + 240 + + + PRPTYARD + 13975 + + + DRN + 158 + + + BKP + 240 + + + + + + + + 142054 + + 60058 + 142054 + 883213000 + TICTF + TICTF + CA + US + PT + true + + + + + + 2023-11-27T21:50:01.306Z + + true + + + f0e765d9-599b-46bf-9aef-bd33e0c2183f + equipment/equipment-placement-at-industry-planned + + + 053f9df9-932f-4795-ab21-cb66546e9314 + equipment/equipment-interchanged + + + 677a8cbc-6de2-4bb5-b778-9b8abbdf935e + equipment/equipment-cycle-status-changed + + + 4e530f8a-9ccd-445f-a807-6b9ab09e625d + equipment/equipment-cycle-status-changed + + + fb7cd7ed-4d05-4c04-b793-11c51923722f + equipment/key-equipment-properties-changed + + + dd9a4616-e41c-4571-9bda-9a5506a2b78d + equipment/equipment-waybill-applied + + + + + + + Load + + + true + + + false + + + Tare + Umler + 8000 + false + + + + + diff --git a/tests/resources/xml_infer_mixed.xml b/tests/resources/xml_infer_mixed.xml new file mode 100644 index 0000000000..d6abaa04b1 --- /dev/null +++ b/tests/resources/xml_infer_mixed.xml @@ -0,0 +1,41 @@ + + + + + + Child 1.1 + + Child 1.2 + + 25 + text1 + Hello World + + 1 + 2 + + + + + + Child 2.1 + + Child 2.2 + + 30 + text2 + Plain text only + + + + Child 3 + + 35 + 999 + Another Nested + + 3 + 4 + + + diff --git a/tests/resources/xml_infer_types.xml b/tests/resources/xml_infer_types.xml new file mode 100644 index 0000000000..b622c447dc --- /dev/null +++ b/tests/resources/xml_infer_types.xml @@ -0,0 +1,57 @@ + + + + Widget A + 44.95 + 100 + true + 2021-02-01 + 2015-01-01 00:00:00 + + electronics + gadget + + + 1.5 + + 10 + 5 + 3 + + + Excellent + + + Widget B + 29.99 + 50 + false + 2022-06-15 + 2023-03-10 12:30:00 + + home + + + 2.3 + + 20 + 10 + 8 + + + Good + + + Widget C + 15.00 + 200 + true + 2020-11-30 + + office + supplies + paper + + Average + + diff --git a/tests/unit/scala/test_utils_suite.py b/tests/unit/scala/test_utils_suite.py index 6b18c6aa9f..02ca8412e0 100644 --- a/tests/unit/scala/test_utils_suite.py +++ b/tests/unit/scala/test_utils_suite.py @@ -313,11 +313,14 @@ def check_zip_files_and_close_stream(input_stream, expected_files): "resources/books.xml", "resources/books.xsd", "resources/books2.xml", + "resources/books_attribute_value.xml", "resources/broken.csv", "resources/cat.jpeg", "resources/conversation.ogg", - "resources/diamonds.csv", + "resources/dblp_6kb.xml", "resources/declared_namespace.xml", + "resources/diamonds.csv", + "resources/dk_trace_sample.xml", "resources/doc.pdf", "resources/dog.jpg", "resources/fias_house.xml", @@ -390,6 +393,8 @@ def check_zip_files_and_close_stream(input_stream, expected_files): "resources/test_udaf_dir/test_udaf_file.py", "resources/undeclared_attr_namespace.xml", "resources/undeclared_namespace.xml", + "resources/xml_infer_mixed.xml", + "resources/xml_infer_types.xml", "resources/xxe.xml", ], ) diff --git a/tests/unit/test_xml_reader.py b/tests/unit/test_xml_reader.py index 85326d87d8..444a65ce95 100644 --- a/tests/unit/test_xml_reader.py +++ b/tests/unit/test_xml_reader.py @@ -28,6 +28,9 @@ _escape_colons_in_quotes, _restore_colons_in_template, _COLON_PLACEHOLDER, + _can_cast_to_type, + _validate_row_for_type_mismatch, + XMLReader, ) from snowflake.snowpark.types import ( StructType, @@ -37,6 +40,9 @@ DateType, ArrayType, MapType, + LongType, + BooleanType, + TimestampType, ) @@ -919,7 +925,8 @@ def test_schema_string_to_result_dict_and_struct_type(session): ) attr, _, _ = session.read._get_schema_from_user_input(user_schema) schema_string = attribute_to_schema_string_deep(attr) - assert schema_string_to_result_dict_and_struct_type(schema_string) == { + template, schema_type = schema_string_to_result_dict_and_struct_type(schema_string) + assert template == { "Author": None, "TITLE": None, "GENRE": None, @@ -928,6 +935,7 @@ def test_schema_string_to_result_dict_and_struct_type(session): "description": None, "map_type": None, } + assert isinstance(schema_type, StructType) def test_user_schema_value_tag(): @@ -1048,7 +1056,8 @@ def test_restore_colons_in_template(): def test_schema_string_round_trip_with_colons(): # Flat schema_str = 'struct<"px:name": string, "px:value": string>' - assert schema_string_to_result_dict_and_struct_type(schema_str) == { + template, _ = schema_string_to_result_dict_and_struct_type(schema_str) + assert template == { "px:name": None, "px:value": None, } @@ -1057,14 +1066,16 @@ def test_schema_string_round_trip_with_colons(): schema_str = ( 'struct<"eq:event-id": string,' '"eq:detail": struct<"eq:sub-id": string>>' ) - assert schema_string_to_result_dict_and_struct_type(schema_str) == { + template, _ = schema_string_to_result_dict_and_struct_type(schema_str) + assert template == { "eq:event-id": None, "eq:detail": {"eq:sub-id": None}, } # Mixed schema_str = 'struct<"px:name": string, "Title": string, price: double>' - assert schema_string_to_result_dict_and_struct_type(schema_str) == { + template, _ = schema_string_to_result_dict_and_struct_type(schema_str) + assert template == { "px:name": None, "Title": None, "PRICE": None, @@ -1072,7 +1083,267 @@ def test_schema_string_round_trip_with_colons(): # No colon fields schema_str = 'struct<"Author": string, "TITLE": string>' - assert schema_string_to_result_dict_and_struct_type(schema_str) == { + template, _ = schema_string_to_result_dict_and_struct_type(schema_str) + assert template == { "Author": None, "TITLE": None, } + + +@pytest.mark.parametrize( + "value, target_type, expected", + [ + # StringType always passes + ("anything", StringType(), True), + ("", StringType(), True), + # LongType + ("42", LongType(), True), + ("-7", LongType(), True), + ("0", LongType(), True), + ("3.14", LongType(), False), + ("hello", LongType(), False), + ("", LongType(), False), + # DoubleType + ("3.14", DoubleType(), True), + ("-0.5", DoubleType(), True), + ("42", DoubleType(), True), + ("NaN", DoubleType(), True), # Python float("NaN") succeeds + ("hello", DoubleType(), False), + # BooleanType + ("true", BooleanType(), True), + ("false", BooleanType(), True), + ("True", BooleanType(), True), + ("1", BooleanType(), True), + ("0", BooleanType(), True), + ("yes", BooleanType(), False), + ("maybe", BooleanType(), False), + # DateType + ("2024-01-15", DateType(), True), + ("not-a-date", DateType(), False), + ("2024-13-01", DateType(), False), + # TimestampType + ("2024-01-15T10:30:00", TimestampType(), True), + ("2024-01-15", TimestampType(), True), + ("not-a-ts", TimestampType(), False), + ], +) +def test_can_cast_to_type(value, target_type, expected): + assert _can_cast_to_type(value, target_type) == expected + + +# --------------------------------------------------------------------------- +# _validate_row_for_type_mismatch tests +# --------------------------------------------------------------------------- + + +def _make_schema(*fields): + """Helper to build a StructType from (name, datatype) tuples.""" + return StructType([StructField(f'"{n}"', t) for n, t in fields]) + + +def test_validate_permissive_all_valid_no_corrupt(): + schema = _make_schema(("name", StringType()), ("age", LongType())) + row = {"name": "Alice", "age": "30"} + result = _validate_row_for_type_mismatch(row, schema, "PERMISSIVE", "") + assert result["name"] == "Alice" + assert result["age"] == "30" + assert "_corrupt_record" not in result + + +def test_validate_permissive_single_field_nulled(): + schema = _make_schema(("name", StringType()), ("age", LongType())) + row = {"name": "Bob", "age": "hello"} + result = _validate_row_for_type_mismatch(row, schema, "PERMISSIVE", "") + assert result["name"] == "Bob" + assert result["age"] is None + assert result["_corrupt_record"] == "" + + +def test_validate_permissive_multiple_fields_nulled(): + schema = _make_schema(("a", LongType()), ("b", BooleanType()), ("c", DoubleType())) + row = {"a": "not_int", "b": "maybe", "c": "1.5"} + result = _validate_row_for_type_mismatch(row, schema, "PERMISSIVE", "") + assert result["a"] is None + assert result["b"] is None + assert result["c"] == "1.5" + assert result["_corrupt_record"] == "" + + +def test_validate_permissive_missing_field_ignored(): + schema = _make_schema(("name", StringType()), ("age", LongType())) + row = {"name": "Carol"} + result = _validate_row_for_type_mismatch(row, schema, "PERMISSIVE", "") + assert result["name"] == "Carol" + assert "_corrupt_record" not in result + + +def test_validate_permissive_none_value_ignored(): + schema = _make_schema(("val", LongType())) + row = {"val": None} + result = _validate_row_for_type_mismatch(row, schema, "PERMISSIVE", "") + assert result["val"] is None + assert "_corrupt_record" not in result + + +def test_validate_permissive_complex_type_skipped(): + schema = _make_schema( + ("data", StructType([StructField("x", StringType())])), + ("tags", ArrayType(StringType())), + ) + row = {"data": {"x": "val"}, "tags": ["a", "b"]} + result = _validate_row_for_type_mismatch(row, schema, "PERMISSIVE", "") + assert result["data"] == {"x": "val"} + assert result["tags"] == ["a", "b"] + assert "_corrupt_record" not in result + + +def test_validate_permissive_custom_corrupt_col_name(): + schema = _make_schema(("val", LongType())) + row = {"val": "bad"} + result = _validate_row_for_type_mismatch( + row, schema, "PERMISSIVE", "", column_name_of_corrupt_record="bad_rec" + ) + assert result["val"] is None + assert result["bad_rec"] == "" + assert "_corrupt_record" not in result + + +def test_validate_failfast_raises_on_mismatch(): + schema = _make_schema(("val", LongType())) + row = {"val": "hello"} + with pytest.raises(RuntimeError, match="Failed to cast value 'hello'"): + _validate_row_for_type_mismatch(row, schema, "FAILFAST", "") + + +def test_validate_failfast_no_error_when_valid(): + schema = _make_schema(("val", LongType())) + row = {"val": "42"} + result = _validate_row_for_type_mismatch(row, schema, "FAILFAST", "") + assert result["val"] == "42" + + +def test_validate_dropmalformed_returns_none_on_mismatch(): + schema = _make_schema(("val", LongType())) + row = {"val": "hello"} + assert ( + _validate_row_for_type_mismatch(row, schema, "DROPMALFORMED", "") is None + ) + + +def test_validate_dropmalformed_returns_row_when_valid(): + schema = _make_schema(("val", LongType())) + row = {"val": "42"} + result = _validate_row_for_type_mismatch(row, schema, "DROPMALFORMED", "") + assert result["val"] == "42" + + +def test_schema_string_to_result_dict_empty_string(): + template, schema_type = schema_string_to_result_dict_and_struct_type("") + assert template is None + assert schema_type is None + + +def test_can_cast_to_complex_type_returns_true(): + assert _can_cast_to_type("anything", ArrayType(StringType())) is True + assert _can_cast_to_type("anything", MapType(StringType(), StringType())) is True + assert _can_cast_to_type("anything", StructType([])) is True + + +def test_element_to_dict_leaf_with_template_edge_cases(): + xml_str = "" + element = ET.fromstring(xml_str) + result_template = {"_VALUE": None, "_country": None, "_language": None} + result = element_to_dict_or_str(element, result_template=result_template) + assert isinstance(result, dict) + assert result["_VALUE"] is None + assert result["_country"] is None + assert result["_language"] is None + + xml_str = "Penguin" + element = ET.fromstring(xml_str) + result = element_to_dict_or_str(element, result_template=result_template) + assert isinstance(result, dict) + assert result["_VALUE"] == "Penguin" + assert result["_country"] is None + assert result["_language"] is None + + xml_str = "N/A" + element = ET.fromstring(xml_str) + result_template = {"_VALUE": None, "_country": None} + result = element_to_dict_or_str( + element, result_template=result_template, null_value="N/A" + ) + assert isinstance(result, dict) + assert result["_VALUE"] is None + + +def test_process_xml_range_scos_permissive_type_validation(): + """process_xml_range validates types when is_snowpark_connect_compatible=True""" + xml_content = ( + "" + "Alice100" + "Frankhello" + "" + ) + xml_bytes = xml_content.encode("utf-8") + schema = _make_schema(("name", StringType()), ("value", LongType())) + mock_file = io.BytesIO(xml_bytes) + with patch("snowflake.snowpark.files.SnowflakeFile.open", return_value=mock_file): + results = list( + process_xml_range( + file_path="test.xml", + tag_name="ROW", + approx_start=0, + approx_end=len(xml_bytes), + mode="PERMISSIVE", + column_name_of_corrupt_record="_corrupt_record", + ignore_namespace=True, + attribute_prefix="_", + exclude_attributes=False, + value_tag="_VALUE", + null_value="", + charset="utf-8", + ignore_surrounding_whitespace=True, + row_validation_xsd_path="", + result_template={"name": None, "value": None}, + schema_type=schema, + is_snowpark_connect_compatible=True, + ) + ) + assert len(results) == 2 + frank = [r for r in results if r.get("name") == "Frank"][0] + assert frank["value"] is None + assert "_corrupt_record" in frank + + +def test_xml_reader_process_with_scos_compatible_param(): + """XMLReader.process passes is_snowpark_connect_compatible through""" + xml_content = "1" + xml_bytes = xml_content.encode("utf-8") + mock_file = io.BytesIO(xml_bytes) + with patch( + "snowflake.snowpark._internal.xml_reader.get_file_size", + return_value=len(xml_bytes), + ), patch("snowflake.snowpark.files.SnowflakeFile.open", return_value=mock_file): + results = list( + XMLReader().process( + "test.xml", + 1, + "record", + 0, + "PERMISSIVE", + "_corrupt_record", + True, + "_", + False, + "_VALUE", + "", + "utf-8", + True, + "", + "", + True, + ) + ) + assert len(results) == 1 + assert results[0][0]["a"] == "1" diff --git a/tests/unit/test_xml_schema_inference.py b/tests/unit/test_xml_schema_inference.py new file mode 100644 index 0000000000..5267cbdf5f --- /dev/null +++ b/tests/unit/test_xml_schema_inference.py @@ -0,0 +1,1893 @@ +# +# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved. +# + +import io +from contextlib import ExitStack +from unittest import mock +from unittest.mock import patch + +import lxml.etree as ET +import pytest + +from snowflake.snowpark._internal.xml_schema_inference import ( + _normalize_text, + infer_type, + infer_element_schema, + compatible_type, + merge_struct_types, + add_or_update_type, + canonicalize_type, + _case_preserving_simple_string, + _struct_has_value_tag, + _merge_struct_with_primitive, + infer_schema_for_xml_range, + XMLSchemaInference, +) +from snowflake.snowpark import DataFrameReader +import snowflake.snowpark.dataframe_reader as _dr_mod +from snowflake.snowpark.types import ( + StructType, + StructField, + ArrayType, + MapType, + NullType, + StringType, + BooleanType, + LongType, + DoubleType, + DecimalType, + DateType, + TimestampType, +) + + +def _xml(xml_str: str) -> ET.Element: + """Parse an XML string into an lxml Element.""" + return ET.fromstring(xml_str) + + +def _infer_and_merge(xml_records, **kwargs): + """Infer schema for each record and merge, mimicking the UDTF behavior.""" + merged = None + value_tag = kwargs.get("value_tag", "_VALUE") + for rec_str in xml_records: + rec_str = rec_str.strip() + if not rec_str: + continue + try: + elem = _xml(rec_str) + except Exception: + continue + schema = infer_element_schema(elem, **kwargs) + if not isinstance(schema, StructType): + schema = StructType( + [StructField(value_tag, schema, nullable=True)], + structured=False, + ) + if merged is None: + merged = schema + else: + merged = merge_struct_types(merged, schema, value_tag) + return merged + + +def _full_pipeline(xml_records, **kwargs): + """Run the full schema inference pipeline: infer + merge + canonicalize.""" + merged = _infer_and_merge(xml_records, **kwargs) + if merged is not None: + merged = canonicalize_type(merged) + return merged + + +# =========================================================================== +# _normalize_text +# =========================================================================== + + +@pytest.mark.parametrize( + "text, strip, expected", + [ + (None, True, None), + (None, False, None), + (" hello ", True, "hello"), + ("\n\thello\n\t", True, "hello"), + (" hello ", False, " hello "), + ("", True, ""), + ("", False, ""), + (" ", True, ""), + (" ", False, " "), + ], +) +def test_normalize_text(text, strip, expected): + assert _normalize_text(text, strip) == expected + + +# =========================================================================== +# infer_type +# =========================================================================== + + +@pytest.mark.parametrize( + "value, kwargs", + [ + (None, {}), + ("", {}), + ("NULL", {"null_value": "NULL"}), + (" NULL ", {"ignore_surrounding_whitespace": True, "null_value": "NULL"}), + ], +) +def test_infer_type_null(value, kwargs): + """Null/empty/null_value inputs → NullType.""" + assert isinstance(infer_type(value, **kwargs), NullType) + + +@pytest.mark.parametrize( + "value", + [ + "42", + "-42", + "0", + str(2**63 - 1), + str(-(2**63)), + "+123", + "1", + "+0", + "-0", + "007", + ], +) +def test_infer_type_long(value): + """Integer strings within Long range → LongType.""" + assert isinstance(infer_type(value), LongType) + + +@pytest.mark.parametrize( + "value", + [ + "3.14", + "1.5e10", + "-2.5", + ".5", + "0.0", + "+3.14", + "NaN", + "Infinity", + "+Infinity", + "-Infinity", + "92233720368547758070", + "-92233720368547758080", + "1.7976931348623157e+308", + ], +) +def test_infer_type_double(value): + """Decimals, scientific notation, special floats, and beyond-Long ints → DoubleType.""" + assert isinstance(infer_type(value), DoubleType) + + +@pytest.mark.parametrize("value", ["true", "false", "True", "FALSE"]) +def test_infer_type_boolean(value): + assert isinstance(infer_type(value), BooleanType) + + +@pytest.mark.parametrize("value", ["2024-01-15", "2000-12-31", "1999-01-01"]) +def test_infer_type_date(value): + assert isinstance(infer_type(value), DateType) + + +@pytest.mark.parametrize( + "value", + ["2024-01-15T10:30:00", "2024-01-15T10:30:00+00:00", "2011-12-03T10:15:30Z"], +) +def test_infer_type_timestamp(value): + assert isinstance(infer_type(value), TimestampType) + + +@pytest.mark.parametrize( + "value", + [ + "hello", + "abc123", + "foo@bar.com", + "inf", + "Inf", + "+inf", + "+Inf", + "1.5d", + "1.5f", + "1.5D", + "1.5F", + ], +) +def test_infer_type_string(value): + """Plain strings and values rejected by d/D/f/F suffix check → StringType.""" + assert isinstance(infer_type(value), StringType) + + +def test_infer_type_whitespace_stripped(): + """With ignore_surrounding_whitespace, ' 42 ' is parsed as Long.""" + assert isinstance( + infer_type(" 42 ", ignore_surrounding_whitespace=True), LongType + ) + assert isinstance( + infer_type(" 42 ", ignore_surrounding_whitespace=False), StringType + ) + + +# =========================================================================== +# infer_element_schema – leaf elements +# =========================================================================== + + +@pytest.mark.parametrize( + "xml_str, expected_type, kwargs", + [ + ("42", LongType, {}), + ("", NullType, {}), + ("", NullType, {}), + ("N/A", NullType, {"null_value": "N/A"}), + ("true", BooleanType, {}), + ("false", BooleanType, {}), + ("3.14", DoubleType, {}), + ("hello", StringType, {}), + ("2024-01-15", DateType, {}), + ("2024-01-15T10:30:00", TimestampType, {}), + ("2024-01-15T10:30:00Z", TimestampType, {}), + ("NaN", DoubleType, {}), + ("92233720368547758070", DoubleType, {}), + ], +) +def test_infer_element_schema_leaf(xml_str, expected_type, kwargs): + """Leaf elements (no children, no attributes) infer scalar types.""" + assert isinstance(infer_element_schema(_xml(xml_str), **kwargs), expected_type) + + +# =========================================================================== +# infer_element_schema – attributes +# =========================================================================== + + +def test_infer_element_schema_attributes_only_root(): + """Root-level attribute-only element has no _VALUE.""" + result = infer_element_schema(_xml(''), is_root=True) + assert isinstance(result, StructType) + field_names = {f._name for f in result.fields} + assert "_name" in field_names and "_age" in field_names + assert "_VALUE" not in field_names + + +def test_infer_element_schema_attributes_only_child(): + """Child-level attribute-only element adds _VALUE as NullType.""" + result = infer_element_schema(_xml(''), is_root=False) + field_names = {f._name for f in result.fields} + assert "_name" in field_names and "_VALUE" in field_names + + +def test_infer_element_schema_attributes_with_text(): + """Element with attributes and text adds _VALUE.""" + result = infer_element_schema(_xml('hello'), is_root=True) + field_names = {f._name for f in result.fields} + assert "_name" in field_names and "_VALUE" in field_names + + +def test_infer_element_schema_custom_attribute_prefix(): + result = infer_element_schema( + _xml(''), attribute_prefix="@", is_root=True + ) + assert "@name" in {f._name for f in result.fields} + + +def test_infer_element_schema_exclude_attributes(): + """When exclude_attributes=True, attributes are ignored.""" + assert isinstance( + infer_element_schema( + _xml('hello'), exclude_attributes=True + ), + StringType, + ) + assert isinstance( + infer_element_schema(_xml(''), exclude_attributes=True), + NullType, + ) + + +# =========================================================================== +# infer_element_schema – children +# =========================================================================== + + +def test_infer_element_schema_simple_children(): + result = infer_element_schema(_xml("1hello")) + assert isinstance(result, StructType) + assert {"a", "b"} <= {f._name for f in result.fields} + + +def test_infer_element_schema_child_types(): + result = infer_element_schema(_xml("42true3.14")) + field_map = {f._name: f.datatype for f in result.fields} + assert isinstance(field_map["a"], LongType) + assert isinstance(field_map["b"], BooleanType) + assert isinstance(field_map["c"], DoubleType) + + +def test_infer_element_schema_repeated_children_array(): + """Repeated child tags → ArrayType.""" + result = infer_element_schema( + _xml("123") + ) + field_map = {f._name: f.datatype for f in result.fields} + assert isinstance(field_map["item"], ArrayType) + assert isinstance(field_map["item"].element_type, LongType) + + +def test_infer_element_schema_nested_children(): + result = infer_element_schema( + _xml("hello") + ) + outer_field = next(f for f in result.fields if f._name == "outer") + assert isinstance(outer_field.datatype, StructType) + inner_field = next(f for f in outer_field.datatype.fields if f._name == "inner") + assert isinstance(inner_field.datatype, StringType) + + +def test_infer_element_schema_child_with_attributes(): + result = infer_element_schema(_xml('text')) + child_field = next(f for f in result.fields if f._name == "child") + assert isinstance(child_field.datatype, StructType) + child_names = {f._name for f in child_field.datatype.fields} + assert "_id" in child_names and "_VALUE" in child_names + + +# =========================================================================== +# infer_element_schema – mixed content +# =========================================================================== + + +def test_infer_element_schema_mixed_content_value_tag(): + """Element with text and children: text goes into _VALUE.""" + result = infer_element_schema(_xml("some text1")) + field_names = {f._name for f in result.fields} + assert "_VALUE" in field_names and "a" in field_names + + +def test_infer_element_schema_whitespace_only_no_value_tag(): + """Whitespace-only text between children does not create _VALUE.""" + result = infer_element_schema(_xml("\n 1\n 2\n")) + assert "_VALUE" not in {f._name for f in result.fields} + + +# =========================================================================== +# infer_element_schema – namespace handling +# =========================================================================== + + +def test_infer_element_schema_namespace_clark(): + """Clark notation {uri}tag → strip namespace part.""" + from snowflake.snowpark._internal.xml_reader import strip_xml_namespaces + + elem = strip_xml_namespaces( + ET.fromstring('val') + ) + result = infer_element_schema(elem, ignore_namespace=True) + assert "child" in {f._name for f in result.fields} + + +def test_infer_element_schema_namespace_prefix(): + parser = ET.XMLParser(recover=True, ns_clean=True) + elem = ET.fromstring("val", parser) + result = infer_element_schema(elem, ignore_namespace=True) + assert "child" in {f._name for f in result.fields} + + +# =========================================================================== +# infer_element_schema – whitespace and sorting +# =========================================================================== + + +def test_infer_element_schema_whitespace_leaf(): + assert isinstance( + infer_element_schema(_xml(" 42 "), ignore_surrounding_whitespace=True), + LongType, + ) + assert isinstance( + infer_element_schema( + _xml(" 42 "), ignore_surrounding_whitespace=False + ), + StringType, + ) + + +def test_infer_element_schema_fields_sorted(): + result = infer_element_schema(_xml("123")) + names = [f._name for f in result.fields] + assert names == sorted(names) + + +def test_infer_element_schema_custom_value_tag(): + result = infer_element_schema(_xml('text'), value_tag="myval") + field_names = {f._name for f in result.fields} + assert "myval" in field_names and "_VALUE" not in field_names + + +# =========================================================================== +# Spark TestXmlData-derived tests (multi-record infer + merge) +# =========================================================================== + + +def test_spark_primitive_field_value_type_conflict(): + """Merging records with conflicting types for same fields.""" + records = [ + """ + 111.1 + true13.1str1 + """, + """ + 21474836470.9 + 12true + """, + """ + 2147483647092233720368547758070 + 100false + str1false + """, + """ + 214748365701.1 + 21474836470 + 92233720368547758070 + """, + ] + merged = _infer_and_merge(records) + fm = {f._name: f.datatype for f in canonicalize_type(merged).fields} + assert isinstance(fm["num_bool"], StringType) + assert isinstance(fm["num_num_1"], LongType) + assert isinstance(fm["num_num_2"], DoubleType) + assert isinstance(fm["num_num_3"], DoubleType) + assert isinstance(fm["num_str"], StringType) + assert isinstance(fm["str_bool"], StringType) + + +def test_spark_complex_field_value_type_conflict(): + """Merging arrays with singletons and struct/primitive conflicts.""" + records = [ + """ + 11 + 123 + + """, + """ + false + + """, + """ + + str + 456 + 78 + 9 + + """, + """ + + str1str233 + 7 + true + str + """, + ] + merged = _infer_and_merge( + records, ignore_surrounding_whitespace=True, null_value="" + ) + fm = {f._name: f.datatype for f in canonicalize_type(merged).fields} + assert isinstance(fm["array"], ArrayType) and isinstance( + fm["array"].element_type, LongType + ) + assert isinstance(fm["num_struct"], StringType) + assert isinstance(fm["str_array"], ArrayType) and isinstance( + fm["str_array"].element_type, StringType + ) + assert isinstance(fm["struct"], StructType) + assert isinstance(fm["struct_array"], ArrayType) and isinstance( + fm["struct_array"].element_type, StringType + ) + + +def test_spark_missing_fields(): + """Different records have different fields; merge produces superset.""" + records = [ + "true", + "21474836470", + "3344", + "true", + "str", + ] + fm = {f._name: f.datatype for f in _full_pipeline(records).fields} + assert isinstance(fm["a"], BooleanType) + assert isinstance(fm["b"], LongType) + assert isinstance(fm["c"], ArrayType) and isinstance(fm["c"].element_type, LongType) + assert isinstance(fm["d"], StructType) + assert isinstance(fm["e"], StringType) + + +def test_spark_null_struct(): + """Null/empty struct fields merged correctly.""" + records = [ + """27.31.100.29 + 1.abc.comUTF-8""", + "27.31.100.29", + "27.31.100.29", + "27.31.100.29", + ] + fm = {f._name: f.datatype for f in _full_pipeline(records).fields} + assert isinstance(fm["nullstr"], StringType) + assert isinstance(fm["ip"], StringType) + assert isinstance(fm["headers"], StructType) + hf = {f._name: f.datatype for f in fm["headers"].fields} + assert isinstance(hf["Host"], StringType) and isinstance(hf["Charset"], StringType) + + +def test_spark_empty_records(): + """Structs with empty/null children properly handled.""" + records = [ + "", + "", + "", + ] + fm = {f._name: f.datatype for f in _full_pipeline(records).fields} + assert isinstance(fm["a"], StructType) + a_struct = {f._name: f.datatype for f in fm["a"].fields}["struct"] + b_c = { + f._name: f.datatype + for f in {f._name: f.datatype for f in a_struct.fields}["b"].fields + }["c"] + assert isinstance(b_c, StringType) + assert isinstance(fm["b"], StructType) + assert isinstance({f._name: f.datatype for f in fm["b"].fields}["item"], ArrayType) + + +def test_spark_nulls_in_arrays(): + """Null entries in arrays don't disrupt type inference.""" + records = [ + """value1value2 + """, + """1 + """, + "", + ] + fm = {f._name: f.datatype for f in _full_pipeline(records).fields} + assert isinstance(fm["field1"], ArrayType) + assert isinstance(fm["field2"], ArrayType) + + +def test_spark_value_tags_type_conflict(): + """Mixed text and children with type conflicts in _VALUE. + + lxml element.text only captures text BEFORE the first child. + Tail text is not captured, diverging from Spark's StAX parser. + """ + records = [ + "\n13.1\n\n11\n\ntrue\n1", + "\nstring\n\n21474836470\n\nfalse\n2", + "\n12\n3", + ] + fm = { + f._name: f.datatype + for f in _full_pipeline(records, ignore_surrounding_whitespace=True).fields + } + assert isinstance(fm["_VALUE"], StringType) + assert isinstance(fm["a"], StructType) + af = {f._name: f.datatype for f in fm["a"].fields} + assert isinstance(af["_VALUE"], LongType) + bf = {f._name: f.datatype for f in af["b"].fields} + assert isinstance(bf["_VALUE"], StringType) + assert isinstance(bf["c"], LongType) + + +def test_spark_value_tag_conflict_name(): + """When valueTag="a" conflicts with child , they merge. + + lxml element.text only captures text BEFORE the first child. + """ + records_before = ["\n2\n1"] + fm = { + f._name: f.datatype + for f in _full_pipeline( + records_before, value_tag="a", ignore_surrounding_whitespace=True + ).fields + } + assert isinstance(fm["a"], ArrayType) and isinstance(fm["a"].element_type, LongType) + + records_after = ["1\n2\n"] + fm2 = { + f._name: f.datatype + for f in _full_pipeline( + records_after, value_tag="a", ignore_surrounding_whitespace=True + ).fields + } + assert isinstance(fm2["a"], LongType) + + +# =========================================================================== +# compatible_type +# =========================================================================== + + +@pytest.mark.parametrize( + "t1, t2, expected_type", + [ + (LongType(), LongType(), LongType), + (StringType(), StringType(), StringType), + (BooleanType(), BooleanType(), BooleanType), + (DoubleType(), DoubleType(), DoubleType), + (DateType(), DateType(), DateType), + (TimestampType(), TimestampType(), TimestampType), + (NullType(), NullType(), NullType), + ], +) +def test_compatible_type_same(t1, t2, expected_type): + """Same types stay the same.""" + assert isinstance(compatible_type(t1, t2), expected_type) + + +@pytest.mark.parametrize( + "t1, t2, expected_type", + [ + (NullType(), LongType(), LongType), + (LongType(), NullType(), LongType), + (NullType(), StringType(), StringType), + ], +) +def test_compatible_type_null_plus_t(t1, t2, expected_type): + """NullType + T → T.""" + assert isinstance(compatible_type(t1, t2), expected_type) + + +def test_compatible_type_null_plus_struct(): + s = StructType([StructField("a", LongType(), True)]) + assert isinstance(compatible_type(NullType(), s), StructType) + + +def test_compatible_type_null_plus_array(): + a = ArrayType(LongType(), structured=False) + assert isinstance(compatible_type(NullType(), a), ArrayType) + + +@pytest.mark.parametrize( + "t1, t2", + [(LongType(), DoubleType()), (DoubleType(), LongType())], +) +def test_compatible_type_long_double_widening(t1, t2): + """Long + Double → Double.""" + assert isinstance(compatible_type(t1, t2), DoubleType) + + +@pytest.mark.parametrize( + "t1, t2", + [(DoubleType(), DecimalType(10, 2)), (DecimalType(10, 2), DoubleType())], +) +def test_compatible_type_double_decimal(t1, t2): + """Double + Decimal → Double.""" + assert isinstance(compatible_type(t1, t2), DoubleType) + + +@pytest.mark.parametrize( + "t1, t2", + [(TimestampType(), DateType()), (DateType(), TimestampType())], +) +def test_compatible_type_timestamp_date(t1, t2): + """Timestamp + Date → Timestamp.""" + assert isinstance(compatible_type(t1, t2), TimestampType) + + +def test_compatible_type_long_plus_decimal(): + result = compatible_type(LongType(), DecimalType(10, 2)) + assert ( + isinstance(result, DecimalType) and result.precision == 22 and result.scale == 2 + ) + + result2 = compatible_type(DecimalType(10, 2), LongType()) + assert ( + isinstance(result2, DecimalType) + and result2.precision == 22 + and result2.scale == 2 + ) + + +def test_compatible_type_same_struct_merge(): + s1 = StructType( + [StructField("a", LongType(), True), StructField("b", StringType(), True)] + ) + s2 = StructType( + [StructField("a", DoubleType(), True), StructField("c", BooleanType(), True)] + ) + result = compatible_type(s1, s2) + fm = {f._name: f.datatype for f in result.fields} + assert ( + isinstance(fm["a"], DoubleType) + and isinstance(fm["b"], StringType) + and isinstance(fm["c"], BooleanType) + ) + + +def test_compatible_type_same_array_merge(): + result = compatible_type( + ArrayType(LongType(), structured=False), + ArrayType(DoubleType(), structured=False), + ) + assert isinstance(result, ArrayType) and isinstance(result.element_type, DoubleType) + + +def test_compatible_type_decimal_widen(): + result = compatible_type(DecimalType(10, 2), DecimalType(15, 5)) + assert ( + isinstance(result, DecimalType) and result.precision == 15 and result.scale == 5 + ) + + +def test_compatible_type_decimal_overflow_to_double(): + """Decimal with precision > 38 falls back to Double.""" + result = compatible_type(DecimalType(38, 10), DecimalType(38, 20)) + assert isinstance(result, DoubleType) + + +def test_compatible_type_array_plus_primitive(): + result = compatible_type(ArrayType(LongType(), structured=False), DoubleType()) + assert isinstance(result, ArrayType) and isinstance(result.element_type, DoubleType) + + result2 = compatible_type(LongType(), ArrayType(StringType(), structured=False)) + assert isinstance(result2, ArrayType) and isinstance( + result2.element_type, StringType + ) + + +def test_compatible_type_struct_value_tag_plus_primitive(): + s = StructType( + [ + StructField("_VALUE", LongType(), True), + StructField("_attr", StringType(), True), + ] + ) + result = compatible_type(s, DoubleType()) + fm = {f._name: f.datatype for f in result.fields} + assert isinstance(fm["_VALUE"], DoubleType) and isinstance(fm["_attr"], StringType) + + s2 = StructType( + [ + StructField("_VALUE", StringType(), True), + StructField("_id", LongType(), True), + ] + ) + result2 = compatible_type(BooleanType(), s2) + assert isinstance( + {f._name: f.datatype for f in result2.fields}["_VALUE"], StringType + ) + + +@pytest.mark.parametrize( + "t1, t2", + [ + (BooleanType(), LongType()), + (StringType(), LongType()), + (BooleanType(), DoubleType()), + (DateType(), LongType()), + (StringType(), BooleanType()), + (DateType(), DoubleType()), + (BooleanType(), DateType()), + (StringType(), DateType()), + (LongType(), BooleanType()), + ], +) +def test_compatible_type_fallback_string(t1, t2): + """Incompatible types → StringType.""" + assert isinstance(compatible_type(t1, t2), StringType) + + +def test_compatible_type_struct_no_value_tag_plus_primitive(): + """Struct without _VALUE + primitive → StringType fallback.""" + s = StructType([StructField("a", LongType(), True)]) + assert isinstance(compatible_type(s, LongType()), StringType) + + +# =========================================================================== +# _struct_has_value_tag +# =========================================================================== + + +def test_struct_has_value_tag(): + s_yes = StructType( + [StructField("_VALUE", LongType(), True), StructField("a", StringType(), True)] + ) + assert _struct_has_value_tag(s_yes, "_VALUE") is True + + s_no = StructType([StructField("a", LongType(), True)]) + assert _struct_has_value_tag(s_no, "_VALUE") is False + + s_custom = StructType([StructField("myval", LongType(), True)]) + assert _struct_has_value_tag(s_custom, "MYVAL") is True + assert _struct_has_value_tag(s_custom, "myval") is False + assert _struct_has_value_tag(s_custom, "_VALUE") is False + + +# =========================================================================== +# _merge_struct_with_primitive +# =========================================================================== + + +def test_merge_struct_with_primitive(): + s = StructType( + [ + StructField("_VALUE", LongType(), True), + StructField("_attr", StringType(), True), + ] + ) + result = _merge_struct_with_primitive(s, DoubleType(), "_VALUE") + fm = {f._name: f.datatype for f in result.fields} + assert isinstance(fm["_VALUE"], DoubleType) and isinstance(fm["_attr"], StringType) + + s2 = StructType( + [ + StructField("_VALUE", StringType(), True), + StructField("other", LongType(), True), + ] + ) + result2 = _merge_struct_with_primitive(s2, LongType(), "_VALUE") + fm2 = {f._name: f.datatype for f in result2.fields} + assert isinstance(fm2["_VALUE"], StringType) and isinstance(fm2["other"], LongType) + + +# =========================================================================== +# merge_struct_types +# =========================================================================== + + +def test_merge_struct_types_identical(): + s = StructType( + [StructField("a", LongType(), True), StructField("b", StringType(), True)] + ) + fm = {f._name: f.datatype for f in merge_struct_types(s, s).fields} + assert isinstance(fm["a"], LongType) and isinstance(fm["b"], StringType) + + +def test_merge_struct_types_disjoint(): + result = merge_struct_types( + StructType([StructField("a", LongType(), True)]), + StructType([StructField("b", StringType(), True)]), + ) + fm = {f._name: f.datatype for f in result.fields} + assert "a" in fm and "b" in fm + + +def test_merge_struct_types_widening(): + result = merge_struct_types( + StructType([StructField("a", LongType(), True)]), + StructType([StructField("a", DoubleType(), True)]), + ) + assert isinstance({f._name: f.datatype for f in result.fields}["a"], DoubleType) + + +def test_merge_struct_types_overlapping_plus_unique(): + s1 = StructType( + [StructField("a", LongType(), True), StructField("b", StringType(), True)] + ) + s2 = StructType( + [StructField("a", DoubleType(), True), StructField("c", BooleanType(), True)] + ) + fm = {f._name: f.datatype for f in merge_struct_types(s1, s2).fields} + assert ( + isinstance(fm["a"], DoubleType) + and isinstance(fm["b"], StringType) + and isinstance(fm["c"], BooleanType) + ) + + +def test_merge_struct_types_sorted(): + result = merge_struct_types( + StructType([StructField("z", LongType(), True)]), + StructType([StructField("a", StringType(), True)]), + ) + names = [f._name for f in result.fields] + assert names == sorted(names) + + +def test_merge_struct_types_case_sensitive(): + """'Name' and 'name' are separate fields (case-sensitive).""" + result = merge_struct_types( + StructType([StructField("Name", LongType(), True)]), + StructType([StructField("name", StringType(), True)]), + ) + fm = {f._name: f.datatype for f in result.fields} + assert isinstance(fm["Name"], LongType) and isinstance(fm["name"], StringType) + + +# =========================================================================== +# add_or_update_type +# =========================================================================== + + +def test_add_or_update_type_first_occurrence(): + d = {} + add_or_update_type(d, "a", LongType()) + assert isinstance(d["a"], LongType) + + +def test_add_or_update_type_second_promotes_array(): + d = {"a": LongType()} + add_or_update_type(d, "a", LongType()) + assert isinstance(d["a"], ArrayType) and isinstance(d["a"].element_type, LongType) + + +def test_add_or_update_type_second_different_types(): + d = {"a": LongType()} + add_or_update_type(d, "a", DoubleType()) + assert isinstance(d["a"], ArrayType) and isinstance(d["a"].element_type, DoubleType) + + +def test_add_or_update_type_third_merges(): + d = {"a": ArrayType(LongType(), structured=False)} + add_or_update_type(d, "a", DoubleType()) + assert isinstance(d["a"], ArrayType) and isinstance(d["a"].element_type, DoubleType) + + +def test_add_or_update_type_progressive(): + """Multiple occurrences keep widening the array element type.""" + d = {} + add_or_update_type(d, "a", LongType()) + assert isinstance(d["a"], LongType) + add_or_update_type(d, "a", LongType()) + assert isinstance(d["a"], ArrayType) + add_or_update_type(d, "a", DoubleType()) + assert isinstance(d["a"].element_type, DoubleType) + add_or_update_type(d, "a", StringType()) + assert isinstance(d["a"].element_type, StringType) + + +# =========================================================================== +# canonicalize_type +# =========================================================================== + + +@pytest.mark.parametrize( + "dtype, expected_type", + [ + (NullType(), StringType), + (LongType(), LongType), + (StringType(), StringType), + (DoubleType(), DoubleType), + (BooleanType(), BooleanType), + (DateType(), DateType), + (TimestampType(), TimestampType), + (DecimalType(10, 2), DecimalType), + (DecimalType(38, 18), DecimalType), + ], +) +def test_canonicalize_type_primitives(dtype, expected_type): + """NullType → StringType; other primitives unchanged.""" + assert isinstance(canonicalize_type(dtype), expected_type) + + +def test_canonicalize_type_array_of_null(): + result = canonicalize_type(ArrayType(NullType(), structured=False)) + assert isinstance(result, ArrayType) and isinstance(result.element_type, StringType) + + +def test_canonicalize_type_array_of_long(): + result = canonicalize_type(ArrayType(LongType(), structured=False)) + assert isinstance(result, ArrayType) and isinstance(result.element_type, LongType) + + +def test_canonicalize_type_array_of_empty_struct(): + """Array containing empty struct → None (removed).""" + result = canonicalize_type( + ArrayType(StructType([], structured=False), structured=False) + ) + assert result is None + + +def test_canonicalize_type_nested_array_null(): + """Array> → Array>.""" + inner = ArrayType(NullType(), structured=False) + result = canonicalize_type(ArrayType(inner, structured=False)) + assert isinstance(result.element_type, ArrayType) + assert isinstance(result.element_type.element_type, StringType) + + +def test_canonicalize_type_struct_null_fields(): + s = StructType( + [StructField("a", NullType(), True), StructField("b", LongType(), True)], + structured=False, + ) + fm = {f._name: f.datatype for f in canonicalize_type(s).fields} + assert isinstance(fm["a"], StringType) and isinstance(fm["b"], LongType) + + +def test_canonicalize_type_empty_struct(): + """Per SPARK-8093: empty structs → None.""" + assert canonicalize_type(StructType([], structured=False)) is None + + +def test_canonicalize_type_struct_empty_name_field(): + """Fields with empty names are removed; if none remain, struct is None.""" + assert ( + canonicalize_type( + StructType([StructField("", LongType(), True)], structured=False) + ) + is None + ) + + +def test_canonicalize_type_struct_mixed_fields(): + s = StructType( + [ + StructField("a", NullType(), True), + StructField("", LongType(), True), + StructField("b", DoubleType(), True), + ], + structured=False, + ) + result = canonicalize_type(s) + assert len(result.fields) == 2 + fm = {f._name: f.datatype for f in result.fields} + assert isinstance(fm["a"], StringType) and isinstance(fm["b"], DoubleType) + + +def test_canonicalize_type_nested_struct(): + """Nested struct: inner NullType fields become StringType.""" + inner = StructType([StructField("x", NullType(), True)], structured=False) + outer = StructType([StructField("child", inner, True)], structured=False) + result = canonicalize_type(outer) + assert isinstance(result.fields[0].datatype.fields[0].datatype, StringType) + + +def test_canonicalize_type_nested_struct_empty_inner(): + """Nested struct where inner becomes empty → field is removed.""" + inner = StructType([], structured=False) + outer = StructType( + [StructField("child", inner, True), StructField("other", LongType(), True)], + structured=False, + ) + result = canonicalize_type(outer) + assert len(result.fields) == 1 and result.fields[0]._name == "other" + + +# =========================================================================== +# _case_preserving_simple_string +# =========================================================================== + + +@pytest.mark.parametrize( + "dtype, expected", + [ + (StringType(), "string"), + (LongType(), "bigint"), + (DoubleType(), "double"), + (BooleanType(), "boolean"), + (DateType(), "date"), + (TimestampType(), "timestamp"), + (NullType(), "null"), + ], +) +def test_case_preserving_simple_string_primitives(dtype, expected): + assert _case_preserving_simple_string(dtype) == expected + + +def test_case_preserving_simple_string_struct(): + s = StructType([StructField("author", StringType(), True)], structured=False) + assert _case_preserving_simple_string(s) == "struct" + + +def test_case_preserving_simple_string_preserves_case(): + s = StructType( + [ + StructField("Author", StringType(), True), + StructField("title", LongType(), True), + ], + structured=False, + ) + result = _case_preserving_simple_string(s) + assert "Author:string" in result and "title:bigint" in result + + +def test_case_preserving_simple_string_nested(): + inner = StructType([StructField("inner_field", LongType(), True)], structured=False) + outer = StructType([StructField("outer", inner, True)], structured=False) + assert ( + _case_preserving_simple_string(outer) + == "struct>" + ) + + +def test_case_preserving_simple_string_array(): + s = StructType( + [StructField("items", ArrayType(StringType(), structured=False), True)], + structured=False, + ) + assert _case_preserving_simple_string(s) == "struct>" + + +def test_case_preserving_simple_string_array_of_struct(): + inner = StructType([StructField("name", StringType(), True)], structured=False) + s = StructType( + [StructField("people", ArrayType(inner, structured=False), True)], + structured=False, + ) + assert ( + _case_preserving_simple_string(s) == "struct>>" + ) + + +def test_case_preserving_simple_string_colon_in_name(): + """Field names with colons get quoted.""" + s = StructType([StructField("px:name", StringType(), True)], structured=False) + assert _case_preserving_simple_string(s) == 'struct<"px:name":string>' + + s2 = StructType( + [ + StructField("px:name", StringType(), True), + StructField("px:value", LongType(), True), + ], + structured=False, + ) + result = _case_preserving_simple_string(s2) + assert '"px:name":string' in result and '"px:value":bigint' in result + + +def test_case_preserving_simple_string_no_unnecessary_quotes(): + s = StructType([StructField("author", StringType(), True)], structured=False) + assert '"' not in _case_preserving_simple_string(s) + + +# =========================================================================== +# End-to-end: infer + merge + canonicalize (Spark parity) +# =========================================================================== + + +def test_e2e_simple_flat_record(): + records = [ + """The Great GatsbyF. Scott Fitzgerald + 192512.99""" + ] + fm = {f._name: f.datatype for f in _full_pipeline(records).fields} + assert isinstance(fm["title"], StringType) and isinstance(fm["author"], StringType) + assert isinstance(fm["year"], LongType) and isinstance(fm["price"], DoubleType) + + +def test_e2e_repeated_elements_array(): + records = [ + "AppleBananaCherry" + ] + fm = {f._name: f.datatype for f in _full_pipeline(records).fields} + assert isinstance(fm["item"], ArrayType) and isinstance( + fm["item"].element_type, StringType + ) + + +def test_e2e_nested_struct_and_array(): + records = [ + """Test Book + user15 + user23 + """ + ] + fm = {f._name: f.datatype for f in _full_pipeline(records).fields} + assert isinstance(fm["title"], StringType) + rev_fields = {f._name: f.datatype for f in fm["reviews"].fields} + assert isinstance(rev_fields["review"], ArrayType) + elem_fields = { + f._name: f.datatype for f in rev_fields["review"].element_type.fields + } + assert isinstance(elem_fields["user"], StringType) and isinstance( + elem_fields["rating"], LongType + ) + + +def test_e2e_attributes_with_text(): + records = [ + """Test + """ + ] + fm = {f._name: f.datatype for f in _full_pipeline(records).fields} + assert isinstance(fm["_id"], LongType) and isinstance(fm["title"], StringType) + ef = {f._name: f.datatype for f in fm["edition"].fields} + assert isinstance(ef["_year"], LongType) and isinstance(ef["_format"], StringType) + assert isinstance(ef["_VALUE"], StringType) + + +def test_e2e_schema_evolution(): + """Field types evolve across records: Long + Double → Double.""" + records = [ + "Widget10", + "Gadget19.99", + ] + fm = {f._name: f.datatype for f in _full_pipeline(records).fields} + assert isinstance(fm["name"], StringType) and isinstance(fm["price"], DoubleType) + + +def test_e2e_schema_evolution_bool_to_string(): + """Boolean + String → String.""" + records = [ + "true", + "maybe", + ] + fm = {f._name: f.datatype for f in _full_pipeline(records).fields} + assert isinstance(fm["flag"], StringType) + + +def test_e2e_schema_evolution_date_to_timestamp(): + """Date + Timestamp → Timestamp.""" + records = [ + "2024-01-15", + "2024-01-15T10:30:00", + ] + fm = {f._name: f.datatype for f in _full_pipeline(records).fields} + assert isinstance(fm["when"], TimestampType) + + +def test_e2e_null_fields_canonicalized(): + """All-null fields become StringType after canonicalization.""" + records = ["hello", "world"] + fm = {f._name: f.datatype for f in _full_pipeline(records).fields} + assert isinstance(fm["a"], StringType) and isinstance(fm["b"], StringType) + + +def test_e2e_serialization_round_trip(): + records = [ + """Jane29.99 + fictionbestseller""" + ] + schema_str = _case_preserving_simple_string(_full_pipeline(records)) + assert "author:string" in schema_str and "price:double" in schema_str + assert "tags:struct>" in schema_str + + +def test_e2e_case_sensitive_fields(): + """'b' and 'B' are different fields (case-sensitive).""" + records = [ + "123", + "456", + ] + fm = {f._name: f.datatype for f in _full_pipeline(records).fields} + assert all(isinstance(fm[k], LongType) for k in ["a", "b", "B", "c"]) + + +def test_e2e_case_sensitive_value_tag(): + """ with children vs without → different fields.""" + records = [ + "\n1\n2", + "3", + ] + fm = { + f._name: f.datatype + for f in _full_pipeline(records, ignore_surrounding_whitespace=True).fields + } + assert isinstance(fm["A"], LongType) + assert isinstance(fm["a"], StructType) + af = {f._name: f.datatype for f in fm["a"].fields} + assert isinstance(af["_VALUE"], LongType) and isinstance(af["b"], LongType) + + +def test_e2e_case_sensitive_attributes(): + """'attr' and 'aTtr' are different attribute fields.""" + records = [ + '123', + '456', + ] + fm = {f._name: f.datatype for f in _full_pipeline(records).fields} + assert isinstance(fm["_attr"], LongType) and isinstance(fm["_aTtr"], LongType) + + +def test_e2e_case_sensitive_struct(): + """ and with children → different struct fields.""" + records = [ + "13", + "57", + ] + fm = {f._name: f.datatype for f in _full_pipeline(records).fields} + assert isinstance(fm["A"], StructType) and isinstance(fm["a"], StructType) + assert all( + isinstance(v, LongType) + for v in {f._name: f.datatype for f in fm["A"].fields}.values() + ) + + +def test_e2e_case_sensitive_array_complex(): + records = [ + "\n1\n234", + "5", + ] + fm = { + f._name: f.datatype + for f in _full_pipeline(records, ignore_surrounding_whitespace=True).fields + } + assert isinstance(fm["A"], LongType) + assert isinstance(fm["a"], StructType) + af = {f._name: f.datatype for f in fm["a"].fields} + assert isinstance(af["_VALUE"], LongType) and isinstance(af["b"], LongType) + + +def test_e2e_case_sensitive_array_simple(): + records = ["1234"] + fm = {f._name: f.datatype for f in _full_pipeline(records).fields} + assert isinstance(fm["A"], StructType) and isinstance(fm["a"], StructType) + assert isinstance({f._name: f.datatype for f in fm["A"].fields}["B"], LongType) + assert isinstance({f._name: f.datatype for f in fm["a"].fields}["b"], LongType) + + +def test_e2e_value_tags_spaces_and_empty_values(): + """lxml element.text only captures text BEFORE the first child.""" + records = [ + "\n str1\n 1\n ", + " value", + "3 ", + "4 ", + ] + fm_ws = { + f._name: f.datatype + for f in _full_pipeline(records, ignore_surrounding_whitespace=True).fields + } + assert isinstance(fm_ws["_VALUE"], StringType) + assert isinstance(fm_ws["a"], StructType) + assert isinstance({f._name: f.datatype for f in fm_ws["a"].fields}["b"], LongType) + + fm_nws = { + f._name: f.datatype + for f in _full_pipeline(records, ignore_surrounding_whitespace=False).fields + } + assert isinstance(fm_nws["_VALUE"], StringType) + + +def test_e2e_value_tags_multiline(): + """element.text only captures text before first child.""" + records = [ + "\nvalue1\n1", + "\nvalue3\nvalue41", + ] + fm = { + f._name: f.datatype + for f in _full_pipeline(records, ignore_surrounding_whitespace=True).fields + } + assert isinstance(fm["_VALUE"], StringType) and isinstance(fm["a"], LongType) + + +def test_e2e_value_tag_comments(): + """Comments should not affect value tag inference.""" + records = ['\n2\n'] + fm = { + f._name: f.datatype + for f in _full_pipeline(records, ignore_surrounding_whitespace=True).fields + } + assert isinstance(fm["_VALUE"], LongType) and isinstance(fm["a"], ArrayType) + + +def test_e2e_value_tag_null_value_option(): + """nullValue option doesn't affect schema inference.""" + fm = { + f._name: f.datatype + for f in _full_pipeline( + ["\n 1\n"], ignore_surrounding_whitespace=True + ).fields + } + assert isinstance(fm["_VALUE"], LongType) + + +# --------------------------------------------------------------------------- +# infer_schema_for_xml_range +# --------------------------------------------------------------------------- + + +def _mock_xml_range(xml_content, row_tag, **kwargs): + """Helper: run infer_schema_for_xml_range against an in-memory XML string.""" + xml_bytes = xml_content.encode("utf-8") + mock_file = kwargs.get("file_obj") or io.BytesIO(xml_bytes) + with patch("snowflake.snowpark.files.SnowflakeFile.open", return_value=mock_file): + return infer_schema_for_xml_range( + file_path="test.xml", + row_tag=row_tag, + approx_start=0, + approx_end=kwargs.get("approx_end", len(xml_bytes)), + sampling_ratio=kwargs.get("sampling_ratio", 1.0), + ignore_namespace=kwargs.get("ignore_namespace", True), + attribute_prefix=kwargs.get("attribute_prefix", "_"), + exclude_attributes=kwargs.get("exclude_attributes", False), + value_tag=kwargs.get("value_tag", "_VALUE"), + null_value=kwargs.get("null_value", ""), + charset="utf-8", + ignore_surrounding_whitespace=kwargs.get( + "ignore_surrounding_whitespace", True + ), + ) + + +def test_infer_range_single_leaf_record(): + """Single leaf record → StructType with _VALUE.""" + schema = _mock_xml_range("hello", "item") + assert schema is not None + assert len(schema.fields) == 1 + assert schema.fields[0]._name == "_VALUE" + assert isinstance(schema.fields[0].datatype, StringType) + + +def test_infer_range_multiple_typed_records(): + """Multiple records with different types → merged schema.""" + xml = "1true2false3.14" + schema = _mock_xml_range(xml, "row") + fm = {f._name: f.datatype for f in schema.fields} + assert isinstance(fm["a"], LongType) + assert isinstance(fm["b"], BooleanType) + assert isinstance(fm["c"], DoubleType) + + +def test_infer_range_type_widening(): + """Long in first record, Double in second → merged to DoubleType.""" + xml = "423.14" + schema = _mock_xml_range(xml, "row") + assert isinstance(schema.fields[0].datatype, DoubleType) + + +def test_infer_range_repeated_children_become_array(): + """Repeated child tags → ArrayType.""" + xml = "ab" + schema = _mock_xml_range(xml, "row") + assert isinstance(schema.fields[0].datatype, ArrayType) + assert isinstance(schema.fields[0].datatype.element_type, StringType) + + +def test_infer_range_attributes(): + """Attributes → prefixed struct fields.""" + xml = 'Alice' + schema = _mock_xml_range(xml, "row") + fm = {f._name: f.datatype for f in schema.fields} + assert isinstance(fm["_id"], LongType) + assert isinstance(fm["_active"], BooleanType) + assert isinstance(fm["name"], StringType) + + +def test_infer_range_nested_struct(): + """Nested elements → nested StructType.""" + xml = "
NYC10001
" + schema = _mock_xml_range(xml, "row") + addr = schema.fields[0] + assert addr._name == "address" + assert isinstance(addr.datatype, StructType) + nested_fm = {f._name: f.datatype for f in addr.datatype.fields} + assert isinstance(nested_fm["city"], StringType) + assert isinstance(nested_fm["zip"], LongType) + + +def test_infer_range_self_closing_tag(): + """Self-closing tags → StructType with attributes.""" + xml = '' + schema = _mock_xml_range(xml, "row") + assert len(schema.fields) == 1 + assert schema.fields[0]._name == "_id" + assert isinstance(schema.fields[0].datatype, LongType) + + +def test_infer_range_no_records(): + xml = "data" + schema = _mock_xml_range(xml, "row") + assert schema is None + + +def test_infer_range_malformed_record_skipped(): + """Malformed XML between valid records is skipped.""" + xml = "1<<2" + schema = _mock_xml_range(xml, "row") + assert schema is not None + assert isinstance(schema.fields[0].datatype, LongType) + + +def test_infer_range_namespace_stripped(): + """Namespaces are stripped when ignore_namespace=True.""" + xml = '42' + schema = _mock_xml_range(xml, "ns:row", ignore_namespace=True) + assert schema is not None + + +def test_infer_range_exclude_attributes(): + """exclude_attributes=True omits attribute fields.""" + xml = 'Alice' + schema = _mock_xml_range(xml, "row", exclude_attributes=True) + fm = {f._name for f in schema.fields} + assert "_id" not in fm + assert "name" in fm + + +def test_infer_range_mixed_content(): + xml = "text1" + schema = _mock_xml_range(xml, "row") + fm = {f._name: f.datatype for f in schema.fields} + assert "_VALUE" in fm + assert "child" in fm + + +def test_infer_range_schema_merge_across_records(): + """Fields present only in some records are merged into the union schema.""" + xml = "1hello2true" + schema = _mock_xml_range(xml, "row") + fm = {f._name: f.datatype for f in schema.fields} + assert isinstance(fm["a"], LongType) + assert isinstance(fm["b"], StringType) + assert isinstance(fm["c"], BooleanType) + + +def test_infer_range_sampling_ratio_skips(): + """With a very low sampling_ratio, some records may be skipped.""" + import random + + xml = "" + for i in range(20): + xml += f"{i}" + xml += "" + xml_bytes = xml.encode("utf-8") + + random.seed(42) + mock_file = io.BytesIO(xml_bytes) + with patch("snowflake.snowpark.files.SnowflakeFile.open", return_value=mock_file): + schema = infer_schema_for_xml_range( + file_path="test.xml", + row_tag="row", + approx_start=0, + approx_end=len(xml_bytes), + sampling_ratio=0.3, + ignore_namespace=True, + attribute_prefix="_", + exclude_attributes=False, + value_tag="_VALUE", + null_value="", + charset="utf-8", + ignore_surrounding_whitespace=True, + ) + assert schema is not None + assert isinstance(schema.fields[0].datatype, LongType) + + +def test_infer_range_truncated_tag_handled(): + """A truncated opening tag that causes tag_is_self_closing to fail is skipped.""" + xml = "12" + xml_bytes = xml.encode("utf-8") + mock_file = io.BytesIO(xml_bytes) + with patch("snowflake.snowpark.files.SnowflakeFile.open", return_value=mock_file): + schema = infer_schema_for_xml_range( + file_path="test.xml", + row_tag="row", + approx_start=0, + approx_end=len(xml_bytes), + sampling_ratio=1.0, + ignore_namespace=True, + attribute_prefix="_", + exclude_attributes=False, + value_tag="_VALUE", + null_value="", + charset="utf-8", + ignore_surrounding_whitespace=True, + ) + assert schema is not None + + +def test_xml_schema_inference_process_empty_results(): + """Cases where process should yield ("",): empty/None file size, invalid worker id, no records.""" + base = ("test.xml",) + tail = (1.0, True, "_", False, "_VALUE", "", "utf-8", True) + + # empty file (size 0) + with patch( + "snowflake.snowpark._internal.xml_schema_inference.get_file_size", + return_value=0, + ): + assert list(XMLSchemaInference().process(*base, 1, "row", 0, *tail)) == [("",)] + + # None file size + with patch( + "snowflake.snowpark._internal.xml_schema_inference.get_file_size", + return_value=None, + ): + assert list(XMLSchemaInference().process(*base, 1, "row", 0, *tail)) == [("",)] + + # worker id >= num_workers + with patch( + "snowflake.snowpark._internal.xml_schema_inference.get_file_size", + return_value=1000, + ): + assert list(XMLSchemaInference().process(*base, 2, "row", 5, *tail)) == [("",)] + + # no matching row tags + xml_bytes = b"data" + with patch( + "snowflake.snowpark._internal.xml_schema_inference.get_file_size", + return_value=len(xml_bytes), + ), patch( + "snowflake.snowpark.files.SnowflakeFile.open", + return_value=io.BytesIO(xml_bytes), + ): + assert list(XMLSchemaInference().process(*base, 1, "row", 0, *tail)) == [("",)] + + +def test_xml_schema_inference_process_param_defaults(): + """None num_workers defaults to 1; negative worker id defaults to 0.""" + xml_bytes = b"42" + tail = (1.0, True, "_", False, "_VALUE", "", "utf-8", True) + + for num_workers, i in [(None, 0), (1, -1)]: + mock_file = io.BytesIO(xml_bytes) + with patch( + "snowflake.snowpark._internal.xml_schema_inference.get_file_size", + return_value=len(xml_bytes), + ), patch("snowflake.snowpark.files.SnowflakeFile.open", return_value=mock_file): + results = list( + XMLSchemaInference().process("test.xml", num_workers, "row", i, *tail) + ) + assert len(results) == 1 + assert "bigint" in results[0][0] + + +def test_xml_schema_inference_process_with_sampling(): + """Sampling ratio < 1.0 sets a deterministic seed.""" + xml = "" + for i in range(10): + xml += f"{i}" + xml += "" + xml_bytes = xml.encode("utf-8") + mock_file = io.BytesIO(xml_bytes) + with patch( + "snowflake.snowpark._internal.xml_schema_inference.get_file_size", + return_value=len(xml_bytes), + ), patch("snowflake.snowpark.files.SnowflakeFile.open", return_value=mock_file): + results = list( + XMLSchemaInference().process( + "test.xml", + 1, + "row", + 0, + 0.5, + True, + "_", + False, + "_VALUE", + "", + "utf-8", + True, + ) + ) + assert len(results) == 1 + assert results[0][0] != "" + + +def test_xml_schema_inference_process_multi_worker(): + """Multi-worker processing partitions the file correctly.""" + xml = "123" + xml_bytes = xml.encode("utf-8") + + all_schemas = [] + for worker_id in range(2): + mock_file = io.BytesIO(xml_bytes) + with patch( + "snowflake.snowpark._internal.xml_schema_inference.get_file_size", + return_value=len(xml_bytes), + ), patch("snowflake.snowpark.files.SnowflakeFile.open", return_value=mock_file): + udtf = XMLSchemaInference() + results = list( + udtf.process( + "test.xml", + 2, + "row", + worker_id, + 1.0, + True, + "_", + False, + "_VALUE", + "", + "utf-8", + True, + ) + ) + all_schemas.append(results[0][0]) + assert any(s != "" for s in all_schemas) + + +def test_infer_element_schema_clark_ns_in_children(): + """Clark {uri}tag stripped with ignore_namespace=True""" + root = ET.Element("row") + ET.SubElement(root, "{http://example.com}child").text = "val" + assert "child" in { + f._name for f in infer_element_schema(root, ignore_namespace=True).fields + } + + +def test_infer_range_approx_end_breaks(): + """open_pos >= approx_end → immediate break.""" + assert _mock_xml_range("1", "row", approx_end=3) is None + + +def test_infer_range_tag_exception_continues(): + """Truncated tag at EOF → tag_is_self_closing raises → seek back, continue""" + assert _mock_xml_range("1long_value", + "row", + approx_end=10, + sampling_ratio=0.01, + ) + assert schema is None + + +def test_infer_range_parse_error_past_approx_end(): + """Parse error + record past approx_end → break""" + schema = _mock_xml_range( + "1", + "row", + approx_end=10, + ignore_namespace=False, + ) + assert schema is None + + +def test_infer_range_stdlib_et_fallback(): + """stdlib ET used when lxml not installed""" + import xml.etree.ElementTree as stdlib_ET + + with patch( + "snowflake.snowpark._internal.xml_schema_inference.lxml_installed", False + ), patch("snowflake.snowpark._internal.xml_schema_inference.ET", stdlib_ET): + assert _mock_xml_range("42", "row") is not None + + +def test_infer_range_seek_failures(): + class FailSeek(io.BytesIO): + """BytesIO that raises OSError on the Nth seek to a target position. + + skip=0 means fail on the very first seek to fail_pos. + skip=1 means allow the first seek (e.g. inside find_next_closing_tag_pos) + and fail on the second (the coverage-target seek). + """ + + def __init__(self, data, fail_pos, skip=0) -> None: + super().__init__(data) + self._fail_pos = fail_pos + self._skip = skip + self._hits = 0 + + def seek(self, pos, whence=0): + if whence == 0 and pos == self._fail_pos: + self._hits += 1 + if self._hits > self._skip: + raise OSError("forced") + return super().seek(pos, whence) + + def _end(b): + return b.find(b"
") + len(b"
") + + # successful parse → seek failure after merge + # skip=1: first seek to record_end happens inside find_next_closing_tag_pos + b = b"7" + assert ( + _mock_xml_range(b.decode(), "row", file_obj=FailSeek(b, _end(b), skip=1)) + is not None + ) + + # parse error → seek failure + b = b"1" + assert ( + _mock_xml_range( + b.decode(), + "row", + file_obj=FailSeek(b, _end(b), skip=1), + ignore_namespace=False, + ) + is None + ) + + # sampling skip → seek failure + b = b"123" + with patch("snowflake.snowpark._internal.xml_schema_inference.random") as rng: + rng.random.return_value = 0.99 + rng.seed = lambda *a, **kw: None + assert ( + _mock_xml_range( + b.decode(), + "row", + file_obj=FailSeek(b, _end(b), skip=1), + sampling_ratio=0.01, + ) + is None + ) + + # tag exception → seek failure on record_start + 1 + # skip=0: position record_start+1 is only hit in the except handler + b = b"1",)], + type_string=[StructType([StructField("a", LongType())])], + canonicalize=LongType(), + ) + assert result is None + + +def test_infer_xml_map_type_field_names_cleaned(): + """MapType fields get quoted names stripped.""" + result = _run_infer_xml( + [("x",)], + type_string=[StructType([StructField("a", LongType())])], + canonicalize=StructType( + [ + StructField( + '"m"', + MapType( + StructType([StructField('"k"', StringType())]), + StructType([StructField('"v"', LongType())]), + ), + ) + ] + ), + ) + assert result.fields[0]._name == "m" + assert result.fields[0].datatype.key_type.fields[0]._name == "k" + assert result.fields[0].datatype.value_type.fields[0]._name == "v" diff --git a/tests/utils.py b/tests/utils.py index a309ac0216..5fc5a352d7 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1717,6 +1717,26 @@ def test_null_value_xml(self): def test_dog_image(self): return os.path.join(self.resources_path, "dog.jpg") + @property + def test_dk_trace_sample_xml(self): + return os.path.join(self.resources_path, "dk_trace_sample.xml") + + @property + def test_dblp_6kb_xml(self): + return os.path.join(self.resources_path, "dblp_6kb.xml") + + @property + def test_books_attribute_value_xml(self): + return os.path.join(self.resources_path, "books_attribute_value.xml") + + @property + def test_xml_infer_types(self): + return os.path.join(self.resources_path, "xml_infer_types.xml") + + @property + def test_xml_infer_mixed(self): + return os.path.join(self.resources_path, "xml_infer_mixed.xml") + @property def test_books_xsd(self): return os.path.join(self.resources_path, "books.xsd")