Skip to content

Commit 1908c1c

Browse files
Merge pull request #140 from askui/fix/hold-mcp-client-session
fix(chat)!: hold mcp client sessions
2 parents 4ebbd9c + b246baf commit 1908c1c

File tree

12 files changed

+302
-96
lines changed

12 files changed

+302
-96
lines changed

src/askui/chat/api/app.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
from askui.chat.api.dependencies import SetEnvFromHeadersDep, get_settings
1212
from askui.chat.api.files.router import router as files_router
1313
from askui.chat.api.health.router import router as health_router
14+
from askui.chat.api.mcp_clients.dependencies import get_mcp_client_manager_manager
15+
from askui.chat.api.mcp_clients.manager import McpServerConnectionError
1416
from askui.chat.api.mcp_configs.dependencies import get_mcp_config_service
1517
from askui.chat.api.mcp_configs.router import router as mcp_configs_router
1618
from askui.chat.api.mcps.computer import mcp as computer_mcp
@@ -35,6 +37,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: # noqa: ARG001
3537
mcp_config_service = get_mcp_config_service(settings=settings)
3638
mcp_config_service.seed()
3739
yield
40+
await get_mcp_client_manager_manager(mcp_config_service).disconnect_all(force=True)
3841

3942

4043
app = FastAPI(
@@ -144,6 +147,17 @@ def catch_all_exception_handler(
144147
)
145148

146149

150+
@app.exception_handler(McpServerConnectionError)
151+
def mcp_server_connection_error_handler(
152+
request: Request, # noqa: ARG001
153+
exc: McpServerConnectionError,
154+
) -> JSONResponse:
155+
return JSONResponse(
156+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
157+
content={"detail": str(exc)},
158+
)
159+
160+
147161
app.add_middleware(
148162
CORSMiddleware,
149163
allow_origins=["*"],

src/askui/chat/api/mcp_clients/__init__.py

Whitespace-only changes.
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
from fastapi import Depends
2+
3+
from askui.chat.api.mcp_clients.manager import McpClientManagerManager
4+
from askui.chat.api.mcp_configs.dependencies import McpConfigServiceDep
5+
from askui.chat.api.mcp_configs.service import McpConfigService
6+
7+
8+
def get_mcp_client_manager_manager(
9+
mcp_config_service: McpConfigService = McpConfigServiceDep,
10+
) -> McpClientManagerManager:
11+
return McpClientManagerManager(mcp_config_service)
12+
13+
14+
McpClientManagerManagerDep = Depends(get_mcp_client_manager_manager)
Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
import types
2+
from datetime import timedelta
3+
from typing import Any, Type
4+
5+
import anyio
6+
import mcp
7+
from fastmcp import Client
8+
from fastmcp.client.client import CallToolResult, ProgressHandler
9+
from fastmcp.exceptions import ToolError
10+
from fastmcp.mcp_config import MCPConfig
11+
12+
from askui.chat.api.mcp_configs.service import McpConfigService
13+
from askui.chat.api.models import WorkspaceId
14+
15+
McpServerName = str
16+
17+
18+
class McpServerConnectionError(Exception):
19+
"""Exception raised when a MCP server connection fails."""
20+
21+
def __init__(self, mcp_server_name: McpServerName, error: Exception):
22+
super().__init__(f"Failed to connect to MCP server: {mcp_server_name}: {error}")
23+
self.mcp_server_name = mcp_server_name
24+
self.error = error
25+
26+
27+
class McpClientManager:
28+
def __init__(
29+
self, mcp_clients: dict[McpServerName, Client[Any]] | None = None
30+
) -> None:
31+
self._mcp_clients = mcp_clients or {}
32+
self._tools: dict[McpServerName, list[mcp.types.Tool]] = {}
33+
34+
@classmethod
35+
def from_config(cls, mcp_config: MCPConfig) -> "McpClientManager":
36+
mcp_clients: dict[McpServerName, Client[Any]] = {
37+
mcp_server_name: Client(mcp_server_config.to_transport())
38+
for mcp_server_name, mcp_server_config in mcp_config.mcpServers.items()
39+
}
40+
return cls(mcp_clients)
41+
42+
async def connect(self) -> "McpClientManager":
43+
for mcp_server_name, mcp_client in self._mcp_clients.items():
44+
try:
45+
await mcp_client._connect() # noqa: SLF001
46+
except Exception as e: # noqa: PERF203
47+
raise McpServerConnectionError(mcp_server_name, e) from e
48+
return self
49+
50+
async def disconnect(self, force: bool = False) -> None:
51+
for mcp_client in self._mcp_clients.values():
52+
if mcp_client.is_connected():
53+
await mcp_client._disconnect(force) # noqa: SLF001
54+
55+
async def list_tools(
56+
self,
57+
) -> list[mcp.types.Tool]:
58+
tools: list[mcp.types.Tool] = []
59+
for mcp_server_name, mcp_client in self._mcp_clients.items():
60+
if mcp_server_name not in self._tools:
61+
self._tools[mcp_server_name] = await mcp_client.list_tools()
62+
tools.extend(self._tools[mcp_server_name])
63+
return tools
64+
65+
async def call_tool(
66+
self,
67+
name: str,
68+
arguments: dict[str, Any] | None = None,
69+
timeout: timedelta | float | None = None, # noqa: ASYNC109
70+
progress_handler: ProgressHandler | None = None,
71+
raise_on_error: bool = True,
72+
) -> CallToolResult:
73+
for mcp_server_name, tools in self._tools.items(): # Make lookup faster
74+
for tool in tools:
75+
if tool.name == name:
76+
return await self._mcp_clients[mcp_server_name].call_tool(
77+
name,
78+
arguments,
79+
timeout,
80+
progress_handler,
81+
raise_on_error,
82+
)
83+
error_msg = f"Unknown tool: {name}"
84+
if raise_on_error:
85+
raise ToolError(error_msg)
86+
return CallToolResult(
87+
content=[mcp.types.TextContent(type="text", text=error_msg)],
88+
structured_content=None,
89+
data=None,
90+
is_error=True,
91+
)
92+
93+
async def __aenter__(self) -> "McpClientManager":
94+
return await self.connect()
95+
96+
async def __aexit__(
97+
self,
98+
exc_type: Type[BaseException] | None,
99+
exc_value: BaseException | None,
100+
traceback: types.TracebackType | None,
101+
) -> None:
102+
await self.disconnect()
103+
104+
105+
McpClientManagerKey = str
106+
107+
108+
class McpClientManagerManager:
109+
_mcp_client_managers: dict[McpClientManagerKey, McpClientManager | None] = {}
110+
_lock: anyio.Lock = anyio.Lock()
111+
112+
def __init__(self, mcp_config_service: McpConfigService) -> None:
113+
self._mcp_config_service = mcp_config_service
114+
115+
async def get_mcp_client_manager(
116+
self, workspace_id: WorkspaceId | None
117+
) -> McpClientManager | None:
118+
key: McpClientManagerKey = (
119+
f"workspace_{workspace_id}" if workspace_id else "global"
120+
)
121+
if key in McpClientManagerManager._mcp_client_managers:
122+
return McpClientManagerManager._mcp_client_managers[key]
123+
124+
fast_mcp_config = self._mcp_config_service.retrieve_fast_mcp_config(
125+
workspace_id
126+
)
127+
if not fast_mcp_config:
128+
McpClientManagerManager._mcp_client_managers[key] = None
129+
return None
130+
131+
async with McpClientManagerManager._lock:
132+
if key not in McpClientManagerManager._mcp_client_managers:
133+
try:
134+
mcp_client_manager = McpClientManager.from_config(fast_mcp_config)
135+
McpClientManagerManager._mcp_client_managers[key] = (
136+
mcp_client_manager
137+
)
138+
await mcp_client_manager.connect()
139+
except Exception:
140+
if key in McpClientManagerManager._mcp_client_managers:
141+
if (
142+
_mcp_client_manager
143+
:= McpClientManagerManager._mcp_client_managers[key]
144+
):
145+
await _mcp_client_manager.disconnect(force=True)
146+
del McpClientManagerManager._mcp_client_managers[key]
147+
raise
148+
return McpClientManagerManager._mcp_client_managers[key]
149+
150+
async def disconnect_all(self, force: bool = False) -> None:
151+
async with McpClientManagerManager._lock:
152+
for (
153+
mcp_client_manager
154+
) in McpClientManagerManager._mcp_client_managers.values():
155+
if mcp_client_manager:
156+
await mcp_client_manager.disconnect(force)

src/askui/chat/api/mcp_configs/seeds.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from fastmcp.mcp_config import RemoteMCPServer
1+
from fastmcp.mcp_config import RemoteMCPServer, StdioMCPServer
22

33
from askui.chat.api.dependencies import get_settings
44
from askui.chat.api.mcp_configs.models import McpConfig
@@ -10,12 +10,23 @@
1010
ASKUI_CHAT_MCP = McpConfig(
1111
id="mcpcnf_68ac2c4edc4b2f27faa5a252",
1212
created_at=now(),
13-
name="AskUI Chat MCP",
13+
name="askui_chat",
1414
mcp_server=RemoteMCPServer(
1515
url=f"http://{settings.host}:{settings.port}/mcp/sse",
1616
transport="sse",
1717
),
1818
)
1919

2020

21-
SEEDS = [ASKUI_CHAT_MCP]
21+
PLAYWRIGHT_MCP = McpConfig(
22+
id="mcpcnf_68ac2c4edc4b2f27faa5a251",
23+
created_at=now(),
24+
name="playwright",
25+
mcp_server=StdioMCPServer(
26+
command="npx",
27+
args=["@playwright/mcp@latest", "--isolated"],
28+
),
29+
)
30+
31+
32+
SEEDS = [ASKUI_CHAT_MCP, PLAYWRIGHT_MCP]

src/askui/chat/api/mcp_configs/service.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from pathlib import Path
22

3+
from fastmcp.mcp_config import MCPConfig
4+
35
from askui.chat.api.mcp_configs.models import (
46
McpConfig,
57
McpConfigCreateParams,
@@ -69,6 +71,18 @@ def retrieve(
6971
else:
7072
return mcp_config
7173

74+
def retrieve_fast_mcp_config(
75+
self, workspace_id: WorkspaceId | None
76+
) -> MCPConfig | None:
77+
list_response = self.list_(
78+
workspace_id=workspace_id,
79+
query=ListQuery(limit=LIST_LIMIT_MAX, order="asc"),
80+
)
81+
mcp_servers_dict = {
82+
mcp_config.name: mcp_config.mcp_server for mcp_config in list_response.data
83+
}
84+
return MCPConfig(mcpServers=mcp_servers_dict) if mcp_servers_dict else None
85+
7286
def _check_limit(self, workspace_id: WorkspaceId | None) -> None:
7387
limit = LIST_LIMIT_MAX
7488
list_result = self.list_(workspace_id, ListQuery(limit=limit))

src/askui/chat/api/runs/dependencies.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
from askui.chat.api.assistants.dependencies import AssistantServiceDep
66
from askui.chat.api.assistants.service import AssistantService
77
from askui.chat.api.dependencies import WorkspaceDirDep
8-
from askui.chat.api.mcp_configs.dependencies import McpConfigServiceDep
9-
from askui.chat.api.mcp_configs.service import McpConfigService
8+
from askui.chat.api.mcp_clients.dependencies import McpClientManagerManagerDep
9+
from askui.chat.api.mcp_clients.manager import McpClientManagerManager
1010
from askui.chat.api.messages.dependencies import MessageServiceDep, MessageTranslatorDep
1111
from askui.chat.api.messages.service import MessageService
1212
from askui.chat.api.messages.translator import MessageTranslator
@@ -17,15 +17,15 @@
1717
def get_runs_service(
1818
workspace_dir: Path = WorkspaceDirDep,
1919
assistant_service: AssistantService = AssistantServiceDep,
20-
mcp_config_service: McpConfigService = McpConfigServiceDep,
20+
mcp_client_manager_manager: McpClientManagerManager = McpClientManagerManagerDep,
2121
message_service: MessageService = MessageServiceDep,
2222
message_translator: MessageTranslator = MessageTranslatorDep,
2323
) -> RunService:
2424
"""Get RunService instance."""
2525
return RunService(
2626
base_dir=workspace_dir,
2727
assistant_service=assistant_service,
28-
mcp_config_service=mcp_config_service,
28+
mcp_client_manager_manager=mcp_client_manager_manager,
2929
message_service=message_service,
3030
message_translator=message_translator,
3131
)

src/askui/chat/api/runs/runner/runner.py

Lines changed: 17 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,13 @@
11
import logging
22
from abc import ABC, abstractmethod
3-
from typing import TYPE_CHECKING, Literal, Sequence
3+
from typing import TYPE_CHECKING, Literal
44

55
import anthropic
66
import anyio
77
from anyio.abc import ObjectStream
88
from asyncer import asyncify, syncify
99
from fastmcp import Client
1010
from fastmcp.client.transports import MCPConfigTransport
11-
from fastmcp.mcp_config import MCPConfig
1211

1312
from askui.android_agent import AndroidVisionAgent
1413
from askui.chat.api.assistants.models import Assistant
@@ -18,8 +17,7 @@
1817
TESTING_AGENT,
1918
WEB_AGENT,
2019
)
21-
from askui.chat.api.mcp_configs.models import McpConfig
22-
from askui.chat.api.mcp_configs.service import McpConfigService
20+
from askui.chat.api.mcp_clients.manager import McpClientManagerManager
2321
from askui.chat.api.messages.models import MessageCreateParams
2422
from askui.chat.api.messages.service import MessageService
2523
from askui.chat.api.messages.translator import MessageTranslator
@@ -57,13 +55,6 @@
5755
logger = logging.getLogger(__name__)
5856

5957

60-
def build_fast_mcp_config(mcp_configs: Sequence[McpConfig]) -> MCPConfig:
61-
mcp_config_dict = {
62-
mcp_config.id: mcp_config.mcp_server for mcp_config in mcp_configs
63-
}
64-
return MCPConfig(mcpServers=mcp_config_dict)
65-
66-
6758
McpClient = Client[MCPConfigTransport]
6859

6960

@@ -85,7 +76,7 @@ def __init__(
8576
run: Run,
8677
message_service: MessageService,
8778
message_translator: MessageTranslator,
88-
mcp_config_service: McpConfigService,
79+
mcp_client_manager_manager: McpClientManagerManager,
8980
run_service: RunnerRunService,
9081
) -> None:
9182
self._workspace_id = workspace_id
@@ -94,18 +85,10 @@ def __init__(
9485
self._message_service = message_service
9586
self._message_translator = message_translator
9687
self._message_content_translator = message_translator.content_translator
97-
self._mcp_config_service = mcp_config_service
88+
self._mcp_client_manager_manager = mcp_client_manager_manager
9889
self._run_service = run_service
9990
self._agent_os = PynputAgentOs()
10091

101-
def _get_mcp_client(self) -> McpClient | None:
102-
mcp_configs = self._mcp_config_service.list_(
103-
workspace_id=self._workspace_id,
104-
query=ListQuery(limit=LIST_LIMIT_MAX, order="asc"),
105-
)
106-
fast_mcp_config = build_fast_mcp_config(mcp_configs.data)
107-
return Client(fast_mcp_config) if fast_mcp_config.mcpServers else None
108-
10992
def _retrieve(self) -> Run:
11093
return self._run_service.retrieve(
11194
thread_id=self._run.thread_id,
@@ -342,19 +325,24 @@ def _run_agent_inner() -> None:
342325

343326
await asyncify(_run_agent_inner)()
344327

328+
async def _get_mcp_client(self) -> McpClient | None:
329+
return await self._mcp_client_manager_manager.get_mcp_client_manager( # type: ignore
330+
self._workspace_id
331+
)
332+
345333
async def run(
346334
self,
347335
send_stream: ObjectStream[Events],
348336
) -> None:
349-
mcp_client = self._get_mcp_client()
350-
self._mark_run_as_started()
351-
await send_stream.send(
352-
RunEvent(
353-
data=self._run,
354-
event="thread.run.in_progress",
355-
)
356-
)
357337
try:
338+
mcp_client = await self._get_mcp_client()
339+
self._mark_run_as_started()
340+
await send_stream.send(
341+
RunEvent(
342+
data=self._run,
343+
event="thread.run.in_progress",
344+
)
345+
)
358346
if self._run.assistant_id == HUMAN_DEMONSTRATION_AGENT.id:
359347
await self._run_human_agent(send_stream)
360348
elif self._run.assistant_id == ANDROID_AGENT.id:

0 commit comments

Comments
 (0)