From 1a0342aaf000551b4a82b7adf4424ca208f5767e Mon Sep 17 00:00:00 2001 From: ZeuZ Agent Date: Tue, 3 Mar 2026 00:39:33 +0000 Subject: [PATCH] Implement REQ-6: Graceful Shutdown Handling --- Framework/Utilities/CommonUtil.py | 4 +- Framework/Utilities/RequestFormatter.py | 8 + Framework/deploy_handler/long_poll_handler.py | 5 +- .../install_handler/long_poll_handler.py | 10 +- node_cli.py | 162 ++++++++++++++++-- server/mobile.py | 45 ++++- server/node_operator.py | 3 +- 7 files changed, 204 insertions(+), 33 deletions(-) diff --git a/Framework/Utilities/CommonUtil.py b/Framework/Utilities/CommonUtil.py index 2edf195fc..e29a1ce80 100644 --- a/Framework/Utilities/CommonUtil.py +++ b/Framework/Utilities/CommonUtil.py @@ -213,8 +213,8 @@ def GetExecutor(): return executor -def ShutdownExecutor(): - executor.shutdown() +def ShutdownExecutor(wait: bool = True): + executor.shutdown(wait=wait) def SaveThread(key, thread): diff --git a/Framework/Utilities/RequestFormatter.py b/Framework/Utilities/RequestFormatter.py index 2c979e1e8..7c23f725b 100644 --- a/Framework/Utilities/RequestFormatter.py +++ b/Framework/Utilities/RequestFormatter.py @@ -28,6 +28,14 @@ SESSION_FILE_NAME = "session.bin" ACCESS_TOKEN_EXPIRES_AT = datetime.now() + +def close_session(): + global session + try: + session.close() + except Exception: + pass + def save_cookies(session: requests.Session, filename: str): try: with open(filename, 'wb') as f: diff --git a/Framework/deploy_handler/long_poll_handler.py b/Framework/deploy_handler/long_poll_handler.py index 1a6de8c6f..25942d9e9 100644 --- a/Framework/deploy_handler/long_poll_handler.py +++ b/Framework/deploy_handler/long_poll_handler.py @@ -255,6 +255,8 @@ async def run(self, host: str) -> None: reconnect = False print_online = False while True: + if CommonUtil.run_cancelled: + break if STATE.reconnect_with_credentials is not None: break @@ -303,6 +305,8 @@ async def run(self, host: str) -> None: break reconnect = False + except asyncio.CancelledError: + break except ( requests.exceptions.ConnectTimeout, requests.exceptions.ReadTimeout, @@ -323,4 +327,3 @@ async def run(self, host: str) -> None: print(e) print(Fore.YELLOW + "Retrying after 30s") await asyncio.sleep(30) - diff --git a/Framework/install_handler/long_poll_handler.py b/Framework/install_handler/long_poll_handler.py index 87882fa0d..3716c6ecb 100644 --- a/Framework/install_handler/long_poll_handler.py +++ b/Framework/install_handler/long_poll_handler.py @@ -12,7 +12,7 @@ read_node_id, generate_services_list, ) -from Framework.Utilities import RequestFormatter, ConfigModule +from Framework.Utilities import RequestFormatter, ConfigModule, CommonUtil from Framework.node_server_state import STATE from Framework.install_handler.android.emulator import ( check_emulator_list, @@ -339,6 +339,10 @@ async def run(self) -> None: print(f"[installer] Started running") while not self.cancel_: + if CommonUtil.run_cancelled: + if debug: + print("[installer] Shutdown requested, stopping...") + break if STATE.reconnect_with_credentials is not None: if debug: print("[installer] Reconnection requested, stopping...") @@ -375,6 +379,10 @@ async def run(self) -> None: print(f"[installer] Type Error in parsing response: {e}") continue + except asyncio.CancelledError: + if debug: + print("[installer] Cancelled.") + break except Exception: if debug: traceback.print_exc() diff --git a/node_cli.py b/node_cli.py index d4e1ce890..8e7a3301a 100755 --- a/node_cli.py +++ b/node_cli.py @@ -71,6 +71,7 @@ def adjust_python_path(): async def start_server(): + global _uvicorn_server def is_port_in_use(port): with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: return s.connect_ex(("127.0.0.1", port)) == 0 @@ -92,8 +93,11 @@ def is_port_in_use(port): log_level="warning", ) server = uvicorn.Server(config) + _uvicorn_server = server await server.serve() + except asyncio.CancelledError: + return except Exception as e: traceback.print_exc() print(f"[WARN] Failed to launch node-server: {str(e)}") @@ -144,6 +148,14 @@ def setup_nodejs_appium(): TMP_INI_FILE = None +SHUTDOWN_TIMEOUT_SECONDS = 5 +shutdown_event = asyncio.Event() +_shutdown_started = False +_shutdown_complete = asyncio.Event() +_main_loop: asyncio.AbstractEventLoop | None = None +_background_tasks: set[asyncio.Task] = set() +_uvicorn_server: uvicorn.Server | None = None +_install_handler: InstallHandler | None = None """Constants""" AUTHENTICATION_TAG = "Authentication" @@ -152,25 +164,120 @@ def setup_nodejs_appium(): device_dict: dict[str, Any] = {} -def kill_child_processes(): +def track_task(task: asyncio.Task) -> asyncio.Task: + _background_tasks.add(task) + task.add_done_callback(lambda done_task: _background_tasks.discard(done_task)) + return task + + +def terminate_child_processes(timeout_seconds: int = 2) -> None: try: parent = psutil.Process() children = parent.children(recursive=True) for child in children: try: - child.kill() + child.terminate() except psutil.NoSuchProcess: pass + + if children: + _, still_alive = psutil.wait_procs(children, timeout=timeout_seconds) + for child in still_alive: + try: + child.kill() + except psutil.NoSuchProcess: + pass except Exception: pass +def request_shutdown(reason: str) -> None: + global _shutdown_started + if _shutdown_started: + return + _shutdown_started = True + if _main_loop is None: + os._exit(0) + _main_loop.call_soon_threadsafe( + lambda: asyncio.create_task(shutdown(reason)) + ) + + def signal_handler(sig, frame): - print("\n--- SIGINT received, quitting ---\n") + try: + signal_name = signal.Signals(sig).name + except Exception: + signal_name = str(sig) + print(f"\n--- {signal_name} received, shutting down ---\n") + request_shutdown(signal_name) + + +async def shutdown(reason: str) -> None: + if shutdown_event.is_set(): + return + shutdown_event.set() CommonUtil.run_cancelled = True - CommonUtil.ShutdownExecutor() - kill_child_processes() - os._exit(0) + loop = asyncio.get_running_loop() + hard_exit = loop.call_later(SHUTDOWN_TIMEOUT_SECONDS, lambda: os._exit(1)) + try: + if _install_handler is not None: + try: + await _install_handler.cancel_run() + except Exception: + pass + + if _uvicorn_server is not None: + _uvicorn_server.should_exit = True + + try: + live_log_service.close() + except Exception: + pass + + try: + RequestFormatter.close_session() + except Exception: + pass + + try: + CommonUtil.ShutdownExecutor(wait=False) + except Exception: + pass + + pending_tasks = [ + task + for task in list(_background_tasks) + if not task.done() and task is not asyncio.current_task() + ] + for task in pending_tasks: + task.cancel() + + if pending_tasks: + try: + await asyncio.wait_for( + asyncio.gather(*pending_tasks, return_exceptions=True), + timeout=SHUTDOWN_TIMEOUT_SECONDS, + ) + except asyncio.TimeoutError: + pass + + await asyncio.to_thread(terminate_child_processes, 2) + finally: + hard_exit.cancel() + _shutdown_complete.set() + + +def setup_signal_handlers() -> None: + signals_to_handle = [signal.SIGINT] + if hasattr(signal, "SIGTERM"): + signals_to_handle.append(signal.SIGTERM) + if hasattr(signal, "SIGBREAK"): + signals_to_handle.append(signal.SIGBREAK) + for sig in signals_to_handle: + try: + signal.signal(sig, signal_handler) + except Exception: + pass async def destroy_session(): @@ -369,7 +476,10 @@ def notify_complete(message="Run completed"): async def RunProcess(node_id, log_dir=None): + global _install_handler try: + if shutdown_event.is_set(): + return False # --- START websocket service connections --- # def live_log_service_addr(): @@ -401,10 +511,13 @@ def deploy_srv_addr(): from Framework import node_server_state install_handler = InstallHandler() - install_task = asyncio.create_task(install_handler.run()) + _install_handler = install_handler + install_task = track_task(asyncio.create_task(install_handler.run())) async def response_callback(response: str): node_server_state.STATE.state = "in_progress" + if shutdown_event.is_set(): + return nonlocal node_json nonlocal log_dir if log_dir is None: @@ -460,7 +573,8 @@ async def done_callback() -> bool: print("[deploy] Run complete.") notify_complete("Run completed") - asyncio.create_task(install_handler.run()) + if not shutdown_event.is_set(): + track_task(asyncio.create_task(install_handler.run())) return False @@ -471,7 +585,8 @@ async def cancel_callback(): print("[deploy] Run cancelled.") notify_complete("Run cancelled") CommonUtil.run_cancelled = True - asyncio.create_task(install_handler.run()) + if not shutdown_event.is_set(): + track_task(asyncio.create_task(install_handler.run())) deploy_handler = long_poll_handler.DeployHandler( on_connect_callback=on_connect_callback, @@ -480,12 +595,14 @@ async def cancel_callback(): done_callback=done_callback, ) - deploy_task = asyncio.create_task(deploy_handler.run(deploy_srv_addr())) + deploy_task = track_task(asyncio.create_task(deploy_handler.run(deploy_srv_addr()))) await asyncio.gather(install_task, deploy_task, return_exceptions=True) return False + except asyncio.CancelledError: + return False except Exception: exc_type, exc_obj, exc_tb = sys.exc_info() fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1] @@ -1317,7 +1434,9 @@ def create_temp_ini_automation_log(): async def main(): + global _main_loop print_system_info_version() + _main_loop = asyncio.get_running_loop() load_dotenv() adjust_python_path() ConfigModule.remove_settings_lock_file() @@ -1346,13 +1465,13 @@ async def main(): update_android_sdk_path() update_outdated_modules() - asyncio.create_task(start_server()) - asyncio.create_task(upload_android_ui_dump()) - asyncio.create_task(upload_ios_ui_dump()) - asyncio.create_task(delete_old_automationlog_folders()) + track_task(asyncio.create_task(start_server())) + track_task(asyncio.create_task(upload_android_ui_dump(shutdown_event))) + track_task(asyncio.create_task(upload_ios_ui_dump(shutdown_event))) + track_task(asyncio.create_task(delete_old_automationlog_folders())) await destroy_session() - signal.signal(signal.SIGINT, signal_handler) + setup_signal_handlers() print("Press Ctrl-C or Ctrl-Break to disconnect and quit.") console = Console() @@ -1382,13 +1501,15 @@ async def main(): await asyncio.sleep(1) else: - asyncio.create_task( + track_task(asyncio.create_task( Login( server_name=server_name, log_dir=log_dir, ) - ) + )) while True: + if shutdown_event.is_set(): + break if STATE.reconnect_with_credentials is not None: await destroy_session() server_name = STATE.reconnect_with_credentials.server @@ -1414,13 +1535,16 @@ async def main(): ) console.print("Please log in to ZeuZ server and connect.") - asyncio.create_task( + track_task(asyncio.create_task( Login( server_name=server_name, log_dir=log_dir, ) - ) + )) await asyncio.sleep(1) + if shutdown_event.is_set(): + await _shutdown_complete.wait() + asyncio.run(main()) diff --git a/server/mobile.py b/server/mobile.py index 68154b452..a0583d981 100644 --- a/server/mobile.py +++ b/server/mobile.py @@ -469,9 +469,24 @@ def capture_ios_ui_dump(device_udid: str): pass -async def upload_android_ui_dump(): +async def _sleep_or_shutdown(shutdown_event: asyncio.Event | None, seconds: int) -> bool: + if shutdown_event is None: + await asyncio.sleep(seconds) + return False + try: + await asyncio.wait_for(shutdown_event.wait(), timeout=seconds) + except asyncio.TimeoutError: + return False + return True + + +async def upload_android_ui_dump(shutdown_event: asyncio.Event | None = None): prev_xml_hash = "" while True: + if shutdown_event is not None and shutdown_event.is_set(): + return + if CommonUtil.run_cancelled: + return try: await asyncio.to_thread(capture_ui_dump) try: @@ -487,12 +502,14 @@ async def upload_android_ui_dump(): ).hexdigest() # Don't upload if the content hasn't changed if prev_xml_hash == new_xml_hash: - await asyncio.sleep(5) + if await _sleep_or_shutdown(shutdown_event, 5): + return continue prev_xml_hash = new_xml_hash except FileNotFoundError: - await asyncio.sleep(5) + if await _sleep_or_shutdown(shutdown_event, 5): + return continue url = ( ConfigModule.get_config_value( @@ -515,16 +532,22 @@ async def upload_android_ui_dump(): CommonUtil.ExecLog("", "UI dump uploaded successfully", iLogLevel=1) except Exception as e: CommonUtil.ExecLog("", f"Error uploading UI dump: {str(e)}", iLogLevel=3) - await asyncio.sleep(5) + if await _sleep_or_shutdown(shutdown_event, 5): + return -async def upload_ios_ui_dump(): +async def upload_ios_ui_dump(shutdown_event: asyncio.Event | None = None): prev_xml_hash = "" while True: + if shutdown_event is not None and shutdown_event.is_set(): + return + if CommonUtil.run_cancelled: + return try: ios_devices = get_ios_devices() if not ios_devices: - await asyncio.sleep(5) + if await _sleep_or_shutdown(shutdown_event, 5): + return continue device_udid = ios_devices[0].udid @@ -537,12 +560,14 @@ async def upload_ios_ui_dump(): new_xml_hash = hashlib.sha256(xml_content.encode('utf-8')).hexdigest() # Don't upload if the content hasn't changed if prev_xml_hash == new_xml_hash: - await asyncio.sleep(5) + if await _sleep_or_shutdown(shutdown_event, 5): + return continue prev_xml_hash = new_xml_hash except FileNotFoundError: - await asyncio.sleep(5) + if await _sleep_or_shutdown(shutdown_event, 5): + return continue url = ConfigModule.get_config_value("Authentication", "server_address").strip() + "/node_ai_contents/" @@ -558,6 +583,8 @@ async def upload_ios_ui_dump(): CommonUtil.ExecLog("", "UI dump uploaded successfully", iLogLevel=1) except Exception as e: CommonUtil.ExecLog("", f"Error uploading iOS UI dump: {str(e)}", iLogLevel=3) + if await _sleep_or_shutdown(shutdown_event, 5): + return await asyncio.sleep(5) @@ -778,4 +805,4 @@ def is_ios_app_installed(sim_udid: str, bundle_id: str): return {"installed": False} except Exception as e: - return {"installed": False, "error": str(e)} \ No newline at end of file + return {"installed": False, "error": str(e)} diff --git a/server/node_operator.py b/server/node_operator.py index a80bc2d74..cea7bc99a 100644 --- a/server/node_operator.py +++ b/server/node_operator.py @@ -16,7 +16,8 @@ class OperatorResponse(BaseModel): @router.post("/kill") def kill(): print("[Node server] Kill signal received. Shutting down.") - os.kill(os.getpid(), signal.SIGINT) + shutdown_signal = signal.SIGTERM if hasattr(signal, "SIGTERM") else signal.SIGINT + os.kill(os.getpid(), shutdown_signal) return OperatorResponse(status="ok")