diff --git a/README.md b/README.md index 1f49f07..7d2951f 100644 --- a/README.md +++ b/README.md @@ -246,6 +246,53 @@ for chunk in stream: print(chunk.decode('utf-8')) ``` +## Advanced: Custom Session / Client Injection + +For enterprise environments that proxy Tavily traffic through an API gateway (e.g., for centralized auth, logging, or policy enforcement), you can pass a pre-configured HTTP session instead of a Tavily API key. + +### Sync (custom `requests.Session`) + +```python +import requests +from tavily import TavilyClient + +# Pre-configure a session with your gateway's auth +session = requests.Session() +session.headers["Authorization"] = "Bearer your-gateway-token" +session.headers["X-Subscription-Key"] = "your-subscription-key" + +# No Tavily API key needed — auth is handled by the session +client = TavilyClient( + session=session, + api_base_url="https://your-gateway.com/tavily", +) + +response = client.search("latest AI research") +``` + +### Async (custom `httpx.AsyncClient`) + +```python +import httpx +from tavily import AsyncTavilyClient + +# Pre-configure an async client with your gateway's auth +custom_client = httpx.AsyncClient( + headers={"Authorization": "Bearer your-gateway-token"}, + base_url="https://your-gateway.com/tavily", +) + +client = AsyncTavilyClient(client=custom_client) + +response = await client.search("latest AI research") +``` + +**Key behaviors:** +- If a custom session/client is provided, `api_key` is optional +- Custom session headers take precedence over SDK defaults (e.g., your `Authorization` won't be overwritten) +- Custom session proxies take precedence over SDK proxy settings +- The SDK will **not** close externally-provided sessions — you manage the lifecycle + ## Documentation For a complete guide on how to use the different endpoints and their parameters, please head to our [Python API Reference](https://docs.tavily.com/sdk/python/reference). diff --git a/setup.py b/setup.py index fd317e7..4c4618f 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ setup( name='tavily-python', - version='0.7.22', + version='0.7.23', url='https://github.com/tavily-ai/tavily-python', author='Tavily AI', author_email='support@tavily.com', diff --git a/tavily/async_tavily.py b/tavily/async_tavily.py index b52640e..dc896b8 100644 --- a/tavily/async_tavily.py +++ b/tavily/async_tavily.py @@ -19,48 +19,63 @@ def __init__(self, api_key: Optional[str] = None, proxies: Optional[dict[str, str]] = None, api_base_url: Optional[str] = None, client_source: Optional[str] = None, - project_id: Optional[str] = None): + project_id: Optional[str] = None, + client: Optional[httpx.AsyncClient] = None): if api_key is None: api_key = os.getenv("TAVILY_API_KEY") - if not api_key: + if not api_key and client is None: raise MissingAPIKeyError() - proxies = proxies or {} + tavily_project = project_id or os.getenv("TAVILY_PROJECT") - mapped_proxies = { - "http://": proxies.get("http", os.getenv("TAVILY_HTTP_PROXY")), - "https://": proxies.get("https", os.getenv("TAVILY_HTTPS_PROXY")), + self._api_base_url = api_base_url or "https://api.tavily.com" + self._company_info_tags = company_info_tags + + default_headers = { + "Content-Type": "application/json", + **({"Authorization": f"Bearer {api_key}"} if api_key else {}), + "X-Client-Source": client_source or "tavily-python", + **({"X-Project-ID": tavily_project} if tavily_project else {}) } - mapped_proxies = {key: value for key, value in mapped_proxies.items() if value} + self._external_client = client is not None + + if client is not None: + self._client = client + # Only set headers that aren't already configured on the external client + for key, value in default_headers.items(): + if key not in self._client.headers: + self._client.headers[key] = value + # Set base_url if the external client doesn't have one + if not str(self._client.base_url): + self._client.base_url = self._api_base_url + else: + proxies = proxies or {} - proxy_mounts = ( - {scheme: httpx.AsyncHTTPTransport(proxy=proxy) for scheme, proxy in mapped_proxies.items()} - if mapped_proxies - else None - ) + mapped_proxies = { + "http://": proxies.get("http", os.getenv("TAVILY_HTTP_PROXY")), + "https://": proxies.get("https", os.getenv("TAVILY_HTTPS_PROXY")), + } - tavily_project = project_id or os.getenv("TAVILY_PROJECT") + mapped_proxies = {key: value for key, value in mapped_proxies.items() if value} - self._api_base_url = api_base_url or "https://api.tavily.com" + proxy_mounts = ( + {scheme: httpx.AsyncHTTPTransport(proxy=proxy) for scheme, proxy in mapped_proxies.items()} + if mapped_proxies + else None + ) - # Create a persistent client for connection pooling - self._client = httpx.AsyncClient( - headers={ - "Content-Type": "application/json", - "Authorization": f"Bearer {api_key}", - "X-Client-Source": client_source or "tavily-python", - **({"X-Project-ID": tavily_project} if tavily_project else {}) - }, - base_url=self._api_base_url, - mounts=proxy_mounts - ) - self._company_info_tags = company_info_tags + self._client = httpx.AsyncClient( + headers=default_headers, + base_url=self._api_base_url, + mounts=proxy_mounts + ) async def close(self): """Close the client and release connection pool resources.""" - await self._client.aclose() + if not self._external_client: + await self._client.aclose() async def __aenter__(self): return self diff --git a/tavily/hybrid_rag/hybrid_rag.py b/tavily/hybrid_rag/hybrid_rag.py index c3ea97c..0a2d823 100644 --- a/tavily/hybrid_rag/hybrid_rag.py +++ b/tavily/hybrid_rag/hybrid_rag.py @@ -1,6 +1,7 @@ import os from typing import Union, Optional, Literal +import requests from tavily import TavilyClient try: @@ -76,7 +77,8 @@ def __init__( embeddings_field: str = 'embeddings', content_field: str = 'content', embedding_function: Optional[callable] = None, - ranking_function: Optional[callable] = None + ranking_function: Optional[callable] = None, + session: Optional[requests.Session] = None ): ''' A client for performing hybrid RAG using both the Tavily API and a local database collection. @@ -90,9 +92,10 @@ def __init__( content_field (str): The name of the field in the collection that contains the content. embedding_function (callable): If provided, this function will be used to generate embeddings for the search query and documents. ranking_function (callable): If provided, this function will be used to rerank the combined results. + session (requests.Session): If provided, this pre-configured session will be used for HTTP requests. When set, api_key is optional. ''' - - self.tavily = TavilyClient(api_key) + + self.tavily = TavilyClient(api_key, session=session) if db_provider != 'mongodb': raise ValueError("Only MongoDB is currently supported as a database provider.") diff --git a/tavily/tavily.py b/tavily/tavily.py index b8e6674..b059734 100644 --- a/tavily/tavily.py +++ b/tavily/tavily.py @@ -11,11 +11,11 @@ class TavilyClient: Tavily API client class. """ - def __init__(self, api_key: Optional[str] = None, proxies: Optional[dict[str, str]] = None, api_base_url: Optional[str] = None, client_source: Optional[str] = None, project_id: Optional[str] = None): + def __init__(self, api_key: Optional[str] = None, proxies: Optional[dict[str, str]] = None, api_base_url: Optional[str] = None, client_source: Optional[str] = None, project_id: Optional[str] = None, session: Optional[requests.Session] = None): if api_key is None: api_key = os.getenv("TAVILY_API_KEY") - if not api_key: + if not api_key and session is None: raise MissingAPIKeyError() resolved_proxies = { @@ -32,19 +32,26 @@ def __init__(self, api_key: Optional[str] = None, proxies: Optional[dict[str, st self.headers = { "Content-Type": "application/json", - "Authorization": f"Bearer {self.api_key}", + **({"Authorization": f"Bearer {self.api_key}"} if self.api_key else {}), "X-Client-Source": client_source or "tavily-python", **({"X-Project-ID": tavily_project} if tavily_project else {}) } - self.session = requests.Session() - self.session.headers.update(self.headers) + self._external_session = session is not None + self.session = session if session is not None else requests.Session() + # For external sessions, only set headers that aren't already configured + for key, value in self.headers.items(): + if key not in self.session.headers: + self.session.headers[key] = value if self.proxies: - self.session.proxies.update(self.proxies) + for protocol, url in self.proxies.items(): + if protocol not in self.session.proxies: + self.session.proxies[protocol] = url def close(self): """Close the session and release resources.""" - self.session.close() + if not self._external_session: + self.session.close() def __enter__(self): return self diff --git a/tests/test_custom_session.py b/tests/test_custom_session.py new file mode 100644 index 0000000..7063b28 --- /dev/null +++ b/tests/test_custom_session.py @@ -0,0 +1,410 @@ +import asyncio +import httpx +from tests.request_intercept import intercept_requests, clear_interceptor, MockSession +import tavily.tavily as sync_tavily +import tavily.async_tavily as async_tavily +import pytest +from tavily.errors import MissingAPIKeyError + + +@pytest.fixture +def sync_interceptor(): + yield intercept_requests(sync_tavily) + clear_interceptor(sync_tavily) + + +@pytest.fixture +def async_interceptor(): + yield intercept_requests(async_tavily) + clear_interceptor(async_tavily) + + +# --- Sync TavilyClient tests --- + +class TestSyncCustomSession: + def test_default_session_created_when_none_provided(self, sync_interceptor): + client = sync_tavily.TavilyClient(api_key="tvly-test") + assert not client._external_session + + def test_custom_session_used(self, sync_interceptor): + custom_session = MockSession(sync_interceptor) + client = sync_tavily.TavilyClient(api_key="tvly-test", session=custom_session) + assert client._external_session + assert client.session is custom_session + + def test_custom_session_preserves_existing_headers(self, sync_interceptor): + custom_session = MockSession(sync_interceptor) + custom_session.headers["Authorization"] = "Bearer apim-token-123" + custom_session.headers["X-Custom"] = "custom-value" + + client = sync_tavily.TavilyClient(api_key="tvly-test", session=custom_session) + + # Custom Authorization should be preserved (not overwritten by Tavily's) + assert client.session.headers["Authorization"] == "Bearer apim-token-123" + # Custom header should be preserved + assert client.session.headers["X-Custom"] == "custom-value" + # Tavily defaults should fill in missing headers + assert client.session.headers["Content-Type"] == "application/json" + assert client.session.headers["X-Client-Source"] == "tavily-python" + + def test_custom_session_gets_default_headers_when_empty(self, sync_interceptor): + custom_session = MockSession(sync_interceptor) + client = sync_tavily.TavilyClient(api_key="tvly-test", session=custom_session) + + assert client.session.headers["Authorization"] == "Bearer tvly-test" + assert client.session.headers["Content-Type"] == "application/json" + + def test_custom_session_preserves_existing_proxies(self, sync_interceptor): + custom_session = MockSession(sync_interceptor) + custom_session.proxies["https"] = "http://my-proxy:8080" + + client = sync_tavily.TavilyClient( + api_key="tvly-test", + session=custom_session, + proxies={"https": "http://tavily-proxy:9090"}, + ) + + # Custom session proxy should take precedence + assert client.session.proxies["https"] == "http://my-proxy:8080" + + def test_close_does_not_close_external_session(self, sync_interceptor): + closed = [] + custom_session = MockSession(sync_interceptor) + custom_session.close = lambda: closed.append(True) + + client = sync_tavily.TavilyClient(api_key="tvly-test", session=custom_session) + client.close() + assert len(closed) == 0 + + def test_close_closes_internal_session(self, sync_interceptor): + client = sync_tavily.TavilyClient(api_key="tvly-test") + # Should not raise — just verifies close() is called on internal session + client.close() + + def test_context_manager_does_not_close_external_session(self, sync_interceptor): + closed = [] + custom_session = MockSession(sync_interceptor) + custom_session.close = lambda: closed.append(True) + + with sync_tavily.TavilyClient(api_key="tvly-test", session=custom_session): + pass + assert len(closed) == 0 + + def test_custom_session_sends_request(self, sync_interceptor): + sync_interceptor.set_response(200, json={"results": []}) + custom_session = MockSession(sync_interceptor) + custom_session.headers["Authorization"] = "Bearer apim-token" + + client = sync_tavily.TavilyClient(api_key="tvly-test", session=custom_session) + client.search("test query") + + req = sync_interceptor.get_request() + assert req is not None + assert req.headers["Authorization"] == "Bearer apim-token" + + # --- API key validation edge cases --- + + def test_no_api_key_no_session_raises(self, monkeypatch): + monkeypatch.delenv("TAVILY_API_KEY", raising=False) + with pytest.raises(MissingAPIKeyError): + sync_tavily.TavilyClient() + + def test_no_api_key_with_session_allowed(self, sync_interceptor, monkeypatch): + monkeypatch.delenv("TAVILY_API_KEY", raising=False) + custom_session = MockSession(sync_interceptor) + custom_session.headers["Authorization"] = "Bearer apim-token" + client = sync_tavily.TavilyClient(session=custom_session) + assert client.api_key is None + assert "Authorization" not in client.headers + assert client.session.headers["Authorization"] == "Bearer apim-token" + + def test_no_api_key_with_session_no_auth_header_on_defaults(self, sync_interceptor, monkeypatch): + monkeypatch.delenv("TAVILY_API_KEY", raising=False) + custom_session = MockSession(sync_interceptor) + client = sync_tavily.TavilyClient(session=custom_session) + # No api_key means no Authorization in defaults + assert "Authorization" not in client.headers + # Session shouldn't get an Authorization header either + assert "Authorization" not in client.session.headers + + def test_no_api_key_with_session_sends_request(self, sync_interceptor, monkeypatch): + monkeypatch.delenv("TAVILY_API_KEY", raising=False) + sync_interceptor.set_response(200, json={"results": []}) + custom_session = MockSession(sync_interceptor) + custom_session.headers["Authorization"] = "Bearer apim-token" + + client = sync_tavily.TavilyClient(session=custom_session) + client.search("test query") + + req = sync_interceptor.get_request() + assert req is not None + assert req.headers["Authorization"] == "Bearer apim-token" + + def test_empty_string_api_key_no_session_raises(self, monkeypatch): + monkeypatch.delenv("TAVILY_API_KEY", raising=False) + with pytest.raises(MissingAPIKeyError): + sync_tavily.TavilyClient(api_key="") + + def test_empty_string_api_key_with_session_allowed(self, sync_interceptor): + custom_session = MockSession(sync_interceptor) + client = sync_tavily.TavilyClient(api_key="", session=custom_session) + assert "Authorization" not in client.headers + + def test_api_key_and_session_both_provided(self, sync_interceptor): + custom_session = MockSession(sync_interceptor) + client = sync_tavily.TavilyClient(api_key="tvly-test", session=custom_session) + # api_key provided and session has no Authorization, so default fills it in + assert client.session.headers["Authorization"] == "Bearer tvly-test" + + def test_custom_session_with_all_endpoints(self, sync_interceptor): + custom_session = MockSession(sync_interceptor) + custom_session.headers["Authorization"] = "Bearer apim-token" + + client = sync_tavily.TavilyClient(session=custom_session) + + # search + sync_interceptor.set_response(200, json={"results": []}) + client.search("test") + assert sync_interceptor.get_request().headers["Authorization"] == "Bearer apim-token" + + # extract + sync_interceptor.set_response(200, json={"results": [], "failed_results": []}) + client.extract(urls=["https://example.com"]) + assert sync_interceptor.get_request().headers["Authorization"] == "Bearer apim-token" + + # crawl + sync_interceptor.set_response(200, json={"results": []}) + client.crawl(url="https://example.com") + assert sync_interceptor.get_request().headers["Authorization"] == "Bearer apim-token" + + # map + sync_interceptor.set_response(200, json={"results": []}) + client.map(url="https://example.com") + assert sync_interceptor.get_request().headers["Authorization"] == "Bearer apim-token" + + def test_custom_session_with_custom_base_url(self, sync_interceptor): + sync_interceptor.set_response(200, json={"results": []}) + custom_session = MockSession(sync_interceptor) + + client = sync_tavily.TavilyClient( + api_key="tvly-test", + session=custom_session, + api_base_url="https://apim.corp.com/tavily", + ) + client.search("test") + + req = sync_interceptor.get_request() + assert req.url == "https://apim.corp.com/tavily/search" + + def test_custom_session_proxies_fill_missing_protocols(self, sync_interceptor): + custom_session = MockSession(sync_interceptor) + custom_session.proxies["http"] = "http://session-proxy:8080" + + client = sync_tavily.TavilyClient( + api_key="tvly-test", + session=custom_session, + proxies={"http": "http://arg-proxy:9090", "https": "http://arg-proxy:9091"}, + ) + + # http: session proxy wins + assert client.session.proxies["http"] == "http://session-proxy:8080" + # https: session didn't have it, so arg fills it in + assert client.session.proxies["https"] == "http://arg-proxy:9091" + + def test_custom_session_project_id_header(self, sync_interceptor): + custom_session = MockSession(sync_interceptor) + client = sync_tavily.TavilyClient( + api_key="tvly-test", + session=custom_session, + project_id="my-project", + ) + assert client.session.headers["X-Project-ID"] == "my-project" + + def test_shared_session_across_multiple_clients(self, sync_interceptor): + sync_interceptor.set_response(200, json={"results": []}) + shared_session = MockSession(sync_interceptor) + shared_session.headers["Authorization"] = "Bearer shared-token" + + client1 = sync_tavily.TavilyClient(session=shared_session) + client2 = sync_tavily.TavilyClient(session=shared_session) + + assert client1.session is client2.session + + client1.search("query1") + assert sync_interceptor.get_request().headers["Authorization"] == "Bearer shared-token" + + client2.search("query2") + assert sync_interceptor.get_request().headers["Authorization"] == "Bearer shared-token" + + # Closing one client should not close the shared session + client1.close() + client2.search("query3") + assert sync_interceptor.get_request() is not None + + +# --- Async AsyncTavilyClient tests --- + +class TestAsyncCustomClient: + def test_default_client_created_when_none_provided(self): + client = async_tavily.AsyncTavilyClient(api_key="tvly-test") + assert not client._external_client + + def test_custom_client_used(self): + custom_client = httpx.AsyncClient() + client = async_tavily.AsyncTavilyClient(api_key="tvly-test", client=custom_client) + assert client._external_client + assert client._client is custom_client + + def test_custom_client_preserves_existing_headers(self): + custom_client = httpx.AsyncClient(headers={ + "Authorization": "Bearer apim-token-123", + "X-Custom": "custom-value", + }) + client = async_tavily.AsyncTavilyClient(api_key="tvly-test", client=custom_client) + + assert client._client.headers["Authorization"] == "Bearer apim-token-123" + assert client._client.headers["X-Custom"] == "custom-value" + assert client._client.headers["Content-Type"] == "application/json" + assert client._client.headers["X-Client-Source"] == "tavily-python" + + def test_custom_client_gets_default_headers_when_empty(self): + custom_client = httpx.AsyncClient() + client = async_tavily.AsyncTavilyClient(api_key="tvly-test", client=custom_client) + + assert client._client.headers["Authorization"] == "Bearer tvly-test" + assert client._client.headers["Content-Type"] == "application/json" + + def test_custom_client_base_url_set_when_missing(self): + custom_client = httpx.AsyncClient() + client = async_tavily.AsyncTavilyClient(api_key="tvly-test", client=custom_client) + assert "api.tavily.com" in str(client._client.base_url) + + def test_custom_client_base_url_preserved_when_set(self): + custom_client = httpx.AsyncClient(base_url="https://apim.example.com/tavily") + client = async_tavily.AsyncTavilyClient(api_key="tvly-test", client=custom_client) + assert "apim.example.com" in str(client._client.base_url) + + def test_close_does_not_close_external_client(self): + closed = [] + custom_client = httpx.AsyncClient() + + async def run(): + client = async_tavily.AsyncTavilyClient(api_key="tvly-test", client=custom_client) + original_aclose = custom_client.aclose + + async def track_close(): + closed.append(True) + await original_aclose() + + custom_client.aclose = track_close + await client.close() + + asyncio.run(run()) + assert len(closed) == 0 + + def test_context_manager_does_not_close_external_client(self): + closed = [] + custom_client = httpx.AsyncClient() + + async def run(): + original_aclose = custom_client.aclose + + async def track_close(): + closed.append(True) + await original_aclose() + + custom_client.aclose = track_close + async with async_tavily.AsyncTavilyClient(api_key="tvly-test", client=custom_client): + pass + + asyncio.run(run()) + assert len(closed) == 0 + + def test_custom_client_sends_request(self, async_interceptor): + async_interceptor.set_response(200, json={"results": []}) + custom_client = httpx.AsyncClient( + headers={"Authorization": "Bearer apim-token"}, + base_url="https://api.tavily.com", + ) + + client = async_tavily.AsyncTavilyClient(api_key="tvly-test", client=custom_client) + asyncio.run(client.search("test query")) + + req = async_interceptor.get_request() + assert req is not None + assert req.headers["Authorization"] == "Bearer apim-token" + + # --- API key validation edge cases --- + + def test_no_api_key_no_client_raises(self, monkeypatch): + monkeypatch.delenv("TAVILY_API_KEY", raising=False) + with pytest.raises(MissingAPIKeyError): + async_tavily.AsyncTavilyClient() + + def test_no_api_key_with_client_allowed(self, monkeypatch): + monkeypatch.delenv("TAVILY_API_KEY", raising=False) + custom_client = httpx.AsyncClient( + headers={"Authorization": "Bearer apim-token"}, + ) + client = async_tavily.AsyncTavilyClient(client=custom_client) + assert client._client.headers["Authorization"] == "Bearer apim-token" + + def test_no_api_key_with_client_no_auth_header_on_defaults(self, monkeypatch): + monkeypatch.delenv("TAVILY_API_KEY", raising=False) + custom_client = httpx.AsyncClient() + client = async_tavily.AsyncTavilyClient(client=custom_client) + # httpx always has headers dict but Authorization shouldn't be added + assert "authorization" not in [k.lower() for k in client._client.headers.keys() + if k.lower() == "authorization" + and client._client.headers[k].startswith("Bearer None")] + + def test_no_api_key_with_client_sends_request(self, async_interceptor, monkeypatch): + monkeypatch.delenv("TAVILY_API_KEY", raising=False) + async_interceptor.set_response(200, json={"results": []}) + custom_client = httpx.AsyncClient( + headers={"Authorization": "Bearer apim-token"}, + base_url="https://api.tavily.com", + ) + + client = async_tavily.AsyncTavilyClient(client=custom_client) + asyncio.run(client.search("test query")) + + req = async_interceptor.get_request() + assert req is not None + assert req.headers["Authorization"] == "Bearer apim-token" + + def test_custom_client_with_all_endpoints(self, async_interceptor): + custom_client = httpx.AsyncClient( + headers={"Authorization": "Bearer apim-token"}, + base_url="https://api.tavily.com", + ) + client = async_tavily.AsyncTavilyClient(client=custom_client) + + # search + async_interceptor.set_response(200, json={"results": []}) + asyncio.run(client.search("test")) + assert async_interceptor.get_request().headers["Authorization"] == "Bearer apim-token" + + # extract + async_interceptor.set_response(200, json={"results": [], "failed_results": []}) + asyncio.run(client.extract(urls=["https://example.com"])) + assert async_interceptor.get_request().headers["Authorization"] == "Bearer apim-token" + + # crawl + async_interceptor.set_response(200, json={"results": []}) + asyncio.run(client.crawl(url="https://example.com")) + assert async_interceptor.get_request().headers["Authorization"] == "Bearer apim-token" + + # map + async_interceptor.set_response(200, json={"results": []}) + asyncio.run(client.map(url="https://example.com")) + assert async_interceptor.get_request().headers["Authorization"] == "Bearer apim-token" + + def test_custom_client_project_id_header(self): + custom_client = httpx.AsyncClient() + client = async_tavily.AsyncTavilyClient( + api_key="tvly-test", + client=custom_client, + project_id="my-project", + ) + assert client._client.headers["X-Project-ID"] == "my-project"