From 5cdb58d32703a567b09bd53aeed037489aa208c1 Mon Sep 17 00:00:00 2001 From: tinosattavily Date: Mon, 26 Jan 2026 13:27:59 -0500 Subject: [PATCH] feat(client): add session pooling for connection reuse - Add requests.Session() for HTTP connection pooling - Add close() method and context manager support for resource cleanup - Remove deprecated get_company_info() method - Simplify public methods with direct returns - Ensure response key normalization matches async client - Update test interceptor to support Session mocking Co-Authored-By: Claude Opus 4.5 --- tavily/tavily.py | 194 +++++++++++++------------------------ tests/request_intercept.py | 27 ++++++ 2 files changed, 97 insertions(+), 124 deletions(-) diff --git a/tavily/tavily.py b/tavily/tavily.py index abd22ba..194d209 100644 --- a/tavily/tavily.py +++ b/tavily/tavily.py @@ -3,7 +3,6 @@ import os import warnings from typing import Literal, Sequence, Optional, List, Union, Generator -from concurrent.futures import ThreadPoolExecutor, as_completed from .utils import get_max_items_from_list from .errors import UsageLimitExceededError, InvalidAPIKeyError, MissingAPIKeyError, BadRequestError, ForbiddenError, TimeoutError @@ -38,6 +37,21 @@ def __init__(self, api_key: Optional[str] = None, proxies: Optional[dict[str, st **({"X-Project-ID": tavily_project} if tavily_project else {}) } + self.session = requests.Session() + self.session.headers.update(self.headers) + if self.proxies: + self.session.proxies.update(self.proxies) + + def close(self): + """Close the session and release resources.""" + self.session.close() + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() + def _search(self, query: str, search_depth: Literal["basic", "advanced", "fast", "ultra-fast"] = None, @@ -89,9 +103,11 @@ def _search(self, data.update(kwargs) timeout = min(timeout, 120) + url = self.base_url + "/search" + payload = json.dumps(data) try: - response = requests.post(self.base_url + "/search", data=json.dumps(data), headers=self.headers, timeout=timeout, proxies=self.proxies) + response = self.session.post(url, data=payload, timeout=timeout) except requests.exceptions.Timeout: raise TimeoutError(timeout) @@ -106,13 +122,12 @@ def _search(self, if response.status_code == 429: raise UsageLimitExceededError(detail) - elif response.status_code in [403,432,433]: + elif response.status_code in [403, 432, 433]: raise ForbiddenError(detail) elif response.status_code == 401: raise InvalidAPIKeyError(detail) elif response.status_code == 400: raise BadRequestError(detail) - else: raise response.raise_for_status() @@ -160,13 +175,8 @@ def search(self, auto_parameters=auto_parameters, include_favicon=include_favicon, include_usage=include_usage, - **kwargs, - ) - - tavily_results = response_dict.get("results", []) - - response_dict["results"] = tavily_results - + **kwargs) + response_dict.setdefault("results", []) return response_dict def _extract(self, @@ -202,7 +212,7 @@ def _extract(self, data.update(kwargs) try: - response = requests.post(self.base_url + "/extract", data=json.dumps(data), headers=self.headers, timeout=timeout, proxies=self.proxies) + response = self.session.post(self.base_url + "/extract", data=json.dumps(data), timeout=timeout) except requests.exceptions.Timeout: raise TimeoutError(timeout) @@ -217,7 +227,7 @@ def _extract(self, if response.status_code == 429: raise UsageLimitExceededError(detail) - elif response.status_code in [403,432,433]: + elif response.status_code in [403, 432, 433]: raise ForbiddenError(detail) elif response.status_code == 401: raise InvalidAPIKeyError(detail) @@ -251,13 +261,8 @@ def extract(self, query=query, chunks_per_source=chunks_per_source, **kwargs) - - tavily_results = response_dict.get("results", []) - failed_results = response_dict.get("failed_results", []) - - response_dict["results"] = tavily_results - response_dict["failed_results"] = failed_results - + response_dict.setdefault("results", []) + response_dict.setdefault("failed_results", []) return response_dict def _crawl(self, @@ -310,8 +315,7 @@ def _crawl(self, data = {k: v for k, v in data.items() if v is not None} try: - response = requests.post( - self.base_url + "/crawl", data=json.dumps(data), headers=self.headers, timeout=timeout, proxies=self.proxies) + response = self.session.post(self.base_url + "/crawl", data=json.dumps(data), timeout=timeout) except requests.exceptions.Timeout: raise TimeoutError(timeout) @@ -326,7 +330,7 @@ def _crawl(self, if response.status_code == 429: raise UsageLimitExceededError(detail) - elif response.status_code in [403,432,433]: + elif response.status_code in [403, 432, 433]: raise ForbiddenError(detail) elif response.status_code == 401: raise InvalidAPIKeyError(detail) @@ -359,26 +363,24 @@ def crawl(self, Combined crawl method. include_favicon: If True, include the favicon in the crawl results. """ - response_dict = self._crawl(url, - max_depth=max_depth, - max_breadth=max_breadth, - limit=limit, - instructions=instructions, - select_paths=select_paths, - select_domains=select_domains, - exclude_paths=exclude_paths, - exclude_domains=exclude_domains, - allow_external=allow_external, - include_images=include_images, - extract_depth=extract_depth, - format=format, - timeout=timeout, - include_favicon=include_favicon, - include_usage=include_usage, - chunks_per_source=chunks_per_source, - **kwargs) - - return response_dict + return self._crawl(url, + max_depth=max_depth, + max_breadth=max_breadth, + limit=limit, + instructions=instructions, + select_paths=select_paths, + select_domains=select_domains, + exclude_paths=exclude_paths, + exclude_domains=exclude_domains, + allow_external=allow_external, + include_images=include_images, + extract_depth=extract_depth, + format=format, + timeout=timeout, + include_favicon=include_favicon, + include_usage=include_usage, + chunks_per_source=chunks_per_source, + **kwargs) def _map(self, url: str, @@ -421,8 +423,7 @@ def _map(self, data = {k: v for k, v in data.items() if v is not None} try: - response = requests.post( - self.base_url + "/map", data=json.dumps(data), headers=self.headers, timeout=timeout, proxies=self.proxies) + response = self.session.post(self.base_url + "/map", data=json.dumps(data), timeout=timeout) except requests.exceptions.Timeout: raise TimeoutError(timeout) @@ -437,7 +438,7 @@ def _map(self, if response.status_code == 429: raise UsageLimitExceededError(detail) - elif response.status_code in [403,432,433]: + elif response.status_code in [403, 432, 433]: raise ForbiddenError(detail) elif response.status_code == 401: raise InvalidAPIKeyError(detail) @@ -466,22 +467,20 @@ def map(self, Combined map method. """ - response_dict = self._map(url, - max_depth=max_depth, - max_breadth=max_breadth, - limit=limit, - instructions=instructions, - select_paths=select_paths, - select_domains=select_domains, - exclude_paths=exclude_paths, - exclude_domains=exclude_domains, - allow_external=allow_external, - include_images=include_images, - timeout=timeout, - include_usage=include_usage, - **kwargs) - - return response_dict + return self._map(url, + max_depth=max_depth, + max_breadth=max_breadth, + limit=limit, + instructions=instructions, + select_paths=select_paths, + select_domains=select_domains, + exclude_paths=exclude_paths, + exclude_domains=exclude_domains, + allow_external=allow_external, + include_images=include_images, + timeout=timeout, + include_usage=include_usage, + **kwargs) def get_search_context(self, query: str, @@ -563,47 +562,6 @@ def qna_search(self, ) return response_dict.get("answer", "") - def get_company_info(self, - query: str, - search_depth: Literal["basic", - "advanced", - "fast", - "ultra-fast"] = "advanced", - max_results: int = 5, - timeout: float = 60, - country: str = None, - ) -> Sequence[dict]: - """ Company information search method. Search depth is advanced by default to get the best answer. """ - warnings.warn("get_company_info is deprecated and will be removed in future versions.", - DeprecationWarning, stacklevel=2) - def _perform_search(topic): - return self._search(query, - search_depth=search_depth, - topic=topic, - max_results=max_results, - include_answer=False, - timeout=timeout, - country=country) - - with ThreadPoolExecutor() as executor: - # Initiate the search for each topic in parallel - future_to_topic = {executor.submit(_perform_search, topic): topic for topic in - ["news", "general", "finance"]} - - all_results = [] - - # Process the results as they become available - for future in as_completed(future_to_topic): - data = future.result() - if 'results' in data: - all_results.extend(data['results']) - - # Sort all the results by score in descending order and take the top 'max_results' items - sorted_results = sorted(all_results, key=lambda x: x['score'], reverse=True)[ - :max_results] - - return sorted_results - def _research(self, input: str, model: Literal["mini", "pro", "auto"] = None, @@ -631,12 +589,10 @@ def _research(self, if stream: try: - response = requests.post( + response = self.session.post( self.base_url + "/research", data=json.dumps(data), - headers=self.headers, timeout=timeout, - proxies=self.proxies, stream=True ) except requests.exceptions.Timeout: @@ -651,7 +607,7 @@ def _research(self, if response.status_code == 429: raise UsageLimitExceededError(detail) - elif response.status_code in [403,432,433]: + elif response.status_code in [403, 432, 433]: raise ForbiddenError(detail) elif response.status_code == 401: raise InvalidAPIKeyError(detail) @@ -671,12 +627,10 @@ def stream_generator() -> Generator[bytes, None, None]: return stream_generator() else: try: - response = requests.post( + response = self.session.post( self.base_url + "/research", data=json.dumps(data), - headers=self.headers, - timeout=timeout, - proxies=self.proxies + timeout=timeout ) except requests.exceptions.Timeout: raise TimeoutError(timeout) @@ -692,7 +646,7 @@ def stream_generator() -> Generator[bytes, None, None]: if response.status_code == 429: raise UsageLimitExceededError(detail) - elif response.status_code in [403,432,433]: + elif response.status_code in [403, 432, 433]: raise ForbiddenError(detail) elif response.status_code == 401: raise InvalidAPIKeyError(detail) @@ -726,8 +680,7 @@ def research(self, dict: Response containing request_id, created_at, status, input, and model. """ - - response_dict = self._research( + return self._research( input=input, model=model, output_schema=output_schema, @@ -737,8 +690,6 @@ def research(self, **kwargs ) - return response_dict - def get_research(self, request_id: str ) -> dict: @@ -752,17 +703,12 @@ def get_research(self, dict: Research response containing request_id, created_at, completed_at, status, content, and sources. """ try: - response = requests.get( - self.base_url + f"/research/{request_id}", - headers=self.headers, - proxies=self.proxies, - ) + response = self.session.get(self.base_url + f"/research/{request_id}") except Exception as e: raise Exception(f"Error getting research: {e}") if response.status_code in (200, 202): - data = response.json() - return data + return response.json() else: detail = "" try: @@ -772,7 +718,7 @@ def get_research(self, if response.status_code == 429: raise UsageLimitExceededError(detail) - elif response.status_code in [403,432,433]: + elif response.status_code in [403, 432, 433]: raise ForbiddenError(detail) elif response.status_code == 401: raise InvalidAPIKeyError(detail) diff --git a/tests/request_intercept.py b/tests/request_intercept.py index 31ba81a..482d84d 100644 --- a/tests/request_intercept.py +++ b/tests/request_intercept.py @@ -46,6 +46,29 @@ def iter_content(self): def close(self): pass +class MockSession: + """Mock requests.Session for testing.""" + def __init__(self, interceptor): + self._interceptor = interceptor + self.headers = {} + self.proxies = {} + + def post(self, url, data=None, headers=None, timeout=None, proxies=None, stream=False): + # Merge session headers with request headers (request headers take precedence) + merged_headers = {**self.headers, **(headers or {})} + merged_proxies = {**self.proxies, **(proxies or {})} if proxies else self.proxies + return self._interceptor.post(url, data, merged_headers, timeout, merged_proxies, stream) + + def get(self, url, headers=None, timeout=None, proxies=None): + # Merge session headers with request headers (request headers take precedence) + merged_headers = {**self.headers, **(headers or {})} + merged_proxies = {**self.proxies, **(proxies or {})} if proxies else self.proxies + return self._interceptor.get(url, merged_headers, timeout, merged_proxies) + + def close(self): + pass + + class Interceptor: def __init__(self): self._request = None @@ -55,6 +78,10 @@ class Exceptions: Timeout = TimeoutError self.exceptions = Exceptions() + def Session(self): + """Return a mock Session object.""" + return MockSession(self) + def set_response(self, status_code, headers={}, body=None, json=None): if json is not None: body = dumps(json)