Skip to content

Commit b17eaf2

Browse files
Merge pull request #193 from askui/feat/select-model
feat(chat): select model of run with request params
2 parents 5f00d12 + f3c06e7 commit b17eaf2

File tree

3 files changed

+11
-4
lines changed

3 files changed

+11
-4
lines changed

src/askui/chat/api/runs/models.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ class RunCreate(BaseModel):
4040

4141
stream: bool = False
4242
assistant_id: AssistantId
43+
model: str | None = None
4344

4445

4546
class RunStart(BaseModel):
@@ -146,7 +147,7 @@ def create(
146147
thread_id=thread_id,
147148
created_at=now(),
148149
expires_at=now() + timedelta(minutes=10),
149-
**params.model_dump(exclude={"stream"}),
150+
**params.model_dump(exclude={"model", "stream"}),
150151
)
151152

152153
@computed_field # type: ignore[prop-decorator]

src/askui/chat/api/runs/runner/runner.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import logging
33
from abc import ABC, abstractmethod
44
from datetime import datetime, timezone
5-
from typing import Any
65

76
from anthropic.types.beta import BetaCacheControlEphemeralParam, BetaTextBlockParam
87
from anyio.abc import ObjectStream
@@ -34,7 +33,6 @@
3433
)
3534
from askui.chat.api.settings import Settings
3635
from askui.custom_agent import CustomAgent
37-
from askui.models.models import ModelName
3836
from askui.models.shared.agent_message_param import MessageParam
3937
from askui.models.shared.agent_on_message_cb import OnMessageCbParam
4038
from askui.models.shared.settings import ActSettings, MessageSettings
@@ -67,6 +65,7 @@ def __init__(
6765
mcp_client_manager_manager: McpClientManagerManager,
6866
run_service: RunnerRunService,
6967
settings: Settings,
68+
model: str | None = None,
7069
) -> None:
7170
self._run_id = run_id
7271
self._workspace_id = workspace_id
@@ -76,6 +75,7 @@ def __init__(
7675
self._mcp_client_manager_manager = mcp_client_manager_manager
7776
self._run_service = run_service
7877
self._settings = settings
78+
self._model: str | None = model
7979

8080
def _retrieve_run(self) -> Run:
8181
return self._run_service.retrieve(
@@ -164,7 +164,7 @@ def _run_agent_inner() -> None:
164164
)
165165
betas = tools.retrieve_tool_beta_flags()
166166
system = self._build_system()
167-
model = self._settings.model
167+
model = self._get_model()
168168
messages = syncify(self._chat_history_manager.retrieve_message_params)(
169169
workspace_id=self._workspace_id,
170170
thread_id=self._thread_id,
@@ -269,3 +269,8 @@ async def run(
269269

270270
def _should_abort(self, run: Run) -> bool:
271271
return run.status in ("cancelled", "cancelling", "expired")
272+
273+
def _get_model(self) -> str:
274+
if self._model is not None:
275+
return self._model
276+
return self._settings.model

src/askui/chat/api/runs/service.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ async def create(
9292
mcp_client_manager_manager=self._mcp_client_manager_manager,
9393
run_service=self,
9494
settings=self._settings,
95+
model=params.model,
9596
)
9697

9798
async def event_generator() -> AsyncGenerator[Event, None]:

0 commit comments

Comments
 (0)