From a5ee30d4d7c56f325079fbb100bc51444597c3df Mon Sep 17 00:00:00 2001 From: AMOOOMA Date: Mon, 15 Dec 2025 21:45:16 +0000 Subject: [PATCH 01/13] Allow multiprocessshared to spawn process and delete directly with obj --- .../apache_beam/utils/multi_process_shared.py | 272 ++++++++++++++++-- .../utils/multi_process_shared_test.py | 218 ++++++++++++++ 2 files changed, 463 insertions(+), 27 deletions(-) diff --git a/sdks/python/apache_beam/utils/multi_process_shared.py b/sdks/python/apache_beam/utils/multi_process_shared.py index aecb1284a1d4..0efa01f45570 100644 --- a/sdks/python/apache_beam/utils/multi_process_shared.py +++ b/sdks/python/apache_beam/utils/multi_process_shared.py @@ -25,6 +25,10 @@ import logging import multiprocessing.managers import os +import time +import traceback +import atexit +import sys import tempfile import threading from typing import Any @@ -79,6 +83,10 @@ def singletonProxy_release(self): assert self._SingletonProxy_valid self._SingletonProxy_valid = False + def unsafe_hard_delete(self): + assert self._SingletonProxy_valid + self._SingletonProxy_entry.unsafe_hard_delete() + def __getattr__(self, name): if not self._SingletonProxy_valid: raise RuntimeError('Entry was released.') @@ -105,17 +113,39 @@ def __dir__(self): dir = self._SingletonProxy_entry.obj.__dir__() dir.append('singletonProxy_call__') dir.append('singletonProxy_release') + dir.append('unsafe_hard_delete') return dir +def _run_with_oom_protection(func, *args, **kwargs): + try: + return func(*args, **kwargs) + except Exception as e: + # Check string to avoid hard import dependency + if 'CUDA out of memory' in str(e): + logging.warning("Caught CUDA OOM during operation. Cleaning memory.") + try: + import gc + import torch + gc.collect() + torch.cuda.empty_cache() + except ImportError: + pass + except Exception as cleanup_error: + logging.error("Failed to clean up CUDA memory: %s", cleanup_error) + raise e + + class _SingletonEntry: """Represents a single, refcounted entry in this process.""" - def __init__(self, constructor, initialize_eagerly=True): + def __init__( + self, constructor, initialize_eagerly=True, hard_delete_callback=None): self.constructor = constructor + self._hard_delete_callback = hard_delete_callback self.refcount = 0 self.lock = threading.Lock() if initialize_eagerly: - self.obj = constructor() + self.obj = _run_with_oom_protection(constructor) self.initialied = True else: self.initialied = False @@ -123,7 +153,7 @@ def __init__(self, constructor, initialize_eagerly=True): def acquire(self): with self.lock: if not self.initialied: - self.obj = self.constructor() + self.obj = _run_with_oom_protection(self.constructor) self.initialied = True self.refcount += 1 return _SingletonProxy(self) @@ -141,14 +171,28 @@ def unsafe_hard_delete(self): if self.initialied: del self.obj self.initialied = False + if self._hard_delete_callback: + self._hard_delete_callback() class _SingletonManager: entries: Dict[Any, Any] = {} - def register_singleton(self, constructor, tag, initialize_eagerly=True): + def __init__(self): + self._hard_delete_callback = None + + def set_hard_delete_callback(self, callback): + self._hard_delete_callback = callback + + def register_singleton( + self, + constructor, + tag, + initialize_eagerly=True, + hard_delete_callback=None): assert tag not in self.entries, tag - self.entries[tag] = _SingletonEntry(constructor, initialize_eagerly) + self.entries[tag] = _SingletonEntry( + constructor, initialize_eagerly, hard_delete_callback) def has_singleton(self, tag): return tag in self.entries @@ -160,7 +204,8 @@ def release_singleton(self, tag, obj): return self.entries[tag].release(obj) def unsafe_hard_delete_singleton(self, tag): - return self.entries[tag].unsafe_hard_delete() + self.entries[tag].unsafe_hard_delete() + self._hard_delete_callback() _process_level_singleton_manager = _SingletonManager() @@ -200,9 +245,99 @@ def __call__(self, *args, **kwargs): def __getattr__(self, name): return getattr(self._proxyObject, name) + def __setstate__(self, state): + self.__dict__.update(state) + + def __getstate__(self): + return self.__dict__ + def get_auto_proxy_object(self): return self._proxyObject + def unsafe_hard_delete(self): + try: + self._proxyObject.unsafe_hard_delete() + except (EOFError, ConnectionResetError, BrokenPipeError): + pass + except Exception as e: + logging.warning( + "Exception %s when trying to hard delete shared object proxy", e) + + +def _run_server_process(address_file, tag, constructor, authkey): + """ + Runs in a separate process. + Includes a 'Suicide Pact' monitor: If parent dies, I die. + """ + parent_pid = os.getppid() + + def cleanup_files(): + logging.info("Server process exiting. Deleting files for %s", tag) + try: + if os.path.exists(address_file): + os.remove(address_file) + if os.path.exists(address_file + ".error"): + os.remove(address_file + ".error") + except Exception: + pass + + def handle_unsafe_hard_delete(): + cleanup_files() + os._exit(0) + + def _monitor_parent(): + """Checks if parent is alive every second.""" + while True: + try: + os.kill(parent_pid, 0) + except OSError: + logging.warning( + "Process %s detected Parent %s died. Self-destructing.", + os.getpid(), + parent_pid) + cleanup_files() + os._exit(0) + time.sleep(0.5) + + atexit.register(cleanup_files) + + try: + t = threading.Thread(target=_monitor_parent, daemon=True) + t.start() + + logging.getLogger().setLevel(logging.INFO) + multiprocessing.current_process().authkey = authkey + + serving_manager = _SingletonRegistrar( + address=('localhost', 0), authkey=authkey) + _process_level_singleton_manager.set_hard_delete_callback( + handle_unsafe_hard_delete) + _process_level_singleton_manager.register_singleton( + constructor, + tag, + initialize_eagerly=True, + hard_delete_callback=handle_unsafe_hard_delete) + + server = serving_manager.get_server() + logging.info( + 'Process %s: Proxy serving %s at %s', os.getpid(), tag, server.address) + + with open(address_file + '.tmp', 'w') as fout: + fout.write('%s:%d' % server.address) + os.rename(address_file + '.tmp', address_file) + + server.serve_forever() + + except Exception: + tb = traceback.format_exc() + try: + with open(address_file + ".error.tmp", 'w') as fout: + fout.write(tb) + os.rename(address_file + ".error.tmp", address_file + ".error") + except Exception: + print(f"CRITICAL ERROR IN SHARED SERVER:\n{tb}", file=sys.stderr) + os._exit(1) + class MultiProcessShared(Generic[T]): """MultiProcessShared is used to share a single object across processes. @@ -252,7 +387,8 @@ def __init__( tag: Any, *, path: str = tempfile.gettempdir(), - always_proxy: Optional[bool] = None): + always_proxy: Optional[bool] = None, + spawn_process: bool = False): self._constructor = constructor self._tag = tag self._path = path @@ -262,6 +398,7 @@ def __init__( self._rpc_address = None self._cross_process_lock = fasteners.InterProcessLock( os.path.join(self._path, self._tag) + '.lock') + self._spawn_process = spawn_process def _get_manager(self): if self._manager is None: @@ -301,6 +438,10 @@ def acquire(self): # Caveat: They must always agree, as they will be ignored if the object # is already constructed. singleton = self._get_manager().acquire_singleton(self._tag) + # Trigger a sweep of zombie processes. + # calling active_children() has the side-effect of joining any finished + # processes, effectively reaping zombies from previous unsafe_hard_deletes. + if self._spawn_process: multiprocessing.active_children() return _AutoProxyWrapper(singleton) def release(self, obj): @@ -315,25 +456,102 @@ def unsafe_hard_delete(self): to this object exist, or (b) you are ok with all existing references to this object throwing strange errors when derefrenced. """ - self._get_manager().unsafe_hard_delete_singleton(self._tag) + try: + self._get_manager().unsafe_hard_delete_singleton(self._tag) + except (EOFError, ConnectionResetError, BrokenPipeError): + pass + except Exception as e: + logging.warning( + "Exception %s when trying to hard delete shared object %s", + e, + self._tag) def _create_server(self, address_file): - # We need to be able to authenticate with both the manager and the process. - self._serving_manager = _SingletonRegistrar( - address=('localhost', 0), authkey=AUTH_KEY) - multiprocessing.current_process().authkey = AUTH_KEY - # Initialize eagerly to avoid acting as the server if there are issues. - # Note, however, that _create_server itself is called lazily. - _process_level_singleton_manager.register_singleton( - self._constructor, self._tag, initialize_eagerly=True) - self._server = self._serving_manager.get_server() - logging.info( - 'Starting proxy server at %s for shared %s', - self._server.address, - self._tag) - with open(address_file + '.tmp', 'w') as fout: - fout.write('%s:%d' % self._server.address) - os.rename(address_file + '.tmp', address_file) - t = threading.Thread(target=self._server.serve_forever, daemon=True) - t.start() - logging.info('Done starting server') + if self._spawn_process: + error_file = address_file + ".error" + + if os.path.exists(error_file): + try: + os.remove(error_file) + except OSError: + pass + + ctx = multiprocessing.get_context('spawn') + p = ctx.Process( + target=_run_server_process, + args=(address_file, self._tag, self._constructor, AUTH_KEY), + daemon=False # Must be False for nested proxies + ) + p.start() + logging.info("Parent: Waiting for %s to write address file...", self._tag) + + def cleanup_process(): + if p.is_alive(): + logging.info( + "Parent: Terminating server process %s for %s", p.pid, self._tag) + p.terminate() + p.join() + try: + if os.path.exists(address_file): + os.remove(address_file) + if os.path.exists(error_file): + os.remove(error_file) + except Exception: + pass + + atexit.register(cleanup_process) + + start_time = time.time() + last_log = start_time + while True: + if os.path.exists(address_file): + break + + if os.path.exists(error_file): + with open(error_file, 'r') as f: + error_msg = f.read() + try: + os.remove(error_file) + except OSError: + pass + + if p.is_alive(): p.terminate() + raise RuntimeError(f"Shared Server Process crashed:\n{error_msg}") + + if not p.is_alive(): + exit_code = p.exitcode + raise RuntimeError( + "Shared Server Process died unexpectedly" + f" with exit code {exit_code}") + + if time.time() - last_log > 300: + logging.warning( + "Still waiting for %s to initialize... %ss elapsed)", + self._tag, + int(time.time() - start_time)) + last_log = time.time() + + time.sleep(0.05) + + logging.info('External process successfully started for %s', self._tag) + else: + # We need to be able to authenticate with both the manager + # and the process. + self._serving_manager = _SingletonRegistrar( + address=('localhost', 0), authkey=AUTH_KEY) + multiprocessing.current_process().authkey = AUTH_KEY + # Initialize eagerly to avoid acting as the server if there are issues. + # Note, however, that _create_server itself is called lazily. + _process_level_singleton_manager.register_singleton( + self._constructor, self._tag, initialize_eagerly=True) + self._server = self._serving_manager.get_server() + logging.info( + 'Starting proxy server at %s for shared %s', + self._server.address, + self._tag) + with open(address_file + '.tmp', 'w') as fout: + fout.write('%s:%d' % self._server.address) + os.rename(address_file + '.tmp', address_file) + t = threading.Thread(target=self._server.serve_forever, daemon=True) + t.start() + logging.info('Done starting server') diff --git a/sdks/python/apache_beam/utils/multi_process_shared_test.py b/sdks/python/apache_beam/utils/multi_process_shared_test.py index 0b7957632368..f3258cf0a968 100644 --- a/sdks/python/apache_beam/utils/multi_process_shared_test.py +++ b/sdks/python/apache_beam/utils/multi_process_shared_test.py @@ -18,6 +18,9 @@ import logging import threading +import tempfile +import os +import multiprocessing import unittest from typing import Any @@ -82,6 +85,14 @@ def __getattribute__(self, __name: str) -> Any: return object.__getattribute__(self, __name) +class SimpleClass: + def make_proxy( + self, tag: str = 'proxy_on_proxy', spawn_process: bool = False): + return multi_process_shared.MultiProcessShared( + Counter, tag=tag, always_proxy=True, + spawn_process=spawn_process).acquire() + + class MultiProcessSharedTest(unittest.TestCase): @classmethod def setUpClass(cls): @@ -193,6 +204,34 @@ def test_unsafe_hard_delete(self): self.assertEqual(counter3.increment(), 1) + def test_unsafe_hard_delete_autoproxywrapper(self): + shared1 = multi_process_shared.MultiProcessShared( + Counter, + tag='test_unsafe_hard_delete_autoproxywrapper', + always_proxy=True) + shared2 = multi_process_shared.MultiProcessShared( + Counter, + tag='test_unsafe_hard_delete_autoproxywrapper', + always_proxy=True) + + counter1 = shared1.acquire() + counter2 = shared2.acquire() + self.assertEqual(counter1.increment(), 1) + self.assertEqual(counter2.increment(), 2) + + counter2.unsafe_hard_delete() + + with self.assertRaises(Exception): + counter1.get() + with self.assertRaises(Exception): + counter2.get() + + counter3 = multi_process_shared.MultiProcessShared( + Counter, + tag='test_unsafe_hard_delete_autoproxywrapper', + always_proxy=True).acquire() + self.assertEqual(counter3.increment(), 1) + def test_unsafe_hard_delete_no_op(self): shared1 = multi_process_shared.MultiProcessShared( Counter, tag='test_unsafe_hard_delete_no_op', always_proxy=True) @@ -242,6 +281,185 @@ def test_release_always_proxy(self): with self.assertRaisesRegex(Exception, 'released'): counter1.get() + def test_proxy_on_proxy(self): + shared1 = multi_process_shared.MultiProcessShared( + SimpleClass, tag='proxy_on_proxy_main', always_proxy=True) + instance = shared1.acquire() + proxy_instance = instance.make_proxy() + self.assertEqual(proxy_instance.increment(), 1) + + +class MultiProcessSharedSpawnProcessTest(unittest.TestCase): + def setUp(self): + tempdir = tempfile.gettempdir() + for tag in ['basic', + 'proxy_on_proxy', + 'proxy_on_proxy_main', + 'main', + 'to_delete', + 'mix1', + 'mix2' + 'test_process_exit']: + for ext in ['', '.address', '.address.error']: + try: + os.remove(os.path.join(tempdir, tag + ext)) + except OSError: + pass + + def tearDown(self): + for p in multiprocessing.active_children(): + p.terminate() + p.join() + + def test_call(self): + shared = multi_process_shared.MultiProcessShared( + Counter, tag='basic', always_proxy=True, spawn_process=True).acquire() + self.assertEqual(shared.get(), 0) + self.assertEqual(shared.increment(), 1) + self.assertEqual(shared.increment(10), 11) + self.assertEqual(shared.increment(value=10), 21) + self.assertEqual(shared.get(), 21) + + def test_proxy_on_proxy(self): + shared1 = multi_process_shared.MultiProcessShared( + SimpleClass, tag='main', always_proxy=True) + instance = shared1.acquire() + proxy_instance = instance.make_proxy(spawn_process=True) + self.assertEqual(proxy_instance.increment(), 1) + proxy_instance.unsafe_hard_delete() + + proxy_instance2 = instance.make_proxy(tag='proxy_2', spawn_process=True) + self.assertEqual(proxy_instance2.increment(), 1) + + def test_unsafe_hard_delete_autoproxywrapper(self): + shared1 = multi_process_shared.MultiProcessShared( + Counter, tag='to_delete', always_proxy=True, spawn_process=True) + shared2 = multi_process_shared.MultiProcessShared( + Counter, tag='to_delete', always_proxy=True, spawn_process=True) + counter3 = multi_process_shared.MultiProcessShared( + Counter, tag='basic', always_proxy=True, spawn_process=True).acquire() + + counter1 = shared1.acquire() + counter2 = shared2.acquire() + self.assertEqual(counter1.increment(), 1) + self.assertEqual(counter2.increment(), 2) + + counter2.unsafe_hard_delete() + + with self.assertRaises(Exception): + counter1.get() + with self.assertRaises(Exception): + counter2.get() + + counter4 = multi_process_shared.MultiProcessShared( + Counter, tag='to_delete', always_proxy=True, + spawn_process=True).acquire() + + self.assertEqual(counter3.increment(), 1) + self.assertEqual(counter4.increment(), 1) + + def test_mix_usage(self): + shared1 = multi_process_shared.MultiProcessShared( + Counter, tag='mix1', always_proxy=True, spawn_process=False).acquire() + shared2 = multi_process_shared.MultiProcessShared( + Counter, tag='mix2', always_proxy=True, spawn_process=True).acquire() + + self.assertEqual(shared1.get(), 0) + self.assertEqual(shared1.increment(), 1) + self.assertEqual(shared2.get(), 0) + self.assertEqual(shared2.increment(), 1) + + def test_process_exits_on_unsafe_hard_delete(self): + shared = multi_process_shared.MultiProcessShared( + Counter, tag='test_process_exit', always_proxy=True, spawn_process=True) + obj = shared.acquire() + + self.assertEqual(obj.increment(), 1) + + children = multiprocessing.active_children() + server_process = None + for p in children: + if p.pid != os.getpid() and p.is_alive(): + server_process = p + break + + self.assertIsNotNone( + server_process, "Could not find spawned server process") + obj.unsafe_hard_delete() + server_process.join(timeout=5) + + self.assertFalse( + server_process.is_alive(), + f"Server process {server_process.pid} is still alive after hard delete") + self.assertIsNotNone( + server_process.exitcode, "Process has no exit code (did not exit)") + + with self.assertRaises(Exception): + obj.get() + + def test_process_exits_on_unsafe_hard_delete_with_manager(self): + shared = multi_process_shared.MultiProcessShared( + Counter, tag='test_process_exit', always_proxy=True, spawn_process=True) + obj = shared.acquire() + + self.assertEqual(obj.increment(), 1) + + children = multiprocessing.active_children() + server_process = None + for p in children: + if p.pid != os.getpid() and p.is_alive(): + server_process = p + break + + self.assertIsNotNone( + server_process, "Could not find spawned server process") + shared.unsafe_hard_delete() + server_process.join(timeout=5) + + self.assertFalse( + server_process.is_alive(), + f"Server process {server_process.pid} is still alive after hard delete") + self.assertIsNotNone( + server_process.exitcode, "Process has no exit code (did not exit)") + + with self.assertRaises(Exception): + obj.get() + + def test_zombie_reaping_on_acquire(self): + shared1 = multi_process_shared.MultiProcessShared( + Counter, tag='test_zombie_reap', always_proxy=True, spawn_process=True) + obj = shared1.acquire() + + children = multiprocessing.active_children() + server_pid = next( + p.pid for p in children if p.is_alive() and p.pid != os.getpid()) + + obj.unsafe_hard_delete() + + try: + os.kill(server_pid, 0) + is_zombie = True + except OSError: + is_zombie = False + self.assertTrue( + is_zombie, + f"Server process {server_pid} was reaped too early before acquire()") + + shared2 = multi_process_shared.MultiProcessShared( + Counter, tag='unrelated_tag', always_proxy=True, spawn_process=True) + _ = shared2.acquire() + + pid_exists = True + try: + os.kill(server_pid, 0) + except OSError: + pid_exists = False + + self.assertFalse( + pid_exists, + f"Old server process {server_pid} was not reaped by acquire() sweep") + shared2.unsafe_hard_delete() + if __name__ == '__main__': logging.getLogger().setLevel(logging.INFO) From 0e8456ebf5289615f9951e603e8d056b92386b84 Mon Sep 17 00:00:00 2001 From: AMOOOMA Date: Mon, 15 Dec 2025 21:47:09 +0000 Subject: [PATCH 02/13] Remove oom protection --- .../apache_beam/utils/multi_process_shared.py | 23 ++----------------- 1 file changed, 2 insertions(+), 21 deletions(-) diff --git a/sdks/python/apache_beam/utils/multi_process_shared.py b/sdks/python/apache_beam/utils/multi_process_shared.py index 0efa01f45570..1a7a751dba89 100644 --- a/sdks/python/apache_beam/utils/multi_process_shared.py +++ b/sdks/python/apache_beam/utils/multi_process_shared.py @@ -117,25 +117,6 @@ def __dir__(self): return dir -def _run_with_oom_protection(func, *args, **kwargs): - try: - return func(*args, **kwargs) - except Exception as e: - # Check string to avoid hard import dependency - if 'CUDA out of memory' in str(e): - logging.warning("Caught CUDA OOM during operation. Cleaning memory.") - try: - import gc - import torch - gc.collect() - torch.cuda.empty_cache() - except ImportError: - pass - except Exception as cleanup_error: - logging.error("Failed to clean up CUDA memory: %s", cleanup_error) - raise e - - class _SingletonEntry: """Represents a single, refcounted entry in this process.""" def __init__( @@ -145,7 +126,7 @@ def __init__( self.refcount = 0 self.lock = threading.Lock() if initialize_eagerly: - self.obj = _run_with_oom_protection(constructor) + self.obj = constructor() self.initialied = True else: self.initialied = False @@ -153,7 +134,7 @@ def __init__( def acquire(self): with self.lock: if not self.initialied: - self.obj = _run_with_oom_protection(self.constructor) + self.obj = self.constructor() self.initialied = True self.refcount += 1 return _SingletonProxy(self) From 3c4ef28895856d0b68ae3eaa04dea0e299d044f5 Mon Sep 17 00:00:00 2001 From: AMOOOMA Date: Wed, 14 Jan 2026 01:19:21 +0000 Subject: [PATCH 03/13] Resolve comments --- .../apache_beam/utils/multi_process_shared.py | 40 ++++++-------- .../utils/multi_process_shared_test.py | 54 ++++++++++++++----- 2 files changed, 58 insertions(+), 36 deletions(-) diff --git a/sdks/python/apache_beam/utils/multi_process_shared.py b/sdks/python/apache_beam/utils/multi_process_shared.py index 1a7a751dba89..f8274d6d525d 100644 --- a/sdks/python/apache_beam/utils/multi_process_shared.py +++ b/sdks/python/apache_beam/utils/multi_process_shared.py @@ -186,7 +186,6 @@ def release_singleton(self, tag, obj): def unsafe_hard_delete_singleton(self, tag): self.entries[tag].unsafe_hard_delete() - self._hard_delete_callback() _process_level_singleton_manager = _SingletonManager() @@ -236,13 +235,7 @@ def get_auto_proxy_object(self): return self._proxyObject def unsafe_hard_delete(self): - try: - self._proxyObject.unsafe_hard_delete() - except (EOFError, ConnectionResetError, BrokenPipeError): - pass - except Exception as e: - logging.warning( - "Exception %s when trying to hard delete shared object proxy", e) + self._proxyObject.unsafe_hard_delete() def _run_server_process(address_file, tag, constructor, authkey): @@ -259,8 +252,8 @@ def cleanup_files(): os.remove(address_file) if os.path.exists(address_file + ".error"): os.remove(address_file + ".error") - except Exception: - pass + except Exception as e: + logging.warning('Failed to cleanup files for tag %s: %s', tag, e) def handle_unsafe_hard_delete(): cleanup_files() @@ -270,6 +263,9 @@ def _monitor_parent(): """Checks if parent is alive every second.""" while True: try: + # Sends a check to see if parent_pid is still alive, + # this call will fail with OSError if the parent has died + # and no-op if alive. os.kill(parent_pid, 0) except OSError: logging.warning( @@ -284,7 +280,6 @@ def _monitor_parent(): try: t = threading.Thread(target=_monitor_parent, daemon=True) - t.start() logging.getLogger().setLevel(logging.INFO) multiprocessing.current_process().authkey = authkey @@ -298,6 +293,9 @@ def _monitor_parent(): tag, initialize_eagerly=True, hard_delete_callback=handle_unsafe_hard_delete) + # Start monitoring parent after initialisation is done to avoid + # potential race conditions. + t.start() server = serving_manager.get_server() logging.info( @@ -422,7 +420,8 @@ def acquire(self): # Trigger a sweep of zombie processes. # calling active_children() has the side-effect of joining any finished # processes, effectively reaping zombies from previous unsafe_hard_deletes. - if self._spawn_process: multiprocessing.active_children() + if self._spawn_process: + multiprocessing.active_children() return _AutoProxyWrapper(singleton) def release(self, obj): @@ -437,15 +436,7 @@ def unsafe_hard_delete(self): to this object exist, or (b) you are ok with all existing references to this object throwing strange errors when derefrenced. """ - try: - self._get_manager().unsafe_hard_delete_singleton(self._tag) - except (EOFError, ConnectionResetError, BrokenPipeError): - pass - except Exception as e: - logging.warning( - "Exception %s when trying to hard delete shared object %s", - e, - self._tag) + self._get_manager().unsafe_hard_delete_singleton(self._tag) def _create_server(self, address_file): if self._spawn_process: @@ -477,8 +468,11 @@ def cleanup_process(): os.remove(address_file) if os.path.exists(error_file): os.remove(error_file) - except Exception: - pass + except Exception as e: + logging.warning( + 'Failed to cleanup files for tag %s in atexit handler: %s', + self._tag, + e) atexit.register(cleanup_process) diff --git a/sdks/python/apache_beam/utils/multi_process_shared_test.py b/sdks/python/apache_beam/utils/multi_process_shared_test.py index f3258cf0a968..4905509b52d0 100644 --- a/sdks/python/apache_beam/utils/multi_process_shared_test.py +++ b/sdks/python/apache_beam/utils/multi_process_shared_test.py @@ -189,8 +189,11 @@ def test_unsafe_hard_delete(self): self.assertEqual(counter1.increment(), 1) self.assertEqual(counter2.increment(), 2) - multi_process_shared.MultiProcessShared( - Counter, tag='test_unsafe_hard_delete').unsafe_hard_delete() + try: + multi_process_shared.MultiProcessShared( + Counter, tag='test_unsafe_hard_delete').unsafe_hard_delete() + except Exception: + pass with self.assertRaises(Exception): counter1.get() @@ -219,7 +222,10 @@ def test_unsafe_hard_delete_autoproxywrapper(self): self.assertEqual(counter1.increment(), 1) self.assertEqual(counter2.increment(), 2) - counter2.unsafe_hard_delete() + try: + counter2.unsafe_hard_delete() + except Exception: + pass with self.assertRaises(Exception): counter1.get() @@ -243,8 +249,11 @@ def test_unsafe_hard_delete_no_op(self): self.assertEqual(counter1.increment(), 1) self.assertEqual(counter2.increment(), 2) - multi_process_shared.MultiProcessShared( - Counter, tag='no_tag_to_delete').unsafe_hard_delete() + try: + multi_process_shared.MultiProcessShared( + Counter, tag='no_tag_to_delete').unsafe_hard_delete() + except Exception: + pass self.assertEqual(counter1.increment(), 3) self.assertEqual(counter2.increment(), 4) @@ -298,8 +307,9 @@ def setUp(self): 'main', 'to_delete', 'mix1', - 'mix2' - 'test_process_exit']: + 'mix2', + 'test_process_exit', + 'thundering_herd_test']: for ext in ['', '.address', '.address.error']: try: os.remove(os.path.join(tempdir, tag + ext)) @@ -326,7 +336,10 @@ def test_proxy_on_proxy(self): instance = shared1.acquire() proxy_instance = instance.make_proxy(spawn_process=True) self.assertEqual(proxy_instance.increment(), 1) - proxy_instance.unsafe_hard_delete() + try: + proxy_instance.unsafe_hard_delete() + except Exception: + pass proxy_instance2 = instance.make_proxy(tag='proxy_2', spawn_process=True) self.assertEqual(proxy_instance2.increment(), 1) @@ -344,7 +357,10 @@ def test_unsafe_hard_delete_autoproxywrapper(self): self.assertEqual(counter1.increment(), 1) self.assertEqual(counter2.increment(), 2) - counter2.unsafe_hard_delete() + try: + counter2.unsafe_hard_delete() + except Exception: + pass with self.assertRaises(Exception): counter1.get() @@ -385,7 +401,10 @@ def test_process_exits_on_unsafe_hard_delete(self): self.assertIsNotNone( server_process, "Could not find spawned server process") - obj.unsafe_hard_delete() + try: + obj.unsafe_hard_delete() + except Exception: + pass server_process.join(timeout=5) self.assertFalse( @@ -413,7 +432,10 @@ def test_process_exits_on_unsafe_hard_delete_with_manager(self): self.assertIsNotNone( server_process, "Could not find spawned server process") - shared.unsafe_hard_delete() + try: + shared.unsafe_hard_delete() + except Exception: + pass server_process.join(timeout=5) self.assertFalse( @@ -434,7 +456,10 @@ def test_zombie_reaping_on_acquire(self): server_pid = next( p.pid for p in children if p.is_alive() and p.pid != os.getpid()) - obj.unsafe_hard_delete() + try: + obj.unsafe_hard_delete() + except Exception: + pass try: os.kill(server_pid, 0) @@ -458,7 +483,10 @@ def test_zombie_reaping_on_acquire(self): self.assertFalse( pid_exists, f"Old server process {server_pid} was not reaped by acquire() sweep") - shared2.unsafe_hard_delete() + try: + shared2.unsafe_hard_delete() + except Exception: + pass if __name__ == '__main__': From dc5d6c783105d0e5e2011d183c4e2eece3649703 Mon Sep 17 00:00:00 2001 From: AMOOOMA Date: Tue, 20 Jan 2026 21:34:46 +0000 Subject: [PATCH 04/13] Rename unsafe_hard_delete for the proxy object to prevent collision --- sdks/python/apache_beam/utils/multi_process_shared.py | 4 ++-- .../python/apache_beam/utils/multi_process_shared_test.py | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/sdks/python/apache_beam/utils/multi_process_shared.py b/sdks/python/apache_beam/utils/multi_process_shared.py index f8274d6d525d..23fe03b31786 100644 --- a/sdks/python/apache_beam/utils/multi_process_shared.py +++ b/sdks/python/apache_beam/utils/multi_process_shared.py @@ -83,7 +83,7 @@ def singletonProxy_release(self): assert self._SingletonProxy_valid self._SingletonProxy_valid = False - def unsafe_hard_delete(self): + def singletonProxy_unsafe_hard_delete(self): assert self._SingletonProxy_valid self._SingletonProxy_entry.unsafe_hard_delete() @@ -113,7 +113,7 @@ def __dir__(self): dir = self._SingletonProxy_entry.obj.__dir__() dir.append('singletonProxy_call__') dir.append('singletonProxy_release') - dir.append('unsafe_hard_delete') + dir.append('singletonProxy_unsafe_hard_delete') return dir diff --git a/sdks/python/apache_beam/utils/multi_process_shared_test.py b/sdks/python/apache_beam/utils/multi_process_shared_test.py index 4905509b52d0..5f794b70f60d 100644 --- a/sdks/python/apache_beam/utils/multi_process_shared_test.py +++ b/sdks/python/apache_beam/utils/multi_process_shared_test.py @@ -223,7 +223,7 @@ def test_unsafe_hard_delete_autoproxywrapper(self): self.assertEqual(counter2.increment(), 2) try: - counter2.unsafe_hard_delete() + counter2.singletonProxy_unsafe_hard_delete() except Exception: pass @@ -358,7 +358,7 @@ def test_unsafe_hard_delete_autoproxywrapper(self): self.assertEqual(counter2.increment(), 2) try: - counter2.unsafe_hard_delete() + counter2.singletonProxy_unsafe_hard_delete() except Exception: pass @@ -402,7 +402,7 @@ def test_process_exits_on_unsafe_hard_delete(self): self.assertIsNotNone( server_process, "Could not find spawned server process") try: - obj.unsafe_hard_delete() + obj.singletonProxy_unsafe_hard_delete() except Exception: pass server_process.join(timeout=5) @@ -457,7 +457,7 @@ def test_zombie_reaping_on_acquire(self): p.pid for p in children if p.is_alive() and p.pid != os.getpid()) try: - obj.unsafe_hard_delete() + obj.singletonProxy_unsafe_hard_delete() except Exception: pass From 63c71077f2fe0d414909185286a859d3294ae5b4 Mon Sep 17 00:00:00 2001 From: AMOOOMA Date: Thu, 29 Jan 2026 20:33:28 +0000 Subject: [PATCH 05/13] Remove support for proxy on proxy to avoid complexity --- .../apache_beam/utils/multi_process_shared.py | 6 ---- .../utils/multi_process_shared_test.py | 31 ------------------- 2 files changed, 37 deletions(-) diff --git a/sdks/python/apache_beam/utils/multi_process_shared.py b/sdks/python/apache_beam/utils/multi_process_shared.py index 23fe03b31786..9430554ff2b5 100644 --- a/sdks/python/apache_beam/utils/multi_process_shared.py +++ b/sdks/python/apache_beam/utils/multi_process_shared.py @@ -225,12 +225,6 @@ def __call__(self, *args, **kwargs): def __getattr__(self, name): return getattr(self._proxyObject, name) - def __setstate__(self, state): - self.__dict__.update(state) - - def __getstate__(self): - return self.__dict__ - def get_auto_proxy_object(self): return self._proxyObject diff --git a/sdks/python/apache_beam/utils/multi_process_shared_test.py b/sdks/python/apache_beam/utils/multi_process_shared_test.py index 5f794b70f60d..63aa164b02ff 100644 --- a/sdks/python/apache_beam/utils/multi_process_shared_test.py +++ b/sdks/python/apache_beam/utils/multi_process_shared_test.py @@ -85,14 +85,6 @@ def __getattribute__(self, __name: str) -> Any: return object.__getattribute__(self, __name) -class SimpleClass: - def make_proxy( - self, tag: str = 'proxy_on_proxy', spawn_process: bool = False): - return multi_process_shared.MultiProcessShared( - Counter, tag=tag, always_proxy=True, - spawn_process=spawn_process).acquire() - - class MultiProcessSharedTest(unittest.TestCase): @classmethod def setUpClass(cls): @@ -290,20 +282,11 @@ def test_release_always_proxy(self): with self.assertRaisesRegex(Exception, 'released'): counter1.get() - def test_proxy_on_proxy(self): - shared1 = multi_process_shared.MultiProcessShared( - SimpleClass, tag='proxy_on_proxy_main', always_proxy=True) - instance = shared1.acquire() - proxy_instance = instance.make_proxy() - self.assertEqual(proxy_instance.increment(), 1) - class MultiProcessSharedSpawnProcessTest(unittest.TestCase): def setUp(self): tempdir = tempfile.gettempdir() for tag in ['basic', - 'proxy_on_proxy', - 'proxy_on_proxy_main', 'main', 'to_delete', 'mix1', @@ -330,20 +313,6 @@ def test_call(self): self.assertEqual(shared.increment(value=10), 21) self.assertEqual(shared.get(), 21) - def test_proxy_on_proxy(self): - shared1 = multi_process_shared.MultiProcessShared( - SimpleClass, tag='main', always_proxy=True) - instance = shared1.acquire() - proxy_instance = instance.make_proxy(spawn_process=True) - self.assertEqual(proxy_instance.increment(), 1) - try: - proxy_instance.unsafe_hard_delete() - except Exception: - pass - - proxy_instance2 = instance.make_proxy(tag='proxy_2', spawn_process=True) - self.assertEqual(proxy_instance2.increment(), 1) - def test_unsafe_hard_delete_autoproxywrapper(self): shared1 = multi_process_shared.MultiProcessShared( Counter, tag='to_delete', always_proxy=True, spawn_process=True) From 6e9bcbe1cf2a8be46c9b0a38719eb4e08d1578d7 Mon Sep 17 00:00:00 2001 From: AMOOOMA Date: Thu, 29 Jan 2026 23:53:50 +0000 Subject: [PATCH 06/13] Fix import order --- sdks/python/apache_beam/utils/multi_process_shared.py | 6 +++--- sdks/python/apache_beam/utils/multi_process_shared_test.py | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/sdks/python/apache_beam/utils/multi_process_shared.py b/sdks/python/apache_beam/utils/multi_process_shared.py index 9430554ff2b5..46226a3d402c 100644 --- a/sdks/python/apache_beam/utils/multi_process_shared.py +++ b/sdks/python/apache_beam/utils/multi_process_shared.py @@ -22,15 +22,15 @@ """ # pytype: skip-file +import atexit import logging import multiprocessing.managers import os -import time -import traceback -import atexit import sys import tempfile import threading +import time +import traceback from typing import Any from typing import Callable from typing import Dict diff --git a/sdks/python/apache_beam/utils/multi_process_shared_test.py b/sdks/python/apache_beam/utils/multi_process_shared_test.py index 63aa164b02ff..c1ffb56066e0 100644 --- a/sdks/python/apache_beam/utils/multi_process_shared_test.py +++ b/sdks/python/apache_beam/utils/multi_process_shared_test.py @@ -17,10 +17,10 @@ # pytype: skip-file import logging -import threading -import tempfile -import os import multiprocessing +import os +import tempfile +import threading import unittest from typing import Any From aee8f39988406acbcd5e8778cfe8b11142094e5b Mon Sep 17 00:00:00 2001 From: AMOOOMA Date: Fri, 30 Jan 2026 17:44:59 +0000 Subject: [PATCH 07/13] Update reap test to be compatiable for windows --- .../apache_beam/utils/multi_process_shared_test.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/sdks/python/apache_beam/utils/multi_process_shared_test.py b/sdks/python/apache_beam/utils/multi_process_shared_test.py index c1ffb56066e0..22a495be7339 100644 --- a/sdks/python/apache_beam/utils/multi_process_shared_test.py +++ b/sdks/python/apache_beam/utils/multi_process_shared_test.py @@ -443,14 +443,12 @@ def test_zombie_reaping_on_acquire(self): Counter, tag='unrelated_tag', always_proxy=True, spawn_process=True) _ = shared2.acquire() - pid_exists = True - try: - os.kill(server_pid, 0) - except OSError: - pid_exists = False + # If reaping worked, our old server_pid should NOT be in this list. + current_children_pids = [p.pid for p in multiprocessing.active_children()] - self.assertFalse( - pid_exists, + self.assertNotIn( + server_pid, + current_children_pids, f"Old server process {server_pid} was not reaped by acquire() sweep") try: shared2.unsafe_hard_delete() From f1c7e6f54a43a69508ce74c9d8e7a219f884a05c Mon Sep 17 00:00:00 2001 From: AMOOOMA Date: Fri, 30 Jan 2026 19:29:21 +0000 Subject: [PATCH 08/13] Update print to logging --- sdks/python/apache_beam/utils/multi_process_shared.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sdks/python/apache_beam/utils/multi_process_shared.py b/sdks/python/apache_beam/utils/multi_process_shared.py index 46226a3d402c..a73751c73f59 100644 --- a/sdks/python/apache_beam/utils/multi_process_shared.py +++ b/sdks/python/apache_beam/utils/multi_process_shared.py @@ -26,7 +26,6 @@ import logging import multiprocessing.managers import os -import sys import tempfile import threading import time @@ -308,7 +307,7 @@ def _monitor_parent(): fout.write(tb) os.rename(address_file + ".error.tmp", address_file + ".error") except Exception: - print(f"CRITICAL ERROR IN SHARED SERVER:\n{tb}", file=sys.stderr) + logging.error("CRITICAL ERROR IN SHARED SERVER:\n%s", tb) os._exit(1) From f479ea3ac83db7a161b3a453b91bbbe1c16df15b Mon Sep 17 00:00:00 2001 From: AMOOOMA Date: Fri, 30 Jan 2026 22:28:38 +0000 Subject: [PATCH 09/13] Try to tearDown test in a cleaner way --- .../apache_beam/utils/multi_process_shared_test.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/sdks/python/apache_beam/utils/multi_process_shared_test.py b/sdks/python/apache_beam/utils/multi_process_shared_test.py index 22a495be7339..2cf240c26f5d 100644 --- a/sdks/python/apache_beam/utils/multi_process_shared_test.py +++ b/sdks/python/apache_beam/utils/multi_process_shared_test.py @@ -301,8 +301,17 @@ def setUp(self): def tearDown(self): for p in multiprocessing.active_children(): - p.terminate() - p.join() + if p.is_alive(): + try: + p.terminate() + p.join(timeout=0.5) + + if p.is_alive(): + # Force kill if still alive + p.kill() + p.join(timeout=0.1) + except Exception: + pass def test_call(self): shared = multi_process_shared.MultiProcessShared( From cbdb511d823599c58c82f2109db68c3b13c21463 Mon Sep 17 00:00:00 2001 From: AMOOOMA Date: Fri, 30 Jan 2026 23:13:20 +0000 Subject: [PATCH 10/13] Try patching atexit call to prevent hanging on window --- .../utils/multi_process_shared_test.py | 30 +++++++++++-------- 1 file changed, 18 insertions(+), 12 deletions(-) diff --git a/sdks/python/apache_beam/utils/multi_process_shared_test.py b/sdks/python/apache_beam/utils/multi_process_shared_test.py index 2cf240c26f5d..efe2e90c1d1b 100644 --- a/sdks/python/apache_beam/utils/multi_process_shared_test.py +++ b/sdks/python/apache_beam/utils/multi_process_shared_test.py @@ -22,6 +22,7 @@ import tempfile import threading import unittest +from unittest.mock import patch from typing import Any from apache_beam.utils import multi_process_shared @@ -299,19 +300,24 @@ def setUp(self): except OSError: pass + # Patch atexit.register to prevent hanging on exit + # on Windows due to multiprocessing issues. + self.atexit_patcher = patch('atexit.register') + self.mock_atexit = self.atexit_patcher.start() + self.captured_handlers = [] + + def capture_handler(func, *args, **kwargs): + self.captured_handlers.append((func, args, kwargs)) + + self.mock_atexit.side_effect = capture_handler + def tearDown(self): - for p in multiprocessing.active_children(): - if p.is_alive(): - try: - p.terminate() - p.join(timeout=0.5) - - if p.is_alive(): - # Force kill if still alive - p.kill() - p.join(timeout=0.1) - except Exception: - pass + for func, args, kwargs in reversed(self.captured_handlers): + try: + func(*args, **kwargs) + except Exception: + pass + self.atexit_patcher.stop() def test_call(self): shared = multi_process_shared.MultiProcessShared( From 9767486665248e50dff752e5d3ca300b7363821f Mon Sep 17 00:00:00 2001 From: AMOOOMA Date: Sat, 31 Jan 2026 00:18:49 +0000 Subject: [PATCH 11/13] Try weakref so windows can GC the process --- .../apache_beam/utils/multi_process_shared.py | 16 +++++++++--- .../utils/multi_process_shared_test.py | 25 ++++++------------- 2 files changed, 19 insertions(+), 22 deletions(-) diff --git a/sdks/python/apache_beam/utils/multi_process_shared.py b/sdks/python/apache_beam/utils/multi_process_shared.py index a73751c73f59..59faed9ca8a2 100644 --- a/sdks/python/apache_beam/utils/multi_process_shared.py +++ b/sdks/python/apache_beam/utils/multi_process_shared.py @@ -36,6 +36,7 @@ from typing import Generic from typing import Optional from typing import TypeVar +import weakref import fasteners @@ -450,12 +451,19 @@ def _create_server(self, address_file): p.start() logging.info("Parent: Waiting for %s to write address file...", self._tag) + # Make a weakref so that on Windows we don't keep + # prevent the process from being GCed. + weakref_p = weakref.ref(p) + def cleanup_process(): - if p.is_alive(): + proc = weakref_p() + if proc and proc.is_alive(): logging.info( - "Parent: Terminating server process %s for %s", p.pid, self._tag) - p.terminate() - p.join() + "Parent: Terminating server process %s for %s", + proc.pid, + self._tag) + proc.terminate() + proc.join() try: if os.path.exists(address_file): os.remove(address_file) diff --git a/sdks/python/apache_beam/utils/multi_process_shared_test.py b/sdks/python/apache_beam/utils/multi_process_shared_test.py index efe2e90c1d1b..7b2b11857bfd 100644 --- a/sdks/python/apache_beam/utils/multi_process_shared_test.py +++ b/sdks/python/apache_beam/utils/multi_process_shared_test.py @@ -22,7 +22,6 @@ import tempfile import threading import unittest -from unittest.mock import patch from typing import Any from apache_beam.utils import multi_process_shared @@ -300,24 +299,14 @@ def setUp(self): except OSError: pass - # Patch atexit.register to prevent hanging on exit - # on Windows due to multiprocessing issues. - self.atexit_patcher = patch('atexit.register') - self.mock_atexit = self.atexit_patcher.start() - self.captured_handlers = [] - - def capture_handler(func, *args, **kwargs): - self.captured_handlers.append((func, args, kwargs)) - - self.mock_atexit.side_effect = capture_handler - def tearDown(self): - for func, args, kwargs in reversed(self.captured_handlers): - try: - func(*args, **kwargs) - except Exception: - pass - self.atexit_patcher.stop() + for p in multiprocessing.active_children(): + if p.is_alive(): + try: + p.kill() + p.join(timeout=1.0) + except Exception: + pass def test_call(self): shared = multi_process_shared.MultiProcessShared( From 3a663e26e11eaec10b4873abb20cb8b52b8aa4c2 Mon Sep 17 00:00:00 2001 From: AMOOOMA Date: Sat, 31 Jan 2026 01:01:37 +0000 Subject: [PATCH 12/13] Try GC manually to make sure p is cleaned up --- .../apache_beam/utils/multi_process_shared.py | 2 +- .../utils/multi_process_shared_test.py | 18 ++++++++++-------- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/sdks/python/apache_beam/utils/multi_process_shared.py b/sdks/python/apache_beam/utils/multi_process_shared.py index 59faed9ca8a2..1896d6cbf942 100644 --- a/sdks/python/apache_beam/utils/multi_process_shared.py +++ b/sdks/python/apache_beam/utils/multi_process_shared.py @@ -30,13 +30,13 @@ import threading import time import traceback +import weakref from typing import Any from typing import Callable from typing import Dict from typing import Generic from typing import Optional from typing import TypeVar -import weakref import fasteners diff --git a/sdks/python/apache_beam/utils/multi_process_shared_test.py b/sdks/python/apache_beam/utils/multi_process_shared_test.py index 7b2b11857bfd..6f7bca929ef2 100644 --- a/sdks/python/apache_beam/utils/multi_process_shared_test.py +++ b/sdks/python/apache_beam/utils/multi_process_shared_test.py @@ -16,6 +16,7 @@ # # pytype: skip-file +import gc import logging import multiprocessing import os @@ -300,6 +301,10 @@ def setUp(self): pass def tearDown(self): + # Force a garbage collection to ensure that any leftover references + # to the old server process proxy are cleaned up. + gc.collect() + for p in multiprocessing.active_children(): if p.is_alive(): try: @@ -434,14 +439,11 @@ def test_zombie_reaping_on_acquire(self): except Exception: pass - try: - os.kill(server_pid, 0) - is_zombie = True - except OSError: - is_zombie = False - self.assertTrue( - is_zombie, - f"Server process {server_pid} was reaped too early before acquire()") + current_children_pids = [p.pid for p in multiprocessing.active_children()] + self.assertIn( + server_pid, + current_children_pids, + f"Server process {server_pid} was prematurely reaped before acquire()") shared2 = multi_process_shared.MultiProcessShared( Counter, tag='unrelated_tag', always_proxy=True, spawn_process=True) From 3c8a3f81b2a2cc5bbe65319d117258055d21efaf Mon Sep 17 00:00:00 2001 From: AMOOOMA Date: Sat, 31 Jan 2026 02:47:35 +0000 Subject: [PATCH 13/13] Use a different way to check if parent is alive --- .../apache_beam/utils/multi_process_shared.py | 33 ++++++++----------- .../utils/multi_process_shared_test.py | 18 +++++----- 2 files changed, 22 insertions(+), 29 deletions(-) diff --git a/sdks/python/apache_beam/utils/multi_process_shared.py b/sdks/python/apache_beam/utils/multi_process_shared.py index 1896d6cbf942..c576f8a78e9d 100644 --- a/sdks/python/apache_beam/utils/multi_process_shared.py +++ b/sdks/python/apache_beam/utils/multi_process_shared.py @@ -30,7 +30,6 @@ import threading import time import traceback -import weakref from typing import Any from typing import Callable from typing import Dict @@ -232,7 +231,7 @@ def unsafe_hard_delete(self): self._proxyObject.unsafe_hard_delete() -def _run_server_process(address_file, tag, constructor, authkey): +def _run_server_process(address_file, tag, constructor, authkey, life_line): """ Runs in a separate process. Includes a 'Suicide Pact' monitor: If parent dies, I die. @@ -257,11 +256,9 @@ def _monitor_parent(): """Checks if parent is alive every second.""" while True: try: - # Sends a check to see if parent_pid is still alive, - # this call will fail with OSError if the parent has died - # and no-op if alive. - os.kill(parent_pid, 0) - except OSError: + # This will break if parent dies. + life_line.recv_bytes() + except (EOFError, OSError, BrokenPipeError): logging.warning( "Process %s detected Parent %s died. Self-destructing.", os.getpid(), @@ -442,28 +439,26 @@ def _create_server(self, address_file): except OSError: pass + # Create a pipe to connect with child process + # used to clean up child process if parent dies + reader, writer = multiprocessing.Pipe(duplex=False) + self._life_line = writer + ctx = multiprocessing.get_context('spawn') p = ctx.Process( target=_run_server_process, - args=(address_file, self._tag, self._constructor, AUTH_KEY), + args=(address_file, self._tag, self._constructor, AUTH_KEY, reader), daemon=False # Must be False for nested proxies ) p.start() logging.info("Parent: Waiting for %s to write address file...", self._tag) - # Make a weakref so that on Windows we don't keep - # prevent the process from being GCed. - weakref_p = weakref.ref(p) - def cleanup_process(): - proc = weakref_p() - if proc and proc.is_alive(): + if p.is_alive(): logging.info( - "Parent: Terminating server process %s for %s", - proc.pid, - self._tag) - proc.terminate() - proc.join() + "Parent: Terminating server process %s for %s", p.pid, self._tag) + p.terminate() + p.join() try: if os.path.exists(address_file): os.remove(address_file) diff --git a/sdks/python/apache_beam/utils/multi_process_shared_test.py b/sdks/python/apache_beam/utils/multi_process_shared_test.py index 6f7bca929ef2..7b2b11857bfd 100644 --- a/sdks/python/apache_beam/utils/multi_process_shared_test.py +++ b/sdks/python/apache_beam/utils/multi_process_shared_test.py @@ -16,7 +16,6 @@ # # pytype: skip-file -import gc import logging import multiprocessing import os @@ -301,10 +300,6 @@ def setUp(self): pass def tearDown(self): - # Force a garbage collection to ensure that any leftover references - # to the old server process proxy are cleaned up. - gc.collect() - for p in multiprocessing.active_children(): if p.is_alive(): try: @@ -439,11 +434,14 @@ def test_zombie_reaping_on_acquire(self): except Exception: pass - current_children_pids = [p.pid for p in multiprocessing.active_children()] - self.assertIn( - server_pid, - current_children_pids, - f"Server process {server_pid} was prematurely reaped before acquire()") + try: + os.kill(server_pid, 0) + is_zombie = True + except OSError: + is_zombie = False + self.assertTrue( + is_zombie, + f"Server process {server_pid} was reaped too early before acquire()") shared2 = multi_process_shared.MultiProcessShared( Counter, tag='unrelated_tag', always_proxy=True, spawn_process=True)