33import os
44import warnings
55from typing import Literal , Sequence , Optional , List , Union , Generator
6- from concurrent .futures import ThreadPoolExecutor , as_completed
76from .utils import get_max_items_from_list
87from .errors import UsageLimitExceededError , InvalidAPIKeyError , MissingAPIKeyError , BadRequestError , ForbiddenError , TimeoutError
98
@@ -38,6 +37,21 @@ def __init__(self, api_key: Optional[str] = None, proxies: Optional[dict[str, st
3837 ** ({"X-Project-ID" : tavily_project } if tavily_project else {})
3938 }
4039
40+ self .session = requests .Session ()
41+ self .session .headers .update (self .headers )
42+ if self .proxies :
43+ self .session .proxies .update (self .proxies )
44+
45+ def close (self ):
46+ """Close the session and release resources."""
47+ self .session .close ()
48+
49+ def __enter__ (self ):
50+ return self
51+
52+ def __exit__ (self , exc_type , exc_val , exc_tb ):
53+ self .close ()
54+
4155 def _search (self ,
4256 query : str ,
4357 search_depth : Literal ["basic" , "advanced" , "fast" , "ultra-fast" ] = None ,
@@ -89,9 +103,11 @@ def _search(self,
89103 data .update (kwargs )
90104
91105 timeout = min (timeout , 120 )
106+ url = self .base_url + "/search"
107+ payload = json .dumps (data )
92108
93109 try :
94- response = requests . post (self . base_url + "/search" , data = json . dumps ( data ), headers = self . headers , timeout = timeout , proxies = self . proxies )
110+ response = self . session . post (url , data = payload , timeout = timeout )
95111 except requests .exceptions .Timeout :
96112 raise TimeoutError (timeout )
97113
@@ -106,13 +122,12 @@ def _search(self,
106122
107123 if response .status_code == 429 :
108124 raise UsageLimitExceededError (detail )
109- elif response .status_code in [403 ,432 ,433 ]:
125+ elif response .status_code in [403 , 432 , 433 ]:
110126 raise ForbiddenError (detail )
111127 elif response .status_code == 401 :
112128 raise InvalidAPIKeyError (detail )
113129 elif response .status_code == 400 :
114130 raise BadRequestError (detail )
115-
116131 else :
117132 raise response .raise_for_status ()
118133
@@ -160,13 +175,8 @@ def search(self,
160175 auto_parameters = auto_parameters ,
161176 include_favicon = include_favicon ,
162177 include_usage = include_usage ,
163- ** kwargs ,
164- )
165-
166- tavily_results = response_dict .get ("results" , [])
167-
168- response_dict ["results" ] = tavily_results
169-
178+ ** kwargs )
179+ response_dict .setdefault ("results" , [])
170180 return response_dict
171181
172182 def _extract (self ,
@@ -202,7 +212,7 @@ def _extract(self,
202212 data .update (kwargs )
203213
204214 try :
205- response = requests . post (self .base_url + "/extract" , data = json .dumps (data ), headers = self . headers , timeout = timeout , proxies = self . proxies )
215+ response = self . session . post (self .base_url + "/extract" , data = json .dumps (data ), timeout = timeout )
206216 except requests .exceptions .Timeout :
207217 raise TimeoutError (timeout )
208218
@@ -217,7 +227,7 @@ def _extract(self,
217227
218228 if response .status_code == 429 :
219229 raise UsageLimitExceededError (detail )
220- elif response .status_code in [403 ,432 ,433 ]:
230+ elif response .status_code in [403 , 432 , 433 ]:
221231 raise ForbiddenError (detail )
222232 elif response .status_code == 401 :
223233 raise InvalidAPIKeyError (detail )
@@ -251,13 +261,8 @@ def extract(self,
251261 query = query ,
252262 chunks_per_source = chunks_per_source ,
253263 ** kwargs )
254-
255- tavily_results = response_dict .get ("results" , [])
256- failed_results = response_dict .get ("failed_results" , [])
257-
258- response_dict ["results" ] = tavily_results
259- response_dict ["failed_results" ] = failed_results
260-
264+ response_dict .setdefault ("results" , [])
265+ response_dict .setdefault ("failed_results" , [])
261266 return response_dict
262267
263268 def _crawl (self ,
@@ -310,8 +315,7 @@ def _crawl(self,
310315 data = {k : v for k , v in data .items () if v is not None }
311316
312317 try :
313- response = requests .post (
314- self .base_url + "/crawl" , data = json .dumps (data ), headers = self .headers , timeout = timeout , proxies = self .proxies )
318+ response = self .session .post (self .base_url + "/crawl" , data = json .dumps (data ), timeout = timeout )
315319 except requests .exceptions .Timeout :
316320 raise TimeoutError (timeout )
317321
@@ -326,7 +330,7 @@ def _crawl(self,
326330
327331 if response .status_code == 429 :
328332 raise UsageLimitExceededError (detail )
329- elif response .status_code in [403 ,432 ,433 ]:
333+ elif response .status_code in [403 , 432 , 433 ]:
330334 raise ForbiddenError (detail )
331335 elif response .status_code == 401 :
332336 raise InvalidAPIKeyError (detail )
@@ -359,26 +363,24 @@ def crawl(self,
359363 Combined crawl method.
360364 include_favicon: If True, include the favicon in the crawl results.
361365 """
362- response_dict = self ._crawl (url ,
363- max_depth = max_depth ,
364- max_breadth = max_breadth ,
365- limit = limit ,
366- instructions = instructions ,
367- select_paths = select_paths ,
368- select_domains = select_domains ,
369- exclude_paths = exclude_paths ,
370- exclude_domains = exclude_domains ,
371- allow_external = allow_external ,
372- include_images = include_images ,
373- extract_depth = extract_depth ,
374- format = format ,
375- timeout = timeout ,
376- include_favicon = include_favicon ,
377- include_usage = include_usage ,
378- chunks_per_source = chunks_per_source ,
379- ** kwargs )
380-
381- return response_dict
366+ return self ._crawl (url ,
367+ max_depth = max_depth ,
368+ max_breadth = max_breadth ,
369+ limit = limit ,
370+ instructions = instructions ,
371+ select_paths = select_paths ,
372+ select_domains = select_domains ,
373+ exclude_paths = exclude_paths ,
374+ exclude_domains = exclude_domains ,
375+ allow_external = allow_external ,
376+ include_images = include_images ,
377+ extract_depth = extract_depth ,
378+ format = format ,
379+ timeout = timeout ,
380+ include_favicon = include_favicon ,
381+ include_usage = include_usage ,
382+ chunks_per_source = chunks_per_source ,
383+ ** kwargs )
382384
383385 def _map (self ,
384386 url : str ,
@@ -421,8 +423,7 @@ def _map(self,
421423 data = {k : v for k , v in data .items () if v is not None }
422424
423425 try :
424- response = requests .post (
425- self .base_url + "/map" , data = json .dumps (data ), headers = self .headers , timeout = timeout , proxies = self .proxies )
426+ response = self .session .post (self .base_url + "/map" , data = json .dumps (data ), timeout = timeout )
426427 except requests .exceptions .Timeout :
427428 raise TimeoutError (timeout )
428429
@@ -437,7 +438,7 @@ def _map(self,
437438
438439 if response .status_code == 429 :
439440 raise UsageLimitExceededError (detail )
440- elif response .status_code in [403 ,432 ,433 ]:
441+ elif response .status_code in [403 , 432 , 433 ]:
441442 raise ForbiddenError (detail )
442443 elif response .status_code == 401 :
443444 raise InvalidAPIKeyError (detail )
@@ -466,22 +467,20 @@ def map(self,
466467 Combined map method.
467468
468469 """
469- response_dict = self ._map (url ,
470- max_depth = max_depth ,
471- max_breadth = max_breadth ,
472- limit = limit ,
473- instructions = instructions ,
474- select_paths = select_paths ,
475- select_domains = select_domains ,
476- exclude_paths = exclude_paths ,
477- exclude_domains = exclude_domains ,
478- allow_external = allow_external ,
479- include_images = include_images ,
480- timeout = timeout ,
481- include_usage = include_usage ,
482- ** kwargs )
483-
484- return response_dict
470+ return self ._map (url ,
471+ max_depth = max_depth ,
472+ max_breadth = max_breadth ,
473+ limit = limit ,
474+ instructions = instructions ,
475+ select_paths = select_paths ,
476+ select_domains = select_domains ,
477+ exclude_paths = exclude_paths ,
478+ exclude_domains = exclude_domains ,
479+ allow_external = allow_external ,
480+ include_images = include_images ,
481+ timeout = timeout ,
482+ include_usage = include_usage ,
483+ ** kwargs )
485484
486485 def get_search_context (self ,
487486 query : str ,
@@ -563,47 +562,6 @@ def qna_search(self,
563562 )
564563 return response_dict .get ("answer" , "" )
565564
566- def get_company_info (self ,
567- query : str ,
568- search_depth : Literal ["basic" ,
569- "advanced" ,
570- "fast" ,
571- "ultra-fast" ] = "advanced" ,
572- max_results : int = 5 ,
573- timeout : float = 60 ,
574- country : str = None ,
575- ) -> Sequence [dict ]:
576- """ Company information search method. Search depth is advanced by default to get the best answer. """
577- warnings .warn ("get_company_info is deprecated and will be removed in future versions." ,
578- DeprecationWarning , stacklevel = 2 )
579- def _perform_search (topic ):
580- return self ._search (query ,
581- search_depth = search_depth ,
582- topic = topic ,
583- max_results = max_results ,
584- include_answer = False ,
585- timeout = timeout ,
586- country = country )
587-
588- with ThreadPoolExecutor () as executor :
589- # Initiate the search for each topic in parallel
590- future_to_topic = {executor .submit (_perform_search , topic ): topic for topic in
591- ["news" , "general" , "finance" ]}
592-
593- all_results = []
594-
595- # Process the results as they become available
596- for future in as_completed (future_to_topic ):
597- data = future .result ()
598- if 'results' in data :
599- all_results .extend (data ['results' ])
600-
601- # Sort all the results by score in descending order and take the top 'max_results' items
602- sorted_results = sorted (all_results , key = lambda x : x ['score' ], reverse = True )[
603- :max_results ]
604-
605- return sorted_results
606-
607565 def _research (self ,
608566 input : str ,
609567 model : Literal ["mini" , "pro" , "auto" ] = None ,
@@ -631,12 +589,10 @@ def _research(self,
631589
632590 if stream :
633591 try :
634- response = requests .post (
592+ response = self . session .post (
635593 self .base_url + "/research" ,
636594 data = json .dumps (data ),
637- headers = self .headers ,
638595 timeout = timeout ,
639- proxies = self .proxies ,
640596 stream = True
641597 )
642598 except requests .exceptions .Timeout :
@@ -651,7 +607,7 @@ def _research(self,
651607
652608 if response .status_code == 429 :
653609 raise UsageLimitExceededError (detail )
654- elif response .status_code in [403 ,432 ,433 ]:
610+ elif response .status_code in [403 , 432 , 433 ]:
655611 raise ForbiddenError (detail )
656612 elif response .status_code == 401 :
657613 raise InvalidAPIKeyError (detail )
@@ -671,12 +627,10 @@ def stream_generator() -> Generator[bytes, None, None]:
671627 return stream_generator ()
672628 else :
673629 try :
674- response = requests .post (
630+ response = self . session .post (
675631 self .base_url + "/research" ,
676632 data = json .dumps (data ),
677- headers = self .headers ,
678- timeout = timeout ,
679- proxies = self .proxies
633+ timeout = timeout
680634 )
681635 except requests .exceptions .Timeout :
682636 raise TimeoutError (timeout )
@@ -692,7 +646,7 @@ def stream_generator() -> Generator[bytes, None, None]:
692646
693647 if response .status_code == 429 :
694648 raise UsageLimitExceededError (detail )
695- elif response .status_code in [403 ,432 ,433 ]:
649+ elif response .status_code in [403 , 432 , 433 ]:
696650 raise ForbiddenError (detail )
697651 elif response .status_code == 401 :
698652 raise InvalidAPIKeyError (detail )
@@ -726,8 +680,7 @@ def research(self,
726680 dict: Response containing request_id, created_at, status, input, and model.
727681 """
728682
729-
730- response_dict = self ._research (
683+ return self ._research (
731684 input = input ,
732685 model = model ,
733686 output_schema = output_schema ,
@@ -737,8 +690,6 @@ def research(self,
737690 ** kwargs
738691 )
739692
740- return response_dict
741-
742693 def get_research (self ,
743694 request_id : str
744695 ) -> dict :
@@ -752,17 +703,12 @@ def get_research(self,
752703 dict: Research response containing request_id, created_at, completed_at, status, content, and sources.
753704 """
754705 try :
755- response = requests .get (
756- self .base_url + f"/research/{ request_id } " ,
757- headers = self .headers ,
758- proxies = self .proxies ,
759- )
706+ response = self .session .get (self .base_url + f"/research/{ request_id } " )
760707 except Exception as e :
761708 raise Exception (f"Error getting research: { e } " )
762709
763710 if response .status_code in (200 , 202 ):
764- data = response .json ()
765- return data
711+ return response .json ()
766712 else :
767713 detail = ""
768714 try :
@@ -772,7 +718,7 @@ def get_research(self,
772718
773719 if response .status_code == 429 :
774720 raise UsageLimitExceededError (detail )
775- elif response .status_code in [403 ,432 ,433 ]:
721+ elif response .status_code in [403 , 432 , 433 ]:
776722 raise ForbiddenError (detail )
777723 elif response .status_code == 401 :
778724 raise InvalidAPIKeyError (detail )
0 commit comments