diff --git a/libs/astradb/langchain_astradb/cache.py b/libs/astradb/langchain_astradb/cache.py index bf2fe6a..059bcc2 100644 --- a/libs/astradb/langchain_astradb/cache.py +++ b/libs/astradb/langchain_astradb/cache.py @@ -23,6 +23,7 @@ ) if TYPE_CHECKING: + from astrapy.api_options import APIOptions from astrapy.authentication import TokenProvider from langchain_core.embeddings import Embeddings from langchain_core.language_models import LLM @@ -117,6 +118,7 @@ def __init__( pre_delete_collection: bool = False, setup_mode: SetupMode = SetupMode.SYNC, ext_callers: list[tuple[str | None, str | None] | str | None] | None = None, + api_options: APIOptions | None = None, ): """Cache that uses Astra DB as a backend. @@ -151,6 +153,13 @@ def __init__( or just strings if no version info is provided, which, if supplied, becomes the leading part of the User-Agent string in all API requests related to this component. + api_options: an instance of ``astrapy.utils.api_options.APIOptions`` that + can be supplied to customize the interaction with the Data API + regarding serialization/deserialization, timeouts, custom headers + and so on. The provided options are applied on top of settings already + tailored to this library, and if specified will take precedence. + Passing None (default) means no customization is requested. + Refer to the astrapy documentation for details. """ self.astra_env = _AstraDBCollectionEnvironment( collection_name=collection_name, @@ -162,6 +171,7 @@ def __init__( pre_delete_collection=pre_delete_collection, ext_callers=ext_callers, component_name=COMPONENT_NAME_CACHE, + api_options=api_options, ) self.collection = self.astra_env.collection self.async_collection = self.astra_env.async_collection @@ -328,6 +338,7 @@ def __init__( metric: str | None = None, similarity_threshold: float = ASTRA_DB_SEMANTIC_CACHE_DEFAULT_THRESHOLD, ext_callers: list[tuple[str | None, str | None] | str | None] | None = None, + api_options: APIOptions | None = None, ): """Astra DB semantic cache. @@ -372,6 +383,13 @@ def __init__( or just strings if no version info is provided, which, if supplied, becomes the leading part of the User-Agent string in all API requests related to this component. + api_options: an instance of ``astrapy.utils.api_options.APIOptions`` that + can be supplied to customize the interaction with the Data API + regarding serialization/deserialization, timeouts, custom headers + and so on. The provided options are applied on top of settings already + tailored to this library, and if specified will take precedence. + Passing None (default) means no customization is requested. + Refer to the astrapy documentation for details. """ self.embedding = embedding self.metric = metric @@ -413,6 +431,7 @@ async def _acache_embedding(text: str) -> list[float]: metric=metric, ext_callers=ext_callers, component_name=COMPONENT_NAME_SEMANTICCACHE, + api_options=api_options, ) self.collection = self.astra_env.collection self.async_collection = self.astra_env.async_collection diff --git a/libs/astradb/langchain_astradb/chat_message_histories.py b/libs/astradb/langchain_astradb/chat_message_histories.py index e3a19ad..5eb02ae 100644 --- a/libs/astradb/langchain_astradb/chat_message_histories.py +++ b/libs/astradb/langchain_astradb/chat_message_histories.py @@ -22,6 +22,7 @@ ) if TYPE_CHECKING: + from astrapy.api_options import APIOptions from astrapy.authentication import TokenProvider DEFAULT_COLLECTION_NAME = "langchain_message_store" @@ -40,6 +41,7 @@ def __init__( setup_mode: SetupMode = SetupMode.SYNC, pre_delete_collection: bool = False, ext_callers: list[tuple[str | None, str | None] | str | None] | None = None, + api_options: APIOptions | None = None, ) -> None: """Chat message history that stores history in Astra DB. @@ -70,6 +72,13 @@ def __init__( or just strings if no version info is provided, which, if supplied, becomes the leading part of the User-Agent string in all API requests related to this component. + api_options: an instance of ``astrapy.utils.api_options.APIOptions`` that + can be supplied to customize the interaction with the Data API + regarding serialization/deserialization, timeouts, custom headers + and so on. The provided options are applied on top of settings already + tailored to this library, and if specified will take precedence. + Passing None (default) means no customization is requested. + Refer to the astrapy documentation for details. """ self.astra_env = _AstraDBCollectionEnvironment( collection_name=collection_name, @@ -81,6 +90,7 @@ def __init__( pre_delete_collection=pre_delete_collection, ext_callers=ext_callers, component_name=COMPONENT_NAME_CHATMESSAGEHISTORY, + api_options=api_options, ) self.collection = self.astra_env.collection diff --git a/libs/astradb/langchain_astradb/document_loaders.py b/libs/astradb/langchain_astradb/document_loaders.py index a879aa9..5fa538f 100644 --- a/libs/astradb/langchain_astradb/document_loaders.py +++ b/libs/astradb/langchain_astradb/document_loaders.py @@ -24,6 +24,7 @@ ) if TYPE_CHECKING: + from astrapy.api_options import APIOptions from astrapy.authentication import TokenProvider logger = logging.getLogger(__name__) @@ -47,6 +48,7 @@ def __init__( page_content_mapper: Callable[[dict], str] = json.dumps, metadata_mapper: Callable[[dict], dict[str, Any]] | None = None, ext_callers: list[tuple[str | None, str | None] | str | None] | None = None, + api_options: APIOptions | None = None, ) -> None: """Load DataStax Astra DB documents. @@ -81,6 +83,13 @@ def __init__( or just strings if no version info is provided, which, if supplied, becomes the leading part of the User-Agent string in all API requests related to this component. + api_options: an instance of ``astrapy.utils.api_options.APIOptions`` that + can be supplied to customize the interaction with the Data API + regarding serialization/deserialization, timeouts, custom headers + and so on. The provided options are applied on top of settings already + tailored to this library, and if specified will take precedence. + Passing None (default) means no customization is requested. + Refer to the astrapy documentation for details. """ astra_db_env = _AstraDBCollectionEnvironment( collection_name=collection_name, @@ -91,6 +100,7 @@ def __init__( setup_mode=SetupMode.OFF, ext_callers=ext_callers, component_name=COMPONENT_NAME_LOADER, + api_options=api_options, ) self.astra_db_env = astra_db_env self.filter = filter_criteria diff --git a/libs/astradb/langchain_astradb/storage.py b/libs/astradb/langchain_astradb/storage.py index 34cc3cc..a8ac54b 100644 --- a/libs/astradb/langchain_astradb/storage.py +++ b/libs/astradb/langchain_astradb/storage.py @@ -30,6 +30,7 @@ ) if TYPE_CHECKING: + from astrapy.api_options import APIOptions from astrapy.authentication import TokenProvider from astrapy.results import CollectionUpdateResult @@ -246,6 +247,7 @@ def __init__( pre_delete_collection: bool = False, setup_mode: SetupMode = SetupMode.SYNC, ext_callers: list[tuple[str | None, str | None] | str | None] | None = None, + api_options: APIOptions | None = None, ) -> None: """BaseStore implementation using DataStax AstraDB as the underlying store. @@ -286,6 +288,13 @@ def __init__( or just strings if no version info is provided, which, if supplied, becomes the leading part of the User-Agent string in all API requests related to this component. + api_options: an instance of ``astrapy.utils.api_options.APIOptions`` that + can be supplied to customize the interaction with the Data API + regarding serialization/deserialization, timeouts, custom headers + and so on. The provided options are applied on top of settings already + tailored to this library, and if specified will take precedence. + Passing None (default) means no customization is requested. + Refer to the astrapy documentation for details. """ super().__init__( collection_name=collection_name, @@ -297,6 +306,7 @@ def __init__( pre_delete_collection=pre_delete_collection, ext_callers=ext_callers, component_name=COMPONENT_NAME_STORE, + api_options=api_options, ) @override @@ -320,6 +330,7 @@ def __init__( pre_delete_collection: bool = False, setup_mode: SetupMode = SetupMode.SYNC, ext_callers: list[tuple[str | None, str | None] | str | None] | None = None, + api_options: APIOptions | None = None, ) -> None: """ByteStore implementation using DataStax AstraDB as the underlying store. @@ -357,6 +368,13 @@ def __init__( or just strings if no version info is provided, which, if supplied, becomes the leading part of the User-Agent string in all API requests related to this component. + api_options: an instance of ``astrapy.utils.api_options.APIOptions`` that + can be supplied to customize the interaction with the Data API + regarding serialization/deserialization, timeouts, custom headers + and so on. The provided options are applied on top of settings already + tailored to this library, and if specified will take precedence. + Passing None (default) means no customization is requested. + Refer to the astrapy documentation for details. """ super().__init__( collection_name=collection_name, @@ -368,6 +386,7 @@ def __init__( pre_delete_collection=pre_delete_collection, ext_callers=ext_callers, component_name=COMPONENT_NAME_BYTESTORE, + api_options=api_options, ) @override diff --git a/libs/astradb/langchain_astradb/utils/astradb.py b/libs/astradb/langchain_astradb/utils/astradb.py index 6538cb5..736a55a 100644 --- a/libs/astradb/langchain_astradb/utils/astradb.py +++ b/libs/astradb/langchain_astradb/utils/astradb.py @@ -33,6 +33,7 @@ CollectionRerankOptions, RerankServiceOptions, ) +from astrapy.utils.api_options import defaultAPIOptions if TYPE_CHECKING: from astrapy.info import CollectionDescriptor, VectorServiceOptions @@ -139,6 +140,7 @@ def _survey_collection( environment: str | None = None, ext_callers: list[tuple[str | None, str | None] | str | None] | None = None, component_name: str | None = None, + api_options: APIOptions | None = None, ) -> tuple[CollectionDescriptor | None, list[dict[str, Any]]]: """Return the collection descriptor (if found) and a sample of documents.""" _astra_db_env = _AstraDBEnvironment( @@ -148,6 +150,7 @@ def _survey_collection( environment=environment, ext_callers=ext_callers, component_name=component_name, + api_options=api_options, ) descriptors = [ coll_d @@ -207,11 +210,13 @@ def __init__( environment: str | None = None, ext_callers: list[tuple[str | None, str | None] | str | None] | None = None, component_name: str | None = None, + api_options: APIOptions | None = None, ) -> None: self.token: TokenProvider self.api_endpoint: str | None self.keyspace: str | None self.environment: str | None + self.api_options: APIOptions | None self.data_api_client: DataAPIClient self.database: Database @@ -270,6 +275,8 @@ def __init__( self.api_endpoint, ) + self.api_options = api_options + # prepare the "callers" list to create the clients. # The callers, passed to astrapy, are made of these Caller pairs in this order: # - zero, one or more are the "ext_callers" passed to this environment @@ -292,9 +299,10 @@ def __init__( (self.component_name, LC_ASTRADB_VERSION), ] # create the client (set to return plain lists for vectors) - self.data_api_client = DataAPIClient( - environment=self.environment, - api_options=APIOptions( + # first must take care of two levels of customizing of the base astrapy options + astrapy_default_api_options = defaultAPIOptions(self.environment) + adapted_default_api_options = astrapy_default_api_options.with_override( + APIOptions( callers=self.full_callers, serdes_options=SerdesOptions(custom_datatypes_in_reading=False), timeout_options=TimeoutOptions( @@ -302,6 +310,13 @@ def __init__( ), ), ) + final_api_options = adapted_default_api_options.with_override( + api_options if api_options is not None else APIOptions() + ) + self.data_api_client = DataAPIClient( + environment=self.environment, + api_options=final_api_options, + ) self.database = self.data_api_client.get_database( api_endpoint=self.api_endpoint, @@ -322,6 +337,7 @@ def __init__( environment: str | None = None, ext_callers: list[tuple[str | None, str | None] | str | None] | None = None, component_name: str | None = None, + api_options: APIOptions | None = None, setup_mode: SetupMode = SetupMode.SYNC, pre_delete_collection: bool = False, embedding_dimension: int | Awaitable[int] | None = None, @@ -347,6 +363,7 @@ def __init__( environment=environment, ext_callers=ext_callers, component_name=component_name, + api_options=api_options, ) self.collection_name = collection_name self.collection_embedding_api_key = ( @@ -493,6 +510,7 @@ def copy( component_name=self.component_name if component_name is None else component_name, + api_options=self.api_options, setup_mode=SetupMode.OFF, collection_embedding_api_key=self.collection_embedding_api_key if collection_embedding_api_key is None diff --git a/libs/astradb/langchain_astradb/utils/vector_store_codecs.py b/libs/astradb/langchain_astradb/utils/vector_store_codecs.py index c0e8757..834017f 100644 --- a/libs/astradb/langchain_astradb/utils/vector_store_codecs.py +++ b/libs/astradb/langchain_astradb/utils/vector_store_codecs.py @@ -305,7 +305,7 @@ def encode_query( the resulting query would return Astra DB documents matching the metadata clause AND having an ID among those provided to this method. If, instead, an OR is required, one should run two separate queries and subsequently merge - the result (taking care of avoiding duplcates). + the result (taking care of avoiding duplicates). Args: ids: an iterable over Document IDs. If provided, the resulting Astra DB diff --git a/libs/astradb/langchain_astradb/vectorstores.py b/libs/astradb/langchain_astradb/vectorstores.py index 7749d66..78fb4e0 100644 --- a/libs/astradb/langchain_astradb/vectorstores.py +++ b/libs/astradb/langchain_astradb/vectorstores.py @@ -63,6 +63,7 @@ ) if TYPE_CHECKING: + from astrapy.api_options import APIOptions from astrapy.authentication import ( EmbeddingHeadersProvider, RerankingHeadersProvider, @@ -104,7 +105,7 @@ class AstraDBQueryResult(NamedTuple): This class represents all that can be returned from the collection when running a query, which goes beyond just the corresponding Document. - Atributes: + Attributes: document: a ``langchain.schema.Document`` object representing the query result. id: the ID of the returned document. embedding: the embedding vector associated to the document. This may be None, @@ -656,6 +657,7 @@ def __init__( autodetect_collection: bool = False, ext_callers: list[tuple[str | None, str | None] | str | None] | None = None, component_name: str = COMPONENT_NAME_VECTORSTORE, + api_options: APIOptions | None = None, collection_rerank: CollectionRerankOptions | RerankServiceOptions | None = None, collection_reranking_api_key: str | RerankingHeadersProvider | None = None, collection_lexical: str @@ -668,7 +670,7 @@ def __init__( | dict[str, float] | HybridLimitFactorPrescription = None, ) -> None: - """A vector store wich uses DataStax Astra DB as backend. + """A vector store which uses DataStax Astra DB as backend. For more on Astra DB, visit https://docs.datastax.com/en/astra-db-serverless/index.html @@ -772,6 +774,13 @@ def __init__( Defaults to "langchain_vectorstore", but can be overridden if this component actually serves as the building block for another component (such as when the vector store is used within a ``GraphRetriever``). + api_options: an instance of ``astrapy.utils.api_options.APIOptions`` that + can be supplied to customize the interaction with the Data API + regarding serialization/deserialization, timeouts, custom headers + and so on. The provided options are applied on top of settings already + tailored to this library, and if specified will take precedence. + Passing None (default) means no customization is requested. + Refer to the astrapy documentation for details. collection_rerank: providing reranking settings is necessary to run hybrid searches for similarity. This parameter can be an instance of the astrapy classes `CollectionRerankOptions` or @@ -937,6 +946,7 @@ def __init__( environment=self.environment, ext_callers=ext_callers, component_name=component_name, + api_options=api_options, ) if c_descriptor is None: msg = f"Collection '{self.collection_name}' not found." @@ -1018,6 +1028,7 @@ def __init__( collection_embedding_api_key=self.collection_embedding_api_key, ext_callers=ext_callers, component_name=component_name, + api_options=api_options, collection_rerank=collection_rerank, collection_reranking_api_key=self.collection_reranking_api_key, collection_lexical=collection_lexical, diff --git a/libs/astradb/tests/integration_tests/test_vectorstore.py b/libs/astradb/tests/integration_tests/test_vectorstore.py index d709557..601e56b 100644 --- a/libs/astradb/tests/integration_tests/test_vectorstore.py +++ b/libs/astradb/tests/integration_tests/test_vectorstore.py @@ -9,12 +9,14 @@ from typing import TYPE_CHECKING, Any import pytest +from astrapy.api_options import APIOptions, TimeoutOptions from astrapy.authentication import ( EmbeddingAPIKeyHeaderProvider, RerankingAPIKeyHeaderProvider, StaticTokenProvider, ) from astrapy.constants import SortMode +from astrapy.exceptions import DataAPITimeoutException from langchain_core.documents import Document from langchain_astradb.utils.astradb import COMPONENT_NAME_VECTORSTORE, SetupMode @@ -2152,3 +2154,35 @@ async def test_astradb_vectorstore_arun_query( ) hits9d_l = [tpl async for tpl in hits9d] assert [doc_id for _, doc_id, _, _ in hits9d_l] == ["10", "9", "8"] + + def test_astradb_vectorstore_custom_api_options( + self, + astra_db_credentials: AstraDBCredentials, + empty_collection_d2: Collection, + embedding_d2: Embeddings, + ) -> None: + """Craft a custom APIOptions (very low timeout), expect a timeout to occur.""" + baseline_v_store = AstraDBVectorStore( + embedding=embedding_d2, + collection_name=empty_collection_d2.name, + token=StaticTokenProvider(astra_db_credentials["token"]), + api_endpoint=astra_db_credentials["api_endpoint"], + namespace=astra_db_credentials["namespace"], + environment=astra_db_credentials["environment"], + setup_mode=SetupMode.OFF, + ) + baseline_v_store.similarity_search("[0,1]") + + impatient_ao = APIOptions(timeout_options=TimeoutOptions(request_timeout_ms=1)) + impatient_v_store = AstraDBVectorStore( + embedding=embedding_d2, + collection_name=empty_collection_d2.name, + token=StaticTokenProvider(astra_db_credentials["token"]), + api_endpoint=astra_db_credentials["api_endpoint"], + namespace=astra_db_credentials["namespace"], + environment=astra_db_credentials["environment"], + setup_mode=SetupMode.OFF, + api_options=impatient_ao, + ) + with pytest.raises(DataAPITimeoutException): + impatient_v_store.similarity_search("[0,1]")