From 0e233ca45bd5220a9496eb3adb1e68f489e0f054 Mon Sep 17 00:00:00 2001 From: Tian Gao Date: Mon, 2 Mar 2026 14:52:03 -0800 Subject: [PATCH 1/2] Add datasource tests with simple worker --- python/pyspark/logger/worker_io.py | 9 +++-- .../sql/tests/test_python_datasource.py | 39 ++++++++++++++----- 2 files changed, 36 insertions(+), 12 deletions(-) diff --git a/python/pyspark/logger/worker_io.py b/python/pyspark/logger/worker_io.py index 79684b7aca624..3522c58e8b75a 100644 --- a/python/pyspark/logger/worker_io.py +++ b/python/pyspark/logger/worker_io.py @@ -223,7 +223,11 @@ def context_provider() -> dict[str, str]: - class_name: Name of the class that initiated the logging if available """ - def is_pyspark_module(module_name: str) -> bool: + def is_pyspark_module(frame: FrameType) -> bool: + module_name = frame.f_globals.get("__name__", "") + if module_name == "__main__": + if mod := sys.modules.get("__main__", None): + module_name = mod.__spec__.name return module_name.startswith("pyspark.") and ".tests." not in module_name bottom: Optional[FrameType] = None @@ -236,9 +240,8 @@ def is_pyspark_module(module_name: str) -> bool: if frame: while frame.f_back: f_back = frame.f_back - module_name = f_back.f_globals.get("__name__", "") - if is_pyspark_module(module_name): + if is_pyspark_module(f_back): if not is_in_pyspark_module: bottom = frame is_in_pyspark_module = True diff --git a/python/pyspark/sql/tests/test_python_datasource.py b/python/pyspark/sql/tests/test_python_datasource.py index 1bdb7a5395e1b..9d90082c654d7 100644 --- a/python/pyspark/sql/tests/test_python_datasource.py +++ b/python/pyspark/sql/tests/test_python_datasource.py @@ -1237,8 +1237,20 @@ def writer(self, schema, overwrite): logs = self.spark.tvf.python_worker_logs() + # We could get either 1 or 2 "TestJsonWriter.write: abort test" logs because + # the operation is time sensitive. When the first partition gets aborted, + # the executor will cancel the rest of the tasks. Whether we are able to get + # the second log depends on whether the second partition starts before the + # cancellation. When we use simple worker, the second log is often missing + # because the spawn overhead is large. + non_abort_logs = logs.select("level", "msg", "context", "logger").filter( + "msg != 'TestJsonWriter.write: abort test'" + ) + abort_logs = logs.select("level", "msg", "context", "logger").filter( + "msg == 'TestJsonWriter.write: abort test'" + ) assertDataFrameEqual( - logs.select("level", "msg", "context", "logger"), + non_abort_logs, [ Row( level="WARNING", @@ -1283,14 +1295,6 @@ def writer(self, schema, overwrite): "TestJsonWriter.__init__: ['abort', 'path']", {"class_name": "TestJsonDataSource", "func_name": "writer"}, ), - ( - "TestJsonWriter.write: abort test", - {"class_name": "TestJsonWriter", "func_name": "write"}, - ), - ( - "TestJsonWriter.write: abort test", - {"class_name": "TestJsonWriter", "func_name": "write"}, - ), ( "TestJsonWriter.abort", {"class_name": "TestJsonWriter", "func_name": "abort"}, @@ -1298,6 +1302,17 @@ def writer(self, schema, overwrite): ] ], ) + assertDataFrameEqual( + abort_logs.dropDuplicates(["msg"]), + [ + Row( + level="WARNING", + msg="TestJsonWriter.write: abort test", + context={"class_name": "TestJsonWriter", "func_name": "write"}, + logger="test_datasource_writer", + ) + ], + ) def test_data_source_perf_profiler(self): with self.sql_conf({"spark.sql.pyspark.dataSource.profiler": "perf"}): @@ -1345,6 +1360,12 @@ class PythonDataSourceTests(BasePythonDataSourceTestsMixin, ReusedSQLTestCase): ... +class PythonDataSourceTestsWithSimpleWorker(PythonDataSourceTests): + @classmethod + def conf(self): + return super().conf().set("spark.python.use.daemon", "false") + + if __name__ == "__main__": from pyspark.testing import main From b8eb0da27e5d78a8f738658c511d7a41378b6962 Mon Sep 17 00:00:00 2001 From: Tian Gao Date: Mon, 2 Mar 2026 16:19:33 -0800 Subject: [PATCH 2/2] Fix type hint --- python/pyspark/logger/worker_io.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/logger/worker_io.py b/python/pyspark/logger/worker_io.py index 3522c58e8b75a..7843b43e10ee9 100644 --- a/python/pyspark/logger/worker_io.py +++ b/python/pyspark/logger/worker_io.py @@ -226,7 +226,7 @@ def context_provider() -> dict[str, str]: def is_pyspark_module(frame: FrameType) -> bool: module_name = frame.f_globals.get("__name__", "") if module_name == "__main__": - if mod := sys.modules.get("__main__", None): + if (mod := sys.modules.get("__main__", None)) and mod.__spec__: module_name = mod.__spec__.name return module_name.startswith("pyspark.") and ".tests." not in module_name