Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
235 changes: 211 additions & 24 deletions sdks/python/apache_beam/utils/multi_process_shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,15 @@
"""
# pytype: skip-file

import atexit
import logging
import multiprocessing.managers
import os
import sys
import tempfile
import threading
import time
import traceback
from typing import Any
from typing import Callable
from typing import Dict
Expand Down Expand Up @@ -79,6 +83,10 @@ def singletonProxy_release(self):
assert self._SingletonProxy_valid
self._SingletonProxy_valid = False

def singletonProxy_unsafe_hard_delete(self):
assert self._SingletonProxy_valid
self._SingletonProxy_entry.unsafe_hard_delete()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we still need this piece now that we're not passing around a proxy?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we do its fine, it mostly depends on how you now want to track models in the manager. You can leave it for now if unsure.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep technically we won't need this now since we are recreating the MPS object everytime. But probably good to have in case in the future we want some more flexibility.


def __getattr__(self, name):
if not self._SingletonProxy_valid:
raise RuntimeError('Entry was released.')
Expand All @@ -105,13 +113,16 @@ def __dir__(self):
dir = self._SingletonProxy_entry.obj.__dir__()
dir.append('singletonProxy_call__')
dir.append('singletonProxy_release')
dir.append('singletonProxy_unsafe_hard_delete')
return dir


class _SingletonEntry:
"""Represents a single, refcounted entry in this process."""
def __init__(self, constructor, initialize_eagerly=True):
def __init__(
self, constructor, initialize_eagerly=True, hard_delete_callback=None):
self.constructor = constructor
self._hard_delete_callback = hard_delete_callback
self.refcount = 0
self.lock = threading.Lock()
if initialize_eagerly:
Expand Down Expand Up @@ -141,14 +152,28 @@ def unsafe_hard_delete(self):
if self.initialied:
del self.obj
self.initialied = False
if self._hard_delete_callback:
self._hard_delete_callback()


class _SingletonManager:
entries: Dict[Any, Any] = {}

def register_singleton(self, constructor, tag, initialize_eagerly=True):
def __init__(self):
self._hard_delete_callback = None

def set_hard_delete_callback(self, callback):
self._hard_delete_callback = callback

def register_singleton(
self,
constructor,
tag,
initialize_eagerly=True,
hard_delete_callback=None):
assert tag not in self.entries, tag
self.entries[tag] = _SingletonEntry(constructor, initialize_eagerly)
self.entries[tag] = _SingletonEntry(
constructor, initialize_eagerly, hard_delete_callback)

def has_singleton(self, tag):
return tag in self.entries
Expand All @@ -160,7 +185,7 @@ def release_singleton(self, tag, obj):
return self.entries[tag].release(obj)

def unsafe_hard_delete_singleton(self, tag):
return self.entries[tag].unsafe_hard_delete()
self.entries[tag].unsafe_hard_delete()


_process_level_singleton_manager = _SingletonManager()
Expand Down Expand Up @@ -203,6 +228,89 @@ def __getattr__(self, name):
def get_auto_proxy_object(self):
return self._proxyObject

def unsafe_hard_delete(self):
self._proxyObject.unsafe_hard_delete()


def _run_server_process(address_file, tag, constructor, authkey):
"""
Runs in a separate process.
Includes a 'Suicide Pact' monitor: If parent dies, I die.
"""
parent_pid = os.getppid()

def cleanup_files():
logging.info("Server process exiting. Deleting files for %s", tag)
try:
if os.path.exists(address_file):
os.remove(address_file)
if os.path.exists(address_file + ".error"):
os.remove(address_file + ".error")
except Exception as e:
logging.warning('Failed to cleanup files for tag %s: %s', tag, e)

def handle_unsafe_hard_delete():
cleanup_files()
os._exit(0)

def _monitor_parent():
"""Checks if parent is alive every second."""
while True:
try:
# Sends a check to see if parent_pid is still alive,
# this call will fail with OSError if the parent has died
# and no-op if alive.
os.kill(parent_pid, 0)
except OSError:
logging.warning(
"Process %s detected Parent %s died. Self-destructing.",
os.getpid(),
parent_pid)
cleanup_files()
os._exit(0)
time.sleep(0.5)

atexit.register(cleanup_files)

try:
t = threading.Thread(target=_monitor_parent, daemon=True)

logging.getLogger().setLevel(logging.INFO)
multiprocessing.current_process().authkey = authkey

serving_manager = _SingletonRegistrar(
address=('localhost', 0), authkey=authkey)
_process_level_singleton_manager.set_hard_delete_callback(
handle_unsafe_hard_delete)
_process_level_singleton_manager.register_singleton(
constructor,
tag,
initialize_eagerly=True,
hard_delete_callback=handle_unsafe_hard_delete)
# Start monitoring parent after initialisation is done to avoid
# potential race conditions.
t.start()

server = serving_manager.get_server()
logging.info(
'Process %s: Proxy serving %s at %s', os.getpid(), tag, server.address)

with open(address_file + '.tmp', 'w') as fout:
fout.write('%s:%d' % server.address)
os.rename(address_file + '.tmp', address_file)

server.serve_forever()

except Exception:
tb = traceback.format_exc()
try:
with open(address_file + ".error.tmp", 'w') as fout:
fout.write(tb)
os.rename(address_file + ".error.tmp", address_file + ".error")
except Exception:
print(f"CRITICAL ERROR IN SHARED SERVER:\n{tb}", file=sys.stderr)
os._exit(1)


class MultiProcessShared(Generic[T]):
"""MultiProcessShared is used to share a single object across processes.
Expand Down Expand Up @@ -252,7 +360,8 @@ def __init__(
tag: Any,
*,
path: str = tempfile.gettempdir(),
always_proxy: Optional[bool] = None):
always_proxy: Optional[bool] = None,
spawn_process: bool = False):
self._constructor = constructor
self._tag = tag
self._path = path
Expand All @@ -262,6 +371,7 @@ def __init__(
self._rpc_address = None
self._cross_process_lock = fasteners.InterProcessLock(
os.path.join(self._path, self._tag) + '.lock')
self._spawn_process = spawn_process

def _get_manager(self):
if self._manager is None:
Expand Down Expand Up @@ -301,6 +411,11 @@ def acquire(self):
# Caveat: They must always agree, as they will be ignored if the object
# is already constructed.
singleton = self._get_manager().acquire_singleton(self._tag)
# Trigger a sweep of zombie processes.
# calling active_children() has the side-effect of joining any finished
# processes, effectively reaping zombies from previous unsafe_hard_deletes.
if self._spawn_process:
multiprocessing.active_children()
return _AutoProxyWrapper(singleton)

def release(self, obj):
Expand All @@ -318,22 +433,94 @@ def unsafe_hard_delete(self):
self._get_manager().unsafe_hard_delete_singleton(self._tag)

def _create_server(self, address_file):
# We need to be able to authenticate with both the manager and the process.
self._serving_manager = _SingletonRegistrar(
address=('localhost', 0), authkey=AUTH_KEY)
multiprocessing.current_process().authkey = AUTH_KEY
# Initialize eagerly to avoid acting as the server if there are issues.
# Note, however, that _create_server itself is called lazily.
_process_level_singleton_manager.register_singleton(
self._constructor, self._tag, initialize_eagerly=True)
self._server = self._serving_manager.get_server()
logging.info(
'Starting proxy server at %s for shared %s',
self._server.address,
self._tag)
with open(address_file + '.tmp', 'w') as fout:
fout.write('%s:%d' % self._server.address)
os.rename(address_file + '.tmp', address_file)
t = threading.Thread(target=self._server.serve_forever, daemon=True)
t.start()
logging.info('Done starting server')
if self._spawn_process:
error_file = address_file + ".error"

if os.path.exists(error_file):
try:
os.remove(error_file)
except OSError:
pass

ctx = multiprocessing.get_context('spawn')
p = ctx.Process(
target=_run_server_process,
args=(address_file, self._tag, self._constructor, AUTH_KEY),
daemon=False # Must be False for nested proxies
)
p.start()
logging.info("Parent: Waiting for %s to write address file...", self._tag)

def cleanup_process():
if p.is_alive():
logging.info(
"Parent: Terminating server process %s for %s", p.pid, self._tag)
p.terminate()
p.join()
try:
if os.path.exists(address_file):
os.remove(address_file)
if os.path.exists(error_file):
os.remove(error_file)
except Exception as e:
logging.warning(
'Failed to cleanup files for tag %s in atexit handler: %s',
self._tag,
e)

atexit.register(cleanup_process)

start_time = time.time()
last_log = start_time
while True:
if os.path.exists(address_file):
break

if os.path.exists(error_file):
with open(error_file, 'r') as f:
error_msg = f.read()
try:
os.remove(error_file)
except OSError:
pass

if p.is_alive(): p.terminate()
raise RuntimeError(f"Shared Server Process crashed:\n{error_msg}")

if not p.is_alive():
exit_code = p.exitcode
raise RuntimeError(
"Shared Server Process died unexpectedly"
f" with exit code {exit_code}")

if time.time() - last_log > 300:
logging.warning(
"Still waiting for %s to initialize... %ss elapsed)",
self._tag,
int(time.time() - start_time))
last_log = time.time()

time.sleep(0.05)

logging.info('External process successfully started for %s', self._tag)
else:
# We need to be able to authenticate with both the manager
# and the process.
self._serving_manager = _SingletonRegistrar(
address=('localhost', 0), authkey=AUTH_KEY)
multiprocessing.current_process().authkey = AUTH_KEY
# Initialize eagerly to avoid acting as the server if there are issues.
# Note, however, that _create_server itself is called lazily.
_process_level_singleton_manager.register_singleton(
self._constructor, self._tag, initialize_eagerly=True)
self._server = self._serving_manager.get_server()
logging.info(
'Starting proxy server at %s for shared %s',
self._server.address,
self._tag)
with open(address_file + '.tmp', 'w') as fout:
fout.write('%s:%d' % self._server.address)
os.rename(address_file + '.tmp', address_file)
t = threading.Thread(target=self._server.serve_forever, daemon=True)
t.start()
logging.info('Done starting server')
Loading
Loading