diff --git a/duckdb/experimental/spark/sql/dataframe.py b/duckdb/experimental/spark/sql/dataframe.py index 83b2dd09..dcedc0b9 100644 --- a/duckdb/experimental/spark/sql/dataframe.py +++ b/duckdb/experimental/spark/sql/dataframe.py @@ -1,5 +1,6 @@ +import itertools import uuid -from collections.abc import Callable +from collections.abc import Callable, Iterable, Iterator from functools import reduce from keyword import iskeyword from typing import ( @@ -19,6 +20,8 @@ from .type_utils import duckdb_to_spark_schema from .types import Row, StructType +_LOCAL_ITERATOR_BATCH_SIZE = 10_000 + if TYPE_CHECKING: import pyarrow as pa from pandas.core.frame import DataFrame as PandasDataFrame @@ -27,7 +30,13 @@ from .group import GroupedData from .session import SparkSession -from duckdb.experimental.spark.sql import functions as spark_sql_functions +from duckdb.experimental.spark.sql import functions as spark_sql_functions # noqa: E402 + + +def _construct_row(values: Iterable, names: list[str]) -> Row: + row = tuple.__new__(Row, list(values)) + row.__fields__ = list(names) + return row class DataFrame: # noqa: D101 @@ -70,6 +79,149 @@ def toArrow(self) -> "pa.Table": """ return self.relation.to_arrow_table() + def toLocalIterator(self, prefetchPartitions: bool = False) -> Iterator[Row]: + """Returns an iterator that contains all of the rows in this :class:`DataFrame`. + + The iterator will consume as much memory as the largest partition in this + :class:`DataFrame`. With prefetch it may consume up to the memory of the 2 largest + partitions. + + .. versionadded:: 2.0.0 + + .. versionchanged:: 3.4.0 + Supports Spark Connect. + + Parameters + ---------- + prefetchPartitions : bool, optional + If Spark should pre-fetch the next partition before it is needed. + + .. versionchanged:: 3.4.0 + This argument does not take effect for Spark Connect. + + Returns: + ------- + Iterator + Iterator of rows. + + Examples: + -------- + >>> df = spark.createDataFrame([(14, "Tom"), (23, "Alice"), (16, "Bob")], ["age", "name"]) + >>> list(df.toLocalIterator()) + [Row(age=14, name='Tom'), Row(age=23, name='Alice'), Row(age=16, name='Bob')] + """ + columns = self.relation.columns + cur = self.relation.execute() + + try: + while rows := cur.fetchmany(_LOCAL_ITERATOR_BATCH_SIZE): + yield from (_construct_row(x, columns) for x in rows) + finally: + cur.close() + + def foreach(self, f: Callable[[Row], None]) -> None: + """Applies the ``f`` function to all :class:`Row` of this :class:`DataFrame`. + + This is a shorthand for ``df.rdd.foreach()``. + + .. versionadded:: 1.3.0 + + .. versionchanged:: 4.0.0 + Supports Spark Connect. + + Parameters + ---------- + f : function + A function that accepts one parameter which will + receive each row to process. + + Examples: + -------- + >>> df = spark.createDataFrame([(14, "Tom"), (23, "Alice"), (16, "Bob")], ["age", "name"]) + >>> def func(person): + ... print(person.name) + >>> df.foreach(func) + """ + for row in self.toLocalIterator(): + f(row) + + def foreachPartition(self, f: Callable[[Iterator[Row]], None]) -> None: + """Applies the ``f`` function to each partition of this :class:`DataFrame`. + + This a shorthand for ``df.rdd.foreachPartition()``. + + .. versionadded:: 1.3.0 + + .. versionchanged:: 4.0.0 + Supports Spark Connect. + + Parameters + ---------- + f : function + A function that accepts one parameter which will receive + each partition to process. + + Examples: + -------- + >>> df = spark.createDataFrame([(14, "Tom"), (23, "Alice"), (16, "Bob")], ["age", "name"]) + >>> def func(itr): + ... for person in itr: + ... print(person.name) + >>> df.foreachPartition(func) + """ + rows_generator = self.toLocalIterator() + while rows := itertools.islice(rows_generator, _LOCAL_ITERATOR_BATCH_SIZE): + f(iter(rows)) + + def isEmpty(self) -> bool: + """Checks if the :class:`DataFrame` is empty and returns a boolean value. + + .. versionadded:: 3.3.0 + + .. versionchanged:: 3.4.0 + Supports Spark Connect. + + Returns: + ------- + bool + Returns ``True`` if the DataFrame is empty, ``False`` otherwise. + + See Also: + -------- + DataFrame.count : Counts the number of rows in DataFrame. + + Notes: + ----- + - An empty DataFrame has no rows. It may have columns, but no data. + + Examples: + -------- + Example 1: Checking if an empty DataFrame is empty + + >>> df_empty = spark.createDataFrame([], "a STRING") + >>> df_empty.isEmpty() + True + + Example 2: Checking if a non-empty DataFrame is empty + + >>> df_non_empty = spark.createDataFrame(["a"], "STRING") + >>> df_non_empty.isEmpty() + False + + Example 3: Checking if a DataFrame with null values is empty + + >>> df_nulls = spark.createDataFrame([(None, None)], "a STRING, b INT") + >>> df_nulls.isEmpty() + False + + Example 4: Checking if a DataFrame with no rows but with columns is empty + + >>> df_no_rows = spark.createDataFrame([], "id INT, value STRING") + >>> df_no_rows.isEmpty() + True + """ + return self.first() is None + def createOrReplaceTempView(self, name: str) -> None: """Creates or replaces a local temporary view with this :class:`DataFrame`. @@ -1392,12 +1544,7 @@ def collect(self) -> list[Row]: # noqa: D102 columns = self.relation.columns result = self.relation.fetchall() - def construct_row(values: list, names: list[str]) -> Row: - row = tuple.__new__(Row, list(values)) - row.__fields__ = list(names) - return row - - rows = [construct_row(x, columns) for x in result] + rows = [_construct_row(x, columns) for x in result] return rows def cache(self) -> "DataFrame": diff --git a/external/duckdb b/external/duckdb index 376cb731..2e305aac 160000 --- a/external/duckdb +++ b/external/duckdb @@ -1 +1 @@ -Subproject commit 376cb731d34122262308eaf3ecc7781862f8eab5 +Subproject commit 2e305aac809ef22e818024be1e86a1d0ee0d2863 diff --git a/tests/fast/spark/test_spark_dataframe.py b/tests/fast/spark/test_spark_dataframe.py index e242092e..19e99921 100644 --- a/tests/fast/spark/test_spark_dataframe.py +++ b/tests/fast/spark/test_spark_dataframe.py @@ -1,3 +1,5 @@ +from unittest import mock + import pytest _ = pytest.importorskip("duckdb.experimental.spark") @@ -597,3 +599,40 @@ def test_treeString_array_type(self, spark): assert " |-- name:" in tree assert " |-- hobbies: array<" in tree assert "(nullable = true)" in tree + + def test_method_is_empty(self, spark): + data = [(1, "Alice"), (2, "Bob")] + df = spark.createDataFrame(data, ["id", "name"]) + empty_df = spark.createDataFrame([], schema=df.schema) + + assert not df.isEmpty() + assert empty_df.isEmpty() + + def test_dataframe_foreach(self, spark): + data = [(56, "Carol"), (20, "Alice"), (3, "Dave")] + df = spark.createDataFrame(data, ["age", "name"]) + expected = [Row(age=56, name="Carol"), Row(age=20, name="Alice"), Row(age=3, name="Dave")] + + mock_callable = mock.MagicMock() + df.foreach(mock_callable) + mock_callable.assert_has_calls( + [mock.call(expected[0]), mock.call(expected[1]), mock.call(expected[2])], + any_order=True, + ) + + def test_dataframe_foreach_partition(self, spark): + data = [(56, "Carol"), (20, "Alice"), (3, "Dave")] + df = spark.createDataFrame(data, ["age", "name"]) + expected = [Row(age=56, name="Carol"), Row(age=20, name="Alice"), Row(age=3, name="Dave")] + + mock_callable = mock.MagicMock() + df.foreachPartition(mock_callable) + mock_callable.assert_called_once_with(expected) + + def test_to_local_iterator(self, spark): + data = [(56, "Carol"), (20, "Alice"), (3, "Dave")] + df = spark.createDataFrame(data, ["age", "name"]) + expected = [Row(age=56, name="Carol"), Row(age=20, name="Alice"), Row(age=3, name="Dave")] + + res = list(df.toLocalIterator()) + assert res == expected