Skip to content

Commit 1d66aa3

Browse files
authored
Merge pull request #152 from tavily-ai/feat/TAV-4105-session-pooling
feat(client): add session pooling for connection reuse TAV-4105
2 parents ed6f3b0 + 5cdb58d commit 1d66aa3

2 files changed

Lines changed: 97 additions & 124 deletions

File tree

tavily/tavily.py

Lines changed: 70 additions & 124 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import os
44
import warnings
55
from typing import Literal, Sequence, Optional, List, Union, Generator
6-
from concurrent.futures import ThreadPoolExecutor, as_completed
76
from .utils import get_max_items_from_list
87
from .errors import UsageLimitExceededError, InvalidAPIKeyError, MissingAPIKeyError, BadRequestError, ForbiddenError, TimeoutError
98

@@ -38,6 +37,21 @@ def __init__(self, api_key: Optional[str] = None, proxies: Optional[dict[str, st
3837
**({"X-Project-ID": tavily_project} if tavily_project else {})
3938
}
4039

40+
self.session = requests.Session()
41+
self.session.headers.update(self.headers)
42+
if self.proxies:
43+
self.session.proxies.update(self.proxies)
44+
45+
def close(self):
46+
"""Close the session and release resources."""
47+
self.session.close()
48+
49+
def __enter__(self):
50+
return self
51+
52+
def __exit__(self, exc_type, exc_val, exc_tb):
53+
self.close()
54+
4155
def _search(self,
4256
query: str,
4357
search_depth: Literal["basic", "advanced", "fast", "ultra-fast"] = None,
@@ -89,9 +103,11 @@ def _search(self,
89103
data.update(kwargs)
90104

91105
timeout = min(timeout, 120)
106+
url = self.base_url + "/search"
107+
payload = json.dumps(data)
92108

93109
try:
94-
response = requests.post(self.base_url + "/search", data=json.dumps(data), headers=self.headers, timeout=timeout, proxies=self.proxies)
110+
response = self.session.post(url, data=payload, timeout=timeout)
95111
except requests.exceptions.Timeout:
96112
raise TimeoutError(timeout)
97113

@@ -106,13 +122,12 @@ def _search(self,
106122

107123
if response.status_code == 429:
108124
raise UsageLimitExceededError(detail)
109-
elif response.status_code in [403,432,433]:
125+
elif response.status_code in [403, 432, 433]:
110126
raise ForbiddenError(detail)
111127
elif response.status_code == 401:
112128
raise InvalidAPIKeyError(detail)
113129
elif response.status_code == 400:
114130
raise BadRequestError(detail)
115-
116131
else:
117132
raise response.raise_for_status()
118133

@@ -160,13 +175,8 @@ def search(self,
160175
auto_parameters=auto_parameters,
161176
include_favicon=include_favicon,
162177
include_usage=include_usage,
163-
**kwargs,
164-
)
165-
166-
tavily_results = response_dict.get("results", [])
167-
168-
response_dict["results"] = tavily_results
169-
178+
**kwargs)
179+
response_dict.setdefault("results", [])
170180
return response_dict
171181

172182
def _extract(self,
@@ -202,7 +212,7 @@ def _extract(self,
202212
data.update(kwargs)
203213

204214
try:
205-
response = requests.post(self.base_url + "/extract", data=json.dumps(data), headers=self.headers, timeout=timeout, proxies=self.proxies)
215+
response = self.session.post(self.base_url + "/extract", data=json.dumps(data), timeout=timeout)
206216
except requests.exceptions.Timeout:
207217
raise TimeoutError(timeout)
208218

@@ -217,7 +227,7 @@ def _extract(self,
217227

218228
if response.status_code == 429:
219229
raise UsageLimitExceededError(detail)
220-
elif response.status_code in [403,432,433]:
230+
elif response.status_code in [403, 432, 433]:
221231
raise ForbiddenError(detail)
222232
elif response.status_code == 401:
223233
raise InvalidAPIKeyError(detail)
@@ -251,13 +261,8 @@ def extract(self,
251261
query=query,
252262
chunks_per_source=chunks_per_source,
253263
**kwargs)
254-
255-
tavily_results = response_dict.get("results", [])
256-
failed_results = response_dict.get("failed_results", [])
257-
258-
response_dict["results"] = tavily_results
259-
response_dict["failed_results"] = failed_results
260-
264+
response_dict.setdefault("results", [])
265+
response_dict.setdefault("failed_results", [])
261266
return response_dict
262267

263268
def _crawl(self,
@@ -310,8 +315,7 @@ def _crawl(self,
310315
data = {k: v for k, v in data.items() if v is not None}
311316

312317
try:
313-
response = requests.post(
314-
self.base_url + "/crawl", data=json.dumps(data), headers=self.headers, timeout=timeout, proxies=self.proxies)
318+
response = self.session.post(self.base_url + "/crawl", data=json.dumps(data), timeout=timeout)
315319
except requests.exceptions.Timeout:
316320
raise TimeoutError(timeout)
317321

@@ -326,7 +330,7 @@ def _crawl(self,
326330

327331
if response.status_code == 429:
328332
raise UsageLimitExceededError(detail)
329-
elif response.status_code in [403,432,433]:
333+
elif response.status_code in [403, 432, 433]:
330334
raise ForbiddenError(detail)
331335
elif response.status_code == 401:
332336
raise InvalidAPIKeyError(detail)
@@ -359,26 +363,24 @@ def crawl(self,
359363
Combined crawl method.
360364
include_favicon: If True, include the favicon in the crawl results.
361365
"""
362-
response_dict = self._crawl(url,
363-
max_depth=max_depth,
364-
max_breadth=max_breadth,
365-
limit=limit,
366-
instructions=instructions,
367-
select_paths=select_paths,
368-
select_domains=select_domains,
369-
exclude_paths=exclude_paths,
370-
exclude_domains=exclude_domains,
371-
allow_external=allow_external,
372-
include_images=include_images,
373-
extract_depth=extract_depth,
374-
format=format,
375-
timeout=timeout,
376-
include_favicon=include_favicon,
377-
include_usage=include_usage,
378-
chunks_per_source=chunks_per_source,
379-
**kwargs)
380-
381-
return response_dict
366+
return self._crawl(url,
367+
max_depth=max_depth,
368+
max_breadth=max_breadth,
369+
limit=limit,
370+
instructions=instructions,
371+
select_paths=select_paths,
372+
select_domains=select_domains,
373+
exclude_paths=exclude_paths,
374+
exclude_domains=exclude_domains,
375+
allow_external=allow_external,
376+
include_images=include_images,
377+
extract_depth=extract_depth,
378+
format=format,
379+
timeout=timeout,
380+
include_favicon=include_favicon,
381+
include_usage=include_usage,
382+
chunks_per_source=chunks_per_source,
383+
**kwargs)
382384

383385
def _map(self,
384386
url: str,
@@ -421,8 +423,7 @@ def _map(self,
421423
data = {k: v for k, v in data.items() if v is not None}
422424

423425
try:
424-
response = requests.post(
425-
self.base_url + "/map", data=json.dumps(data), headers=self.headers, timeout=timeout, proxies=self.proxies)
426+
response = self.session.post(self.base_url + "/map", data=json.dumps(data), timeout=timeout)
426427
except requests.exceptions.Timeout:
427428
raise TimeoutError(timeout)
428429

@@ -437,7 +438,7 @@ def _map(self,
437438

438439
if response.status_code == 429:
439440
raise UsageLimitExceededError(detail)
440-
elif response.status_code in [403,432,433]:
441+
elif response.status_code in [403, 432, 433]:
441442
raise ForbiddenError(detail)
442443
elif response.status_code == 401:
443444
raise InvalidAPIKeyError(detail)
@@ -466,22 +467,20 @@ def map(self,
466467
Combined map method.
467468
468469
"""
469-
response_dict = self._map(url,
470-
max_depth=max_depth,
471-
max_breadth=max_breadth,
472-
limit=limit,
473-
instructions=instructions,
474-
select_paths=select_paths,
475-
select_domains=select_domains,
476-
exclude_paths=exclude_paths,
477-
exclude_domains=exclude_domains,
478-
allow_external=allow_external,
479-
include_images=include_images,
480-
timeout=timeout,
481-
include_usage=include_usage,
482-
**kwargs)
483-
484-
return response_dict
470+
return self._map(url,
471+
max_depth=max_depth,
472+
max_breadth=max_breadth,
473+
limit=limit,
474+
instructions=instructions,
475+
select_paths=select_paths,
476+
select_domains=select_domains,
477+
exclude_paths=exclude_paths,
478+
exclude_domains=exclude_domains,
479+
allow_external=allow_external,
480+
include_images=include_images,
481+
timeout=timeout,
482+
include_usage=include_usage,
483+
**kwargs)
485484

486485
def get_search_context(self,
487486
query: str,
@@ -563,47 +562,6 @@ def qna_search(self,
563562
)
564563
return response_dict.get("answer", "")
565564

566-
def get_company_info(self,
567-
query: str,
568-
search_depth: Literal["basic",
569-
"advanced",
570-
"fast",
571-
"ultra-fast"] = "advanced",
572-
max_results: int = 5,
573-
timeout: float = 60,
574-
country: str = None,
575-
) -> Sequence[dict]:
576-
""" Company information search method. Search depth is advanced by default to get the best answer. """
577-
warnings.warn("get_company_info is deprecated and will be removed in future versions.",
578-
DeprecationWarning, stacklevel=2)
579-
def _perform_search(topic):
580-
return self._search(query,
581-
search_depth=search_depth,
582-
topic=topic,
583-
max_results=max_results,
584-
include_answer=False,
585-
timeout=timeout,
586-
country=country)
587-
588-
with ThreadPoolExecutor() as executor:
589-
# Initiate the search for each topic in parallel
590-
future_to_topic = {executor.submit(_perform_search, topic): topic for topic in
591-
["news", "general", "finance"]}
592-
593-
all_results = []
594-
595-
# Process the results as they become available
596-
for future in as_completed(future_to_topic):
597-
data = future.result()
598-
if 'results' in data:
599-
all_results.extend(data['results'])
600-
601-
# Sort all the results by score in descending order and take the top 'max_results' items
602-
sorted_results = sorted(all_results, key=lambda x: x['score'], reverse=True)[
603-
:max_results]
604-
605-
return sorted_results
606-
607565
def _research(self,
608566
input: str,
609567
model: Literal["mini", "pro", "auto"] = None,
@@ -631,12 +589,10 @@ def _research(self,
631589

632590
if stream:
633591
try:
634-
response = requests.post(
592+
response = self.session.post(
635593
self.base_url + "/research",
636594
data=json.dumps(data),
637-
headers=self.headers,
638595
timeout=timeout,
639-
proxies=self.proxies,
640596
stream=True
641597
)
642598
except requests.exceptions.Timeout:
@@ -651,7 +607,7 @@ def _research(self,
651607

652608
if response.status_code == 429:
653609
raise UsageLimitExceededError(detail)
654-
elif response.status_code in [403,432,433]:
610+
elif response.status_code in [403, 432, 433]:
655611
raise ForbiddenError(detail)
656612
elif response.status_code == 401:
657613
raise InvalidAPIKeyError(detail)
@@ -671,12 +627,10 @@ def stream_generator() -> Generator[bytes, None, None]:
671627
return stream_generator()
672628
else:
673629
try:
674-
response = requests.post(
630+
response = self.session.post(
675631
self.base_url + "/research",
676632
data=json.dumps(data),
677-
headers=self.headers,
678-
timeout=timeout,
679-
proxies=self.proxies
633+
timeout=timeout
680634
)
681635
except requests.exceptions.Timeout:
682636
raise TimeoutError(timeout)
@@ -692,7 +646,7 @@ def stream_generator() -> Generator[bytes, None, None]:
692646

693647
if response.status_code == 429:
694648
raise UsageLimitExceededError(detail)
695-
elif response.status_code in [403,432,433]:
649+
elif response.status_code in [403, 432, 433]:
696650
raise ForbiddenError(detail)
697651
elif response.status_code == 401:
698652
raise InvalidAPIKeyError(detail)
@@ -726,8 +680,7 @@ def research(self,
726680
dict: Response containing request_id, created_at, status, input, and model.
727681
"""
728682

729-
730-
response_dict = self._research(
683+
return self._research(
731684
input=input,
732685
model=model,
733686
output_schema=output_schema,
@@ -737,8 +690,6 @@ def research(self,
737690
**kwargs
738691
)
739692

740-
return response_dict
741-
742693
def get_research(self,
743694
request_id: str
744695
) -> dict:
@@ -752,17 +703,12 @@ def get_research(self,
752703
dict: Research response containing request_id, created_at, completed_at, status, content, and sources.
753704
"""
754705
try:
755-
response = requests.get(
756-
self.base_url + f"/research/{request_id}",
757-
headers=self.headers,
758-
proxies=self.proxies,
759-
)
706+
response = self.session.get(self.base_url + f"/research/{request_id}")
760707
except Exception as e:
761708
raise Exception(f"Error getting research: {e}")
762709

763710
if response.status_code in (200, 202):
764-
data = response.json()
765-
return data
711+
return response.json()
766712
else:
767713
detail = ""
768714
try:
@@ -772,7 +718,7 @@ def get_research(self,
772718

773719
if response.status_code == 429:
774720
raise UsageLimitExceededError(detail)
775-
elif response.status_code in [403,432,433]:
721+
elif response.status_code in [403, 432, 433]:
776722
raise ForbiddenError(detail)
777723
elif response.status_code == 401:
778724
raise InvalidAPIKeyError(detail)

0 commit comments

Comments
 (0)