diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 49da6ed..09fa802 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -26,9 +26,9 @@ jobs: - name: Install python packages run: | pip install joblib==${{ matrix.JOBLIB_VERSION }} scikit-learn>=0.23.1 pytest pylint "pyspark[connect]==${{ matrix.PYSPARK_VERSION }}" pandas - - name: Run pylint - run: | - ./run-pylint.sh - name: Run test suites run: | TEST_SPARK_CONNECT=${{ matrix.SPARK_CONNECT_MODE }} PYSPARK_VERSION=${{ matrix.PYSPARK_VERSION }} ./run-tests.sh + - name: Run pylint + run: | + ./run-pylint.sh diff --git a/joblibspark/backend.py b/joblibspark/backend.py index c3ea2ec..61583f5 100644 --- a/joblibspark/backend.py +++ b/joblibspark/backend.py @@ -153,8 +153,15 @@ def _create_resource_profile(self, if self._support_stage_scheduling: self.using_stage_scheduling = True - default_cpus_per_task = int(self._spark.conf.get("spark.task.cpus", "1")) - default_gpus_per_task = int(self._spark.conf.get("spark.task.resource.gpu.amount", "0")) + if is_spark_connect_mode(): + # In Spark Connect mode, we can't read Spark cluster configures. + default_cpus_per_task = 1 + default_gpus_per_task = 0 + else: + default_cpus_per_task = int(self._spark.conf.get("spark.task.cpus", "1")) + default_gpus_per_task = int( + self._spark.conf.get("spark.task.resource.gpu.amount", "0") + ) num_cpus_per_spark_task = num_cpus_per_spark_task or default_cpus_per_task num_gpus_per_spark_task = num_gpus_per_spark_task or default_gpus_per_task @@ -331,17 +338,14 @@ def mapper_fn(_): return cloudpickle.loads(ser_res) try: - # pylint: disable=no-name-in-module,import-outside-toplevel - from pyspark import inheritable_thread_target - if Version(pyspark.__version__).major >= 4 and is_spark_connect_mode(): # pylint: disable=fixme # TODO: remove this patch once Spark 4.0.0 is released. # the patch is for propagating the Spark session to current thread. - def patched_inheritable_thread_target(f): # pylint: disable=invalid-name - import functools - import copy - from typing import Any + def inheritable_thread_target(f): # pylint: disable=invalid-name + import functools # pylint: disable=C0415 + import copy # pylint: disable=C0415 + from typing import Any # pylint: disable=C0415 session = f assert session is not None, "Spark Connect session must be provided." @@ -359,6 +363,7 @@ def outer(ff: Any) -> Any: # pylint: disable=invalid-name @functools.wraps(ff) def inner(*args: Any, **kwargs: Any) -> Any: # Propagates the active spark session to the current thread + # pylint: disable=C0415 from pyspark.sql.connect.session import SparkSession as SCS # pylint: disable=protected-access,no-member @@ -374,7 +379,10 @@ def inner(*args: Any, **kwargs: Any) -> Any: return outer - inheritable_thread_target = patched_inheritable_thread_target(self._spark) + inheritable_thread_target = inheritable_thread_target(self._spark) + else: + # pylint: disable=no-name-in-module + from pyspark import inheritable_thread_target # pylint: disable=C0415 run_on_worker_and_fetch_result = \ inheritable_thread_target(run_on_worker_and_fetch_result) diff --git a/test/test_spark.py b/test/test_spark.py index e871b45..6a776c9 100644 --- a/test/test_spark.py +++ b/test/test_spark.py @@ -203,8 +203,12 @@ def get_spark_context(x): assert len(taskcontext.resources().get("gpu").addresses) == 1 return TaskContext.get() - with parallel_backend('spark') as (ba, _): - Parallel(n_jobs=5)(delayed(get_spark_context)(i) for i in range(10)) + if is_spark_connect_mode: + with parallel_backend('spark', num_gpus_per_spark_task=1) as (ba, _): + Parallel(n_jobs=5)(delayed(get_spark_context)(i) for i in range(10)) + else: + with parallel_backend('spark') as (ba, _): + Parallel(n_jobs=5)(delayed(get_spark_context)(i) for i in range(10)) def test_customized_resource_group(self): def get_spark_context(x):