Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 59 additions & 4 deletions duckdb/experimental/spark/sql/session.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import uuid
from collections.abc import Iterable, Sized
from typing import TYPE_CHECKING, Any, NoReturn, Union
from functools import reduce
from typing import TYPE_CHECKING, Any, List, NoReturn, Optional, Union

import duckdb

Expand All @@ -12,13 +13,13 @@

from ..conf import SparkConf
from ..context import SparkContext
from ..errors import PySparkTypeError
from ..errors import PySparkTypeError, PySparkValueError
from ..exception import ContributionsAcceptedError
from .conf import RuntimeConfig
from .dataframe import DataFrame
from .readwriter import DataFrameReader
from .streaming import DataStreamReader
from .types import StructType
from .types import StructType, _has_nulltype, _infer_schema, _merge_type
from .udf import UDFRegistration

# In spark:
Expand All @@ -38,7 +39,11 @@ def _combine_data_and_schema(data: Iterable[Any], schema: StructType) -> list[du

new_data = []
for row in data:
new_row = [Value(x, dtype.duckdb_type) for x, dtype in zip(row, [y.dataType for y in schema], strict=False)]
if isinstance(row, dict):
row_values = list(map(row.get, schema.fieldNames()))
else:
row_values = list(row)
new_row = [Value(x, dtype.duckdb_type) for x, dtype in zip(row_values, [y.dataType for y in schema], strict=False)]
new_data.append(new_row)
return new_data

Expand Down Expand Up @@ -150,6 +155,9 @@ def createDataFrame( # noqa: D102
types, names = schema.extract_types_and_names()
else:
names = schema
elif isinstance(data, list) and data:
schema = self._inferSchemaFromList(data)
types, names = schema.extract_types_and_names()
Comment on lines +158 to +160

try:
import pandas
Expand Down Expand Up @@ -188,6 +196,53 @@ def createDataFrame( # noqa: D102
df = df.toDF(*names)
return df

def _inferSchemaFromList(
self, data: Iterable[Any], names: Optional[List[str]] = None
) -> StructType:
Comment on lines +199 to +201
"""Infer schema from list of Row, dict, or tuple.

Parameters
----------
data : iterable
list of Row, dict, or tuple
names : list, optional
list of column names

Returns
-------
:class:`duckdb.experimental.spark.sql.types.StructType`
"""
if not data:
raise PySparkValueError(
error_class="CANNOT_INFER_EMPTY_SCHEMA",
message_parameters={},
)
Comment on lines +215 to +219

# TODO: These should be configurable
infer_dict_as_struct = False
infer_array_from_first_element = False
prefer_timestamp_ntz = False

schema = reduce(
_merge_type,
(
_infer_schema(
row,
names,
infer_dict_as_struct=infer_dict_as_struct,
infer_array_from_first_element=infer_array_from_first_element,
prefer_timestamp_ntz=prefer_timestamp_ntz,
)
for row in data
),
)
Comment on lines +226 to +238
if _has_nulltype(schema):
raise PySparkValueError(
error_class="CANNOT_DETERMINE_TYPE",
message_parameters={},
)
return schema

def newSession(self) -> "SparkSession": # noqa: D102
return SparkSession(self._context)

Expand Down
Loading
Loading