From ea28569da862d068b0582e693fb8006ebf24bda3 Mon Sep 17 00:00:00 2001 From: Mario Taddeucci Date: Mon, 16 Feb 2026 14:24:38 -0300 Subject: [PATCH 1/6] [PySpark] - Add window function support --- duckdb/experimental/spark/sql/column.py | 44 +- duckdb/experimental/spark/sql/functions.py | 185 +++++++ duckdb/experimental/spark/sql/window.py | 468 ++++++++++++++++++ .../fast/spark/test_spark_functions_window.py | 132 +++++ tests/spark_namespace/sql/window.py | 6 + 5 files changed, 834 insertions(+), 1 deletion(-) create mode 100644 duckdb/experimental/spark/sql/window.py create mode 100644 tests/fast/spark/test_spark_functions_window.py create mode 100644 tests/spark_namespace/sql/window.py diff --git a/duckdb/experimental/spark/sql/column.py b/duckdb/experimental/spark/sql/column.py index e013a56d..4088c60f 100644 --- a/duckdb/experimental/spark/sql/column.py +++ b/duckdb/experimental/spark/sql/column.py @@ -9,8 +9,9 @@ if TYPE_CHECKING: from ._typing import DateTimeLiteral, DecimalLiteral, LiteralType + from .window import WindowSpec -from duckdb import ColumnExpression, ConstantExpression, Expression, FunctionExpression +from duckdb import ColumnExpression, ConstantExpression, Expression, FunctionExpression, SQLExpression from duckdb.sqltypes import DuckDBPyType __all__ = ["Column"] @@ -362,3 +363,44 @@ def isNull(self) -> "Column": # noqa: D102 def isNotNull(self) -> "Column": # noqa: D102 return Column(self.expr.isnotnull()) + + def over(self, window_spec: "WindowSpec") -> "Column": + """Define a windowing column. + + .. versionadded:: 1.4.0 + + .. versionchanged:: 3.4.0 + Supports Spark Connect. + + Parameters + ---------- + window : :class:`WindowSpec` + + Returns: + ------- + :class:`Column` + + Examples: + -------- + >>> from pyspark.sql import Window + >>> window = ( + ... Window.partitionBy("name") + ... .orderBy("age") + ... .rowsBetween(Window.unboundedPreceding, Window.currentRow) + ... ) + >>> from pyspark.sql.functions import rank, min, desc + >>> df = spark.createDataFrame([(2, "Alice"), (5, "Bob")], ["age", "name"]) + >>> df.withColumn("rank", rank().over(window)).withColumn( + ... "min", min("age").over(window) + ... ).sort(desc("age")).show() + +---+-----+----+---+ + |age| name|rank|min| + +---+-----+----+---+ + | 5| Bob| 1| 5| + | 2|Alice| 1| 2| + +---+-----+----+---+ + """ + col_expr = self.expr + window_expr = window_spec._window_expr() + full_expr = f"{col_expr} OVER ({window_expr})" + return Column(SQLExpression(full_expr)) diff --git a/duckdb/experimental/spark/sql/functions.py b/duckdb/experimental/spark/sql/functions.py index 71ff8c59..2cbd0904 100644 --- a/duckdb/experimental/spark/sql/functions.py +++ b/duckdb/experimental/spark/sql/functions.py @@ -6215,3 +6215,188 @@ def broadcast(df: "DataFrame") -> "DataFrame": or optimizations, since broadcasting is not applicable in the DuckDB context. """ # noqa: D205 return df + + +def row_number() -> Column: + """Window function: returns a sequential number starting at 1 within a window partition. + + .. versionadded:: 1.6.0 + + .. versionchanged:: 3.4.0 + Supports Spark Connect. + + Returns: + ------- + :class:`~pyspark.sql.Column` + the column for calculating row numbers. + + Examples: + -------- + >>> from pyspark.sql import functions as sf + >>> from pyspark.sql import Window + >>> df = spark.range(3) + >>> w = Window.orderBy(df.id.desc()) + >>> df.withColumn("desc_order", sf.row_number().over(w)).show() + +---+----------+ + | id|desc_order| + +---+----------+ + | 2| 1| + | 1| 2| + | 0| 3| + +---+----------+ + """ + return _invoke_function("row_number") + + +def dense_rank() -> Column: + """Window function: returns the rank of rows within a window partition, without any gaps. + + The difference between rank and dense_rank is that dense_rank leaves no gaps in ranking + sequence when there are ties. That is, if you were ranking a competition using dense_rank + and had three people tie for second place, you would say that all three were in second + place and that the next person came in third. Rank would give me sequential numbers, making + the person that came in third place (after the ties) would register as coming in fifth. + + This is equivalent to the DENSE_RANK function in SQL. + + .. versionadded:: 1.6.0 + + .. versionchanged:: 3.4.0 + Supports Spark Connect. + + Returns: + ------- + :class:`~pyspark.sql.Column` + the column for calculating ranks. + + Examples: + -------- + >>> from pyspark.sql import functions as sf + >>> from pyspark.sql import Window + >>> df = spark.createDataFrame([1, 1, 2, 3, 3, 4], "int") + >>> w = Window.orderBy("value") + >>> df.withColumn("drank", sf.dense_rank().over(w)).show() + +-----+-----+ + |value|drank| + +-----+-----+ + | 1| 1| + | 1| 1| + | 2| 2| + | 3| 3| + | 3| 3| + | 4| 4| + +-----+-----+ + """ + return _invoke_function("dense_rank") + + +def rank() -> Column: + """Window function: returns the rank of rows within a window partition. + + The difference between rank and dense_rank is that dense_rank leaves no gaps in ranking + sequence when there are ties. That is, if you were ranking a competition using dense_rank + and had three people tie for second place, you would say that all three were in second + place and that the next person came in third. Rank would give me sequential numbers, making + the person that came in third place (after the ties) would register as coming in fifth. + + This is equivalent to the RANK function in SQL. + + .. versionadded:: 1.6.0 + + .. versionchanged:: 3.4.0 + Supports Spark Connect. + + Returns: + ------- + :class:`~pyspark.sql.Column` + the column for calculating ranks. + + Examples: + -------- + >>> from pyspark.sql import functions as sf + >>> from pyspark.sql import Window + >>> df = spark.createDataFrame([1, 1, 2, 3, 3, 4], "int") + >>> w = Window.orderBy("value") + >>> df.withColumn("drank", sf.rank().over(w)).show() + +-----+-----+ + |value|drank| + +-----+-----+ + | 1| 1| + | 1| 1| + | 2| 3| + | 3| 4| + | 3| 4| + | 4| 6| + +-----+-----+ + """ + return _invoke_function("rank") + + +def cume_dist() -> Column: + """Window function: returns the cumulative distribution of values within a window partition. + + Window function: returns the cumulative distribution of values within a window partition + i.e. the fraction of rows that are below the current row. + + .. versionadded:: 1.6.0 + + .. versionchanged:: 3.4.0 + Supports Spark Connect. + + Returns: + ------- + :class:`~pyspark.sql.Column` + the column for calculating cumulative distribution. + + Examples: + -------- + >>> from pyspark.sql import functions as sf + >>> from pyspark.sql import Window + >>> df = spark.createDataFrame([1, 2, 3, 3, 4], "int") + >>> w = Window.orderBy("value") + >>> df.withColumn("cd", sf.cume_dist().over(w)).show() + +-----+---+ + |value| cd| + +-----+---+ + | 1|0.2| + | 2|0.4| + | 3|0.8| + | 3|0.8| + | 4|1.0| + +-----+---+ + """ + return _invoke_function("cume_dist") + + +def percent_rank() -> Column: + """Window function: returns the relative rank (i.e. percentile) of rows within a window partition. + + .. versionadded:: 1.6.0 + + .. versionchanged:: 3.4.0 + Supports Spark Connect. + + Returns: + ------- + :class:`~pyspark.sql.Column` + the column for calculating relative rank. + + Examples: + -------- + >>> from pyspark.sql import functions as sf + >>> from pyspark.sql import Window + >>> df = spark.createDataFrame([1, 1, 2, 3, 3, 4], "int") + >>> w = Window.orderBy("value") + >>> df.withColumn("pr", sf.percent_rank().over(w)).show() + +-----+---+ + |value| pr| + +-----+---+ + | 1|0.0| + | 1|0.0| + | 2|0.4| + | 3|0.6| + | 3|0.6| + | 4|1.0| + +-----+---+ + """ + return _invoke_function("percent_rank") diff --git a/duckdb/experimental/spark/sql/window.py b/duckdb/experimental/spark/sql/window.py new file mode 100644 index 00000000..a2865c3e --- /dev/null +++ b/duckdb/experimental/spark/sql/window.py @@ -0,0 +1,468 @@ +from collections.abc import Sequence +from typing import List, Optional, Tuple, Union + +from ..errors import PySparkTypeError +from ..exception import ContributionsAcceptedError +from ._typing import ColumnOrName +from .column import Column + + +class WindowSpec: + """A window specification that defines the partitioning, ordering, and frame boundaries. + + Use the static methods in :class:`Window` to create a :class:`WindowSpec`. + + .. versionadded:: 1.4.0 + + .. versionchanged:: 3.4.0 + Supports Spark Connect. + """ + + def __init__(self) -> None: + self._partition_by: List[ColumnOrName] = [] + self._order_by: List[ColumnOrName] = [] + self._rows_between: Optional[Tuple[int, int]] = None + self._range_between: Optional[Tuple[int, int]] = None + + def _copy(self) -> "WindowSpec": + new_window = WindowSpec() + new_window._partition_by = self._partition_by.copy() + new_window._order_by = self._order_by.copy() + new_window._rows_between = self._rows_between + new_window._range_between = self._range_between + return new_window + + def partitionBy(self, *cols: Union[ColumnOrName, Sequence[ColumnOrName]]) -> "WindowSpec": + """Defines the partitioning columns in a :class:`WindowSpec`. + + .. versionadded:: 1.4.0 + + Parameters + ---------- + cols : str, :class:`Column` or list + names of columns or expressions + """ + all_cols: Union[List[ColumnOrName], List[List[ColumnOrName]]] = list(cols) # type: ignore[assignment] + + if isinstance(all_cols[0], list): + all_cols = all_cols[0] + + new_window = self._copy() + new_window._partition_by = all_cols + return new_window + + def orderBy(self, *cols: Union[ColumnOrName, Sequence[ColumnOrName]]) -> "WindowSpec": + """Defines the ordering columns in a :class:`WindowSpec`. + + .. versionadded:: 1.4.0 + + Parameters + ---------- + cols : str, :class:`Column` or list + names of columns or expressions + """ + all_cols: Union[List[ColumnOrName], List[List[ColumnOrName]]] = list(cols) # type: ignore[assignment] + + if isinstance(all_cols[0], list): + all_cols = all_cols[0] + + new_window = self._copy() + new_window._order_by = all_cols + return new_window + + def rowsBetween(self, start: int, end: int) -> "WindowSpec": + """Defines the frame boundaries, from `start` (inclusive) to `end` (inclusive). + + Both `start` and `end` are relative positions from the current row. + For example, "0" means "current row", while "-1" means the row before + the current row, and "5" means the fifth row after the current row. + + We recommend users use ``Window.unboundedPreceding``, ``Window.unboundedFollowing``, + and ``Window.currentRow`` to specify special boundary values, rather than using integral + values directly. + + .. versionadded:: 1.4.0 + + Parameters + ---------- + start : int + boundary start, inclusive. + The frame is unbounded if this is ``Window.unboundedPreceding``, or + any value less than or equal to max(-sys.maxsize, -9223372036854775808). + end : int + boundary end, inclusive. + The frame is unbounded if this is ``Window.unboundedFollowing``, or + any value greater than or equal to min(sys.maxsize, 9223372036854775807). + """ + new_window = self._copy() + new_window._rows_between = (start, end) + return new_window + + def rangeBetween(self, start: int, end: int) -> "WindowSpec": + """Defines the frame boundaries, from `start` (inclusive) to `end` (inclusive). + + Both `start` and `end` are relative from the current row. For example, + "0" means "current row", while "-1" means one off before the current row, + and "5" means the five off after the current row. + + We recommend users use ``Window.unboundedPreceding``, ``Window.unboundedFollowing``, + and ``Window.currentRow`` to specify special boundary values, rather than using integral + values directly. + + .. versionadded:: 1.4.0 + + Parameters + ---------- + start : int + boundary start, inclusive. + The frame is unbounded if this is ``Window.unboundedPreceding``, or + any value less than or equal to max(-sys.maxsize, -9223372036854775808). + end : int + boundary end, inclusive. + The frame is unbounded if this is ``Window.unboundedFollowing``, or + any value greater than or equal to min(sys.maxsize, 9223372036854775807). + """ + new_window = self._copy() + new_window._range_between = (start, end) + return new_window + + def _columns_as_str(self, *, cols: List[ColumnOrName], include_order_direction: bool) -> list[str]: + expressions = [] + for col in cols: + if isinstance(col, str): + expressions.append(col) + elif isinstance(col, Column): + if include_order_direction: + # TODO: Handle ascending/descending order if needed + raise ContributionsAcceptedError("Column Expression is not supported in WindowSpec.orderBy yet") + + else: + expressions.append(str(col.expr)) + else: + raise PySparkTypeError(f"Invalid column type: {type(col)}") + return expressions + + @staticmethod + def _generate_window_interval_expr(start: int, end: int) -> str: + if start == Window.currentRow and end == Window.currentRow: + return "CURRENT ROW AND CURRENT ROW" + + if start == Window.currentRow: + return f"CURRENT ROW AND {end} FOLLOWING" + + if end == Window.currentRow: + return f"{start} PRECEDING AND CURRENT ROW" + if start == Window.unboundedPreceding and end == Window.unboundedFollowing: + return "UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING" + + if start == Window.unboundedPreceding: + return f"UNBOUNDED PRECEDING AND {end} FOLLOWING" + if end == Window.unboundedFollowing: + return f"{start} PRECEDING AND UNBOUNDED FOLLOWING" + + return f"{start} PRECEDING AND {end} FOLLOWING" + + def _window_expr(self) -> str: + parts = [] + if self._partition_by: + parts.append( + "PARTITION BY " + + ", ".join(self._columns_as_str(cols=self._partition_by, include_order_direction=False)) + ) + if self._order_by: + parts.append( + "ORDER BY " + ", ".join(self._columns_as_str(cols=self._order_by, include_order_direction=True)) + ) + if self._rows_between is not None: + parts.append(f"ROWS BETWEEN {self._generate_window_interval_expr(*self._rows_between)}") + if self._range_between is not None: + parts.append(f"RANGE BETWEEN {self._generate_window_interval_expr(*self._range_between)}") + sql = " ".join(parts) + return sql + + +class Window: + """Utility functions for defining window in DataFrames. + + .. versionadded:: 1.4.0 + + .. versionchanged:: 3.4.0 + Supports Spark Connect. + + Notes: + ----- + When ordering is not defined, an unbounded window frame (rowFrame, + unboundedPreceding, unboundedFollowing) is used by default. When ordering is defined, + a growing window frame (rangeFrame, unboundedPreceding, currentRow) is used by default. + + Examples: + -------- + >>> # ORDER BY date ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW + >>> window = Window.orderBy("date").rowsBetween( + ... Window.unboundedPreceding, Window.currentRow + ... ) + + >>> # PARTITION BY country ORDER BY date RANGE BETWEEN 3 PRECEDING AND 3 FOLLOWING + >>> window = Window.orderBy("date").partitionBy("country").rangeBetween(-3, 3) + """ + + currentRow = 0 + unboundedPreceding: int = -(1 << 63) # -9223372036854775808 - equivalent to Java's Long.MIN_VALUE + unboundedFollowing: int = (1 << 63) - 1 # 9223372036854775807 - equivalent to Java's Long.MAX_VALUE + + @classmethod + def partitionBy(cls, *cols: Union[ColumnOrName, Sequence[ColumnOrName]]) -> WindowSpec: + """Creates a :class:`WindowSpec` with the partitioning defined. + + .. versionadded:: 1.4.0 + + Parameters + ---------- + cols : str, :class:`Column` or list + names of columns or expressions + + Returns: + ------- + :class: `WindowSpec` + A :class:`WindowSpec` with the partitioning defined. + + Examples: + -------- + >>> from pyspark.sql import Window + >>> from pyspark.sql.functions import row_number + >>> df = spark.createDataFrame( + ... [(1, "a"), (1, "a"), (2, "a"), (1, "b"), (2, "b"), (3, "b")], ["id", "category"] + ... ) + >>> df.show() + +---+--------+ + | id|category| + +---+--------+ + | 1| a| + | 1| a| + | 2| a| + | 1| b| + | 2| b| + | 3| b| + +---+--------+ + + Show row number order by ``id`` in partition ``category``. + + >>> window = Window.partitionBy("category").orderBy("id") + >>> df.withColumn("row_number", row_number().over(window)).show() + +---+--------+----------+ + | id|category|row_number| + +---+--------+----------+ + | 1| a| 1| + | 1| a| 2| + | 2| a| 3| + | 1| b| 1| + | 2| b| 2| + | 3| b| 3| + +---+--------+----------+ + """ + return WindowSpec().partitionBy(*cols) + + @classmethod + def orderBy(cls, *cols: Union[ColumnOrName, Sequence[ColumnOrName]]) -> WindowSpec: + """Creates a :class:`WindowSpec` with the ordering defined. + + .. versionadded:: 1.4.0 + + Parameters + ---------- + cols : str, :class:`Column` or list + names of columns or expressions + + Returns: + ------- + :class: `WindowSpec` + A :class:`WindowSpec` with the ordering defined. + + Examples: + -------- + >>> from pyspark.sql import Window + >>> from pyspark.sql.functions import row_number + >>> df = spark.createDataFrame( + ... [(1, "a"), (1, "a"), (2, "a"), (1, "b"), (2, "b"), (3, "b")], ["id", "category"] + ... ) + >>> df.show() + +---+--------+ + | id|category| + +---+--------+ + | 1| a| + | 1| a| + | 2| a| + | 1| b| + | 2| b| + | 3| b| + +---+--------+ + + Show row number order by ``category`` in partition ``id``. + + >>> window = Window.partitionBy("id").orderBy("category") + >>> df.withColumn("row_number", row_number().over(window)).show() + +---+--------+----------+ + | id|category|row_number| + +---+--------+----------+ + | 1| a| 1| + | 1| a| 2| + | 1| b| 3| + | 2| a| 1| + | 2| b| 2| + | 3| b| 1| + +---+--------+----------+ + """ + return WindowSpec().orderBy(*cols) + + @classmethod + def rowsBetween(cls, start: int, end: int) -> WindowSpec: + """Creates a :class:`WindowSpec` with the frame boundaries defined, from `start` (inclusive) to `end` (inclusive). + + Both `start` and `end` are relative positions from the current row. + For example, "0" means "current row", while "-1" means the row before + the current row, and "5" means the fifth row after the current row. + + We recommend users use ``Window.unboundedPreceding``, ``Window.unboundedFollowing``, + and ``Window.currentRow`` to specify special boundary values, rather than using integral + values directly. + + A row based boundary is based on the position of the row within the partition. + An offset indicates the number of rows above or below the current row, the frame for the + current row starts or ends. For instance, given a row based sliding frame with a lower bound + offset of -1 and a upper bound offset of +2. The frame for row with index 5 would range from + index 4 to index 7. + + .. versionadded:: 2.1.0 + + Parameters + ---------- + start : int + boundary start, inclusive. + The frame is unbounded if this is ``Window.unboundedPreceding``, or + any value less than or equal to -9223372036854775808. + end : int + boundary end, inclusive. + The frame is unbounded if this is ``Window.unboundedFollowing``, or + any value greater than or equal to 9223372036854775807. + + Returns: + ------- + :class: `WindowSpec` + A :class:`WindowSpec` with the frame boundaries defined, + from `start` (inclusive) to `end` (inclusive). + + Examples: + -------- + >>> from pyspark.sql import Window + >>> from pyspark.sql import functions as func + >>> df = spark.createDataFrame( + ... [(1, "a"), (1, "a"), (2, "a"), (1, "b"), (2, "b"), (3, "b")], ["id", "category"] + ... ) + >>> df.show() + +---+--------+ + | id|category| + +---+--------+ + | 1| a| + | 1| a| + | 2| a| + | 1| b| + | 2| b| + | 3| b| + +---+--------+ + + Calculate sum of ``id`` in the range from currentRow to currentRow + 1 + in partition ``category`` + + >>> window = Window.partitionBy("category").orderBy("id").rowsBetween(Window.currentRow, 1) + >>> df.withColumn("sum", func.sum("id").over(window)).sort("id", "category", "sum").show() + +---+--------+---+ + | id|category|sum| + +---+--------+---+ + | 1| a| 2| + | 1| a| 3| + | 1| b| 3| + | 2| a| 2| + | 2| b| 5| + | 3| b| 3| + +---+--------+---+ + + """ + return WindowSpec().rowsBetween(start, end) + + @classmethod + def rangeBetween(cls, start: int, end: int) -> WindowSpec: + """Creates a :class:`WindowSpec` with the frame boundaries defined, from `start` (inclusive) to `end` (inclusive). + + Both `start` and `end` are relative from the current row. For example, + "0" means "current row", while "-1" means one off before the current row, + and "5" means the five off after the current row. + + We recommend users use ``Window.unboundedPreceding``, ``Window.unboundedFollowing``, + and ``Window.currentRow`` to specify special boundary values, rather than using integral + values directly. + + A range-based boundary is based on the actual value of the ORDER BY + expression(s). An offset is used to alter the value of the ORDER BY expression, for + instance if the current ORDER BY expression has a value of 10 and the lower bound offset + is -3, the resulting lower bound for the current row will be 10 - 3 = 7. This however puts a + number of constraints on the ORDER BY expressions: there can be only one expression and this + expression must have a numerical data type. An exception can be made when the offset is + unbounded, because no value modification is needed, in this case multiple and non-numeric + ORDER BY expression are allowed. + + .. versionadded:: 2.1.0 + + Parameters + ---------- + start : int + boundary start, inclusive. + The frame is unbounded if this is ``Window.unboundedPreceding``, or + any value less than or equal to max(-sys.maxsize, -9223372036854775808). + end : int + boundary end, inclusive. + The frame is unbounded if this is ``Window.unboundedFollowing``, or + any value greater than or equal to min(sys.maxsize, 9223372036854775807). + + Returns: + ------- + :class: `WindowSpec` + A :class:`WindowSpec` with the frame boundaries defined, + from `start` (inclusive) to `end` (inclusive). + + Examples: + -------- + >>> from pyspark.sql import Window + >>> from pyspark.sql import functions as func + >>> df = spark.createDataFrame( + ... [(1, "a"), (1, "a"), (2, "a"), (1, "b"), (2, "b"), (3, "b")], ["id", "category"] + ... ) + >>> df.show() + +---+--------+ + | id|category| + +---+--------+ + | 1| a| + | 1| a| + | 2| a| + | 1| b| + | 2| b| + | 3| b| + +---+--------+ + + Calculate sum of ``id`` in the range from ``id`` of currentRow to ``id`` of currentRow + 1 + in partition ``category`` + + >>> window = Window.partitionBy("category").orderBy("id").rangeBetween(Window.currentRow, 1) + >>> df.withColumn("sum", func.sum("id").over(window)).sort("id", "category").show() + +---+--------+---+ + | id|category|sum| + +---+--------+---+ + | 1| a| 4| + | 1| a| 4| + | 1| b| 3| + | 2| a| 2| + | 2| b| 5| + | 3| b| 3| + +---+--------+---+ + + """ + return WindowSpec().rangeBetween(start, end) diff --git a/tests/fast/spark/test_spark_functions_window.py b/tests/fast/spark/test_spark_functions_window.py new file mode 100644 index 00000000..fc9bb0c1 --- /dev/null +++ b/tests/fast/spark/test_spark_functions_window.py @@ -0,0 +1,132 @@ +import pytest + +_ = pytest.importorskip("duckdb.experimental.spark") + +from spark_namespace.sql import functions as F +from spark_namespace.sql.types import Row +from spark_namespace.sql.window import Window + +from duckdb.experimental.spark import ContributionsAcceptedError + + +class TestDataFrameWindowFunction: + def test_order_by(self, spark): + simpleData = [ + ("Sales", "NY", 2024, 10000), + ("Sales", "NY", 2025, 20000), + ("Sales", "CA", 2024, 23000), + ("Finance", "CA", 2024, 23000), + ("Finance", "CA", 2025, 24000), + ("Finance", "NY", 2025, 19000), + ("Finance", "NY", 2024, 15000), + ("Marketing", "CA", 2024, 18000), + ("Marketing", "NY", 2025, 21000), + ] + columns = ["department", "state", "year", "bonus"] + df = spark.createDataFrame(data=simpleData, schema=columns) + df = df.withColumn( + "cumulative_bonus", F.sum("bonus").over(Window.partitionBy("department", "state").orderBy("year")) + ) + df = df.sort("department", "state", "year") + res1 = df.collect() + assert res1 == [ + Row(department="Finance", state="CA", year=2024, bonus=23000, cumulative_bonus=23000), + Row(department="Finance", state="CA", year=2025, bonus=24000, cumulative_bonus=47000), + Row(department="Finance", state="NY", year=2024, bonus=15000, cumulative_bonus=15000), + Row(department="Finance", state="NY", year=2025, bonus=19000, cumulative_bonus=34000), + Row(department="Marketing", state="CA", year=2024, bonus=18000, cumulative_bonus=18000), + Row(department="Marketing", state="NY", year=2025, bonus=21000, cumulative_bonus=21000), + Row(department="Sales", state="CA", year=2024, bonus=23000, cumulative_bonus=23000), + Row(department="Sales", state="NY", year=2024, bonus=10000, cumulative_bonus=10000), + Row(department="Sales", state="NY", year=2025, bonus=20000, cumulative_bonus=30000), + ] + + def test_percent_rank(self, spark): + df = spark.createDataFrame(data=[(1,), (1,), (2,), (3,), (3,), (4,)], schema=["value"]) + w = Window.orderBy("value") + df = df.withColumn("pr", F.percent_rank().over(w)) + res = df.sort("value").collect() + + assert res == [ + Row(value=1, pr=0.0), + Row(value=1, pr=0.0), + Row(value=2, pr=0.4), + Row(value=3, pr=0.6), + Row(value=3, pr=0.6), + Row(value=4, pr=1.0), + ] + + def test_cume_dist(self, spark): + df = spark.createDataFrame(data=[(1,), (2,), (3,), (3,), (4,)], schema=["value"]) + w = Window.orderBy("value") + df = df.withColumn("cd", F.cume_dist().over(w)) + df = df.sort("value") + res = df.collect() + + assert res == [ + Row(value=1, cd=0.2), + Row(value=2, cd=0.4), + Row(value=3, cd=0.8), + Row(value=3, cd=0.8), + Row(value=4, cd=1.0), + ] + + def test_simple_row_number(self, spark): + df = spark.createDataFrame( + data=[(2, "A"), (4, "A"), (3, "A"), (2, "B"), (1, "B"), (3, "B")], schema=["value", "grp"] + ) + w = Window.partitionBy("grp").orderBy("value") + df = df.withColumn("rn", F.row_number().over(w)) + res = df.sort("grp", "value").collect() + + assert res == [ + Row(value=2, grp="A", rn=1), + Row(value=3, grp="A", rn=2), + Row(value=4, grp="A", rn=3), + Row(value=1, grp="B", rn=1), + Row(value=2, grp="B", rn=2), + Row(value=3, grp="B", rn=3), + ] + + def test_deduplicate_rows(self, spark): + df = spark.createDataFrame( + data=[(2, "A"), (4, "A"), (3, "A"), (2, "B"), (1, "B"), (3, "B")], schema=["value", "grp"] + ) + w = Window.partitionBy(F.col("grp")).orderBy(F.col("value").desc()) + + with pytest.raises( + ContributionsAcceptedError, match="Column Expression is not supported in WindowSpec.orderBy yet" + ): + df = df.withColumn("rn", F.row_number().over(w)) + + def test_moving_average_last_3_points(self, spark): + data = [(1, 10), (2, 20), (3, 30), (4, 40), (5, 50)] + df = spark.createDataFrame(data=data, schema=["idx", "value"]) + w = Window.orderBy("idx").rowsBetween(2, Window.currentRow) + df = df.withColumn("ma3", F.avg("value").over(w)) + res = df.sort("idx").collect() + + assert res == [ + Row(idx=1, value=10, ma3=10.0), + Row(idx=2, value=20, ma3=15.0), + Row(idx=3, value=30, ma3=20.0), + Row(idx=4, value=40, ma3=30.0), + Row(idx=5, value=50, ma3=40.0), + ] + + def test_range_between(self, spark): + # rangeBetween uses the ordering column's values; here we include + # rows within a value distance of 2 up to the current row. + data = [(1, 10), (2, 20), (3, 30), (4, 40), (6, 60)] + df = spark.createDataFrame(data=data, schema=["idx", "value"]) + w = Window.orderBy("idx").rangeBetween(2, Window.currentRow) + df = df.withColumn("ma_range2", F.avg("value").over(w)) + res = df.sort("idx").collect() + + assert res == [ + Row(idx=1, value=10, ma_range2=10.0), + Row(idx=2, value=20, ma_range2=15.0), + Row(idx=3, value=30, ma_range2=20.0), + Row(idx=4, value=40, ma_range2=30.0), + Row(idx=6, value=60, ma_range2=50.0), + ] diff --git a/tests/spark_namespace/sql/window.py b/tests/spark_namespace/sql/window.py new file mode 100644 index 00000000..d80c263c --- /dev/null +++ b/tests/spark_namespace/sql/window.py @@ -0,0 +1,6 @@ +from .. import USE_ACTUAL_SPARK + +if USE_ACTUAL_SPARK: + from pyspark.sql.window import * +else: + from duckdb.experimental.spark.sql.window import * From 0252e92f13fe627e8100886edbdb2e2e141dd5b9 Mon Sep 17 00:00:00 2001 From: Mario Taddeucci Date: Tue, 17 Feb 2026 09:42:34 -0300 Subject: [PATCH 2/6] Refactor WindowSpec to use built-in list and tuple types; update test for ContributionsAcceptedError message regex --- duckdb/experimental/spark/sql/window.py | 32 ++++++++++--------- .../fast/spark/test_spark_functions_window.py | 2 +- 2 files changed, 18 insertions(+), 16 deletions(-) diff --git a/duckdb/experimental/spark/sql/window.py b/duckdb/experimental/spark/sql/window.py index a2865c3e..a0a6a141 100644 --- a/duckdb/experimental/spark/sql/window.py +++ b/duckdb/experimental/spark/sql/window.py @@ -1,5 +1,5 @@ -from collections.abc import Sequence -from typing import List, Optional, Tuple, Union +from collections.abc import Sequence # noqa: D100 +from typing import Optional, Union from ..errors import PySparkTypeError from ..exception import ContributionsAcceptedError @@ -18,11 +18,11 @@ class WindowSpec: Supports Spark Connect. """ - def __init__(self) -> None: - self._partition_by: List[ColumnOrName] = [] - self._order_by: List[ColumnOrName] = [] - self._rows_between: Optional[Tuple[int, int]] = None - self._range_between: Optional[Tuple[int, int]] = None + def __init__(self) -> None: # noqa: D107 + self._partition_by: list[ColumnOrName] = [] + self._order_by: list[ColumnOrName] = [] + self._rows_between: Optional[tuple[int, int]] = None + self._range_between: Optional[tuple[int, int]] = None def _copy(self) -> "WindowSpec": new_window = WindowSpec() @@ -42,7 +42,7 @@ def partitionBy(self, *cols: Union[ColumnOrName, Sequence[ColumnOrName]]) -> "Wi cols : str, :class:`Column` or list names of columns or expressions """ - all_cols: Union[List[ColumnOrName], List[List[ColumnOrName]]] = list(cols) # type: ignore[assignment] + all_cols: Union[list[ColumnOrName], list[list[ColumnOrName]]] = list(cols) # type: ignore[assignment] if isinstance(all_cols[0], list): all_cols = all_cols[0] @@ -61,7 +61,7 @@ def orderBy(self, *cols: Union[ColumnOrName, Sequence[ColumnOrName]]) -> "Window cols : str, :class:`Column` or list names of columns or expressions """ - all_cols: Union[List[ColumnOrName], List[List[ColumnOrName]]] = list(cols) # type: ignore[assignment] + all_cols: Union[list[ColumnOrName], list[list[ColumnOrName]]] = list(cols) # type: ignore[assignment] if isinstance(all_cols[0], list): all_cols = all_cols[0] @@ -126,20 +126,22 @@ def rangeBetween(self, start: int, end: int) -> "WindowSpec": new_window._range_between = (start, end) return new_window - def _columns_as_str(self, *, cols: List[ColumnOrName], include_order_direction: bool) -> list[str]: + def _columns_as_str(self, *, cols: list[ColumnOrName], include_order_direction: bool) -> list[str]: expressions = [] for col in cols: if isinstance(col, str): expressions.append(col) elif isinstance(col, Column): if include_order_direction: - # TODO: Handle ascending/descending order if needed - raise ContributionsAcceptedError("Column Expression is not supported in WindowSpec.orderBy yet") + # TODO: Handle ascending/descending order if needed # noqa: TD002, TD003 + msg = "Column Expression is not supported in WindowSpec.orderBy yet" + raise ContributionsAcceptedError(msg) else: expressions.append(str(col.expr)) else: - raise PySparkTypeError(f"Invalid column type: {type(col)}") + msg = f"Invalid column type: {type(col)}" + raise PySparkTypeError(msg) return expressions @staticmethod @@ -316,7 +318,7 @@ def orderBy(cls, *cols: Union[ColumnOrName, Sequence[ColumnOrName]]) -> WindowSp @classmethod def rowsBetween(cls, start: int, end: int) -> WindowSpec: - """Creates a :class:`WindowSpec` with the frame boundaries defined, from `start` (inclusive) to `end` (inclusive). + """Creates a :class:`WindowSpec` with the frame boundaries defined, from start (inclusive) to end (inclusive). Both `start` and `end` are relative positions from the current row. For example, "0" means "current row", while "-1" means the row before @@ -391,7 +393,7 @@ def rowsBetween(cls, start: int, end: int) -> WindowSpec: @classmethod def rangeBetween(cls, start: int, end: int) -> WindowSpec: - """Creates a :class:`WindowSpec` with the frame boundaries defined, from `start` (inclusive) to `end` (inclusive). + """Creates a :class:`WindowSpec` with the frame boundaries defined, from start (inclusive) to end (inclusive). Both `start` and `end` are relative from the current row. For example, "0" means "current row", while "-1" means one off before the current row, diff --git a/tests/fast/spark/test_spark_functions_window.py b/tests/fast/spark/test_spark_functions_window.py index fc9bb0c1..20073105 100644 --- a/tests/fast/spark/test_spark_functions_window.py +++ b/tests/fast/spark/test_spark_functions_window.py @@ -95,7 +95,7 @@ def test_deduplicate_rows(self, spark): w = Window.partitionBy(F.col("grp")).orderBy(F.col("value").desc()) with pytest.raises( - ContributionsAcceptedError, match="Column Expression is not supported in WindowSpec.orderBy yet" + ContributionsAcceptedError, match=r"Column Expression is not supported in WindowSpec.orderBy yet" ): df = df.withColumn("rn", F.row_number().over(w)) From 2f254fc0f426ab9edaa165db34d2b7fbfc7ebe72 Mon Sep 17 00:00:00 2001 From: Mario Taddeucci Date: Tue, 17 Feb 2026 12:05:30 -0300 Subject: [PATCH 3/6] Add lag and lead window functions with tests --- duckdb/experimental/spark/sql/functions.py | 171 ++++++++++++++++++ .../fast/spark/test_spark_functions_window.py | 32 ++++ 2 files changed, 203 insertions(+) diff --git a/duckdb/experimental/spark/sql/functions.py b/duckdb/experimental/spark/sql/functions.py index 2cbd0904..8a9cd734 100644 --- a/duckdb/experimental/spark/sql/functions.py +++ b/duckdb/experimental/spark/sql/functions.py @@ -6400,3 +6400,174 @@ def percent_rank() -> Column: +-----+---+ """ return _invoke_function("percent_rank") + + +def lag(col: "ColumnOrName", offset: int = 1, default: Optional[Any] = None) -> Column: # noqa: ANN401 + """Window function: returns the value that is `offset` rows before the current row, and + `default` if there is less than `offset` rows before the current row. For example, + an `offset` of one will return the previous row at any given point in the window partition. + + This is equivalent to the LAG function in SQL. + + .. versionadded:: 1.4.0 + + .. versionchanged:: 3.4.0 + Supports Spark Connect. + + Parameters + ---------- + col : :class:`~pyspark.sql.Column` or column name + name of column or expression + offset : int, optional default 1 + number of row to extend + default : optional + default value + + Returns: + ------- + :class:`~pyspark.sql.Column` + value before current row based on `offset`. + + See Also: + -------- + :meth:`pyspark.sql.functions.lead` + + Examples: + -------- + >>> from pyspark.sql import functions as sf + >>> from pyspark.sql import Window + >>> df = spark.createDataFrame( + ... [("a", 1), ("a", 2), ("a", 3), ("b", 8), ("b", 2)], ["c1", "c2"] + ... ) + >>> df.show() + +---+---+ + | c1| c2| + +---+---+ + | a| 1| + | a| 2| + | a| 3| + | b| 8| + | b| 2| + +---+---+ + + >>> w = Window.partitionBy("c1").orderBy("c2") + >>> df.withColumn("previous_value", sf.lag("c2").over(w)).show() + +---+---+--------------+ + | c1| c2|previous_value| + +---+---+--------------+ + | a| 1| NULL| + | a| 2| 1| + | a| 3| 2| + | b| 2| NULL| + | b| 8| 2| + +---+---+--------------+ + + >>> df.withColumn("previous_value", sf.lag("c2", 1, 0).over(w)).show() + +---+---+--------------+ + | c1| c2|previous_value| + +---+---+--------------+ + | a| 1| 0| + | a| 2| 1| + | a| 3| 2| + | b| 2| 0| + | b| 8| 2| + +---+---+--------------+ + + >>> df.withColumn("previous_value", sf.lag("c2", 2, -1).over(w)).show() + +---+---+--------------+ + | c1| c2|previous_value| + +---+---+--------------+ + | a| 1| -1| + | a| 2| -1| + | a| 3| 1| + | b| 2| -1| + | b| 8| -1| + +---+---+--------------+ + """ # noqa: D205 + return _invoke_function("lag", _to_column_expr(col), ConstantExpression(offset), ConstantExpression(default)) + + +def lead(col: "ColumnOrName", offset: int = 1, default: Optional[Any] = None) -> Column: # noqa: ANN401 + """ + Window function: returns the value that is `offset` rows after the current row, and + `default` if there is less than `offset` rows after the current row. For example, + an `offset` of one will return the next row at any given point in the window partition. + + This is equivalent to the LEAD function in SQL. + + .. versionadded:: 1.4.0 + + .. versionchanged:: 3.4.0 + Supports Spark Connect. + + Parameters + ---------- + col : :class:`~pyspark.sql.Column` or column name + name of column or expression + offset : int, optional default 1 + number of row to extend + default : optional + default value + + Returns: + ------- + :class:`~pyspark.sql.Column` + value after current row based on `offset`. + + See Also: + -------- + :meth:`pyspark.sql.functions.lag` + + Examples: + -------- + >>> from pyspark.sql import functions as sf + >>> from pyspark.sql import Window + >>> df = spark.createDataFrame( + ... [("a", 1), ("a", 2), ("a", 3), ("b", 8), ("b", 2)], ["c1", "c2"] + ... ) + >>> df.show() + +---+---+ + | c1| c2| + +---+---+ + | a| 1| + | a| 2| + | a| 3| + | b| 8| + | b| 2| + +---+---+ + + >>> w = Window.partitionBy("c1").orderBy("c2") + >>> df.withColumn("next_value", sf.lead("c2").over(w)).show() + +---+---+----------+ + | c1| c2|next_value| + +---+---+----------+ + | a| 1| 2| + | a| 2| 3| + | a| 3| NULL| + | b| 2| 8| + | b| 8| NULL| + +---+---+----------+ + + >>> df.withColumn("next_value", sf.lead("c2", 1, 0).over(w)).show() + +---+---+----------+ + | c1| c2|next_value| + +---+---+----------+ + | a| 1| 2| + | a| 2| 3| + | a| 3| 0| + | b| 2| 8| + | b| 8| 0| + +---+---+----------+ + + >>> df.withColumn("next_value", sf.lead("c2", 2, -1).over(w)).show() + +---+---+----------+ + | c1| c2|next_value| + +---+---+----------+ + | a| 1| 3| + | a| 2| -1| + | a| 3| -1| + | b| 2| -1| + | b| 8| -1| + +---+---+----------+ + """ # noqa: D205, D212 + return _invoke_function("lead", _to_column_expr(col), ConstantExpression(offset), ConstantExpression(default)) diff --git a/tests/fast/spark/test_spark_functions_window.py b/tests/fast/spark/test_spark_functions_window.py index 20073105..d8a87914 100644 --- a/tests/fast/spark/test_spark_functions_window.py +++ b/tests/fast/spark/test_spark_functions_window.py @@ -130,3 +130,35 @@ def test_range_between(self, spark): Row(idx=4, value=40, ma_range2=30.0), Row(idx=6, value=60, ma_range2=50.0), ] + + def test_lag(self, spark): + df = spark.createDataFrame(data=[("a", 1), ("a", 2), ("a", 3), ("b", 8), ("b", 2)], schema=["c1", "c2"]) + w = Window.partitionBy("c1").orderBy("c2") + df = df.withColumn("previous_value", F.lag("c2").over(w)) + df = df.withColumn("previous_value_default", F.lag("c2", 1, 0).over(w)) + df = df.withColumn("previous_value_offset2", F.lag("c2", 2, -1).over(w)) + res = df.sort("c1", "c2").collect() + + assert res == [ + Row(c1="a", c2=1, previous_value=None, previous_value_default=0, previous_value_offset2=-1), + Row(c1="a", c2=2, previous_value=1, previous_value_default=1, previous_value_offset2=-1), + Row(c1="a", c2=3, previous_value=2, previous_value_default=2, previous_value_offset2=1), + Row(c1="b", c2=2, previous_value=None, previous_value_default=0, previous_value_offset2=-1), + Row(c1="b", c2=8, previous_value=2, previous_value_default=2, previous_value_offset2=-1), + ] + + def test_lead(self, spark): + df = spark.createDataFrame(data=[("a", 1), ("a", 2), ("a", 3), ("b", 8), ("b", 2)], schema=["c1", "c2"]) + w = Window.partitionBy("c1").orderBy("c2") + df = df.withColumn("next_value", F.lead("c2").over(w)) + df = df.withColumn("next_value_default", F.lead("c2", 1, 0).over(w)) + df = df.withColumn("next_value_offset2", F.lead("c2", 2, -1).over(w)) + res = df.sort("c1", "c2").collect() + + assert res == [ + Row(c1="a", c2=1, next_value=2, next_value_default=2, next_value_offset2=3), + Row(c1="a", c2=2, next_value=3, next_value_default=3, next_value_offset2=-1), + Row(c1="a", c2=3, next_value=None, next_value_default=0, next_value_offset2=-1), + Row(c1="b", c2=2, next_value=8, next_value_default=8, next_value_offset2=-1), + Row(c1="b", c2=8, next_value=None, next_value_default=0, next_value_offset2=-1), + ] From 9718dcdbcd313cbd28abad1ed8c6400a0f5e2804 Mon Sep 17 00:00:00 2001 From: Mario Taddeucci Date: Tue, 17 Feb 2026 12:12:07 -0300 Subject: [PATCH 4/6] Add nth_value window function with tests --- duckdb/experimental/spark/sql/functions.py | 76 +++++++++++++++++++ .../fast/spark/test_spark_functions_window.py | 15 ++++ 2 files changed, 91 insertions(+) diff --git a/duckdb/experimental/spark/sql/functions.py b/duckdb/experimental/spark/sql/functions.py index 8a9cd734..27adf05f 100644 --- a/duckdb/experimental/spark/sql/functions.py +++ b/duckdb/experimental/spark/sql/functions.py @@ -6571,3 +6571,79 @@ def lead(col: "ColumnOrName", offset: int = 1, default: Optional[Any] = None) -> +---+---+----------+ """ # noqa: D205, D212 return _invoke_function("lead", _to_column_expr(col), ConstantExpression(offset), ConstantExpression(default)) + + +def nth_value(col: "ColumnOrName", offset: int, ignoreNulls: Optional[bool] = False) -> Column: + """Window function: returns the value that is the `offset`\\th row of the window frame + (counting from 1), and `null` if the size of window frame is less than `offset` rows. + + It will return the `offset`\\th non-null value it sees when `ignoreNulls` is set to + true. If all values are null, then null is returned. + + This is equivalent to the nth_value function in SQL. + + .. versionadded:: 3.1.0 + + .. versionchanged:: 3.4.0 + Supports Spark Connect. + + Parameters + ---------- + col : :class:`~pyspark.sql.Column` or column name + name of column or expression + offset : int + number of row to use as the value + ignoreNulls : bool, optional + indicates the Nth value should skip null in the + determination of which row to use + + Returns: + ------- + :class:`~pyspark.sql.Column` + value of nth row. + + Examples: + -------- + >>> from pyspark.sql import functions as sf + >>> from pyspark.sql import Window + >>> df = spark.createDataFrame( + ... [("a", 1), ("a", 2), ("a", 3), ("b", 8), ("b", 2)], ["c1", "c2"] + ... ) + >>> df.show() + +---+---+ + | c1| c2| + +---+---+ + | a| 1| + | a| 2| + | a| 3| + | b| 8| + | b| 2| + +---+---+ + + >>> w = Window.partitionBy("c1").orderBy("c2") + >>> df.withColumn("nth_value", sf.nth_value("c2", 1).over(w)).show() + +---+---+---------+ + | c1| c2|nth_value| + +---+---+---------+ + | a| 1| 1| + | a| 2| 1| + | a| 3| 1| + | b| 2| 2| + | b| 8| 2| + +---+---+---------+ + + >>> df.withColumn("nth_value", sf.nth_value("c2", 2).over(w)).show() + +---+---+---------+ + | c1| c2|nth_value| + +---+---+---------+ + | a| 1| NULL| + | a| 2| 2| + | a| 3| 2| + | b| 2| NULL| + | b| 8| 8| + +---+---+---------+ + """ # noqa: D205, D301 + if ignoreNulls: + msg = "The ignoreNulls option of nth_value is not supported yet." + raise ContributionsAcceptedError(msg) + return _invoke_function("nth_value", _to_column_expr(col), ConstantExpression(offset)) diff --git a/tests/fast/spark/test_spark_functions_window.py b/tests/fast/spark/test_spark_functions_window.py index d8a87914..41930094 100644 --- a/tests/fast/spark/test_spark_functions_window.py +++ b/tests/fast/spark/test_spark_functions_window.py @@ -162,3 +162,18 @@ def test_lead(self, spark): Row(c1="b", c2=2, next_value=8, next_value_default=8, next_value_offset2=-1), Row(c1="b", c2=8, next_value=None, next_value_default=0, next_value_offset2=-1), ] + + def test_nth_value(self, spark): + df = spark.createDataFrame(data=[("a", 1), ("a", 2), ("a", 3), ("b", 8), ("b", 2)], schema=["c1", "c2"]) + w = Window.partitionBy("c1").orderBy("c2") + df = df.withColumn("nth1", F.nth_value("c2", 1).over(w)) + df = df.withColumn("nth2", F.nth_value("c2", 2).over(w)) + res = df.sort("c1", "c2").collect() + + assert res == [ + Row(c1="a", c2=1, nth1=1, nth2=None), + Row(c1="a", c2=2, nth1=1, nth2=2), + Row(c1="a", c2=3, nth1=1, nth2=2), + Row(c1="b", c2=2, nth1=2, nth2=None), + Row(c1="b", c2=8, nth1=2, nth2=8), + ] From e8546f8b425603043f739ecbec44b9105b71795b Mon Sep 17 00:00:00 2001 From: Mario Taddeucci Date: Wed, 18 Mar 2026 22:23:56 -0300 Subject: [PATCH 5/6] fix: apply ruff auto-fixes for linting issues --- duckdb/experimental/spark/sql/functions.py | 6 +++--- duckdb/experimental/spark/sql/window.py | 19 +++++++++---------- external/duckdb | 2 +- 3 files changed, 13 insertions(+), 14 deletions(-) diff --git a/duckdb/experimental/spark/sql/functions.py b/duckdb/experimental/spark/sql/functions.py index 27adf05f..a9db714a 100644 --- a/duckdb/experimental/spark/sql/functions.py +++ b/duckdb/experimental/spark/sql/functions.py @@ -6402,7 +6402,7 @@ def percent_rank() -> Column: return _invoke_function("percent_rank") -def lag(col: "ColumnOrName", offset: int = 1, default: Optional[Any] = None) -> Column: # noqa: ANN401 +def lag(col: "ColumnOrName", offset: int = 1, default: Any | None = None) -> Column: # noqa: ANN401 """Window function: returns the value that is `offset` rows before the current row, and `default` if there is less than `offset` rows before the current row. For example, an `offset` of one will return the previous row at any given point in the window partition. @@ -6487,7 +6487,7 @@ def lag(col: "ColumnOrName", offset: int = 1, default: Optional[Any] = None) -> return _invoke_function("lag", _to_column_expr(col), ConstantExpression(offset), ConstantExpression(default)) -def lead(col: "ColumnOrName", offset: int = 1, default: Optional[Any] = None) -> Column: # noqa: ANN401 +def lead(col: "ColumnOrName", offset: int = 1, default: Any | None = None) -> Column: # noqa: ANN401 """ Window function: returns the value that is `offset` rows after the current row, and `default` if there is less than `offset` rows after the current row. For example, @@ -6573,7 +6573,7 @@ def lead(col: "ColumnOrName", offset: int = 1, default: Optional[Any] = None) -> return _invoke_function("lead", _to_column_expr(col), ConstantExpression(offset), ConstantExpression(default)) -def nth_value(col: "ColumnOrName", offset: int, ignoreNulls: Optional[bool] = False) -> Column: +def nth_value(col: "ColumnOrName", offset: int, ignoreNulls: bool | None = False) -> Column: """Window function: returns the value that is the `offset`\\th row of the window frame (counting from 1), and `null` if the size of window frame is less than `offset` rows. diff --git a/duckdb/experimental/spark/sql/window.py b/duckdb/experimental/spark/sql/window.py index a0a6a141..af7eef2d 100644 --- a/duckdb/experimental/spark/sql/window.py +++ b/duckdb/experimental/spark/sql/window.py @@ -1,5 +1,4 @@ -from collections.abc import Sequence # noqa: D100 -from typing import Optional, Union +from collections.abc import Sequence from ..errors import PySparkTypeError from ..exception import ContributionsAcceptedError @@ -21,8 +20,8 @@ class WindowSpec: def __init__(self) -> None: # noqa: D107 self._partition_by: list[ColumnOrName] = [] self._order_by: list[ColumnOrName] = [] - self._rows_between: Optional[tuple[int, int]] = None - self._range_between: Optional[tuple[int, int]] = None + self._rows_between: tuple[int, int] | None = None + self._range_between: tuple[int, int] | None = None def _copy(self) -> "WindowSpec": new_window = WindowSpec() @@ -32,7 +31,7 @@ def _copy(self) -> "WindowSpec": new_window._range_between = self._range_between return new_window - def partitionBy(self, *cols: Union[ColumnOrName, Sequence[ColumnOrName]]) -> "WindowSpec": + def partitionBy(self, *cols: ColumnOrName | Sequence[ColumnOrName]) -> "WindowSpec": """Defines the partitioning columns in a :class:`WindowSpec`. .. versionadded:: 1.4.0 @@ -42,7 +41,7 @@ def partitionBy(self, *cols: Union[ColumnOrName, Sequence[ColumnOrName]]) -> "Wi cols : str, :class:`Column` or list names of columns or expressions """ - all_cols: Union[list[ColumnOrName], list[list[ColumnOrName]]] = list(cols) # type: ignore[assignment] + all_cols: list[ColumnOrName] | list[list[ColumnOrName]] = list(cols) # type: ignore[assignment] if isinstance(all_cols[0], list): all_cols = all_cols[0] @@ -51,7 +50,7 @@ def partitionBy(self, *cols: Union[ColumnOrName, Sequence[ColumnOrName]]) -> "Wi new_window._partition_by = all_cols return new_window - def orderBy(self, *cols: Union[ColumnOrName, Sequence[ColumnOrName]]) -> "WindowSpec": + def orderBy(self, *cols: ColumnOrName | Sequence[ColumnOrName]) -> "WindowSpec": """Defines the ordering columns in a :class:`WindowSpec`. .. versionadded:: 1.4.0 @@ -61,7 +60,7 @@ def orderBy(self, *cols: Union[ColumnOrName, Sequence[ColumnOrName]]) -> "Window cols : str, :class:`Column` or list names of columns or expressions """ - all_cols: Union[list[ColumnOrName], list[list[ColumnOrName]]] = list(cols) # type: ignore[assignment] + all_cols: list[ColumnOrName] | list[list[ColumnOrName]] = list(cols) # type: ignore[assignment] if isinstance(all_cols[0], list): all_cols = all_cols[0] @@ -213,7 +212,7 @@ class Window: unboundedFollowing: int = (1 << 63) - 1 # 9223372036854775807 - equivalent to Java's Long.MAX_VALUE @classmethod - def partitionBy(cls, *cols: Union[ColumnOrName, Sequence[ColumnOrName]]) -> WindowSpec: + def partitionBy(cls, *cols: ColumnOrName | Sequence[ColumnOrName]) -> WindowSpec: """Creates a :class:`WindowSpec` with the partitioning defined. .. versionadded:: 1.4.0 @@ -265,7 +264,7 @@ def partitionBy(cls, *cols: Union[ColumnOrName, Sequence[ColumnOrName]]) -> Wind return WindowSpec().partitionBy(*cols) @classmethod - def orderBy(cls, *cols: Union[ColumnOrName, Sequence[ColumnOrName]]) -> WindowSpec: + def orderBy(cls, *cols: ColumnOrName | Sequence[ColumnOrName]) -> WindowSpec: """Creates a :class:`WindowSpec` with the ordering defined. .. versionadded:: 1.4.0 diff --git a/external/duckdb b/external/duckdb index 376cb731..2e305aac 160000 --- a/external/duckdb +++ b/external/duckdb @@ -1 +1 @@ -Subproject commit 376cb731d34122262308eaf3ecc7781862f8eab5 +Subproject commit 2e305aac809ef22e818024be1e86a1d0ee0d2863 From 92618299b8a4d132f8a8fd38b9080fbb8949e502 Mon Sep 17 00:00:00 2001 From: Mario Taddeucci Date: Wed, 18 Mar 2026 22:37:35 -0300 Subject: [PATCH 6/6] fix: support createDataFrame with list of dicts in Spark API Port schema inference from duckdb/duckdb#18051 to fix #183. When calling spark.createDataFrame([{"col": value}, ...]), the Spark API now infers the schema from dict keys, matching PySpark behavior. Changes: - Add _type_mappings, _array_type_mappings, _has_nulltype, _merge_type, _infer_type, and _infer_schema functions to types.py - Update session.py to handle dict rows in _combine_data_and_schema and add schema inference branch in createDataFrame for list[dict] - Add _inferSchemaFromList method to SparkSession - Fix test_struct_column to use inferred field names instead of col0/col1 - Add test_dataframe_from_list_dicts test case --- duckdb/experimental/spark/sql/session.py | 63 ++++- duckdb/experimental/spark/sql/types.py | 319 ++++++++++++++++++++++- tests/fast/spark/test_spark_column.py | 18 +- tests/fast/spark/test_spark_dataframe.py | 14 + 4 files changed, 394 insertions(+), 20 deletions(-) diff --git a/duckdb/experimental/spark/sql/session.py b/duckdb/experimental/spark/sql/session.py index c407a9f1..40fbaa74 100644 --- a/duckdb/experimental/spark/sql/session.py +++ b/duckdb/experimental/spark/sql/session.py @@ -1,6 +1,7 @@ import uuid from collections.abc import Iterable, Sized -from typing import TYPE_CHECKING, Any, NoReturn, Union +from functools import reduce +from typing import TYPE_CHECKING, Any, List, NoReturn, Optional, Union import duckdb @@ -12,13 +13,13 @@ from ..conf import SparkConf from ..context import SparkContext -from ..errors import PySparkTypeError +from ..errors import PySparkTypeError, PySparkValueError from ..exception import ContributionsAcceptedError from .conf import RuntimeConfig from .dataframe import DataFrame from .readwriter import DataFrameReader from .streaming import DataStreamReader -from .types import StructType +from .types import StructType, _has_nulltype, _infer_schema, _merge_type from .udf import UDFRegistration # In spark: @@ -38,7 +39,11 @@ def _combine_data_and_schema(data: Iterable[Any], schema: StructType) -> list[du new_data = [] for row in data: - new_row = [Value(x, dtype.duckdb_type) for x, dtype in zip(row, [y.dataType for y in schema], strict=False)] + if isinstance(row, dict): + row_values = list(map(row.get, schema.fieldNames())) + else: + row_values = list(row) + new_row = [Value(x, dtype.duckdb_type) for x, dtype in zip(row_values, [y.dataType for y in schema], strict=False)] new_data.append(new_row) return new_data @@ -150,6 +155,9 @@ def createDataFrame( # noqa: D102 types, names = schema.extract_types_and_names() else: names = schema + elif isinstance(data, list) and data: + schema = self._inferSchemaFromList(data) + types, names = schema.extract_types_and_names() try: import pandas @@ -188,6 +196,53 @@ def createDataFrame( # noqa: D102 df = df.toDF(*names) return df + def _inferSchemaFromList( + self, data: Iterable[Any], names: Optional[List[str]] = None + ) -> StructType: + """Infer schema from list of Row, dict, or tuple. + + Parameters + ---------- + data : iterable + list of Row, dict, or tuple + names : list, optional + list of column names + + Returns + ------- + :class:`duckdb.experimental.spark.sql.types.StructType` + """ + if not data: + raise PySparkValueError( + error_class="CANNOT_INFER_EMPTY_SCHEMA", + message_parameters={}, + ) + + # TODO: These should be configurable + infer_dict_as_struct = False + infer_array_from_first_element = False + prefer_timestamp_ntz = False + + schema = reduce( + _merge_type, + ( + _infer_schema( + row, + names, + infer_dict_as_struct=infer_dict_as_struct, + infer_array_from_first_element=infer_array_from_first_element, + prefer_timestamp_ntz=prefer_timestamp_ntz, + ) + for row in data + ), + ) + if _has_nulltype(schema): + raise PySparkValueError( + error_class="CANNOT_DETERMINE_TYPE", + message_parameters={}, + ) + return schema + def newSession(self) -> "SparkSession": # noqa: D102 return SparkSession(self._context) diff --git a/duckdb/experimental/spark/sql/types.py b/duckdb/experimental/spark/sql/types.py index 5bfff09f..b4eed95b 100644 --- a/duckdb/experimental/spark/sql/types.py +++ b/duckdb/experimental/spark/sql/types.py @@ -1,21 +1,25 @@ # This code is based on code from Apache Spark under the license found in the LICENSE # file located in the 'spark' folder. +import array import calendar import datetime +import decimal import math import re import time from builtins import tuple -from collections.abc import Iterator, Mapping +from collections.abc import Iterable, Iterator, Mapping +from functools import reduce from types import MappingProxyType -from typing import Any, ClassVar, NoReturn, TypeVar, Union, cast, overload +from typing import Any, ClassVar, Dict, List, NoReturn, Optional, Tuple, Type, TypeVar, Union, cast, overload from typing_extensions import Self import duckdb from duckdb.sqltypes import DuckDBPyType +from ..errors.exceptions.base import PySparkTypeError from ..exception import ContributionsAcceptedError T = TypeVar("T") @@ -1137,6 +1141,317 @@ def __eq__(self, other: object) -> bool: _INTERVAL_DAYTIME = re.compile(r"interval (day|hour|minute|second)( to (day|hour|minute|second))?") +# Mapping Python types to Spark SQL DataType +_type_mappings = { + type(None): NullType, + bool: BooleanType, + int: LongType, + float: DoubleType, + str: StringType, + bytearray: BinaryType, + decimal.Decimal: DecimalType, + datetime.date: DateType, + datetime.datetime: TimestampType, # can be TimestampNTZType + datetime.time: TimestampType, # can be TimestampNTZType + datetime.timedelta: DayTimeIntervalType, + bytes: BinaryType, +} + + +# The list of all supported array typecodes, is stored here +_array_type_mappings: Dict[str, Type[DataType]] = { + # Warning: Actual properties for float and double in C is not specified in C. + # On almost every system supported by both python and JVM, they are IEEE 754 + # single-precision binary floating-point format and IEEE 754 double-precision + # binary floating-point format. And we do assume the same thing here for now. + "f": FloatType, + "d": DoubleType, +} + + +def _has_nulltype(dt: DataType) -> bool: + """Return whether there is a NullType in `dt` or not""" + if isinstance(dt, StructType): + return any(_has_nulltype(f.dataType) for f in dt.fields) + elif isinstance(dt, ArrayType): + return _has_nulltype((dt.elementType)) + elif isinstance(dt, MapType): + return _has_nulltype(dt.keyType) or _has_nulltype(dt.valueType) + else: + return isinstance(dt, NullType) + + +@overload +def _merge_type( + a: StructType, b: StructType, name: Optional[str] = None +) -> StructType: ... + + +@overload +def _merge_type( + a: ArrayType, b: ArrayType, name: Optional[str] = None +) -> ArrayType: ... + + +@overload +def _merge_type(a: MapType, b: MapType, name: Optional[str] = None) -> MapType: ... + + +@overload +def _merge_type(a: DataType, b: DataType, name: Optional[str] = None) -> DataType: ... + + +def _merge_type( + a: Union[StructType, ArrayType, MapType, DataType], + b: Union[StructType, ArrayType, MapType, DataType], + name: Optional[str] = None, +) -> Union[StructType, ArrayType, MapType, DataType]: + if name is None: + + def new_msg(msg: str) -> str: + return msg + + def new_name(n: str) -> str: + return "field %s" % n + + else: + + def new_msg(msg: str) -> str: + return "%s: %s" % (name, msg) + + def new_name(n: str) -> str: + return "field %s in %s" % (n, name) + + if isinstance(a, NullType): + return b + elif isinstance(b, NullType): + return a + elif isinstance(a, TimestampType) and isinstance(b, TimestampNTZType): + return a + elif isinstance(a, TimestampNTZType) and isinstance(b, TimestampType): + return b + elif isinstance(a, AtomicType) and isinstance(b, StringType): + return b + elif isinstance(a, StringType) and isinstance(b, AtomicType): + return a + elif type(a) is not type(b): + # TODO: type cast (such as int -> long) + raise PySparkTypeError( + error_class="CANNOT_MERGE_TYPE", + message_parameters={ + "data_type1": type(a).__name__, + "data_type2": type(b).__name__, + }, + ) + + # same type + if isinstance(a, StructType): + nfs = dict((f.name, f.dataType) for f in cast(StructType, b).fields) + fields = [ + StructField( + f.name, + _merge_type( + f.dataType, nfs.get(f.name, NullType()), name=new_name(f.name) + ), + ) + for f in a.fields + ] + names = set([f.name for f in fields]) + for n in nfs: + if n not in names: + fields.append(StructField(n, nfs[n])) + return StructType(fields) + + elif isinstance(a, ArrayType): + return ArrayType( + _merge_type( + a.elementType, + cast(ArrayType, b).elementType, + name="element in array %s" % name, + ), + True, + ) + + elif isinstance(a, MapType): + return MapType( + _merge_type( + a.keyType, cast(MapType, b).keyType, name="key of map %s" % name + ), + _merge_type( + a.valueType, cast(MapType, b).valueType, name="value of map %s" % name + ), + True, + ) + else: + return a + + +def _infer_type( + obj: Any, + infer_dict_as_struct: bool = False, + infer_array_from_first_element: bool = False, + prefer_timestamp_ntz: bool = False, +) -> DataType: + """Infer the DataType from obj""" + if obj is None: + return NullType() + + if hasattr(obj, "__UDT__"): + return obj.__UDT__ + + dataType = _type_mappings.get(type(obj)) + if dataType is DecimalType: + # the precision and scale of `obj` may be different from row to row. + return DecimalType(38, 18) + if dataType is TimestampType and prefer_timestamp_ntz and obj.tzinfo is None: + return TimestampNTZType() + if dataType is DayTimeIntervalType: + return DayTimeIntervalType() + elif dataType is not None: + return dataType() + + if isinstance(obj, dict): + if infer_dict_as_struct: + struct = StructType() + for key, value in obj.items(): + if key is not None and value is not None: + struct.add( + key, + _infer_type( + value, + infer_dict_as_struct, + infer_array_from_first_element, + prefer_timestamp_ntz, + ), + True, + ) + return struct + else: + for key, value in obj.items(): + if key is not None and value is not None: + return MapType( + _infer_type( + key, + infer_dict_as_struct, + infer_array_from_first_element, + prefer_timestamp_ntz, + ), + _infer_type( + value, + infer_dict_as_struct, + infer_array_from_first_element, + prefer_timestamp_ntz, + ), + True, + ) + return MapType(NullType(), NullType(), True) + elif isinstance(obj, list): + if len(obj) > 0: + if infer_array_from_first_element: + return ArrayType( + _infer_type( + obj[0], + infer_dict_as_struct, + infer_array_from_first_element, + prefer_timestamp_ntz, + ), + True, + ) + else: + return ArrayType( + reduce( + _merge_type, + ( + _infer_type( + v, + infer_dict_as_struct, + infer_array_from_first_element, + prefer_timestamp_ntz, + ) + for v in obj + ), + ), + True, + ) + return ArrayType(NullType(), True) + elif isinstance(obj, array.array): + if obj.typecode in _array_type_mappings: + return ArrayType(_array_type_mappings[obj.typecode](), False) + else: + raise PySparkTypeError( + error_class="UNSUPPORTED_DATA_TYPE", + message_parameters={"data_type": f"array({obj.typecode})"}, + ) + else: + try: + return _infer_schema( + obj, + infer_dict_as_struct=infer_dict_as_struct, + infer_array_from_first_element=infer_array_from_first_element, + ) + except TypeError: + raise PySparkTypeError( + error_class="UNSUPPORTED_DATA_TYPE", + message_parameters={"data_type": type(obj).__name__}, + ) + + +def _infer_schema( + row: Any, + names: Optional[List[str]] = None, + infer_dict_as_struct: bool = False, + infer_array_from_first_element: bool = False, + prefer_timestamp_ntz: bool = False, +) -> StructType: + """Infer the schema from dict/namedtuple/object""" + items: Iterable[Tuple[str, Any]] + if isinstance(row, dict): + items = sorted(row.items()) + + elif isinstance(row, (tuple, list)): + if hasattr(row, "__fields__"): # Row + items = zip(row.__fields__, tuple(row)) # type: ignore[union-attr] + elif hasattr(row, "_fields"): # namedtuple + items = zip(row._fields, tuple(row)) # type: ignore[union-attr] + else: + if names is None: + names = ["_%d" % i for i in range(1, len(row) + 1)] + elif len(names) < len(row): + names.extend("_%d" % i for i in range(len(names) + 1, len(row) + 1)) + items = zip(names, row) + + elif hasattr(row, "__dict__"): # object + items = sorted(row.__dict__.items()) + + else: + raise PySparkTypeError( + error_class="CANNOT_INFER_SCHEMA_FOR_TYPE", + message_parameters={"data_type": type(row).__name__}, + ) + + fields = [] + for k, v in items: + try: + fields.append( + StructField( + k, + _infer_type( + v, + infer_dict_as_struct, + infer_array_from_first_element, + prefer_timestamp_ntz, + ), + True, + ) + ) + except TypeError: + raise PySparkTypeError( + error_class="CANNOT_INFER_TYPE_FOR_FIELD", + message_parameters={"field_name": k}, + ) + return StructType(fields) + + def _create_row(fields: Union["Row", list[str]], values: tuple[Any, ...] | list[Any]) -> "Row": row = Row(*values) row.__fields__ = fields diff --git a/tests/fast/spark/test_spark_column.py b/tests/fast/spark/test_spark_column.py index b2656643..409cabd0 100644 --- a/tests/fast/spark/test_spark_column.py +++ b/tests/fast/spark/test_spark_column.py @@ -14,21 +14,11 @@ class TestSparkColumn: def test_struct_column(self, spark): df = spark.createDataFrame([Row(a=1, b=2, c=3, d=4)]) - # TODO: column names should be set explicitly using the Row, rather than letting duckdb # noqa: TD002, TD003 - # assign defaults(col0, col1, etc..) - if USE_ACTUAL_SPARK: - df = df.withColumn("struct", struct(df.a, df.b)) - else: - df = df.withColumn("struct", struct(df.col0, df.col1)) - assert "struct" in df - new_col = df.schema["struct"] + df = df.withColumn("struct", struct(df.a, df.b)) + assert "struct" in df - if USE_ACTUAL_SPARK: - assert "a" in df.schema["struct"].dataType.fieldNames() - assert "b" in df.schema["struct"].dataType.fieldNames() - else: - assert "col0" in new_col.dataType - assert "col1" in new_col.dataType + assert "a" in df.schema["struct"].dataType.fieldNames() + assert "b" in df.schema["struct"].dataType.fieldNames() with pytest.raises( PySparkTypeError, match=re.escape("[NOT_COLUMN] Argument `col` should be a Column, got str.") diff --git a/tests/fast/spark/test_spark_dataframe.py b/tests/fast/spark/test_spark_dataframe.py index e242092e..6a2f0079 100644 --- a/tests/fast/spark/test_spark_dataframe.py +++ b/tests/fast/spark/test_spark_dataframe.py @@ -121,6 +121,20 @@ def test_dataframe_from_list_of_tuples(self, spark): with pytest.raises(TypeError, match="must be an iterable, not int"): spark.createDataFrame(address, 5) + def test_dataframe_from_list_dicts(self, spark): + data = [ + {"id": 1, "name": "Alice", "age": 25}, + {"id": 2, "age": 30, "name": "Bob"}, + {"age": 35, "id": 3, "name": "Charlie", "city": "New York"}, + ] + df = spark.createDataFrame(data) + res = df.collect() + assert res == [ + Row(age=25, id=1, name='Alice', city=None), + Row(age=30, id=2, name='Bob', city=None), + Row(age=35, id=3, name='Charlie', city='New York'), + ] + def test_dataframe(self, spark): # Create DataFrame df = spark.createDataFrame([("Scala", 25000), ("Spark", 35000), ("PHP", 21000)])