diff --git a/src/framex/adapter/base.py b/src/framex/adapter/base.py index d9c6391..a6cb721 100644 --- a/src/framex/adapter/base.py +++ b/src/framex/adapter/base.py @@ -1,6 +1,6 @@ import abc import inspect -from collections.abc import Callable +from collections.abc import AsyncIterable, Callable from enum import StrEnum from typing import Any, cast @@ -39,7 +39,7 @@ async def call_func(self, api: PluginApi, **kwargs: Any) -> Any: stream = await self._resolve_stream(api, kwargs) if stream: gen = self._stream_call(func, **kwargs) - if inspect.isawaitable(gen): + if not isinstance(gen, AsyncIterable) and inspect.isawaitable(gen): gen = await gen return [chunk async for chunk in gen] return await self._invoke(func, **kwargs) diff --git a/src/framex/driver/ingress.py b/src/framex/driver/ingress.py index cd52307..20756c1 100644 --- a/src/framex/driver/ingress.py +++ b/src/framex/driver/ingress.py @@ -142,7 +142,7 @@ async def stream_with_error() -> AsyncIterable[Any]: if auth_keys is not None: _verify_api_key(framex_request, framex_request.headers.get("Authorization")) gen = adapter._stream_call(c_handle, **request_kwargs) - if inspect.isawaitable(gen): + if not isinstance(gen, AsyncIterable) and inspect.isawaitable(gen): gen = await gen chunks = gen if isinstance(gen, AsyncIterable) else iterate_in_threadpool(iter(gen)) async for chunk in chunks: diff --git a/tests/adapter/test_local_adapter.py b/tests/adapter/test_local_adapter.py index 4685866..8ca96ae 100644 --- a/tests/adapter/test_local_adapter.py +++ b/tests/adapter/test_local_adapter.py @@ -9,6 +9,24 @@ from framex.plugin.model import PluginApi +class AwaitableAsyncStream: + def __init__(self, chunks): + self._chunks = iter(chunks) + + def __await__(self): + raise RuntimeError("stream response should not be awaited") + yield # pragma: no cover + + def __aiter__(self): + return self + + async def __anext__(self): + try: + return next(self._chunks) + except StopIteration as exc: + raise StopAsyncIteration from exc + + class TestLocalAdapter: """Tests for LocalAdapter class.""" @@ -247,6 +265,18 @@ def stream_func(**kwargs): assert "value1" in values assert "value2" in values + async def test_call_func_stream_does_not_await_async_iterable_response(self): + adapter = LocalAdapter() + api = PluginApi(deployment_name="demo", func_name="stream", stream=True) + + with ( + patch.object(adapter, "get_handle_func", return_value=MagicMock()), + patch.object(adapter, "_stream_call", return_value=AwaitableAsyncStream(["chunk"])), + ): + result = await adapter.call_func(api) + + assert result == ["chunk"] + async def test_stream_call_with_async_generator(self): """Test _stream_call works with async generators.""" adapter = LocalAdapter() diff --git a/tests/driver/test_ingress.py b/tests/driver/test_ingress.py index 5b15924..bb892f1 100644 --- a/tests/driver/test_ingress.py +++ b/tests/driver/test_ingress.py @@ -188,6 +188,32 @@ async def collect_stream_response(endpoint): return [chunk async for chunk in response.body_iterator] +class AwaitableAsyncStream: + def __init__(self, chunks): + self._chunks = iter(chunks) + + def __await__(self): + raise RuntimeError("stream response should not be awaited") + yield # pragma: no cover + + def __aiter__(self): + return self + + async def __anext__(self): + try: + return next(self._chunks) + except StopIteration as exc: + raise StopAsyncIteration from exc + + +async def test_register_route_stream_does_not_await_async_iterable_response(ingress, mock_app): + adapter = Mock() + adapter._stream_call.return_value = AwaitableAsyncStream(["chunk"]) + endpoint = register_stream_endpoint(ingress, mock_app, adapter) + + assert await collect_stream_response(endpoint) == ["chunk"] + + async def test_register_route_stream_converts_iteration_error_to_sse_event(ingress, mock_app): async def failing_stream(): yield "first"