Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ jobs:
- name: Install python packages
run: |
pip install joblib==${{ matrix.JOBLIB_VERSION }} scikit-learn>=0.23.1 pytest pylint "pyspark[connect]==${{ matrix.PYSPARK_VERSION }}" pandas
- name: Run pylint
run: |
./run-pylint.sh
- name: Run test suites
run: |
TEST_SPARK_CONNECT=${{ matrix.SPARK_CONNECT_MODE }} PYSPARK_VERSION=${{ matrix.PYSPARK_VERSION }} ./run-tests.sh
- name: Run pylint
run: |
./run-pylint.sh
28 changes: 18 additions & 10 deletions joblibspark/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,15 @@ def _create_resource_profile(self,
if self._support_stage_scheduling:
self.using_stage_scheduling = True

default_cpus_per_task = int(self._spark.conf.get("spark.task.cpus", "1"))
default_gpus_per_task = int(self._spark.conf.get("spark.task.resource.gpu.amount", "0"))
if is_spark_connect_mode():
# In Spark Connect mode, we can't read Spark cluster configures.
default_cpus_per_task = 1
default_gpus_per_task = 0
else:
default_cpus_per_task = int(self._spark.conf.get("spark.task.cpus", "1"))
default_gpus_per_task = int(
self._spark.conf.get("spark.task.resource.gpu.amount", "0")
)
num_cpus_per_spark_task = num_cpus_per_spark_task or default_cpus_per_task
num_gpus_per_spark_task = num_gpus_per_spark_task or default_gpus_per_task

Expand Down Expand Up @@ -331,17 +338,14 @@ def mapper_fn(_):
return cloudpickle.loads(ser_res)

try:
# pylint: disable=no-name-in-module,import-outside-toplevel
from pyspark import inheritable_thread_target

if Version(pyspark.__version__).major >= 4 and is_spark_connect_mode():
# pylint: disable=fixme
# TODO: remove this patch once Spark 4.0.0 is released.
# the patch is for propagating the Spark session to current thread.
def patched_inheritable_thread_target(f): # pylint: disable=invalid-name
import functools
import copy
from typing import Any
def inheritable_thread_target(f): # pylint: disable=invalid-name
import functools # pylint: disable=C0415
import copy # pylint: disable=C0415
from typing import Any # pylint: disable=C0415

session = f
assert session is not None, "Spark Connect session must be provided."
Expand All @@ -359,6 +363,7 @@ def outer(ff: Any) -> Any: # pylint: disable=invalid-name
@functools.wraps(ff)
def inner(*args: Any, **kwargs: Any) -> Any:
# Propagates the active spark session to the current thread
# pylint: disable=C0415
from pyspark.sql.connect.session import SparkSession as SCS

# pylint: disable=protected-access,no-member
Expand All @@ -374,7 +379,10 @@ def inner(*args: Any, **kwargs: Any) -> Any:

return outer

inheritable_thread_target = patched_inheritable_thread_target(self._spark)
inheritable_thread_target = inheritable_thread_target(self._spark)
else:
# pylint: disable=no-name-in-module
from pyspark import inheritable_thread_target # pylint: disable=C0415

run_on_worker_and_fetch_result = \
inheritable_thread_target(run_on_worker_and_fetch_result)
Expand Down
8 changes: 6 additions & 2 deletions test/test_spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,8 +203,12 @@ def get_spark_context(x):
assert len(taskcontext.resources().get("gpu").addresses) == 1
return TaskContext.get()

with parallel_backend('spark') as (ba, _):
Parallel(n_jobs=5)(delayed(get_spark_context)(i) for i in range(10))
if is_spark_connect_mode:
with parallel_backend('spark', num_gpus_per_spark_task=1) as (ba, _):
Parallel(n_jobs=5)(delayed(get_spark_context)(i) for i in range(10))
else:
with parallel_backend('spark') as (ba, _):
Parallel(n_jobs=5)(delayed(get_spark_context)(i) for i in range(10))

def test_customized_resource_group(self):
def get_spark_context(x):
Expand Down