diff --git a/.gitignore b/.gitignore index 80b089b..653f9e0 100644 --- a/.gitignore +++ b/.gitignore @@ -13,3 +13,4 @@ __pycache__/ .env venv/ test_local.py +test_local_suite.py diff --git a/setup.py b/setup.py index fe2f97d..9f65b52 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ setup( name='tavily-python', - version='0.7.20', + version='0.7.21', url='https://github.com/tavily-ai/tavily-python', author='Tavily AI', author_email='support@tavily.com', diff --git a/tavily/async_tavily.py b/tavily/async_tavily.py index 8969047..80dce2b 100644 --- a/tavily/async_tavily.py +++ b/tavily/async_tavily.py @@ -44,8 +44,9 @@ def __init__(self, api_key: Optional[str] = None, tavily_project = project_id or os.getenv("TAVILY_PROJECT") self._api_base_url = api_base_url or "https://api.tavily.com" - - self._client_creator = lambda: httpx.AsyncClient( + + # Create a persistent client for connection pooling + self._client = httpx.AsyncClient( headers={ "Content-Type": "application/json", "Authorization": f"Bearer {api_key}", @@ -57,6 +58,16 @@ def __init__(self, api_key: Optional[str] = None, ) self._company_info_tags = company_info_tags + async def close(self): + """Close the client and release connection pool resources.""" + await self._client.aclose() + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + await self.close() + async def _search( self, query: str, @@ -109,11 +120,10 @@ async def _search( timeout = min(timeout, 120) - async with self._client_creator() as client: - try: - response = await client.post("/search", content=json.dumps(data), timeout=timeout) - except httpx.TimeoutException: - raise TimeoutError(timeout) + try: + response = await self._client.post("/search", content=json.dumps(data), timeout=timeout) + except httpx.TimeoutException: + raise TimeoutError(timeout) if response.status_code == 200: return response.json() @@ -221,11 +231,10 @@ async def _extract( if kwargs: data.update(kwargs) - async with self._client_creator() as client: - try: - response = await client.post("/extract", content=json.dumps(data), timeout=timeout) - except httpx.TimeoutException: - raise TimeoutError(timeout) + try: + response = await self._client.post("/extract", content=json.dumps(data), timeout=timeout) + except httpx.TimeoutException: + raise TimeoutError(timeout) if response.status_code == 200: return response.json() @@ -283,7 +292,7 @@ async def extract(self, response_dict["failed_results"] = failed_results return response_dict - + async def _crawl(self, url: str, max_depth: int = None, @@ -332,31 +341,30 @@ async def _crawl(self, data = {k: v for k, v in data.items() if v is not None} - async with self._client_creator() as client: + try: + response = await self._client.post("/crawl", content=json.dumps(data), timeout=timeout) + except httpx.TimeoutException: + raise TimeoutError(timeout) + + if response.status_code == 200: + return response.json() + else: + detail = "" try: - response = await client.post("/crawl", content=json.dumps(data), timeout=timeout) - except httpx.TimeoutException: - raise TimeoutError(timeout) + detail = response.json().get("detail", {}).get("error", None) + except Exception: + pass - if response.status_code == 200: - return response.json() + if response.status_code == 429: + raise UsageLimitExceededError(detail) + elif response.status_code in [403,432,433]: + raise ForbiddenError(detail) + elif response.status_code == 401: + raise InvalidAPIKeyError(detail) + elif response.status_code == 400: + raise BadRequestError(detail) else: - detail = "" - try: - detail = response.json().get("detail", {}).get("error", None) - except Exception: - pass - - if response.status_code == 429: - raise UsageLimitExceededError(detail) - elif response.status_code in [403,432,433]: - raise ForbiddenError(detail) - elif response.status_code == 401: - raise InvalidAPIKeyError(detail) - elif response.status_code == 400: - raise BadRequestError(detail) - else: - raise response.raise_for_status() + raise response.raise_for_status() async def crawl(self, url: str, @@ -380,7 +388,7 @@ async def crawl(self, ) -> dict: """ Combined crawl method. - + """ response_dict = await self._crawl(url, max_depth=max_depth, @@ -402,7 +410,7 @@ async def crawl(self, **kwargs) return response_dict - + async def _map(self, url: str, max_depth: int = None, @@ -443,31 +451,30 @@ async def _map(self, data = {k: v for k, v in data.items() if v is not None} - async with self._client_creator() as client: + try: + response = await self._client.post("/map", content=json.dumps(data), timeout=timeout) + except httpx.TimeoutException: + raise TimeoutError(timeout) + + if response.status_code == 200: + return response.json() + else: + detail = "" try: - response = await client.post("/map", content=json.dumps(data), timeout=timeout) - except httpx.TimeoutException: - raise TimeoutError(timeout) + detail = response.json().get("detail", {}).get("error", None) + except Exception: + pass - if response.status_code == 200: - return response.json() + if response.status_code == 429: + raise UsageLimitExceededError(detail) + elif response.status_code in [403,432,433]: + raise ForbiddenError(detail) + elif response.status_code == 401: + raise InvalidAPIKeyError(detail) + elif response.status_code == 400: + raise BadRequestError(detail) else: - detail = "" - try: - detail = response.json().get("detail", {}).get("error", None) - except Exception: - pass - - if response.status_code == 429: - raise UsageLimitExceededError(detail) - elif response.status_code in [403,432,433]: - raise ForbiddenError(detail) - elif response.status_code == 401: - raise InvalidAPIKeyError(detail) - elif response.status_code == 400: - raise BadRequestError(detail) - else: - raise response.raise_for_status() + raise response.raise_for_status() async def map(self, url: str, @@ -639,68 +646,66 @@ def _research(self, if stream: async def stream_generator() -> AsyncGenerator[bytes, None]: try: - async with self._client_creator() as client: - async with client.stream( - "POST", - "/research", - content=json.dumps(data), - timeout=timeout - ) as response: - if response.status_code != 200: - try: - error_text = await response.aread() - error_text = error_text.decode('utf-8') if isinstance(error_text, bytes) else error_text - except Exception: - error_text = "Unknown error" - - if response.status_code == 429: - raise UsageLimitExceededError(error_text) - elif response.status_code in [403,432,433]: - raise ForbiddenError(error_text) - elif response.status_code == 401: - raise InvalidAPIKeyError(error_text) - elif response.status_code == 400: - raise BadRequestError(error_text) - else: - raise Exception(f"Error {response.status_code}: {error_text}") - - async for chunk in response.aiter_bytes(): - if chunk: - yield chunk + async with self._client.stream( + "POST", + "/research", + content=json.dumps(data), + timeout=timeout + ) as response: + if response.status_code != 200: + try: + error_text = await response.aread() + error_text = error_text.decode('utf-8') if isinstance(error_text, bytes) else error_text + except Exception: + error_text = "Unknown error" + + if response.status_code == 429: + raise UsageLimitExceededError(error_text) + elif response.status_code in [403,432,433]: + raise ForbiddenError(error_text) + elif response.status_code == 401: + raise InvalidAPIKeyError(error_text) + elif response.status_code == 400: + raise BadRequestError(error_text) + else: + raise Exception(f"Error {response.status_code}: {error_text}") + + async for chunk in response.aiter_bytes(): + if chunk: + yield chunk except httpx.TimeoutException: raise TimeoutError(timeout) except Exception as e: raise Exception(f"Error during research stream: {str(e)}") - + return stream_generator() else: async def _make_request(): - async with self._client_creator() as client: - try: - response = await client.post("/research", content=json.dumps(data), timeout=timeout) - except httpx.TimeoutException: - raise TimeoutError(timeout) + try: + response = await self._client.post("/research", content=json.dumps(data), timeout=timeout) + except httpx.TimeoutException: + raise TimeoutError(timeout) - if response.status_code == 200: - return response.json() + if response.status_code == 200: + return response.json() + else: + detail = "" + try: + detail = response.json().get("detail", {}).get("error", None) + except Exception: + pass + + if response.status_code == 429: + raise UsageLimitExceededError(detail) + elif response.status_code in [403,432,433]: + raise ForbiddenError(detail) + elif response.status_code == 401: + raise InvalidAPIKeyError(detail) + elif response.status_code == 400: + raise BadRequestError(detail) else: - detail = "" - try: - detail = response.json().get("detail", {}).get("error", None) - except Exception: - pass - - if response.status_code == 429: - raise UsageLimitExceededError(detail) - elif response.status_code in [403,432,433]: - raise ForbiddenError(detail) - elif response.status_code == 401: - raise InvalidAPIKeyError(detail) - elif response.status_code == 400: - raise BadRequestError(detail) - else: - raise response.raise_for_status() - + raise response.raise_for_status() + return _make_request() async def research(self, @@ -714,16 +719,16 @@ async def research(self, ) -> Union[dict, AsyncGenerator[bytes, None]]: """ Research method to create a research task. - + Args: input: The research task description (required). model: Research depth - must be either 'mini', 'pro', or 'auto'. output_schema: Schema for the 'structured_output' response format (JSON Schema dict). stream: Whether to stream the research task. citation_format: Citation format - must be either 'numbered', 'mla', 'apa', or 'chicago'. - timeout: Optional HTTP request timeout in seconds. + timeout: Optional HTTP request timeout in seconds. **kwargs: Additional custom arguments. - + Returns: When stream=False: dict - the response dictionary. When stream=True: AsyncGenerator[bytes, None] - iterate over this to get streaming chunks. @@ -740,43 +745,42 @@ async def research(self, if stream: return result # Don't await the result, it's an AsyncGenerator that will be lazy and only execute when iterated over with async for else: - return await result + return await result async def get_research(self, request_id: str ) -> dict: """ Get research results by request_id. - + Args: request_id: The research request ID. - + Returns: dict: Research response containing request_id, created_at, completed_at, status, content, and sources. """ - async with self._client_creator() as client: + try: + response = await self._client.get(f"/research/{request_id}") + except Exception as e: + raise Exception(f"Error getting research: {e}") + + if response.status_code in (200, 202): + data = response.json() + return data + else: + detail = "" try: - response = await client.get(f"/research/{request_id}") - except Exception as e: - raise Exception(f"Error getting research: {e}") + detail = response.json().get("detail", {}).get("error", None) + except Exception: + pass - if response.status_code in (200, 202): - data = response.json() - return data + if response.status_code == 429: + raise UsageLimitExceededError(detail) + elif response.status_code in [403,432,433]: + raise ForbiddenError(detail) + elif response.status_code == 401: + raise InvalidAPIKeyError(detail) + elif response.status_code == 400: + raise BadRequestError(detail) else: - detail = "" - try: - detail = response.json().get("detail", {}).get("error", None) - except Exception: - pass - - if response.status_code == 429: - raise UsageLimitExceededError(detail) - elif response.status_code in [403,432,433]: - raise ForbiddenError(detail) - elif response.status_code == 401: - raise InvalidAPIKeyError(detail) - elif response.status_code == 400: - raise BadRequestError(detail) - else: - raise response.raise_for_status() + raise response.raise_for_status() diff --git a/tests/test_errors.py b/tests/test_errors.py index 48a66ef..71276d5 100644 --- a/tests/test_errors.py +++ b/tests/test_errors.py @@ -11,7 +11,10 @@ def set_api_key(): old_key = os.getenv("TAVILY_API_KEY") os.environ["TAVILY_API_KEY"] = "test_api_key" yield - os.environ["TAVILY_API_KEY"] = old_key + if old_key is not None: + os.environ["TAVILY_API_KEY"] = old_key + elif "TAVILY_API_KEY" in os.environ: + del os.environ["TAVILY_API_KEY"] @pytest.fixture def clear_api_key(): diff --git a/tests/test_session_pooling.py b/tests/test_session_pooling.py new file mode 100644 index 0000000..233a893 --- /dev/null +++ b/tests/test_session_pooling.py @@ -0,0 +1,152 @@ +""" +Tests for session pooling functionality in both sync and async clients. +""" +import asyncio +import pytest + +import tavily.tavily as sync_tavily +import tavily.async_tavily as async_tavily +from tests.request_intercept import intercept_requests, clear_interceptor + + +dummy_response = { + "query": "test", + "results": [{"title": "Test", "url": "https://test.com", "content": "Test content", "score": 0.99}], + "response_time": 0.1 +} + + +# ============================================================================= +# SYNC CLIENT TESTS +# ============================================================================= + +class TestSyncSessionPooling: + """Test sync client session pooling and lifecycle.""" + + @pytest.fixture + def interceptor(self): + yield intercept_requests(sync_tavily) + clear_interceptor(sync_tavily) + + def test_context_manager(self, interceptor): + """Test that sync client works with context manager.""" + interceptor.set_response(200, json=dummy_response) + + with sync_tavily.TavilyClient(api_key="tvly-test") as client: + response = client.search("test query") + assert response["results"][0]["title"] == "Test" + + def test_close_method(self, interceptor): + """Test that close() method works without error.""" + interceptor.set_response(200, json=dummy_response) + + client = sync_tavily.TavilyClient(api_key="tvly-test") + response = client.search("test query") + assert response["results"][0]["title"] == "Test" + + # close() should not raise + client.close() + + def test_multiple_sequential_requests(self, interceptor): + """Test that multiple requests work with same client (connection reuse).""" + interceptor.set_response(200, json=dummy_response) + + client = sync_tavily.TavilyClient(api_key="tvly-test") + + # Make multiple requests with same client + for i in range(3): + response = client.search(f"test query {i}") + assert response["results"][0]["title"] == "Test" + + client.close() + + def test_context_manager_multiple_requests(self, interceptor): + """Test multiple requests within context manager.""" + interceptor.set_response(200, json=dummy_response) + + with sync_tavily.TavilyClient(api_key="tvly-test") as client: + for i in range(3): + response = client.search(f"test query {i}") + assert response["results"][0]["title"] == "Test" + + +# ============================================================================= +# ASYNC CLIENT TESTS +# ============================================================================= + +class TestAsyncSessionPooling: + """Test async client session pooling and lifecycle.""" + + @pytest.fixture + def interceptor(self): + yield intercept_requests(async_tavily) + clear_interceptor(async_tavily) + + def test_context_manager(self, interceptor): + """Test that async client works with async context manager.""" + interceptor.set_response(200, json=dummy_response) + + async def run(): + async with async_tavily.AsyncTavilyClient(api_key="tvly-test") as client: + response = await client.search("test query") + assert response["results"][0]["title"] == "Test" + + asyncio.run(run()) + + def test_close_method(self, interceptor): + """Test that close() method works without error.""" + interceptor.set_response(200, json=dummy_response) + + async def run(): + client = async_tavily.AsyncTavilyClient(api_key="tvly-test") + response = await client.search("test query") + assert response["results"][0]["title"] == "Test" + + # close() should not raise + await client.close() + + asyncio.run(run()) + + def test_multiple_sequential_requests(self, interceptor): + """Test that multiple requests work with same client (connection reuse).""" + interceptor.set_response(200, json=dummy_response) + + async def run(): + client = async_tavily.AsyncTavilyClient(api_key="tvly-test") + + # Make multiple requests with same client + for i in range(3): + response = await client.search(f"test query {i}") + assert response["results"][0]["title"] == "Test" + + await client.close() + + asyncio.run(run()) + + def test_context_manager_multiple_requests(self, interceptor): + """Test multiple requests within async context manager.""" + interceptor.set_response(200, json=dummy_response) + + async def run(): + async with async_tavily.AsyncTavilyClient(api_key="tvly-test") as client: + for i in range(3): + response = await client.search(f"test query {i}") + assert response["results"][0]["title"] == "Test" + + asyncio.run(run()) + + def test_concurrent_requests_same_client(self, interceptor): + """Test concurrent requests using same client.""" + interceptor.set_response(200, json=dummy_response) + + async def run(): + async with async_tavily.AsyncTavilyClient(api_key="tvly-test") as client: + # Run multiple searches concurrently + tasks = [client.search(f"query {i}") for i in range(3)] + results = await asyncio.gather(*tasks) + + assert len(results) == 3 + for result in results: + assert result["results"][0]["title"] == "Test" + + asyncio.run(run())