diff --git a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py
index 06296eaea2..4918652dc3 100644
--- a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py
+++ b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py
@@ -1882,6 +1882,7 @@ def _create_xml_query(
lit(ignore_surrounding_whitespace),
lit(row_validation_xsd_path),
lit(schema_string),
+ lit(context._is_snowpark_connect_compatible_mode),
),
)
diff --git a/src/snowflake/snowpark/_internal/utils.py b/src/snowflake/snowpark/_internal/utils.py
index 12b2b13794..453b838a7e 100644
--- a/src/snowflake/snowpark/_internal/utils.py
+++ b/src/snowflake/snowpark/_internal/utils.py
@@ -199,6 +199,7 @@
# The following are not copy into SQL command options but client side options.
"INFER_SCHEMA",
"INFER_SCHEMA_OPTIONS",
+ "SAMPLING_RATIO",
"FORMAT_TYPE_OPTIONS",
"TARGET_COLUMNS",
"TRANSFORMATIONS",
@@ -210,6 +211,9 @@
XML_ROW_TAG_STRING = "ROWTAG"
XML_ROW_DATA_COLUMN_NAME = "ROW_DATA"
XML_READER_FILE_PATH = os.path.join(os.path.dirname(__file__), "xml_reader.py")
+XML_SCHEMA_INFERENCE_FILE_PATH = os.path.join(
+ os.path.dirname(__file__), "xml_schema_inference.py"
+)
XML_READER_API_SIGNATURE = "DataFrameReader.xml[rowTag]"
XML_READER_SQL_COMMENT = f"/* Python:snowflake.snowpark.{XML_READER_API_SIGNATURE} */"
diff --git a/src/snowflake/snowpark/_internal/xml_reader.py b/src/snowflake/snowpark/_internal/xml_reader.py
index 96cf5b723e..4a5300ff0b 100644
--- a/src/snowflake/snowpark/_internal/xml_reader.py
+++ b/src/snowflake/snowpark/_internal/xml_reader.py
@@ -2,6 +2,7 @@
# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
#
+import datetime
import os
import re
import html.entities
@@ -12,7 +13,18 @@
from snowflake.snowpark._internal.analyzer.analyzer_utils import unquote_if_quoted
from snowflake.snowpark._internal.type_utils import type_string_to_type_object
from snowflake.snowpark.files import SnowflakeFile
-from snowflake.snowpark.types import StructType, ArrayType, DataType, MapType
+from snowflake.snowpark.types import (
+ ArrayType,
+ BooleanType,
+ DataType,
+ DateType,
+ DoubleType,
+ LongType,
+ MapType,
+ StringType,
+ StructType,
+ TimestampType,
+)
# lxml is only a dev dependency so use try/except to import it if available
try:
@@ -102,13 +114,97 @@ def _restore_colons_in_template(template: Optional[dict]) -> Optional[dict]:
return restored
-def schema_string_to_result_dict_and_struct_type(schema_string: str) -> Optional[dict]:
+def schema_string_to_result_dict_and_struct_type(
+ schema_string: str,
+) -> Tuple[Optional[dict], Optional[StructType]]:
if schema_string == "":
- return None
+ return None, None
safe_string = _escape_colons_in_quotes(schema_string)
schema = type_string_to_type_object(safe_string)
template = struct_type_to_result_template(schema)
- return _restore_colons_in_template(template)
+ return _restore_colons_in_template(template), schema
+
+
+def _can_cast_to_type(value: str, target_type: DataType) -> bool:
+ if isinstance(target_type, StringType):
+ return True
+ if isinstance(target_type, LongType):
+ try:
+ int(value)
+ return True
+ except (ValueError, OverflowError):
+ return False
+ if isinstance(target_type, DoubleType):
+ try:
+ float(value)
+ return True
+ except ValueError:
+ return False
+ if isinstance(target_type, BooleanType):
+ return value.lower() in ("true", "false", "1", "0")
+ if isinstance(target_type, DateType):
+ try:
+ datetime.date.fromisoformat(value)
+ return True
+ except (ValueError, TypeError):
+ return False
+ if isinstance(target_type, TimestampType):
+ try:
+ datetime.datetime.fromisoformat(value)
+ return True
+ except (ValueError, TypeError):
+ return False
+ return True
+
+
+def _validate_row_for_type_mismatch(
+ row: dict,
+ schema: StructType,
+ mode: str,
+ record_str: str = "",
+ column_name_of_corrupt_record: str = "_corrupt_record",
+) -> Optional[dict]:
+ """Validate a parsed row dict against the expected schema types for mode handling:
+ - PERMISSIVE: set mismatched fields to ``None`` and store the raw XML
+ record in *column_name_of_corrupt_record* (Spark compatible).
+ - FAILFAST: raise immediately on the first mismatch.
+ - DROPMALFORMED: return ``None`` so the caller skips the row.
+
+ Only top-level primitive fields are validated because complex types
+ are kept as VARIANT and never cast downstream.
+ """
+ had_error = False
+ for field in schema.fields:
+ field_name = unquote_if_quoted(field.name)
+ if field_name not in row:
+ continue
+
+ value = row[field_name]
+ if value is None:
+ continue
+
+ # Skip complex types as these are kept as VARIANT
+ if isinstance(field.datatype, (StructType, ArrayType, MapType)):
+ continue
+
+ castable = isinstance(value, str) and _can_cast_to_type(value, field.datatype)
+ if not castable:
+ if mode == "FAILFAST":
+ raise RuntimeError(
+ f"Failed to cast value '{value}' to "
+ f"{field.datatype.simple_string()} for field "
+ f"'{field_name}'.\nXML record: {record_str}"
+ )
+ if mode == "DROPMALFORMED":
+ return None
+ # PERMISSIVE: null the bad field, continue checking remaining fields
+ row[field_name] = None
+ had_error = True
+
+ if had_error and mode == "PERMISSIVE":
+ row[column_name_of_corrupt_record] = record_str
+
+ return row
def struct_type_to_result_template(dt: DataType) -> Optional[dict]:
@@ -367,7 +463,18 @@ def get_text(element: ET.Element) -> Optional[str]:
children = list(element)
if not children and (not element.attrib or exclude_attributes):
- # it's a value element with no attributes or excluded attributes, so return the text
+ # When the schema (result_template) expects a struct for this element,
+ # wrap the text in the template so the output shape matches the schema.
+ # e.g. Some Publisher with template
+ # {"_VALUE": None, "_country": None, "_language": None} becomes
+ # {"_VALUE": "Some Publisher", "_country": None, "_language": None}
+ # instead of the raw string "Some Publisher".
+ if result_template is not None and isinstance(result_template, dict):
+ result = copy.deepcopy(result_template)
+ text = get_text(element)
+ if text is not None:
+ result[value_tag] = text
+ return result
return get_text(element)
result = copy.deepcopy(result_template) if result_template is not None else {}
@@ -445,6 +552,8 @@ def process_xml_range(
row_validation_xsd_path: str,
chunk_size: int = DEFAULT_CHUNK_SIZE,
result_template: Optional[dict] = None,
+ schema_type: Optional[StructType] = None,
+ is_snowpark_connect_compatible: bool = False,
) -> Iterator[Optional[Dict[str, Any]]]:
"""
Processes an XML file within a given approximate byte range.
@@ -475,6 +584,8 @@ def process_xml_range(
row_validation_xsd_path (str): Path to XSD file for row validation.
chunk_size (int): Size of chunks to read.
result_template(dict): a result template generate from user input schema
+ schema_type(StructType): the parsed StructType for row validation
+ is_snowpark_connect_compatible(bool): context._is_snowpark_connect_compatible_mode
Yields:
Optional[Dict[str, Any]]: Dictionary representation of the parsed XML element.
@@ -607,10 +718,22 @@ def process_xml_range(
ignore_surrounding_whitespace=ignore_surrounding_whitespace,
result_template=copy.deepcopy(result_template),
)
- if isinstance(result, dict):
- yield result
- else:
- yield {value_tag: result}
+ row = result if isinstance(result, dict) else {value_tag: result}
+
+ # Validate primitive field values against schema types in Snowpark Connect mode only
+ if schema_type is not None and is_snowpark_connect_compatible:
+ # Mode handling for type mismatch errors.
+ row = _validate_row_for_type_mismatch(
+ row,
+ schema_type,
+ mode,
+ record_str,
+ column_name_of_corrupt_record,
+ )
+
+ if row is not None:
+ yield row
+ # Mode handling for malformed XML records that fail to parse.
except ET.ParseError as e:
if mode == "PERMISSIVE":
yield {column_name_of_corrupt_record: record_str}
@@ -645,6 +768,7 @@ def process(
ignore_surrounding_whitespace: bool,
row_validation_xsd_path: str,
custom_schema: str,
+ is_snowpark_connect_compatible: bool,
):
"""
Splits the file into byte ranges—one per worker—by starting with an even
@@ -668,12 +792,15 @@ def process(
ignore_surrounding_whitespace (bool): Whether or not whitespaces surrounding values should be skipped.
row_validation_xsd_path (str): Path to XSD file for row validation.
custom_schema: User input schema for xml, must be used together with row tag.
+ is_snowpark_connect_compatible (bool): context._is_snowpark_connect_compatible_mode
"""
file_size = get_file_size(filename)
approx_chunk_size = file_size // num_workers
approx_start = approx_chunk_size * i
approx_end = approx_chunk_size * (i + 1) if i < num_workers - 1 else file_size
- result_template = schema_string_to_result_dict_and_struct_type(custom_schema)
+ result_template, schema_type = schema_string_to_result_dict_and_struct_type(
+ custom_schema
+ )
for element in process_xml_range(
filename,
row_tag,
@@ -690,5 +817,7 @@ def process(
ignore_surrounding_whitespace,
row_validation_xsd_path=row_validation_xsd_path,
result_template=result_template,
+ schema_type=schema_type,
+ is_snowpark_connect_compatible=is_snowpark_connect_compatible,
):
yield (element,)
diff --git a/src/snowflake/snowpark/_internal/xml_schema_inference.py b/src/snowflake/snowpark/_internal/xml_schema_inference.py
new file mode 100644
index 0000000000..15ee78d6b8
--- /dev/null
+++ b/src/snowflake/snowpark/_internal/xml_schema_inference.py
@@ -0,0 +1,711 @@
+#
+# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
+#
+
+import re
+import random
+from datetime import datetime, date
+from typing import Optional, Dict, List
+
+from snowflake.snowpark._internal.xml_reader import (
+ DEFAULT_CHUNK_SIZE,
+ get_file_size,
+ find_next_opening_tag_pos,
+ tag_is_self_closing,
+ find_next_closing_tag_pos,
+ strip_xml_namespaces,
+ replace_entity,
+)
+from snowflake.snowpark.files import SnowflakeFile
+from snowflake.snowpark.types import (
+ StructType,
+ ArrayType,
+ DataType,
+ NullType,
+ StringType,
+ BooleanType,
+ LongType,
+ DoubleType,
+ DecimalType,
+ DateType,
+ TimestampType,
+ StructField,
+)
+
+# lxml is only a dev dependency so use try/except to import it if available
+try:
+ import lxml.etree as ET
+
+ lxml_installed = True
+except ImportError:
+ import xml.etree.ElementTree as ET
+
+ lxml_installed = False
+
+
+# ---------------------------------------------------------------------------
+# Stage 1 – Per-record type inference
+# ---------------------------------------------------------------------------
+
+
+def _normalize_text(
+ text: Optional[str], ignore_surrounding_whitespace: bool
+) -> Optional[str]:
+ """Normalize text by stripping whitespace if configured."""
+ if text is None:
+ return None
+ return text.strip() if ignore_surrounding_whitespace else text
+
+
+def _infer_primitive_type(text: str) -> DataType:
+ """
+ Infer the DataType from a single string value with below priority order:
+ - null/empty -> NullType
+ - parseable as Long -> LongType
+ - parseable as Double -> DoubleType
+ - "true"/"false" -> BooleanType
+ - parseable as Date -> DateType (ISO format: yyyy-mm-dd)
+ - parseable as Timestamp -> TimestampType (ISO format)
+ - anything else -> StringType
+ """
+ # Long (matching Spark: infers integers as LongType directly)
+ try:
+ sign_safe = text.lstrip("+-")
+ if sign_safe and sign_safe[0].isdigit() and "." not in text:
+ val = int(text)
+ # Spark's Long is 64-bit; Python's int is unbounded
+ if -(2**63) <= val <= 2**63 - 1:
+ return LongType()
+ # Numbers outside Long range fall through to Double
+ except (ValueError, OverflowError):
+ pass
+
+ # Double
+ try:
+ sign_safe = text.lstrip("+-")
+ is_numeric_start = sign_safe and (sign_safe[0].isdigit() or sign_safe[0] == ".")
+ is_special_float = text.lower() in (
+ "nan",
+ "infinity",
+ "+infinity",
+ "-infinity",
+ "inf",
+ "+inf",
+ "-inf",
+ )
+ if is_numeric_start or is_special_float:
+ # Reject strings ending in d/D/f/F
+ if text[-1] not in ("d", "D", "f", "F"):
+ float(text)
+ return DoubleType()
+ except (ValueError, OverflowError):
+ pass
+
+ # Boolean
+ if text.lower() in ("true", "false"):
+ return BooleanType()
+
+ # Date (Spark default pattern: yyyy-MM-dd via DateFormatter)
+ try:
+ date.fromisoformat(text)
+ return DateType()
+ except (ValueError, TypeError):
+ pass
+
+ # Timestamp (Spark default: TimestampFormatter with CAST logic)
+ # Backward compatibility: Python 3.9 fromisoformat doesn't support 'Z' suffix; replace with +00:00
+ try:
+ ts_text = text.replace("Z", "+00:00") if text.endswith("Z") else text
+ datetime.fromisoformat(ts_text)
+ return TimestampType()
+ except (ValueError, TypeError):
+ pass
+
+ return StringType()
+
+
+def infer_type(
+ value: Optional[str],
+ ignore_surrounding_whitespace: bool = False,
+ null_value: str = "",
+) -> DataType:
+ """
+ Infer the DataType from a single string value.
+ Normalizes *value*, checks for null / empty / null_value, then delegates
+ to :func:`_infer_primitive_type`.
+ """
+ text = _normalize_text(value, ignore_surrounding_whitespace)
+ if text is None or text == null_value or text == "":
+ return NullType()
+ return _infer_primitive_type(text)
+
+
+def infer_element_schema(
+ element: ET.Element,
+ attribute_prefix: str = "_",
+ exclude_attributes: bool = False,
+ value_tag: str = "_VALUE",
+ null_value: str = "",
+ ignore_surrounding_whitespace: bool = False,
+ ignore_namespace: bool = False,
+ is_root: bool = True,
+) -> DataType:
+ """
+ Infer the schema (DataType) from a parsed XML Element.
+ - Elements with no children and no attributes (or excluded) -> infer from text (primitive/NullType)
+ - Elements with children -> StructType with child field types
+ - Elements with attributes -> StructType with attribute fields
+ - Mixed content (text + children/attributes) -> StructType with _VALUE field
+ - Repeated child tags -> ArrayType detection via add_or_update_type
+
+ is_root: True for the row-tag element (top-level), False for child elements.
+ Spark treats these differently: at root level, self-closing attribute-only
+ elements do NOT get _VALUE. At child level, they always get _VALUE
+ (NullType -> StringType after canonicalization).
+ """
+ children = list(element)
+ has_attributes = bool(element.attrib) and not exclude_attributes
+
+ # Case: leaf element with no attributes -> infer from text content
+ if not children and not has_attributes:
+ return infer_type(element.text, ignore_surrounding_whitespace, null_value)
+
+ # This element will become a StructType
+ # Use a dict to track field names and types (for array detection)
+ name_to_type: Dict[str, DataType] = {}
+ field_order: List[str] = []
+
+ # Process attributes first
+ if has_attributes:
+ for attr_name, attr_value in element.attrib.items():
+ prefixed_name = f"{attribute_prefix}{attr_name}"
+ attr_type = infer_type(
+ attr_value, ignore_surrounding_whitespace, null_value
+ )
+ if prefixed_name not in name_to_type:
+ field_order.append(prefixed_name)
+ add_or_update_type(name_to_type, prefixed_name, attr_type, value_tag)
+
+ if children:
+ for child in children:
+ child_tag = child.tag
+ # Ignore namespace in tag if configured (both Clark and prefix notation)
+ if ignore_namespace:
+ if "}" in child_tag:
+ child_tag = child_tag.split("}", 1)[1]
+ elif ":" in child_tag:
+ child_tag = child_tag.split(":", 1)[1]
+
+ # Check if child has attributes
+ child_has_attrs = bool(child.attrib) and not exclude_attributes
+
+ # inferField dispatch
+ child_children = list(child)
+ if not child_children and not child_has_attrs:
+ # Leaf element
+ child_type = infer_type(
+ child.text, ignore_surrounding_whitespace, null_value
+ )
+ else:
+ # Non-leaf element: recurse into inferObject
+ child_type = infer_element_schema(
+ child,
+ attribute_prefix=attribute_prefix,
+ exclude_attributes=exclude_attributes,
+ value_tag=value_tag,
+ null_value=null_value,
+ ignore_surrounding_whitespace=ignore_surrounding_whitespace,
+ ignore_namespace=ignore_namespace,
+ is_root=False,
+ )
+
+ # When child_has_attrs is True, the recursive infer_element_schema call
+ # above already processes child attributes and includes them in the
+ # returned StructType. No additional attribute processing needed here.
+ if child_tag not in name_to_type:
+ field_order.append(child_tag)
+ add_or_update_type(name_to_type, child_tag, child_type, value_tag)
+
+ # Handle mixed content: text + child elements
+ text = _normalize_text(element.text, ignore_surrounding_whitespace)
+ if text is not None and text != null_value and text.strip() != "":
+ text_type = _infer_primitive_type(text)
+ if value_tag not in name_to_type:
+ field_order.append(value_tag)
+ add_or_update_type(name_to_type, value_tag, text_type, value_tag)
+ else:
+ # No children but has attributes -> conditionally include _VALUE.
+ # [SPARK PARITY] Spark's behavior differs by element level:
+ # - Root/row-tag level: _VALUE only added when actual text exists
+ # - Child level: _VALUE always added (NullType if no text, canonicalized
+ # to StringType later). This covers cases like self-closing
+ # .
+ text = _normalize_text(element.text, ignore_surrounding_whitespace)
+ if text is not None and text != null_value and text != "":
+ text_type = _infer_primitive_type(text)
+ if value_tag not in name_to_type:
+ field_order.append(value_tag)
+ add_or_update_type(name_to_type, value_tag, text_type, value_tag)
+ elif not is_root:
+ # Child-level attribute-only element: add _VALUE as NullType
+ if value_tag not in name_to_type:
+ field_order.append(value_tag)
+ add_or_update_type(name_to_type, value_tag, NullType(), value_tag)
+
+ # Build the StructType with sorted fields to match Spark behavior.
+ result_fields = sorted(
+ (StructField(name, name_to_type[name], nullable=True) for name in field_order),
+ key=lambda f: f.name,
+ )
+ return StructType(result_fields)
+
+
+# ---------------------------------------------------------------------------
+# Stage 2 – Cross-partition merge
+# ---------------------------------------------------------------------------
+
+
+def compatible_type(t1: DataType, t2: DataType, value_tag: str = "_VALUE") -> DataType:
+ """
+ Returns the most general data type for two given data types.
+ """
+ # Same type
+ if type(t1) == type(t2):
+ if isinstance(t1, StructType):
+ return merge_struct_types(t1, t2, value_tag)
+ if isinstance(t1, ArrayType):
+ return ArrayType(
+ compatible_type(t1.element_type, t2.element_type, value_tag),
+ )
+ if isinstance(t1, DecimalType):
+ # Widen decimal
+ scale = max(t1.scale, t2.scale)
+ range_ = max(t1.precision - t1.scale, t2.precision - t2.scale)
+ if range_ + scale > 38:
+ return DoubleType()
+ return DecimalType(range_ + scale, scale)
+ return t1
+
+ # NullType + T -> T
+ if isinstance(t1, NullType):
+ return t2
+ if isinstance(t2, NullType):
+ return t1
+
+ # Numeric widening: Long < Double
+ if (isinstance(t1, LongType) and isinstance(t2, DoubleType)) or (
+ isinstance(t1, DoubleType) and isinstance(t2, LongType)
+ ):
+ return DoubleType()
+
+ # Double + Decimal -> Double
+ if (isinstance(t1, DoubleType) and isinstance(t2, DecimalType)) or (
+ isinstance(t1, DecimalType) and isinstance(t2, DoubleType)
+ ):
+ return DoubleType()
+
+ # Long + Decimal -> Decimal (widened)
+ if isinstance(t1, LongType) and isinstance(t2, DecimalType):
+ # DecimalType.forType(LongType) in Spark is Decimal(20, 0)
+ return compatible_type(DecimalType(20, 0), t2, value_tag)
+ if isinstance(t1, DecimalType) and isinstance(t2, LongType):
+ return compatible_type(t1, DecimalType(20, 0), value_tag)
+
+ # Timestamp + Date -> Timestamp
+ if (isinstance(t1, TimestampType) and isinstance(t2, DateType)) or (
+ isinstance(t1, DateType) and isinstance(t2, TimestampType)
+ ):
+ return TimestampType()
+
+ # Array + non-Array -> ArrayType(compatible)
+ if isinstance(t1, ArrayType):
+ return ArrayType(compatible_type(t1.element_type, t2, value_tag))
+ if isinstance(t2, ArrayType):
+ return ArrayType(compatible_type(t1, t2.element_type, value_tag))
+
+ # Struct with _VALUE tag + Primitive -> widen _VALUE field
+ if isinstance(t1, StructType) and _struct_has_value_tag(t1, value_tag):
+ return _merge_struct_with_primitive(t1, t2, value_tag)
+ if isinstance(t2, StructType) and _struct_has_value_tag(t2, value_tag):
+ return _merge_struct_with_primitive(t2, t1, value_tag)
+
+ # Fallback: anything else -> StringType
+ return StringType()
+
+
+def _struct_has_value_tag(st: StructType, value_tag: str) -> bool:
+ """Check if a StructType has a field with the given value_tag name."""
+ return any(f.name == value_tag for f in st.fields)
+
+
+def _merge_struct_with_primitive(
+ st: StructType, primitive: DataType, value_tag: str
+) -> StructType:
+ """
+ Merge a StructType containing a value_tag field with a primitive type.
+ The value_tag field's type is widened to be compatible with the primitive.
+ """
+ new_fields = []
+ for f in st.fields:
+ if f.name == value_tag:
+ new_type = compatible_type(f.datatype, primitive, value_tag)
+ new_fields.append(StructField(f.name, new_type, nullable=True))
+ else:
+ new_fields.append(f)
+ return StructType(new_fields)
+
+
+def merge_struct_types(
+ a: StructType, b: StructType, value_tag: str = "_VALUE"
+) -> StructType:
+ """
+ Merge two StructTypes field-by-field (case-sensitive).
+ Fields present in both are merged via compatible_type.
+ Fields present in only one are included as-is (nullable).
+
+ Uses f._name (original case from XML) rather than f.name (uppercased by ColumnIdentifier).
+ """
+ field_map: Dict[str, DataType] = {}
+ field_order: List[str] = []
+
+ for f in a.fields:
+ field_map[f._name] = f.datatype
+ field_order.append(f._name)
+
+ for f in b.fields:
+ if f._name in field_map:
+ field_map[f._name] = compatible_type(
+ field_map[f._name], f.datatype, value_tag
+ )
+ else:
+ field_map[f._name] = f.datatype
+ field_order.append(f._name)
+
+ # Sort fields by name to match Spark behavior
+ return StructType(
+ sorted(
+ (StructField(name, field_map[name], nullable=True) for name in field_order),
+ key=lambda f: f._name,
+ ),
+ )
+
+
+def add_or_update_type(
+ name_to_type: Dict[str, DataType],
+ field_name: str,
+ new_type: DataType,
+ value_tag: str = "_VALUE",
+) -> None:
+ """
+ Array detection logic:
+ - 1st occurrence of field_name -> store the type as-is
+ - 2nd occurrence -> wrap into ArrayType(compatible(old, new))
+ - Nth occurrence -> the existing type is already ArrayType, merge via compatible_type
+ """
+ if field_name in name_to_type:
+ old_type = name_to_type[field_name]
+ if not isinstance(old_type, ArrayType):
+ # 2nd occurrence: promote to ArrayType
+ name_to_type[field_name] = ArrayType(
+ compatible_type(old_type, new_type, value_tag),
+ )
+ else:
+ # Already an ArrayType: merge element types
+ name_to_type[field_name] = compatible_type(old_type, new_type, value_tag)
+ else:
+ name_to_type[field_name] = new_type
+
+
+# ---------------------------------------------------------------------------
+# Stage 3 – Canonicalization
+# ---------------------------------------------------------------------------
+
+
+def canonicalize_type(dt: DataType) -> Optional[DataType]:
+ """
+ Convert NullType to StringType and remove StructTypes with no fields:
+ - NullType -> StringType
+ - Empty StructType -> None
+ - ArrayType -> recurse on element type
+ - StructType -> recurse, remove empty-name fields
+ - Other -> kept as-is
+ """
+ if isinstance(dt, NullType):
+ return StringType()
+
+ if isinstance(dt, ArrayType):
+ canonical_element = canonicalize_type(dt.element_type)
+ if canonical_element is not None:
+ return ArrayType(canonical_element)
+ return None
+
+ if isinstance(dt, StructType):
+ canonical_fields = []
+ for f in dt.fields:
+ if f._name == "":
+ continue
+ canonical_child = canonicalize_type(f.datatype)
+ if canonical_child is not None:
+ canonical_fields.append(
+ StructField(f._name, canonical_child, nullable=True)
+ )
+ if canonical_fields:
+ return StructType(canonical_fields)
+ # empty structs should be deleted
+ return None
+
+ return dt
+
+
+# ---------------------------------------------------------------------------
+# Schema string serialization
+# ---------------------------------------------------------------------------
+
+
+def _case_preserving_simple_string(dt: DataType) -> str:
+ """
+ Serialize a DataType to a simple string, preserving original field name case.
+ Wrap field names with colons in double quotes so that the schema string parser
+ can correctly find the top-level colon separating name from type.
+ The consumer of these strings must strip the outer quotes after parsing.
+ """
+ if isinstance(dt, StructType):
+ parts = []
+ for f in dt.fields:
+ name = f._name
+ # Wrap names with colons in double quotes so the parser can
+ # distinguish the name:type separator from colons in the name.
+ if ":" in name:
+ name = f'"{name}"'
+ parts.append(f"{name}:{_case_preserving_simple_string(f.datatype)}")
+ return f"struct<{','.join(parts)}>"
+ elif isinstance(dt, ArrayType):
+ return f"array<{_case_preserving_simple_string(dt.element_type)}>"
+ else:
+ return dt.simple_string()
+
+
+# ---------------------------------------------------------------------------
+# XMLSchemaInference UDTF
+# ---------------------------------------------------------------------------
+
+
+def infer_schema_for_xml_range(
+ file_path: str,
+ row_tag: str,
+ approx_start: int,
+ approx_end: int,
+ sampling_ratio: float,
+ ignore_namespace: bool,
+ attribute_prefix: str,
+ exclude_attributes: bool,
+ value_tag: str,
+ null_value: str,
+ charset: str,
+ ignore_surrounding_whitespace: bool,
+ chunk_size: int = DEFAULT_CHUNK_SIZE,
+) -> Optional[StructType]:
+ """
+ Infer the merged XML schema for all records within a byte range.
+
+ Scans the file from *approx_start* to *approx_end*, parses each
+ XML record delimited by *row_tag*, infers per-record schemas, and
+ merges them into a single StructType.
+
+ Returns:
+ The merged StructType, or None if no records were found.
+ """
+ tag_start_1 = f"<{row_tag}>".encode()
+ tag_start_2 = f"<{row_tag} ".encode()
+ closing_tag = f"{row_tag}>".encode()
+
+ merged_schema: Optional[StructType] = None
+
+ with SnowflakeFile.open(file_path, "rb", require_scoped_url=False) as f:
+ f.seek(approx_start)
+
+ while True:
+ try:
+ open_pos = find_next_opening_tag_pos(
+ f, tag_start_1, tag_start_2, approx_end, chunk_size
+ )
+ except EOFError:
+ break
+
+ if open_pos >= approx_end:
+ break
+
+ record_start = open_pos
+ f.seek(record_start)
+
+ try:
+ is_self_close, tag_end = tag_is_self_closing(f, chunk_size)
+ if is_self_close:
+ record_end = tag_end
+ else:
+ f.seek(tag_end)
+ record_end = find_next_closing_tag_pos(f, closing_tag, chunk_size)
+ except Exception:
+ try:
+ f.seek(min(record_start + 1, approx_end))
+ except Exception:
+ break
+ continue
+
+ if sampling_ratio < 1.0 and random.random() > sampling_ratio:
+ if record_end > approx_end:
+ break
+ try:
+ f.seek(min(record_end, approx_end))
+ except Exception:
+ break
+ continue
+
+ try:
+ f.seek(record_start)
+ record_bytes = f.read(record_end - record_start)
+ record_str = record_bytes.decode(charset, errors="replace")
+ record_str = re.sub(r"&(\w+);", replace_entity, record_str)
+
+ if lxml_installed:
+ recover = bool(":" in row_tag)
+ parser = ET.XMLParser(recover=recover, ns_clean=True)
+ try:
+ element = ET.fromstring(record_str, parser)
+ except ET.XMLSyntaxError:
+ if ignore_namespace:
+ cleaned = re.sub(r"\s+(\w+):(\w+)=", r" \2=", record_str)
+ element = ET.fromstring(cleaned, parser)
+ else:
+ raise
+ else:
+ element = ET.fromstring(record_str)
+
+ if ignore_namespace:
+ element = strip_xml_namespaces(element)
+ except Exception:
+ if record_end > approx_end:
+ break
+ try:
+ f.seek(min(record_end, approx_end))
+ except Exception:
+ break
+ continue
+
+ record_schema = infer_element_schema(
+ element,
+ attribute_prefix=attribute_prefix,
+ exclude_attributes=exclude_attributes,
+ value_tag=value_tag,
+ null_value=null_value,
+ ignore_surrounding_whitespace=ignore_surrounding_whitespace,
+ ignore_namespace=ignore_namespace,
+ )
+
+ if not isinstance(record_schema, StructType):
+ record_schema = StructType(
+ [StructField(value_tag, record_schema, nullable=True)],
+ )
+
+ if merged_schema is None:
+ merged_schema = record_schema
+ else:
+ merged_schema = merge_struct_types(
+ merged_schema, record_schema, value_tag
+ )
+
+ if record_end > approx_end:
+ break
+ try:
+ f.seek(min(record_end, approx_end))
+ except Exception:
+ break
+
+ return merged_schema
+
+
+class XMLSchemaInference:
+ """
+ UDTF handler for parallelized XML schema inference.
+
+ Each worker reads its assigned byte range of the XML file, parses
+ XML records, infers a per-record schema, and merges all schemas
+ within its partition. The merged schema is yielded as a serialized
+ string (StructType.simple_string()).
+
+ The parallelization pattern mirrors XMLReader.
+ """
+
+ def process(
+ self,
+ filename: str,
+ num_workers: int,
+ row_tag: str,
+ i: int,
+ sampling_ratio: float,
+ ignore_namespace: bool,
+ attribute_prefix: str,
+ exclude_attributes: bool,
+ value_tag: str,
+ null_value: str,
+ charset: str,
+ ignore_surrounding_whitespace: bool,
+ ):
+ """
+ Infer XML schema for a byte-range partition of the file.
+
+ Args:
+ filename: Path to the XML file.
+ num_workers: Total number of workers.
+ row_tag: The tag that delimits records.
+ i: This worker's ID (0-based).
+ sampling_ratio: Fraction of records to sample (0.0-1.0).
+ ignore_namespace: Whether to strip namespaces.
+ attribute_prefix: Prefix for attribute names.
+ exclude_attributes: Whether to exclude attributes.
+ value_tag: Tag name for the value column.
+ null_value: Value to treat as null.
+ charset: Character encoding of the XML file.
+ ignore_surrounding_whitespace: Whether to strip whitespace from values.
+ """
+ file_size = get_file_size(filename)
+ if not file_size or file_size <= 0:
+ yield ("",)
+ return
+
+ if num_workers is None or num_workers <= 0:
+ num_workers = 1
+ if i is None or i < 0:
+ i = 0
+ if i >= num_workers:
+ yield ("",)
+ return
+
+ approx_chunk_size = file_size // num_workers
+ approx_start = approx_chunk_size * i
+ approx_end = approx_chunk_size * (i + 1) if i < num_workers - 1 else file_size
+
+ # Deterministic per-worker seed for Bernoulli sampling:
+ if sampling_ratio < 1.0:
+ random.seed(1 + i)
+
+ merged_schema = infer_schema_for_xml_range(
+ file_path=filename,
+ row_tag=row_tag,
+ approx_start=approx_start,
+ approx_end=approx_end,
+ sampling_ratio=sampling_ratio,
+ ignore_namespace=ignore_namespace,
+ attribute_prefix=attribute_prefix,
+ exclude_attributes=exclude_attributes,
+ value_tag=value_tag,
+ null_value=null_value,
+ charset=charset,
+ ignore_surrounding_whitespace=ignore_surrounding_whitespace,
+ )
+
+ yield (
+ _case_preserving_simple_string(merged_schema)
+ if merged_schema is not None
+ else "",
+ )
diff --git a/src/snowflake/snowpark/dataframe_reader.py b/src/snowflake/snowpark/dataframe_reader.py
index 218c086b5e..9c8d7824b4 100644
--- a/src/snowflake/snowpark/dataframe_reader.py
+++ b/src/snowflake/snowpark/dataframe_reader.py
@@ -12,6 +12,7 @@
from datetime import datetime
import snowflake.snowpark
+import snowflake.snowpark.context as context
import snowflake.snowpark._internal.proto.generated.ast_pb2 as proto
from snowflake.snowpark._internal.analyzer.analyzer_utils import (
convert_value_to_sql_option,
@@ -51,16 +52,21 @@
convert_sf_to_sp_type,
convert_sp_to_sf_type,
most_permissive_type,
+ type_string_to_type_object,
)
from snowflake.snowpark._internal.udf_utils import get_types_from_type_hints
+from snowflake.snowpark._internal.xml_reader import DEFAULT_CHUNK_SIZE
+from snowflake.snowpark._internal.xml_schema_inference import (
+ merge_struct_types,
+ canonicalize_type,
+)
from snowflake.snowpark._internal.utils import (
SNOWURL_PREFIX,
- STAGE_PREFIX,
XML_ROW_TAG_STRING,
XML_ROW_DATA_COLUMN_NAME,
XML_READER_FILE_PATH,
+ XML_SCHEMA_INFERENCE_FILE_PATH,
XML_READER_API_SIGNATURE,
- XML_READER_SQL_COMMENT,
INFER_SCHEMA_FORMAT_TYPES,
SNOWFLAKE_PATH_PREFIXES,
TempObjectType,
@@ -85,6 +91,9 @@
from snowflake.snowpark.mock._connection import MockServerConnection
from snowflake.snowpark.table import Table
from snowflake.snowpark.types import (
+ ArrayType,
+ MapType,
+ StringType,
StructType,
TimestampTimeZone,
VariantType,
@@ -118,6 +127,7 @@
"PATHGLOBFILTER": "PATTERN",
"FILENAMEPATTERN": "PATTERN",
"INFERSCHEMA": "INFER_SCHEMA",
+ "SAMPLINGRATIO": "SAMPLING_RATIO",
"SEP": "FIELD_DELIMITER",
"LINESEP": "RECORD_DELIMITER",
"QUOTE": "FIELD_OPTIONALLY_ENCLOSED_BY",
@@ -473,6 +483,7 @@ def __init__(
self._infer_schema_target_columns: Optional[List[str]] = None
self.__format: Optional[str] = None
self._data_source_format = ["jdbc", "dbapi"]
+ self._xml_inferred_schema: Optional[StructType] = None
self._ast = None
if _emit_ast:
@@ -1070,19 +1081,53 @@ def xml(self, path: str, _emit_ast: bool = True) -> DataFrame:
ast.reader.CopyFrom(self._ast)
df._ast_id = stmt.uid
- # cast to input custom schema type
- # TODO: SNOW-2923003: remove single quote after server side BCR is done
- if self._user_schema:
- cols = [
- df[single_quote(field._name)]
- .cast(field.datatype)
- .alias(quote_name_without_upper_casing(field._name))
- for field in self._user_schema.fields
- ]
- return df.select(cols)
+ # xml_reader returns VARIANT DataFrame, cast to custom or inferred schema type
+ effective_schema = self._user_schema or self._xml_inferred_schema
+ if effective_schema is not None:
+ return self._apply_xml_schema(df, effective_schema)
else:
return df
+ def _apply_xml_schema(self, df: DataFrame, schema: StructType) -> DataFrame:
+ """Apply an XML schema to a VARIANT DataFrame"""
+ cols = []
+ for field in schema.fields:
+ # TODO: SNOW-2923003: remove single quote after server side BCR is done
+ col = df[single_quote(field._name)]
+ if isinstance(field.datatype, (StructType, ArrayType, MapType)):
+ # Complex types: keep as VARIANT to prevent structured cast errors
+ cols.append(col.alias(quote_name_without_upper_casing(field._name)))
+ else:
+ # Primitive types: cast to the inferred datatype
+ cols.append(
+ col.cast(field.datatype).alias(
+ quote_name_without_upper_casing(field._name)
+ )
+ )
+
+ # In PERMISSIVE mode, append the StringType column for the corrupt record if it
+ # exists in the DataFrame emitted by the UDTF on malformed or corrupted records.
+ mode = self._cur_options.get("MODE", "PERMISSIVE").upper()
+ corrupt_col_name = self._cur_options.get(
+ "COLUMNNAMEOFCORRUPTRECORD", "_corrupt_record"
+ )
+ if mode == "PERMISSIVE":
+ df_columns = {c.strip('"').strip("'") for c in df.columns}
+ if corrupt_col_name in df_columns:
+ corrupt_ref = df[single_quote(corrupt_col_name)]
+ cols.append(
+ corrupt_ref.cast(StringType()).alias(
+ quote_name_without_upper_casing(corrupt_col_name)
+ )
+ )
+ if self._xml_inferred_schema is not None:
+ self._xml_inferred_schema.fields.append(
+ StructField(corrupt_col_name, StringType())
+ )
+
+ result = df.select(cols)
+ return result
+
@publicapi
def option(self, key: str, value: Any, _emit_ast: bool = True) -> "DataFrameReader":
"""Sets the specified option in the DataFrameReader.
@@ -1368,6 +1413,163 @@ def _get_schema_from_user_input(
read_file_transformations = [t._expression.sql for t in transformations]
return new_schema, schema_to_cast, read_file_transformations
+ def _resolve_xml_file_for_udtf(self, local_file_path: str) -> str:
+ """Return the UDTF file path, uploading to a temp stage in stored procedures."""
+ if is_in_stored_procedure(): # pragma: no cover
+ session_stage = self._session.get_session_stage()
+ self._session._conn.upload_file(
+ local_file_path,
+ session_stage,
+ compress_data=False,
+ overwrite=True,
+ skip_upload_on_content_match=True,
+ )
+ return f"{session_stage}/{os.path.basename(local_file_path)}"
+ return local_file_path
+
+ def _infer_schema_for_xml(self, path: str) -> Optional[StructType]:
+ # Register the XMLSchemaInference UDTF
+ handler_name = "XMLSchemaInference"
+ _, input_types = get_types_from_type_hints(
+ (XML_SCHEMA_INFERENCE_FILE_PATH, handler_name),
+ TempObjectType.TABLE_FUNCTION,
+ )
+
+ inference_file_path = self._resolve_xml_file_for_udtf(
+ XML_SCHEMA_INFERENCE_FILE_PATH
+ )
+
+ output_schema = StructType([StructField("SCHEMA_VALUE", StringType(), True)])
+ schema_udtf = self._session.udtf.register_from_file(
+ inference_file_path,
+ handler_name,
+ output_schema=output_schema,
+ input_types=input_types,
+ packages=["snowflake-snowpark-python", "lxml<6"],
+ replace=True,
+ _suppress_local_package_warnings=True,
+ )
+
+ # Determine number of workers
+ try:
+ file_size = int(
+ self._session.sql(f"ls {path}", _emit_ast=False).collect(
+ _emit_ast=False
+ )[0]["size"]
+ )
+ except IndexError:
+ raise ValueError(f"{path} does not exist")
+
+ num_workers = min(16, file_size // DEFAULT_CHUNK_SIZE + 1)
+
+ row_tag = self._cur_options[XML_ROW_TAG_STRING]
+ sampling_ratio = float(self._cur_options.get("SAMPLING_RATIO", 1.0))
+ if sampling_ratio <= 0:
+ raise ValueError(
+ f"samplingRatio ({sampling_ratio}) should be greater than 0"
+ )
+ ignore_namespace = self._cur_options.get("IGNORENAMESPACE", True)
+ attribute_prefix = self._cur_options.get("ATTRIBUTEPREFIX", "_")
+ exclude_attributes = self._cur_options.get("EXCLUDEATTRIBUTES", False)
+ value_tag = self._cur_options.get("VALUETAG", "_VALUE")
+ null_value = self._cur_options.get("NULL_IF", "")
+ charset = self._cur_options.get("CHARSET", "utf-8")
+ ignore_surrounding_whitespace = self._cur_options.get(
+ "IGNORESURROUNDINGWHITESPACE", False
+ )
+
+ # Create range DataFrame and apply UDTF
+ worker_col = "WORKER"
+ df = self._session.range(num_workers).to_df(worker_col)
+ df = df.select(
+ schema_udtf(
+ lit(path),
+ lit(num_workers),
+ lit(row_tag),
+ col(worker_col),
+ lit(sampling_ratio),
+ lit(ignore_namespace),
+ lit(attribute_prefix),
+ lit(exclude_attributes),
+ lit(value_tag),
+ lit(null_value),
+ lit(charset),
+ lit(ignore_surrounding_whitespace),
+ )
+ )
+
+ # Collect and merge schema results from all workers
+ results = df.collect(_emit_ast=False)
+ merged_schema: Optional[StructType] = None
+
+ for row in results:
+ schema_str = row[0]
+ if schema_str is None or schema_str == "":
+ continue
+ try:
+ partial_schema = type_string_to_type_object(schema_str)
+ except Exception:
+ continue
+ if not isinstance(partial_schema, StructType):
+ partial_schema = StructType(
+ [StructField(value_tag, partial_schema, nullable=True)]
+ )
+ if merged_schema is None:
+ merged_schema = partial_schema
+ else:
+ merged_schema = merge_struct_types(
+ merged_schema, partial_schema, value_tag
+ )
+
+ if merged_schema is None:
+ return None
+
+ # Canonicalize: NullType -> StringType, remove empty structs
+ canonical = canonicalize_type(merged_schema)
+ if canonical is None or not isinstance(canonical, StructType):
+ return None
+
+ def _strip_outer_quotes(name: str) -> str:
+ """Strip one layer of double quotes added by _case_preserving_simple_string()"""
+ if len(name) >= 2 and name[0] == '"' and name[-1] == '"':
+ return name[1:-1]
+ return name
+
+ def _clean_schema_field_names(dt):
+ """Strip outer double quotes from field names.
+
+ _case_preserving_simple_string wraps colon-containing names in
+ double quotes so the parser can split name:type correctly.
+ type_string_to_type_object keeps those quotes as part of the name.
+ Strip them here so downstream quote_name_without_upper_casing
+ doesn't produce triple-quoted identifiers.
+ """
+ if isinstance(dt, StructType):
+ return StructType(
+ [
+ StructField(
+ _strip_outer_quotes(f._name),
+ _clean_schema_field_names(f.datatype),
+ f.nullable,
+ )
+ for f in dt.fields
+ ]
+ )
+ if isinstance(dt, ArrayType):
+ return ArrayType(
+ _clean_schema_field_names(dt.element_type), dt.contains_null
+ )
+ if isinstance(dt, MapType):
+ return MapType(
+ _clean_schema_field_names(dt.key_type),
+ _clean_schema_field_names(dt.value_type),
+ dt.value_contains_null,
+ )
+ return dt
+
+ inferred_schema = _clean_schema_field_names(canonical)
+ return inferred_schema
+
def _read_semi_structured_file(self, path: str, format: str) -> DataFrame:
if isinstance(self._session._conn, MockServerConnection):
if self._session._conn.is_closed():
@@ -1422,26 +1624,26 @@ def _read_semi_structured_file(self, path: str, format: str) -> DataFrame:
metadata_project, metadata_schema = self._get_metadata_project_and_schema()
+ xml_inferred_schema = None
if format == "XML" and XML_ROW_TAG_STRING in self._cur_options:
- if is_in_stored_procedure(): # pragma: no cover
- # create a temp stage for udtf import files
- # we have to use "temp" object instead of "scoped temp" object in stored procedure
- # so we need to upload the file to the temp stage first to use register_from_file
- temp_stage = random_name_for_temp_object(TempObjectType.STAGE)
- sql_create_temp_stage = f"create temp stage if not exists {temp_stage} {XML_READER_SQL_COMMENT}"
- self._session.sql(sql_create_temp_stage, _emit_ast=False).collect(
- _emit_ast=False
- )
- self._session._conn.upload_file(
- XML_READER_FILE_PATH,
- temp_stage,
- compress_data=False,
- overwrite=True,
- skip_upload_on_content_match=True,
- )
- python_file_path = f"{STAGE_PREFIX}{temp_stage}/{os.path.basename(XML_READER_FILE_PATH)}"
- else:
- python_file_path = XML_READER_FILE_PATH
+ python_file_path = self._resolve_xml_file_for_udtf(XML_READER_FILE_PATH)
+ if (
+ context._is_snowpark_connect_compatible_mode
+ and not self._user_schema
+ and self._cur_options.get("INFER_SCHEMA", True)
+ ):
+ xml_inferred_schema = self._infer_schema_for_xml(path)
+ if xml_inferred_schema is not None:
+ self._xml_inferred_schema = xml_inferred_schema
+ schema = [
+ Attribute(
+ quote_name_without_upper_casing(f._name),
+ f.datatype,
+ f.nullable,
+ )
+ for f in xml_inferred_schema.fields
+ ]
+ use_user_schema = True
# create udtf
handler_name = "XMLReader"
@@ -1508,7 +1710,10 @@ def _read_semi_structured_file(self, path: str, format: str) -> DataFrame:
set_api_call_source(df, XML_READER_API_SIGNATURE)
if self._cur_options.get("CACHERESULT", True):
df = df.cache_result()
- df._all_variant_cols = True
+ # When schema is inferred or user-provided, columns are typed
+ # (not all VARIANT), so don't enable the all-variant dot-notation mode.
+ if xml_inferred_schema is None and not self._user_schema:
+ df._all_variant_cols = True
else:
set_api_call_source(df, f"DataFrameReader.{format.lower()}")
return df
diff --git a/tests/integ/test_xml_infer_schema.py b/tests/integ/test_xml_infer_schema.py
new file mode 100644
index 0000000000..1c6f9e0fc5
--- /dev/null
+++ b/tests/integ/test_xml_infer_schema.py
@@ -0,0 +1,1189 @@
+#
+# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
+#
+
+import datetime
+import json
+import os
+import pytest
+
+from snowflake.snowpark import Row
+from snowflake.snowpark.functions import col
+from snowflake.snowpark.types import (
+ StructType,
+ StructField,
+ StringType,
+ LongType,
+ DoubleType,
+ DateType,
+ BooleanType,
+ TimestampType,
+ ArrayType,
+ VariantType,
+)
+import snowflake.snowpark.context as context
+from tests.utils import TestFiles, Utils
+
+
+pytestmark = [
+ pytest.mark.skipif(
+ "config.getoption('local_testing_mode', default=False)",
+ reason="xml not supported in local testing mode",
+ ),
+ pytest.mark.udf,
+]
+
+tmp_stage_name = Utils.random_stage_name()
+
+# Resource XML file names (uploaded from tests/resources/)
+RES_BOOKS_XML = "books.xml"
+RES_BOOKS2_XML = "books2.xml"
+RES_DK_TRACE_XML = "dk_trace_sample.xml"
+RES_DBLP_XML = "dblp_6kb.xml"
+RES_BOOKS_ATTR_VAL_XML = "books_attribute_value.xml"
+
+# Inline XML strings uploaded to stage as files for testing
+# Each covers specific inference scenarios without needing separate resource files.
+
+# Primitive type inference: bool, double, long, string, timestamp
+PRIMITIVES_XML = """\
+
+
+
+ true
+ +10.1
+ -10
+ 10
+ 8E9D
+ 2015-01-01 00:00:00
+
+
+"""
+
+# Date and timestamp inference
+DATE_TIME_XML = """\
+
+
+
+ John Smith
+ 2021-02-01
+ 02-01-2021
+
+
+"""
+
+TIMESTAMP_XML = """\
+
+
+
+ John Smith
+
+ not-a-timestamp
+
+
+"""
+
+# Root-level _VALUE: attribute-only row element with text content
+ROOT_VALUE_XML = """\
+
+
+ value1
+ value2
+ value3
+
+"""
+
+# Root-level _VALUE with child elements
+ROOT_VALUE_MIXED_XML = """\
+
+
+ value1
+ value2
+ 45
+ 67
+
+
+"""
+
+# Nested object: struct child element
+NESTED_OBJECT_XML = """\
+
+
+
+ Book A
+ 44.95
+
+ Acme
+ 2020
+
+
+
+ Book B
+ 29.99
+
+ Beta
+ 2021
+
+
+
+"""
+
+# Nested array: repeated sibling elements
+NESTED_ARRAY_XML = """\
+
+
+
+ Book A
+ fiction
+ classic
+
+
+ Book B
+ science
+
+
+"""
+
+# Element with attribute on leaf: 5.95 → struct(_VALUE, _unit)
+ATTR_ON_LEAF_XML = """\
+
+
+
+ Book A
+ 44.95
+
+
+ Book B
+ 29.99
+
+
+ Book C
+ 15.00
+
+
+"""
+
+# Missing nested struct: some rows have nested struct, some don't
+MISSING_NESTED_XML = """\
+
+
+ -
+ Item A
+
+ red
+
+
+ -
+ Item B
+
+
+"""
+
+# Unbalanced types: same field has different types across rows
+UNBALANCED_TYPES_XML = """\
+
+
+
+ 123
+ hello
+
+
+ 45.6
+ world
+
+
+"""
+
+# Mixed content: text + child elements
+MIXED_CONTENT_XML = """\
+
+
+ -
+ Simple text
+
+ 1
+
+
+ -
+ Has mixed content
+
+ 2
+
+
+
+"""
+
+# ExcludeAttributes with inferSchema
+EXCLUDE_ATTRS_XML = """\
+
+
+ -
+ Widget
+ 9.99
+
+ -
+ Gadget
+ 19.99
+
+
+"""
+
+# Big integer inference
+BIG_INT_XML = """\
+
+
+
+ 42
+ 92233720368547758070
+ 3.14
+
+
+"""
+
+# Nested element same name as parent
+PARENT_NAME_COLLISION_XML = """\
+
+
+
+
+ Child 1.1
+
+ Child 1.2
+
+
+
+ Child 2.1
+
+ Child 2.2
+
+
+"""
+
+# Complicated nested: struct containing array of structs with attributes
+COMPLICATED_NESTED_XML = """\
+
+
+
+ Author A
+
+ 1
+ Fiction
+
+
+
+ 2020
+ 1
+
+
+ 2020
+ 6
+
+
+
+
+ Author B
+
+ 2
+ Science
+
+
+
+ 2021
+ 3
+
+
+
+
+"""
+
+# Sampling heterogeneous: 20 rows where the first 15 have value as integer,
+# but rows 16-20 have value as a float string. With a low sampling ratio
+# deterministic seed may only see the first chunk → LongType instead of DoubleType.
+SAMPLING_HETERO_XML = """\
+
+
+ r011
+ r022
+ r033
+ r044
+ r055
+ r066
+ r077
+ r088
+ r099
+ r1010
+ r1111
+ r1212
+ r1313
+ r1414
+ r1515
+ r1616.1
+ r1717.2
+ r1818.3
+ r1919.4
+ r2020.5
+
+"""
+
+# Processing instruction XML
+PROCESSING_INSTRUCTION_XML = """\
+
+
+
+
+ 1
+ hello
+
+
+"""
+
+
+def _upload_xml_string(session, stage, filename, xml_content):
+ """Write XML string to a temp file and upload to stage."""
+ import tempfile
+
+ with tempfile.NamedTemporaryFile(
+ mode="w", suffix=".xml", delete=False, prefix=filename.replace(".xml", "_")
+ ) as f:
+ f.write(xml_content)
+ tmp_path = f.name
+ try:
+ Utils.upload_to_stage(session, stage, tmp_path, compress=False)
+ finally:
+ os.unlink(tmp_path)
+ return os.path.basename(tmp_path)
+
+
+# Map of logical name -> (xml_content, staged_filename) populated during setup
+_staged_files = {}
+
+
+@pytest.fixture(autouse=True)
+def enable_scos_compatible_mode():
+ """XML inferSchema is gated behind _is_snowpark_connect_compatible_mode.
+ Enable it for every test in this module."""
+ original = context._is_snowpark_connect_compatible_mode
+ context._is_snowpark_connect_compatible_mode = True
+ yield
+ context._is_snowpark_connect_compatible_mode = original
+
+
+@pytest.fixture(scope="module", autouse=True)
+def setup(session, resources_path, local_testing_mode):
+ test_files = TestFiles(resources_path)
+ if not local_testing_mode:
+ Utils.create_stage(session, tmp_stage_name, is_temporary=True)
+
+ # Upload resource XML files
+ Utils.upload_to_stage(
+ session, "@" + tmp_stage_name, test_files.test_xml_infer_types, compress=False
+ )
+ Utils.upload_to_stage(
+ session, "@" + tmp_stage_name, test_files.test_xml_infer_mixed, compress=False
+ )
+ Utils.upload_to_stage(
+ session, "@" + tmp_stage_name, test_files.test_books_xml, compress=False
+ )
+ Utils.upload_to_stage(
+ session, "@" + tmp_stage_name, test_files.test_books2_xml, compress=False
+ )
+ Utils.upload_to_stage(
+ session,
+ "@" + tmp_stage_name,
+ test_files.test_dk_trace_sample_xml,
+ compress=False,
+ )
+ Utils.upload_to_stage(
+ session, "@" + tmp_stage_name, test_files.test_dblp_6kb_xml, compress=False
+ )
+ Utils.upload_to_stage(
+ session,
+ "@" + tmp_stage_name,
+ test_files.test_books_attribute_value_xml,
+ compress=False,
+ )
+
+ # Upload inline XML strings as files
+ inline_xmls = {
+ "primitives": PRIMITIVES_XML,
+ "date_time": DATE_TIME_XML,
+ "timestamp": TIMESTAMP_XML,
+ "root_value": ROOT_VALUE_XML,
+ "root_value_mixed": ROOT_VALUE_MIXED_XML,
+ "nested_object": NESTED_OBJECT_XML,
+ "nested_array": NESTED_ARRAY_XML,
+ "attr_on_leaf": ATTR_ON_LEAF_XML,
+ "missing_nested": MISSING_NESTED_XML,
+ "unbalanced_types": UNBALANCED_TYPES_XML,
+ "mixed_content": MIXED_CONTENT_XML,
+ "exclude_attrs": EXCLUDE_ATTRS_XML,
+ "big_int": BIG_INT_XML,
+ "parent_collision": PARENT_NAME_COLLISION_XML,
+ "complicated_nested": COMPLICATED_NESTED_XML,
+ "processing_instr": PROCESSING_INSTRUCTION_XML,
+ "sampling_hetero": SAMPLING_HETERO_XML,
+ }
+ for name, xml_str in inline_xmls.items():
+ staged = _upload_xml_string(
+ session, "@" + tmp_stage_name, f"{name}.xml", xml_str
+ )
+ _staged_files[name] = staged
+
+ yield
+ if not local_testing_mode:
+ session.sql(f"DROP STAGE IF EXISTS {tmp_stage_name}").collect()
+
+
+def _schema_types(df):
+ """Return {lowercase_field_name: datatype_class} from df.schema."""
+ return {f.name.strip('"').lower(): type(f.datatype) for f in df.schema.fields}
+
+
+# ─── Primitive type inference ───────────────────────────────────────────────
+
+
+def test_infer_primitives(session):
+ """Bool, double, long, string, timestamp inferred from inline XML."""
+ df = session.read.option("rowTag", "ROW").xml(
+ f"@{tmp_stage_name}/{_staged_files['primitives']}"
+ )
+ assert df._all_variant_cols is False
+ types = _schema_types(df)
+ assert types["bool1"] == BooleanType
+ assert types["double1"] == DoubleType
+ assert types["long1"] == LongType
+ assert types["long2"] == LongType
+ assert types["string1"] == StringType
+ assert types["ts1"] == TimestampType
+ result = df.collect()
+ assert len(result) == 1
+ assert result[0]["bool1"] is True
+ assert result[0]["double1"] == 10.1
+ assert result[0]["long1"] == -10
+ assert result[0]["string1"] == "8E9D"
+
+
+def test_infer_date(session):
+ """ISO date string inferred as DateType, non-ISO stays StringType."""
+ df = session.read.option("rowTag", "book").xml(
+ f"@{tmp_stage_name}/{_staged_files['date_time']}"
+ )
+ assert df._all_variant_cols is False
+ types = _schema_types(df)
+ assert types["date"] == DateType
+ assert types["date2"] == StringType # non-ISO format stays string
+ assert types["author"] == StringType
+ result = df.collect()
+ assert result[0]["date"] == datetime.date(2021, 2, 1)
+
+
+def test_infer_timestamp(session):
+ """ISO timestamp inferred as TimestampType, non-timestamp stays StringType."""
+ df = session.read.option("rowTag", "book").xml(
+ f"@{tmp_stage_name}/{_staged_files['timestamp']}"
+ )
+ assert df._all_variant_cols is False
+ types = _schema_types(df)
+ assert types["time"] == TimestampType
+ assert types["time2"] == StringType
+
+
+# ─── Root-level _VALUE tag ──────────────────────────────────────────────────
+
+
+def test_infer_root_value_attrs_only(session):
+ """value
→ _VALUE + _attr columns."""
+ df = session.read.option("rowTag", "ROW").xml(
+ f"@{tmp_stage_name}/{_staged_files['root_value']}"
+ )
+ assert df._all_variant_cols is False
+ types = _schema_types(df)
+ assert "_value" in types or "_VALUE" in types.keys()
+ result = df.collect()
+ assert len(result) == 3
+
+
+def test_infer_root_value_with_child_elements(session):
+ """45
→ _VALUE + tag columns."""
+ df = session.read.option("rowTag", "ROW").xml(
+ f"@{tmp_stage_name}/{_staged_files['root_value_mixed']}"
+ )
+ assert df._all_variant_cols is False
+ types = _schema_types(df)
+ assert "tag" in types
+ result = df.collect()
+ assert len(result) == 5
+
+
+# ─── Nested structures ─────────────────────────────────────────────────────
+
+
+def test_infer_nested_object(session):
+ """Child element becomes nested struct: info.publisher, info.year."""
+ df = session.read.option("rowTag", "book").xml(
+ f"@{tmp_stage_name}/{_staged_files['nested_object']}"
+ )
+ assert df._all_variant_cols is False
+ types = _schema_types(df)
+ # info should be VariantType (Snowpark represents nested structs as Variant)
+ assert types["info"] == VariantType
+ assert types["price"] == DoubleType
+ result = df.collect()
+ assert len(result) == 2
+ info = json.loads(result[0]["info"])
+ assert info["publisher"] in ["Acme", "Beta"]
+
+
+def test_infer_nested_array(session):
+ """Repeated sibling elements → ArrayType or VariantType."""
+ df = session.read.option("rowTag", "book").xml(
+ f"@{tmp_stage_name}/{_staged_files['nested_array']}"
+ )
+ assert df._all_variant_cols is False
+ types = _schema_types(df)
+ # Repeated tag elements: Variant wrapping an array
+ assert types["tag"] in (VariantType, ArrayType)
+ result = df.collect()
+ assert len(result) == 2
+ # Book with 2 tags should be an array; _id may be Long or String
+ book1 = [r for r in result if str(r["_id"]).strip('"') == "1"][0]
+ tag_val = (
+ json.loads(book1["tag"]) if isinstance(book1["tag"], str) else book1["tag"]
+ )
+ assert isinstance(tag_val, list)
+ assert len(tag_val) == 2
+
+
+def test_infer_complicated_nested(session):
+ """Struct with sub-struct and array of structs with attributes."""
+ df = session.read.option("rowTag", "book").xml(
+ f"@{tmp_stage_name}/{_staged_files['complicated_nested']}"
+ )
+ assert df._all_variant_cols is False
+ types = _schema_types(df)
+ assert types["genre"] == VariantType
+ assert types["dates"] == VariantType
+ result = df.collect()
+ assert len(result) == 2
+ # Verify genre nested content
+ genre1 = json.loads(result[0]["genre"])
+ assert "genreid" in genre1 or "name" in genre1
+ # Verify dates nested array
+ dates1 = json.loads(result[0]["dates"])
+ assert "date" in dates1
+
+
+# ─── Attribute on leaf element (_VALUE pattern) ────────────────────────────
+
+
+def test_infer_attribute_on_leaf(session):
+ """44.95 → struct with _VALUE + _unit."""
+ df = session.read.option("rowTag", "book").xml(
+ f"@{tmp_stage_name}/{_staged_files['attr_on_leaf']}"
+ )
+ assert df._all_variant_cols is False
+ types = _schema_types(df)
+ # price has attr on some rows → becomes VariantType (struct with _VALUE/_unit)
+ assert types["price"] == VariantType
+ result = df.collect()
+ assert len(result) == 3
+ # Book 1 has unit="$"; _id may be Long or String
+ book1 = [r for r in result if str(r["_id"]).strip('"') == "1"][0]
+ price_data = json.loads(book1["price"])
+ assert "_VALUE" in price_data
+ assert price_data["_unit"] == "$"
+
+
+# ─── Missing nested struct ─────────────────────────────────────────────────
+
+
+def test_infer_missing_nested_struct(session):
+ """Row missing a nested struct field gets null/empty, not crash."""
+ df = session.read.option("rowTag", "item").xml(
+ f"@{tmp_stage_name}/{_staged_files['missing_nested']}"
+ )
+ assert df._all_variant_cols is False
+ types = _schema_types(df)
+ assert "details" in types
+ result = df.collect()
+ assert len(result) == 2
+ # Item B has no details
+ item_b = [r for r in result if r["name"] == "Item B"][0]
+ details_val = item_b["details"]
+ # Either null or struct with null fields
+ if details_val is not None:
+ parsed = json.loads(details_val)
+ assert parsed.get("color") is None
+
+
+# ─── Unbalanced types (type widening across rows) ──────────────────────────
+
+
+def test_infer_unbalanced_types(session):
+ """Field1 is 123 in one row and 45.6 in another → widened to DoubleType."""
+ df = session.read.option("rowTag", "ROW").xml(
+ f"@{tmp_stage_name}/{_staged_files['unbalanced_types']}"
+ )
+ assert df._all_variant_cols is False
+ types = _schema_types(df)
+ assert types["field1"] == DoubleType # Long widened to Double
+ assert types["field2"] == StringType
+ result = df.collect()
+ assert len(result) == 2
+
+
+# ─── excludeAttributes with inferSchema ─────────────────────────────────────
+
+
+def test_infer_exclude_attributes(session):
+ """excludeAttributes=true removes id/category from inferred schema."""
+ df = (
+ session.read.option("rowTag", "item")
+ .option("excludeAttributes", True)
+ .xml(f"@{tmp_stage_name}/{_staged_files['exclude_attrs']}")
+ )
+ assert df._all_variant_cols is False
+ types = _schema_types(df)
+ assert "_id" not in types
+ assert "_category" not in types
+ assert "name" in types
+ assert "price" in types
+ result = df.collect()
+ assert len(result) == 2
+
+ # Without excludeAttributes, attributes should appear
+ df2 = session.read.option("rowTag", "item").xml(
+ f"@{tmp_stage_name}/{_staged_files['exclude_attrs']}"
+ )
+ types2 = _schema_types(df2)
+ assert "_id" in types2
+ assert "_category" in types2
+
+
+# ─── Big integer handling ───────────────────────────────────────────────────
+
+
+def test_infer_big_integer(session):
+ """Very large integer doesn't fit in Long → inferred as Double or String."""
+ df = session.read.option("rowTag", "ROW").xml(
+ f"@{tmp_stage_name}/{_staged_files['big_int']}"
+ )
+ assert df._all_variant_cols is False
+ types = _schema_types(df)
+ assert types["small_int"] == LongType
+ # Big int overflows Long → should become Double or String
+ assert types["big_int"] in (DoubleType, StringType)
+ assert types["normal_double"] == DoubleType
+
+
+# ─── Nested element same name as parent ─────────────────────────────────────
+
+
+def test_infer_parent_name_collision(session):
+ """......
+ Known limitation: when rowTag matches a child element name, byte-scanning
+ picks up inner tags as row boundaries, causing schema inference to fall back
+ to all-variant or produce extra rows. Verify it doesn't crash."""
+ df = session.read.option("rowTag", "parent").xml(
+ f"@{tmp_stage_name}/{_staged_files['parent_collision']}"
+ )
+ result = df.collect()
+ # May produce 2 or 4 rows depending on inner matching
+ assert len(result) >= 2
+
+
+# ─── Processing instruction ────────────────────────────────────────────────
+
+
+def test_infer_with_processing_instruction(session):
+ """XML with should parse without error."""
+ df = session.read.option("rowTag", "foo").xml(
+ f"@{tmp_stage_name}/{_staged_files['processing_instr']}"
+ )
+ assert df._all_variant_cols is False
+ result = df.collect()
+ assert len(result) == 1
+
+
+# ─── Mixed content (text + child elements) ──────────────────────────────────
+
+
+def test_infer_mixed_content(session):
+ """Text mixed with child elements: desc field has both text and ."""
+ df = session.read.option("rowTag", "item").xml(
+ f"@{tmp_stage_name}/{_staged_files['mixed_content']}"
+ )
+ assert df._all_variant_cols is False
+ result = df.collect()
+ assert len(result) == 2
+
+
+# ─── inferSchema=false keeps all strings ────────────────────────────────────
+
+
+@pytest.mark.parametrize(
+ "staged_key, row_tag, expected_count",
+ [
+ ("primitives", "ROW", 1),
+ ("nested_object", "book", 2),
+ ("unbalanced_types", "ROW", 2),
+ ],
+)
+def test_infer_schema_false_all_strings(session, staged_key, row_tag, expected_count):
+ """inferSchema=false → all fields are variant/string columns."""
+ df = (
+ session.read.option("rowTag", row_tag)
+ .option("inferSchema", False)
+ .xml(f"@{tmp_stage_name}/{_staged_files[staged_key]}")
+ )
+ assert df._all_variant_cols is True
+ result = df.collect()
+ assert len(result) == expected_count
+
+
+# ─── Resource file: xml_infer_types.xml ─────────────────────────────────────
+
+
+def test_infer_types_resource_file(session):
+ """Comprehensive resource file: primitives, nested, array, attributes, mixed."""
+ df = session.read.option("rowTag", "item").xml(
+ f"@{tmp_stage_name}/xml_infer_types.xml"
+ )
+ assert df._all_variant_cols is False
+ types = _schema_types(df)
+
+ # Flat fields
+ assert types["_id"] in (LongType, StringType)
+ assert types["name"] == StringType
+ assert types["quantity"] == LongType
+ assert types["in_stock"] == BooleanType
+ assert types["release_date"] == DateType
+ assert types["last_updated"] == TimestampType
+
+ # price has attribute on some rows → complex type
+ assert types["price"] == VariantType
+
+ # tags → nested with array → Variant
+ assert types["tags"] == VariantType
+
+ # specs → nested struct → Variant
+ assert types["specs"] == VariantType
+
+ # rating has attribute on some rows → complex
+ assert types["rating"] == VariantType
+
+ result = df.collect()
+ assert len(result) == 3
+
+ # Verify nested data
+ item1 = [r for r in result if str(r["_id"]) in ("1", '"1"')][0]
+ specs = json.loads(item1["specs"])
+ assert "weight" in specs or "dimensions" in specs
+ tags = json.loads(item1["tags"])
+ assert "tag" in tags
+
+
+def test_infer_mixed_resource_file(session):
+ """Resource file with name collisions, sparse fields, mixed content."""
+ df = session.read.option("rowTag", "record").xml(
+ f"@{tmp_stage_name}/xml_infer_mixed.xml"
+ )
+ assert df._all_variant_cols is False
+ types = _schema_types(df)
+ assert "_type" in types
+ assert "parent" in types
+ assert "age" in types
+ assert "value" in types
+ result = df.collect()
+ assert len(result) == 3
+
+ # Verify opt_struct is null for sparse record
+ sparse = [r for r in result if r["_type"] == "sparse"][0]
+ try:
+ opt_val = sparse["opt_struct"]
+ except (KeyError, IndexError):
+ opt_val = None
+ if opt_val is not None:
+ parsed = json.loads(opt_val)
+ assert parsed["a"] is None
+
+
+def test_read_xml_infer_schema_books_flat(session):
+ """Infer schema on books.xml: all flat primitives, 12 rows, 7 columns."""
+ expected_schema = StructType(
+ [
+ StructField("_id", StringType()),
+ StructField("author", StringType()),
+ StructField("description", StringType()),
+ StructField("genre", StringType()),
+ StructField("price", DoubleType()),
+ StructField("publish_date", DateType()),
+ StructField("title", StringType()),
+ ]
+ )
+ df = session.read.option("rowTag", "book").xml(f"@{tmp_stage_name}/{RES_BOOKS_XML}")
+ assert df._all_variant_cols is False
+ actual = {f.name.strip('"').lower(): type(f.datatype) for f in df.schema.fields}
+ expected = {f.name.lower(): type(f.datatype) for f in expected_schema.fields}
+ assert actual == expected
+ result = df.collect()
+ assert len(result) == 12
+ assert len(result[0]) == 7
+ Utils.check_answer(
+ df.filter(col('"_id"') == "bk101").select(
+ col('"price"'), col('"publish_date"'), col('"author"')
+ ),
+ [Row(44.95, datetime.date(2000, 10, 1), "Gambardella, Matthew")],
+ )
+
+
+def test_read_xml_infer_schema_books2_nested(session):
+ """Infer schema on books2.xml: complex types become VariantType, verify nested data."""
+ expected_schema = StructType(
+ [
+ StructField("_id", LongType()),
+ StructField("author", StringType()),
+ StructField("editions", VariantType()),
+ StructField("price", DoubleType()),
+ StructField("reviews", VariantType()),
+ StructField("title", StringType()),
+ ]
+ )
+ df = session.read.option("rowTag", "book").xml(
+ f"@{tmp_stage_name}/{RES_BOOKS2_XML}"
+ )
+ assert df._all_variant_cols is False
+ actual = {f.name.strip('"').lower(): type(f.datatype) for f in df.schema.fields}
+ expected = {f.name.lower(): type(f.datatype) for f in expected_schema.fields}
+ assert actual == expected
+ result = df.collect()
+ assert len(result) == 2
+ assert len(result[0]) == 6
+
+ book1 = df.filter(col('"_id"') == 1).collect()
+ assert len(book1) == 1
+ reviews = json.loads(book1[0]["reviews"])
+ assert isinstance(reviews["review"], list)
+ assert len(reviews["review"]) == 2
+ assert reviews["review"][0]["user"] == "tech_guru_87"
+ editions = json.loads(book1[0]["editions"])
+ assert isinstance(editions["edition"], list)
+ assert len(editions["edition"]) == 2
+
+ book2 = df.filter(col('"_id"') == 2).collect()
+ assert len(book2) == 1
+ review_data = json.loads(book2[0]["reviews"])["review"]
+ if isinstance(review_data, dict):
+ assert review_data["user"] == "xml_master"
+ else:
+ assert review_data[0]["user"] == "xml_master"
+ assert book1[0]["price"] == 29.99
+ assert book2[0]["price"] == 35.50
+
+
+def test_read_xml_namespace_infer_schema(session):
+ """Infer schema on dk_trace_sample.xml with namespace-prefixed eqTrace:event rowTag."""
+ expected_schema = StructType(
+ [
+ StructField("eqTrace:date-time", TimestampType()),
+ StructField("eqTrace:equipment-cycle-status-changed", VariantType()),
+ StructField("eqTrace:event-descriptor-list", VariantType()),
+ StructField("eqTrace:event-id", StringType()),
+ StructField("eqTrace:event-name", StringType()),
+ StructField("eqTrace:event-version-number", DoubleType()),
+ StructField("eqTrace:interchanged", VariantType()),
+ StructField("eqTrace:is-planned", BooleanType()),
+ StructField("eqTrace:is-synthetic", BooleanType()),
+ StructField("eqTrace:location-id", VariantType()),
+ StructField("eqTrace:placement-at-industry-planned", VariantType()),
+ StructField("eqTrace:publisher-identification", StringType()),
+ StructField("eqTrace:reporting-detail", VariantType()),
+ StructField("eqTrace:tag-name-list", VariantType()),
+ StructField("eqTrace:waybill-applied", VariantType()),
+ ]
+ )
+ df = (
+ session.read.option("rowTag", "eqTrace:event")
+ .option("ignoreNamespace", False)
+ .xml(f"@{tmp_stage_name}/{RES_DK_TRACE_XML}")
+ )
+ assert df._all_variant_cols is False
+ actual = {f.name.strip('"'): type(f.datatype) for f in df.schema.fields}
+ expected = {f.name.strip('"'): type(f.datatype) for f in expected_schema.fields}
+ assert actual == expected
+ result = df.collect()
+ assert len(result) == 5
+
+ first_event = df.filter(
+ col('"eqTrace:event-id"') == "f0e765d9-599b-46bf-9aef-bd33e0c2183f"
+ ).collect()
+ assert len(first_event) == 1
+ Utils.check_answer(
+ df.filter(
+ col('"eqTrace:event-id"') == "f0e765d9-599b-46bf-9aef-bd33e0c2183f"
+ ).select(col('"eqTrace:event-name"'), col('"eqTrace:is-planned"')),
+ [Row("equipment/equipment-placement-at-industry-planned", True)],
+ )
+
+ second_event = df.filter(
+ col('"eqTrace:event-id"') == "dd9a4616-e41c-4571-9bda-9a5506a2b78d"
+ ).collect()
+ assert len(second_event) == 1
+ assert second_event[0]["eqTrace:is-planned"] is False
+
+
+def test_read_xml_dblp_mastersthesis_infer_schema(session):
+ """Infer schema on dblp_6kb.xml mastersthesis: verify schema, nested ee data, printSchema."""
+ expected_schema = StructType(
+ [
+ StructField("_key", StringType()),
+ StructField("_mdate", DateType()),
+ StructField("author", StringType()),
+ StructField("ee", VariantType()),
+ StructField("note", StringType()),
+ StructField("school", StringType()),
+ StructField("title", StringType()),
+ StructField("year", LongType()),
+ ]
+ )
+ df = session.read.option("rowTag", "mastersthesis").xml(
+ f"@{tmp_stage_name}/{RES_DBLP_XML}"
+ )
+ assert df._all_variant_cols is False
+ actual = {f.name.strip('"').lower(): type(f.datatype) for f in df.schema.fields}
+ expected = {f.name.lower(): type(f.datatype) for f in expected_schema.fields}
+ assert actual == expected
+ result = df.collect()
+ assert len(result) == 6
+
+ hoffmann = df.filter(col('"_key"') == "ms/Hoffmann2008").collect()
+ assert len(hoffmann) == 1
+ ee = json.loads(hoffmann[0]["ee"])
+ assert ee["_type"] == "oa"
+ assert "dblp.uni-trier.de" in ee["_VALUE"]
+
+ vollmer = df.filter(col('"_key"') == "ms/Vollmer2006").collect()
+ assert len(vollmer) == 1
+ ee_v = json.loads(vollmer[0]["ee"])
+ assert ee_v["_type"] is None
+ assert ee_v["_VALUE"] is not None
+
+ brown = df.filter(col('"_key"') == "ms/Brown92").collect()
+ assert len(brown) == 1
+ assert json.loads(brown[0]["ee"]) == {"_VALUE": None, "_type": None}
+
+ schema_str = df._format_schema()
+ assert "LongType" in schema_str
+ assert "StringType" in schema_str
+ assert "VariantType" in schema_str
+
+
+def test_read_xml_dblp_incollection_infer_schema(session):
+ """Infer schema on dblp_6kb.xml incollection: author array, ee struct, verify data."""
+ expected_schema = StructType(
+ [
+ StructField("_corrupt_record", StringType()),
+ StructField("_key", StringType()),
+ StructField("_mdate", DateType()),
+ StructField("author", VariantType()),
+ StructField("booktitle", StringType()),
+ StructField("crossref", StringType()),
+ StructField("ee", VariantType()),
+ StructField("pages", StringType()),
+ StructField("title", StringType()),
+ StructField("url", StringType()),
+ StructField("year", LongType()),
+ ]
+ )
+ df = session.read.option("rowTag", "incollection").xml(
+ f"@{tmp_stage_name}/{RES_DBLP_XML}"
+ )
+ assert df._all_variant_cols is False
+ actual = {f.name.strip('"').lower(): type(f.datatype) for f in df.schema.fields}
+ expected = {f.name.lower(): type(f.datatype) for f in expected_schema.fields}
+ assert actual == expected
+ result = df.collect()
+ assert len(result) == 6
+
+ parker = df.filter(col('"_key"') == "series/ifip/ParkerD14").collect()
+ assert len(parker) == 1
+ authors = json.loads(parker[0]["author"])
+ assert isinstance(authors, list)
+ assert len(authors) == 2
+ assert authors[0]["_VALUE"] == "Kevin R. Parker"
+ assert authors[0]["_orcid"] == "0000-0003-0549-3687"
+ ee = json.loads(parker[0]["ee"])
+ assert ee["_type"] == "oa"
+
+ lecomber = df.filter(col('"_key"') == "series/ifip/Lecomber14").collect()
+ assert len(lecomber) == 1
+ author_val = json.loads(lecomber[0]["author"])
+ if isinstance(author_val, dict):
+ assert author_val["_VALUE"] == "Angela Lecomber"
+ assert author_val["_orcid"] is None
+ elif isinstance(author_val, list):
+ assert len(author_val) == 1
+ assert author_val[0]["_VALUE"] == "Angela Lecomber"
+
+
+def test_read_xml_attribute_value_infer_schema(session):
+ """Infer schema on books_attribute_value.xml: publisher mixed struct + plain text."""
+ expected_schema = StructType(
+ [
+ StructField("_id", LongType()),
+ StructField("author", StringType()),
+ StructField("pages", LongType()),
+ StructField("price", DoubleType()),
+ StructField("publisher", VariantType()),
+ StructField("title", StringType()),
+ ]
+ )
+ df = session.read.option("rowTag", "book").xml(
+ f"@{tmp_stage_name}/{RES_BOOKS_ATTR_VAL_XML}"
+ )
+ assert df._all_variant_cols is False
+ actual = {f.name.strip('"').lower(): type(f.datatype) for f in df.schema.fields}
+ expected = {f.name.lower(): type(f.datatype) for f in expected_schema.fields}
+ assert actual == expected
+ result = df.collect()
+ assert len(result) == 5
+ assert len(result[0]) == 6
+
+ book1 = df.filter(col('"_id"') == 1).collect()
+ assert len(book1) == 1
+ pub1 = json.loads(book1[0]["publisher"])
+ assert pub1 == {"_VALUE": "O'Reilly Media", "_country": "USA", "_language": None}
+
+ book3 = df.filter(col('"_id"') == 3).collect()
+ assert len(book3) == 1
+ pub3 = json.loads(book3[0]["publisher"])
+ assert pub3 == {"_VALUE": "Springer", "_country": "Canada", "_language": "English"}
+
+ book4 = df.filter(col('"_id"') == 4).collect()
+ assert len(book4) == 1
+ pub4 = json.loads(book4[0]["publisher"])
+ assert pub4 == {"_VALUE": "Some Publisher", "_country": None, "_language": None}
+
+ book5 = df.filter(col('"_id"') == 5).collect()
+ assert len(book5) == 1
+ pub5 = json.loads(book5[0]["publisher"])
+ assert pub5 == {"_VALUE": None, "_country": None, "_language": None}
+
+
+# ─── inferSchema vs inferSchema=false comparison ────────────────────────────
+
+
+@pytest.mark.parametrize(
+ "staged_file, row_tag",
+ [
+ ("xml_infer_types.xml", "item"),
+ (RES_BOOKS_XML, "book"),
+ (RES_BOOKS2_XML, "book"),
+ ],
+)
+def test_infer_vs_no_infer_column_count(session, staged_file, row_tag):
+ """inferSchema produces typed columns; no inferSchema produces variant columns."""
+ df_infer = session.read.option("rowTag", row_tag).xml(
+ f"@{tmp_stage_name}/{staged_file}"
+ )
+ df_no_infer = (
+ session.read.option("rowTag", row_tag)
+ .option("inferSchema", False)
+ .xml(f"@{tmp_stage_name}/{staged_file}")
+ )
+ assert df_infer._all_variant_cols is False
+ assert df_no_infer._all_variant_cols is True
+ assert df_infer.count() == df_no_infer.count()
+
+
+# ─── samplingRatio tests ─────────────────────────────────────────────────────
+
+
+def test_sampling_ratio_schema_books_flat(session):
+ """samplingRatio=0.5 on homogeneous books.xml: correct schema and deterministic across runs."""
+ path = f"@{tmp_stage_name}/{RES_BOOKS_XML}"
+ schemas = []
+ for _ in range(3):
+ df = (
+ session.read.option("rowTag", "book").option("samplingRatio", 0.5).xml(path)
+ )
+ schemas.append(df.schema)
+
+ assert schemas[0] == schemas[1] == schemas[2]
+ assert df._all_variant_cols is False
+ assert _schema_types(df) == {
+ "_id": StringType,
+ "author": StringType,
+ "description": StringType,
+ "genre": StringType,
+ "price": DoubleType,
+ "publish_date": DateType,
+ "title": StringType,
+ }
+ assert df.count() == 12
+
+
+@pytest.mark.parametrize("invalid_ratio", [0, -0.5])
+def test_sampling_ratio_invalid(session, invalid_ratio):
+ with pytest.raises(ValueError, match="should be greater than 0"):
+ session.read.option("rowTag", "ROW").option("samplingRatio", invalid_ratio).xml(
+ f"@{tmp_stage_name}/{_staged_files['primitives']}"
+ )
+
+
+def test_sampling_ratio_hetero_may_narrow_schema(session):
+ """Low samplingRatio on heterogeneous data may infer a narrower schema."""
+ path = f"@{tmp_stage_name}/{_staged_files['sampling_hetero']}"
+
+ df_full = session.read.option("rowTag", "ROW").xml(path)
+ assert _schema_types(df_full)["value"] == DoubleType
+
+ df_sampled = (
+ session.read.option("rowTag", "ROW").option("samplingRatio", 0.3).xml(path)
+ )
+ types_sampled = _schema_types(df_sampled)
+ assert types_sampled["value"] in (LongType, DoubleType)
+ assert types_sampled["name"] == StringType
+ assert df_sampled.count() == 20
+
+
+def test_sampling_ratio_nested_schema_preserved(session):
+ """samplingRatio < 1.0 on nested data still infers nested structure."""
+ df = (
+ session.read.option("rowTag", "book")
+ .option("samplingRatio", 0.5)
+ .xml(f"@{tmp_stage_name}/{RES_BOOKS2_XML}")
+ )
+ assert df._all_variant_cols is False
+ assert _schema_types(df) == {
+ "_id": LongType,
+ "author": StringType,
+ "editions": VariantType,
+ "price": DoubleType,
+ "reviews": VariantType,
+ "title": StringType,
+ }
+ assert df.count() == 2
+ book1 = df.filter(col('"_id"') == 1).collect()[0]
+ reviews = json.loads(book1["reviews"])
+ assert isinstance(reviews["review"], list)
+ assert len(reviews["review"]) == 2
+
+
+def test_infer_schema_non_existing_file(session):
+ with pytest.raises(ValueError, match="does not exist"):
+ session.read.option("rowTag", "row").xml(
+ f"@{tmp_stage_name}/non_existing_file.xml"
+ )
+
+
+def test_infer_schema_use_leaf_row_tag(session):
+ xml_content = "- hello
- world
"
+ actual_filename = _upload_xml_string(
+ session, tmp_stage_name, "leaf_only_infer.xml", xml_content
+ )
+ df = session.read.option("rowTag", "item").xml(
+ f"@{tmp_stage_name}/{actual_filename}"
+ )
+ assert df.count() == 2
+ assert "_VALUE" in [f.name for f in df.schema.fields] or len(df.schema.fields) == 1
diff --git a/tests/integ/test_xml_reader_row_tag.py b/tests/integ/test_xml_reader_row_tag.py
index b63123ae10..4eda2408c4 100644
--- a/tests/integ/test_xml_reader_row_tag.py
+++ b/tests/integ/test_xml_reader_row_tag.py
@@ -4,6 +4,7 @@
import datetime
import logging
import json
+import os
import pytest
from snowflake.snowpark import Row
@@ -16,10 +17,14 @@
StructType,
StructField,
StringType,
+ LongType,
DoubleType,
DateType,
+ BooleanType,
+ TimestampType,
ArrayType,
)
+import snowflake.snowpark.context as context
from tests.utils import TestFiles, Utils
@@ -32,6 +37,15 @@
]
+@pytest.fixture()
+def enable_scos_compatible_mode():
+ """Enable SCOS compatible mode so that type validation runs inside the UDTF."""
+ original = context._is_snowpark_connect_compatible_mode
+ context._is_snowpark_connect_compatible_mode = True
+ yield
+ context._is_snowpark_connect_compatible_mode = original
+
+
# XML test file constants
test_file_books_xml = "books.xml"
test_file_books2_xml = "books2.xml"
@@ -46,10 +60,54 @@
test_file_xml_undeclared_namespace = "undeclared_namespace.xml"
test_file_null_value_xml = "null_value.xml"
test_file_books_xsd = "books.xsd"
+test_file_dk_trace_xml = "dk_trace_sample.xml"
+test_file_dblp_xml = "dblp_6kb.xml"
+test_file_books_attr_val_xml = "books_attribute_value.xml"
# Global stage name for uploading test files
tmp_stage_name = Utils.random_stage_name()
+# Inline XML strings for permissive/failfast/dropmalformed mode tests
+SAMPLING_MISMATCH_XML = """\
+
+
+ Alice100
+ Bob200
+ Carol300
+ Dave400
+ Eve500
+ Frankhello
+
+"""
+
+MULTIFIELD_MISMATCH_XML = """\
+
+
+ 42true3.14
+ not_a_nummaybe2.72
+ 99falsenot_a_dbl
+
+"""
+
+
+def _upload_xml_string(session, stage, filename, xml_content):
+ """Write XML string to a temp file and upload to stage."""
+ import tempfile
+
+ with tempfile.NamedTemporaryFile(
+ mode="w", suffix=".xml", delete=False, prefix=filename.replace(".xml", "_")
+ ) as f:
+ f.write(xml_content)
+ tmp_path = f.name
+ try:
+ Utils.upload_to_stage(session, stage, tmp_path, compress=False)
+ finally:
+ os.unlink(tmp_path)
+ return os.path.basename(tmp_path)
+
+
+_staged_files = {}
+
@pytest.fixture(scope="module", autouse=True)
def setup(session, resources_path, local_testing_mode):
@@ -124,6 +182,34 @@ def setup(session, resources_path, local_testing_mode):
test_files.test_books_xsd,
compress=False,
)
+ Utils.upload_to_stage(
+ session,
+ "@" + tmp_stage_name,
+ test_files.test_dk_trace_sample_xml,
+ compress=False,
+ )
+ Utils.upload_to_stage(
+ session,
+ "@" + tmp_stage_name,
+ test_files.test_dblp_6kb_xml,
+ compress=False,
+ )
+ Utils.upload_to_stage(
+ session,
+ "@" + tmp_stage_name,
+ test_files.test_books_attribute_value_xml,
+ compress=False,
+ )
+
+ # Upload inline XML strings for mode tests
+ for name, xml_str in {
+ "sampling_mismatch": SAMPLING_MISMATCH_XML,
+ "multifield_mismatch": MULTIFIELD_MISMATCH_XML,
+ }.items():
+ staged = _upload_xml_string(
+ session, "@" + tmp_stage_name, f"{name}.xml", xml_str
+ )
+ _staged_files[name] = staged
yield
# Clean up resources
@@ -138,6 +224,8 @@ def setup(session, resources_path, local_testing_mode):
[test_file_books2_xml, "book", 2, 6],
[test_file_house_xml, "House", 37, 22],
[test_file_house_large_xml, "House", 740, 22],
+ [test_file_dblp_xml, "mastersthesis", 6, 8],
+ [test_file_books_attr_val_xml, "book", 5, 6],
],
)
def test_read_xml_row_tag(
@@ -769,3 +857,292 @@ def test_value_tag_custom_schema(session):
.xml(f"@{tmp_stage_name}/{test_file_null_value_xml}")
)
Utils.check_answer(df, [Row(num="1", str1="NULL", str2=None, str3="xxx")])
+
+
+def test_read_xml_namespace_user_schema(session):
+ """User-provided schema with namespace-prefixed fields (ignoreNamespace=false)."""
+ user_schema = StructType(
+ [
+ StructField("eqTrace:event-id", StringType(), True),
+ StructField("eqTrace:event-name", StringType(), True),
+ StructField("eqTrace:event-version-number", DoubleType(), True),
+ StructField("eqTrace:is-planned", BooleanType(), True),
+ StructField("eqTrace:is-synthetic", BooleanType(), True),
+ StructField("eqTrace:date-time", TimestampType(), True),
+ ]
+ )
+ df = (
+ session.read.option("rowTag", "eqTrace:event")
+ .option("ignoreNamespace", False)
+ .schema(user_schema)
+ .xml(f"@{tmp_stage_name}/{test_file_dk_trace_xml}")
+ )
+ result = df.collect()
+ assert len(result) == 5
+ assert len(result[0]) == 6
+ col_names = [f.name.strip('"') for f in df.schema.fields]
+ assert "eqTrace:event-id" in col_names
+ assert "eqTrace:event-name" in col_names
+ event_ids = {r[0] for r in result}
+ assert "f0e765d9-599b-46bf-9aef-bd33e0c2183f" in event_ids
+ assert "dd9a4616-e41c-4571-9bda-9a5506a2b78d" in event_ids
+
+
+def test_read_xml_dblp_user_schema(session):
+ """User-provided schema on dblp_6kb.xml mastersthesis with nested ee StructType."""
+ user_schema = StructType(
+ [
+ StructField("_mdate", DateType(), True),
+ StructField("_key", StringType(), True),
+ StructField("author", StringType(), True),
+ StructField("title", StringType(), True),
+ StructField("year", LongType(), True),
+ StructField("school", StringType(), True),
+ StructField(
+ "ee",
+ StructType(
+ [
+ StructField("_VALUE", StringType(), True),
+ StructField("_type", StringType(), True),
+ ]
+ ),
+ True,
+ ),
+ ]
+ )
+ df = (
+ session.read.option("rowTag", "mastersthesis")
+ .schema(user_schema)
+ .xml(f"@{tmp_stage_name}/{test_file_dblp_xml}")
+ )
+ result = df.collect()
+ assert len(result) == 6
+ assert len(result[0]) == 7
+ brown = df.filter(col('"_key"') == "ms/Brown92").collect()
+ assert len(brown) == 1
+ assert json.loads(brown[0]["ee"]) == {"_VALUE": None, "_type": None}
+ assert brown[0]["year"] == 1992
+
+
+def test_read_xml_dblp_incollection_user_schema(session):
+ """User-provided ArrayType(StructType) for author, StructType for ee on dblp incollection."""
+ author_element_schema = StructType(
+ [
+ StructField("_VALUE", StringType(), True),
+ StructField("_orcid", StringType(), True),
+ ]
+ )
+ ee_schema = StructType(
+ [
+ StructField("_VALUE", StringType(), True),
+ StructField("_type", StringType(), True),
+ ]
+ )
+ user_schema = StructType(
+ [
+ StructField("_key", StringType(), True),
+ StructField("_mdate", DateType(), True),
+ StructField("author", ArrayType(author_element_schema, True), True),
+ StructField("booktitle", StringType(), True),
+ StructField("crossref", StringType(), True),
+ StructField("ee", ee_schema, True),
+ StructField("pages", StringType(), True),
+ StructField("title", StringType(), True),
+ StructField("url", StringType(), True),
+ StructField("year", LongType(), True),
+ ]
+ )
+ df = (
+ session.read.option("rowTag", "incollection")
+ .schema(user_schema)
+ .xml(f"@{tmp_stage_name}/{test_file_dblp_xml}")
+ )
+ result = df.collect()
+ assert len(result) == 6
+ assert len(result[0]) == 11
+
+ parker = df.filter(col('"_key"') == "series/ifip/ParkerD14").collect()
+ assert len(parker) == 1
+ authors = json.loads(parker[0]["author"])
+ assert len(authors) == 2
+ assert authors[0]["_VALUE"] == "Kevin R. Parker"
+ assert authors[0]["_orcid"] == "0000-0003-0549-3687"
+ assert authors[1]["_VALUE"] == "Bill Davey"
+ assert authors[1]["_orcid"] is None
+ ee = json.loads(parker[0]["ee"])
+ assert ee["_VALUE"] == "https://doi.org/10.1007/978-3-642-55119-2_14"
+ assert ee["_type"] == "oa"
+
+ scheid = df.filter(col('"_key"') == "series/ifip/ScheidRKFRS21").collect()
+ assert len(scheid) == 1
+ scheid_authors = json.loads(scheid[0]["author"])
+ assert len(scheid_authors) == 6
+ assert all(a["_orcid"] is not None for a in scheid_authors)
+
+ rheingans = df.filter(col('"_key"') == "series/ifip/RheingansL95").collect()
+ assert len(rheingans) == 1
+ ee_no_type = json.loads(rheingans[0]["ee"])
+ assert ee_no_type["_type"] is None
+ assert "doi.org" in ee_no_type["_VALUE"]
+ assert parker[0]["year"] == 2014
+
+
+def test_read_xml_attribute_value_user_schema_struct_publisher(session):
+ """User StructType schema for publisher on books_attribute_value.xml."""
+ publisher_schema = StructType(
+ [
+ StructField("_VALUE", StringType(), True),
+ StructField("_country", StringType(), True),
+ StructField("_language", StringType(), True),
+ ]
+ )
+ user_schema = StructType(
+ [
+ StructField("_id", LongType(), True),
+ StructField("title", StringType(), True),
+ StructField("author", StringType(), True),
+ StructField("price", DoubleType(), True),
+ StructField("publisher", publisher_schema, True),
+ ]
+ )
+ df = (
+ session.read.option("rowTag", "book")
+ .schema(user_schema)
+ .xml(f"@{tmp_stage_name}/{test_file_books_attr_val_xml}")
+ )
+ result = df.collect()
+ assert len(result) == 5
+ assert len(result[0]) == 5
+
+ book1 = df.filter(col('"_id"') == 1).collect()
+ assert len(book1) == 1
+ pub1 = json.loads(book1[0]["publisher"])
+ assert pub1 == {"_VALUE": "O'Reilly Media", "_country": "USA", "_language": None}
+
+ book3 = df.filter(col('"_id"') == 3).collect()
+ assert len(book3) == 1
+ pub3 = json.loads(book3[0]["publisher"])
+ assert pub3 == {"_VALUE": "Springer", "_country": "Canada", "_language": "English"}
+
+ book4 = df.filter(col('"_id"') == 4).collect()
+ assert len(book4) == 1
+ pub4 = json.loads(book4[0]["publisher"])
+ assert pub4 == {"_VALUE": "Some Publisher", "_country": None, "_language": None}
+ assert book1[0]["price"] == 29.99
+ assert book1[0]["_id"] == 1
+
+
+def test_permissive_type_mismatch_user_schema(session, enable_scos_compatible_mode):
+ schema = StructType(
+ [
+ StructField("name", StringType()),
+ StructField("value", LongType()),
+ ]
+ )
+ df = (
+ session.read.option("rowTag", "ROW")
+ .schema(schema)
+ .xml(f"@{tmp_stage_name}/{_staged_files['sampling_mismatch']}")
+ )
+ result = df.order_by('"name"').collect()
+ assert len(result) == 6
+
+ alice = [r for r in result if r["name"] == "Alice"][0]
+ assert alice["value"] == 100
+ assert alice["_corrupt_record"] is None
+
+ frank = [r for r in result if r["name"] == "Frank"][0]
+ assert frank["value"] is None
+ assert frank["_corrupt_record"] is not None
+ assert "hello" in frank["_corrupt_record"]
+
+ col_names = [c.strip('"') for c in df.columns]
+ assert "_corrupt_record" in col_names
+
+
+def test_permissive_multifield_per_field_granularity(
+ session, enable_scos_compatible_mode
+):
+ schema = StructType(
+ [
+ StructField("int_col", LongType()),
+ StructField("bool_col", BooleanType()),
+ StructField("dbl_col", DoubleType()),
+ ]
+ )
+ df = (
+ session.read.option("rowTag", "ROW")
+ .schema(schema)
+ .xml(f"@{tmp_stage_name}/{_staged_files['multifield_mismatch']}")
+ )
+ result = df.order_by('"int_col"').collect()
+ assert len(result) == 3
+
+ assert result[0]["int_col"] is None
+ assert result[0]["bool_col"] is None
+ assert abs(result[0]["dbl_col"] - 2.72) < 0.001
+
+ assert result[1]["int_col"] == 42
+ assert result[1]["bool_col"] is True
+ assert abs(result[1]["dbl_col"] - 3.14) < 0.001
+
+ assert result[2]["int_col"] == 99
+ assert result[2]["bool_col"] is False
+ assert result[2]["dbl_col"] is None
+
+
+def test_failfast_type_mismatch_raises(session, enable_scos_compatible_mode):
+ narrow_schema = StructType(
+ [
+ StructField("name", StringType()),
+ StructField("value", LongType()),
+ ]
+ )
+ with pytest.raises(SnowparkSQLException):
+ session.read.option("rowTag", "ROW").option("mode", "FAILFAST").schema(
+ narrow_schema
+ ).xml(f"@{tmp_stage_name}/{_staged_files['sampling_mismatch']}")
+
+
+def test_dropmalformed_type_mismatch_drops_rows(session, enable_scos_compatible_mode):
+ narrow_schema = StructType(
+ [
+ StructField("name", StringType()),
+ StructField("value", LongType()),
+ ]
+ )
+ df = (
+ session.read.option("rowTag", "ROW")
+ .option("mode", "DROPMALFORMED")
+ .schema(narrow_schema)
+ .xml(f"@{tmp_stage_name}/{_staged_files['sampling_mismatch']}")
+ )
+ result = df.order_by('"name"').collect()
+ assert len(result) == 5
+ names = [r["name"] for r in result]
+ assert "Frank" not in names
+ for r in result:
+ assert r["value"] is not None
+
+
+def test_dropmalformed_multifield_drops_any_bad_field(
+ session, enable_scos_compatible_mode
+):
+ schema = StructType(
+ [
+ StructField("int_col", LongType()),
+ StructField("bool_col", BooleanType()),
+ StructField("dbl_col", DoubleType()),
+ ]
+ )
+ df = (
+ session.read.option("rowTag", "ROW")
+ .option("mode", "DROPMALFORMED")
+ .schema(schema)
+ .xml(f"@{tmp_stage_name}/{_staged_files['multifield_mismatch']}")
+ )
+ result = df.collect()
+ assert len(result) == 1
+ assert result[0]["int_col"] == 42
+ assert result[0]["bool_col"] is True
+ assert abs(result[0]["dbl_col"] - 3.14) < 0.001
diff --git a/tests/resources/books_attribute_value.xml b/tests/resources/books_attribute_value.xml
new file mode 100644
index 0000000000..107be4bedc
--- /dev/null
+++ b/tests/resources/books_attribute_value.xml
@@ -0,0 +1,33 @@
+
+
+ The Art of Snowflake
+ Jane Doe
+ 29.99
+ O'Reilly Media
+
+
+ XML for Data Engineers
+ John Smith
+ 35.50
+ Springer
+
+
+ Book 3
+ John Smith
+ 35.50
+ Springer
+
+
+ Book 4
+ John Smith
+ 35.50
+ Some Publisher
+
+
+ Book 5
+ author5
+ 35
+
+ 5
+
+
diff --git a/tests/resources/dblp_6kb.xml b/tests/resources/dblp_6kb.xml
new file mode 100644
index 0000000000..2fe62000eb
--- /dev/null
+++ b/tests/resources/dblp_6kb.xml
@@ -0,0 +1,126 @@
+
+
+
+
+Oliver Hoffmann 0002
+Regelbasierte Extraktion und asymmetrische Fusion bibliographischer Informationen.
+2009
+University of Trier
+Diplomarbeit, Universität Trier, FB IV, DBIS/DBLP
+http://dblp.uni-trier.de/papers/DiplomarbeitOliverHoffmann.pdf
+
+
+Rita Ley
+Der Einfluss kleiner naturnaher Retentionsmaßnahmen in der Fläche auf den Hochwasserabfluss - Kleinrückhaltebecken -.
+2006
+Diplomarbeit, Universität Trier, FB VI, Physische Geographie
+http://dblp.uni-trier.de/papers/DiplomarbeitRitaLey.pdf
+
+
+Kurt P. Brown
+PRPL: A Database Workload Specification Language, v1.3.
+1992
+University of Wisconsin-Madison
+
+
+
+
+Stephan Vollmer
+Portierung des DBLP-Systems auf ein relationales Datenbanksystem und Evaluation der Performance.
+2006
+Diplomarbeit, Universität Trier, FB IV, Informatik
+http://dbis.uni-trier.de/Diplomanden/Vollmer/vollmer.shtml
+
+
+
+
+Vanessa C. Klaas
+Who's Who in the World Wide Web: Approaches to Name Disambiguation
+2007
+Diplomarbeit, LMU München, Informatik
+http://www.pms.ifi.lmu.de/publikationen/diplomarbeiten/Vanessa.Klaas/thesis.pdf
+
+
+Tolga Yurek
+Efficient View Maintenance at Data Warehouses.
+1997
+University of California at Santa Barbara, Department of Computer Science, CA, USA
+
+
+
+
+Ulrich Briefs
+John Kjaer
+Jean-Louis Rigal
+Computerization and Work, A Reader on Social Aspects of Computerization
+Computerization and Work
+Springer
+1985
+978-3-540-15367-2
+978-3-642-70453-6
+https://doi.org/10.1007/978-3-642-70453-6
+IFIP State-of-the-Art Reports
+db/series/ifip/computerization1985.html
+
+Kevin R. Parker
+Bill Davey
+Computers in Schools in the USA: A Social History.
+203-211
+2014
+Reflections on the History of Computers in Education
+https://doi.org/10.1007/978-3-642-55119-2_14
+series/ifip/hedu2014
+db/series/ifip/hedu2014.html#ParkerD14
+
+Penny Rheingans
+Chris Landreth
+Perceptual Principles for Effective Visualizations.
+59-73
+1995
+Perceptual Issues in Visualization
+https://doi.org/10.1007/978-3-642-79057-7_6
+series/ifip/piv1995
+db/series/ifip/piv1995.html#RheingansL95
+
+Rory Butler
+Adrie J. Visscher
+The Hopes and Realities of the Computer as a School Administration and School Management Tool.
+197-202
+2014
+Reflections on the History of Computers in Education
+https://doi.org/10.1007/978-3-642-55119-2_13
+series/ifip/hedu2014
+db/series/ifip/hedu2014.html#ButlerV14
+
+Eder J. Scheid
+Bruno Bastos Rodrigues
+Christian Killer
+Muriel Figueredo Franco
+Sina Rafati 0001
+Burkhard Stiller
+Blockchains and Distributed Ledgers Uncovered: Clarifications, Achievements, and Open Issues.
+289-317
+2021
+IFIP's Exciting First 60+ Years
+https://doi.org/10.1007/978-3-030-81701-5_12
+series/ifip/600
+db/series/ifip/sixty2021.html#ScheidRKFRS21
+
+Angela Lecomber
+Pioneering the Internet in the Nineties - An Innovative Project Involving UK and Australian Schools.
+384-393
+2014
+Reflections on the History of Computers in Education
+https://doi.org/10.1007/978-3-642-55119-2_27
+series/ifip/hedu2014
+db/series/ifip/hedu2014.html#Lecomber14
+
+Hartmut Ehrig
+Hans-Jörg Kreowski
+Refinement and Implementation.
+201-242
+1999
+Algebraic Foundations of Systems Specification
+https://doi.org/10.1007/978-3-642-59851-7_7
+series/ifip/afss1999
+db/series/ifip/afss1999.ht
\ No newline at end of file
diff --git a/tests/resources/dk_trace_sample.xml b/tests/resources/dk_trace_sample.xml
new file mode 100644
index 0000000000..466bdcc769
--- /dev/null
+++ b/tests/resources/dk_trace_sample.xml
@@ -0,0 +1,1015 @@
+
+
+ 124922747
+
+ CAAU
+ 573469
+
+
+
+ 124922747
+
+ CAAU
+ 573469
+
+ 0
+ 000000000000000000000000001402748340
+
+ Opened
+
+ 2023-11-27T21:48:00Z
+ 2023-11-27T21:48:00Z
+
+ Load
+ InYard
+
+ UP
+
+
+ 2023-11-27T21:48:00Z
+
+
+ true
+
+
+ EVER
+
+ 66ef31dc-f71f-45e1-8b1a-6877da033c66
+ Industry
+ 502238
+
+ false
+
+ 502238
+
+
+ 053f9df9-932f-4795-ab21-cb66546e9314
+ equipment/equipment-interchanged
+ 2023-11-27T21:48:00Z
+ 2023-11-27T21:49:56Z
+
+ 7609
+
+
+
+ EqStopReason
+ EventCode
+ InterchangeReceipt
+
+
+
+
+ 7609
+
+
+ InterchangeReceipt
+
+ 47264fad-5172-4de8-b513-4f2f4ea89a9e
+ InterchangeReceipt
+ 7609
+
+ ETS
+ 2023-11-27T21:48:00Z
+
+
+
+ true
+ ETS
+ OriginSwitch
+ UP
+
+ 142054
+
+
+ 7609
+ TICTF
+
+
+
+ false
+ UP
+ RoadHaul
+ ETS
+
+ 7609
+ TICTF
+
+
+ 502238
+
+
+
+
+
+ 66ef31dc-f71f-45e1-8b1a-6877da033c66
+ Industry
+
+ 677a8cbc-6de2-4bb5-b778-9b8abbdf935e
+ equipment/equipment-cycle-status-changed
+
+ 502238
+
+ ActualPlacement
+
+ CompleteUnload
+
+
+
+ FinalSystemDest
+ NextPlannedStop
+ NextScheduledStop
+
+
+
+ EqStopReason
+ EventCode
+ ActualPlacement
+
+
+ EqStopSubReason
+ EventQualifierCode
+ CompleteUnload
+
+
+
+
+
+ 4611110
+ STCC
+ MIXFRT
+ OTHER
+ FREIGHT ALL KINDS, (FAK) OR ALL FREIGHT RATE SHIPMENTS, NEC, OR TRAILER-ON FLATCAR SHIPMENTS, COMMERCIAL (EXCEPT IDENTIFIED BY COMMODITIES, THEN CODE BY COMMODITY)
+
+ false
+ false
+ false
+
+
+ false
+ false
+
+
+
+
+ CofcShipment
+ Reporting
+
+
+ ImportShipment
+ Reporting
+
+
+ InBond
+ Reporting
+
+
+ ShprCertfdScaleWgts
+ Reporting
+
+
+
+
+ Tare
+ Umler
+ 8000
+ false
+
+
+ Net
+ Waybill
+ 9619
+ false
+
+
+ Gross
+ NCEqmtActRptg
+ 17800
+ false
+
+
+
+
+ 869277338
+
+ 114035
+ 2023-11-17
+
+
+
+
+ true
+ BillOfLadingNumber
+ 100123111771505
+
+
+ false
+ OceanBillOfLading
+ 003302074967
+
+
+ false
+ EssCommFlags
+ ABEKEY
+ 231117103606
+
+
+ false
+ EssGrpId
+ PABOCIRC7
+ CS526
+
+
+ false
+ EssCommFlags
+ YNNNNNNYYNNYNNNNYNNNYNNNNNNNNN
+ NNNNNNNNNN
+
+
+ false
+ EssGrpId
+ PABDCIRC7
+ HL011
+
+
+ false
+ EssGrpId
+ ESSID=880553922,0
+
+
+ Local
+ Load
+ false
+ false
+ false
+ 1
+ false
+ ShipperCert
+
+ 4611110
+ STCC
+ MIXFRT
+ OTHER
+ FREIGHT ALL KINDS, (FAK) OR ALL FREIGHT RATE SHIPMENTS, NEC, OR TRAILER-ON FLATCAR SHIPMENTS, COMMERCIAL (EXCEPT IDENTIFIED BY COMMODITIES, THEN CODE BY COMMODITY)
+
+ false
+ false
+ false
+
+
+ false
+ false
+
+
+
+ 7609
+
+
+ 502238
+
+
+
+ UP
+ RoadHaul
+
+ 7609
+
+
+ 502238
+
+
+
+ 3
+
+
+ InCareOf1
+
+ 633493
+ 144973435
+ EVERGREEN SHIPPING AGENCY
+ 16000 DALLAS PKWY STE 400
+ DALLAS
+ TX
+ US
+
+
+ Phone
+ 9722465520
+
+
+
+
+
+ IMMEDIATE
+
+ 831777
+ 144973435
+ EVERGREEN SHIPPING AGENCY
+ DALLAS
+ TX
+ US
+
+
+
+
+
+ Consignee
+
+ 633493
+ 144973435
+ EVERGREEN SHIPPING AGENCY
+ 16000 DALLAS PKWY STE 400
+ DALLAS
+ TX
+ US
+
+
+ Phone
+ 9722465520
+
+
+
+
+
+ IMMEDIATE
+
+ 831777
+ 144973435
+ EVERGREEN SHIPPING AGENCY
+ DALLAS
+ TX
+ US
+
+
+
+
+
+ Notify1
+
+ 831777
+ 144973435
+ EVERGREEN SHIPPING AGENCY
+ 16000 DALLAS PKWY STE 400
+ DALLAS
+ TX
+ US
+
+
+ Phone
+ 9722465520
+
+
+ Phone
+ 8665718765
+
+
+
+
+
+ IMMEDIATE
+
+ 993341
+ 174299743
+ EVERGREEN SHIPPING AGENCY
+ JERSEY CITY
+ NJ
+ US
+
+
+
+ ULTIMATE
+
+ 993341
+ 174299743
+ EVERGREEN SHIPPING AGENCY
+ JERSEY CITY
+ NJ
+ US
+
+
+
+
+
+ ToReceiveFreight
+
+ 27313
+ EVERGREEN SHIPPING AGENCY
+ CYPRESS
+ CA
+ US
+
+
+ Phone
+ 7148226800
+
+
+
+
+
+ Shipper
+
+ 371960
+ 174299743
+ EVERGREEN SHIPPING AGENCY
+ ONE EVERTRUST PLZ
+ JERSEY CITY
+ NJ
+ US
+
+
+ Phone
+ 2017613000
+
+
+
+
+
+ IMMEDIATE
+
+ 993341
+ 174299743
+ EVERGREEN SHIPPING AGENCY
+ JERSEY CITY
+ NJ
+ US
+
+
+
+
+
+
+ EMCJFR5623
+
+
+
+ ShippersBondNumber
+ VVS10930776
+
+
+ 9619
+ Net
+
+
+ IN SHIPPER BOND
+ TCS
+
+
+ WB
+ AUTOBILL
+
+ 003302074967
+ 1322-018E
+ LOS ANGELES
+ D
+ 2023-11-20
+ EVER FOREVER
+
+
+
+ CT
+ 52072
+ UP
+ UPC
+
+
+
+
+
+ Intermodal
+ Container
+ false
+ false
+ K4E
+ U
+ 903
+ 8000
+ U
+ P
+ 480
+ ACT
+ INSV
+ CAAU
+ CAAU
+ CAAU
+ N
+
+
+
+ f0e765d9-599b-46bf-9aef-bd33e0c2183f
+ equipment/equipment-placement-at-industry-planned
+ 1.2
+ cn=dneq999,ou=uprr,o=up
+
+ 01f14f9f-2464-448b-9051-2da30b5f3fcd
+ InterchangeReceipt
+ 2023-11-27T21:49:50Z
+ 2023-11-27T21:49:50Z
+ 2023-11-27T21:49:56Z
+ OIPT223
+ UP
+ InterchangeReceiptUI
+
+ InterchangeReceipt
+
+
+
+ Stop
+
+ true
+ false
+
+ 502238
+
+
+
+ EqStopReason
+ EventCode
+ ActualPlacement
+
+
+ EqStopSubReason
+ EventQualifierCode
+ CompleteUnload
+
+
+
+ NEW
+
+ CompleteUnload
+
+ 502238
+
+
+ 66ef31dc-f71f-45e1-8b1a-6877da033c66
+
+ FinalSystemDest
+ NextPlannedStop
+ NextScheduledStop
+
+
+
+
+ dd9a4616-e41c-4571-9bda-9a5506a2b78d
+ equipment/equipment-waybill-applied
+ 1.2
+ cn=dneq999,ou=uprr,o=up
+
+ 01f14f9f-2464-448b-9051-2da30b5f3fcd
+ InterchangeReceipt
+ 2023-11-27T21:49:55Z
+ 2023-11-27T21:49:50Z
+ 2023-11-27T21:49:56Z
+ OIPT223
+ UP
+ InterchangeReceiptUI
+
+ InterchangeReceipt
+
+
+ false
+ false
+ 2023-11-27T21:49:55Z
+
+
+ EqTraceUIEventCode
+ EventCode
+ WaybillApplied
+
+
+
+ NEW
+
+
+ 000000000000000000000000001402748340
+ Opened
+ Current
+
+
+
+ 869277338
+
+ 114035
+ 2023-11-17
+
+
+
+
+ true
+ BillOfLadingNumber
+ 100123111771505
+ DESCRIPTION
+
+
+ false
+ OceanBillOfLading
+ 003302074967
+ DESCRIPTION
+
+
+ false
+ EssCommFlags
+ ABEKEY
+ DESCRIPTION
+
+
+ false
+ EssGrpId
+ PABOCIRC7
+ DESCRIPTION
+
+
+ false
+ EssCommFlags
+ YNNNNNNYYNNYNNNNYNNNYNNNNNNNNN
+ DESCRIPTION
+
+
+ false
+ EssGrpId
+ PABDCIRC7
+ DESCRIPTION
+
+
+ false
+ EssGrpId
+ ESSID=880553922,0
+ DESCRIPTION
+
+
+ Local
+ Load
+ false
+ false
+ false
+ false
+
+ 4611110
+ MIXFRT
+ OTHER
+ FREIGHT ALL KINDS, (FAK) OR ALL FREIGHT RATE SHIPMENTS, NEC, OR TRAILER-ON FLATCAR SHIPMENTS, COMMERCIAL (EXCEPT IDENTIFIED BY COMMODITIES, THEN CODE BY COMMODITY)
+
+ false
+ false
+
+
+
+ 7609
+
+
+ 502238
+
+
+
+ UP
+ RoadHaul
+
+ 7609
+
+
+ 502238
+
+
+
+
+
+ Shipper
+
+ 371960
+ 174299743
+ EVERGREEN SHIPPING AGENCY
+ ONE EVERTRUST PLZ
+ JERSEY CITY
+ NJ
+ US
+
+
+
+ IMMEDIATE
+
+ 993341
+ 174299743
+ EVERGREEN SHIPPING AGENCY
+ JERSEY CITY
+ NJ
+ US
+
+
+
+
+
+ Consignee
+
+ 633493
+ 144973435
+ EVERGREEN SHIPPING AGENCY
+ 16000 DALLAS PKWY STE 400
+ DALLAS
+ TX
+ US
+
+
+
+ IMMEDIATE
+
+ 831777
+ 144973435
+ EVERGREEN SHIPPING AGENCY
+ DALLAS
+ TX
+ US
+
+
+
+
+
+ InCareOf
+
+ 633493
+ 144973435
+ EVERGREEN SHIPPING AGENCY
+ 16000 DALLAS PKWY STE 400
+ DALLAS
+ TX
+ US
+
+
+
+ IMMEDIATE
+
+ 831777
+ 144973435
+ EVERGREEN SHIPPING AGENCY
+ DALLAS
+ TX
+ US
+
+
+
+
+
+ 9619
+ Net
+
+
+ IN SHIPPER BOND
+ TCS
+
+
+
+ 2023-11-27T21:49:55Z
+
+
+
+
+ 677a8cbc-6de2-4bb5-b778-9b8abbdf935e
+ equipment/equipment-cycle-status-changed
+ 1.3
+ cn=dneq999,ou=uprr,o=up
+
+ 01f14f9f-2464-448b-9051-2da30b5f3fcd
+ InterchangeReceipt
+ 2023-11-27T21:48:00Z
+ 2023-11-27T21:49:50Z
+ 2023-11-27T21:49:56Z
+ OIPT223
+ UP
+ InterchangeReceiptUI
+
+ InterchangeReceipt
+
+
+ false
+ false
+ 2023-11-27T21:48:00Z
+
+
+ Opened
+
+
+ Current
+
+
+
+
+ 053f9df9-932f-4795-ab21-cb66546e9314
+ equipment/equipment-interchanged
+ 1.3
+ cn=dneq999,ou=uprr,o=up
+
+ 01f14f9f-2464-448b-9051-2da30b5f3fcd
+ InterchangeReceipt
+ 2023-11-27T21:48:00Z
+ 2023-11-27T21:49:50Z
+ 2023-11-27T21:49:56Z
+ OIPT223
+ UP
+ InterchangeReceiptUI
+
+ InterchangeReceipt
+
+
+
+ Stop
+
+ false
+ false
+ 2023-11-27T21:48:00Z
+
+ 7609
+
+
+
+ EqStopReason
+ EventCode
+ InterchangeReceipt
+
+
+
+ false
+ NEW
+ false
+ InterchangeReceipt
+ false
+ false
+ 47264fad-5172-4de8-b513-4f2f4ea89a9e
+
+ OpeningCycleStop
+
+
+
+ ETS
+ 7609
+ 2023-11-27T21:48:00Z
+
+
+
+
+
+ 4e530f8a-9ccd-445f-a807-6b9ab09e625d
+ equipment/equipment-cycle-status-changed
+ 1.3
+ cn=dneq999,ou=uprr,o=up
+
+ 01f14f9f-2464-448b-9051-2da30b5f3fcd
+ InterchangeReceipt
+ 2023-11-27T21:48:00Z
+ 2023-11-27T21:49:50Z
+ 2023-11-27T21:49:56Z
+ OIPT223
+ UP
+ InterchangeReceiptUI
+
+ InterchangeReceipt
+
+
+ false
+ false
+ 2023-11-17T16:36:06Z
+
+
+ Unopened
+ Deleted
+
+
+ Future
+ Deleted
+
+
+
+
+
+
+
+
+ 502238
+
+ 502238
+ HL011
+ 435
+ DIT
+ DIT
+ DIT
+ TX
+ US
+ America/Chicago
+ 514167
+ 667285000
+ 57810
+ 046
+
+
+ PRPTYARD
+ 11609
+
+
+ DRN
+ 336
+
+
+ BKP
+ 046
+
+
+
+
+
+ 7609
+
+ 7609
+ CS526
+ 17793
+ TICTF
+ TICTF
+ TCTF
+ CA
+ US
+ America/Los_Angeles
+ 142054
+ 883213000
+ 15004
+ 240
+
+
+ PRPTYARD
+ 13975
+
+
+ DRN
+ 158
+
+
+ BKP
+ 240
+
+
+
+
+
+
+
+ 142054
+
+ 60058
+ 142054
+ 883213000
+ TICTF
+ TICTF
+ CA
+ US
+ PT
+ true
+
+
+
+
+
+ 2023-11-27T21:50:01.306Z
+
+ true
+
+
+ f0e765d9-599b-46bf-9aef-bd33e0c2183f
+ equipment/equipment-placement-at-industry-planned
+
+
+ 053f9df9-932f-4795-ab21-cb66546e9314
+ equipment/equipment-interchanged
+
+
+ 677a8cbc-6de2-4bb5-b778-9b8abbdf935e
+ equipment/equipment-cycle-status-changed
+
+
+ 4e530f8a-9ccd-445f-a807-6b9ab09e625d
+ equipment/equipment-cycle-status-changed
+
+
+ fb7cd7ed-4d05-4c04-b793-11c51923722f
+ equipment/key-equipment-properties-changed
+
+
+ dd9a4616-e41c-4571-9bda-9a5506a2b78d
+ equipment/equipment-waybill-applied
+
+
+
+
+
+
+ Load
+
+
+ true
+
+
+ false
+
+
+ Tare
+ Umler
+ 8000
+ false
+
+
+
+
+
diff --git a/tests/resources/xml_infer_mixed.xml b/tests/resources/xml_infer_mixed.xml
new file mode 100644
index 0000000000..d6abaa04b1
--- /dev/null
+++ b/tests/resources/xml_infer_mixed.xml
@@ -0,0 +1,41 @@
+
+
+
+
+
+ Child 1.1
+
+ Child 1.2
+
+ 25
+ text1
+ Hello World
+
+ 1
+ 2
+
+
+
+
+
+ Child 2.1
+
+ Child 2.2
+
+ 30
+ text2
+ Plain text only
+
+
+
+ Child 3
+
+ 35
+ 999
+ Another Nested
+
+ 3
+ 4
+
+
+
diff --git a/tests/resources/xml_infer_types.xml b/tests/resources/xml_infer_types.xml
new file mode 100644
index 0000000000..b622c447dc
--- /dev/null
+++ b/tests/resources/xml_infer_types.xml
@@ -0,0 +1,57 @@
+
+
+ -
+ Widget A
+ 44.95
+ 100
+ true
+ 2021-02-01
+ 2015-01-01 00:00:00
+
+ electronics
+ gadget
+
+
+ 1.5
+
+ 10
+ 5
+ 3
+
+
+ Excellent
+
+ -
+ Widget B
+ 29.99
+ 50
+ false
+ 2022-06-15
+ 2023-03-10 12:30:00
+
+ home
+
+
+ 2.3
+
+ 20
+ 10
+ 8
+
+
+ Good
+
+ -
+ Widget C
+ 15.00
+ 200
+ true
+ 2020-11-30
+
+ office
+ supplies
+ paper
+
+ Average
+
+
diff --git a/tests/unit/scala/test_utils_suite.py b/tests/unit/scala/test_utils_suite.py
index 6b18c6aa9f..02ca8412e0 100644
--- a/tests/unit/scala/test_utils_suite.py
+++ b/tests/unit/scala/test_utils_suite.py
@@ -313,11 +313,14 @@ def check_zip_files_and_close_stream(input_stream, expected_files):
"resources/books.xml",
"resources/books.xsd",
"resources/books2.xml",
+ "resources/books_attribute_value.xml",
"resources/broken.csv",
"resources/cat.jpeg",
"resources/conversation.ogg",
- "resources/diamonds.csv",
+ "resources/dblp_6kb.xml",
"resources/declared_namespace.xml",
+ "resources/diamonds.csv",
+ "resources/dk_trace_sample.xml",
"resources/doc.pdf",
"resources/dog.jpg",
"resources/fias_house.xml",
@@ -390,6 +393,8 @@ def check_zip_files_and_close_stream(input_stream, expected_files):
"resources/test_udaf_dir/test_udaf_file.py",
"resources/undeclared_attr_namespace.xml",
"resources/undeclared_namespace.xml",
+ "resources/xml_infer_mixed.xml",
+ "resources/xml_infer_types.xml",
"resources/xxe.xml",
],
)
diff --git a/tests/unit/test_xml_reader.py b/tests/unit/test_xml_reader.py
index 85326d87d8..444a65ce95 100644
--- a/tests/unit/test_xml_reader.py
+++ b/tests/unit/test_xml_reader.py
@@ -28,6 +28,9 @@
_escape_colons_in_quotes,
_restore_colons_in_template,
_COLON_PLACEHOLDER,
+ _can_cast_to_type,
+ _validate_row_for_type_mismatch,
+ XMLReader,
)
from snowflake.snowpark.types import (
StructType,
@@ -37,6 +40,9 @@
DateType,
ArrayType,
MapType,
+ LongType,
+ BooleanType,
+ TimestampType,
)
@@ -919,7 +925,8 @@ def test_schema_string_to_result_dict_and_struct_type(session):
)
attr, _, _ = session.read._get_schema_from_user_input(user_schema)
schema_string = attribute_to_schema_string_deep(attr)
- assert schema_string_to_result_dict_and_struct_type(schema_string) == {
+ template, schema_type = schema_string_to_result_dict_and_struct_type(schema_string)
+ assert template == {
"Author": None,
"TITLE": None,
"GENRE": None,
@@ -928,6 +935,7 @@ def test_schema_string_to_result_dict_and_struct_type(session):
"description": None,
"map_type": None,
}
+ assert isinstance(schema_type, StructType)
def test_user_schema_value_tag():
@@ -1048,7 +1056,8 @@ def test_restore_colons_in_template():
def test_schema_string_round_trip_with_colons():
# Flat
schema_str = 'struct<"px:name": string, "px:value": string>'
- assert schema_string_to_result_dict_and_struct_type(schema_str) == {
+ template, _ = schema_string_to_result_dict_and_struct_type(schema_str)
+ assert template == {
"px:name": None,
"px:value": None,
}
@@ -1057,14 +1066,16 @@ def test_schema_string_round_trip_with_colons():
schema_str = (
'struct<"eq:event-id": string,' '"eq:detail": struct<"eq:sub-id": string>>'
)
- assert schema_string_to_result_dict_and_struct_type(schema_str) == {
+ template, _ = schema_string_to_result_dict_and_struct_type(schema_str)
+ assert template == {
"eq:event-id": None,
"eq:detail": {"eq:sub-id": None},
}
# Mixed
schema_str = 'struct<"px:name": string, "Title": string, price: double>'
- assert schema_string_to_result_dict_and_struct_type(schema_str) == {
+ template, _ = schema_string_to_result_dict_and_struct_type(schema_str)
+ assert template == {
"px:name": None,
"Title": None,
"PRICE": None,
@@ -1072,7 +1083,267 @@ def test_schema_string_round_trip_with_colons():
# No colon fields
schema_str = 'struct<"Author": string, "TITLE": string>'
- assert schema_string_to_result_dict_and_struct_type(schema_str) == {
+ template, _ = schema_string_to_result_dict_and_struct_type(schema_str)
+ assert template == {
"Author": None,
"TITLE": None,
}
+
+
+@pytest.mark.parametrize(
+ "value, target_type, expected",
+ [
+ # StringType always passes
+ ("anything", StringType(), True),
+ ("", StringType(), True),
+ # LongType
+ ("42", LongType(), True),
+ ("-7", LongType(), True),
+ ("0", LongType(), True),
+ ("3.14", LongType(), False),
+ ("hello", LongType(), False),
+ ("", LongType(), False),
+ # DoubleType
+ ("3.14", DoubleType(), True),
+ ("-0.5", DoubleType(), True),
+ ("42", DoubleType(), True),
+ ("NaN", DoubleType(), True), # Python float("NaN") succeeds
+ ("hello", DoubleType(), False),
+ # BooleanType
+ ("true", BooleanType(), True),
+ ("false", BooleanType(), True),
+ ("True", BooleanType(), True),
+ ("1", BooleanType(), True),
+ ("0", BooleanType(), True),
+ ("yes", BooleanType(), False),
+ ("maybe", BooleanType(), False),
+ # DateType
+ ("2024-01-15", DateType(), True),
+ ("not-a-date", DateType(), False),
+ ("2024-13-01", DateType(), False),
+ # TimestampType
+ ("2024-01-15T10:30:00", TimestampType(), True),
+ ("2024-01-15", TimestampType(), True),
+ ("not-a-ts", TimestampType(), False),
+ ],
+)
+def test_can_cast_to_type(value, target_type, expected):
+ assert _can_cast_to_type(value, target_type) == expected
+
+
+# ---------------------------------------------------------------------------
+# _validate_row_for_type_mismatch tests
+# ---------------------------------------------------------------------------
+
+
+def _make_schema(*fields):
+ """Helper to build a StructType from (name, datatype) tuples."""
+ return StructType([StructField(f'"{n}"', t) for n, t in fields])
+
+
+def test_validate_permissive_all_valid_no_corrupt():
+ schema = _make_schema(("name", StringType()), ("age", LongType()))
+ row = {"name": "Alice", "age": "30"}
+ result = _validate_row_for_type_mismatch(row, schema, "PERMISSIVE", "
")
+ assert result["name"] == "Alice"
+ assert result["age"] == "30"
+ assert "_corrupt_record" not in result
+
+
+def test_validate_permissive_single_field_nulled():
+ schema = _make_schema(("name", StringType()), ("age", LongType()))
+ row = {"name": "Bob", "age": "hello"}
+ result = _validate_row_for_type_mismatch(row, schema, "PERMISSIVE", "
")
+ assert result["name"] == "Bob"
+ assert result["age"] is None
+ assert result["_corrupt_record"] == "
"
+
+
+def test_validate_permissive_multiple_fields_nulled():
+ schema = _make_schema(("a", LongType()), ("b", BooleanType()), ("c", DoubleType()))
+ row = {"a": "not_int", "b": "maybe", "c": "1.5"}
+ result = _validate_row_for_type_mismatch(row, schema, "PERMISSIVE", "")
+ assert result["a"] is None
+ assert result["b"] is None
+ assert result["c"] == "1.5"
+ assert result["_corrupt_record"] == ""
+
+
+def test_validate_permissive_missing_field_ignored():
+ schema = _make_schema(("name", StringType()), ("age", LongType()))
+ row = {"name": "Carol"}
+ result = _validate_row_for_type_mismatch(row, schema, "PERMISSIVE", "
")
+ assert result["name"] == "Carol"
+ assert "_corrupt_record" not in result
+
+
+def test_validate_permissive_none_value_ignored():
+ schema = _make_schema(("val", LongType()))
+ row = {"val": None}
+ result = _validate_row_for_type_mismatch(row, schema, "PERMISSIVE", "
")
+ assert result["val"] is None
+ assert "_corrupt_record" not in result
+
+
+def test_validate_permissive_complex_type_skipped():
+ schema = _make_schema(
+ ("data", StructType([StructField("x", StringType())])),
+ ("tags", ArrayType(StringType())),
+ )
+ row = {"data": {"x": "val"}, "tags": ["a", "b"]}
+ result = _validate_row_for_type_mismatch(row, schema, "PERMISSIVE", "
")
+ assert result["data"] == {"x": "val"}
+ assert result["tags"] == ["a", "b"]
+ assert "_corrupt_record" not in result
+
+
+def test_validate_permissive_custom_corrupt_col_name():
+ schema = _make_schema(("val", LongType()))
+ row = {"val": "bad"}
+ result = _validate_row_for_type_mismatch(
+ row, schema, "PERMISSIVE", "", column_name_of_corrupt_record="bad_rec"
+ )
+ assert result["val"] is None
+ assert result["bad_rec"] == ""
+ assert "_corrupt_record" not in result
+
+
+def test_validate_failfast_raises_on_mismatch():
+ schema = _make_schema(("val", LongType()))
+ row = {"val": "hello"}
+ with pytest.raises(RuntimeError, match="Failed to cast value 'hello'"):
+ _validate_row_for_type_mismatch(row, schema, "FAILFAST", "
")
+
+
+def test_validate_failfast_no_error_when_valid():
+ schema = _make_schema(("val", LongType()))
+ row = {"val": "42"}
+ result = _validate_row_for_type_mismatch(row, schema, "FAILFAST", "
")
+ assert result["val"] == "42"
+
+
+def test_validate_dropmalformed_returns_none_on_mismatch():
+ schema = _make_schema(("val", LongType()))
+ row = {"val": "hello"}
+ assert (
+ _validate_row_for_type_mismatch(row, schema, "DROPMALFORMED", "
") is None
+ )
+
+
+def test_validate_dropmalformed_returns_row_when_valid():
+ schema = _make_schema(("val", LongType()))
+ row = {"val": "42"}
+ result = _validate_row_for_type_mismatch(row, schema, "DROPMALFORMED", "
")
+ assert result["val"] == "42"
+
+
+def test_schema_string_to_result_dict_empty_string():
+ template, schema_type = schema_string_to_result_dict_and_struct_type("")
+ assert template is None
+ assert schema_type is None
+
+
+def test_can_cast_to_complex_type_returns_true():
+ assert _can_cast_to_type("anything", ArrayType(StringType())) is True
+ assert _can_cast_to_type("anything", MapType(StringType(), StringType())) is True
+ assert _can_cast_to_type("anything", StructType([])) is True
+
+
+def test_element_to_dict_leaf_with_template_edge_cases():
+ xml_str = ""
+ element = ET.fromstring(xml_str)
+ result_template = {"_VALUE": None, "_country": None, "_language": None}
+ result = element_to_dict_or_str(element, result_template=result_template)
+ assert isinstance(result, dict)
+ assert result["_VALUE"] is None
+ assert result["_country"] is None
+ assert result["_language"] is None
+
+ xml_str = "Penguin"
+ element = ET.fromstring(xml_str)
+ result = element_to_dict_or_str(element, result_template=result_template)
+ assert isinstance(result, dict)
+ assert result["_VALUE"] == "Penguin"
+ assert result["_country"] is None
+ assert result["_language"] is None
+
+ xml_str = "N/A"
+ element = ET.fromstring(xml_str)
+ result_template = {"_VALUE": None, "_country": None}
+ result = element_to_dict_or_str(
+ element, result_template=result_template, null_value="N/A"
+ )
+ assert isinstance(result, dict)
+ assert result["_VALUE"] is None
+
+
+def test_process_xml_range_scos_permissive_type_validation():
+ """process_xml_range validates types when is_snowpark_connect_compatible=True"""
+ xml_content = (
+ ""
+ "Alice100
"
+ "Frankhello
"
+ ""
+ )
+ xml_bytes = xml_content.encode("utf-8")
+ schema = _make_schema(("name", StringType()), ("value", LongType()))
+ mock_file = io.BytesIO(xml_bytes)
+ with patch("snowflake.snowpark.files.SnowflakeFile.open", return_value=mock_file):
+ results = list(
+ process_xml_range(
+ file_path="test.xml",
+ tag_name="ROW",
+ approx_start=0,
+ approx_end=len(xml_bytes),
+ mode="PERMISSIVE",
+ column_name_of_corrupt_record="_corrupt_record",
+ ignore_namespace=True,
+ attribute_prefix="_",
+ exclude_attributes=False,
+ value_tag="_VALUE",
+ null_value="",
+ charset="utf-8",
+ ignore_surrounding_whitespace=True,
+ row_validation_xsd_path="",
+ result_template={"name": None, "value": None},
+ schema_type=schema,
+ is_snowpark_connect_compatible=True,
+ )
+ )
+ assert len(results) == 2
+ frank = [r for r in results if r.get("name") == "Frank"][0]
+ assert frank["value"] is None
+ assert "_corrupt_record" in frank
+
+
+def test_xml_reader_process_with_scos_compatible_param():
+ """XMLReader.process passes is_snowpark_connect_compatible through"""
+ xml_content = "1"
+ xml_bytes = xml_content.encode("utf-8")
+ mock_file = io.BytesIO(xml_bytes)
+ with patch(
+ "snowflake.snowpark._internal.xml_reader.get_file_size",
+ return_value=len(xml_bytes),
+ ), patch("snowflake.snowpark.files.SnowflakeFile.open", return_value=mock_file):
+ results = list(
+ XMLReader().process(
+ "test.xml",
+ 1,
+ "record",
+ 0,
+ "PERMISSIVE",
+ "_corrupt_record",
+ True,
+ "_",
+ False,
+ "_VALUE",
+ "",
+ "utf-8",
+ True,
+ "",
+ "",
+ True,
+ )
+ )
+ assert len(results) == 1
+ assert results[0][0]["a"] == "1"
diff --git a/tests/unit/test_xml_schema_inference.py b/tests/unit/test_xml_schema_inference.py
new file mode 100644
index 0000000000..5267cbdf5f
--- /dev/null
+++ b/tests/unit/test_xml_schema_inference.py
@@ -0,0 +1,1893 @@
+#
+# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
+#
+
+import io
+from contextlib import ExitStack
+from unittest import mock
+from unittest.mock import patch
+
+import lxml.etree as ET
+import pytest
+
+from snowflake.snowpark._internal.xml_schema_inference import (
+ _normalize_text,
+ infer_type,
+ infer_element_schema,
+ compatible_type,
+ merge_struct_types,
+ add_or_update_type,
+ canonicalize_type,
+ _case_preserving_simple_string,
+ _struct_has_value_tag,
+ _merge_struct_with_primitive,
+ infer_schema_for_xml_range,
+ XMLSchemaInference,
+)
+from snowflake.snowpark import DataFrameReader
+import snowflake.snowpark.dataframe_reader as _dr_mod
+from snowflake.snowpark.types import (
+ StructType,
+ StructField,
+ ArrayType,
+ MapType,
+ NullType,
+ StringType,
+ BooleanType,
+ LongType,
+ DoubleType,
+ DecimalType,
+ DateType,
+ TimestampType,
+)
+
+
+def _xml(xml_str: str) -> ET.Element:
+ """Parse an XML string into an lxml Element."""
+ return ET.fromstring(xml_str)
+
+
+def _infer_and_merge(xml_records, **kwargs):
+ """Infer schema for each record and merge, mimicking the UDTF behavior."""
+ merged = None
+ value_tag = kwargs.get("value_tag", "_VALUE")
+ for rec_str in xml_records:
+ rec_str = rec_str.strip()
+ if not rec_str:
+ continue
+ try:
+ elem = _xml(rec_str)
+ except Exception:
+ continue
+ schema = infer_element_schema(elem, **kwargs)
+ if not isinstance(schema, StructType):
+ schema = StructType(
+ [StructField(value_tag, schema, nullable=True)],
+ structured=False,
+ )
+ if merged is None:
+ merged = schema
+ else:
+ merged = merge_struct_types(merged, schema, value_tag)
+ return merged
+
+
+def _full_pipeline(xml_records, **kwargs):
+ """Run the full schema inference pipeline: infer + merge + canonicalize."""
+ merged = _infer_and_merge(xml_records, **kwargs)
+ if merged is not None:
+ merged = canonicalize_type(merged)
+ return merged
+
+
+# ===========================================================================
+# _normalize_text
+# ===========================================================================
+
+
+@pytest.mark.parametrize(
+ "text, strip, expected",
+ [
+ (None, True, None),
+ (None, False, None),
+ (" hello ", True, "hello"),
+ ("\n\thello\n\t", True, "hello"),
+ (" hello ", False, " hello "),
+ ("", True, ""),
+ ("", False, ""),
+ (" ", True, ""),
+ (" ", False, " "),
+ ],
+)
+def test_normalize_text(text, strip, expected):
+ assert _normalize_text(text, strip) == expected
+
+
+# ===========================================================================
+# infer_type
+# ===========================================================================
+
+
+@pytest.mark.parametrize(
+ "value, kwargs",
+ [
+ (None, {}),
+ ("", {}),
+ ("NULL", {"null_value": "NULL"}),
+ (" NULL ", {"ignore_surrounding_whitespace": True, "null_value": "NULL"}),
+ ],
+)
+def test_infer_type_null(value, kwargs):
+ """Null/empty/null_value inputs → NullType."""
+ assert isinstance(infer_type(value, **kwargs), NullType)
+
+
+@pytest.mark.parametrize(
+ "value",
+ [
+ "42",
+ "-42",
+ "0",
+ str(2**63 - 1),
+ str(-(2**63)),
+ "+123",
+ "1",
+ "+0",
+ "-0",
+ "007",
+ ],
+)
+def test_infer_type_long(value):
+ """Integer strings within Long range → LongType."""
+ assert isinstance(infer_type(value), LongType)
+
+
+@pytest.mark.parametrize(
+ "value",
+ [
+ "3.14",
+ "1.5e10",
+ "-2.5",
+ ".5",
+ "0.0",
+ "+3.14",
+ "NaN",
+ "Infinity",
+ "+Infinity",
+ "-Infinity",
+ "92233720368547758070",
+ "-92233720368547758080",
+ "1.7976931348623157e+308",
+ ],
+)
+def test_infer_type_double(value):
+ """Decimals, scientific notation, special floats, and beyond-Long ints → DoubleType."""
+ assert isinstance(infer_type(value), DoubleType)
+
+
+@pytest.mark.parametrize("value", ["true", "false", "True", "FALSE"])
+def test_infer_type_boolean(value):
+ assert isinstance(infer_type(value), BooleanType)
+
+
+@pytest.mark.parametrize("value", ["2024-01-15", "2000-12-31", "1999-01-01"])
+def test_infer_type_date(value):
+ assert isinstance(infer_type(value), DateType)
+
+
+@pytest.mark.parametrize(
+ "value",
+ ["2024-01-15T10:30:00", "2024-01-15T10:30:00+00:00", "2011-12-03T10:15:30Z"],
+)
+def test_infer_type_timestamp(value):
+ assert isinstance(infer_type(value), TimestampType)
+
+
+@pytest.mark.parametrize(
+ "value",
+ [
+ "hello",
+ "abc123",
+ "foo@bar.com",
+ "inf",
+ "Inf",
+ "+inf",
+ "+Inf",
+ "1.5d",
+ "1.5f",
+ "1.5D",
+ "1.5F",
+ ],
+)
+def test_infer_type_string(value):
+ """Plain strings and values rejected by d/D/f/F suffix check → StringType."""
+ assert isinstance(infer_type(value), StringType)
+
+
+def test_infer_type_whitespace_stripped():
+ """With ignore_surrounding_whitespace, ' 42 ' is parsed as Long."""
+ assert isinstance(
+ infer_type(" 42 ", ignore_surrounding_whitespace=True), LongType
+ )
+ assert isinstance(
+ infer_type(" 42 ", ignore_surrounding_whitespace=False), StringType
+ )
+
+
+# ===========================================================================
+# infer_element_schema – leaf elements
+# ===========================================================================
+
+
+@pytest.mark.parametrize(
+ "xml_str, expected_type, kwargs",
+ [
+ ("42", LongType, {}),
+ ("", NullType, {}),
+ ("", NullType, {}),
+ ("N/A", NullType, {"null_value": "N/A"}),
+ ("true", BooleanType, {}),
+ ("false", BooleanType, {}),
+ ("3.14", DoubleType, {}),
+ ("hello", StringType, {}),
+ ("2024-01-15", DateType, {}),
+ ("2024-01-15T10:30:00", TimestampType, {}),
+ ("2024-01-15T10:30:00Z", TimestampType, {}),
+ ("NaN", DoubleType, {}),
+ ("92233720368547758070", DoubleType, {}),
+ ],
+)
+def test_infer_element_schema_leaf(xml_str, expected_type, kwargs):
+ """Leaf elements (no children, no attributes) infer scalar types."""
+ assert isinstance(infer_element_schema(_xml(xml_str), **kwargs), expected_type)
+
+
+# ===========================================================================
+# infer_element_schema – attributes
+# ===========================================================================
+
+
+def test_infer_element_schema_attributes_only_root():
+ """Root-level attribute-only element has no _VALUE."""
+ result = infer_element_schema(_xml(''), is_root=True)
+ assert isinstance(result, StructType)
+ field_names = {f._name for f in result.fields}
+ assert "_name" in field_names and "_age" in field_names
+ assert "_VALUE" not in field_names
+
+
+def test_infer_element_schema_attributes_only_child():
+ """Child-level attribute-only element adds _VALUE as NullType."""
+ result = infer_element_schema(_xml(''), is_root=False)
+ field_names = {f._name for f in result.fields}
+ assert "_name" in field_names and "_VALUE" in field_names
+
+
+def test_infer_element_schema_attributes_with_text():
+ """Element with attributes and text adds _VALUE."""
+ result = infer_element_schema(_xml('hello'), is_root=True)
+ field_names = {f._name for f in result.fields}
+ assert "_name" in field_names and "_VALUE" in field_names
+
+
+def test_infer_element_schema_custom_attribute_prefix():
+ result = infer_element_schema(
+ _xml(''), attribute_prefix="@", is_root=True
+ )
+ assert "@name" in {f._name for f in result.fields}
+
+
+def test_infer_element_schema_exclude_attributes():
+ """When exclude_attributes=True, attributes are ignored."""
+ assert isinstance(
+ infer_element_schema(
+ _xml('hello'), exclude_attributes=True
+ ),
+ StringType,
+ )
+ assert isinstance(
+ infer_element_schema(_xml(''), exclude_attributes=True),
+ NullType,
+ )
+
+
+# ===========================================================================
+# infer_element_schema – children
+# ===========================================================================
+
+
+def test_infer_element_schema_simple_children():
+ result = infer_element_schema(_xml("1hello
"))
+ assert isinstance(result, StructType)
+ assert {"a", "b"} <= {f._name for f in result.fields}
+
+
+def test_infer_element_schema_child_types():
+ result = infer_element_schema(_xml("42true3.14
"))
+ field_map = {f._name: f.datatype for f in result.fields}
+ assert isinstance(field_map["a"], LongType)
+ assert isinstance(field_map["b"], BooleanType)
+ assert isinstance(field_map["c"], DoubleType)
+
+
+def test_infer_element_schema_repeated_children_array():
+ """Repeated child tags → ArrayType."""
+ result = infer_element_schema(
+ _xml("- 1
- 2
- 3
")
+ )
+ field_map = {f._name: f.datatype for f in result.fields}
+ assert isinstance(field_map["item"], ArrayType)
+ assert isinstance(field_map["item"].element_type, LongType)
+
+
+def test_infer_element_schema_nested_children():
+ result = infer_element_schema(
+ _xml("hello
")
+ )
+ outer_field = next(f for f in result.fields if f._name == "outer")
+ assert isinstance(outer_field.datatype, StructType)
+ inner_field = next(f for f in outer_field.datatype.fields if f._name == "inner")
+ assert isinstance(inner_field.datatype, StringType)
+
+
+def test_infer_element_schema_child_with_attributes():
+ result = infer_element_schema(_xml('text
'))
+ child_field = next(f for f in result.fields if f._name == "child")
+ assert isinstance(child_field.datatype, StructType)
+ child_names = {f._name for f in child_field.datatype.fields}
+ assert "_id" in child_names and "_VALUE" in child_names
+
+
+# ===========================================================================
+# infer_element_schema – mixed content
+# ===========================================================================
+
+
+def test_infer_element_schema_mixed_content_value_tag():
+ """Element with text and children: text goes into _VALUE."""
+ result = infer_element_schema(_xml("some text1
"))
+ field_names = {f._name for f in result.fields}
+ assert "_VALUE" in field_names and "a" in field_names
+
+
+def test_infer_element_schema_whitespace_only_no_value_tag():
+ """Whitespace-only text between children does not create _VALUE."""
+ result = infer_element_schema(_xml("\n 1\n 2\n
"))
+ assert "_VALUE" not in {f._name for f in result.fields}
+
+
+# ===========================================================================
+# infer_element_schema – namespace handling
+# ===========================================================================
+
+
+def test_infer_element_schema_namespace_clark():
+ """Clark notation {uri}tag → strip namespace part."""
+ from snowflake.snowpark._internal.xml_reader import strip_xml_namespaces
+
+ elem = strip_xml_namespaces(
+ ET.fromstring('val
')
+ )
+ result = infer_element_schema(elem, ignore_namespace=True)
+ assert "child" in {f._name for f in result.fields}
+
+
+def test_infer_element_schema_namespace_prefix():
+ parser = ET.XMLParser(recover=True, ns_clean=True)
+ elem = ET.fromstring("val", parser)
+ result = infer_element_schema(elem, ignore_namespace=True)
+ assert "child" in {f._name for f in result.fields}
+
+
+# ===========================================================================
+# infer_element_schema – whitespace and sorting
+# ===========================================================================
+
+
+def test_infer_element_schema_whitespace_leaf():
+ assert isinstance(
+ infer_element_schema(_xml(" 42 "), ignore_surrounding_whitespace=True),
+ LongType,
+ )
+ assert isinstance(
+ infer_element_schema(
+ _xml(" 42 "), ignore_surrounding_whitespace=False
+ ),
+ StringType,
+ )
+
+
+def test_infer_element_schema_fields_sorted():
+ result = infer_element_schema(_xml("123
"))
+ names = [f._name for f in result.fields]
+ assert names == sorted(names)
+
+
+def test_infer_element_schema_custom_value_tag():
+ result = infer_element_schema(_xml('text
'), value_tag="myval")
+ field_names = {f._name for f in result.fields}
+ assert "myval" in field_names and "_VALUE" not in field_names
+
+
+# ===========================================================================
+# Spark TestXmlData-derived tests (multi-record infer + merge)
+# ===========================================================================
+
+
+def test_spark_primitive_field_value_type_conflict():
+ """Merging records with conflicting types for same fields."""
+ records = [
+ """
+ 111.1
+ true13.1str1
+
""",
+ """
+ 21474836470.9
+ 12true
+
""",
+ """
+ 2147483647092233720368547758070
+ 100false
+ str1false
+
""",
+ """
+ 214748365701.1
+ 21474836470
+ 92233720368547758070
+
""",
+ ]
+ merged = _infer_and_merge(records)
+ fm = {f._name: f.datatype for f in canonicalize_type(merged).fields}
+ assert isinstance(fm["num_bool"], StringType)
+ assert isinstance(fm["num_num_1"], LongType)
+ assert isinstance(fm["num_num_2"], DoubleType)
+ assert isinstance(fm["num_num_3"], DoubleType)
+ assert isinstance(fm["num_str"], StringType)
+ assert isinstance(fm["str_bool"], StringType)
+
+
+def test_spark_complex_field_value_type_conflict():
+ """Merging arrays with singletons and struct/primitive conflicts."""
+ records = [
+ """
+ 11
+ 123
+
+
""",
+ """
+ false
+
+
""",
+ """
+
+ str
+ 456
+ 78
+ 9
+
+
""",
+ """
+
+ str1str233
+ 7
+ true
+ str
+
""",
+ ]
+ merged = _infer_and_merge(
+ records, ignore_surrounding_whitespace=True, null_value=""
+ )
+ fm = {f._name: f.datatype for f in canonicalize_type(merged).fields}
+ assert isinstance(fm["array"], ArrayType) and isinstance(
+ fm["array"].element_type, LongType
+ )
+ assert isinstance(fm["num_struct"], StringType)
+ assert isinstance(fm["str_array"], ArrayType) and isinstance(
+ fm["str_array"].element_type, StringType
+ )
+ assert isinstance(fm["struct"], StructType)
+ assert isinstance(fm["struct_array"], ArrayType) and isinstance(
+ fm["struct_array"].element_type, StringType
+ )
+
+
+def test_spark_missing_fields():
+ """Different records have different fields; merge produces superset."""
+ records = [
+ "true
",
+ "21474836470
",
+ "3344
",
+ "true
",
+ "str
",
+ ]
+ fm = {f._name: f.datatype for f in _full_pipeline(records).fields}
+ assert isinstance(fm["a"], BooleanType)
+ assert isinstance(fm["b"], LongType)
+ assert isinstance(fm["c"], ArrayType) and isinstance(fm["c"].element_type, LongType)
+ assert isinstance(fm["d"], StructType)
+ assert isinstance(fm["e"], StringType)
+
+
+def test_spark_null_struct():
+ """Null/empty struct fields merged correctly."""
+ records = [
+ """27.31.100.29
+ 1.abc.comUTF-8
""",
+ "27.31.100.29
",
+ "27.31.100.29
",
+ "27.31.100.29
",
+ ]
+ fm = {f._name: f.datatype for f in _full_pipeline(records).fields}
+ assert isinstance(fm["nullstr"], StringType)
+ assert isinstance(fm["ip"], StringType)
+ assert isinstance(fm["headers"], StructType)
+ hf = {f._name: f.datatype for f in fm["headers"].fields}
+ assert isinstance(hf["Host"], StringType) and isinstance(hf["Charset"], StringType)
+
+
+def test_spark_empty_records():
+ """Structs with empty/null children properly handled."""
+ records = [
+ "
",
+ "
",
+ "
",
+ ]
+ fm = {f._name: f.datatype for f in _full_pipeline(records).fields}
+ assert isinstance(fm["a"], StructType)
+ a_struct = {f._name: f.datatype for f in fm["a"].fields}["struct"]
+ b_c = {
+ f._name: f.datatype
+ for f in {f._name: f.datatype for f in a_struct.fields}["b"].fields
+ }["c"]
+ assert isinstance(b_c, StringType)
+ assert isinstance(fm["b"], StructType)
+ assert isinstance({f._name: f.datatype for f in fm["b"].fields}["item"], ArrayType)
+
+
+def test_spark_nulls_in_arrays():
+ """Null entries in arrays don't disrupt type inference."""
+ records = [
+ """value1value2
+
""",
+ """1
+
""",
+ "
",
+ ]
+ fm = {f._name: f.datatype for f in _full_pipeline(records).fields}
+ assert isinstance(fm["field1"], ArrayType)
+ assert isinstance(fm["field2"], ArrayType)
+
+
+def test_spark_value_tags_type_conflict():
+ """Mixed text and children with type conflicts in _VALUE.
+
+ lxml element.text only captures text BEFORE the first child.
+ Tail text is not captured, diverging from Spark's StAX parser.
+ """
+ records = [
+ "\n13.1\n\n11\n\ntrue\n1
",
+ "\nstring\n\n21474836470\n\nfalse\n2
",
+ "\n12\n3
",
+ ]
+ fm = {
+ f._name: f.datatype
+ for f in _full_pipeline(records, ignore_surrounding_whitespace=True).fields
+ }
+ assert isinstance(fm["_VALUE"], StringType)
+ assert isinstance(fm["a"], StructType)
+ af = {f._name: f.datatype for f in fm["a"].fields}
+ assert isinstance(af["_VALUE"], LongType)
+ bf = {f._name: f.datatype for f in af["b"].fields}
+ assert isinstance(bf["_VALUE"], StringType)
+ assert isinstance(bf["c"], LongType)
+
+
+def test_spark_value_tag_conflict_name():
+ """When valueTag="a" conflicts with child , they merge.
+
+ lxml element.text only captures text BEFORE the first child.
+ """
+ records_before = ["\n2\n1
"]
+ fm = {
+ f._name: f.datatype
+ for f in _full_pipeline(
+ records_before, value_tag="a", ignore_surrounding_whitespace=True
+ ).fields
+ }
+ assert isinstance(fm["a"], ArrayType) and isinstance(fm["a"].element_type, LongType)
+
+ records_after = ["1\n2\n
"]
+ fm2 = {
+ f._name: f.datatype
+ for f in _full_pipeline(
+ records_after, value_tag="a", ignore_surrounding_whitespace=True
+ ).fields
+ }
+ assert isinstance(fm2["a"], LongType)
+
+
+# ===========================================================================
+# compatible_type
+# ===========================================================================
+
+
+@pytest.mark.parametrize(
+ "t1, t2, expected_type",
+ [
+ (LongType(), LongType(), LongType),
+ (StringType(), StringType(), StringType),
+ (BooleanType(), BooleanType(), BooleanType),
+ (DoubleType(), DoubleType(), DoubleType),
+ (DateType(), DateType(), DateType),
+ (TimestampType(), TimestampType(), TimestampType),
+ (NullType(), NullType(), NullType),
+ ],
+)
+def test_compatible_type_same(t1, t2, expected_type):
+ """Same types stay the same."""
+ assert isinstance(compatible_type(t1, t2), expected_type)
+
+
+@pytest.mark.parametrize(
+ "t1, t2, expected_type",
+ [
+ (NullType(), LongType(), LongType),
+ (LongType(), NullType(), LongType),
+ (NullType(), StringType(), StringType),
+ ],
+)
+def test_compatible_type_null_plus_t(t1, t2, expected_type):
+ """NullType + T → T."""
+ assert isinstance(compatible_type(t1, t2), expected_type)
+
+
+def test_compatible_type_null_plus_struct():
+ s = StructType([StructField("a", LongType(), True)])
+ assert isinstance(compatible_type(NullType(), s), StructType)
+
+
+def test_compatible_type_null_plus_array():
+ a = ArrayType(LongType(), structured=False)
+ assert isinstance(compatible_type(NullType(), a), ArrayType)
+
+
+@pytest.mark.parametrize(
+ "t1, t2",
+ [(LongType(), DoubleType()), (DoubleType(), LongType())],
+)
+def test_compatible_type_long_double_widening(t1, t2):
+ """Long + Double → Double."""
+ assert isinstance(compatible_type(t1, t2), DoubleType)
+
+
+@pytest.mark.parametrize(
+ "t1, t2",
+ [(DoubleType(), DecimalType(10, 2)), (DecimalType(10, 2), DoubleType())],
+)
+def test_compatible_type_double_decimal(t1, t2):
+ """Double + Decimal → Double."""
+ assert isinstance(compatible_type(t1, t2), DoubleType)
+
+
+@pytest.mark.parametrize(
+ "t1, t2",
+ [(TimestampType(), DateType()), (DateType(), TimestampType())],
+)
+def test_compatible_type_timestamp_date(t1, t2):
+ """Timestamp + Date → Timestamp."""
+ assert isinstance(compatible_type(t1, t2), TimestampType)
+
+
+def test_compatible_type_long_plus_decimal():
+ result = compatible_type(LongType(), DecimalType(10, 2))
+ assert (
+ isinstance(result, DecimalType) and result.precision == 22 and result.scale == 2
+ )
+
+ result2 = compatible_type(DecimalType(10, 2), LongType())
+ assert (
+ isinstance(result2, DecimalType)
+ and result2.precision == 22
+ and result2.scale == 2
+ )
+
+
+def test_compatible_type_same_struct_merge():
+ s1 = StructType(
+ [StructField("a", LongType(), True), StructField("b", StringType(), True)]
+ )
+ s2 = StructType(
+ [StructField("a", DoubleType(), True), StructField("c", BooleanType(), True)]
+ )
+ result = compatible_type(s1, s2)
+ fm = {f._name: f.datatype for f in result.fields}
+ assert (
+ isinstance(fm["a"], DoubleType)
+ and isinstance(fm["b"], StringType)
+ and isinstance(fm["c"], BooleanType)
+ )
+
+
+def test_compatible_type_same_array_merge():
+ result = compatible_type(
+ ArrayType(LongType(), structured=False),
+ ArrayType(DoubleType(), structured=False),
+ )
+ assert isinstance(result, ArrayType) and isinstance(result.element_type, DoubleType)
+
+
+def test_compatible_type_decimal_widen():
+ result = compatible_type(DecimalType(10, 2), DecimalType(15, 5))
+ assert (
+ isinstance(result, DecimalType) and result.precision == 15 and result.scale == 5
+ )
+
+
+def test_compatible_type_decimal_overflow_to_double():
+ """Decimal with precision > 38 falls back to Double."""
+ result = compatible_type(DecimalType(38, 10), DecimalType(38, 20))
+ assert isinstance(result, DoubleType)
+
+
+def test_compatible_type_array_plus_primitive():
+ result = compatible_type(ArrayType(LongType(), structured=False), DoubleType())
+ assert isinstance(result, ArrayType) and isinstance(result.element_type, DoubleType)
+
+ result2 = compatible_type(LongType(), ArrayType(StringType(), structured=False))
+ assert isinstance(result2, ArrayType) and isinstance(
+ result2.element_type, StringType
+ )
+
+
+def test_compatible_type_struct_value_tag_plus_primitive():
+ s = StructType(
+ [
+ StructField("_VALUE", LongType(), True),
+ StructField("_attr", StringType(), True),
+ ]
+ )
+ result = compatible_type(s, DoubleType())
+ fm = {f._name: f.datatype for f in result.fields}
+ assert isinstance(fm["_VALUE"], DoubleType) and isinstance(fm["_attr"], StringType)
+
+ s2 = StructType(
+ [
+ StructField("_VALUE", StringType(), True),
+ StructField("_id", LongType(), True),
+ ]
+ )
+ result2 = compatible_type(BooleanType(), s2)
+ assert isinstance(
+ {f._name: f.datatype for f in result2.fields}["_VALUE"], StringType
+ )
+
+
+@pytest.mark.parametrize(
+ "t1, t2",
+ [
+ (BooleanType(), LongType()),
+ (StringType(), LongType()),
+ (BooleanType(), DoubleType()),
+ (DateType(), LongType()),
+ (StringType(), BooleanType()),
+ (DateType(), DoubleType()),
+ (BooleanType(), DateType()),
+ (StringType(), DateType()),
+ (LongType(), BooleanType()),
+ ],
+)
+def test_compatible_type_fallback_string(t1, t2):
+ """Incompatible types → StringType."""
+ assert isinstance(compatible_type(t1, t2), StringType)
+
+
+def test_compatible_type_struct_no_value_tag_plus_primitive():
+ """Struct without _VALUE + primitive → StringType fallback."""
+ s = StructType([StructField("a", LongType(), True)])
+ assert isinstance(compatible_type(s, LongType()), StringType)
+
+
+# ===========================================================================
+# _struct_has_value_tag
+# ===========================================================================
+
+
+def test_struct_has_value_tag():
+ s_yes = StructType(
+ [StructField("_VALUE", LongType(), True), StructField("a", StringType(), True)]
+ )
+ assert _struct_has_value_tag(s_yes, "_VALUE") is True
+
+ s_no = StructType([StructField("a", LongType(), True)])
+ assert _struct_has_value_tag(s_no, "_VALUE") is False
+
+ s_custom = StructType([StructField("myval", LongType(), True)])
+ assert _struct_has_value_tag(s_custom, "MYVAL") is True
+ assert _struct_has_value_tag(s_custom, "myval") is False
+ assert _struct_has_value_tag(s_custom, "_VALUE") is False
+
+
+# ===========================================================================
+# _merge_struct_with_primitive
+# ===========================================================================
+
+
+def test_merge_struct_with_primitive():
+ s = StructType(
+ [
+ StructField("_VALUE", LongType(), True),
+ StructField("_attr", StringType(), True),
+ ]
+ )
+ result = _merge_struct_with_primitive(s, DoubleType(), "_VALUE")
+ fm = {f._name: f.datatype for f in result.fields}
+ assert isinstance(fm["_VALUE"], DoubleType) and isinstance(fm["_attr"], StringType)
+
+ s2 = StructType(
+ [
+ StructField("_VALUE", StringType(), True),
+ StructField("other", LongType(), True),
+ ]
+ )
+ result2 = _merge_struct_with_primitive(s2, LongType(), "_VALUE")
+ fm2 = {f._name: f.datatype for f in result2.fields}
+ assert isinstance(fm2["_VALUE"], StringType) and isinstance(fm2["other"], LongType)
+
+
+# ===========================================================================
+# merge_struct_types
+# ===========================================================================
+
+
+def test_merge_struct_types_identical():
+ s = StructType(
+ [StructField("a", LongType(), True), StructField("b", StringType(), True)]
+ )
+ fm = {f._name: f.datatype for f in merge_struct_types(s, s).fields}
+ assert isinstance(fm["a"], LongType) and isinstance(fm["b"], StringType)
+
+
+def test_merge_struct_types_disjoint():
+ result = merge_struct_types(
+ StructType([StructField("a", LongType(), True)]),
+ StructType([StructField("b", StringType(), True)]),
+ )
+ fm = {f._name: f.datatype for f in result.fields}
+ assert "a" in fm and "b" in fm
+
+
+def test_merge_struct_types_widening():
+ result = merge_struct_types(
+ StructType([StructField("a", LongType(), True)]),
+ StructType([StructField("a", DoubleType(), True)]),
+ )
+ assert isinstance({f._name: f.datatype for f in result.fields}["a"], DoubleType)
+
+
+def test_merge_struct_types_overlapping_plus_unique():
+ s1 = StructType(
+ [StructField("a", LongType(), True), StructField("b", StringType(), True)]
+ )
+ s2 = StructType(
+ [StructField("a", DoubleType(), True), StructField("c", BooleanType(), True)]
+ )
+ fm = {f._name: f.datatype for f in merge_struct_types(s1, s2).fields}
+ assert (
+ isinstance(fm["a"], DoubleType)
+ and isinstance(fm["b"], StringType)
+ and isinstance(fm["c"], BooleanType)
+ )
+
+
+def test_merge_struct_types_sorted():
+ result = merge_struct_types(
+ StructType([StructField("z", LongType(), True)]),
+ StructType([StructField("a", StringType(), True)]),
+ )
+ names = [f._name for f in result.fields]
+ assert names == sorted(names)
+
+
+def test_merge_struct_types_case_sensitive():
+ """'Name' and 'name' are separate fields (case-sensitive)."""
+ result = merge_struct_types(
+ StructType([StructField("Name", LongType(), True)]),
+ StructType([StructField("name", StringType(), True)]),
+ )
+ fm = {f._name: f.datatype for f in result.fields}
+ assert isinstance(fm["Name"], LongType) and isinstance(fm["name"], StringType)
+
+
+# ===========================================================================
+# add_or_update_type
+# ===========================================================================
+
+
+def test_add_or_update_type_first_occurrence():
+ d = {}
+ add_or_update_type(d, "a", LongType())
+ assert isinstance(d["a"], LongType)
+
+
+def test_add_or_update_type_second_promotes_array():
+ d = {"a": LongType()}
+ add_or_update_type(d, "a", LongType())
+ assert isinstance(d["a"], ArrayType) and isinstance(d["a"].element_type, LongType)
+
+
+def test_add_or_update_type_second_different_types():
+ d = {"a": LongType()}
+ add_or_update_type(d, "a", DoubleType())
+ assert isinstance(d["a"], ArrayType) and isinstance(d["a"].element_type, DoubleType)
+
+
+def test_add_or_update_type_third_merges():
+ d = {"a": ArrayType(LongType(), structured=False)}
+ add_or_update_type(d, "a", DoubleType())
+ assert isinstance(d["a"], ArrayType) and isinstance(d["a"].element_type, DoubleType)
+
+
+def test_add_or_update_type_progressive():
+ """Multiple occurrences keep widening the array element type."""
+ d = {}
+ add_or_update_type(d, "a", LongType())
+ assert isinstance(d["a"], LongType)
+ add_or_update_type(d, "a", LongType())
+ assert isinstance(d["a"], ArrayType)
+ add_or_update_type(d, "a", DoubleType())
+ assert isinstance(d["a"].element_type, DoubleType)
+ add_or_update_type(d, "a", StringType())
+ assert isinstance(d["a"].element_type, StringType)
+
+
+# ===========================================================================
+# canonicalize_type
+# ===========================================================================
+
+
+@pytest.mark.parametrize(
+ "dtype, expected_type",
+ [
+ (NullType(), StringType),
+ (LongType(), LongType),
+ (StringType(), StringType),
+ (DoubleType(), DoubleType),
+ (BooleanType(), BooleanType),
+ (DateType(), DateType),
+ (TimestampType(), TimestampType),
+ (DecimalType(10, 2), DecimalType),
+ (DecimalType(38, 18), DecimalType),
+ ],
+)
+def test_canonicalize_type_primitives(dtype, expected_type):
+ """NullType → StringType; other primitives unchanged."""
+ assert isinstance(canonicalize_type(dtype), expected_type)
+
+
+def test_canonicalize_type_array_of_null():
+ result = canonicalize_type(ArrayType(NullType(), structured=False))
+ assert isinstance(result, ArrayType) and isinstance(result.element_type, StringType)
+
+
+def test_canonicalize_type_array_of_long():
+ result = canonicalize_type(ArrayType(LongType(), structured=False))
+ assert isinstance(result, ArrayType) and isinstance(result.element_type, LongType)
+
+
+def test_canonicalize_type_array_of_empty_struct():
+ """Array containing empty struct → None (removed)."""
+ result = canonicalize_type(
+ ArrayType(StructType([], structured=False), structured=False)
+ )
+ assert result is None
+
+
+def test_canonicalize_type_nested_array_null():
+ """Array> → Array>."""
+ inner = ArrayType(NullType(), structured=False)
+ result = canonicalize_type(ArrayType(inner, structured=False))
+ assert isinstance(result.element_type, ArrayType)
+ assert isinstance(result.element_type.element_type, StringType)
+
+
+def test_canonicalize_type_struct_null_fields():
+ s = StructType(
+ [StructField("a", NullType(), True), StructField("b", LongType(), True)],
+ structured=False,
+ )
+ fm = {f._name: f.datatype for f in canonicalize_type(s).fields}
+ assert isinstance(fm["a"], StringType) and isinstance(fm["b"], LongType)
+
+
+def test_canonicalize_type_empty_struct():
+ """Per SPARK-8093: empty structs → None."""
+ assert canonicalize_type(StructType([], structured=False)) is None
+
+
+def test_canonicalize_type_struct_empty_name_field():
+ """Fields with empty names are removed; if none remain, struct is None."""
+ assert (
+ canonicalize_type(
+ StructType([StructField("", LongType(), True)], structured=False)
+ )
+ is None
+ )
+
+
+def test_canonicalize_type_struct_mixed_fields():
+ s = StructType(
+ [
+ StructField("a", NullType(), True),
+ StructField("", LongType(), True),
+ StructField("b", DoubleType(), True),
+ ],
+ structured=False,
+ )
+ result = canonicalize_type(s)
+ assert len(result.fields) == 2
+ fm = {f._name: f.datatype for f in result.fields}
+ assert isinstance(fm["a"], StringType) and isinstance(fm["b"], DoubleType)
+
+
+def test_canonicalize_type_nested_struct():
+ """Nested struct: inner NullType fields become StringType."""
+ inner = StructType([StructField("x", NullType(), True)], structured=False)
+ outer = StructType([StructField("child", inner, True)], structured=False)
+ result = canonicalize_type(outer)
+ assert isinstance(result.fields[0].datatype.fields[0].datatype, StringType)
+
+
+def test_canonicalize_type_nested_struct_empty_inner():
+ """Nested struct where inner becomes empty → field is removed."""
+ inner = StructType([], structured=False)
+ outer = StructType(
+ [StructField("child", inner, True), StructField("other", LongType(), True)],
+ structured=False,
+ )
+ result = canonicalize_type(outer)
+ assert len(result.fields) == 1 and result.fields[0]._name == "other"
+
+
+# ===========================================================================
+# _case_preserving_simple_string
+# ===========================================================================
+
+
+@pytest.mark.parametrize(
+ "dtype, expected",
+ [
+ (StringType(), "string"),
+ (LongType(), "bigint"),
+ (DoubleType(), "double"),
+ (BooleanType(), "boolean"),
+ (DateType(), "date"),
+ (TimestampType(), "timestamp"),
+ (NullType(), "null"),
+ ],
+)
+def test_case_preserving_simple_string_primitives(dtype, expected):
+ assert _case_preserving_simple_string(dtype) == expected
+
+
+def test_case_preserving_simple_string_struct():
+ s = StructType([StructField("author", StringType(), True)], structured=False)
+ assert _case_preserving_simple_string(s) == "struct"
+
+
+def test_case_preserving_simple_string_preserves_case():
+ s = StructType(
+ [
+ StructField("Author", StringType(), True),
+ StructField("title", LongType(), True),
+ ],
+ structured=False,
+ )
+ result = _case_preserving_simple_string(s)
+ assert "Author:string" in result and "title:bigint" in result
+
+
+def test_case_preserving_simple_string_nested():
+ inner = StructType([StructField("inner_field", LongType(), True)], structured=False)
+ outer = StructType([StructField("outer", inner, True)], structured=False)
+ assert (
+ _case_preserving_simple_string(outer)
+ == "struct>"
+ )
+
+
+def test_case_preserving_simple_string_array():
+ s = StructType(
+ [StructField("items", ArrayType(StringType(), structured=False), True)],
+ structured=False,
+ )
+ assert _case_preserving_simple_string(s) == "struct>"
+
+
+def test_case_preserving_simple_string_array_of_struct():
+ inner = StructType([StructField("name", StringType(), True)], structured=False)
+ s = StructType(
+ [StructField("people", ArrayType(inner, structured=False), True)],
+ structured=False,
+ )
+ assert (
+ _case_preserving_simple_string(s) == "struct>>"
+ )
+
+
+def test_case_preserving_simple_string_colon_in_name():
+ """Field names with colons get quoted."""
+ s = StructType([StructField("px:name", StringType(), True)], structured=False)
+ assert _case_preserving_simple_string(s) == 'struct<"px:name":string>'
+
+ s2 = StructType(
+ [
+ StructField("px:name", StringType(), True),
+ StructField("px:value", LongType(), True),
+ ],
+ structured=False,
+ )
+ result = _case_preserving_simple_string(s2)
+ assert '"px:name":string' in result and '"px:value":bigint' in result
+
+
+def test_case_preserving_simple_string_no_unnecessary_quotes():
+ s = StructType([StructField("author", StringType(), True)], structured=False)
+ assert '"' not in _case_preserving_simple_string(s)
+
+
+# ===========================================================================
+# End-to-end: infer + merge + canonicalize (Spark parity)
+# ===========================================================================
+
+
+def test_e2e_simple_flat_record():
+ records = [
+ """The Great GatsbyF. Scott Fitzgerald
+ 192512.99"""
+ ]
+ fm = {f._name: f.datatype for f in _full_pipeline(records).fields}
+ assert isinstance(fm["title"], StringType) and isinstance(fm["author"], StringType)
+ assert isinstance(fm["year"], LongType) and isinstance(fm["price"], DoubleType)
+
+
+def test_e2e_repeated_elements_array():
+ records = [
+ "- Apple
- Banana
- Cherry
"
+ ]
+ fm = {f._name: f.datatype for f in _full_pipeline(records).fields}
+ assert isinstance(fm["item"], ArrayType) and isinstance(
+ fm["item"].element_type, StringType
+ )
+
+
+def test_e2e_nested_struct_and_array():
+ records = [
+ """Test Book
+ user15
+ user23
+ """
+ ]
+ fm = {f._name: f.datatype for f in _full_pipeline(records).fields}
+ assert isinstance(fm["title"], StringType)
+ rev_fields = {f._name: f.datatype for f in fm["reviews"].fields}
+ assert isinstance(rev_fields["review"], ArrayType)
+ elem_fields = {
+ f._name: f.datatype for f in rev_fields["review"].element_type.fields
+ }
+ assert isinstance(elem_fields["user"], StringType) and isinstance(
+ elem_fields["rating"], LongType
+ )
+
+
+def test_e2e_attributes_with_text():
+ records = [
+ """Test
+ """
+ ]
+ fm = {f._name: f.datatype for f in _full_pipeline(records).fields}
+ assert isinstance(fm["_id"], LongType) and isinstance(fm["title"], StringType)
+ ef = {f._name: f.datatype for f in fm["edition"].fields}
+ assert isinstance(ef["_year"], LongType) and isinstance(ef["_format"], StringType)
+ assert isinstance(ef["_VALUE"], StringType)
+
+
+def test_e2e_schema_evolution():
+ """Field types evolve across records: Long + Double → Double."""
+ records = [
+ "- Widget10
",
+ "- Gadget19.99
",
+ ]
+ fm = {f._name: f.datatype for f in _full_pipeline(records).fields}
+ assert isinstance(fm["name"], StringType) and isinstance(fm["price"], DoubleType)
+
+
+def test_e2e_schema_evolution_bool_to_string():
+ """Boolean + String → String."""
+ records = [
+ "- true
",
+ "- maybe
",
+ ]
+ fm = {f._name: f.datatype for f in _full_pipeline(records).fields}
+ assert isinstance(fm["flag"], StringType)
+
+
+def test_e2e_schema_evolution_date_to_timestamp():
+ """Date + Timestamp → Timestamp."""
+ records = [
+ "- 2024-01-15
",
+ "- 2024-01-15T10:30:00
",
+ ]
+ fm = {f._name: f.datatype for f in _full_pipeline(records).fields}
+ assert isinstance(fm["when"], TimestampType)
+
+
+def test_e2e_null_fields_canonicalized():
+ """All-null fields become StringType after canonicalization."""
+ records = ["hello
", "world
"]
+ fm = {f._name: f.datatype for f in _full_pipeline(records).fields}
+ assert isinstance(fm["a"], StringType) and isinstance(fm["b"], StringType)
+
+
+def test_e2e_serialization_round_trip():
+ records = [
+ """Jane29.99
+ fictionbestseller
"""
+ ]
+ schema_str = _case_preserving_simple_string(_full_pipeline(records))
+ assert "author:string" in schema_str and "price:double" in schema_str
+ assert "tags:struct>" in schema_str
+
+
+def test_e2e_case_sensitive_fields():
+ """'b' and 'B' are different fields (case-sensitive)."""
+ records = [
+ "123
",
+ "456
",
+ ]
+ fm = {f._name: f.datatype for f in _full_pipeline(records).fields}
+ assert all(isinstance(fm[k], LongType) for k in ["a", "b", "B", "c"])
+
+
+def test_e2e_case_sensitive_value_tag():
+ """ with children vs without → different fields."""
+ records = [
+ "\n1\n2
",
+ "3
",
+ ]
+ fm = {
+ f._name: f.datatype
+ for f in _full_pipeline(records, ignore_surrounding_whitespace=True).fields
+ }
+ assert isinstance(fm["A"], LongType)
+ assert isinstance(fm["a"], StructType)
+ af = {f._name: f.datatype for f in fm["a"].fields}
+ assert isinstance(af["_VALUE"], LongType) and isinstance(af["b"], LongType)
+
+
+def test_e2e_case_sensitive_attributes():
+ """'attr' and 'aTtr' are different attribute fields."""
+ records = [
+ '123
',
+ '456
',
+ ]
+ fm = {f._name: f.datatype for f in _full_pipeline(records).fields}
+ assert isinstance(fm["_attr"], LongType) and isinstance(fm["_aTtr"], LongType)
+
+
+def test_e2e_case_sensitive_struct():
+ """ and with children → different struct fields."""
+ records = [
+ "13
",
+ "57
",
+ ]
+ fm = {f._name: f.datatype for f in _full_pipeline(records).fields}
+ assert isinstance(fm["A"], StructType) and isinstance(fm["a"], StructType)
+ assert all(
+ isinstance(v, LongType)
+ for v in {f._name: f.datatype for f in fm["A"].fields}.values()
+ )
+
+
+def test_e2e_case_sensitive_array_complex():
+ records = [
+ "\n1\n234
",
+ "5
",
+ ]
+ fm = {
+ f._name: f.datatype
+ for f in _full_pipeline(records, ignore_surrounding_whitespace=True).fields
+ }
+ assert isinstance(fm["A"], LongType)
+ assert isinstance(fm["a"], StructType)
+ af = {f._name: f.datatype for f in fm["a"].fields}
+ assert isinstance(af["_VALUE"], LongType) and isinstance(af["b"], LongType)
+
+
+def test_e2e_case_sensitive_array_simple():
+ records = ["1234
"]
+ fm = {f._name: f.datatype for f in _full_pipeline(records).fields}
+ assert isinstance(fm["A"], StructType) and isinstance(fm["a"], StructType)
+ assert isinstance({f._name: f.datatype for f in fm["A"].fields}["B"], LongType)
+ assert isinstance({f._name: f.datatype for f in fm["a"].fields}["b"], LongType)
+
+
+def test_e2e_value_tags_spaces_and_empty_values():
+ """lxml element.text only captures text BEFORE the first child."""
+ records = [
+ "\n str1\n 1\n
",
+ " value
",
+ "3
",
+ "4
",
+ ]
+ fm_ws = {
+ f._name: f.datatype
+ for f in _full_pipeline(records, ignore_surrounding_whitespace=True).fields
+ }
+ assert isinstance(fm_ws["_VALUE"], StringType)
+ assert isinstance(fm_ws["a"], StructType)
+ assert isinstance({f._name: f.datatype for f in fm_ws["a"].fields}["b"], LongType)
+
+ fm_nws = {
+ f._name: f.datatype
+ for f in _full_pipeline(records, ignore_surrounding_whitespace=False).fields
+ }
+ assert isinstance(fm_nws["_VALUE"], StringType)
+
+
+def test_e2e_value_tags_multiline():
+ """element.text only captures text before first child."""
+ records = [
+ "\nvalue1\n1
",
+ "\nvalue3\nvalue41
",
+ ]
+ fm = {
+ f._name: f.datatype
+ for f in _full_pipeline(records, ignore_surrounding_whitespace=True).fields
+ }
+ assert isinstance(fm["_VALUE"], StringType) and isinstance(fm["a"], LongType)
+
+
+def test_e2e_value_tag_comments():
+ """Comments should not affect value tag inference."""
+ records = ['\n2\n
']
+ fm = {
+ f._name: f.datatype
+ for f in _full_pipeline(records, ignore_surrounding_whitespace=True).fields
+ }
+ assert isinstance(fm["_VALUE"], LongType) and isinstance(fm["a"], ArrayType)
+
+
+def test_e2e_value_tag_null_value_option():
+ """nullValue option doesn't affect schema inference."""
+ fm = {
+ f._name: f.datatype
+ for f in _full_pipeline(
+ ["\n 1\n
"], ignore_surrounding_whitespace=True
+ ).fields
+ }
+ assert isinstance(fm["_VALUE"], LongType)
+
+
+# ---------------------------------------------------------------------------
+# infer_schema_for_xml_range
+# ---------------------------------------------------------------------------
+
+
+def _mock_xml_range(xml_content, row_tag, **kwargs):
+ """Helper: run infer_schema_for_xml_range against an in-memory XML string."""
+ xml_bytes = xml_content.encode("utf-8")
+ mock_file = kwargs.get("file_obj") or io.BytesIO(xml_bytes)
+ with patch("snowflake.snowpark.files.SnowflakeFile.open", return_value=mock_file):
+ return infer_schema_for_xml_range(
+ file_path="test.xml",
+ row_tag=row_tag,
+ approx_start=0,
+ approx_end=kwargs.get("approx_end", len(xml_bytes)),
+ sampling_ratio=kwargs.get("sampling_ratio", 1.0),
+ ignore_namespace=kwargs.get("ignore_namespace", True),
+ attribute_prefix=kwargs.get("attribute_prefix", "_"),
+ exclude_attributes=kwargs.get("exclude_attributes", False),
+ value_tag=kwargs.get("value_tag", "_VALUE"),
+ null_value=kwargs.get("null_value", ""),
+ charset="utf-8",
+ ignore_surrounding_whitespace=kwargs.get(
+ "ignore_surrounding_whitespace", True
+ ),
+ )
+
+
+def test_infer_range_single_leaf_record():
+ """Single leaf record → StructType with _VALUE."""
+ schema = _mock_xml_range("- hello
", "item")
+ assert schema is not None
+ assert len(schema.fields) == 1
+ assert schema.fields[0]._name == "_VALUE"
+ assert isinstance(schema.fields[0].datatype, StringType)
+
+
+def test_infer_range_multiple_typed_records():
+ """Multiple records with different types → merged schema."""
+ xml = "1true
2false3.14
"
+ schema = _mock_xml_range(xml, "row")
+ fm = {f._name: f.datatype for f in schema.fields}
+ assert isinstance(fm["a"], LongType)
+ assert isinstance(fm["b"], BooleanType)
+ assert isinstance(fm["c"], DoubleType)
+
+
+def test_infer_range_type_widening():
+ """Long in first record, Double in second → merged to DoubleType."""
+ xml = "42
3.14
"
+ schema = _mock_xml_range(xml, "row")
+ assert isinstance(schema.fields[0].datatype, DoubleType)
+
+
+def test_infer_range_repeated_children_become_array():
+ """Repeated child tags → ArrayType."""
+ xml = "- a
- b
"
+ schema = _mock_xml_range(xml, "row")
+ assert isinstance(schema.fields[0].datatype, ArrayType)
+ assert isinstance(schema.fields[0].datatype.element_type, StringType)
+
+
+def test_infer_range_attributes():
+ """Attributes → prefixed struct fields."""
+ xml = 'Alice
'
+ schema = _mock_xml_range(xml, "row")
+ fm = {f._name: f.datatype for f in schema.fields}
+ assert isinstance(fm["_id"], LongType)
+ assert isinstance(fm["_active"], BooleanType)
+ assert isinstance(fm["name"], StringType)
+
+
+def test_infer_range_nested_struct():
+ """Nested elements → nested StructType."""
+ xml = "NYC10001
"
+ schema = _mock_xml_range(xml, "row")
+ addr = schema.fields[0]
+ assert addr._name == "address"
+ assert isinstance(addr.datatype, StructType)
+ nested_fm = {f._name: f.datatype for f in addr.datatype.fields}
+ assert isinstance(nested_fm["city"], StringType)
+ assert isinstance(nested_fm["zip"], LongType)
+
+
+def test_infer_range_self_closing_tag():
+ """Self-closing tags → StructType with attributes."""
+ xml = '|
'
+ schema = _mock_xml_range(xml, "row")
+ assert len(schema.fields) == 1
+ assert schema.fields[0]._name == "_id"
+ assert isinstance(schema.fields[0].datatype, LongType)
+
+
+def test_infer_range_no_records():
+ xml = "data"
+ schema = _mock_xml_range(xml, "row")
+ assert schema is None
+
+
+def test_infer_range_malformed_record_skipped():
+ """Malformed XML between valid records is skipped."""
+ xml = "1
<<2
"
+ schema = _mock_xml_range(xml, "row")
+ assert schema is not None
+ assert isinstance(schema.fields[0].datatype, LongType)
+
+
+def test_infer_range_namespace_stripped():
+ """Namespaces are stripped when ignore_namespace=True."""
+ xml = '42'
+ schema = _mock_xml_range(xml, "ns:row", ignore_namespace=True)
+ assert schema is not None
+
+
+def test_infer_range_exclude_attributes():
+ """exclude_attributes=True omits attribute fields."""
+ xml = 'Alice
'
+ schema = _mock_xml_range(xml, "row", exclude_attributes=True)
+ fm = {f._name for f in schema.fields}
+ assert "_id" not in fm
+ assert "name" in fm
+
+
+def test_infer_range_mixed_content():
+ xml = "text1
"
+ schema = _mock_xml_range(xml, "row")
+ fm = {f._name: f.datatype for f in schema.fields}
+ assert "_VALUE" in fm
+ assert "child" in fm
+
+
+def test_infer_range_schema_merge_across_records():
+ """Fields present only in some records are merged into the union schema."""
+ xml = "1
hello
2true
"
+ schema = _mock_xml_range(xml, "row")
+ fm = {f._name: f.datatype for f in schema.fields}
+ assert isinstance(fm["a"], LongType)
+ assert isinstance(fm["b"], StringType)
+ assert isinstance(fm["c"], BooleanType)
+
+
+def test_infer_range_sampling_ratio_skips():
+ """With a very low sampling_ratio, some records may be skipped."""
+ import random
+
+ xml = ""
+ for i in range(20):
+ xml += f"{i}
"
+ xml += ""
+ xml_bytes = xml.encode("utf-8")
+
+ random.seed(42)
+ mock_file = io.BytesIO(xml_bytes)
+ with patch("snowflake.snowpark.files.SnowflakeFile.open", return_value=mock_file):
+ schema = infer_schema_for_xml_range(
+ file_path="test.xml",
+ row_tag="row",
+ approx_start=0,
+ approx_end=len(xml_bytes),
+ sampling_ratio=0.3,
+ ignore_namespace=True,
+ attribute_prefix="_",
+ exclude_attributes=False,
+ value_tag="_VALUE",
+ null_value="",
+ charset="utf-8",
+ ignore_surrounding_whitespace=True,
+ )
+ assert schema is not None
+ assert isinstance(schema.fields[0].datatype, LongType)
+
+
+def test_infer_range_truncated_tag_handled():
+ """A truncated opening tag that causes tag_is_self_closing to fail is skipped."""
+ xml = "1
2
"
+ xml_bytes = xml.encode("utf-8")
+ mock_file = io.BytesIO(xml_bytes)
+ with patch("snowflake.snowpark.files.SnowflakeFile.open", return_value=mock_file):
+ schema = infer_schema_for_xml_range(
+ file_path="test.xml",
+ row_tag="row",
+ approx_start=0,
+ approx_end=len(xml_bytes),
+ sampling_ratio=1.0,
+ ignore_namespace=True,
+ attribute_prefix="_",
+ exclude_attributes=False,
+ value_tag="_VALUE",
+ null_value="",
+ charset="utf-8",
+ ignore_surrounding_whitespace=True,
+ )
+ assert schema is not None
+
+
+def test_xml_schema_inference_process_empty_results():
+ """Cases where process should yield ("",): empty/None file size, invalid worker id, no records."""
+ base = ("test.xml",)
+ tail = (1.0, True, "_", False, "_VALUE", "", "utf-8", True)
+
+ # empty file (size 0)
+ with patch(
+ "snowflake.snowpark._internal.xml_schema_inference.get_file_size",
+ return_value=0,
+ ):
+ assert list(XMLSchemaInference().process(*base, 1, "row", 0, *tail)) == [("",)]
+
+ # None file size
+ with patch(
+ "snowflake.snowpark._internal.xml_schema_inference.get_file_size",
+ return_value=None,
+ ):
+ assert list(XMLSchemaInference().process(*base, 1, "row", 0, *tail)) == [("",)]
+
+ # worker id >= num_workers
+ with patch(
+ "snowflake.snowpark._internal.xml_schema_inference.get_file_size",
+ return_value=1000,
+ ):
+ assert list(XMLSchemaInference().process(*base, 2, "row", 5, *tail)) == [("",)]
+
+ # no matching row tags
+ xml_bytes = b"data"
+ with patch(
+ "snowflake.snowpark._internal.xml_schema_inference.get_file_size",
+ return_value=len(xml_bytes),
+ ), patch(
+ "snowflake.snowpark.files.SnowflakeFile.open",
+ return_value=io.BytesIO(xml_bytes),
+ ):
+ assert list(XMLSchemaInference().process(*base, 1, "row", 0, *tail)) == [("",)]
+
+
+def test_xml_schema_inference_process_param_defaults():
+ """None num_workers defaults to 1; negative worker id defaults to 0."""
+ xml_bytes = b"42
"
+ tail = (1.0, True, "_", False, "_VALUE", "", "utf-8", True)
+
+ for num_workers, i in [(None, 0), (1, -1)]:
+ mock_file = io.BytesIO(xml_bytes)
+ with patch(
+ "snowflake.snowpark._internal.xml_schema_inference.get_file_size",
+ return_value=len(xml_bytes),
+ ), patch("snowflake.snowpark.files.SnowflakeFile.open", return_value=mock_file):
+ results = list(
+ XMLSchemaInference().process("test.xml", num_workers, "row", i, *tail)
+ )
+ assert len(results) == 1
+ assert "bigint" in results[0][0]
+
+
+def test_xml_schema_inference_process_with_sampling():
+ """Sampling ratio < 1.0 sets a deterministic seed."""
+ xml = ""
+ for i in range(10):
+ xml += f"{i}
"
+ xml += ""
+ xml_bytes = xml.encode("utf-8")
+ mock_file = io.BytesIO(xml_bytes)
+ with patch(
+ "snowflake.snowpark._internal.xml_schema_inference.get_file_size",
+ return_value=len(xml_bytes),
+ ), patch("snowflake.snowpark.files.SnowflakeFile.open", return_value=mock_file):
+ results = list(
+ XMLSchemaInference().process(
+ "test.xml",
+ 1,
+ "row",
+ 0,
+ 0.5,
+ True,
+ "_",
+ False,
+ "_VALUE",
+ "",
+ "utf-8",
+ True,
+ )
+ )
+ assert len(results) == 1
+ assert results[0][0] != ""
+
+
+def test_xml_schema_inference_process_multi_worker():
+ """Multi-worker processing partitions the file correctly."""
+ xml = "1
2
3
"
+ xml_bytes = xml.encode("utf-8")
+
+ all_schemas = []
+ for worker_id in range(2):
+ mock_file = io.BytesIO(xml_bytes)
+ with patch(
+ "snowflake.snowpark._internal.xml_schema_inference.get_file_size",
+ return_value=len(xml_bytes),
+ ), patch("snowflake.snowpark.files.SnowflakeFile.open", return_value=mock_file):
+ udtf = XMLSchemaInference()
+ results = list(
+ udtf.process(
+ "test.xml",
+ 2,
+ "row",
+ worker_id,
+ 1.0,
+ True,
+ "_",
+ False,
+ "_VALUE",
+ "",
+ "utf-8",
+ True,
+ )
+ )
+ all_schemas.append(results[0][0])
+ assert any(s != "" for s in all_schemas)
+
+
+def test_infer_element_schema_clark_ns_in_children():
+ """Clark {uri}tag stripped with ignore_namespace=True"""
+ root = ET.Element("row")
+ ET.SubElement(root, "{http://example.com}child").text = "val"
+ assert "child" in {
+ f._name for f in infer_element_schema(root, ignore_namespace=True).fields
+ }
+
+
+def test_infer_range_approx_end_breaks():
+ """open_pos >= approx_end → immediate break."""
+ assert _mock_xml_range("1
", "row", approx_end=3) is None
+
+
+def test_infer_range_tag_exception_continues():
+ """Truncated tag at EOF → tag_is_self_closing raises → seek back, continue"""
+ assert _mock_xml_range("1
long_value
",
+ "row",
+ approx_end=10,
+ sampling_ratio=0.01,
+ )
+ assert schema is None
+
+
+def test_infer_range_parse_error_past_approx_end():
+ """Parse error + record past approx_end → break"""
+ schema = _mock_xml_range(
+ "1
",
+ "row",
+ approx_end=10,
+ ignore_namespace=False,
+ )
+ assert schema is None
+
+
+def test_infer_range_stdlib_et_fallback():
+ """stdlib ET used when lxml not installed"""
+ import xml.etree.ElementTree as stdlib_ET
+
+ with patch(
+ "snowflake.snowpark._internal.xml_schema_inference.lxml_installed", False
+ ), patch("snowflake.snowpark._internal.xml_schema_inference.ET", stdlib_ET):
+ assert _mock_xml_range("42
", "row") is not None
+
+
+def test_infer_range_seek_failures():
+ class FailSeek(io.BytesIO):
+ """BytesIO that raises OSError on the Nth seek to a target position.
+
+ skip=0 means fail on the very first seek to fail_pos.
+ skip=1 means allow the first seek (e.g. inside find_next_closing_tag_pos)
+ and fail on the second (the coverage-target seek).
+ """
+
+ def __init__(self, data, fail_pos, skip=0) -> None:
+ super().__init__(data)
+ self._fail_pos = fail_pos
+ self._skip = skip
+ self._hits = 0
+
+ def seek(self, pos, whence=0):
+ if whence == 0 and pos == self._fail_pos:
+ self._hits += 1
+ if self._hits > self._skip:
+ raise OSError("forced")
+ return super().seek(pos, whence)
+
+ def _end(b):
+ return b.find(b"") + len(b"")
+
+ # successful parse → seek failure after merge
+ # skip=1: first seek to record_end happens inside find_next_closing_tag_pos
+ b = b"7
"
+ assert (
+ _mock_xml_range(b.decode(), "row", file_obj=FailSeek(b, _end(b), skip=1))
+ is not None
+ )
+
+ # parse error → seek failure
+ b = b"1
"
+ assert (
+ _mock_xml_range(
+ b.decode(),
+ "row",
+ file_obj=FailSeek(b, _end(b), skip=1),
+ ignore_namespace=False,
+ )
+ is None
+ )
+
+ # sampling skip → seek failure
+ b = b"123
"
+ with patch("snowflake.snowpark._internal.xml_schema_inference.random") as rng:
+ rng.random.return_value = 0.99
+ rng.seed = lambda *a, **kw: None
+ assert (
+ _mock_xml_range(
+ b.decode(),
+ "row",
+ file_obj=FailSeek(b, _end(b), skip=1),
+ sampling_ratio=0.01,
+ )
+ is None
+ )
+
+ # tag exception → seek failure on record_start + 1
+ # skip=0: position record_start+1 is only hit in the except handler
+ b = b"1
",)],
+ type_string=[StructType([StructField("a", LongType())])],
+ canonicalize=LongType(),
+ )
+ assert result is None
+
+
+def test_infer_xml_map_type_field_names_cleaned():
+ """MapType fields get quoted names stripped."""
+ result = _run_infer_xml(
+ [("x",)],
+ type_string=[StructType([StructField("a", LongType())])],
+ canonicalize=StructType(
+ [
+ StructField(
+ '"m"',
+ MapType(
+ StructType([StructField('"k"', StringType())]),
+ StructType([StructField('"v"', LongType())]),
+ ),
+ )
+ ]
+ ),
+ )
+ assert result.fields[0]._name == "m"
+ assert result.fields[0].datatype.key_type.fields[0]._name == "k"
+ assert result.fields[0].datatype.value_type.fields[0]._name == "v"
diff --git a/tests/utils.py b/tests/utils.py
index a309ac0216..5fc5a352d7 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -1717,6 +1717,26 @@ def test_null_value_xml(self):
def test_dog_image(self):
return os.path.join(self.resources_path, "dog.jpg")
+ @property
+ def test_dk_trace_sample_xml(self):
+ return os.path.join(self.resources_path, "dk_trace_sample.xml")
+
+ @property
+ def test_dblp_6kb_xml(self):
+ return os.path.join(self.resources_path, "dblp_6kb.xml")
+
+ @property
+ def test_books_attribute_value_xml(self):
+ return os.path.join(self.resources_path, "books_attribute_value.xml")
+
+ @property
+ def test_xml_infer_types(self):
+ return os.path.join(self.resources_path, "xml_infer_types.xml")
+
+ @property
+ def test_xml_infer_mixed(self):
+ return os.path.join(self.resources_path, "xml_infer_mixed.xml")
+
@property
def test_books_xsd(self):
return os.path.join(self.resources_path, "books.xsd")