Skip to content

Commit dbddb3a

Browse files
Merge pull request #124 from askui/feat/create-thread-and-run-endpoint
feat/create thread and run endpoint
2 parents c518fc5 + c664a27 commit dbddb3a

File tree

4 files changed

+356
-10
lines changed

4 files changed

+356
-10
lines changed

src/askui/chat/api/runs/models.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
from datetime import datetime, timedelta, timezone
22
from typing import Literal
33

4-
from pydantic import BaseModel, Field, computed_field
4+
from pydantic import BaseModel, computed_field
55

66
from askui.chat.api.models import AssistantId, RunId, ThreadId
7+
from askui.chat.api.threads.models import ThreadCreateParams
78
from askui.utils.api_utils import Resource
89
from askui.utils.datetime_utils import UnixDatetime, now
910
from askui.utils.id_utils import generate_time_ordered_id
@@ -38,6 +39,10 @@ class RunCreateParams(RunBase):
3839
stream: bool = False
3940

4041

42+
class ThreadAndRunCreateParams(RunCreateParams):
43+
thread: ThreadCreateParams
44+
45+
4146
class Run(RunBase, Resource):
4247
"""A run execution within a thread."""
4348

src/askui/chat/api/runs/router.py

Lines changed: 39 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,13 @@
1313
from askui.utils.api_utils import ListQuery, ListResponse
1414

1515
from .dependencies import RunServiceDep
16-
from .models import Run
16+
from .models import Run, ThreadAndRunCreateParams
1717
from .service import RunService
1818

19-
router = APIRouter(prefix="/threads/{thread_id}/runs", tags=["runs"])
19+
router = APIRouter(tags=["runs"])
2020

2121

22-
@router.post("")
22+
@router.post("/threads/{thread_id}/runs")
2323
async def create_run(
2424
thread_id: Annotated[ThreadId, Path(...)],
2525
params: RunCreateParams,
@@ -53,7 +53,40 @@ async def _run_async_generator() -> None:
5353
return JSONResponse(status_code=status.HTTP_201_CREATED, content=run.model_dump())
5454

5555

56-
@router.get("/{run_id}")
56+
@router.post("/runs")
57+
async def create_thread_and_run(
58+
params: ThreadAndRunCreateParams,
59+
background_tasks: BackgroundTasks,
60+
thread_facade: ThreadFacade = ThreadFacadeDep,
61+
) -> Response:
62+
stream = params.stream
63+
run, async_generator = await thread_facade.create_thread_and_run(params)
64+
if stream:
65+
66+
async def sse_event_stream() -> AsyncGenerator[str, None]:
67+
async for event in async_generator:
68+
data = (
69+
event.data.model_dump_json()
70+
if isinstance(event.data, BaseModel)
71+
else event.data
72+
)
73+
yield f"event: {event.event}\ndata: {data}\n\n"
74+
75+
return StreamingResponse(
76+
status_code=status.HTTP_201_CREATED,
77+
content=sse_event_stream(),
78+
media_type="text/event-stream",
79+
)
80+
81+
async def _run_async_generator() -> None:
82+
async for _ in async_generator:
83+
pass
84+
85+
background_tasks.add_task(_run_async_generator)
86+
return JSONResponse(status_code=status.HTTP_201_CREATED, content=run.model_dump())
87+
88+
89+
@router.get("/threads/{thread_id}/runs/{run_id}")
5790
def retrieve_run(
5891
thread_id: Annotated[ThreadId, Path(...)],
5992
run_id: Annotated[RunId, Path(...)],
@@ -62,7 +95,7 @@ def retrieve_run(
6295
return run_service.retrieve(thread_id, run_id)
6396

6497

65-
@router.get("")
98+
@router.get("/threads/{thread_id}/runs")
6699
def list_runs(
67100
thread_id: Annotated[ThreadId, Path(...)],
68101
query: ListQuery = ListQueryDep,
@@ -71,7 +104,7 @@ def list_runs(
71104
return thread_facade.list_runs(thread_id, query=query)
72105

73106

74-
@router.post("/{run_id}/cancel")
107+
@router.post("/threads/{thread_id}/runs/{run_id}/cancel")
75108
def cancel_run(
76109
thread_id: Annotated[ThreadId, Path(...)],
77110
run_id: Annotated[RunId, Path(...)],

src/askui/chat/api/threads/facade.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from askui.chat.api.messages.models import Message, MessageCreateParams
44
from askui.chat.api.messages.service import MessageService
55
from askui.chat.api.models import ThreadId
6-
from askui.chat.api.runs.models import Run, RunCreateParams
6+
from askui.chat.api.runs.models import Run, RunCreateParams, ThreadAndRunCreateParams
77
from askui.chat.api.runs.runner.events.events import Events
88
from askui.chat.api.runs.service import RunService
99
from askui.chat.api.threads.service import ThreadService
@@ -43,6 +43,13 @@ async def create_run(
4343
self._ensure_thread_exists(thread_id)
4444
return await self._run_service.create(thread_id, params)
4545

46+
async def create_thread_and_run(
47+
self, params: ThreadAndRunCreateParams
48+
) -> tuple[Run, AsyncGenerator[Events, None]]:
49+
"""Create a thread and a run, ensuring the thread exists first."""
50+
thread = self._thread_service.create(params.thread)
51+
return await self._run_service.create(thread.id, params)
52+
4653
def list_messages(
4754
self, thread_id: ThreadId, query: ListQuery
4855
) -> ListResponse[Message]:

0 commit comments

Comments
 (0)