Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
194 changes: 70 additions & 124 deletions tavily/tavily.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import os
import warnings
from typing import Literal, Sequence, Optional, List, Union, Generator
from concurrent.futures import ThreadPoolExecutor, as_completed
from .utils import get_max_items_from_list
from .errors import UsageLimitExceededError, InvalidAPIKeyError, MissingAPIKeyError, BadRequestError, ForbiddenError, TimeoutError

Expand Down Expand Up @@ -38,6 +37,21 @@ def __init__(self, api_key: Optional[str] = None, proxies: Optional[dict[str, st
**({"X-Project-ID": tavily_project} if tavily_project else {})
}

self.session = requests.Session()
self.session.headers.update(self.headers)
if self.proxies:
self.session.proxies.update(self.proxies)

def close(self):
"""Close the session and release resources."""
self.session.close()

def __enter__(self):
return self

def __exit__(self, exc_type, exc_val, exc_tb):
self.close()

def _search(self,
query: str,
search_depth: Literal["basic", "advanced", "fast", "ultra-fast"] = None,
Expand Down Expand Up @@ -89,9 +103,11 @@ def _search(self,
data.update(kwargs)

timeout = min(timeout, 120)
url = self.base_url + "/search"
payload = json.dumps(data)

try:
response = requests.post(self.base_url + "/search", data=json.dumps(data), headers=self.headers, timeout=timeout, proxies=self.proxies)
response = self.session.post(url, data=payload, timeout=timeout)
except requests.exceptions.Timeout:
raise TimeoutError(timeout)

Expand All @@ -106,13 +122,12 @@ def _search(self,

if response.status_code == 429:
raise UsageLimitExceededError(detail)
elif response.status_code in [403,432,433]:
elif response.status_code in [403, 432, 433]:
raise ForbiddenError(detail)
elif response.status_code == 401:
raise InvalidAPIKeyError(detail)
elif response.status_code == 400:
raise BadRequestError(detail)

else:
raise response.raise_for_status()

Expand Down Expand Up @@ -160,13 +175,8 @@ def search(self,
auto_parameters=auto_parameters,
include_favicon=include_favicon,
include_usage=include_usage,
**kwargs,
)

tavily_results = response_dict.get("results", [])

response_dict["results"] = tavily_results

**kwargs)
response_dict.setdefault("results", [])
return response_dict

def _extract(self,
Expand Down Expand Up @@ -202,7 +212,7 @@ def _extract(self,
data.update(kwargs)

try:
response = requests.post(self.base_url + "/extract", data=json.dumps(data), headers=self.headers, timeout=timeout, proxies=self.proxies)
response = self.session.post(self.base_url + "/extract", data=json.dumps(data), timeout=timeout)
except requests.exceptions.Timeout:
raise TimeoutError(timeout)

Expand All @@ -217,7 +227,7 @@ def _extract(self,

if response.status_code == 429:
raise UsageLimitExceededError(detail)
elif response.status_code in [403,432,433]:
elif response.status_code in [403, 432, 433]:
raise ForbiddenError(detail)
elif response.status_code == 401:
raise InvalidAPIKeyError(detail)
Expand Down Expand Up @@ -251,13 +261,8 @@ def extract(self,
query=query,
chunks_per_source=chunks_per_source,
**kwargs)

tavily_results = response_dict.get("results", [])
failed_results = response_dict.get("failed_results", [])

response_dict["results"] = tavily_results
response_dict["failed_results"] = failed_results

response_dict.setdefault("results", [])
response_dict.setdefault("failed_results", [])
return response_dict

def _crawl(self,
Expand Down Expand Up @@ -310,8 +315,7 @@ def _crawl(self,
data = {k: v for k, v in data.items() if v is not None}

try:
response = requests.post(
self.base_url + "/crawl", data=json.dumps(data), headers=self.headers, timeout=timeout, proxies=self.proxies)
response = self.session.post(self.base_url + "/crawl", data=json.dumps(data), timeout=timeout)
except requests.exceptions.Timeout:
raise TimeoutError(timeout)

Expand All @@ -326,7 +330,7 @@ def _crawl(self,

if response.status_code == 429:
raise UsageLimitExceededError(detail)
elif response.status_code in [403,432,433]:
elif response.status_code in [403, 432, 433]:
raise ForbiddenError(detail)
elif response.status_code == 401:
raise InvalidAPIKeyError(detail)
Expand Down Expand Up @@ -359,26 +363,24 @@ def crawl(self,
Combined crawl method.
include_favicon: If True, include the favicon in the crawl results.
"""
response_dict = self._crawl(url,
max_depth=max_depth,
max_breadth=max_breadth,
limit=limit,
instructions=instructions,
select_paths=select_paths,
select_domains=select_domains,
exclude_paths=exclude_paths,
exclude_domains=exclude_domains,
allow_external=allow_external,
include_images=include_images,
extract_depth=extract_depth,
format=format,
timeout=timeout,
include_favicon=include_favicon,
include_usage=include_usage,
chunks_per_source=chunks_per_source,
**kwargs)

return response_dict
return self._crawl(url,
max_depth=max_depth,
max_breadth=max_breadth,
limit=limit,
instructions=instructions,
select_paths=select_paths,
select_domains=select_domains,
exclude_paths=exclude_paths,
exclude_domains=exclude_domains,
allow_external=allow_external,
include_images=include_images,
extract_depth=extract_depth,
format=format,
timeout=timeout,
include_favicon=include_favicon,
include_usage=include_usage,
chunks_per_source=chunks_per_source,
**kwargs)

def _map(self,
url: str,
Expand Down Expand Up @@ -421,8 +423,7 @@ def _map(self,
data = {k: v for k, v in data.items() if v is not None}

try:
response = requests.post(
self.base_url + "/map", data=json.dumps(data), headers=self.headers, timeout=timeout, proxies=self.proxies)
response = self.session.post(self.base_url + "/map", data=json.dumps(data), timeout=timeout)
except requests.exceptions.Timeout:
raise TimeoutError(timeout)

Expand All @@ -437,7 +438,7 @@ def _map(self,

if response.status_code == 429:
raise UsageLimitExceededError(detail)
elif response.status_code in [403,432,433]:
elif response.status_code in [403, 432, 433]:
raise ForbiddenError(detail)
elif response.status_code == 401:
raise InvalidAPIKeyError(detail)
Expand Down Expand Up @@ -466,22 +467,20 @@ def map(self,
Combined map method.

"""
response_dict = self._map(url,
max_depth=max_depth,
max_breadth=max_breadth,
limit=limit,
instructions=instructions,
select_paths=select_paths,
select_domains=select_domains,
exclude_paths=exclude_paths,
exclude_domains=exclude_domains,
allow_external=allow_external,
include_images=include_images,
timeout=timeout,
include_usage=include_usage,
**kwargs)

return response_dict
return self._map(url,
max_depth=max_depth,
max_breadth=max_breadth,
limit=limit,
instructions=instructions,
select_paths=select_paths,
select_domains=select_domains,
exclude_paths=exclude_paths,
exclude_domains=exclude_domains,
allow_external=allow_external,
include_images=include_images,
timeout=timeout,
include_usage=include_usage,
**kwargs)

def get_search_context(self,
query: str,
Expand Down Expand Up @@ -563,47 +562,6 @@ def qna_search(self,
)
return response_dict.get("answer", "")

def get_company_info(self,
query: str,
search_depth: Literal["basic",
"advanced",
"fast",
"ultra-fast"] = "advanced",
max_results: int = 5,
timeout: float = 60,
country: str = None,
) -> Sequence[dict]:
""" Company information search method. Search depth is advanced by default to get the best answer. """
warnings.warn("get_company_info is deprecated and will be removed in future versions.",
DeprecationWarning, stacklevel=2)
def _perform_search(topic):
return self._search(query,
search_depth=search_depth,
topic=topic,
max_results=max_results,
include_answer=False,
timeout=timeout,
country=country)

with ThreadPoolExecutor() as executor:
# Initiate the search for each topic in parallel
future_to_topic = {executor.submit(_perform_search, topic): topic for topic in
["news", "general", "finance"]}

all_results = []

# Process the results as they become available
for future in as_completed(future_to_topic):
data = future.result()
if 'results' in data:
all_results.extend(data['results'])

# Sort all the results by score in descending order and take the top 'max_results' items
sorted_results = sorted(all_results, key=lambda x: x['score'], reverse=True)[
:max_results]

return sorted_results

def _research(self,
input: str,
model: Literal["mini", "pro", "auto"] = None,
Expand Down Expand Up @@ -631,12 +589,10 @@ def _research(self,

if stream:
try:
response = requests.post(
response = self.session.post(
self.base_url + "/research",
data=json.dumps(data),
headers=self.headers,
timeout=timeout,
proxies=self.proxies,
stream=True
)
except requests.exceptions.Timeout:
Expand All @@ -651,7 +607,7 @@ def _research(self,

if response.status_code == 429:
raise UsageLimitExceededError(detail)
elif response.status_code in [403,432,433]:
elif response.status_code in [403, 432, 433]:
raise ForbiddenError(detail)
elif response.status_code == 401:
raise InvalidAPIKeyError(detail)
Expand All @@ -671,12 +627,10 @@ def stream_generator() -> Generator[bytes, None, None]:
return stream_generator()
else:
try:
response = requests.post(
response = self.session.post(
self.base_url + "/research",
data=json.dumps(data),
headers=self.headers,
timeout=timeout,
proxies=self.proxies
timeout=timeout
)
except requests.exceptions.Timeout:
raise TimeoutError(timeout)
Expand All @@ -692,7 +646,7 @@ def stream_generator() -> Generator[bytes, None, None]:

if response.status_code == 429:
raise UsageLimitExceededError(detail)
elif response.status_code in [403,432,433]:
elif response.status_code in [403, 432, 433]:
raise ForbiddenError(detail)
elif response.status_code == 401:
raise InvalidAPIKeyError(detail)
Expand Down Expand Up @@ -726,8 +680,7 @@ def research(self,
dict: Response containing request_id, created_at, status, input, and model.
"""


response_dict = self._research(
return self._research(
input=input,
model=model,
output_schema=output_schema,
Expand All @@ -737,8 +690,6 @@ def research(self,
**kwargs
)

return response_dict

def get_research(self,
request_id: str
) -> dict:
Expand All @@ -752,17 +703,12 @@ def get_research(self,
dict: Research response containing request_id, created_at, completed_at, status, content, and sources.
"""
try:
response = requests.get(
self.base_url + f"/research/{request_id}",
headers=self.headers,
proxies=self.proxies,
)
response = self.session.get(self.base_url + f"/research/{request_id}")
except Exception as e:
raise Exception(f"Error getting research: {e}")

if response.status_code in (200, 202):
data = response.json()
return data
return response.json()
else:
detail = ""
try:
Expand All @@ -772,7 +718,7 @@ def get_research(self,

if response.status_code == 429:
raise UsageLimitExceededError(detail)
elif response.status_code in [403,432,433]:
elif response.status_code in [403, 432, 433]:
raise ForbiddenError(detail)
elif response.status_code == 401:
raise InvalidAPIKeyError(detail)
Expand Down
Loading