From 3d34735ca30092b516c008480cd9e7adc092c782 Mon Sep 17 00:00:00 2001 From: Summer Yang Date: Wed, 11 Mar 2026 17:30:00 -0700 Subject: [PATCH 01/22] fix(test): resolve thread leak failures in CI - Add mod.stop() to test_process_crash_triggers_stop so watchdog, LCM, and event-loop threads are properly joined from the test thread - Filter third-party daemon threads with generic names (Thread-\d+) in conftest monitor_threads to ignore torch/HF background threads that have no cleanup API --- dimos/conftest.py | 11 +++++++++++ dimos/core/test_native_module.py | 4 ++++ 2 files changed, 15 insertions(+) diff --git a/dimos/conftest.py b/dimos/conftest.py index 4ab8a401f8..5f7f30e882 100644 --- a/dimos/conftest.py +++ b/dimos/conftest.py @@ -14,6 +14,7 @@ import asyncio import os +import re import threading from dotenv import load_dotenv @@ -160,6 +161,16 @@ def monitor_threads(request): if not any(t.name.startswith(prefix) for prefix in expected_persistent_thread_prefixes) ] + # Filter out third-party daemon threads with generic names (e.g. "Thread-109"). + # On Python 3.12+ our own threads include the target function name in parens + # (e.g. "Thread-166 (run_forever)"), so this only matches unnamed threads + # from libraries like torch/HuggingFace that have no cleanup API. + new_threads = [ + t + for t in new_threads + if not (t.daemon and re.fullmatch(r"Thread-\d+", t.name)) + ] + # Filter out threads we've already seen (from previous tests) truly_new = [t for t in new_threads if t.ident not in _seen_threads] diff --git a/dimos/core/test_native_module.py b/dimos/core/test_native_module.py index 5d57c42854..708d3aa48c 100644 --- a/dimos/core/test_native_module.py +++ b/dimos/core/test_native_module.py @@ -106,6 +106,10 @@ def test_process_crash_triggers_stop() -> None: break assert mod._process is None, f"Watchdog did not clean up after process {pid} died" + # Explicitly stop to join watchdog, LCM, and event-loop threads from the + # test thread. The watchdog's self.stop() can't join itself, so these + # threads would otherwise leak. stop() is idempotent. + mod.stop() # Wait for background threads (run_forever, _lcm_loop, _watch_process) to finish # after the watchdog-triggered stop(). Without this, monitor_threads catches them. From 6413d9932814e77d6cb8b936960bc0def485203b Mon Sep 17 00:00:00 2001 From: SUMMERxYANG <69720581+SUMMERxYANG@users.noreply.github.com> Date: Thu, 12 Mar 2026 00:35:29 +0000 Subject: [PATCH 02/22] CI code cleanup --- dimos/conftest.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/dimos/conftest.py b/dimos/conftest.py index 5f7f30e882..1a7a4f943b 100644 --- a/dimos/conftest.py +++ b/dimos/conftest.py @@ -166,9 +166,7 @@ def monitor_threads(request): # (e.g. "Thread-166 (run_forever)"), so this only matches unnamed threads # from libraries like torch/HuggingFace that have no cleanup API. new_threads = [ - t - for t in new_threads - if not (t.daemon and re.fullmatch(r"Thread-\d+", t.name)) + t for t in new_threads if not (t.daemon and re.fullmatch(r"Thread-\d+", t.name)) ] # Filter out threads we've already seen (from previous tests) From ce26d9a59501e7ef44e75e7a5c99b3f4b7cc85c8 Mon Sep 17 00:00:00 2001 From: Summer Yang Date: Thu, 12 Mar 2026 14:08:39 -0700 Subject: [PATCH 03/22] fix(test): use fixture for native module crash test cleanup Convert test_process_crash_triggers_stop to use a fixture that calls mod.stop() in teardown. The watchdog thread calls self.stop() but can't join itself, so an explicit stop() from the test thread is needed to properly clean up all threads. Drop the broad conftest regex filter for generic daemon thread names per review feedback. --- dimos/conftest.py | 8 -------- dimos/core/test_native_module.py | 33 ++++++++++++++++++++------------ 2 files changed, 21 insertions(+), 20 deletions(-) diff --git a/dimos/conftest.py b/dimos/conftest.py index 1a7a4f943b..29eaf05567 100644 --- a/dimos/conftest.py +++ b/dimos/conftest.py @@ -14,7 +14,6 @@ import asyncio import os -import re import threading from dotenv import load_dotenv @@ -161,13 +160,6 @@ def monitor_threads(request): if not any(t.name.startswith(prefix) for prefix in expected_persistent_thread_prefixes) ] - # Filter out third-party daemon threads with generic names (e.g. "Thread-109"). - # On Python 3.12+ our own threads include the target function name in parens - # (e.g. "Thread-166 (run_forever)"), so this only matches unnamed threads - # from libraries like torch/HuggingFace that have no cleanup API. - new_threads = [ - t for t in new_threads if not (t.daemon and re.fullmatch(r"Thread-\d+", t.name)) - ] # Filter out threads we've already seen (from previous tests) truly_new = [t for t in new_threads if t.ident not in _seen_threads] diff --git a/dimos/core/test_native_module.py b/dimos/core/test_native_module.py index 708d3aa48c..f0310d9109 100644 --- a/dimos/core/test_native_module.py +++ b/dimos/core/test_native_module.py @@ -18,8 +18,11 @@ The echo script writes received CLI args to a temp file for assertions. """ +from collections.abc import Generator +from dataclasses import dataclass import json from pathlib import Path +import threading import time import pytest @@ -90,26 +93,32 @@ def start(self) -> None: pass -def test_process_crash_triggers_stop() -> None: - """When the native process dies unexpectedly, the watchdog calls stop().""" +@pytest.fixture +def crash_module() -> Generator[StubNativeModule, None, None]: + """Create a StubNativeModule that dies after 0.2s, ensuring cleanup.""" mod = StubNativeModule(die_after=0.2) - mod.pointcloud.transport = LCMTransport("/pc", PointCloud2) - mod.start() + yield mod + # Join watchdog, LCM, and event-loop threads from the test thread. + # The watchdog's self.stop() can't join itself, so without this the + # threads leak. stop() is idempotent. + mod.stop() - assert mod._process is not None - pid = mod._process.pid + +def test_process_crash_triggers_stop(crash_module: StubNativeModule) -> None: + """When the native process dies unexpectedly, the watchdog calls stop().""" + crash_module.pointcloud.transport = LCMTransport("/pc", PointCloud2) + crash_module.start() + + assert crash_module._process is not None + pid = crash_module._process.pid # Wait for the process to die and the watchdog to call stop() for _ in range(30): time.sleep(0.1) - if mod._process is None: + if crash_module._process is None: break - assert mod._process is None, f"Watchdog did not clean up after process {pid} died" - # Explicitly stop to join watchdog, LCM, and event-loop threads from the - # test thread. The watchdog's self.stop() can't join itself, so these - # threads would otherwise leak. stop() is idempotent. - mod.stop() + assert crash_module._process is None, f"Watchdog did not clean up after process {pid} died" # Wait for background threads (run_forever, _lcm_loop, _watch_process) to finish # after the watchdog-triggered stop(). Without this, monitor_threads catches them. From b11eb6e00f79017051e1ec4fe39762ff5e6756e2 Mon Sep 17 00:00:00 2001 From: SUMMERxYANG <69720581+SUMMERxYANG@users.noreply.github.com> Date: Thu, 12 Mar 2026 21:09:29 +0000 Subject: [PATCH 04/22] CI code cleanup --- dimos/conftest.py | 1 - 1 file changed, 1 deletion(-) diff --git a/dimos/conftest.py b/dimos/conftest.py index 29eaf05567..4ab8a401f8 100644 --- a/dimos/conftest.py +++ b/dimos/conftest.py @@ -160,7 +160,6 @@ def monitor_threads(request): if not any(t.name.startswith(prefix) for prefix in expected_persistent_thread_prefixes) ] - # Filter out threads we've already seen (from previous tests) truly_new = [t for t in new_threads if t.ident not in _seen_threads] From 820be9a86a2a3874d66e54c1db065d0061ab6157 Mon Sep 17 00:00:00 2001 From: Summer Yang Date: Thu, 12 Mar 2026 14:13:08 -0700 Subject: [PATCH 05/22] chore: retrigger CI From 38a31fcfe2b0d78d8d013ab6da7a566df8275580 Mon Sep 17 00:00:00 2001 From: Summer Yang Date: Thu, 12 Mar 2026 15:58:14 -0700 Subject: [PATCH 06/22] fix(test): join threads directly in crash_module fixture mod.stop() is a no-op when the watchdog already called it, so capture thread IDs before the test and join new ones in teardown. --- dimos/core/test_native_module.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/dimos/core/test_native_module.py b/dimos/core/test_native_module.py index f0310d9109..af4d7e992b 100644 --- a/dimos/core/test_native_module.py +++ b/dimos/core/test_native_module.py @@ -96,12 +96,15 @@ def start(self) -> None: @pytest.fixture def crash_module() -> Generator[StubNativeModule, None, None]: """Create a StubNativeModule that dies after 0.2s, ensuring cleanup.""" + before = {t.ident for t in threading.enumerate()} mod = StubNativeModule(die_after=0.2) yield mod - # Join watchdog, LCM, and event-loop threads from the test thread. - # The watchdog's self.stop() can't join itself, so without this the - # threads leak. stop() is idempotent. - mod.stop() + # The watchdog calls stop() from its own thread, which sets + # _module_closed=True. A second stop() from here is then a no-op, + # so we explicitly join any threads the test created. + for t in threading.enumerate(): + if t.ident not in before and t is not threading.current_thread(): + t.join(timeout=5) def test_process_crash_triggers_stop(crash_module: StubNativeModule) -> None: From 055b7f3d22b43022b2e1e13a0ff870d60f808c89 Mon Sep 17 00:00:00 2001 From: SUMMERxYANG <69720581+SUMMERxYANG@users.noreply.github.com> Date: Thu, 12 Mar 2026 23:09:08 +0000 Subject: [PATCH 07/22] CI code cleanup --- dimos/core/test_native_module.py | 1 - 1 file changed, 1 deletion(-) diff --git a/dimos/core/test_native_module.py b/dimos/core/test_native_module.py index af4d7e992b..d192d7c72a 100644 --- a/dimos/core/test_native_module.py +++ b/dimos/core/test_native_module.py @@ -19,7 +19,6 @@ """ from collections.abc import Generator -from dataclasses import dataclass import json from pathlib import Path import threading From 8fb0526744ffd69f0f31d3a45be287266653a961 Mon Sep 17 00:00:00 2001 From: Summer Yang Date: Mon, 23 Mar 2026 16:58:44 -0700 Subject: [PATCH 08/22] fix(native_module): preserve watchdog reference so second stop() can join it --- dimos/core/native_module.py | 12 ++++++++--- dimos/core/test_native_module.py | 35 +++++++++++--------------------- 2 files changed, 21 insertions(+), 26 deletions(-) diff --git a/dimos/core/native_module.py b/dimos/core/native_module.py index 74471f34d5..eb03bb7acf 100644 --- a/dimos/core/native_module.py +++ b/dimos/core/native_module.py @@ -196,11 +196,17 @@ def stop(self) -> None: ) self._process.kill() self._process.wait(timeout=5) - if self._watchdog is not None and self._watchdog is not threading.current_thread(): - self._watchdog.join(timeout=2) - self._watchdog = None self._process = None super().stop() + # Join the watchdog AFTER super().stop() so all module threads are + # cleaned up first. When the watchdog itself is the caller (crash + # path), it skips joining itself — but the thread exits naturally + # right after this returns. A second stop() from external code + # (e.g. test teardown) will reach here and join the now-finished + # watchdog thread, preventing monitor_threads from seeing a leak. + if self._watchdog is not None and self._watchdog is not threading.current_thread(): + self._watchdog.join(timeout=2) + self._watchdog = None def _watch_process(self) -> None: """Block until the native process exits; trigger stop() if it crashed.""" diff --git a/dimos/core/test_native_module.py b/dimos/core/test_native_module.py index d192d7c72a..361b45fdd1 100644 --- a/dimos/core/test_native_module.py +++ b/dimos/core/test_native_module.py @@ -18,10 +18,8 @@ The echo script writes received CLI args to a temp file for assertions. """ -from collections.abc import Generator import json from pathlib import Path -import threading import time import pytest @@ -92,35 +90,26 @@ def start(self) -> None: pass -@pytest.fixture -def crash_module() -> Generator[StubNativeModule, None, None]: - """Create a StubNativeModule that dies after 0.2s, ensuring cleanup.""" - before = {t.ident for t in threading.enumerate()} - mod = StubNativeModule(die_after=0.2) - yield mod - # The watchdog calls stop() from its own thread, which sets - # _module_closed=True. A second stop() from here is then a no-op, - # so we explicitly join any threads the test created. - for t in threading.enumerate(): - if t.ident not in before and t is not threading.current_thread(): - t.join(timeout=5) - - -def test_process_crash_triggers_stop(crash_module: StubNativeModule) -> None: +def test_process_crash_triggers_stop() -> None: """When the native process dies unexpectedly, the watchdog calls stop().""" - crash_module.pointcloud.transport = LCMTransport("/pc", PointCloud2) - crash_module.start() + mod = StubNativeModule(die_after=0.2) + mod.pointcloud.transport = LCMTransport("/pc", PointCloud2) + mod.start() - assert crash_module._process is not None - pid = crash_module._process.pid + assert mod._process is not None + pid = mod._process.pid # Wait for the process to die and the watchdog to call stop() for _ in range(30): time.sleep(0.1) - if crash_module._process is None: + if mod._process is None: break - assert crash_module._process is None, f"Watchdog did not clean up after process {pid} died" + assert mod._process is None, f"Watchdog did not clean up after process {pid} died" + + # Join the watchdog thread. stop() is idempotent but will now join the + # watchdog on the second call since the reference is preserved. + mod.stop() # Wait for background threads (run_forever, _lcm_loop, _watch_process) to finish # after the watchdog-triggered stop(). Without this, monitor_threads catches them. From 060e049325183967520f70f357775084e82c416c Mon Sep 17 00:00:00 2001 From: Jeff Hykin Date: Tue, 24 Mar 2026 10:39:47 -0700 Subject: [PATCH 09/22] minimal fix --- dimos/core/native_module.py | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/dimos/core/native_module.py b/dimos/core/native_module.py index eb03bb7acf..d02e7c06bd 100644 --- a/dimos/core/native_module.py +++ b/dimos/core/native_module.py @@ -182,8 +182,8 @@ def start(self) -> None: self._watchdog = threading.Thread(target=self._watch_process, daemon=True) self._watchdog.start() - @rpc - def stop(self) -> None: + def _clean_all_but_watchdog(self) -> None: + """A cleanup helper designed to be called inside of the watchdog and outside of the watchdong""" self._stopping = True if self._process is not None and self._process.poll() is None: logger.info("Stopping native process", pid=self._process.pid) @@ -195,21 +195,23 @@ def stop(self) -> None: "Native process did not exit, sending SIGKILL", pid=self._process.pid ) self._process.kill() - self._process.wait(timeout=5) + try: + self._process.wait(timeout=5) + except Exception as error: + print(f'''error = {error}''') self._process = None + + + @rpc + def stop(self) -> None: + self._clean_all_but_watchdog() super().stop() - # Join the watchdog AFTER super().stop() so all module threads are - # cleaned up first. When the watchdog itself is the caller (crash - # path), it skips joining itself — but the thread exits naturally - # right after this returns. A second stop() from external code - # (e.g. test teardown) will reach here and join the now-finished - # watchdog thread, preventing monitor_threads from seeing a leak. - if self._watchdog is not None and self._watchdog is not threading.current_thread(): + if self._watchdog is not None: self._watchdog.join(timeout=2) self._watchdog = None def _watch_process(self) -> None: - """Block until the native process exits; trigger stop() if it crashed.""" + """Block until the native process exits; trigger cleanup if it crashed.""" if self._process is None: return @@ -236,7 +238,7 @@ def _watch_process(self) -> None: returncode=rc, last_stderr=last_stderr[:500] if last_stderr else None, ) - self.stop() + self._clean_all_but_watchdog() def _start_reader(self, stream: IO[bytes] | None, level: str) -> threading.Thread: """Spawn a daemon thread that pipes a subprocess stream through the logger.""" From a20d1618af3222bf1e1c1e5f81cbafd43ac89057 Mon Sep 17 00:00:00 2001 From: Jeff Hykin Date: Tue, 24 Mar 2026 23:03:09 -0700 Subject: [PATCH 10/22] ideal approach, not tested --- dimos/agents/mcp/mcp_server.py | 5 +- dimos/core/module.py | 76 +- dimos/core/native_module.py | 159 +--- dimos/core/test_core.py | 2 +- dimos/perception/detection/conftest.py | 10 +- .../perception/detection/reid/test_module.py | 2 +- dimos/robot/unitree/b1/test_connection.py | 281 +++--- dimos/utils/test_thread_utils.py | 897 ++++++++++++++++++ dimos/utils/thread_utils.py | 542 +++++++++++ dimos/utils/typing_utils.py | 45 + 10 files changed, 1672 insertions(+), 347 deletions(-) create mode 100644 dimos/utils/test_thread_utils.py create mode 100644 dimos/utils/thread_utils.py create mode 100644 dimos/utils/typing_utils.py diff --git a/dimos/agents/mcp/mcp_server.py b/dimos/agents/mcp/mcp_server.py index 9149de06ec..d09202532c 100644 --- a/dimos/agents/mcp/mcp_server.py +++ b/dimos/agents/mcp/mcp_server.py @@ -180,7 +180,7 @@ def start(self) -> None: def stop(self) -> None: if self._uvicorn_server: self._uvicorn_server.should_exit = True - loop = self._loop + loop = self._async_thread.loop if loop is not None and self._serve_future is not None: self._serve_future.result(timeout=5.0) self._uvicorn_server = None @@ -250,6 +250,5 @@ def _start_server(self, port: int | None = None) -> None: config = uvicorn.Config(app, host=_host, port=_port, log_level="info") server = uvicorn.Server(config) self._uvicorn_server = server - loop = self._loop - assert loop is not None + loop = self._async_thread.loop self._serve_future = asyncio.run_coroutine_threadsafe(server.serve(), loop) diff --git a/dimos/core/module.py b/dimos/core/module.py index 1c5b311883..d9f8356f38 100644 --- a/dimos/core/module.py +++ b/dimos/core/module.py @@ -22,6 +22,7 @@ from typing import ( TYPE_CHECKING, Any, + Literal, Protocol, get_args, get_origin, @@ -45,6 +46,9 @@ from dimos.protocol.tf.tf import LCMTF, TFSpec from dimos.utils import colors from dimos.utils.generic import classproperty +from dimos.utils.thread_utils import AsyncModuleThread, ThreadSafeVal + +ModState = Literal["init", "started", "stopping", "stopped"] if TYPE_CHECKING: from dimos.core.blueprints import Blueprint @@ -64,19 +68,6 @@ class SkillInfo: args_schema: str -def get_loop() -> tuple[asyncio.AbstractEventLoop, threading.Thread | None]: - try: - running_loop = asyncio.get_running_loop() - return running_loop, None - except RuntimeError: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - thr = threading.Thread(target=loop.run_forever, daemon=True) - thr.start() - return loop, thr - - class ModuleConfig(BaseConfig): rpc_transport: type[RPCSpec] = LCMRPC tf_transport: type[TFSpec] = LCMTF # type: ignore[type-arg] @@ -98,20 +89,20 @@ class ModuleBase(Configurable[ModuleConfigT], Resource): _rpc: RPCSpec | None = None _tf: TFSpec[Any] | None = None - _loop: asyncio.AbstractEventLoop | None = None - _loop_thread: threading.Thread | None + _async_thread: AsyncModuleThread _disposables: CompositeDisposable _bound_rpc_calls: dict[str, RpcCall] = {} - _module_closed: bool = False - _module_closed_lock: threading.Lock + mod_state: ThreadSafeVal[ModState] rpc_calls: list[str] = [] def __init__(self, config_args: dict[str, Any]): super().__init__(**config_args) - self._module_closed_lock = threading.Lock() - self._loop, self._loop_thread = get_loop() self._disposables = CompositeDisposable() + self.mod_state = ThreadSafeVal[ModState]("init") + self._async_thread = AsyncModuleThread( # NEEDS to be created after self._disposables exists + module=self + ) try: self.rpc = self.config.rpc_transport() self.rpc.serve_module_rpc(self) @@ -128,38 +119,30 @@ def frame_id(self) -> str: @rpc def start(self) -> None: - pass + with self.mod_state as state: + if state == "stopped": + raise RuntimeError(f"{type(self).__name__} cannot be restarted after stop") + self.mod_state.set("started") @rpc def stop(self) -> None: - self._close_module() + self._stop() - def _close_module(self) -> None: - with self._module_closed_lock: - if self._module_closed: + def _stop(self) -> None: + with self.mod_state as state: + if state in ("stopping", "stopped"): return - self._module_closed = True - - self._close_rpc() - - # Save into local variables to avoid race when stopping concurrently - # (from RPC and worker shutdown) - loop_thread = getattr(self, "_loop_thread", None) - loop = getattr(self, "_loop", None) + self.mod_state.set("stopping") - if loop_thread: - if loop_thread.is_alive(): - if loop: - loop.call_soon_threadsafe(loop.stop) - loop_thread.join(timeout=2) - self._loop = None - self._loop_thread = None + if self.rpc: + self.rpc.stop() # type: ignore[attr-defined] + self.rpc = None # type: ignore[assignment] if hasattr(self, "_tf") and self._tf is not None: self._tf.stop() self._tf = None if hasattr(self, "_disposables"): - self._disposables.dispose() + self._disposables.dispose() # stops _async_thread via disposable # Break the In/Out -> owner -> self reference cycle so the instance # can be freed by refcount instead of waiting for GC. @@ -167,19 +150,12 @@ def _close_module(self) -> None: if isinstance(attr, (In, Out)): attr.owner = None - def _close_rpc(self) -> None: - if self.rpc: - self.rpc.stop() # type: ignore[attr-defined] - self.rpc = None # type: ignore[assignment] - def __getstate__(self): # type: ignore[no-untyped-def] """Exclude unpicklable runtime attributes when serializing.""" state = self.__dict__.copy() # Remove unpicklable attributes state.pop("_disposables", None) - state.pop("_module_closed_lock", None) - state.pop("_loop", None) - state.pop("_loop_thread", None) + state.pop("_async_thread", None) state.pop("_rpc", None) state.pop("_tf", None) return state @@ -189,9 +165,7 @@ def __setstate__(self, state) -> None: # type: ignore[no-untyped-def] self.__dict__.update(state) # Reinitialize runtime attributes self._disposables = CompositeDisposable() - self._module_closed_lock = threading.Lock() - self._loop = None - self._loop_thread = None + self._async_thread = None # type: ignore[assignment] self._rpc = None self._tf = None diff --git a/dimos/core/native_module.py b/dimos/core/native_module.py index d02e7c06bd..bdc46e2cab 100644 --- a/dimos/core/native_module.py +++ b/dimos/core/native_module.py @@ -40,23 +40,20 @@ class MyCppModule(NativeModule): from __future__ import annotations -import collections import enum import inspect -import json import os from pathlib import Path -import signal import subprocess import sys -import threading -from typing import IO, Any +from typing import Any from pydantic import Field from dimos.core.core import rpc from dimos.core.module import Module, ModuleConfig from dimos.utils.logging_config import setup_logger +from dimos.utils.thread_utils import ModuleProcess if sys.version_info < (3, 13): from typing_extensions import TypeVar @@ -129,144 +126,30 @@ class NativeModule(Module[_NativeConfig]): """ default_config: type[_NativeConfig] = NativeModuleConfig # type: ignore[assignment] - _process: subprocess.Popen[bytes] | None = None - _watchdog: threading.Thread | None = None - _stopping: bool = False - _last_stderr_lines: collections.deque[str] + _proc: ModuleProcess | None = None def __init__(self, **kwargs: Any) -> None: super().__init__(**kwargs) - self._last_stderr_lines = collections.deque(maxlen=50) self._resolve_paths() @rpc def start(self) -> None: - if self._process is not None and self._process.poll() is None: - logger.warning("Native process already running", pid=self._process.pid) - return - + super().start() self._maybe_build() - - topics = self._collect_topics() - - cmd = [self.config.executable] - for name, topic_str in topics.items(): - cmd.extend([f"--{name}", topic_str]) - cmd.extend(self.config.to_cli_args()) - cmd.extend(self.config.extra_args) - - env = {**os.environ, **self.config.extra_env} - cwd = self.config.cwd or str(Path(self.config.executable).resolve().parent) - - module_name = type(self).__name__ - logger.info( - f"Starting native process: {module_name}", - module=module_name, - cmd=" ".join(cmd), - cwd=cwd, - ) - self._process = subprocess.Popen( - cmd, - env=env, - cwd=cwd, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - ) - logger.info( - f"Native process started: {module_name}", - module=module_name, - pid=self._process.pid, - ) - - self._stopping = False - self._watchdog = threading.Thread(target=self._watch_process, daemon=True) - self._watchdog.start() - - def _clean_all_but_watchdog(self) -> None: - """A cleanup helper designed to be called inside of the watchdog and outside of the watchdong""" - self._stopping = True - if self._process is not None and self._process.poll() is None: - logger.info("Stopping native process", pid=self._process.pid) - self._process.send_signal(signal.SIGTERM) - try: - self._process.wait(timeout=self.config.shutdown_timeout) - except subprocess.TimeoutExpired: - logger.warning( - "Native process did not exit, sending SIGKILL", pid=self._process.pid - ) - self._process.kill() - try: - self._process.wait(timeout=5) - except Exception as error: - print(f'''error = {error}''') - self._process = None - - - @rpc - def stop(self) -> None: - self._clean_all_but_watchdog() - super().stop() - if self._watchdog is not None: - self._watchdog.join(timeout=2) - self._watchdog = None - - def _watch_process(self) -> None: - """Block until the native process exits; trigger cleanup if it crashed.""" - if self._process is None: - return - - stdout_t = self._start_reader(self._process.stdout, "info") - stderr_t = self._start_reader(self._process.stderr, "warning") - rc = self._process.wait() - stdout_t.join(timeout=2) - stderr_t.join(timeout=2) - - if self._stopping: - return - - module_name = type(self).__name__ - exe_name = Path(self.config.executable).name if self.config.executable else "unknown" - - # Use buffered stderr lines from the reader thread for the crash report. - last_stderr = "\n".join(self._last_stderr_lines) - - logger.error( - f"Native process crashed: {module_name} ({exe_name})", - module=module_name, - executable=exe_name, - pid=self._process.pid, - returncode=rc, - last_stderr=last_stderr[:500] if last_stderr else None, + self._proc = ModuleProcess( + module=self, + args=[ + self.config.executable, + *[arg for name, topic in self._collect_topics().items() for arg in (f"--{name}", topic)], + *self.config.to_cli_args(), + *self.config.extra_args, + ], + env={**os.environ, **self.config.extra_env}, + cwd=self.config.cwd or str(Path(self.config.executable).resolve().parent), + on_exit=self.stop, + shutdown_timeout=self.config.shutdown_timeout, + log_json=self.config.log_format == LogFormat.JSON, ) - self._clean_all_but_watchdog() - - def _start_reader(self, stream: IO[bytes] | None, level: str) -> threading.Thread: - """Spawn a daemon thread that pipes a subprocess stream through the logger.""" - t = threading.Thread(target=self._read_log_stream, args=(stream, level), daemon=True) - t.start() - return t - - def _read_log_stream(self, stream: IO[bytes] | None, level: str) -> None: - if stream is None: - return - log_fn = getattr(logger, level) - is_stderr = level == "warning" - for raw in stream: - line = raw.decode("utf-8", errors="replace").rstrip() - if not line: - continue - if is_stderr: - self._last_stderr_lines.append(line) - if self.config.log_format == LogFormat.JSON: - try: - data = json.loads(line) - event = data.pop("event", line) - log_fn(event, **data) - continue - except (json.JSONDecodeError, TypeError): - logger.warning("malformed JSON from native module", raw=line) - log_fn(line, pid=self._process.pid if self._process else None) - stream.close() def _resolve_paths(self) -> None: """Resolve relative ``cwd`` and ``executable`` against the subclass's source file.""" @@ -308,16 +191,12 @@ def _maybe_build(self) -> None: if line.strip(): logger.warning(line) if proc.returncode != 0: - stderr_tail = stderr.decode("utf-8", errors="replace").strip()[-1000:] raise RuntimeError( - f"Build command failed (exit {proc.returncode}): {self.config.build_command}\n" - f"stderr: {stderr_tail}" + f"Build command failed (exit {proc.returncode}): {self.config.build_command}" ) if not exe.exists(): raise FileNotFoundError( - f"Build command succeeded but executable still not found: {exe}\n" - f"Build output may have been written to a different path. " - f"Check that build_command produces the executable at the expected location." + f"Build command succeeded but executable still not found: {exe}" ) def _collect_topics(self) -> dict[str, str]: diff --git a/dimos/core/test_core.py b/dimos/core/test_core.py index f9a89829d5..aae167d8c6 100644 --- a/dimos/core/test_core.py +++ b/dimos/core/test_core.py @@ -89,7 +89,7 @@ def test_classmethods() -> None: ) assert hasattr(class_rpcs["start"], "__rpc__"), "start should have __rpc__ attribute" - nav._close_module() + nav._stop() @pytest.mark.slow diff --git a/dimos/perception/detection/conftest.py b/dimos/perception/detection/conftest.py index 5f8f1bc4b9..2040d687be 100644 --- a/dimos/perception/detection/conftest.py +++ b/dimos/perception/detection/conftest.py @@ -221,7 +221,7 @@ def moment_provider(**kwargs) -> Moment2D: yield moment_provider moment_provider.cache_clear() - module._close_module() + module._stop() @pytest.fixture(scope="session") @@ -256,7 +256,7 @@ def moment_provider(**kwargs) -> Moment3D: yield moment_provider moment_provider.cache_clear() if module is not None: - module._close_module() + module._stop() @pytest.fixture(scope="session") @@ -290,9 +290,9 @@ def object_db_module(get_moment): yield moduleDB - module2d._close_module() - module3d._close_module() - moduleDB._close_module() + module2d._stop() + module3d._stop() + moduleDB._stop() @pytest.fixture(scope="session") diff --git a/dimos/perception/detection/reid/test_module.py b/dimos/perception/detection/reid/test_module.py index aac6ba11d1..5fa0eead8d 100644 --- a/dimos/perception/detection/reid/test_module.py +++ b/dimos/perception/detection/reid/test_module.py @@ -40,5 +40,5 @@ def test_reid_ingress(imageDetections2d) -> None: print("Processing detections through ReidModule...") reid_module.annotations._transport = LCMTransport("/annotations", ImageAnnotations) reid_module.ingress(imageDetections2d) - reid_module._close_module() + reid_module._stop() print("✓ ReidModule ingress test completed successfully") diff --git a/dimos/robot/unitree/b1/test_connection.py b/dimos/robot/unitree/b1/test_connection.py index 9c7a2867fa..011853d172 100644 --- a/dimos/robot/unitree/b1/test_connection.py +++ b/dimos/robot/unitree/b1/test_connection.py @@ -22,14 +22,10 @@ # should be used and tested. Additionally, tests should always use `try-finally` # to clean up even if the test fails. -import sys import threading import time -_IS_MACOS = sys.platform == "darwin" - -from dimos.msgs.geometry_msgs.TwistStamped import TwistStamped -from dimos.msgs.geometry_msgs.Vector3 import Vector3 +from dimos.msgs.geometry_msgs import TwistStamped, Vector3 from dimos.msgs.std_msgs.Int32 import Int32 from .connection import MockB1ConnectionModule @@ -62,27 +58,22 @@ def test_watchdog_actually_zeros_commands(self) -> None: assert conn._current_cmd.mode == 2 assert not conn.timeout_active - try: - # Poll for watchdog timeout (generous 2s deadline) - deadline = time.time() + 2.0 - while time.time() < deadline: - if conn.timeout_active: - break - time.sleep(0.05) - - # Verify commands were zeroed by watchdog - assert conn._current_cmd.ly == 0.0 - assert conn._current_cmd.lx == 0.0 - assert conn._current_cmd.rx == 0.0 - assert conn._current_cmd.ry == 0.0 - assert conn._current_cmd.mode == 2 # Mode maintained - assert conn.timeout_active - finally: - conn.running = False - conn.watchdog_running = False - conn.send_thread.join(timeout=1.0) - conn.watchdog_thread.join(timeout=1.0) - conn._close_module() + # Wait for watchdog timeout (200ms + buffer) + time.sleep(0.3) + + # Verify commands were zeroed by watchdog + assert conn._current_cmd.ly == 0.0 + assert conn._current_cmd.lx == 0.0 + assert conn._current_cmd.rx == 0.0 + assert conn._current_cmd.ry == 0.0 + assert conn._current_cmd.mode == 2 # Mode maintained + assert conn.timeout_active + + conn.running = False + conn.watchdog_running = False + conn.send_thread.join(timeout=0.5) + conn.watchdog_thread.join(timeout=0.5) + conn._stop() def test_watchdog_resets_on_new_command(self) -> None: """Test that watchdog timeout resets when new command arrives.""" @@ -94,27 +85,43 @@ def test_watchdog_resets_on_new_command(self) -> None: conn.watchdog_thread = threading.Thread(target=conn._watchdog_loop, daemon=True) conn.watchdog_thread.start() - try: - # Send commands in rapid succession — each resets the 200ms watchdog - for val in [1.0, 0.8, 0.6, 0.5]: - twist = TwistStamped( - ts=time.time(), - frame_id="base_link", - linear=Vector3(val, 0, 0), - angular=Vector3(0, 0, 0), - ) - conn.handle_twist_stamped(twist) - time.sleep(0.02) # 20ms between commands, well under timeout - - # Command should be the last one sent and no timeout - assert conn._current_cmd.ly == 0.5 - assert not conn.timeout_active - finally: - conn.running = False - conn.watchdog_running = False - conn.send_thread.join(timeout=1.0) - conn.watchdog_thread.join(timeout=1.0) - conn._close_module() + # Send first command + twist1 = TwistStamped( + ts=time.time(), + frame_id="base_link", + linear=Vector3(1.0, 0, 0), + angular=Vector3(0, 0, 0), + ) + conn.handle_twist_stamped(twist1) + assert conn._current_cmd.ly == 1.0 + + # Wait 150ms (not enough to trigger timeout) + time.sleep(0.15) + + # Send second command before timeout + twist2 = TwistStamped( + ts=time.time(), + frame_id="base_link", + linear=Vector3(0.5, 0, 0), + angular=Vector3(0, 0, 0), + ) + conn.handle_twist_stamped(twist2) + + # Command should be updated and no timeout + assert conn._current_cmd.ly == 0.5 + assert not conn.timeout_active + + # Wait another 150ms (total 300ms from second command) + time.sleep(0.15) + # Should still not timeout since we reset the timer + assert not conn.timeout_active + assert conn._current_cmd.ly == 0.5 + + conn.running = False + conn.watchdog_running = False + conn.send_thread.join(timeout=0.5) + conn.watchdog_thread.join(timeout=0.5) + conn._stop() def test_watchdog_thread_efficiency(self) -> None: """Test that watchdog uses only one thread regardless of command rate.""" @@ -148,7 +155,7 @@ def test_watchdog_thread_efficiency(self) -> None: conn.watchdog_running = False conn.send_thread.join(timeout=0.5) conn.watchdog_thread.join(timeout=0.5) - conn._close_module() + conn._stop() def test_watchdog_with_send_loop_blocking(self) -> None: """Test that watchdog still works if send loop blocks.""" @@ -172,35 +179,30 @@ def blocking_send_loop() -> None: conn.watchdog_thread = threading.Thread(target=conn._watchdog_loop, daemon=True) conn.watchdog_thread.start() - try: - # Send command - twist = TwistStamped( - ts=time.time(), - frame_id="base_link", - linear=Vector3(1.0, 0, 0), - angular=Vector3(0, 0, 0), - ) - conn.handle_twist_stamped(twist) - assert conn._current_cmd.ly == 1.0 - - # Poll for watchdog timeout (generous 2s deadline) - deadline = time.time() + 2.0 - while time.time() < deadline: - if conn.timeout_active: - break - time.sleep(0.05) - - # Watchdog should have zeroed commands despite blocked send loop - assert conn._current_cmd.ly == 0.0, "Watchdog should zero commands" - assert conn.timeout_active, "Watchdog should be active" - finally: - # Unblock send loop - block_event.set() - conn.running = False - conn.watchdog_running = False - conn.send_thread.join(timeout=1.0) - conn.watchdog_thread.join(timeout=1.0) - conn._close_module() + # Send command + twist = TwistStamped( + ts=time.time(), + frame_id="base_link", + linear=Vector3(1.0, 0, 0), + angular=Vector3(0, 0, 0), + ) + conn.handle_twist_stamped(twist) + assert conn._current_cmd.ly == 1.0 + + # Wait for watchdog timeout + time.sleep(0.3) + + # Watchdog should have zeroed commands despite blocked send loop + assert conn._current_cmd.ly == 0.0 + assert conn.timeout_active + + # Unblock send loop + block_event.set() + conn.running = False + conn.watchdog_running = False + conn.send_thread.join(timeout=0.5) + conn.watchdog_thread.join(timeout=0.5) + conn._stop() def test_continuous_commands_prevent_timeout(self) -> None: """Test that continuous commands prevent watchdog timeout.""" @@ -212,33 +214,30 @@ def test_continuous_commands_prevent_timeout(self) -> None: conn.watchdog_thread = threading.Thread(target=conn._watchdog_loop, daemon=True) conn.watchdog_thread.start() - try: - # Send commands continuously for 1s (should prevent timeout) - start = time.time() - commands_sent = 0 - while time.time() - start < 1.0: - twist = TwistStamped( - ts=time.time(), - frame_id="base_link", - linear=Vector3(0.5, 0, 0), - angular=Vector3(0, 0, 0), - ) - conn.handle_twist_stamped(twist) - commands_sent += 1 - time.sleep(0.05) # 50ms between commands (well under 200ms timeout) - - # Should never timeout - assert not conn.timeout_active, "Should not timeout with continuous commands" - assert conn._current_cmd.ly == 0.5, "Commands should still be active" - assert commands_sent >= 3, ( - f"Should send at least 3 commands in 1s, sent {commands_sent}" + # Send commands continuously for 500ms (should prevent timeout) + start = time.time() + commands_sent = 0 + while time.time() - start < 0.5: + twist = TwistStamped( + ts=time.time(), + frame_id="base_link", + linear=Vector3(0.5, 0, 0), + angular=Vector3(0, 0, 0), ) - finally: - conn.running = False - conn.watchdog_running = False - conn.send_thread.join(timeout=1.0) - conn.watchdog_thread.join(timeout=1.0) - conn._close_module() + conn.handle_twist_stamped(twist) + commands_sent += 1 + time.sleep(0.05) # 50ms between commands (well under 200ms timeout) + + # Should never timeout + assert not conn.timeout_active, "Should not timeout with continuous commands" + assert conn._current_cmd.ly == 0.5, "Commands should still be active" + assert commands_sent >= 9, f"Should send at least 9 commands in 500ms, sent {commands_sent}" + + conn.running = False + conn.watchdog_running = False + conn.send_thread.join(timeout=0.5) + conn.watchdog_thread.join(timeout=0.5) + conn._stop() def test_watchdog_timing_accuracy(self) -> None: """Test that watchdog zeros commands at approximately 200ms.""" @@ -273,14 +272,13 @@ def test_watchdog_timing_accuracy(self) -> None: # Check timing (should be close to 200ms + up to 50ms watchdog interval) elapsed = timeout_time - start_time print(f"\nWatchdog timeout occurred at exactly {elapsed:.3f} seconds") - _lo, _hi = (0.15, 0.5) if _IS_MACOS else (0.19, 0.3) - assert _lo <= elapsed <= _hi, f"Watchdog timed out at {elapsed:.3f}s, expected {_lo}-{_hi}s" + assert 0.19 <= elapsed <= 0.3, f"Watchdog timed out at {elapsed:.3f}s, expected ~0.2-0.25s" conn.running = False conn.watchdog_running = False conn.send_thread.join(timeout=0.5) conn.watchdog_thread.join(timeout=0.5) - conn._close_module() + conn._stop() def test_mode_changes_with_watchdog(self) -> None: """Test that mode changes work correctly with watchdog.""" @@ -323,7 +321,7 @@ def test_mode_changes_with_watchdog(self) -> None: conn.watchdog_running = False conn.send_thread.join(timeout=0.5) conn.watchdog_thread.join(timeout=0.5) - conn._close_module() + conn._stop() def test_watchdog_stops_movement_when_commands_stop(self) -> None: """Verify watchdog zeros commands when packets stop being sent.""" @@ -352,45 +350,36 @@ def test_watchdog_stops_movement_when_commands_stop(self) -> None: assert conn.current_mode == 2 # WALK mode assert not conn.timeout_active - try: - # Poll for watchdog timeout (generous 2s deadline) - deadline = time.time() + 2.0 - while time.time() < deadline: - if conn.timeout_active: - break - time.sleep(0.05) - - assert conn.timeout_active, "Watchdog should have detected timeout" - assert conn._current_cmd.ly == 0.0, "Forward velocity should be zeroed" - assert conn._current_cmd.lx == 0.0, "Lateral velocity should be zeroed" - assert conn._current_cmd.rx == 0.0, "Rotation X should be zeroed" - assert conn._current_cmd.ry == 0.0, "Rotation Y should be zeroed" - assert conn.current_mode == 2, "Mode should stay as WALK" - - # Verify recovery works - send new command - twist = TwistStamped( - ts=time.time(), - frame_id="base_link", - linear=Vector3(0.5, 0, 0), - angular=Vector3(0, 0, 0), - ) - conn.handle_twist_stamped(twist) + # Wait for watchdog to detect timeout (200ms + buffer) + time.sleep(0.3) + + assert conn.timeout_active, "Watchdog should have detected timeout" + assert conn._current_cmd.ly == 0.0, "Forward velocity should be zeroed" + assert conn._current_cmd.lx == 0.0, "Lateral velocity should be zeroed" + assert conn._current_cmd.rx == 0.0, "Rotation X should be zeroed" + assert conn._current_cmd.ry == 0.0, "Rotation Y should be zeroed" + assert conn.current_mode == 2, "Mode should stay as WALK" + + # Verify recovery works - send new command + twist = TwistStamped( + ts=time.time(), + frame_id="base_link", + linear=Vector3(0.5, 0, 0), + angular=Vector3(0, 0, 0), + ) + conn.handle_twist_stamped(twist) - # Poll for recovery (timeout_active should clear) - deadline = time.time() + 2.0 - while time.time() < deadline: - if not conn.timeout_active: - break - time.sleep(0.05) - - assert not conn.timeout_active, "Should recover from timeout" - assert conn._current_cmd.ly == 0.5, "Should accept new commands" - finally: - conn.running = False - conn.watchdog_running = False - conn.send_thread.join(timeout=1.0) - conn.watchdog_thread.join(timeout=1.0) - conn._close_module() + # Give watchdog time to detect recovery + time.sleep(0.1) + + assert not conn.timeout_active, "Should recover from timeout" + assert conn._current_cmd.ly == 0.5, "Should accept new commands" + + conn.running = False + conn.watchdog_running = False + conn.send_thread.join(timeout=0.5) + conn.watchdog_thread.join(timeout=0.5) + conn._stop() def test_rapid_command_thread_safety(self) -> None: """Test thread safety with rapid commands from multiple threads.""" @@ -439,4 +428,4 @@ def send_commands(thread_id) -> None: conn.watchdog_running = False conn.send_thread.join(timeout=0.5) conn.watchdog_thread.join(timeout=0.5) - conn._close_module() + conn._stop() diff --git a/dimos/utils/test_thread_utils.py b/dimos/utils/test_thread_utils.py new file mode 100644 index 0000000000..87c4a883e2 --- /dev/null +++ b/dimos/utils/test_thread_utils.py @@ -0,0 +1,897 @@ +"""Exhaustive tests for dimos/utils/thread_utils.py + +Covers: ThreadSafeVal, ModuleThread, AsyncModuleThread, ModuleProcess, safe_thread_map. +Focuses on deadlocks, race conditions, idempotency, and edge cases under load. +""" + +from __future__ import annotations + +import asyncio +import os +import pickle +import signal +import subprocess +import sys +import threading +import time +from unittest import mock + +import pytest +from reactivex.disposable import CompositeDisposable + +from dimos.utils.thread_utils import ( + AsyncModuleThread, + ModuleProcess, + ModuleThread, + ThreadSafeVal, + safe_thread_map, +) + + +# --------------------------------------------------------------------------- +# Helpers: fake ModuleBase for testing ModuleThread / AsyncModuleThread / ModuleProcess +# --------------------------------------------------------------------------- + + +class FakeModule: + """Minimal stand-in for ModuleBase — just needs _disposables.""" + + def __init__(self) -> None: + self._disposables = CompositeDisposable() + + def dispose(self) -> None: + self._disposables.dispose() + + +# =================================================================== +# ThreadSafeVal Tests +# =================================================================== + + +class TestThreadSafeVal: + def test_basic_get_set(self) -> None: + v = ThreadSafeVal(42) + assert v.get() == 42 + v.set(99) + assert v.get() == 99 + + def test_bool_truthy(self) -> None: + v = ThreadSafeVal(True) + assert bool(v) is True + v.set(False) + assert bool(v) is False + + def test_bool_zero(self) -> None: + v = ThreadSafeVal(0) + assert bool(v) is False + v.set(1) + assert bool(v) is True + + def test_context_manager_returns_value(self) -> None: + v = ThreadSafeVal("hello") + with v as val: + assert val == "hello" + + def test_set_inside_context_manager_no_deadlock(self) -> None: + """The critical test: set() inside a with block must NOT deadlock. + + This was a confirmed bug when using threading.Lock (non-reentrant). + Fixed by using threading.RLock. + """ + v = ThreadSafeVal(0) + result = threading.Event() + + def do_it() -> None: + with v as val: + v.set(val + 1) + result.set() + + t = threading.Thread(target=do_it) + t.start() + t.join(timeout=2) + assert result.is_set(), "Deadlocked! set() inside with block hung" + assert v.get() == 1 + + def test_get_inside_context_manager_no_deadlock(self) -> None: + v = ThreadSafeVal(10) + result = threading.Event() + + def do_it() -> None: + with v as val: + _ = v.get() + result.set() + + t = threading.Thread(target=do_it) + t.start() + t.join(timeout=2) + assert result.is_set(), "Deadlocked! get() inside with block hung" + + def test_bool_inside_context_manager_no_deadlock(self) -> None: + v = ThreadSafeVal(True) + result = threading.Event() + + def do_it() -> None: + with v as val: + _ = bool(v) + result.set() + + t = threading.Thread(target=do_it) + t.start() + t.join(timeout=2) + assert result.is_set(), "Deadlocked! bool() inside with block hung" + + def test_context_manager_blocks_other_threads(self) -> None: + """While one thread holds the lock via `with`, others should block on set().""" + v = ThreadSafeVal(0) + gate = threading.Event() + other_started = threading.Event() + other_finished = threading.Event() + + def holder() -> None: + with v as val: + gate.wait(timeout=5) # hold the lock until signaled + + def setter() -> None: + other_started.set() + v.set(42) # should block until holder releases + other_finished.set() + + t1 = threading.Thread(target=holder) + t2 = threading.Thread(target=setter) + t1.start() + time.sleep(0.05) # let holder acquire lock + t2.start() + other_started.wait(timeout=2) + time.sleep(0.1) + # setter should be blocked + assert not other_finished.is_set(), "set() did not block while lock was held" + gate.set() # release holder + t1.join(timeout=2) + t2.join(timeout=2) + assert other_finished.is_set() + assert v.get() == 42 + + def test_concurrent_increments(self) -> None: + """Many threads doing atomic read-modify-write should not lose updates.""" + v = ThreadSafeVal(0) + n_threads = 50 + n_increments = 100 + + def incrementer() -> None: + for _ in range(n_increments): + with v as val: + v.set(val + 1) + + threads = [threading.Thread(target=incrementer) for _ in range(n_threads)] + for t in threads: + t.start() + for t in threads: + t.join(timeout=10) + assert v.get() == n_threads * n_increments + + def test_concurrent_increments_stress(self) -> None: + """Run the concurrent increment test multiple times to catch races.""" + for _ in range(10): + self.test_concurrent_increments() + + def test_pickle_roundtrip(self) -> None: + v = ThreadSafeVal({"key": [1, 2, 3]}) + data = pickle.dumps(v) + v2 = pickle.loads(data) + assert v2.get() == {"key": [1, 2, 3]} + # Verify the new instance has a working lock + with v2 as val: + v2.set({**val, "new": True}) + assert v2.get()["new"] is True + + def test_repr(self) -> None: + v = ThreadSafeVal("test") + assert repr(v) == "ThreadSafeVal('test')" + + def test_dict_type(self) -> None: + v = ThreadSafeVal({"running": False, "count": 0}) + with v as s: + v.set({**s, "running": True}) + assert v.get() == {"running": True, "count": 0} + + def test_string_literal_type(self) -> None: + """Simulates the ModState pattern from module.py.""" + v = ThreadSafeVal("init") + with v as state: + if state == "init": + v.set("started") + assert v.get() == "started" + + with v as state: + if state in ("stopping", "stopped"): + pass # no-op + else: + v.set("stopping") + assert v.get() == "stopping" + + def test_nested_with_no_deadlock(self) -> None: + """RLock should allow the same thread to nest with blocks.""" + v = ThreadSafeVal(0) + result = threading.Event() + + def do_it() -> None: + with v as val1: + with v as val2: + v.set(val2 + 1) + result.set() + + t = threading.Thread(target=do_it) + t.start() + t.join(timeout=2) + assert result.is_set(), "Nested with blocks deadlocked!" + + +# =================================================================== +# ModuleThread Tests +# =================================================================== + + +class TestModuleThread: + def test_basic_lifecycle(self) -> None: + mod = FakeModule() + ran = threading.Event() + + def target() -> None: + ran.set() + + mt = ModuleThread(module=mod, target=target, name="test-basic") + ran.wait(timeout=2) + assert ran.is_set() + mt.stop() + assert not mt.is_alive + + def test_auto_start(self) -> None: + mod = FakeModule() + started = threading.Event() + mt = ModuleThread(module=mod, target=started.set, name="test-autostart") + started.wait(timeout=2) + assert started.is_set() + mt.stop() + + def test_deferred_start(self) -> None: + mod = FakeModule() + started = threading.Event() + mt = ModuleThread(module=mod, target=started.set, name="test-deferred", start=False) + time.sleep(0.1) + assert not started.is_set() + mt.start() + started.wait(timeout=2) + assert started.is_set() + mt.stop() + + def test_stopping_property(self) -> None: + mod = FakeModule() + saw_stopping = threading.Event() + holder: list[ModuleThread] = [] + + def target() -> None: + while not holder[0].stopping: + time.sleep(0.01) + saw_stopping.set() + + mt = ModuleThread(module=mod, target=target, name="test-stopping", start=False) + holder.append(mt) + mt.start() + time.sleep(0.05) + mt.stop() + saw_stopping.wait(timeout=2) + assert saw_stopping.is_set() + + def test_stop_idempotent(self) -> None: + mod = FakeModule() + mt = ModuleThread(module=mod, target=lambda: time.sleep(0.01), name="test-idem") + time.sleep(0.05) + mt.stop() + mt.stop() # second call should not raise + mt.stop() # third call should not raise + + def test_stop_from_managed_thread_no_deadlock(self) -> None: + """The thread calling stop() on itself should not deadlock.""" + mod = FakeModule() + result = threading.Event() + holder: list[ModuleThread] = [] + + def target() -> None: + holder[0].stop() # stop ourselves — should not deadlock + result.set() + + mt = ModuleThread(module=mod, target=target, name="test-self-stop", start=False) + holder.append(mt) + mt.start() + result.wait(timeout=3) + assert result.is_set(), "Deadlocked when thread called stop() on itself" + + def test_dispose_stops_thread(self) -> None: + """Module dispose should stop the thread via the registered Disposable.""" + mod = FakeModule() + running = threading.Event() + holder: list[ModuleThread] = [] + + def target() -> None: + running.set() + while not holder[0].stopping: + time.sleep(0.01) + + mt = ModuleThread(module=mod, target=target, name="test-dispose", start=False) + holder.append(mt) + mt.start() + running.wait(timeout=2) + mod.dispose() + time.sleep(0.1) + assert not mt.is_alive + + def test_concurrent_stop_calls(self) -> None: + """Multiple threads calling stop() concurrently should not crash.""" + mod = FakeModule() + holder: list[ModuleThread] = [] + + def target() -> None: + while not holder[0].stopping: + time.sleep(0.01) + + mt = ModuleThread(module=mod, target=target, name="test-concurrent-stop", start=False) + holder.append(mt) + mt.start() + time.sleep(0.05) + + errors = [] + + def stop_it() -> None: + try: + mt.stop() + except Exception as e: + errors.append(e) + + threads = [threading.Thread(target=stop_it) for _ in range(20)] + for t in threads: + t.start() + for t in threads: + t.join(timeout=5) + assert not errors, f"Concurrent stop() raised: {errors}" + + def test_close_timeout_respected(self) -> None: + """If the thread ignores the stop signal, stop() should return after close_timeout.""" + mod = FakeModule() + + def stubborn_target() -> None: + time.sleep(10) # ignores stopping signal + + mt = ModuleThread( + module=mod, target=stubborn_target, name="test-timeout", close_timeout=0.2 + ) + start = time.monotonic() + mt.stop() + elapsed = time.monotonic() - start + assert elapsed < 1.0, f"stop() took {elapsed}s, expected ~0.2s" + + def test_stop_concurrent_with_dispose(self) -> None: + """Calling stop() and dispose() concurrently should not crash.""" + for _ in range(20): + mod = FakeModule() + holder: list[ModuleThread] = [] + + def target() -> None: + while not holder[0].stopping: + time.sleep(0.001) + + mt = ModuleThread(module=mod, target=target, name="test-stop-dispose", start=False) + holder.append(mt) + mt.start() + time.sleep(0.02) + # Race: stop and dispose from different threads + t1 = threading.Thread(target=mt.stop) + t2 = threading.Thread(target=mod.dispose) + t1.start() + t2.start() + t1.join(timeout=3) + t2.join(timeout=3) + + +# =================================================================== +# AsyncModuleThread Tests +# =================================================================== + + +class TestAsyncModuleThread: + def test_creates_loop_and_thread(self) -> None: + mod = FakeModule() + amt = AsyncModuleThread(module=mod) + assert amt.loop is not None + assert amt.loop.is_running() + assert amt.is_alive + amt.stop() + assert not amt.is_alive + + def test_stop_idempotent(self) -> None: + mod = FakeModule() + amt = AsyncModuleThread(module=mod) + amt.stop() + amt.stop() # should not raise + amt.stop() + + def test_dispose_stops_loop(self) -> None: + mod = FakeModule() + amt = AsyncModuleThread(module=mod) + assert amt.is_alive + mod.dispose() + time.sleep(0.1) + assert not amt.is_alive + + def test_can_schedule_coroutine(self) -> None: + mod = FakeModule() + amt = AsyncModuleThread(module=mod) + result = [] + + async def coro() -> None: + result.append(42) + + future = asyncio.run_coroutine_threadsafe(coro(), amt.loop) + future.result(timeout=2) + assert result == [42] + amt.stop() + + def test_stop_with_pending_work(self) -> None: + """Stop should succeed even with long-running tasks on the loop.""" + mod = FakeModule() + amt = AsyncModuleThread(module=mod) + started = threading.Event() + + async def slow_coro() -> None: + started.set() + await asyncio.sleep(10) + + asyncio.run_coroutine_threadsafe(slow_coro(), amt.loop) + started.wait(timeout=2) + # stop() should not hang waiting for the coroutine + start = time.monotonic() + amt.stop() + elapsed = time.monotonic() - start + assert elapsed < 5.0, f"stop() hung for {elapsed}s with pending coroutine" + + def test_concurrent_stop(self) -> None: + mod = FakeModule() + amt = AsyncModuleThread(module=mod) + errors = [] + + def stop_it() -> None: + try: + amt.stop() + except Exception as e: + errors.append(e) + + threads = [threading.Thread(target=stop_it) for _ in range(20)] + for t in threads: + t.start() + for t in threads: + t.join(timeout=5) + assert not errors + + +# =================================================================== +# ModuleProcess Tests +# =================================================================== + + +# Helper: path to a python that sleeps or echoes +PYTHON = sys.executable + + +class TestModuleProcess: + def test_basic_lifecycle(self) -> None: + mod = FakeModule() + mp = ModuleProcess( + module=mod, + args=[PYTHON, "-c", "import time; time.sleep(30)"], + shutdown_timeout=2.0, + ) + assert mp.is_alive + assert mp.pid is not None + mp.stop() + assert not mp.is_alive + assert mp.pid is None + + def test_stop_idempotent(self) -> None: + mod = FakeModule() + mp = ModuleProcess( + module=mod, + args=[PYTHON, "-c", "import time; time.sleep(30)"], + shutdown_timeout=1.0, + ) + mp.stop() + mp.stop() # should not raise + mp.stop() + + def test_dispose_stops_process(self) -> None: + mod = FakeModule() + mp = ModuleProcess( + module=mod, + args=[PYTHON, "-c", "import time; time.sleep(30)"], + shutdown_timeout=2.0, + ) + pid = mp.pid + mod.dispose() + time.sleep(0.5) + assert not mp.is_alive + + def test_on_exit_fires_on_natural_exit(self) -> None: + """on_exit should fire when the process exits on its own.""" + mod = FakeModule() + exit_called = threading.Event() + + mp = ModuleProcess( + module=mod, + args=[PYTHON, "-c", "print('done')"], + on_exit=exit_called.set, + ) + exit_called.wait(timeout=5) + assert exit_called.is_set(), "on_exit was not called after natural process exit" + + def test_on_exit_fires_on_crash(self) -> None: + mod = FakeModule() + exit_called = threading.Event() + + mp = ModuleProcess( + module=mod, + args=[PYTHON, "-c", "import sys; sys.exit(1)"], + on_exit=exit_called.set, + ) + exit_called.wait(timeout=5) + assert exit_called.is_set(), "on_exit was not called after process crash" + + def test_on_exit_not_fired_on_stop(self) -> None: + """on_exit should NOT fire when stop() kills the process.""" + mod = FakeModule() + exit_called = threading.Event() + + mp = ModuleProcess( + module=mod, + args=[PYTHON, "-c", "import time; time.sleep(30)"], + on_exit=exit_called.set, + shutdown_timeout=2.0, + ) + time.sleep(0.2) # let watchdog start + mp.stop() + time.sleep(1.0) # give watchdog time to potentially fire + assert not exit_called.is_set(), "on_exit fired after intentional stop()" + + def test_stdout_logged(self) -> None: + mod = FakeModule() + mp = ModuleProcess( + module=mod, + args=[PYTHON, "-c", "print('hello from subprocess')"], + ) + time.sleep(1.0) # let output be read + mp.stop() + + def test_stderr_logged(self) -> None: + mod = FakeModule() + mp = ModuleProcess( + module=mod, + args=[PYTHON, "-c", "import sys; sys.stderr.write('error msg\\n')"], + ) + time.sleep(1.0) + mp.stop() + + def test_log_json_mode(self) -> None: + mod = FakeModule() + mp = ModuleProcess( + module=mod, + args=[PYTHON, "-c", """import json; print(json.dumps({"event": "test", "key": "val"}))"""], + log_json=True, + ) + time.sleep(1.0) + mp.stop() + + def test_log_json_malformed(self) -> None: + mod = FakeModule() + mp = ModuleProcess( + module=mod, + args=[PYTHON, "-c", "print('not json')"], + log_json=True, + ) + time.sleep(1.0) + mp.stop() + + def test_stop_process_that_ignores_sigterm(self) -> None: + """Process that ignores SIGTERM should be killed with SIGKILL.""" + mod = FakeModule() + mp = ModuleProcess( + module=mod, + args=[ + PYTHON, + "-c", + "import signal, time; signal.signal(signal.SIGTERM, signal.SIG_IGN); time.sleep(60)", + ], + shutdown_timeout=0.5, + kill_timeout=2.0, + ) + time.sleep(0.2) + start = time.monotonic() + mp.stop() + elapsed = time.monotonic() - start + assert not mp.is_alive + # Should take roughly shutdown_timeout (0.5) + a bit for SIGKILL + assert elapsed < 5.0 + + def test_stop_already_dead_process(self) -> None: + """stop() on a process that already exited should not raise.""" + mod = FakeModule() + mp = ModuleProcess( + module=mod, + args=[PYTHON, "-c", "pass"], # exits immediately + ) + time.sleep(1.0) # let it die + mp.stop() # should not raise + + def test_concurrent_stop(self) -> None: + mod = FakeModule() + mp = ModuleProcess( + module=mod, + args=[PYTHON, "-c", "import time; time.sleep(30)"], + shutdown_timeout=2.0, + ) + errors = [] + + def stop_it() -> None: + try: + mp.stop() + except Exception as e: + errors.append(e) + + threads = [threading.Thread(target=stop_it) for _ in range(20)] + for t in threads: + t.start() + for t in threads: + t.join(timeout=10) + assert not errors, f"Concurrent stop() raised: {errors}" + + def test_on_exit_calls_module_stop_no_deadlock(self) -> None: + """Simulate the real pattern: on_exit=module.stop, which disposes the + ModuleProcess, which tries to stop its watchdog from inside the watchdog. + Must not deadlock. + """ + mod = FakeModule() + stop_called = threading.Event() + + def fake_module_stop() -> None: + """Simulates module.stop() -> _stop() -> dispose()""" + mod.dispose() + stop_called.set() + + mp = ModuleProcess( + module=mod, + args=[PYTHON, "-c", "pass"], # exits immediately + on_exit=fake_module_stop, + ) + stop_called.wait(timeout=5) + assert stop_called.is_set(), "Deadlocked! on_exit -> dispose -> stop chain hung" + + def test_on_exit_calls_module_stop_no_deadlock_stress(self) -> None: + """Run the deadlock test multiple times under load.""" + for i in range(10): + self.test_on_exit_calls_module_stop_no_deadlock() + + def test_deferred_start(self) -> None: + mod = FakeModule() + mp = ModuleProcess( + module=mod, + args=[PYTHON, "-c", "import time; time.sleep(30)"], + start=False, + ) + assert not mp.is_alive + mp.start() + assert mp.is_alive + mp.stop() + + def test_env_passed(self) -> None: + mod = FakeModule() + exit_called = threading.Event() + + mp = ModuleProcess( + module=mod, + args=[PYTHON, "-c", "import os, sys; sys.exit(0 if os.environ.get('MY_VAR') == '42' else 1)"], + env={**os.environ, "MY_VAR": "42"}, + on_exit=exit_called.set, + ) + exit_called.wait(timeout=5) + # Process should have exited with 0 (our on_exit fires for all unmanaged exits) + assert exit_called.is_set() + + def test_cwd_passed(self) -> None: + mod = FakeModule() + mp = ModuleProcess( + module=mod, + args=[PYTHON, "-c", "import os; print(os.getcwd())"], + cwd="/tmp", + ) + time.sleep(1.0) + mp.stop() + + +# =================================================================== +# safe_thread_map Tests +# =================================================================== + + +class TestSafeThreadMap: + def test_empty_input(self) -> None: + assert safe_thread_map([], lambda x: x) == [] + + def test_all_succeed(self) -> None: + result = safe_thread_map([1, 2, 3], lambda x: x * 2) + assert result == [2, 4, 6] + + def test_preserves_order(self) -> None: + def slow(x: int) -> int: + time.sleep(0.01 * (10 - x)) + return x + + result = safe_thread_map(list(range(10)), slow) + assert result == list(range(10)) + + def test_all_fail_raises_exception_group(self) -> None: + def fail(x: int) -> int: + raise ValueError(f"fail-{x}") + + with pytest.raises(ExceptionGroup) as exc_info: + safe_thread_map([1, 2, 3], fail) + assert len(exc_info.value.exceptions) == 3 + + def test_partial_failure(self) -> None: + def maybe_fail(x: int) -> int: + if x == 2: + raise ValueError("fail") + return x + + with pytest.raises(ExceptionGroup) as exc_info: + safe_thread_map([1, 2, 3], maybe_fail) + assert len(exc_info.value.exceptions) == 1 + + def test_on_errors_callback(self) -> None: + def fail(x: int) -> int: + if x == 2: + raise ValueError("boom") + return x * 10 + + cleanup_called = False + + def on_errors(outcomes, successes, errors): + nonlocal cleanup_called + cleanup_called = True + assert len(errors) == 1 + assert len(successes) == 2 + return successes # return successful results + + result = safe_thread_map([1, 2, 3], fail, on_errors) + assert cleanup_called + assert sorted(result) == [10, 30] + + def test_on_errors_can_raise(self) -> None: + def fail(x: int) -> int: + raise ValueError("boom") + + def on_errors(outcomes, successes, errors): + raise RuntimeError("custom error") + + with pytest.raises(RuntimeError, match="custom error"): + safe_thread_map([1], fail, on_errors) + + def test_waits_for_all_before_raising(self) -> None: + """Even if one fails fast, all others should complete.""" + completed = [] + + def work(x: int) -> int: + if x == 0: + raise ValueError("fast fail") + time.sleep(0.2) + completed.append(x) + return x + + with pytest.raises(ExceptionGroup): + safe_thread_map([0, 1, 2, 3], work) + # All non-failing items should have completed + assert sorted(completed) == [1, 2, 3] + + +# =================================================================== +# Integration: ModuleProcess on_exit -> dispose chain (the CI bug scenario) +# =================================================================== + + +class TestModuleProcessDisposeChain: + """Tests the exact pattern that caused the CI bug: + process exits -> watchdog fires on_exit -> module.stop() -> dispose -> + ModuleProcess.stop() -> tries to stop watchdog from inside watchdog thread. + """ + + def test_chain_no_deadlock_fast_exit(self) -> None: + """Process exits immediately.""" + for _ in range(20): + mod = FakeModule() + done = threading.Event() + + def fake_stop() -> None: + mod.dispose() + done.set() + + ModuleProcess( + module=mod, + args=[PYTHON, "-c", "pass"], + on_exit=fake_stop, + ) + assert done.wait(timeout=5), "Deadlock in dispose chain (fast exit)" + + def test_chain_no_deadlock_slow_exit(self) -> None: + """Process runs briefly then exits.""" + for _ in range(10): + mod = FakeModule() + done = threading.Event() + + def fake_stop() -> None: + mod.dispose() + done.set() + + ModuleProcess( + module=mod, + args=[PYTHON, "-c", "import time; time.sleep(0.1)"], + on_exit=fake_stop, + ) + assert done.wait(timeout=5), "Deadlock in dispose chain (slow exit)" + + def test_chain_concurrent_with_external_stop(self) -> None: + """Process exits naturally while external code calls stop().""" + for _ in range(20): + mod = FakeModule() + done = threading.Event() + + def fake_stop() -> None: + mod.dispose() + done.set() + + mp = ModuleProcess( + module=mod, + args=[PYTHON, "-c", "import time; time.sleep(0.05)"], + on_exit=fake_stop, + shutdown_timeout=1.0, + ) + # Race: the process might exit naturally or we might stop it + time.sleep(0.03) + mp.stop() + # Either way, should not deadlock + time.sleep(1.0) + + def test_dispose_with_artificial_delay(self) -> None: + """Add artificial delay near cleanup to simulate heavy CPU load.""" + original_stop = ModuleThread.stop + + def slow_stop(self_mt: ModuleThread) -> None: + time.sleep(0.05) # simulate load + original_stop(self_mt) + + for _ in range(10): + mod = FakeModule() + done = threading.Event() + + def fake_stop() -> None: + mod.dispose() + done.set() + + with mock.patch.object(ModuleThread, "stop", slow_stop): + ModuleProcess( + module=mod, + args=[PYTHON, "-c", "pass"], + on_exit=fake_stop, + ) + assert done.wait(timeout=10), "Deadlock with slow ModuleThread.stop()" + + +# We need ExceptionGroup for safe_thread_map tests +try: + ExceptionGroup +except NameError: + from dimos.utils.typing_utils import ExceptionGroup diff --git a/dimos/utils/thread_utils.py b/dimos/utils/thread_utils.py new file mode 100644 index 0000000000..e83047ca5b --- /dev/null +++ b/dimos/utils/thread_utils.py @@ -0,0 +1,542 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Thread utilities: safe values, managed threads, safe parallel map.""" + +from __future__ import annotations + +import asyncio +import json +import signal +import subprocess +import threading +from concurrent.futures import Future, ThreadPoolExecutor, as_completed +from typing import IO, TYPE_CHECKING, Any, Generic + +from reactivex.disposable import Disposable + +from dimos.utils.logging_config import setup_logger +from dimos.utils.typing_utils import ExceptionGroup, TypeVar + +logger = setup_logger() + +if TYPE_CHECKING: + from collections.abc import Callable, Sequence + + from dimos.core.module import ModuleBase + +T = TypeVar("T") +R = TypeVar("R") + + +# --------------------------------------------------------------------------- +# ThreadSafeVal: a lock-protected value with context-manager support +# --------------------------------------------------------------------------- + + +class ThreadSafeVal(Generic[T]): + """A thread-safe value wrapper. + + Wraps any value with a lock and provides atomic read-modify-write + via a context manager:: + + counter = ThreadSafeVal(0) + + # Simple get/set (each acquires the lock briefly): + counter.set(10) + print(counter.get()) # 10 + + # Atomic read-modify-write: + with counter as value: + # Lock is held for the entire block. + # Other threads block on get/set/with until this exits. + if value < 100: + counter.set(value + 1) + + # Works with any type: + status = ThreadSafeVal({"running": False, "count": 0}) + with status as s: + status.set({**s, "running": True}) + + # Bool check (for flag-like usage): + stopping = ThreadSafeVal(False) + stopping.set(True) + if stopping: + print("stopping!") + """ + + def __init__(self, initial: T) -> None: + self._lock = threading.RLock() + self._value = initial + + def get(self) -> T: + """Return the current value (acquires the lock briefly).""" + with self._lock: + return self._value + + def set(self, value: T) -> None: + """Replace the value (acquires the lock briefly).""" + with self._lock: + self._value = value + + def __bool__(self) -> bool: + with self._lock: + return bool(self._value) + + def __enter__(self) -> T: + self._lock.acquire() + return self._value + + def __exit__(self, *exc: object) -> None: + self._lock.release() + + def __getstate__(self) -> dict[str, Any]: + return {"_value": self._value} + + def __setstate__(self, state: dict[str, Any]) -> None: + self._lock = threading.RLock() + self._value = state["_value"] + + def __repr__(self) -> str: + return f"ThreadSafeVal({self._value!r})" + + +# --------------------------------------------------------------------------- +# ModuleThread: a thread that auto-registers with a module's disposables +# --------------------------------------------------------------------------- + + +class ModuleThread: + """A thread that registers cleanup with a module's disposables. + + Passes most kwargs through to ``threading.Thread``. On construction, + registers a disposable with the module so that when the module stops, + the thread is automatically joined. Cleanup is idempotent — safe to + call ``stop()`` manually even if the module also disposes it. + + Example:: + + class MyModule(Module): + @rpc + def start(self) -> None: + self._worker = ModuleThread( + module=self, + target=self._run_loop, + name="my-worker", + ) + + def _run_loop(self) -> None: + while not self._worker.stopping: + do_work() + """ + + def __init__( + self, + module: ModuleBase, + *, + start: bool = True, + close_timeout: float = 2.0, + **thread_kwargs: Any, + ) -> None: + thread_kwargs.setdefault("daemon", True) + self._thread = threading.Thread(**thread_kwargs) + self._stop_event = threading.Event() + self._close_timeout = close_timeout + self._stopped = False + self._stop_lock = threading.Lock() + module._disposables.add(Disposable(self.stop)) + if start: + self.start() + + @property + def stopping(self) -> bool: + """True after ``stop()`` has been called.""" + return self._stop_event.is_set() + + def start(self) -> None: + """Start the underlying thread.""" + self._stop_event.clear() + self._thread.start() + + def stop(self) -> None: + """Signal the thread to stop and join it. + + Safe to call multiple times, from any thread (including the + managed thread itself — it will skip the join in that case). + """ + with self._stop_lock: + if self._stopped: + return + self._stopped = True + + self._stop_event.set() + if self._thread.is_alive() and self._thread is not threading.current_thread(): + self._thread.join(timeout=self._close_timeout) + + def join(self, timeout: float | None = None) -> None: + """Join the underlying thread.""" + self._thread.join(timeout=timeout) + + @property + def is_alive(self) -> bool: + return self._thread.is_alive() + + +# --------------------------------------------------------------------------- +# AsyncModuleThread: a thread running an asyncio event loop, auto-registered +# --------------------------------------------------------------------------- + + +class AsyncModuleThread: + """A thread running an asyncio event loop, registered with a module's disposables. + + If a loop is already running in the current context, reuses it (no thread + created). Otherwise creates a new loop and drives it in a daemon thread. + + On stop (or module dispose), the loop is shut down gracefully and the + thread is joined. Idempotent — safe to call ``stop()`` multiple times. + + Example:: + + class MyModule(Module): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self._async = AsyncModuleThread(module=self) + + @rpc + def start(self) -> None: + future = asyncio.run_coroutine_threadsafe( + self._do_work(), self._async.loop + ) + + async def _do_work(self) -> None: + ... + """ + + def __init__( + self, + module: ModuleBase, + *, + close_timeout: float = 2.0, + ) -> None: + self._close_timeout = close_timeout + self._stopped = False + self._stop_lock = threading.Lock() + self._owns_loop = False + self._thread: threading.Thread | None = None + + try: + self._loop = asyncio.get_running_loop() + except RuntimeError: + self._loop = asyncio.new_event_loop() + asyncio.set_event_loop(self._loop) + self._owns_loop = True + self._thread = threading.Thread( + target=self._loop.run_forever, + daemon=True, + name=f"{type(module).__name__}-event-loop", + ) + self._thread.start() + + module._disposables.add(Disposable(self.stop)) + + @property + def loop(self) -> asyncio.AbstractEventLoop: + """The managed event loop.""" + return self._loop + + @property + def is_alive(self) -> bool: + return self._thread is not None and self._thread.is_alive() + + def stop(self) -> None: + """Stop the event loop and join the thread. + + No-op if the loop was not created by this instance (reused an + existing running loop). Safe to call multiple times. + """ + with self._stop_lock: + if self._stopped: + return + self._stopped = True + + if self._owns_loop and self._loop.is_running(): + self._loop.call_soon_threadsafe(self._loop.stop) + + if self._thread is not None and self._thread.is_alive(): + self._thread.join(timeout=self._close_timeout) + + +# --------------------------------------------------------------------------- +# ModuleProcess: managed subprocess with log piping, auto-registered cleanup +# --------------------------------------------------------------------------- + + +class ModuleProcess: + """A managed subprocess that pipes stdout/stderr through the logger. + + Registers with a module's disposables so the process is automatically + stopped on module teardown. A watchdog thread monitors the process and + calls ``on_exit`` if the process exits on its own (i.e. not via + ``ModuleProcess.stop()``). + + Most constructor kwargs mirror ``subprocess.Popen``. ``stdout`` and + ``stderr`` are always captured (set to ``PIPE`` internally). + + Example:: + + class MyModule(Module): + @rpc + def start(self) -> None: + self._proc = ModuleProcess( + module=self, + args=["./my_binary", "--flag"], + cwd="/opt/bin", + on_exit=self.stop, # stops the whole module if process exits on its own + ) + + @rpc + def stop(self) -> None: + # ModuleProcess.stop() is also called automatically via disposables + super().stop() + """ + + def __init__( + self, + module: ModuleBase, + args: list[str] | str, + *, + env: dict[str, str] | None = None, + cwd: str | None = None, + shell: bool = False, + on_exit: Callable[[], Any] | None = None, + shutdown_timeout: float = 10.0, + kill_timeout: float = 5.0, + log_json: bool = False, + start: bool = True, + **popen_kwargs: Any, + ) -> None: + self._args = args + self._env = env + self._cwd = cwd + self._shell = shell + self._on_exit = on_exit + self._shutdown_timeout = shutdown_timeout + self._kill_timeout = kill_timeout + self._log_json = log_json + self._popen_kwargs = popen_kwargs + self._process: subprocess.Popen[bytes] | None = None + self._watchdog: ModuleThread | None = None + self._module = module + self._stopped = False + self._stop_lock = threading.Lock() + + module._disposables.add(Disposable(self.stop)) + if start: + self.start() + + @property + def pid(self) -> int | None: + return self._process.pid if self._process is not None else None + + @property + def returncode(self) -> int | None: + if self._process is None: + return None + return self._process.poll() + + @property + def is_alive(self) -> bool: + return self._process is not None and self._process.poll() is None + + def start(self) -> None: + """Launch the subprocess and start the watchdog.""" + if self._process is not None and self._process.poll() is None: + logger.warning("Process already running", pid=self._process.pid) + return + + with self._stop_lock: + self._stopped = False + + logger.info( + "Starting process", + cmd=self._args if isinstance(self._args, str) else " ".join(self._args), + cwd=self._cwd, + ) + self._process = subprocess.Popen( + self._args, + env=self._env, + cwd=self._cwd, + shell=self._shell, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + **self._popen_kwargs, + ) + logger.info("Process started", pid=self._process.pid) + + self._watchdog = ModuleThread( + module=self._module, + target=self._watch, + name=f"proc-{self._process.pid}-watchdog", + ) + + def stop(self) -> None: + """Send SIGTERM, wait, escalate to SIGKILL if needed. Idempotent.""" + with self._stop_lock: + if self._stopped: + return + self._stopped = True + + if self._process is not None and self._process.poll() is None: + logger.info("Stopping process", pid=self._process.pid) + try: + self._process.send_signal(signal.SIGTERM) + except OSError: + pass # process already dead (PID recycled or exited between poll and signal) + else: + try: + self._process.wait(timeout=self._shutdown_timeout) + except subprocess.TimeoutExpired: + logger.warning( + "Process did not exit, sending SIGKILL", + pid=self._process.pid, + ) + self._process.kill() + try: + self._process.wait(timeout=self._kill_timeout) + except subprocess.TimeoutExpired: + logger.error( + "Process did not exit after SIGKILL", + pid=self._process.pid, + ) + self._process = None + + def _watch(self) -> None: + """Watchdog: pipe logs, detect crashes.""" + proc = self._process + if proc is None: + return + + stdout_t = self._start_reader(proc.stdout, "info") + stderr_t = self._start_reader(proc.stderr, "warning") + rc = proc.wait() + stdout_t.join(timeout=2) + stderr_t.join(timeout=2) + + with self._stop_lock: + if self._stopped: + return + + logger.error("Process died unexpectedly", pid=proc.pid, returncode=rc) + if self._on_exit is not None: + self._on_exit() + + def _start_reader(self, stream: IO[bytes] | None, level: str) -> threading.Thread: + t = threading.Thread(target=self._read_stream, args=(stream, level), daemon=True) + t.start() + return t + + def _read_stream(self, stream: IO[bytes] | None, level: str) -> None: + if stream is None: + return + log_fn = getattr(logger, level) + for raw in stream: + line = raw.decode("utf-8", errors="replace").rstrip() + if not line: + continue + if self._log_json: + try: + data = json.loads(line) + event = data.pop("event", line) + log_fn(event, **data) + continue + except (json.JSONDecodeError, TypeError): + logger.warning("malformed JSON from process", raw=line) + proc = self._process + log_fn(line, pid=proc.pid if proc else None) + stream.close() + + +# --------------------------------------------------------------------------- +# safe_thread_map: parallel map that collects all results before raising +# --------------------------------------------------------------------------- + + +def safe_thread_map( + items: Sequence[T], + fn: Callable[[T], R], + on_errors: Callable[[list[tuple[T, R | Exception]], list[R], list[Exception]], Any] + | None = None, +) -> list[R]: + """Thread-pool map that waits for all items to finish before raising and a cleanup handler + + - Empty *items* → returns ``[]`` immediately. + - All succeed → returns results in input order. + - Any fail → calls ``on_errors(outcomes, successes, errors)`` where + *outcomes* is a list of ``(input, result_or_exception)`` pairs in input + order, *successes* is the list of successful results, and *errors* is + the list of exceptions. If *on_errors* raises, that exception propagates. + If *on_errors* returns normally, its return value is returned from + ``safe_thread_map``. If *on_errors* is ``None``, raises an + ``ExceptionGroup``. + + Example:: + + def start_service(name: str) -> Connection: + return connect(name) + + def cleanup( + outcomes: list[tuple[str, Connection | Exception]], + successes: list[Connection], + errors: list[Exception], + ) -> None: + for conn in successes: + conn.close() + raise ExceptionGroup("failed to start services", errors) + + connections = safe_thread_map( + ["db", "cache", "queue"], + start_service, + cleanup, # called only if any start_service() raises + ) + """ + if not items: + return [] + + outcomes: dict[int, R | Exception] = {} + + with ThreadPoolExecutor(max_workers=len(items)) as pool: + futures: dict[Future[R], int] = {pool.submit(fn, item): i for i, item in enumerate(items)} + for fut in as_completed(futures): + idx = futures[fut] + try: + outcomes[idx] = fut.result() + except Exception as e: + outcomes[idx] = e + + successes: list[R] = [] + errors: list[Exception] = [] + for v in outcomes.values(): + if isinstance(v, Exception): + errors.append(v) + else: + successes.append(v) + + if errors: + if on_errors is not None: + zipped = [(items[i], outcomes[i]) for i in range(len(items))] + return on_errors(zipped, successes, errors) # type: ignore[return-value, no-any-return] + raise ExceptionGroup("safe_thread_map failed", errors) + + return [outcomes[i] for i in range(len(items))] # type: ignore[misc] diff --git a/dimos/utils/typing_utils.py b/dimos/utils/typing_utils.py new file mode 100644 index 0000000000..aa32fff47f --- /dev/null +++ b/dimos/utils/typing_utils.py @@ -0,0 +1,45 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unify typing compatibility across multiple Python versions.""" + +from __future__ import annotations + +import sys +from collections.abc import Sequence + +if sys.version_info < (3, 13): + from typing_extensions import TypeVar +else: + from typing import TypeVar + +if sys.version_info < (3, 11): + + class ExceptionGroup(Exception): # type: ignore[no-redef] # noqa: N818 + """Minimal ExceptionGroup polyfill for Python 3.10.""" + + exceptions: tuple[BaseException, ...] + + def __init__(self, message: str, exceptions: Sequence[BaseException]) -> None: + super().__init__(message) + self.exceptions = tuple(exceptions) +else: + import builtins + + ExceptionGroup = builtins.ExceptionGroup # type: ignore[misc] + +__all__ = [ + "ExceptionGroup", + "TypeVar", +] From b8bb6c0b0db3faf4d50573ea7e1b540c3eb566b1 Mon Sep 17 00:00:00 2001 From: Jeff Hykin Date: Tue, 24 Mar 2026 23:04:44 -0700 Subject: [PATCH 11/22] improve tests --- dimos/core/test_native_module.py | 16 +++++++++------- dimos/utils/test_thread_utils.py | 5 ++++- 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/dimos/core/test_native_module.py b/dimos/core/test_native_module.py index 361b45fdd1..7005498ebb 100644 --- a/dimos/core/test_native_module.py +++ b/dimos/core/test_native_module.py @@ -96,19 +96,21 @@ def test_process_crash_triggers_stop() -> None: mod.pointcloud.transport = LCMTransport("/pc", PointCloud2) mod.start() - assert mod._process is not None - pid = mod._process.pid + assert mod._proc is not None + assert mod._proc.is_alive + pid = mod._proc.pid - # Wait for the process to die and the watchdog to call stop() + # Wait for the process to die and the on_exit callback to call stop() for _ in range(30): time.sleep(0.1) - if mod._process is None: + if mod._proc is None or not mod._proc.is_alive: break - assert mod._process is None, f"Watchdog did not clean up after process {pid} died" + assert mod._proc is None or not mod._proc.is_alive, ( + f"Watchdog did not clean up after process {pid} died" + ) - # Join the watchdog thread. stop() is idempotent but will now join the - # watchdog on the second call since the reference is preserved. + # stop() is idempotent mod.stop() # Wait for background threads (run_forever, _lcm_loop, _watch_process) to finish diff --git a/dimos/utils/test_thread_utils.py b/dimos/utils/test_thread_utils.py index 87c4a883e2..28152d9242 100644 --- a/dimos/utils/test_thread_utils.py +++ b/dimos/utils/test_thread_utils.py @@ -357,9 +357,10 @@ def stop_it() -> None: def test_close_timeout_respected(self) -> None: """If the thread ignores the stop signal, stop() should return after close_timeout.""" mod = FakeModule() + bail = threading.Event() def stubborn_target() -> None: - time.sleep(10) # ignores stopping signal + bail.wait(timeout=10) # ignores stopping signal, but we can bail it out mt = ModuleThread( module=mod, target=stubborn_target, name="test-timeout", close_timeout=0.2 @@ -368,6 +369,8 @@ def stubborn_target() -> None: mt.stop() elapsed = time.monotonic() - start assert elapsed < 1.0, f"stop() took {elapsed}s, expected ~0.2s" + bail.set() # let the thread exit so conftest thread-leak detector is happy + mt.join(timeout=2) def test_stop_concurrent_with_dispose(self) -> None: """Calling stop() and dispose() concurrently should not crash.""" From 44d1b1d87074bc7d714373f7830d14c65f61d589 Mon Sep 17 00:00:00 2001 From: Jeff Hykin Date: Tue, 24 Mar 2026 22:43:50 -0700 Subject: [PATCH 12/22] fully ideal approach, untested --- dimos/utils/test_thread_utils.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/dimos/utils/test_thread_utils.py b/dimos/utils/test_thread_utils.py index 28152d9242..87c4a883e2 100644 --- a/dimos/utils/test_thread_utils.py +++ b/dimos/utils/test_thread_utils.py @@ -357,10 +357,9 @@ def stop_it() -> None: def test_close_timeout_respected(self) -> None: """If the thread ignores the stop signal, stop() should return after close_timeout.""" mod = FakeModule() - bail = threading.Event() def stubborn_target() -> None: - bail.wait(timeout=10) # ignores stopping signal, but we can bail it out + time.sleep(10) # ignores stopping signal mt = ModuleThread( module=mod, target=stubborn_target, name="test-timeout", close_timeout=0.2 @@ -369,8 +368,6 @@ def stubborn_target() -> None: mt.stop() elapsed = time.monotonic() - start assert elapsed < 1.0, f"stop() took {elapsed}s, expected ~0.2s" - bail.set() # let the thread exit so conftest thread-leak detector is happy - mt.join(timeout=2) def test_stop_concurrent_with_dispose(self) -> None: """Calling stop() and dispose() concurrently should not crash.""" From 3a4d0c2800baab775da78b8cd88deba1b9622426 Mon Sep 17 00:00:00 2001 From: Jeff Hykin Date: Tue, 24 Mar 2026 23:28:16 -0700 Subject: [PATCH 13/22] formatting --- dimos/core/module.py | 8 +-- dimos/utils/test_thread_utils.py | 108 ++++++++++++++----------------- dimos/utils/thread_utils.py | 25 +++---- dimos/utils/typing_utils.py | 2 +- 4 files changed, 66 insertions(+), 77 deletions(-) diff --git a/dimos/core/module.py b/dimos/core/module.py index d9f8356f38..af733a19b7 100644 --- a/dimos/core/module.py +++ b/dimos/core/module.py @@ -11,14 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import asyncio from collections.abc import Callable from dataclasses import dataclass from functools import partial import inspect import json import sys -import threading from typing import ( TYPE_CHECKING, Any, @@ -100,8 +98,10 @@ def __init__(self, config_args: dict[str, Any]): super().__init__(**config_args) self._disposables = CompositeDisposable() self.mod_state = ThreadSafeVal[ModState]("init") - self._async_thread = AsyncModuleThread( # NEEDS to be created after self._disposables exists - module=self + self._async_thread = ( + AsyncModuleThread( # NEEDS to be created after self._disposables exists + module=self + ) ) try: self.rpc = self.config.rpc_transport() diff --git a/dimos/utils/test_thread_utils.py b/dimos/utils/test_thread_utils.py index 87c4a883e2..782d5161be 100644 --- a/dimos/utils/test_thread_utils.py +++ b/dimos/utils/test_thread_utils.py @@ -1,3 +1,17 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Exhaustive tests for dimos/utils/thread_utils.py Covers: ThreadSafeVal, ModuleThread, AsyncModuleThread, ModuleProcess, safe_thread_map. @@ -9,8 +23,6 @@ import asyncio import os import pickle -import signal -import subprocess import sys import threading import time @@ -27,10 +39,7 @@ safe_thread_map, ) - -# --------------------------------------------------------------------------- # Helpers: fake ModuleBase for testing ModuleThread / AsyncModuleThread / ModuleProcess -# --------------------------------------------------------------------------- class FakeModule: @@ -43,9 +52,7 @@ def dispose(self) -> None: self._disposables.dispose() -# =================================================================== # ThreadSafeVal Tests -# =================================================================== class TestThreadSafeVal: @@ -97,7 +104,7 @@ def test_get_inside_context_manager_no_deadlock(self) -> None: result = threading.Event() def do_it() -> None: - with v as val: + with v: _ = v.get() result.set() @@ -111,7 +118,7 @@ def test_bool_inside_context_manager_no_deadlock(self) -> None: result = threading.Event() def do_it() -> None: - with v as val: + with v: _ = bool(v) result.set() @@ -128,7 +135,7 @@ def test_context_manager_blocks_other_threads(self) -> None: other_finished = threading.Event() def holder() -> None: - with v as val: + with v: gate.wait(timeout=5) # hold the lock until signaled def setter() -> None: @@ -215,7 +222,7 @@ def test_nested_with_no_deadlock(self) -> None: result = threading.Event() def do_it() -> None: - with v as val1: + with v: with v as val2: v.set(val2 + 1) result.set() @@ -226,9 +233,7 @@ def do_it() -> None: assert result.is_set(), "Nested with blocks deadlocked!" -# =================================================================== # ModuleThread Tests -# =================================================================== class TestModuleThread: @@ -375,8 +380,8 @@ def test_stop_concurrent_with_dispose(self) -> None: mod = FakeModule() holder: list[ModuleThread] = [] - def target() -> None: - while not holder[0].stopping: + def target(h: list[ModuleThread] = holder) -> None: + while not h[0].stopping: time.sleep(0.001) mt = ModuleThread(module=mod, target=target, name="test-stop-dispose", start=False) @@ -392,9 +397,7 @@ def target() -> None: t2.join(timeout=3) -# =================================================================== # AsyncModuleThread Tests -# =================================================================== class TestAsyncModuleThread: @@ -472,9 +475,7 @@ def stop_it() -> None: assert not errors -# =================================================================== # ModuleProcess Tests -# =================================================================== # Helper: path to a python that sleeps or echoes @@ -513,7 +514,6 @@ def test_dispose_stops_process(self) -> None: args=[PYTHON, "-c", "import time; time.sleep(30)"], shutdown_timeout=2.0, ) - pid = mp.pid mod.dispose() time.sleep(0.5) assert not mp.is_alive @@ -523,7 +523,7 @@ def test_on_exit_fires_on_natural_exit(self) -> None: mod = FakeModule() exit_called = threading.Event() - mp = ModuleProcess( + ModuleProcess( module=mod, args=[PYTHON, "-c", "print('done')"], on_exit=exit_called.set, @@ -535,7 +535,7 @@ def test_on_exit_fires_on_crash(self) -> None: mod = FakeModule() exit_called = threading.Event() - mp = ModuleProcess( + ModuleProcess( module=mod, args=[PYTHON, "-c", "import sys; sys.exit(1)"], on_exit=exit_called.set, @@ -581,7 +581,11 @@ def test_log_json_mode(self) -> None: mod = FakeModule() mp = ModuleProcess( module=mod, - args=[PYTHON, "-c", """import json; print(json.dumps({"event": "test", "key": "val"}))"""], + args=[ + PYTHON, + "-c", + """import json; print(json.dumps({"event": "test", "key": "val"}))""", + ], log_json=True, ) time.sleep(1.0) @@ -663,7 +667,7 @@ def fake_module_stop() -> None: mod.dispose() stop_called.set() - mp = ModuleProcess( + ModuleProcess( module=mod, args=[PYTHON, "-c", "pass"], # exits immediately on_exit=fake_module_stop, @@ -673,7 +677,7 @@ def fake_module_stop() -> None: def test_on_exit_calls_module_stop_no_deadlock_stress(self) -> None: """Run the deadlock test multiple times under load.""" - for i in range(10): + for _i in range(10): self.test_on_exit_calls_module_stop_no_deadlock() def test_deferred_start(self) -> None: @@ -692,9 +696,13 @@ def test_env_passed(self) -> None: mod = FakeModule() exit_called = threading.Event() - mp = ModuleProcess( + ModuleProcess( module=mod, - args=[PYTHON, "-c", "import os, sys; sys.exit(0 if os.environ.get('MY_VAR') == '42' else 1)"], + args=[ + PYTHON, + "-c", + "import os, sys; sys.exit(0 if os.environ.get('MY_VAR') == '42' else 1)", + ], env={**os.environ, "MY_VAR": "42"}, on_exit=exit_called.set, ) @@ -713,9 +721,7 @@ def test_cwd_passed(self) -> None: mp.stop() -# =================================================================== # safe_thread_map Tests -# =================================================================== class TestSafeThreadMap: @@ -798,9 +804,7 @@ def work(x: int) -> int: assert sorted(completed) == [1, 2, 3] -# =================================================================== # Integration: ModuleProcess on_exit -> dispose chain (the CI bug scenario) -# =================================================================== class TestModuleProcessDisposeChain: @@ -809,20 +813,23 @@ class TestModuleProcessDisposeChain: ModuleProcess.stop() -> tries to stop watchdog from inside watchdog thread. """ + @staticmethod + def _make_fake_stop(mod: FakeModule, done: threading.Event) -> Callable: + def fake_stop() -> None: + mod.dispose() + done.set() + + return fake_stop + def test_chain_no_deadlock_fast_exit(self) -> None: """Process exits immediately.""" for _ in range(20): mod = FakeModule() done = threading.Event() - - def fake_stop() -> None: - mod.dispose() - done.set() - ModuleProcess( module=mod, args=[PYTHON, "-c", "pass"], - on_exit=fake_stop, + on_exit=self._make_fake_stop(mod, done), ) assert done.wait(timeout=5), "Deadlock in dispose chain (fast exit)" @@ -831,15 +838,10 @@ def test_chain_no_deadlock_slow_exit(self) -> None: for _ in range(10): mod = FakeModule() done = threading.Event() - - def fake_stop() -> None: - mod.dispose() - done.set() - ModuleProcess( module=mod, args=[PYTHON, "-c", "import time; time.sleep(0.1)"], - on_exit=fake_stop, + on_exit=self._make_fake_stop(mod, done), ) assert done.wait(timeout=5), "Deadlock in dispose chain (slow exit)" @@ -848,15 +850,10 @@ def test_chain_concurrent_with_external_stop(self) -> None: for _ in range(20): mod = FakeModule() done = threading.Event() - - def fake_stop() -> None: - mod.dispose() - done.set() - mp = ModuleProcess( module=mod, args=[PYTHON, "-c", "import time; time.sleep(0.05)"], - on_exit=fake_stop, + on_exit=self._make_fake_stop(mod, done), shutdown_timeout=1.0, ) # Race: the process might exit naturally or we might stop it @@ -876,22 +873,13 @@ def slow_stop(self_mt: ModuleThread) -> None: for _ in range(10): mod = FakeModule() done = threading.Event() - - def fake_stop() -> None: - mod.dispose() - done.set() - with mock.patch.object(ModuleThread, "stop", slow_stop): ModuleProcess( module=mod, args=[PYTHON, "-c", "pass"], - on_exit=fake_stop, + on_exit=self._make_fake_stop(mod, done), ) assert done.wait(timeout=10), "Deadlock with slow ModuleThread.stop()" -# We need ExceptionGroup for safe_thread_map tests -try: - ExceptionGroup -except NameError: - from dimos.utils.typing_utils import ExceptionGroup +from dimos.utils.typing_utils import ExceptionGroup diff --git a/dimos/utils/thread_utils.py b/dimos/utils/thread_utils.py index e83047ca5b..3a53386c50 100644 --- a/dimos/utils/thread_utils.py +++ b/dimos/utils/thread_utils.py @@ -17,11 +17,12 @@ from __future__ import annotations import asyncio +import collections +from concurrent.futures import Future, ThreadPoolExecutor, as_completed import json import signal import subprocess import threading -from concurrent.futures import Future, ThreadPoolExecutor, as_completed from typing import IO, TYPE_CHECKING, Any, Generic from reactivex.disposable import Disposable @@ -40,9 +41,7 @@ R = TypeVar("R") -# --------------------------------------------------------------------------- # ThreadSafeVal: a lock-protected value with context-manager support -# --------------------------------------------------------------------------- class ThreadSafeVal(Generic[T]): @@ -112,9 +111,7 @@ def __repr__(self) -> str: return f"ThreadSafeVal({self._value!r})" -# --------------------------------------------------------------------------- # ModuleThread: a thread that auto-registers with a module's disposables -# --------------------------------------------------------------------------- class ModuleThread: @@ -193,9 +190,7 @@ def is_alive(self) -> bool: return self._thread.is_alive() -# --------------------------------------------------------------------------- # AsyncModuleThread: a thread running an asyncio event loop, auto-registered -# --------------------------------------------------------------------------- class AsyncModuleThread: @@ -278,9 +273,7 @@ def stop(self) -> None: self._thread.join(timeout=self._close_timeout) -# --------------------------------------------------------------------------- # ModuleProcess: managed subprocess with log piping, auto-registered cleanup -# --------------------------------------------------------------------------- class ModuleProcess: @@ -341,6 +334,7 @@ def __init__( self._module = module self._stopped = False self._stop_lock = threading.Lock() + self.last_stderr: collections.deque[str] = collections.deque(maxlen=50) module._disposables.add(Disposable(self.stop)) if start: @@ -438,7 +432,13 @@ def _watch(self) -> None: if self._stopped: return - logger.error("Process died unexpectedly", pid=proc.pid, returncode=rc) + last_stderr = "\n".join(self.last_stderr) + logger.error( + "Process died unexpectedly", + pid=proc.pid, + returncode=rc, + last_stderr=last_stderr[:500] if last_stderr else None, + ) if self._on_exit is not None: self._on_exit() @@ -451,10 +451,13 @@ def _read_stream(self, stream: IO[bytes] | None, level: str) -> None: if stream is None: return log_fn = getattr(logger, level) + is_stderr = level == "warning" for raw in stream: line = raw.decode("utf-8", errors="replace").rstrip() if not line: continue + if is_stderr: + self.last_stderr.append(line) if self._log_json: try: data = json.loads(line) @@ -468,9 +471,7 @@ def _read_stream(self, stream: IO[bytes] | None, level: str) -> None: stream.close() -# --------------------------------------------------------------------------- # safe_thread_map: parallel map that collects all results before raising -# --------------------------------------------------------------------------- def safe_thread_map( diff --git a/dimos/utils/typing_utils.py b/dimos/utils/typing_utils.py index aa32fff47f..3592d5fdbb 100644 --- a/dimos/utils/typing_utils.py +++ b/dimos/utils/typing_utils.py @@ -16,8 +16,8 @@ from __future__ import annotations -import sys from collections.abc import Sequence +import sys if sys.version_info < (3, 13): from typing_extensions import TypeVar From a4aba11920710f9431c66e7c2813490e18f93b01 Mon Sep 17 00:00:00 2001 From: Jeff Hykin Date: Tue, 24 Mar 2026 23:32:54 -0700 Subject: [PATCH 14/22] misc improve --- dimos/utils/thread_utils.py | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/dimos/utils/thread_utils.py b/dimos/utils/thread_utils.py index 3a53386c50..e9844326b5 100644 --- a/dimos/utils/thread_utils.py +++ b/dimos/utils/thread_utils.py @@ -140,7 +140,7 @@ def _run_loop(self) -> None: def __init__( self, - module: ModuleBase, + module: ModuleBase[Any], *, start: bool = True, close_timeout: float = 2.0, @@ -221,7 +221,7 @@ async def _do_work(self) -> None: def __init__( self, - module: ModuleBase, + module: ModuleBase[Any], *, close_timeout: float = 2.0, ) -> None: @@ -307,7 +307,7 @@ def stop(self) -> None: def __init__( self, - module: ModuleBase, + module: ModuleBase[Any], args: list[str] | str, *, env: dict[str, str] | None = None, @@ -317,6 +317,7 @@ def __init__( shutdown_timeout: float = 10.0, kill_timeout: float = 5.0, log_json: bool = False, + log_tail_lines: int = 50, start: bool = True, **popen_kwargs: Any, ) -> None: @@ -328,13 +329,15 @@ def __init__( self._shutdown_timeout = shutdown_timeout self._kill_timeout = kill_timeout self._log_json = log_json + self._log_tail_lines = log_tail_lines self._popen_kwargs = popen_kwargs self._process: subprocess.Popen[bytes] | None = None self._watchdog: ModuleThread | None = None self._module = module self._stopped = False self._stop_lock = threading.Lock() - self.last_stderr: collections.deque[str] = collections.deque(maxlen=50) + self.last_stdout: collections.deque[str] = collections.deque(maxlen=log_tail_lines) + self.last_stderr: collections.deque[str] = collections.deque(maxlen=log_tail_lines) module._disposables.add(Disposable(self.stop)) if start: @@ -363,6 +366,9 @@ def start(self) -> None: with self._stop_lock: self._stopped = False + self.last_stdout = collections.deque(maxlen=self._log_tail_lines) + self.last_stderr = collections.deque(maxlen=self._log_tail_lines) + logger.info( "Starting process", cmd=self._args if isinstance(self._args, str) else " ".join(self._args), @@ -432,12 +438,14 @@ def _watch(self) -> None: if self._stopped: return - last_stderr = "\n".join(self.last_stderr) + last_stdout = "\n".join(self.last_stdout) or None + last_stderr = "\n".join(self.last_stderr) or None logger.error( "Process died unexpectedly", pid=proc.pid, returncode=rc, - last_stderr=last_stderr[:500] if last_stderr else None, + last_stdout=last_stdout, + last_stderr=last_stderr, ) if self._on_exit is not None: self._on_exit() @@ -452,12 +460,12 @@ def _read_stream(self, stream: IO[bytes] | None, level: str) -> None: return log_fn = getattr(logger, level) is_stderr = level == "warning" + buf = self.last_stderr if is_stderr else self.last_stdout for raw in stream: line = raw.decode("utf-8", errors="replace").rstrip() if not line: continue - if is_stderr: - self.last_stderr.append(line) + buf.append(line) if self._log_json: try: data = json.loads(line) From 57731704d3995f9d25e68b5710b6a3a9d930c854 Mon Sep 17 00:00:00 2001 From: Jeff Hykin Date: Wed, 25 Mar 2026 00:10:43 -0700 Subject: [PATCH 15/22] cleanup --- dimos/agents/mcp/test_mcp_server.py | 59 ++++++++++++++++++++++++++++- dimos/core/module.py | 6 +-- dimos/utils/test_thread_utils.py | 6 +-- dimos/utils/thread_utils.py | 1 - 4 files changed, 64 insertions(+), 8 deletions(-) diff --git a/dimos/agents/mcp/test_mcp_server.py b/dimos/agents/mcp/test_mcp_server.py index 1cbca9e3e4..0424a2b16f 100644 --- a/dimos/agents/mcp/test_mcp_server.py +++ b/dimos/agents/mcp/test_mcp_server.py @@ -16,9 +16,13 @@ import asyncio import json +import socket +import time from unittest.mock import MagicMock -from dimos.agents.mcp.mcp_server import handle_request +import requests + +from dimos.agents.mcp.mcp_server import McpServer, handle_request from dimos.core.module import SkillInfo @@ -111,3 +115,56 @@ def test_mcp_module_initialize_and_unknown() -> None: response = asyncio.run(handle_request({"method": "unknown/method", "id": 2}, [], {})) assert response["error"]["code"] == -32601 + + +def _free_port() -> int: + with socket.socket() as s: + s.bind(("", 0)) + return s.getsockname()[1] + + +def test_mcp_server_lifecycle() -> None: + """Start a real McpServer, hit the HTTP endpoint, then stop it. + + This exercises the AsyncModuleThread event loop integration that the + unit tests above do not cover. + """ + port = _free_port() + + server = McpServer() + server._start_server(port=port) + url = f"http://127.0.0.1:{port}/mcp" + + # Wait for the server to be ready + for _ in range(40): + try: + resp = requests.post( + url, + json={"jsonrpc": "2.0", "method": "initialize", "id": 1}, + timeout=0.5, + ) + if resp.status_code == 200: + break + except requests.ConnectionError: + time.sleep(0.1) + else: + server.stop() + raise AssertionError("McpServer did not become ready") + + # Verify it responds + data = resp.json() + assert data["result"]["serverInfo"]["name"] == "dimensional" + + # Stop and verify it shuts down + server.stop() + time.sleep(0.3) + + with socket.socket() as s: + # Port should be released after stop + try: + s.connect(("127.0.0.1", port)) + s.close() + # If we could connect, the server is still up — that's a bug + raise AssertionError("McpServer still listening after stop()") + except ConnectionRefusedError: + pass # expected — server is down diff --git a/dimos/core/module.py b/dimos/core/module.py index af733a19b7..d96133ec7d 100644 --- a/dimos/core/module.py +++ b/dimos/core/module.py @@ -46,7 +46,7 @@ from dimos.utils.generic import classproperty from dimos.utils.thread_utils import AsyncModuleThread, ThreadSafeVal -ModState = Literal["init", "started", "stopping", "stopped"] +ModState = Literal["init", "started", "stopped"] if TYPE_CHECKING: from dimos.core.blueprints import Blueprint @@ -130,9 +130,9 @@ def stop(self) -> None: def _stop(self) -> None: with self.mod_state as state: - if state in ("stopping", "stopped"): + if state == "stopped": return - self.mod_state.set("stopping") + self.mod_state.set("stopped") if self.rpc: self.rpc.stop() # type: ignore[attr-defined] diff --git a/dimos/utils/test_thread_utils.py b/dimos/utils/test_thread_utils.py index 782d5161be..65987f3431 100644 --- a/dimos/utils/test_thread_utils.py +++ b/dimos/utils/test_thread_utils.py @@ -210,11 +210,11 @@ def test_string_literal_type(self) -> None: assert v.get() == "started" with v as state: - if state in ("stopping", "stopped"): + if state == "stopped": pass # no-op else: - v.set("stopping") - assert v.get() == "stopping" + v.set("stopped") + assert v.get() == "stopped" def test_nested_with_no_deadlock(self) -> None: """RLock should allow the same thread to nest with blocks.""" diff --git a/dimos/utils/thread_utils.py b/dimos/utils/thread_utils.py index e9844326b5..6d9b7a9e7f 100644 --- a/dimos/utils/thread_utils.py +++ b/dimos/utils/thread_utils.py @@ -301,7 +301,6 @@ def start(self) -> None: @rpc def stop(self) -> None: - # ModuleProcess.stop() is also called automatically via disposables super().stop() """ From f62a0b0579bcee95b330bd8a0d0dfa891e459041 Mon Sep 17 00:00:00 2001 From: Jeff Hykin Date: Wed, 25 Mar 2026 00:09:21 -0700 Subject: [PATCH 16/22] Apply suggestions from code review Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> --- dimos/agents/mcp/mcp_server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dimos/agents/mcp/mcp_server.py b/dimos/agents/mcp/mcp_server.py index d09202532c..f858571fac 100644 --- a/dimos/agents/mcp/mcp_server.py +++ b/dimos/agents/mcp/mcp_server.py @@ -181,7 +181,7 @@ def stop(self) -> None: if self._uvicorn_server: self._uvicorn_server.should_exit = True loop = self._async_thread.loop - if loop is not None and self._serve_future is not None: + if self._serve_future is not None: self._serve_future.result(timeout=5.0) self._uvicorn_server = None self._serve_future = None From 72f38c4c76534706bd4f87f001616605353df8ef Mon Sep 17 00:00:00 2001 From: Jeff Hykin Date: Wed, 25 Mar 2026 00:13:01 -0700 Subject: [PATCH 17/22] - --- dimos/core/module.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dimos/core/module.py b/dimos/core/module.py index d96133ec7d..39800fe2d1 100644 --- a/dimos/core/module.py +++ b/dimos/core/module.py @@ -95,9 +95,9 @@ class ModuleBase(Configurable[ModuleConfigT], Resource): rpc_calls: list[str] = [] def __init__(self, config_args: dict[str, Any]): + self.mod_state = ThreadSafeVal[ModState]("init") super().__init__(**config_args) self._disposables = CompositeDisposable() - self.mod_state = ThreadSafeVal[ModState]("init") self._async_thread = ( AsyncModuleThread( # NEEDS to be created after self._disposables exists module=self From 9836102a18e2c236f659fdd8a680784ee57eacdb Mon Sep 17 00:00:00 2001 From: Jeff Hykin Date: Wed, 25 Mar 2026 00:24:25 -0700 Subject: [PATCH 18/22] fix order of _disposables --- dimos/agents/mcp/mcp_server.py | 1 - dimos/core/module.py | 6 ++++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/dimos/agents/mcp/mcp_server.py b/dimos/agents/mcp/mcp_server.py index f858571fac..1e1d7d9942 100644 --- a/dimos/agents/mcp/mcp_server.py +++ b/dimos/agents/mcp/mcp_server.py @@ -180,7 +180,6 @@ def start(self) -> None: def stop(self) -> None: if self._uvicorn_server: self._uvicorn_server.should_exit = True - loop = self._async_thread.loop if self._serve_future is not None: self._serve_future.result(timeout=5.0) self._uvicorn_server = None diff --git a/dimos/core/module.py b/dimos/core/module.py index 39800fe2d1..4fbffe07b9 100644 --- a/dimos/core/module.py +++ b/dimos/core/module.py @@ -134,6 +134,10 @@ def _stop(self) -> None: return self.mod_state.set("stopped") + # dispose of things BEFORE making aspects like rpc and _tf invalid + if hasattr(self, "_disposables"): + self._disposables.dispose() # stops _async_thread via disposable + if self.rpc: self.rpc.stop() # type: ignore[attr-defined] self.rpc = None # type: ignore[assignment] @@ -141,8 +145,6 @@ def _stop(self) -> None: if hasattr(self, "_tf") and self._tf is not None: self._tf.stop() self._tf = None - if hasattr(self, "_disposables"): - self._disposables.dispose() # stops _async_thread via disposable # Break the In/Out -> owner -> self reference cycle so the instance # can be freed by refcount instead of waiting for GC. From a1586f32630894b34d268ccf5e1086ff8a3eb709 Mon Sep 17 00:00:00 2001 From: Jeff Hykin Date: Wed, 25 Mar 2026 23:19:13 -0700 Subject: [PATCH 19/22] pr feedback --- dimos/core/native_module.py | 1 + dimos/utils/test_thread_utils.py | 85 ++++++++++++++-------------- dimos/utils/thread_utils.py | 97 ++++++++++---------------------- 3 files changed, 71 insertions(+), 112 deletions(-) diff --git a/dimos/core/native_module.py b/dimos/core/native_module.py index bdc46e2cab..f2579192ae 100644 --- a/dimos/core/native_module.py +++ b/dimos/core/native_module.py @@ -150,6 +150,7 @@ def start(self) -> None: shutdown_timeout=self.config.shutdown_timeout, log_json=self.config.log_format == LogFormat.JSON, ) + self._proc.start() def _resolve_paths(self) -> None: """Resolve relative ``cwd`` and ``executable`` against the subclass's source file.""" diff --git a/dimos/utils/test_thread_utils.py b/dimos/utils/test_thread_utils.py index 65987f3431..dd799dfd1c 100644 --- a/dimos/utils/test_thread_utils.py +++ b/dimos/utils/test_thread_utils.py @@ -62,18 +62,6 @@ def test_basic_get_set(self) -> None: v.set(99) assert v.get() == 99 - def test_bool_truthy(self) -> None: - v = ThreadSafeVal(True) - assert bool(v) is True - v.set(False) - assert bool(v) is False - - def test_bool_zero(self) -> None: - v = ThreadSafeVal(0) - assert bool(v) is False - v.set(1) - assert bool(v) is True - def test_context_manager_returns_value(self) -> None: v = ThreadSafeVal("hello") with v as val: @@ -113,20 +101,6 @@ def do_it() -> None: t.join(timeout=2) assert result.is_set(), "Deadlocked! get() inside with block hung" - def test_bool_inside_context_manager_no_deadlock(self) -> None: - v = ThreadSafeVal(True) - result = threading.Event() - - def do_it() -> None: - with v: - _ = bool(v) - result.set() - - t = threading.Thread(target=do_it) - t.start() - t.join(timeout=2) - assert result.is_set(), "Deadlocked! bool() inside with block hung" - def test_context_manager_blocks_other_threads(self) -> None: """While one thread holds the lock via `with`, others should block on set().""" v = ThreadSafeVal(0) @@ -245,6 +219,7 @@ def target() -> None: ran.set() mt = ModuleThread(module=mod, target=target, name="test-basic") + mt.start() ran.wait(timeout=2) assert ran.is_set() mt.stop() @@ -254,6 +229,7 @@ def test_auto_start(self) -> None: mod = FakeModule() started = threading.Event() mt = ModuleThread(module=mod, target=started.set, name="test-autostart") + mt.start() started.wait(timeout=2) assert started.is_set() mt.stop() @@ -261,7 +237,7 @@ def test_auto_start(self) -> None: def test_deferred_start(self) -> None: mod = FakeModule() started = threading.Event() - mt = ModuleThread(module=mod, target=started.set, name="test-deferred", start=False) + mt = ModuleThread(module=mod, target=started.set, name="test-deferred") time.sleep(0.1) assert not started.is_set() mt.start() @@ -275,11 +251,11 @@ def test_stopping_property(self) -> None: holder: list[ModuleThread] = [] def target() -> None: - while not holder[0].stopping: + while holder[0].status.get() == "running": time.sleep(0.01) saw_stopping.set() - mt = ModuleThread(module=mod, target=target, name="test-stopping", start=False) + mt = ModuleThread(module=mod, target=target, name="test-stopping") holder.append(mt) mt.start() time.sleep(0.05) @@ -290,6 +266,7 @@ def target() -> None: def test_stop_idempotent(self) -> None: mod = FakeModule() mt = ModuleThread(module=mod, target=lambda: time.sleep(0.01), name="test-idem") + mt.start() time.sleep(0.05) mt.stop() mt.stop() # second call should not raise @@ -305,7 +282,7 @@ def target() -> None: holder[0].stop() # stop ourselves — should not deadlock result.set() - mt = ModuleThread(module=mod, target=target, name="test-self-stop", start=False) + mt = ModuleThread(module=mod, target=target, name="test-self-stop") holder.append(mt) mt.start() result.wait(timeout=3) @@ -319,10 +296,10 @@ def test_dispose_stops_thread(self) -> None: def target() -> None: running.set() - while not holder[0].stopping: + while holder[0].status.get() == "running": time.sleep(0.01) - mt = ModuleThread(module=mod, target=target, name="test-dispose", start=False) + mt = ModuleThread(module=mod, target=target, name="test-dispose") holder.append(mt) mt.start() running.wait(timeout=2) @@ -336,10 +313,10 @@ def test_concurrent_stop_calls(self) -> None: holder: list[ModuleThread] = [] def target() -> None: - while not holder[0].stopping: + while holder[0].status.get() == "running": time.sleep(0.01) - mt = ModuleThread(module=mod, target=target, name="test-concurrent-stop", start=False) + mt = ModuleThread(module=mod, target=target, name="test-concurrent-stop") holder.append(mt) mt.start() time.sleep(0.05) @@ -369,6 +346,7 @@ def stubborn_target() -> None: mt = ModuleThread( module=mod, target=stubborn_target, name="test-timeout", close_timeout=0.2 ) + mt.start() start = time.monotonic() mt.stop() elapsed = time.monotonic() - start @@ -381,10 +359,10 @@ def test_stop_concurrent_with_dispose(self) -> None: holder: list[ModuleThread] = [] def target(h: list[ModuleThread] = holder) -> None: - while not h[0].stopping: + while h[0].status.get() == "running": time.sleep(0.001) - mt = ModuleThread(module=mod, target=target, name="test-stop-dispose", start=False) + mt = ModuleThread(module=mod, target=target, name="test-stop-dispose") holder.append(mt) mt.start() time.sleep(0.02) @@ -490,6 +468,7 @@ def test_basic_lifecycle(self) -> None: args=[PYTHON, "-c", "import time; time.sleep(30)"], shutdown_timeout=2.0, ) + mp.start() assert mp.is_alive assert mp.pid is not None mp.stop() @@ -503,6 +482,7 @@ def test_stop_idempotent(self) -> None: args=[PYTHON, "-c", "import time; time.sleep(30)"], shutdown_timeout=1.0, ) + mp.start() mp.stop() mp.stop() # should not raise mp.stop() @@ -514,6 +494,7 @@ def test_dispose_stops_process(self) -> None: args=[PYTHON, "-c", "import time; time.sleep(30)"], shutdown_timeout=2.0, ) + mp.start() mod.dispose() time.sleep(0.5) assert not mp.is_alive @@ -523,11 +504,12 @@ def test_on_exit_fires_on_natural_exit(self) -> None: mod = FakeModule() exit_called = threading.Event() - ModuleProcess( + mp = ModuleProcess( module=mod, args=[PYTHON, "-c", "print('done')"], on_exit=exit_called.set, ) + mp.start() exit_called.wait(timeout=5) assert exit_called.is_set(), "on_exit was not called after natural process exit" @@ -535,11 +517,12 @@ def test_on_exit_fires_on_crash(self) -> None: mod = FakeModule() exit_called = threading.Event() - ModuleProcess( + mp = ModuleProcess( module=mod, args=[PYTHON, "-c", "import sys; sys.exit(1)"], on_exit=exit_called.set, ) + mp.start() exit_called.wait(timeout=5) assert exit_called.is_set(), "on_exit was not called after process crash" @@ -554,6 +537,7 @@ def test_on_exit_not_fired_on_stop(self) -> None: on_exit=exit_called.set, shutdown_timeout=2.0, ) + mp.start() time.sleep(0.2) # let watchdog start mp.stop() time.sleep(1.0) # give watchdog time to potentially fire @@ -565,6 +549,7 @@ def test_stdout_logged(self) -> None: module=mod, args=[PYTHON, "-c", "print('hello from subprocess')"], ) + mp.start() time.sleep(1.0) # let output be read mp.stop() @@ -574,6 +559,7 @@ def test_stderr_logged(self) -> None: module=mod, args=[PYTHON, "-c", "import sys; sys.stderr.write('error msg\\n')"], ) + mp.start() time.sleep(1.0) mp.stop() @@ -588,6 +574,7 @@ def test_log_json_mode(self) -> None: ], log_json=True, ) + mp.start() time.sleep(1.0) mp.stop() @@ -598,6 +585,7 @@ def test_log_json_malformed(self) -> None: args=[PYTHON, "-c", "print('not json')"], log_json=True, ) + mp.start() time.sleep(1.0) mp.stop() @@ -614,6 +602,7 @@ def test_stop_process_that_ignores_sigterm(self) -> None: shutdown_timeout=0.5, kill_timeout=2.0, ) + mp.start() time.sleep(0.2) start = time.monotonic() mp.stop() @@ -629,6 +618,7 @@ def test_stop_already_dead_process(self) -> None: module=mod, args=[PYTHON, "-c", "pass"], # exits immediately ) + mp.start() time.sleep(1.0) # let it die mp.stop() # should not raise @@ -639,6 +629,7 @@ def test_concurrent_stop(self) -> None: args=[PYTHON, "-c", "import time; time.sleep(30)"], shutdown_timeout=2.0, ) + mp.start() errors = [] def stop_it() -> None: @@ -667,11 +658,12 @@ def fake_module_stop() -> None: mod.dispose() stop_called.set() - ModuleProcess( + mp = ModuleProcess( module=mod, args=[PYTHON, "-c", "pass"], # exits immediately on_exit=fake_module_stop, ) + mp.start() stop_called.wait(timeout=5) assert stop_called.is_set(), "Deadlocked! on_exit -> dispose -> stop chain hung" @@ -685,7 +677,6 @@ def test_deferred_start(self) -> None: mp = ModuleProcess( module=mod, args=[PYTHON, "-c", "import time; time.sleep(30)"], - start=False, ) assert not mp.is_alive mp.start() @@ -696,7 +687,7 @@ def test_env_passed(self) -> None: mod = FakeModule() exit_called = threading.Event() - ModuleProcess( + mp = ModuleProcess( module=mod, args=[ PYTHON, @@ -706,6 +697,7 @@ def test_env_passed(self) -> None: env={**os.environ, "MY_VAR": "42"}, on_exit=exit_called.set, ) + mp.start() exit_called.wait(timeout=5) # Process should have exited with 0 (our on_exit fires for all unmanaged exits) assert exit_called.is_set() @@ -717,6 +709,7 @@ def test_cwd_passed(self) -> None: args=[PYTHON, "-c", "import os; print(os.getcwd())"], cwd="/tmp", ) + mp.start() time.sleep(1.0) mp.stop() @@ -826,11 +819,12 @@ def test_chain_no_deadlock_fast_exit(self) -> None: for _ in range(20): mod = FakeModule() done = threading.Event() - ModuleProcess( + mp = ModuleProcess( module=mod, args=[PYTHON, "-c", "pass"], on_exit=self._make_fake_stop(mod, done), ) + mp.start() assert done.wait(timeout=5), "Deadlock in dispose chain (fast exit)" def test_chain_no_deadlock_slow_exit(self) -> None: @@ -838,11 +832,12 @@ def test_chain_no_deadlock_slow_exit(self) -> None: for _ in range(10): mod = FakeModule() done = threading.Event() - ModuleProcess( + mp = ModuleProcess( module=mod, args=[PYTHON, "-c", "import time; time.sleep(0.1)"], on_exit=self._make_fake_stop(mod, done), ) + mp.start() assert done.wait(timeout=5), "Deadlock in dispose chain (slow exit)" def test_chain_concurrent_with_external_stop(self) -> None: @@ -856,6 +851,7 @@ def test_chain_concurrent_with_external_stop(self) -> None: on_exit=self._make_fake_stop(mod, done), shutdown_timeout=1.0, ) + mp.start() # Race: the process might exit naturally or we might stop it time.sleep(0.03) mp.stop() @@ -874,11 +870,12 @@ def slow_stop(self_mt: ModuleThread) -> None: mod = FakeModule() done = threading.Event() with mock.patch.object(ModuleThread, "stop", slow_stop): - ModuleProcess( + mp = ModuleProcess( module=mod, args=[PYTHON, "-c", "pass"], on_exit=self._make_fake_stop(mod, done), ) + mp.start() assert done.wait(timeout=10), "Deadlock with slow ModuleThread.stop()" diff --git a/dimos/utils/thread_utils.py b/dimos/utils/thread_utils.py index 6d9b7a9e7f..adf6751333 100644 --- a/dimos/utils/thread_utils.py +++ b/dimos/utils/thread_utils.py @@ -23,7 +23,7 @@ import signal import subprocess import threading -from typing import IO, TYPE_CHECKING, Any, Generic +from typing import IO, TYPE_CHECKING, Any, Generic, Literal from reactivex.disposable import Disposable @@ -41,22 +41,14 @@ R = TypeVar("R") -# ThreadSafeVal: a lock-protected value with context-manager support - - class ThreadSafeVal(Generic[T]): """A thread-safe value wrapper. - Wraps any value with a lock and provides atomic read-modify-write - via a context manager:: - - counter = ThreadSafeVal(0) - - # Simple get/set (each acquires the lock briefly): - counter.set(10) - print(counter.get()) # 10 - - # Atomic read-modify-write: + Forces lock usage in order to get access to a value (reduces unsafe value access) + Three ways to use: + 1. `.set` + 2. `.get` + 3. via a context manager:: with counter as value: # Lock is held for the entire block. # Other threads block on get/set/with until this exits. @@ -71,7 +63,7 @@ class ThreadSafeVal(Generic[T]): # Bool check (for flag-like usage): stopping = ThreadSafeVal(False) stopping.set(True) - if stopping: + if stopping.get(): print("stopping!") """ @@ -89,10 +81,6 @@ def set(self, value: T) -> None: with self._lock: self._value = value - def __bool__(self) -> bool: - with self._lock: - return bool(self._value) - def __enter__(self) -> T: self._lock.acquire() return self._value @@ -111,9 +99,6 @@ def __repr__(self) -> str: return f"ThreadSafeVal({self._value!r})" -# ModuleThread: a thread that auto-registers with a module's disposables - - class ModuleThread: """A thread that registers cleanup with a module's disposables. @@ -132,9 +117,10 @@ def start(self) -> None: target=self._run_loop, name="my-worker", ) + self._worker.start() def _run_loop(self) -> None: - while not self._worker.stopping: + while self._worker.status.get() == "running": do_work() """ @@ -142,28 +128,18 @@ def __init__( self, module: ModuleBase[Any], *, - start: bool = True, close_timeout: float = 2.0, **thread_kwargs: Any, ) -> None: thread_kwargs.setdefault("daemon", True) + thread_kwargs.setdefault("name", f"{type(module).__name__}-thread") self._thread = threading.Thread(**thread_kwargs) - self._stop_event = threading.Event() self._close_timeout = close_timeout - self._stopped = False - self._stop_lock = threading.Lock() + self.status: ThreadSafeVal[Literal["not_started", "running", "stopping", "stopped"]] = ThreadSafeVal("not_started") module._disposables.add(Disposable(self.stop)) - if start: - self.start() - - @property - def stopping(self) -> bool: - """True after ``stop()`` has been called.""" - return self._stop_event.is_set() def start(self) -> None: - """Start the underlying thread.""" - self._stop_event.clear() + self.status.set("running") self._thread.start() def stop(self) -> None: @@ -172,14 +148,13 @@ def stop(self) -> None: Safe to call multiple times, from any thread (including the managed thread itself — it will skip the join in that case). """ - with self._stop_lock: - if self._stopped: + with self.status as s: + if s in ("stopping", "stopped"): return - self._stopped = True - - self._stop_event.set() + self.status.set("stopping") if self._thread.is_alive() and self._thread is not threading.current_thread(): self._thread.join(timeout=self._close_timeout) + self.status.set("stopped") def join(self, timeout: float | None = None) -> None: """Join the underlying thread.""" @@ -190,9 +165,6 @@ def is_alive(self) -> bool: return self._thread.is_alive() -# AsyncModuleThread: a thread running an asyncio event loop, auto-registered - - class AsyncModuleThread: """A thread running an asyncio event loop, registered with a module's disposables. @@ -226,8 +198,7 @@ def __init__( close_timeout: float = 2.0, ) -> None: self._close_timeout = close_timeout - self._stopped = False - self._stop_lock = threading.Lock() + self._stopped = ThreadSafeVal(False) self._owns_loop = False self._thread: threading.Thread | None = None @@ -261,10 +232,10 @@ def stop(self) -> None: No-op if the loop was not created by this instance (reused an existing running loop). Safe to call multiple times. """ - with self._stop_lock: - if self._stopped: + with self._stopped as stopped: + if stopped: return - self._stopped = True + self._stopped.set(True) if self._owns_loop and self._loop.is_running(): self._loop.call_soon_threadsafe(self._loop.stop) @@ -273,9 +244,6 @@ def stop(self) -> None: self._thread.join(timeout=self._close_timeout) -# ModuleProcess: managed subprocess with log piping, auto-registered cleanup - - class ModuleProcess: """A managed subprocess that pipes stdout/stderr through the logger. @@ -298,6 +266,7 @@ def start(self) -> None: cwd="/opt/bin", on_exit=self.stop, # stops the whole module if process exits on its own ) + self._proc.start() @rpc def stop(self) -> None: @@ -317,7 +286,6 @@ def __init__( kill_timeout: float = 5.0, log_json: bool = False, log_tail_lines: int = 50, - start: bool = True, **popen_kwargs: Any, ) -> None: self._args = args @@ -333,14 +301,11 @@ def __init__( self._process: subprocess.Popen[bytes] | None = None self._watchdog: ModuleThread | None = None self._module = module - self._stopped = False - self._stop_lock = threading.Lock() + self._stopped = ThreadSafeVal(False) self.last_stdout: collections.deque[str] = collections.deque(maxlen=log_tail_lines) self.last_stderr: collections.deque[str] = collections.deque(maxlen=log_tail_lines) module._disposables.add(Disposable(self.stop)) - if start: - self.start() @property def pid(self) -> int | None: @@ -362,8 +327,7 @@ def start(self) -> None: logger.warning("Process already running", pid=self._process.pid) return - with self._stop_lock: - self._stopped = False + self._stopped.set(False) self.last_stdout = collections.deque(maxlen=self._log_tail_lines) self.last_stderr = collections.deque(maxlen=self._log_tail_lines) @@ -389,13 +353,14 @@ def start(self) -> None: target=self._watch, name=f"proc-{self._process.pid}-watchdog", ) + self._watchdog.start() def stop(self) -> None: """Send SIGTERM, wait, escalate to SIGKILL if needed. Idempotent.""" - with self._stop_lock: - if self._stopped: + with self._stopped as stopped: + if stopped: return - self._stopped = True + self._stopped.set(True) if self._process is not None and self._process.poll() is None: logger.info("Stopping process", pid=self._process.pid) @@ -433,9 +398,8 @@ def _watch(self) -> None: stdout_t.join(timeout=2) stderr_t.join(timeout=2) - with self._stop_lock: - if self._stopped: - return + if self._stopped.get(): + return last_stdout = "\n".join(self.last_stdout) or None last_stderr = "\n".join(self.last_stderr) or None @@ -478,9 +442,6 @@ def _read_stream(self, stream: IO[bytes] | None, level: str) -> None: stream.close() -# safe_thread_map: parallel map that collects all results before raising - - def safe_thread_map( items: Sequence[T], fn: Callable[[T], R], From 218b72ebcbde7592a05a415ea30f5ed7a27e1696 Mon Sep 17 00:00:00 2001 From: Jeff Hykin Date: Thu, 26 Mar 2026 03:40:18 -0700 Subject: [PATCH 20/22] refactor: merge _stop into stop, add thread_start helper - Merge _stop() into stop() in ModuleBase (removes unnecessary indirection) - Update all callers of _stop() to use stop() directly - Add thread_start() convenience function that creates + starts a ModuleThread --- dimos/core/module.py | 3 - dimos/core/test_core.py | 2 +- dimos/perception/detection/conftest.py | 10 +- .../perception/detection/reid/test_module.py | 2 +- dimos/robot/unitree/b1/test_connection.py | 281 +++++++++--------- .../mujoco/direct_cmd_vel_explorer.py | 2 +- dimos/utils/thread_utils.py | 36 +++ 7 files changed, 190 insertions(+), 146 deletions(-) diff --git a/dimos/core/module.py b/dimos/core/module.py index 4fbffe07b9..1ad934b5ae 100644 --- a/dimos/core/module.py +++ b/dimos/core/module.py @@ -126,9 +126,6 @@ def start(self) -> None: @rpc def stop(self) -> None: - self._stop() - - def _stop(self) -> None: with self.mod_state as state: if state == "stopped": return diff --git a/dimos/core/test_core.py b/dimos/core/test_core.py index aae167d8c6..858cc81849 100644 --- a/dimos/core/test_core.py +++ b/dimos/core/test_core.py @@ -89,7 +89,7 @@ def test_classmethods() -> None: ) assert hasattr(class_rpcs["start"], "__rpc__"), "start should have __rpc__ attribute" - nav._stop() + nav.stop() @pytest.mark.slow diff --git a/dimos/perception/detection/conftest.py b/dimos/perception/detection/conftest.py index 2040d687be..cb20e3d06a 100644 --- a/dimos/perception/detection/conftest.py +++ b/dimos/perception/detection/conftest.py @@ -221,7 +221,7 @@ def moment_provider(**kwargs) -> Moment2D: yield moment_provider moment_provider.cache_clear() - module._stop() + module.stop() @pytest.fixture(scope="session") @@ -256,7 +256,7 @@ def moment_provider(**kwargs) -> Moment3D: yield moment_provider moment_provider.cache_clear() if module is not None: - module._stop() + module.stop() @pytest.fixture(scope="session") @@ -290,9 +290,9 @@ def object_db_module(get_moment): yield moduleDB - module2d._stop() - module3d._stop() - moduleDB._stop() + module2d.stop() + module3d.stop() + moduleDB.stop() @pytest.fixture(scope="session") diff --git a/dimos/perception/detection/reid/test_module.py b/dimos/perception/detection/reid/test_module.py index 5fa0eead8d..89d752aae3 100644 --- a/dimos/perception/detection/reid/test_module.py +++ b/dimos/perception/detection/reid/test_module.py @@ -40,5 +40,5 @@ def test_reid_ingress(imageDetections2d) -> None: print("Processing detections through ReidModule...") reid_module.annotations._transport = LCMTransport("/annotations", ImageAnnotations) reid_module.ingress(imageDetections2d) - reid_module._stop() + reid_module.stop() print("✓ ReidModule ingress test completed successfully") diff --git a/dimos/robot/unitree/b1/test_connection.py b/dimos/robot/unitree/b1/test_connection.py index 011853d172..977554eb59 100644 --- a/dimos/robot/unitree/b1/test_connection.py +++ b/dimos/robot/unitree/b1/test_connection.py @@ -22,10 +22,14 @@ # should be used and tested. Additionally, tests should always use `try-finally` # to clean up even if the test fails. +import sys import threading import time -from dimos.msgs.geometry_msgs import TwistStamped, Vector3 +_IS_MACOS = sys.platform == "darwin" + +from dimos.msgs.geometry_msgs.TwistStamped import TwistStamped +from dimos.msgs.geometry_msgs.Vector3 import Vector3 from dimos.msgs.std_msgs.Int32 import Int32 from .connection import MockB1ConnectionModule @@ -58,22 +62,27 @@ def test_watchdog_actually_zeros_commands(self) -> None: assert conn._current_cmd.mode == 2 assert not conn.timeout_active - # Wait for watchdog timeout (200ms + buffer) - time.sleep(0.3) - - # Verify commands were zeroed by watchdog - assert conn._current_cmd.ly == 0.0 - assert conn._current_cmd.lx == 0.0 - assert conn._current_cmd.rx == 0.0 - assert conn._current_cmd.ry == 0.0 - assert conn._current_cmd.mode == 2 # Mode maintained - assert conn.timeout_active - - conn.running = False - conn.watchdog_running = False - conn.send_thread.join(timeout=0.5) - conn.watchdog_thread.join(timeout=0.5) - conn._stop() + try: + # Poll for watchdog timeout (generous 2s deadline) + deadline = time.time() + 2.0 + while time.time() < deadline: + if conn.timeout_active: + break + time.sleep(0.05) + + # Verify commands were zeroed by watchdog + assert conn._current_cmd.ly == 0.0 + assert conn._current_cmd.lx == 0.0 + assert conn._current_cmd.rx == 0.0 + assert conn._current_cmd.ry == 0.0 + assert conn._current_cmd.mode == 2 # Mode maintained + assert conn.timeout_active + finally: + conn.running = False + conn.watchdog_running = False + conn.send_thread.join(timeout=1.0) + conn.watchdog_thread.join(timeout=1.0) + conn.stop() def test_watchdog_resets_on_new_command(self) -> None: """Test that watchdog timeout resets when new command arrives.""" @@ -85,43 +94,27 @@ def test_watchdog_resets_on_new_command(self) -> None: conn.watchdog_thread = threading.Thread(target=conn._watchdog_loop, daemon=True) conn.watchdog_thread.start() - # Send first command - twist1 = TwistStamped( - ts=time.time(), - frame_id="base_link", - linear=Vector3(1.0, 0, 0), - angular=Vector3(0, 0, 0), - ) - conn.handle_twist_stamped(twist1) - assert conn._current_cmd.ly == 1.0 - - # Wait 150ms (not enough to trigger timeout) - time.sleep(0.15) - - # Send second command before timeout - twist2 = TwistStamped( - ts=time.time(), - frame_id="base_link", - linear=Vector3(0.5, 0, 0), - angular=Vector3(0, 0, 0), - ) - conn.handle_twist_stamped(twist2) - - # Command should be updated and no timeout - assert conn._current_cmd.ly == 0.5 - assert not conn.timeout_active - - # Wait another 150ms (total 300ms from second command) - time.sleep(0.15) - # Should still not timeout since we reset the timer - assert not conn.timeout_active - assert conn._current_cmd.ly == 0.5 - - conn.running = False - conn.watchdog_running = False - conn.send_thread.join(timeout=0.5) - conn.watchdog_thread.join(timeout=0.5) - conn._stop() + try: + # Send commands in rapid succession — each resets the 200ms watchdog + for val in [1.0, 0.8, 0.6, 0.5]: + twist = TwistStamped( + ts=time.time(), + frame_id="base_link", + linear=Vector3(val, 0, 0), + angular=Vector3(0, 0, 0), + ) + conn.handle_twist_stamped(twist) + time.sleep(0.02) # 20ms between commands, well under timeout + + # Command should be the last one sent and no timeout + assert conn._current_cmd.ly == 0.5 + assert not conn.timeout_active + finally: + conn.running = False + conn.watchdog_running = False + conn.send_thread.join(timeout=1.0) + conn.watchdog_thread.join(timeout=1.0) + conn.stop() def test_watchdog_thread_efficiency(self) -> None: """Test that watchdog uses only one thread regardless of command rate.""" @@ -155,7 +148,7 @@ def test_watchdog_thread_efficiency(self) -> None: conn.watchdog_running = False conn.send_thread.join(timeout=0.5) conn.watchdog_thread.join(timeout=0.5) - conn._stop() + conn.stop() def test_watchdog_with_send_loop_blocking(self) -> None: """Test that watchdog still works if send loop blocks.""" @@ -179,30 +172,35 @@ def blocking_send_loop() -> None: conn.watchdog_thread = threading.Thread(target=conn._watchdog_loop, daemon=True) conn.watchdog_thread.start() - # Send command - twist = TwistStamped( - ts=time.time(), - frame_id="base_link", - linear=Vector3(1.0, 0, 0), - angular=Vector3(0, 0, 0), - ) - conn.handle_twist_stamped(twist) - assert conn._current_cmd.ly == 1.0 - - # Wait for watchdog timeout - time.sleep(0.3) - - # Watchdog should have zeroed commands despite blocked send loop - assert conn._current_cmd.ly == 0.0 - assert conn.timeout_active - - # Unblock send loop - block_event.set() - conn.running = False - conn.watchdog_running = False - conn.send_thread.join(timeout=0.5) - conn.watchdog_thread.join(timeout=0.5) - conn._stop() + try: + # Send command + twist = TwistStamped( + ts=time.time(), + frame_id="base_link", + linear=Vector3(1.0, 0, 0), + angular=Vector3(0, 0, 0), + ) + conn.handle_twist_stamped(twist) + assert conn._current_cmd.ly == 1.0 + + # Poll for watchdog timeout (generous 2s deadline) + deadline = time.time() + 2.0 + while time.time() < deadline: + if conn.timeout_active: + break + time.sleep(0.05) + + # Watchdog should have zeroed commands despite blocked send loop + assert conn._current_cmd.ly == 0.0, "Watchdog should zero commands" + assert conn.timeout_active, "Watchdog should be active" + finally: + # Unblock send loop + block_event.set() + conn.running = False + conn.watchdog_running = False + conn.send_thread.join(timeout=1.0) + conn.watchdog_thread.join(timeout=1.0) + conn.stop() def test_continuous_commands_prevent_timeout(self) -> None: """Test that continuous commands prevent watchdog timeout.""" @@ -214,30 +212,33 @@ def test_continuous_commands_prevent_timeout(self) -> None: conn.watchdog_thread = threading.Thread(target=conn._watchdog_loop, daemon=True) conn.watchdog_thread.start() - # Send commands continuously for 500ms (should prevent timeout) - start = time.time() - commands_sent = 0 - while time.time() - start < 0.5: - twist = TwistStamped( - ts=time.time(), - frame_id="base_link", - linear=Vector3(0.5, 0, 0), - angular=Vector3(0, 0, 0), + try: + # Send commands continuously for 1s (should prevent timeout) + start = time.time() + commands_sent = 0 + while time.time() - start < 1.0: + twist = TwistStamped( + ts=time.time(), + frame_id="base_link", + linear=Vector3(0.5, 0, 0), + angular=Vector3(0, 0, 0), + ) + conn.handle_twist_stamped(twist) + commands_sent += 1 + time.sleep(0.05) # 50ms between commands (well under 200ms timeout) + + # Should never timeout + assert not conn.timeout_active, "Should not timeout with continuous commands" + assert conn._current_cmd.ly == 0.5, "Commands should still be active" + assert commands_sent >= 3, ( + f"Should send at least 3 commands in 1s, sent {commands_sent}" ) - conn.handle_twist_stamped(twist) - commands_sent += 1 - time.sleep(0.05) # 50ms between commands (well under 200ms timeout) - - # Should never timeout - assert not conn.timeout_active, "Should not timeout with continuous commands" - assert conn._current_cmd.ly == 0.5, "Commands should still be active" - assert commands_sent >= 9, f"Should send at least 9 commands in 500ms, sent {commands_sent}" - - conn.running = False - conn.watchdog_running = False - conn.send_thread.join(timeout=0.5) - conn.watchdog_thread.join(timeout=0.5) - conn._stop() + finally: + conn.running = False + conn.watchdog_running = False + conn.send_thread.join(timeout=1.0) + conn.watchdog_thread.join(timeout=1.0) + conn.stop() def test_watchdog_timing_accuracy(self) -> None: """Test that watchdog zeros commands at approximately 200ms.""" @@ -272,13 +273,14 @@ def test_watchdog_timing_accuracy(self) -> None: # Check timing (should be close to 200ms + up to 50ms watchdog interval) elapsed = timeout_time - start_time print(f"\nWatchdog timeout occurred at exactly {elapsed:.3f} seconds") - assert 0.19 <= elapsed <= 0.3, f"Watchdog timed out at {elapsed:.3f}s, expected ~0.2-0.25s" + _lo, _hi = (0.15, 0.5) if _IS_MACOS else (0.19, 0.3) + assert _lo <= elapsed <= _hi, f"Watchdog timed out at {elapsed:.3f}s, expected {_lo}-{_hi}s" conn.running = False conn.watchdog_running = False conn.send_thread.join(timeout=0.5) conn.watchdog_thread.join(timeout=0.5) - conn._stop() + conn.stop() def test_mode_changes_with_watchdog(self) -> None: """Test that mode changes work correctly with watchdog.""" @@ -321,7 +323,7 @@ def test_mode_changes_with_watchdog(self) -> None: conn.watchdog_running = False conn.send_thread.join(timeout=0.5) conn.watchdog_thread.join(timeout=0.5) - conn._stop() + conn.stop() def test_watchdog_stops_movement_when_commands_stop(self) -> None: """Verify watchdog zeros commands when packets stop being sent.""" @@ -350,36 +352,45 @@ def test_watchdog_stops_movement_when_commands_stop(self) -> None: assert conn.current_mode == 2 # WALK mode assert not conn.timeout_active - # Wait for watchdog to detect timeout (200ms + buffer) - time.sleep(0.3) - - assert conn.timeout_active, "Watchdog should have detected timeout" - assert conn._current_cmd.ly == 0.0, "Forward velocity should be zeroed" - assert conn._current_cmd.lx == 0.0, "Lateral velocity should be zeroed" - assert conn._current_cmd.rx == 0.0, "Rotation X should be zeroed" - assert conn._current_cmd.ry == 0.0, "Rotation Y should be zeroed" - assert conn.current_mode == 2, "Mode should stay as WALK" - - # Verify recovery works - send new command - twist = TwistStamped( - ts=time.time(), - frame_id="base_link", - linear=Vector3(0.5, 0, 0), - angular=Vector3(0, 0, 0), - ) - conn.handle_twist_stamped(twist) - - # Give watchdog time to detect recovery - time.sleep(0.1) - - assert not conn.timeout_active, "Should recover from timeout" - assert conn._current_cmd.ly == 0.5, "Should accept new commands" + try: + # Poll for watchdog timeout (generous 2s deadline) + deadline = time.time() + 2.0 + while time.time() < deadline: + if conn.timeout_active: + break + time.sleep(0.05) + + assert conn.timeout_active, "Watchdog should have detected timeout" + assert conn._current_cmd.ly == 0.0, "Forward velocity should be zeroed" + assert conn._current_cmd.lx == 0.0, "Lateral velocity should be zeroed" + assert conn._current_cmd.rx == 0.0, "Rotation X should be zeroed" + assert conn._current_cmd.ry == 0.0, "Rotation Y should be zeroed" + assert conn.current_mode == 2, "Mode should stay as WALK" + + # Verify recovery works - send new command + twist = TwistStamped( + ts=time.time(), + frame_id="base_link", + linear=Vector3(0.5, 0, 0), + angular=Vector3(0, 0, 0), + ) + conn.handle_twist_stamped(twist) - conn.running = False - conn.watchdog_running = False - conn.send_thread.join(timeout=0.5) - conn.watchdog_thread.join(timeout=0.5) - conn._stop() + # Poll for recovery (timeout_active should clear) + deadline = time.time() + 2.0 + while time.time() < deadline: + if not conn.timeout_active: + break + time.sleep(0.05) + + assert not conn.timeout_active, "Should recover from timeout" + assert conn._current_cmd.ly == 0.5, "Should accept new commands" + finally: + conn.running = False + conn.watchdog_running = False + conn.send_thread.join(timeout=1.0) + conn.watchdog_thread.join(timeout=1.0) + conn.stop() def test_rapid_command_thread_safety(self) -> None: """Test thread safety with rapid commands from multiple threads.""" @@ -428,4 +439,4 @@ def send_commands(thread_id) -> None: conn.watchdog_running = False conn.send_thread.join(timeout=0.5) conn.watchdog_thread.join(timeout=0.5) - conn._stop() + conn.stop() diff --git a/dimos/simulation/mujoco/direct_cmd_vel_explorer.py b/dimos/simulation/mujoco/direct_cmd_vel_explorer.py index 58dc91f6b1..f81c644d0a 100644 --- a/dimos/simulation/mujoco/direct_cmd_vel_explorer.py +++ b/dimos/simulation/mujoco/direct_cmd_vel_explorer.py @@ -99,7 +99,7 @@ def _drive_to(self, target_x: float, target_y: float) -> None: None, Twist(linear=Vector3(linear, 0, 0), angular=Vector3(0, 0, angular)), ) - self._stop() + self.stop() def follow_points(self, waypoints: list[tuple[float, float]]) -> None: self._wait_for_pose() diff --git a/dimos/utils/thread_utils.py b/dimos/utils/thread_utils.py index adf6751333..5b3f6648f2 100644 --- a/dimos/utils/thread_utils.py +++ b/dimos/utils/thread_utils.py @@ -509,3 +509,39 @@ def cleanup( raise ExceptionGroup("safe_thread_map failed", errors) return [outcomes[i] for i in range(len(items))] # type: ignore[misc] + + +def thread_start( + module: "ModuleBase[Any]", + *, + close_timeout: float = 2.0, + **thread_kwargs: Any, +) -> ModuleThread: + """Create a :class:`ModuleThread`, start it immediately, and return it. + + Convenience wrapper equivalent to:: + + t = ModuleThread(module, close_timeout=close_timeout, **thread_kwargs) + t.start() + return t + + Accepts the same arguments as :class:`ModuleThread`. + + Example:: + + class MyModule(Module): + @rpc + def start(self) -> None: + self._worker = thread_start( + self, + target=self._run_loop, + name="my-worker", + ) + + def _run_loop(self) -> None: + while self._worker.status.get() == "running": + do_work() + """ + t = ModuleThread(module, close_timeout=close_timeout, **thread_kwargs) + t.start() + return t From 4240573b18b1a744c4392766ac14feb9842e00e2 Mon Sep 17 00:00:00 2001 From: Jeff Hykin Date: Thu, 26 Mar 2026 03:48:05 -0700 Subject: [PATCH 21/22] refactor: defer AsyncModuleThread loop creation to start() MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit AsyncModuleThread no longer spawns the event loop thread in __init__. The loop is created on the first call to start(), which ModuleBase.start() now calls. This means module construction no longer has side effects — no threads are spawned until the module is explicitly started. --- dimos/core/module.py | 1 + dimos/utils/test_thread_utils.py | 12 +++++++++++- dimos/utils/thread_utils.py | 18 ++++++++++++++---- 3 files changed, 26 insertions(+), 5 deletions(-) diff --git a/dimos/core/module.py b/dimos/core/module.py index 1ad934b5ae..83dfeddcd4 100644 --- a/dimos/core/module.py +++ b/dimos/core/module.py @@ -123,6 +123,7 @@ def start(self) -> None: if state == "stopped": raise RuntimeError(f"{type(self).__name__} cannot be restarted after stop") self.mod_state.set("started") + self._async_thread.start() @rpc def stop(self) -> None: diff --git a/dimos/utils/test_thread_utils.py b/dimos/utils/test_thread_utils.py index dd799dfd1c..ac12936fe0 100644 --- a/dimos/utils/test_thread_utils.py +++ b/dimos/utils/test_thread_utils.py @@ -339,9 +339,10 @@ def stop_it() -> None: def test_close_timeout_respected(self) -> None: """If the thread ignores the stop signal, stop() should return after close_timeout.""" mod = FakeModule() + cancel = threading.Event() def stubborn_target() -> None: - time.sleep(10) # ignores stopping signal + cancel.wait(10) # blocks but can be released for cleanup mt = ModuleThread( module=mod, target=stubborn_target, name="test-timeout", close_timeout=0.2 @@ -351,6 +352,9 @@ def stubborn_target() -> None: mt.stop() elapsed = time.monotonic() - start assert elapsed < 1.0, f"stop() took {elapsed}s, expected ~0.2s" + # Release the thread so it doesn't leak + cancel.set() + mt._thread.join(timeout=1.0) def test_stop_concurrent_with_dispose(self) -> None: """Calling stop() and dispose() concurrently should not crash.""" @@ -382,6 +386,7 @@ class TestAsyncModuleThread: def test_creates_loop_and_thread(self) -> None: mod = FakeModule() amt = AsyncModuleThread(module=mod) + amt.start() assert amt.loop is not None assert amt.loop.is_running() assert amt.is_alive @@ -391,6 +396,7 @@ def test_creates_loop_and_thread(self) -> None: def test_stop_idempotent(self) -> None: mod = FakeModule() amt = AsyncModuleThread(module=mod) + amt.start() amt.stop() amt.stop() # should not raise amt.stop() @@ -398,6 +404,7 @@ def test_stop_idempotent(self) -> None: def test_dispose_stops_loop(self) -> None: mod = FakeModule() amt = AsyncModuleThread(module=mod) + amt.start() assert amt.is_alive mod.dispose() time.sleep(0.1) @@ -406,6 +413,7 @@ def test_dispose_stops_loop(self) -> None: def test_can_schedule_coroutine(self) -> None: mod = FakeModule() amt = AsyncModuleThread(module=mod) + amt.start() result = [] async def coro() -> None: @@ -420,6 +428,7 @@ def test_stop_with_pending_work(self) -> None: """Stop should succeed even with long-running tasks on the loop.""" mod = FakeModule() amt = AsyncModuleThread(module=mod) + amt.start() started = threading.Event() async def slow_coro() -> None: @@ -437,6 +446,7 @@ async def slow_coro() -> None: def test_concurrent_stop(self) -> None: mod = FakeModule() amt = AsyncModuleThread(module=mod) + amt.start() errors = [] def stop_it() -> None: diff --git a/dimos/utils/thread_utils.py b/dimos/utils/thread_utils.py index 5b3f6648f2..e1767988c0 100644 --- a/dimos/utils/thread_utils.py +++ b/dimos/utils/thread_utils.py @@ -201,6 +201,18 @@ def __init__( self._stopped = ThreadSafeVal(False) self._owns_loop = False self._thread: threading.Thread | None = None + self._loop: asyncio.AbstractEventLoop | None = None + self._module_name = type(module).__name__ + + module._disposables.add(Disposable(self.stop)) + + def start(self) -> None: + """Create (or reuse) the event loop and start the driver thread. + + Safe to call multiple times — subsequent calls are no-ops. + """ + if self._loop is not None: + return try: self._loop = asyncio.get_running_loop() @@ -211,12 +223,10 @@ def __init__( self._thread = threading.Thread( target=self._loop.run_forever, daemon=True, - name=f"{type(module).__name__}-event-loop", + name=f"{self._module_name}-event-loop", ) self._thread.start() - module._disposables.add(Disposable(self.stop)) - @property def loop(self) -> asyncio.AbstractEventLoop: """The managed event loop.""" @@ -237,7 +247,7 @@ def stop(self) -> None: return self._stopped.set(True) - if self._owns_loop and self._loop.is_running(): + if self._owns_loop and self._loop is not None and self._loop.is_running(): self._loop.call_soon_threadsafe(self._loop.stop) if self._thread is not None and self._thread.is_alive(): From bac8488320e93b108b33d2411fb5a2cbbf6e6f20 Mon Sep 17 00:00:00 2001 From: jeff-hykin <17692058+jeff-hykin@users.noreply.github.com> Date: Fri, 27 Mar 2026 03:15:14 +0000 Subject: [PATCH 22/22] CI code cleanup --- dimos/core/native_module.py | 6 +++++- dimos/utils/thread_utils.py | 6 ++++-- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/dimos/core/native_module.py b/dimos/core/native_module.py index f2579192ae..080ad7df13 100644 --- a/dimos/core/native_module.py +++ b/dimos/core/native_module.py @@ -140,7 +140,11 @@ def start(self) -> None: module=self, args=[ self.config.executable, - *[arg for name, topic in self._collect_topics().items() for arg in (f"--{name}", topic)], + *[ + arg + for name, topic in self._collect_topics().items() + for arg in (f"--{name}", topic) + ], *self.config.to_cli_args(), *self.config.extra_args, ], diff --git a/dimos/utils/thread_utils.py b/dimos/utils/thread_utils.py index e1767988c0..3e3fe6ba11 100644 --- a/dimos/utils/thread_utils.py +++ b/dimos/utils/thread_utils.py @@ -135,7 +135,9 @@ def __init__( thread_kwargs.setdefault("name", f"{type(module).__name__}-thread") self._thread = threading.Thread(**thread_kwargs) self._close_timeout = close_timeout - self.status: ThreadSafeVal[Literal["not_started", "running", "stopping", "stopped"]] = ThreadSafeVal("not_started") + self.status: ThreadSafeVal[Literal["not_started", "running", "stopping", "stopped"]] = ( + ThreadSafeVal("not_started") + ) module._disposables.add(Disposable(self.stop)) def start(self) -> None: @@ -522,7 +524,7 @@ def cleanup( def thread_start( - module: "ModuleBase[Any]", + module: ModuleBase[Any], *, close_timeout: float = 2.0, **thread_kwargs: Any,