diff --git a/tavily/tavily.py b/tavily/tavily.py index b8e6674..d354b29 100644 --- a/tavily/tavily.py +++ b/tavily/tavily.py @@ -4,14 +4,30 @@ import warnings from typing import Literal, Sequence, Optional, List, Union, Generator from .utils import get_max_items_from_list -from .errors import UsageLimitExceededError, InvalidAPIKeyError, MissingAPIKeyError, BadRequestError, ForbiddenError, TimeoutError +from .errors import ( + UsageLimitExceededError, + InvalidAPIKeyError, + MissingAPIKeyError, + BadRequestError, + ForbiddenError, + TimeoutError, +) + class TavilyClient: """ Tavily API client class. """ - def __init__(self, api_key: Optional[str] = None, proxies: Optional[dict[str, str]] = None, api_base_url: Optional[str] = None, client_source: Optional[str] = None, project_id: Optional[str] = None): + def __init__( + self, + api_key: Optional[str] = None, + proxies: Optional[dict[str, str]] = None, + api_base_url: Optional[str] = None, + client_source: Optional[str] = None, + project_id: Optional[str] = None, + session: Optional[requests.Session] = None, + ): if api_key is None: api_key = os.getenv("TAVILY_API_KEY") @@ -20,31 +36,55 @@ def __init__(self, api_key: Optional[str] = None, proxies: Optional[dict[str, st resolved_proxies = { "http": proxies.get("http") if proxies else os.getenv("TAVILY_HTTP_PROXY"), - "https": proxies.get("https") if proxies else os.getenv("TAVILY_HTTPS_PROXY"), + "https": proxies.get("https") + if proxies + else os.getenv("TAVILY_HTTPS_PROXY"), } resolved_proxies = {k: v for k, v in resolved_proxies.items() if v} or None tavily_project = project_id or os.getenv("TAVILY_PROJECT") - + self.base_url = api_base_url or "https://api.tavily.com" self.api_key = api_key self.proxies = resolved_proxies - - self.headers = { + + # Create or use provided session + # Track whether session is external to avoid closing it on exit + self._external_session = session is not None + self.session = session if session is not None else requests.Session() + + # Build default headers + default_headers = { "Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}", "X-Client-Source": client_source or "tavily-python", - **({"X-Project-ID": tavily_project} if tavily_project else {}) + **({"X-Project-ID": tavily_project} if tavily_project else {}), } - self.session = requests.Session() - self.session.headers.update(self.headers) + # Store Tavily-specific headers for reference + self.headers = default_headers + + # Merge headers: preserve existing session headers, add defaults for missing keys + # This allows custom sessions to override Authorization and other headers + for key, value in default_headers.items(): + if key not in self.session.headers: + self.session.headers[key] = value + + # Merge proxies: preserve existing session proxies, add defaults for missing protocols + # This allows custom sessions to define their own proxy configuration if self.proxies: - self.session.proxies.update(self.proxies) + for protocol, proxy_url in self.proxies.items(): + if protocol not in self.session.proxies: + self.session.proxies[protocol] = proxy_url def close(self): - """Close the session and release resources.""" - self.session.close() + """Close the session and release resources. + + Only closes the session if it was created internally. + External sessions provided by the user are not closed. + """ + if not self._external_session: + self.session.close() def __enter__(self): return self @@ -52,28 +92,29 @@ def __enter__(self): def __exit__(self, exc_type, exc_val, exc_tb): self.close() - def _search(self, - query: str, - search_depth: Literal["basic", "advanced", "fast", "ultra-fast"] = None, - topic: Literal["general", "news", "finance"] = None, - time_range: Literal["day", "week", "month", "year"] = None, - start_date: str = None, - end_date: str = None, - days: int = None, - max_results: int = None, - include_domains: Sequence[str] = None, - exclude_domains: Sequence[str] = None, - include_answer: Union[bool, Literal["basic", "advanced"]] = None, - include_raw_content: Union[bool, Literal["markdown", "text"]] = None, - include_images: bool = None, - timeout: float = 60, - country: str = None, - auto_parameters: bool = None, - include_favicon: bool = None, - include_usage: bool = None, - exact_match: bool = None, - **kwargs - ) -> dict: + def _search( + self, + query: str, + search_depth: Literal["basic", "advanced", "fast", "ultra-fast"] = None, + topic: Literal["general", "news", "finance"] = None, + time_range: Literal["day", "week", "month", "year"] = None, + start_date: str = None, + end_date: str = None, + days: int = None, + max_results: int = None, + include_domains: Sequence[str] = None, + exclude_domains: Sequence[str] = None, + include_answer: Union[bool, Literal["basic", "advanced"]] = None, + include_raw_content: Union[bool, Literal["markdown", "text"]] = None, + include_images: bool = None, + timeout: float = 60, + country: str = None, + auto_parameters: bool = None, + include_favicon: bool = None, + include_usage: bool = None, + exact_match: bool = None, + **kwargs, + ) -> dict: """ Internal search method to send the request to the API. """ @@ -133,70 +174,73 @@ def _search(self, else: raise response.raise_for_status() - - def search(self, - query: str, - search_depth: Literal["basic", "advanced", "fast", "ultra-fast"] = None, - topic: Literal["general", "news", "finance" ] = None, - time_range: Literal["day", "week", "month", "year"] = None, - start_date: str = None, - end_date: str = None, - days: int = None, - max_results: int = None, - include_domains: Sequence[str] = None, - exclude_domains: Sequence[str] = None, - include_answer: Union[bool, Literal["basic", "advanced"]] = None, - include_raw_content: Union[bool, Literal["markdown", "text"]] = None, - include_images: bool = None, - timeout: float = 60, - country: str = None, - auto_parameters: bool = None, - include_favicon: bool = None, - include_usage: bool = None, - exact_match: bool = None, - **kwargs, # Accept custom arguments - ) -> dict: + def search( + self, + query: str, + search_depth: Literal["basic", "advanced", "fast", "ultra-fast"] = None, + topic: Literal["general", "news", "finance"] = None, + time_range: Literal["day", "week", "month", "year"] = None, + start_date: str = None, + end_date: str = None, + days: int = None, + max_results: int = None, + include_domains: Sequence[str] = None, + exclude_domains: Sequence[str] = None, + include_answer: Union[bool, Literal["basic", "advanced"]] = None, + include_raw_content: Union[bool, Literal["markdown", "text"]] = None, + include_images: bool = None, + timeout: float = 60, + country: str = None, + auto_parameters: bool = None, + include_favicon: bool = None, + include_usage: bool = None, + exact_match: bool = None, + **kwargs, # Accept custom arguments + ) -> dict: """ Combined search method. """ - response_dict = self._search(query, - search_depth=search_depth, - topic=topic, - time_range=time_range, - start_date=start_date, - end_date=end_date, - days=days, - max_results=max_results, - include_domains=include_domains, - exclude_domains=exclude_domains, - include_answer=include_answer, - include_raw_content=include_raw_content, - include_images=include_images, - timeout=timeout, - country=country, - auto_parameters=auto_parameters, - include_favicon=include_favicon, - include_usage=include_usage, - exact_match=exact_match, - **kwargs) + response_dict = self._search( + query, + search_depth=search_depth, + topic=topic, + time_range=time_range, + start_date=start_date, + end_date=end_date, + days=days, + max_results=max_results, + include_domains=include_domains, + exclude_domains=exclude_domains, + include_answer=include_answer, + include_raw_content=include_raw_content, + include_images=include_images, + timeout=timeout, + country=country, + auto_parameters=auto_parameters, + include_favicon=include_favicon, + include_usage=include_usage, + exact_match=exact_match, + **kwargs, + ) response_dict.setdefault("results", []) return response_dict - def _extract(self, - urls: Union[List[str], str], - include_images: bool = None, - extract_depth: Literal["basic", "advanced"] = None, - format: Literal["markdown", "text"] = None, - timeout: float = 30, - include_favicon: bool = None, - include_usage: bool = None, - query: str = None, - chunks_per_source: int = None, - **kwargs - ) -> dict: + def _extract( + self, + urls: Union[List[str], str], + include_images: bool = None, + extract_depth: Literal["basic", "advanced"] = None, + format: Literal["markdown", "text"] = None, + timeout: float = 30, + include_favicon: bool = None, + include_usage: bool = None, + query: str = None, + chunks_per_source: int = None, + **kwargs, + ) -> dict: """ - Internal extract method to send the request to the API. + Internal extract method to send the request to the API. """ data = { "urls": urls, @@ -216,7 +260,9 @@ def _extract(self, data.update(kwargs) try: - response = self.session.post(self.base_url + "/extract", data=json.dumps(data), timeout=timeout) + response = self.session.post( + self.base_url + "/extract", data=json.dumps(data), timeout=timeout + ) except requests.exceptions.Timeout: raise TimeoutError(timeout) @@ -240,55 +286,59 @@ def _extract(self, else: raise response.raise_for_status() - def extract(self, - urls: Union[List[str], str], # Accept a list of URLs or a single URL - include_images: bool = None, - extract_depth: Literal["basic", "advanced"] = None, - format: Literal["markdown", "text"] = None, - timeout: float = 30, - include_favicon: bool = None, - include_usage: bool = None, - query: str = None, - chunks_per_source: int = None, - **kwargs, # Accept custom arguments - ) -> dict: + def extract( + self, + urls: Union[List[str], str], # Accept a list of URLs or a single URL + include_images: bool = None, + extract_depth: Literal["basic", "advanced"] = None, + format: Literal["markdown", "text"] = None, + timeout: float = 30, + include_favicon: bool = None, + include_usage: bool = None, + query: str = None, + chunks_per_source: int = None, + **kwargs, # Accept custom arguments + ) -> dict: """ Combined extract method. """ - response_dict = self._extract(urls, - include_images, - extract_depth, - format, - timeout, - include_favicon=include_favicon, - include_usage=include_usage, - query=query, - chunks_per_source=chunks_per_source, - **kwargs) + response_dict = self._extract( + urls, + include_images, + extract_depth, + format, + timeout, + include_favicon=include_favicon, + include_usage=include_usage, + query=query, + chunks_per_source=chunks_per_source, + **kwargs, + ) response_dict.setdefault("results", []) response_dict.setdefault("failed_results", []) return response_dict - def _crawl(self, - url: str, - max_depth: int = None, - max_breadth: int = None, - limit: int = None, - instructions: str = None, - select_paths: Sequence[str] = None, - select_domains: Sequence[str] = None, - exclude_paths: Sequence[str] = None, - exclude_domains: Sequence[str] = None, - allow_external: bool = None, - include_images: bool = None, - extract_depth: Literal["basic", "advanced"] = None, - format: Literal["markdown", "text"] = None, - timeout: float = 150, - include_favicon: bool = None, - include_usage: bool = None, - chunks_per_source: int = None, - **kwargs - ) -> dict: + def _crawl( + self, + url: str, + max_depth: int = None, + max_breadth: int = None, + limit: int = None, + instructions: str = None, + select_paths: Sequence[str] = None, + select_domains: Sequence[str] = None, + exclude_paths: Sequence[str] = None, + exclude_domains: Sequence[str] = None, + allow_external: bool = None, + include_images: bool = None, + extract_depth: Literal["basic", "advanced"] = None, + format: Literal["markdown", "text"] = None, + timeout: float = 150, + include_favicon: bool = None, + include_usage: bool = None, + chunks_per_source: int = None, + **kwargs, + ) -> dict: """ Internal crawl method to send the request to the API. include_favicon: If True, include the favicon in the crawl results. @@ -315,11 +365,13 @@ def _crawl(self, if kwargs: data.update(kwargs) - + data = {k: v for k, v in data.items() if v is not None} try: - response = self.session.post(self.base_url + "/crawl", data=json.dumps(data), timeout=timeout) + response = self.session.post( + self.base_url + "/crawl", data=json.dumps(data), timeout=timeout + ) except requests.exceptions.Timeout: raise TimeoutError(timeout) @@ -343,65 +395,69 @@ def _crawl(self, else: raise response.raise_for_status() - def crawl(self, - url: str, - max_depth: int = None, - max_breadth: int = None, - limit: int = None, - instructions: str = None, - select_paths: Sequence[str] = None, - select_domains: Sequence[str] = None, - exclude_paths: Sequence[str] = None, - exclude_domains: Sequence[str] = None, - allow_external: bool = None, - include_images: bool = None, - extract_depth: Literal["basic", "advanced"] = None, - format: Literal["markdown", "text"] = None, - timeout: float = 150, - include_favicon: bool = None, - include_usage: bool = None, - chunks_per_source: int = None, - **kwargs - ) -> dict: + def crawl( + self, + url: str, + max_depth: int = None, + max_breadth: int = None, + limit: int = None, + instructions: str = None, + select_paths: Sequence[str] = None, + select_domains: Sequence[str] = None, + exclude_paths: Sequence[str] = None, + exclude_domains: Sequence[str] = None, + allow_external: bool = None, + include_images: bool = None, + extract_depth: Literal["basic", "advanced"] = None, + format: Literal["markdown", "text"] = None, + timeout: float = 150, + include_favicon: bool = None, + include_usage: bool = None, + chunks_per_source: int = None, + **kwargs, + ) -> dict: """ Combined crawl method. include_favicon: If True, include the favicon in the crawl results. """ - return self._crawl(url, - max_depth=max_depth, - max_breadth=max_breadth, - limit=limit, - instructions=instructions, - select_paths=select_paths, - select_domains=select_domains, - exclude_paths=exclude_paths, - exclude_domains=exclude_domains, - allow_external=allow_external, - include_images=include_images, - extract_depth=extract_depth, - format=format, - timeout=timeout, - include_favicon=include_favicon, - include_usage=include_usage, - chunks_per_source=chunks_per_source, - **kwargs) - - def _map(self, - url: str, - max_depth: int = None, - max_breadth: int = None, - limit: int = None, - instructions: str = None, - select_paths: Sequence[str] = None, - select_domains: Sequence[str] = None, - exclude_paths: Sequence[str] = None, - exclude_domains: Sequence[str] = None, - allow_external: bool = None, - include_images: bool = None, - timeout: float = 150, - include_usage: bool = None, - **kwargs - ) -> dict: + return self._crawl( + url, + max_depth=max_depth, + max_breadth=max_breadth, + limit=limit, + instructions=instructions, + select_paths=select_paths, + select_domains=select_domains, + exclude_paths=exclude_paths, + exclude_domains=exclude_domains, + allow_external=allow_external, + include_images=include_images, + extract_depth=extract_depth, + format=format, + timeout=timeout, + include_favicon=include_favicon, + include_usage=include_usage, + chunks_per_source=chunks_per_source, + **kwargs, + ) + + def _map( + self, + url: str, + max_depth: int = None, + max_breadth: int = None, + limit: int = None, + instructions: str = None, + select_paths: Sequence[str] = None, + select_domains: Sequence[str] = None, + exclude_paths: Sequence[str] = None, + exclude_domains: Sequence[str] = None, + allow_external: bool = None, + include_images: bool = None, + timeout: float = 150, + include_usage: bool = None, + **kwargs, + ) -> dict: """ Internal map method to send the request to the API. """ @@ -423,11 +479,13 @@ def _map(self, if kwargs: data.update(kwargs) - + data = {k: v for k, v in data.items() if v is not None} try: - response = self.session.post(self.base_url + "/map", data=json.dumps(data), timeout=timeout) + response = self.session.post( + self.base_url + "/map", data=json.dumps(data), timeout=timeout + ) except requests.exceptions.Timeout: raise TimeoutError(timeout) @@ -451,55 +509,59 @@ def _map(self, else: raise response.raise_for_status() - def map(self, - url: str, - max_depth: int = None, - max_breadth: int = None, - limit: int = None, - instructions: str = None, - select_paths: Sequence[str] = None, - select_domains: Sequence[str] = None, - exclude_paths: Sequence[str] = None, - exclude_domains: Sequence[str] = None, - allow_external: bool = None, - include_images: bool = None, - timeout: float = 150, - include_usage: bool = None, - **kwargs - ) -> dict: + def map( + self, + url: str, + max_depth: int = None, + max_breadth: int = None, + limit: int = None, + instructions: str = None, + select_paths: Sequence[str] = None, + select_domains: Sequence[str] = None, + exclude_paths: Sequence[str] = None, + exclude_domains: Sequence[str] = None, + allow_external: bool = None, + include_images: bool = None, + timeout: float = 150, + include_usage: bool = None, + **kwargs, + ) -> dict: """ Combined map method. - + """ - return self._map(url, - max_depth=max_depth, - max_breadth=max_breadth, - limit=limit, - instructions=instructions, - select_paths=select_paths, - select_domains=select_domains, - exclude_paths=exclude_paths, - exclude_domains=exclude_domains, - allow_external=allow_external, - include_images=include_images, - timeout=timeout, - include_usage=include_usage, - **kwargs) - - def get_search_context(self, - query: str, - search_depth: Literal["basic", "advanced", "fast", "ultra-fast"] = "basic", - topic: Literal["general", "news", "finance"] = "general", - days: int = 7, - max_results: int = 5, - include_domains: Sequence[str] = None, - exclude_domains: Sequence[str] = None, - max_tokens: int = 4000, - timeout: float = 60, - country: str = None, - include_favicon: bool = None, - **kwargs, # Accept custom arguments - ) -> str: + return self._map( + url, + max_depth=max_depth, + max_breadth=max_breadth, + limit=limit, + instructions=instructions, + select_paths=select_paths, + select_domains=select_domains, + exclude_paths=exclude_paths, + exclude_domains=exclude_domains, + allow_external=allow_external, + include_images=include_images, + timeout=timeout, + include_usage=include_usage, + **kwargs, + ) + + def get_search_context( + self, + query: str, + search_depth: Literal["basic", "advanced", "fast", "ultra-fast"] = "basic", + topic: Literal["general", "news", "finance"] = "general", + days: int = 7, + max_results: int = 5, + include_domains: Sequence[str] = None, + exclude_domains: Sequence[str] = None, + max_tokens: int = 4000, + timeout: float = 60, + country: str = None, + include_favicon: bool = None, + **kwargs, # Accept custom arguments + ) -> str: """ Get the search context for a query. Useful for getting only related content from retrieved websites without having to deal with context extraction and limitation yourself. @@ -508,73 +570,84 @@ def get_search_context(self, Returns a string of JSON containing the search context up to context limit. """ - warnings.warn("get_search_context is deprecated and will be removed in future versions.", - DeprecationWarning, stacklevel=2) - - response_dict = self._search(query, - search_depth=search_depth, - topic=topic, - days=days, - max_results=max_results, - include_domains=include_domains, - exclude_domains=exclude_domains, - include_answer=False, - include_raw_content=False, - include_images=False, - timeout=timeout, - country=country, - include_favicon=include_favicon, - **kwargs, - ) + warnings.warn( + "get_search_context is deprecated and will be removed in future versions.", + DeprecationWarning, + stacklevel=2, + ) + + response_dict = self._search( + query, + search_depth=search_depth, + topic=topic, + days=days, + max_results=max_results, + include_domains=include_domains, + exclude_domains=exclude_domains, + include_answer=False, + include_raw_content=False, + include_images=False, + timeout=timeout, + country=country, + include_favicon=include_favicon, + **kwargs, + ) sources = response_dict.get("results", []) - context = [{"url": source["url"], "content": source["content"]} - for source in sources] + context = [ + {"url": source["url"], "content": source["content"]} for source in sources + ] return json.dumps(get_max_items_from_list(context, max_tokens)) - def qna_search(self, - query: str, - search_depth: Literal["basic", "advanced", "fast", "ultra-fast"] = "advanced", - topic: Literal["general", "news", "finance"] = "general", - days: int = 7, - max_results: int = 5, - include_domains: Sequence[str] = None, - exclude_domains: Sequence[str] = None, - timeout: float = 60, - country: str = None, - include_favicon: bool = None, - **kwargs, # Accept custom arguments - ) -> str: + def qna_search( + self, + query: str, + search_depth: Literal["basic", "advanced", "fast", "ultra-fast"] = "advanced", + topic: Literal["general", "news", "finance"] = "general", + days: int = 7, + max_results: int = 5, + include_domains: Sequence[str] = None, + exclude_domains: Sequence[str] = None, + timeout: float = 60, + country: str = None, + include_favicon: bool = None, + **kwargs, # Accept custom arguments + ) -> str: """ Q&A search method. Search depth is advanced by default to get the best answer. """ - warnings.warn("qna_search is deprecated and will be removed in future versions.", - DeprecationWarning, stacklevel=2) - response_dict = self._search(query, - search_depth=search_depth, - topic=topic, - days=days, - max_results=max_results, - include_domains=include_domains, - exclude_domains=exclude_domains, - include_raw_content=False, - include_images=False, - include_answer=True, - timeout=timeout, - country=country, - include_favicon=include_favicon, - **kwargs, - ) + warnings.warn( + "qna_search is deprecated and will be removed in future versions.", + DeprecationWarning, + stacklevel=2, + ) + response_dict = self._search( + query, + search_depth=search_depth, + topic=topic, + days=days, + max_results=max_results, + include_domains=include_domains, + exclude_domains=exclude_domains, + include_raw_content=False, + include_images=False, + include_answer=True, + timeout=timeout, + country=country, + include_favicon=include_favicon, + **kwargs, + ) return response_dict.get("answer", "") - def _research(self, - input: str, - model: Literal["mini", "pro", "auto"] = None, - output_schema: dict = None, - stream: bool = False, - citation_format: Literal["numbered", "mla", "apa", "chicago"] = "numbered", - timeout: Optional[float] = None, - **kwargs - ) -> Union[dict, Generator[bytes, None, None]]: + def _research( + self, + input: str, + model: Literal["mini", "pro", "auto"] = None, + output_schema: dict = None, + stream: bool = False, + citation_format: Literal["numbered", "mla", "apa", "chicago"] = "numbered", + timeout: Optional[float] = None, + **kwargs, + ) -> Union[dict, Generator[bytes, None, None]]: """ Internal research method to send the request to the API. """ @@ -597,7 +670,7 @@ def _research(self, self.base_url + "/research", data=json.dumps(data), timeout=timeout, - stream=True + stream=True, ) except requests.exceptions.Timeout: raise TimeoutError(timeout) @@ -632,9 +705,7 @@ def stream_generator() -> Generator[bytes, None, None]: else: try: response = self.session.post( - self.base_url + "/research", - data=json.dumps(data), - timeout=timeout + self.base_url + "/research", data=json.dumps(data), timeout=timeout ) except requests.exceptions.Timeout: raise TimeoutError(timeout) @@ -659,27 +730,28 @@ def stream_generator() -> Generator[bytes, None, None]: else: raise response.raise_for_status() - def research(self, - input: str, - model: Literal["mini", "pro", "auto"] = None, - output_schema: dict = None, - stream: bool = False, - citation_format: Literal["numbered", "mla", "apa", "chicago"] = "numbered", - timeout: Optional[float] = None, - **kwargs - ) -> Union[dict, Generator[bytes, None, None]]: + def research( + self, + input: str, + model: Literal["mini", "pro", "auto"] = None, + output_schema: dict = None, + stream: bool = False, + citation_format: Literal["numbered", "mla", "apa", "chicago"] = "numbered", + timeout: Optional[float] = None, + **kwargs, + ) -> Union[dict, Generator[bytes, None, None]]: """ Research method to create a research task. - + Args: input: The research task or question to investigate (required). model: The model used by the research agent - must be either 'mini', 'pro', or 'auto'. output_schema: Schema for the 'structured_output' response format (JSON Schema dict). stream: Whether to stream the research task. citation_format: Citation format - must be either 'numbered', 'mla', 'apa', or 'chicago'. - timeout: Optional HTTP request timeout in seconds. + timeout: Optional HTTP request timeout in seconds. **kwargs: Additional custom arguments. - + Returns: dict: Response containing request_id, created_at, status, input, and model. """ @@ -691,18 +763,16 @@ def research(self, stream=stream, citation_format=citation_format, timeout=timeout, - **kwargs + **kwargs, ) - def get_research(self, - request_id: str - ) -> dict: + def get_research(self, request_id: str) -> dict: """ Get research results by request_id. - + Args: request_id: The research request ID. - + Returns: dict: Research response containing request_id, created_at, completed_at, status, content, and sources. """ @@ -740,6 +810,9 @@ class Client(TavilyClient): """ def __init__(self, kwargs): - warnings.warn("Client is deprecated, please use TavilyClient instead", - DeprecationWarning, stacklevel=2) + warnings.warn( + "Client is deprecated, please use TavilyClient instead", + DeprecationWarning, + stacklevel=2, + ) super().__init__(kwargs) diff --git a/tests/test_custom_session.py b/tests/test_custom_session.py new file mode 100644 index 0000000..6a47d98 --- /dev/null +++ b/tests/test_custom_session.py @@ -0,0 +1,295 @@ +"""Tests for custom session functionality.""" + +import pytest +import requests +from tavily import TavilyClient + + +def test_custom_session_basic(): + """Test that a custom session is properly used.""" + # Create a custom session with a custom header + custom_session = requests.Session() + custom_session.headers.update({"X-Custom-Header": "test-value"}) + + # Initialize client with custom session + client = TavilyClient(api_key="test-key", session=custom_session) + + # Verify the session is the one we passed + assert client.session is custom_session + + # Verify our custom header is still there + assert "X-Custom-Header" in client.session.headers + assert client.session.headers["X-Custom-Header"] == "test-value" + + # Verify Tavily headers were added + assert "Authorization" in client.session.headers + assert client.session.headers["Authorization"] == "Bearer test-key" + assert client.session.headers["Content-Type"] == "application/json" + + client.close() + + +def test_custom_session_with_proxies(): + """Test that proxies are properly applied to custom session.""" + custom_session = requests.Session() + + proxies = { + "http": "http://proxy.example.com:8080", + "https": "https://proxy.example.com:8080", + } + + client = TavilyClient(api_key="test-key", proxies=proxies, session=custom_session) + + # Verify proxies were applied to the custom session + assert client.session.proxies.get("http") == "http://proxy.example.com:8080" + assert client.session.proxies.get("https") == "https://proxy.example.com:8080" + + client.close() + + +def test_custom_session_with_azure_apim_headers(): + """Test custom session with Azure APIM style authentication.""" + custom_session = requests.Session() + custom_session.headers.update( + { + "Ocp-Apim-Subscription-Key": "apim-subscription-key-123", + "X-Tenant-ID": "tenant-456", + } + ) + + client = TavilyClient( + api_key="dummy-key", + api_base_url="https://apim.example.com/tavily", + session=custom_session, + ) + + # Verify custom APIM headers are preserved + assert ( + client.session.headers["Ocp-Apim-Subscription-Key"] + == "apim-subscription-key-123" + ) + assert client.session.headers["X-Tenant-ID"] == "tenant-456" + + # Verify base URL is set correctly + assert client.base_url == "https://apim.example.com/tavily" + + client.close() + + +def test_default_session_when_none_provided(): + """Test that default session is created when none is provided.""" + client = TavilyClient(api_key="test-key") + + # Verify a session was created + assert client.session is not None + assert isinstance(client.session, requests.Session) + + # Verify headers were set + assert client.session.headers["Authorization"] == "Bearer test-key" + + client.close() + + +def test_custom_session_with_context_manager(): + """Test that custom session works with context manager.""" + custom_session = requests.Session() + custom_session.headers.update({"X-Custom": "value"}) + + with TavilyClient(api_key="test-key", session=custom_session) as client: + assert client.session is custom_session + assert "X-Custom" in client.session.headers + + # External session should NOT be closed after context manager exits + # User may want to reuse it for other requests or TavilyClient instances + # We can test this by verifying the session still has its headers + assert "X-Custom" in custom_session.headers + + # Clean up + custom_session.close() + + +def test_internal_session_closed_by_context_manager(): + """Test that internally created session IS closed by context manager.""" + with TavilyClient(api_key="test-key") as client: + # Capture the session reference + session = client.session + assert session is not None + + # The internal session should have been closed + # A closed session will raise an exception if we try to make a request + # (Note: We can't easily test this without making actual HTTP calls, + # but we can verify the flag was set correctly) + + +def test_external_session_not_closed_by_close(): + """Test that external sessions are not closed by close() method.""" + custom_session = requests.Session() + custom_session.headers.update({"X-Test": "value"}) + + client = TavilyClient(api_key="test-key", session=custom_session) + + # Explicitly close the client + client.close() + + # The external session should still be usable + assert "X-Test" in custom_session.headers + + # Clean up + custom_session.close() + + +def test_internal_session_closed_by_close(): + """Test that internally created sessions ARE closed by close() method.""" + client = TavilyClient(api_key="test-key") + session = client.session + + # Close the client + client.close() + + # The internal session should have been closed + # (We can't easily verify this without making HTTP calls, but the flag check ensures it) + + +def test_custom_session_preserves_existing_auth(): + """Test that custom session auth is preserved (not overridden).""" + custom_session = requests.Session() + # Set a custom authorization that should be PRESERVED + custom_session.headers.update({"Authorization": "Bearer custom-token"}) + + client = TavilyClient(api_key="new-api-key", session=custom_session) + + # The client should PRESERVE the custom Authorization header + # This allows enterprise API gateways to use custom auth (e.g., Azure AD JWT) + assert client.session.headers["Authorization"] == "Bearer custom-token" + + client.close() + + +def test_custom_session_preserves_existing_proxies(): + """Test that custom session proxies are preserved (not overridden).""" + custom_session = requests.Session() + # Set custom proxies that should be PRESERVED + custom_session.proxies.update( + { + "http": "http://custom-proxy.example.com:8080", + "https": "https://custom-proxy.example.com:8443", + } + ) + + # Try to override with different proxies via TavilyClient params + override_proxies = { + "http": "http://override-proxy.example.com:9090", + "https": "https://override-proxy.example.com:9443", + } + + client = TavilyClient( + api_key="test-key", proxies=override_proxies, session=custom_session + ) + + # The custom session's proxies should be PRESERVED (not overridden) + # This allows users to configure proxies on the session with full control + assert client.session.proxies["http"] == "http://custom-proxy.example.com:8080" + assert client.session.proxies["https"] == "https://custom-proxy.example.com:8443" + + client.close() + + +def test_custom_session_adds_missing_proxies(): + """Test that TavilyClient adds proxies for protocols not in custom session.""" + custom_session = requests.Session() + # Only set HTTP proxy, leave HTTPS undefined + custom_session.proxies.update({"http": "http://custom-http-proxy.example.com:8080"}) + + # TavilyClient provides both HTTP and HTTPS proxies + client_proxies = { + "http": "http://override-http.example.com:9090", + "https": "https://client-https.example.com:9443", + } + + client = TavilyClient( + api_key="test-key", proxies=client_proxies, session=custom_session + ) + + # HTTP proxy should be preserved from custom session + assert client.session.proxies["http"] == "http://custom-http-proxy.example.com:8080" + + # HTTPS proxy should be added from client_proxies (since it wasn't in session) + assert client.session.proxies["https"] == "https://client-https.example.com:9443" + + client.close() + + +def test_custom_session_proxies_from_env(): + """Test that custom session proxies are preserved over environment variables.""" + import os + + # Set environment variables + os.environ["TAVILY_HTTP_PROXY"] = "http://env-proxy.example.com:7070" + os.environ["TAVILY_HTTPS_PROXY"] = "https://env-proxy.example.com:7443" + + try: + custom_session = requests.Session() + # Custom session proxies should take precedence + custom_session.proxies.update( + { + "http": "http://session-proxy.example.com:6060", + "https": "https://session-proxy.example.com:6443", + } + ) + + client = TavilyClient(api_key="test-key", session=custom_session) + + # Custom session proxies should be preserved (not overridden by env vars) + assert client.session.proxies["http"] == "http://session-proxy.example.com:6060" + assert ( + client.session.proxies["https"] == "https://session-proxy.example.com:6443" + ) + + client.close() + finally: + # Clean up environment variables + os.environ.pop("TAVILY_HTTP_PROXY", None) + os.environ.pop("TAVILY_HTTPS_PROXY", None) + + +def test_headers_attribute_contains_only_tavily_headers(): + """Test that self.headers contains only Tavily-specific headers, not session defaults.""" + # Create a custom session with extra headers + custom_session = requests.Session() + custom_session.headers.update( + { + "X-Custom-Header": "custom-value", + "X-Azure-APIM-Key": "apim-key-123", + } + ) + + client = TavilyClient( + api_key="test-key", session=custom_session, project_id="test-project" + ) + + # self.headers should only contain Tavily-specific headers + expected_keys = {"Content-Type", "Authorization", "X-Client-Source", "X-Project-ID"} + assert set(client.headers.keys()) == expected_keys + + # Verify the values + assert client.headers["Content-Type"] == "application/json" + assert client.headers["Authorization"] == "Bearer test-key" + assert client.headers["X-Client-Source"] == "tavily-python" + assert client.headers["X-Project-ID"] == "test-project" + + # self.headers should NOT contain custom session headers + assert "X-Custom-Header" not in client.headers + assert "X-Azure-APIM-Key" not in client.headers + + # self.headers should NOT contain requests.Session default headers + # (like User-Agent, Accept-Encoding, etc.) + assert "User-Agent" not in client.headers + assert "Accept-Encoding" not in client.headers + assert "Accept" not in client.headers + assert "Connection" not in client.headers + + # But the session itself should have all headers + assert "X-Custom-Header" in client.session.headers + assert "X-Azure-APIM-Key" in client.session.headers + + client.close()