-
Notifications
You must be signed in to change notification settings - Fork 360
Threading Utils (and fix for Native Module flakey test) #1663
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Changes from all commits
3d34735
6413d99
ce26d9a
b11eb6e
820be9a
38a31fc
055b7f3
8fb0526
060e049
a20d161
b8bb6c0
44d1b1d
3a4d0c2
a4aba11
5773170
f62a0b0
72f38c4
9836102
a1586f3
218b72e
4240573
bac8488
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -180,8 +180,7 @@ def start(self) -> None: | |
| def stop(self) -> None: | ||
| if self._uvicorn_server: | ||
| self._uvicorn_server.should_exit = True | ||
| loop = self._loop | ||
| if loop is not None and self._serve_future is not None: | ||
| if self._serve_future is not None: | ||
| self._serve_future.result(timeout=5.0) | ||
| self._uvicorn_server = None | ||
| self._serve_future = None | ||
|
|
@@ -250,6 +249,5 @@ def _start_server(self, port: int | None = None) -> None: | |
| config = uvicorn.Config(app, host=_host, port=_port, log_level="info") | ||
| server = uvicorn.Server(config) | ||
| self._uvicorn_server = server | ||
| loop = self._loop | ||
| assert loop is not None | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. loop always there until stop is called |
||
| loop = self._async_thread.loop | ||
| self._serve_future = asyncio.run_coroutine_threadsafe(server.serve(), loop) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -16,9 +16,13 @@ | |
|
|
||
| import asyncio | ||
| import json | ||
| import socket | ||
| import time | ||
| from unittest.mock import MagicMock | ||
|
|
||
| from dimos.agents.mcp.mcp_server import handle_request | ||
| import requests | ||
|
|
||
| from dimos.agents.mcp.mcp_server import McpServer, handle_request | ||
| from dimos.core.module import SkillInfo | ||
|
|
||
|
|
||
|
|
@@ -111,3 +115,56 @@ def test_mcp_module_initialize_and_unknown() -> None: | |
|
|
||
| response = asyncio.run(handle_request({"method": "unknown/method", "id": 2}, [], {})) | ||
| assert response["error"]["code"] == -32601 | ||
|
|
||
|
|
||
| def _free_port() -> int: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This already exists as |
||
| with socket.socket() as s: | ||
| s.bind(("", 0)) | ||
| return s.getsockname()[1] | ||
|
|
||
|
|
||
| def test_mcp_server_lifecycle() -> None: | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. new test |
||
| """Start a real McpServer, hit the HTTP endpoint, then stop it. | ||
|
|
||
| This exercises the AsyncModuleThread event loop integration that the | ||
| unit tests above do not cover. | ||
| """ | ||
| port = _free_port() | ||
|
|
||
| server = McpServer() | ||
| server._start_server(port=port) | ||
| url = f"http://127.0.0.1:{port}/mcp" | ||
|
|
||
| # Wait for the server to be ready | ||
| for _ in range(40): | ||
| try: | ||
| resp = requests.post( | ||
| url, | ||
| json={"jsonrpc": "2.0", "method": "initialize", "id": 1}, | ||
| timeout=0.5, | ||
| ) | ||
| if resp.status_code == 200: | ||
| break | ||
| except requests.ConnectionError: | ||
| time.sleep(0.1) | ||
|
Comment on lines
+139
to
+149
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Stash already has 3 instances of this exact sequance (as functions called wait_for_mcp). Please don't add a 4th that is inlined. |
||
| else: | ||
| server.stop() | ||
| raise AssertionError("McpServer did not become ready") | ||
|
|
||
| # Verify it responds | ||
| data = resp.json() | ||
| assert data["result"]["serverInfo"]["name"] == "dimensional" | ||
|
|
||
| # Stop and verify it shuts down | ||
| server.stop() | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The server should be stopped even if the test fails. The best way to do it is to use a fixture for the mcp server. |
||
| time.sleep(0.3) | ||
|
|
||
| with socket.socket() as s: | ||
| # Port should be released after stop | ||
| try: | ||
| s.connect(("127.0.0.1", port)) | ||
| s.close() | ||
| # If we could connect, the server is still up — that's a bug | ||
| raise AssertionError("McpServer still listening after stop()") | ||
| except ConnectionRefusedError: | ||
| pass # expected — server is down | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -11,17 +11,16 @@ | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| import asyncio | ||
| from collections.abc import Callable | ||
| from dataclasses import dataclass | ||
| from functools import partial | ||
| import inspect | ||
| import json | ||
| import sys | ||
| import threading | ||
| from typing import ( | ||
| TYPE_CHECKING, | ||
| Any, | ||
| Literal, | ||
| Protocol, | ||
| get_args, | ||
| get_origin, | ||
|
|
@@ -45,6 +44,9 @@ | |
| from dimos.protocol.tf.tf import LCMTF, TFSpec | ||
| from dimos.utils import colors | ||
| from dimos.utils.generic import classproperty | ||
| from dimos.utils.thread_utils import AsyncModuleThread, ThreadSafeVal | ||
|
|
||
| ModState = Literal["init", "started", "stopped"] | ||
|
|
||
| if TYPE_CHECKING: | ||
| from dimos.core.blueprints import Blueprint | ||
|
|
@@ -64,19 +66,6 @@ class SkillInfo: | |
| args_schema: str | ||
|
|
||
|
|
||
| def get_loop() -> tuple[asyncio.AbstractEventLoop, threading.Thread | None]: | ||
| try: | ||
| running_loop = asyncio.get_running_loop() | ||
| return running_loop, None | ||
| except RuntimeError: | ||
| loop = asyncio.new_event_loop() | ||
| asyncio.set_event_loop(loop) | ||
|
|
||
| thr = threading.Thread(target=loop.run_forever, daemon=True) | ||
| thr.start() | ||
| return loop, thr | ||
|
|
||
|
|
||
| class ModuleConfig(BaseConfig): | ||
| rpc_transport: type[RPCSpec] = LCMRPC | ||
| tf_transport: type[TFSpec] = LCMTF # type: ignore[type-arg] | ||
|
|
@@ -98,20 +87,22 @@ class ModuleBase(Configurable[ModuleConfigT], Resource): | |
|
|
||
| _rpc: RPCSpec | None = None | ||
| _tf: TFSpec[Any] | None = None | ||
| _loop: asyncio.AbstractEventLoop | None = None | ||
| _loop_thread: threading.Thread | None | ||
| _async_thread: AsyncModuleThread | ||
| _disposables: CompositeDisposable | ||
| _bound_rpc_calls: dict[str, RpcCall] = {} | ||
| _module_closed: bool = False | ||
| _module_closed_lock: threading.Lock | ||
| mod_state: ThreadSafeVal[ModState] | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't like |
||
|
|
||
| rpc_calls: list[str] = [] | ||
|
|
||
| def __init__(self, config_args: dict[str, Any]): | ||
| self.mod_state = ThreadSafeVal[ModState]("init") | ||
| super().__init__(**config_args) | ||
| self._module_closed_lock = threading.Lock() | ||
| self._loop, self._loop_thread = get_loop() | ||
| self._disposables = CompositeDisposable() | ||
| self._async_thread = ( | ||
| AsyncModuleThread( # NEEDS to be created after self._disposables exists | ||
| module=self | ||
| ) | ||
| ) | ||
| try: | ||
| self.rpc = self.config.rpc_transport() | ||
| self.rpc.serve_module_rpc(self) | ||
|
|
@@ -128,58 +119,43 @@ def frame_id(self) -> str: | |
|
|
||
| @rpc | ||
| def start(self) -> None: | ||
| pass | ||
| with self.mod_state as state: | ||
| if state == "stopped": | ||
| raise RuntimeError(f"{type(self).__name__} cannot be restarted after stop") | ||
| self.mod_state.set("started") | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I know lots of modules don't call super().start() but they also wouldn't be using mod_state cause its a new thing. Different/off-topic discussion, but I think core2 should have ModuleBase as class decorator instead of an inherited class (we can basically wrap methods instead of saying "please remember to call super").
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's one of the reasons I don't like inheritance. But can you explain what you mean by ModuleBase being a decorator? At first glace that seems more complicated. |
||
| self._async_thread.start() | ||
|
|
||
| @rpc | ||
| def stop(self) -> None: | ||
| self._close_module() | ||
|
|
||
| def _close_module(self) -> None: | ||
| with self._module_closed_lock: | ||
| if self._module_closed: | ||
| with self.mod_state as state: | ||
| if state == "stopped": | ||
| return | ||
| self._module_closed = True | ||
|
|
||
| self._close_rpc() | ||
| self.mod_state.set("stopped") | ||
|
|
||
| # Save into local variables to avoid race when stopping concurrently | ||
| # (from RPC and worker shutdown) | ||
| loop_thread = getattr(self, "_loop_thread", None) | ||
| loop = getattr(self, "_loop", None) | ||
| # dispose of things BEFORE making aspects like rpc and _tf invalid | ||
| if hasattr(self, "_disposables"): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You'll have lots of conflicts here with what Ivan is doing for disposables in his PR: #1682 . |
||
| self._disposables.dispose() # stops _async_thread via disposable | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think its important to move disposables up before the rpc stop and the tf stop |
||
|
|
||
| if loop_thread: | ||
| if loop_thread.is_alive(): | ||
| if loop: | ||
| loop.call_soon_threadsafe(loop.stop) | ||
| loop_thread.join(timeout=2) | ||
| self._loop = None | ||
| self._loop_thread = None | ||
| if self.rpc: | ||
| self.rpc.stop() # type: ignore[attr-defined] | ||
| self.rpc = None # type: ignore[assignment] | ||
|
Comment on lines
+140
to
+141
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please avoid |
||
|
|
||
| if hasattr(self, "_tf") and self._tf is not None: | ||
| self._tf.stop() | ||
| self._tf = None | ||
| if hasattr(self, "_disposables"): | ||
| self._disposables.dispose() | ||
|
|
||
| # Break the In/Out -> owner -> self reference cycle so the instance | ||
| # can be freed by refcount instead of waiting for GC. | ||
| for attr in list(vars(self).values()): | ||
| if isinstance(attr, (In, Out)): | ||
| attr.owner = None | ||
|
|
||
| def _close_rpc(self) -> None: | ||
| if self.rpc: | ||
| self.rpc.stop() # type: ignore[attr-defined] | ||
| self.rpc = None # type: ignore[assignment] | ||
|
|
||
| def __getstate__(self): # type: ignore[no-untyped-def] | ||
| """Exclude unpicklable runtime attributes when serializing.""" | ||
| state = self.__dict__.copy() | ||
| # Remove unpicklable attributes | ||
| state.pop("_disposables", None) | ||
| state.pop("_module_closed_lock", None) | ||
| state.pop("_loop", None) | ||
| state.pop("_loop_thread", None) | ||
| state.pop("_async_thread", None) | ||
| state.pop("_rpc", None) | ||
| state.pop("_tf", None) | ||
| return state | ||
|
|
@@ -189,9 +165,7 @@ def __setstate__(self, state) -> None: # type: ignore[no-untyped-def] | |
| self.__dict__.update(state) | ||
| # Reinitialize runtime attributes | ||
| self._disposables = CompositeDisposable() | ||
| self._module_closed_lock = threading.Lock() | ||
| self._loop = None | ||
| self._loop_thread = None | ||
| self._async_thread = None # type: ignore[assignment] | ||
| self._rpc = None | ||
| self._tf = None | ||
|
|
||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the loop is always there until super().stop() is called