diff --git a/metaflow/multicore_utils.py b/metaflow/multicore_utils.py index 165f9837e94..f1309fb7329 100644 --- a/metaflow/multicore_utils.py +++ b/metaflow/multicore_utils.py @@ -22,7 +22,7 @@ try: # Python 2 import cPickle as pickle -except: +except ImportError: # Python 3 import pickle @@ -66,7 +66,7 @@ def _spawn( with open(output_file, "wb") as f: pickle.dump(ret, f, protocol=pickle.HIGHEST_PROTOCOL) exit_code = 0 - except: + except BaseException: # we must not let any exceptions escape this function # which might trigger unintended side-effects traceback.print_exc() diff --git a/test/unit/test_multicore_utils.py b/test/unit/test_multicore_utils.py index 9b14a20b99a..3ec3ef69f6d 100644 --- a/test/unit/test_multicore_utils.py +++ b/test/unit/test_multicore_utils.py @@ -1,7 +1,9 @@ -from metaflow.multicore_utils import parallel_map +import pytest +from metaflow.multicore_utils import MulticoreException, parallel_imap_unordered, parallel_map -def test_parallel_map(): + +def test_parallel_map_basic(): assert parallel_map(lambda s: s.upper(), ["a", "b", "c", "d", "e", "f"]) == [ "A", "B", @@ -10,3 +12,70 @@ def test_parallel_map(): "E", "F", ] + + +def test_parallel_map_preserves_order(): + result = parallel_map(lambda x: x * 2, range(10)) + assert result == [0, 2, 4, 6, 8, 10, 12, 14, 16, 18] + + +def test_parallel_map_empty_input(): + assert parallel_map(lambda x: x, []) == [] + + +def test_parallel_map_single_element(): + assert parallel_map(lambda x: x + 1, [41]) == [42] + + +def test_parallel_map_with_max_parallel(): + result = parallel_map(lambda x: x**2, range(5), max_parallel=2) + assert result == [0, 1, 4, 9, 16] + + +def test_parallel_map_with_large_dataset(): + result = parallel_map(lambda x: x * 3, range(100)) + assert result == [x * 3 for x in range(100)] + + +def test_parallel_imap_unordered_basic(): + results = list(parallel_imap_unordered(lambda x: x * 2, range(4))) + assert sorted(results) == [0, 2, 4, 6] + + +def test_parallel_imap_unordered_empty(): + assert list(parallel_imap_unordered(lambda x: x, [])) == [] + + +def test_parallel_imap_unordered_with_max_parallel(): + results = list(parallel_imap_unordered(lambda x: x + 1, range(3), max_parallel=1)) + assert sorted(results) == [1, 2, 3] + + +def test_parallel_map_raises_on_child_failure(): + def failing_func(x): + if x == 2: + raise ValueError("Child process failure") + return x + + with pytest.raises(MulticoreException, match="Child failed"): + parallel_map(failing_func, [1, 2, 3]) + + +def test_parallel_map_returns_complex_objects(): + def make_dict(x): + return {"key": x, "nested": {"value": x * 2}} + + result = parallel_map(make_dict, [1, 2, 3]) + assert result == [ + {"key": 1, "nested": {"value": 2}}, + {"key": 2, "nested": {"value": 4}}, + {"key": 3, "nested": {"value": 6}}, + ] + + +def test_parallel_map_with_named_function(): + def square(x): + return x * x + + result = parallel_map(square, [1, 2, 3, 4, 5]) + assert result == [1, 4, 9, 16, 25]