diff --git a/duckdb/experimental/spark/sql/session.py b/duckdb/experimental/spark/sql/session.py index c407a9f1..40fbaa74 100644 --- a/duckdb/experimental/spark/sql/session.py +++ b/duckdb/experimental/spark/sql/session.py @@ -1,6 +1,7 @@ import uuid from collections.abc import Iterable, Sized -from typing import TYPE_CHECKING, Any, NoReturn, Union +from functools import reduce +from typing import TYPE_CHECKING, Any, List, NoReturn, Optional, Union import duckdb @@ -12,13 +13,13 @@ from ..conf import SparkConf from ..context import SparkContext -from ..errors import PySparkTypeError +from ..errors import PySparkTypeError, PySparkValueError from ..exception import ContributionsAcceptedError from .conf import RuntimeConfig from .dataframe import DataFrame from .readwriter import DataFrameReader from .streaming import DataStreamReader -from .types import StructType +from .types import StructType, _has_nulltype, _infer_schema, _merge_type from .udf import UDFRegistration # In spark: @@ -38,7 +39,11 @@ def _combine_data_and_schema(data: Iterable[Any], schema: StructType) -> list[du new_data = [] for row in data: - new_row = [Value(x, dtype.duckdb_type) for x, dtype in zip(row, [y.dataType for y in schema], strict=False)] + if isinstance(row, dict): + row_values = list(map(row.get, schema.fieldNames())) + else: + row_values = list(row) + new_row = [Value(x, dtype.duckdb_type) for x, dtype in zip(row_values, [y.dataType for y in schema], strict=False)] new_data.append(new_row) return new_data @@ -150,6 +155,9 @@ def createDataFrame( # noqa: D102 types, names = schema.extract_types_and_names() else: names = schema + elif isinstance(data, list) and data: + schema = self._inferSchemaFromList(data) + types, names = schema.extract_types_and_names() try: import pandas @@ -188,6 +196,53 @@ def createDataFrame( # noqa: D102 df = df.toDF(*names) return df + def _inferSchemaFromList( + self, data: Iterable[Any], names: Optional[List[str]] = None + ) -> StructType: + """Infer schema from list of Row, dict, or tuple. + + Parameters + ---------- + data : iterable + list of Row, dict, or tuple + names : list, optional + list of column names + + Returns + ------- + :class:`duckdb.experimental.spark.sql.types.StructType` + """ + if not data: + raise PySparkValueError( + error_class="CANNOT_INFER_EMPTY_SCHEMA", + message_parameters={}, + ) + + # TODO: These should be configurable + infer_dict_as_struct = False + infer_array_from_first_element = False + prefer_timestamp_ntz = False + + schema = reduce( + _merge_type, + ( + _infer_schema( + row, + names, + infer_dict_as_struct=infer_dict_as_struct, + infer_array_from_first_element=infer_array_from_first_element, + prefer_timestamp_ntz=prefer_timestamp_ntz, + ) + for row in data + ), + ) + if _has_nulltype(schema): + raise PySparkValueError( + error_class="CANNOT_DETERMINE_TYPE", + message_parameters={}, + ) + return schema + def newSession(self) -> "SparkSession": # noqa: D102 return SparkSession(self._context) diff --git a/duckdb/experimental/spark/sql/types.py b/duckdb/experimental/spark/sql/types.py index 5bfff09f..b4eed95b 100644 --- a/duckdb/experimental/spark/sql/types.py +++ b/duckdb/experimental/spark/sql/types.py @@ -1,21 +1,25 @@ # This code is based on code from Apache Spark under the license found in the LICENSE # file located in the 'spark' folder. +import array import calendar import datetime +import decimal import math import re import time from builtins import tuple -from collections.abc import Iterator, Mapping +from collections.abc import Iterable, Iterator, Mapping +from functools import reduce from types import MappingProxyType -from typing import Any, ClassVar, NoReturn, TypeVar, Union, cast, overload +from typing import Any, ClassVar, Dict, List, NoReturn, Optional, Tuple, Type, TypeVar, Union, cast, overload from typing_extensions import Self import duckdb from duckdb.sqltypes import DuckDBPyType +from ..errors.exceptions.base import PySparkTypeError from ..exception import ContributionsAcceptedError T = TypeVar("T") @@ -1137,6 +1141,317 @@ def __eq__(self, other: object) -> bool: _INTERVAL_DAYTIME = re.compile(r"interval (day|hour|minute|second)( to (day|hour|minute|second))?") +# Mapping Python types to Spark SQL DataType +_type_mappings = { + type(None): NullType, + bool: BooleanType, + int: LongType, + float: DoubleType, + str: StringType, + bytearray: BinaryType, + decimal.Decimal: DecimalType, + datetime.date: DateType, + datetime.datetime: TimestampType, # can be TimestampNTZType + datetime.time: TimestampType, # can be TimestampNTZType + datetime.timedelta: DayTimeIntervalType, + bytes: BinaryType, +} + + +# The list of all supported array typecodes, is stored here +_array_type_mappings: Dict[str, Type[DataType]] = { + # Warning: Actual properties for float and double in C is not specified in C. + # On almost every system supported by both python and JVM, they are IEEE 754 + # single-precision binary floating-point format and IEEE 754 double-precision + # binary floating-point format. And we do assume the same thing here for now. + "f": FloatType, + "d": DoubleType, +} + + +def _has_nulltype(dt: DataType) -> bool: + """Return whether there is a NullType in `dt` or not""" + if isinstance(dt, StructType): + return any(_has_nulltype(f.dataType) for f in dt.fields) + elif isinstance(dt, ArrayType): + return _has_nulltype((dt.elementType)) + elif isinstance(dt, MapType): + return _has_nulltype(dt.keyType) or _has_nulltype(dt.valueType) + else: + return isinstance(dt, NullType) + + +@overload +def _merge_type( + a: StructType, b: StructType, name: Optional[str] = None +) -> StructType: ... + + +@overload +def _merge_type( + a: ArrayType, b: ArrayType, name: Optional[str] = None +) -> ArrayType: ... + + +@overload +def _merge_type(a: MapType, b: MapType, name: Optional[str] = None) -> MapType: ... + + +@overload +def _merge_type(a: DataType, b: DataType, name: Optional[str] = None) -> DataType: ... + + +def _merge_type( + a: Union[StructType, ArrayType, MapType, DataType], + b: Union[StructType, ArrayType, MapType, DataType], + name: Optional[str] = None, +) -> Union[StructType, ArrayType, MapType, DataType]: + if name is None: + + def new_msg(msg: str) -> str: + return msg + + def new_name(n: str) -> str: + return "field %s" % n + + else: + + def new_msg(msg: str) -> str: + return "%s: %s" % (name, msg) + + def new_name(n: str) -> str: + return "field %s in %s" % (n, name) + + if isinstance(a, NullType): + return b + elif isinstance(b, NullType): + return a + elif isinstance(a, TimestampType) and isinstance(b, TimestampNTZType): + return a + elif isinstance(a, TimestampNTZType) and isinstance(b, TimestampType): + return b + elif isinstance(a, AtomicType) and isinstance(b, StringType): + return b + elif isinstance(a, StringType) and isinstance(b, AtomicType): + return a + elif type(a) is not type(b): + # TODO: type cast (such as int -> long) + raise PySparkTypeError( + error_class="CANNOT_MERGE_TYPE", + message_parameters={ + "data_type1": type(a).__name__, + "data_type2": type(b).__name__, + }, + ) + + # same type + if isinstance(a, StructType): + nfs = dict((f.name, f.dataType) for f in cast(StructType, b).fields) + fields = [ + StructField( + f.name, + _merge_type( + f.dataType, nfs.get(f.name, NullType()), name=new_name(f.name) + ), + ) + for f in a.fields + ] + names = set([f.name for f in fields]) + for n in nfs: + if n not in names: + fields.append(StructField(n, nfs[n])) + return StructType(fields) + + elif isinstance(a, ArrayType): + return ArrayType( + _merge_type( + a.elementType, + cast(ArrayType, b).elementType, + name="element in array %s" % name, + ), + True, + ) + + elif isinstance(a, MapType): + return MapType( + _merge_type( + a.keyType, cast(MapType, b).keyType, name="key of map %s" % name + ), + _merge_type( + a.valueType, cast(MapType, b).valueType, name="value of map %s" % name + ), + True, + ) + else: + return a + + +def _infer_type( + obj: Any, + infer_dict_as_struct: bool = False, + infer_array_from_first_element: bool = False, + prefer_timestamp_ntz: bool = False, +) -> DataType: + """Infer the DataType from obj""" + if obj is None: + return NullType() + + if hasattr(obj, "__UDT__"): + return obj.__UDT__ + + dataType = _type_mappings.get(type(obj)) + if dataType is DecimalType: + # the precision and scale of `obj` may be different from row to row. + return DecimalType(38, 18) + if dataType is TimestampType and prefer_timestamp_ntz and obj.tzinfo is None: + return TimestampNTZType() + if dataType is DayTimeIntervalType: + return DayTimeIntervalType() + elif dataType is not None: + return dataType() + + if isinstance(obj, dict): + if infer_dict_as_struct: + struct = StructType() + for key, value in obj.items(): + if key is not None and value is not None: + struct.add( + key, + _infer_type( + value, + infer_dict_as_struct, + infer_array_from_first_element, + prefer_timestamp_ntz, + ), + True, + ) + return struct + else: + for key, value in obj.items(): + if key is not None and value is not None: + return MapType( + _infer_type( + key, + infer_dict_as_struct, + infer_array_from_first_element, + prefer_timestamp_ntz, + ), + _infer_type( + value, + infer_dict_as_struct, + infer_array_from_first_element, + prefer_timestamp_ntz, + ), + True, + ) + return MapType(NullType(), NullType(), True) + elif isinstance(obj, list): + if len(obj) > 0: + if infer_array_from_first_element: + return ArrayType( + _infer_type( + obj[0], + infer_dict_as_struct, + infer_array_from_first_element, + prefer_timestamp_ntz, + ), + True, + ) + else: + return ArrayType( + reduce( + _merge_type, + ( + _infer_type( + v, + infer_dict_as_struct, + infer_array_from_first_element, + prefer_timestamp_ntz, + ) + for v in obj + ), + ), + True, + ) + return ArrayType(NullType(), True) + elif isinstance(obj, array.array): + if obj.typecode in _array_type_mappings: + return ArrayType(_array_type_mappings[obj.typecode](), False) + else: + raise PySparkTypeError( + error_class="UNSUPPORTED_DATA_TYPE", + message_parameters={"data_type": f"array({obj.typecode})"}, + ) + else: + try: + return _infer_schema( + obj, + infer_dict_as_struct=infer_dict_as_struct, + infer_array_from_first_element=infer_array_from_first_element, + ) + except TypeError: + raise PySparkTypeError( + error_class="UNSUPPORTED_DATA_TYPE", + message_parameters={"data_type": type(obj).__name__}, + ) + + +def _infer_schema( + row: Any, + names: Optional[List[str]] = None, + infer_dict_as_struct: bool = False, + infer_array_from_first_element: bool = False, + prefer_timestamp_ntz: bool = False, +) -> StructType: + """Infer the schema from dict/namedtuple/object""" + items: Iterable[Tuple[str, Any]] + if isinstance(row, dict): + items = sorted(row.items()) + + elif isinstance(row, (tuple, list)): + if hasattr(row, "__fields__"): # Row + items = zip(row.__fields__, tuple(row)) # type: ignore[union-attr] + elif hasattr(row, "_fields"): # namedtuple + items = zip(row._fields, tuple(row)) # type: ignore[union-attr] + else: + if names is None: + names = ["_%d" % i for i in range(1, len(row) + 1)] + elif len(names) < len(row): + names.extend("_%d" % i for i in range(len(names) + 1, len(row) + 1)) + items = zip(names, row) + + elif hasattr(row, "__dict__"): # object + items = sorted(row.__dict__.items()) + + else: + raise PySparkTypeError( + error_class="CANNOT_INFER_SCHEMA_FOR_TYPE", + message_parameters={"data_type": type(row).__name__}, + ) + + fields = [] + for k, v in items: + try: + fields.append( + StructField( + k, + _infer_type( + v, + infer_dict_as_struct, + infer_array_from_first_element, + prefer_timestamp_ntz, + ), + True, + ) + ) + except TypeError: + raise PySparkTypeError( + error_class="CANNOT_INFER_TYPE_FOR_FIELD", + message_parameters={"field_name": k}, + ) + return StructType(fields) + + def _create_row(fields: Union["Row", list[str]], values: tuple[Any, ...] | list[Any]) -> "Row": row = Row(*values) row.__fields__ = fields diff --git a/tests/fast/spark/test_spark_dataframe.py b/tests/fast/spark/test_spark_dataframe.py index e242092e..6a2f0079 100644 --- a/tests/fast/spark/test_spark_dataframe.py +++ b/tests/fast/spark/test_spark_dataframe.py @@ -121,6 +121,20 @@ def test_dataframe_from_list_of_tuples(self, spark): with pytest.raises(TypeError, match="must be an iterable, not int"): spark.createDataFrame(address, 5) + def test_dataframe_from_list_dicts(self, spark): + data = [ + {"id": 1, "name": "Alice", "age": 25}, + {"id": 2, "age": 30, "name": "Bob"}, + {"age": 35, "id": 3, "name": "Charlie", "city": "New York"}, + ] + df = spark.createDataFrame(data) + res = df.collect() + assert res == [ + Row(age=25, id=1, name='Alice', city=None), + Row(age=30, id=2, name='Bob', city=None), + Row(age=35, id=3, name='Charlie', city='New York'), + ] + def test_dataframe(self, spark): # Create DataFrame df = spark.createDataFrame([("Scala", 25000), ("Spark", 35000), ("PHP", 21000)])