From a6941ba09bd8ea13cb8e8129e9a117325894d3cb Mon Sep 17 00:00:00 2001 From: Stefano Lottini Date: Tue, 1 Oct 2024 16:31:50 +0200 Subject: [PATCH] completed internal switching to keyspace for astrapy 1.5+ --- libs/astradb/langchain_astradb/cache.py | 4 +- .../chat_message_histories.py | 2 +- .../langchain_astradb/document_loaders.py | 4 +- libs/astradb/langchain_astradb/storage.py | 4 ++ .../langchain_astradb/utils/astradb.py | 42 +++++++++---------- .../astradb/langchain_astradb/vectorstores.py | 4 +- .../tests/integration_tests/conftest.py | 2 +- .../test_document_loaders.py | 6 +-- .../unit_tests/test_astra_db_environment.py | 28 ++++++------- 9 files changed, 50 insertions(+), 46 deletions(-) diff --git a/libs/astradb/langchain_astradb/cache.py b/libs/astradb/langchain_astradb/cache.py index adc1b69..01d26e3 100644 --- a/libs/astradb/langchain_astradb/cache.py +++ b/libs/astradb/langchain_astradb/cache.py @@ -156,7 +156,7 @@ def __init__( environment=environment, astra_db_client=astra_db_client, async_astra_db_client=async_astra_db_client, - namespace=namespace, + keyspace=namespace, setup_mode=setup_mode, pre_delete_collection=pre_delete_collection, ) @@ -411,7 +411,7 @@ async def _acache_embedding(text: str) -> list[float]: environment=environment, astra_db_client=astra_db_client, async_astra_db_client=async_astra_db_client, - namespace=namespace, + keyspace=namespace, setup_mode=setup_mode, pre_delete_collection=pre_delete_collection, embedding_dimension=embedding_dimension, diff --git a/libs/astradb/langchain_astradb/chat_message_histories.py b/libs/astradb/langchain_astradb/chat_message_histories.py index 1b71707..327b413 100644 --- a/libs/astradb/langchain_astradb/chat_message_histories.py +++ b/libs/astradb/langchain_astradb/chat_message_histories.py @@ -81,7 +81,7 @@ def __init__( environment=environment, astra_db_client=astra_db_client, async_astra_db_client=async_astra_db_client, - namespace=namespace, + keyspace=namespace, setup_mode=setup_mode, pre_delete_collection=pre_delete_collection, ) diff --git a/libs/astradb/langchain_astradb/document_loaders.py b/libs/astradb/langchain_astradb/document_loaders.py index 37ee41d..e838431 100644 --- a/libs/astradb/langchain_astradb/document_loaders.py +++ b/libs/astradb/langchain_astradb/document_loaders.py @@ -99,7 +99,7 @@ def __init__( environment=environment, astra_db_client=astra_db_client, async_astra_db_client=async_astra_db_client, - namespace=namespace, + keyspace=namespace, setup_mode=SetupMode.OFF, ) self.astra_db_env = astra_db_env @@ -151,7 +151,7 @@ def __init__( self.page_content_mapper = page_content_mapper self.metadata_mapper = metadata_mapper or ( lambda _: { - "namespace": self.astra_db_env.database.namespace, + "namespace": self.astra_db_env.database.keyspace, "api_endpoint": self.astra_db_env.database.api_endpoint, "collection": collection_name, } diff --git a/libs/astradb/langchain_astradb/storage.py b/libs/astradb/langchain_astradb/storage.py index f55d6c0..16d6c86 100644 --- a/libs/astradb/langchain_astradb/storage.py +++ b/libs/astradb/langchain_astradb/storage.py @@ -46,6 +46,10 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: raise ValueError(msg) kwargs["requested_indexing_policy"] = {"allow": ["_id"]} kwargs["default_indexing_policy"] = {"allow": ["_id"]} + + if "namespace" in kwargs: + kwargs["keyspace"] = kwargs.pop("namespace") + self.astra_env = _AstraDBCollectionEnvironment( *args, **kwargs, diff --git a/libs/astradb/langchain_astradb/utils/astradb.py b/libs/astradb/langchain_astradb/utils/astradb.py index 6c67d34..64291cd 100644 --- a/libs/astradb/langchain_astradb/utils/astradb.py +++ b/libs/astradb/langchain_astradb/utils/astradb.py @@ -23,7 +23,7 @@ TOKEN_ENV_VAR = "ASTRA_DB_APPLICATION_TOKEN" # noqa: S105 API_ENDPOINT_ENV_VAR = "ASTRA_DB_API_ENDPOINT" -NAMESPACE_ENV_VAR = "ASTRA_DB_KEYSPACE" +KEYSPACE_ENV_VAR = "ASTRA_DB_KEYSPACE" # Default settings for API data operations (concurrency & similar): # Chunk size for many-document insertions (None meaning defer to astrapy): @@ -57,7 +57,7 @@ def _survey_collection( environment: str | None = None, astra_db_client: AstraDB | None = None, async_astra_db_client: AsyncAstraDB | None = None, - namespace: str | None = None, + keyspace: str | None = None, ) -> tuple[CollectionDescriptor | None, list[dict[str, Any]]]: """Return the collection descriptor (if found) and a sample of documents.""" _environment = _AstraDBEnvironment( @@ -66,7 +66,7 @@ def _survey_collection( environment=environment, astra_db_client=astra_db_client, async_astra_db_client=async_astra_db_client, - namespace=namespace, + keyspace=keyspace, ) descriptors = [ coll_d @@ -93,11 +93,11 @@ def __init__( environment: str | None = None, astra_db_client: AstraDB | None = None, async_astra_db_client: AsyncAstraDB | None = None, - namespace: str | None = None, + keyspace: str | None = None, ) -> None: self.token: str | TokenProvider | None self.api_endpoint: str | None - self.namespace: str | None + self.keyspace: str | None self.environment: str | None self.data_api_client: DataAPIClient @@ -147,7 +147,7 @@ def __init__( if klient is not None } ) - _namespaces = list( + _keyspaces = list( { klient.namespace for klient in [astra_db_client, async_astra_db_client] @@ -164,21 +164,21 @@ def __init__( if len(_api_endpoints) != 1: msg = ( "Conflicting API endpoints found in the sync and async " - "AstraDB constructor parameters. Please check the tokens " + "AstraDB constructor parameters. Please check the endpoints " "and ensure they match." ) raise ValueError(msg) - if len(_namespaces) != 1: + if len(_keyspaces) != 1: msg = ( - "Conflicting namespaces found in the sync and async " - "AstraDB constructor parameters. Please check the tokens " - "and ensure they match." + "Conflicting keyspaces found in the sync and async " + "AstraDB constructor parameters' 'namespace' attributes. " + "Please check the keyspaces and ensure they match." ) raise ValueError(msg) # all good: these are 1-element lists here self.token = _tokens[0] self.api_endpoint = _api_endpoints[0] - self.namespace = _namespaces[0] + self.keyspace = _keyspaces[0] else: _token: str | TokenProvider | None # secrets-based initialization @@ -199,19 +199,19 @@ def __init__( _api_endpoint = os.environ.get(API_ENDPOINT_ENV_VAR) else: _api_endpoint = api_endpoint - if namespace is None: - _namespace = os.environ.get(NAMESPACE_ENV_VAR) + if keyspace is None: + _keyspace = os.environ.get(KEYSPACE_ENV_VAR) else: - _namespace = namespace + _keyspace = keyspace self.token = _token self.api_endpoint = _api_endpoint - self.namespace = _namespace + self.keyspace = _keyspace self.environment = environment - # init parameters are normalized to self.{token, api_endpoint, namespace}. - # Proceed. Namespace and token can be None (resp. on Astra DB and non-Astra) + # init parameters are normalized to self.{token, api_endpoint, keyspace}. + # Proceed. Keyspace and token can be None (resp. on Astra DB and non-Astra) if self.api_endpoint is None: msg = ( "API endpoint for Data API not provided. " @@ -232,7 +232,7 @@ def __init__( self.database = self.data_api_client.get_database( api_endpoint=self.api_endpoint, token=self.token, - keyspace=self.namespace, + keyspace=self.keyspace, ) self.async_database = self.database.to_async() @@ -247,7 +247,7 @@ def __init__( environment: str | None = None, astra_db_client: AstraDB | None = None, async_astra_db_client: AsyncAstraDB | None = None, - namespace: str | None = None, + keyspace: str | None = None, setup_mode: SetupMode = SetupMode.SYNC, pre_delete_collection: bool = False, embedding_dimension: int | Awaitable[int] | None = None, @@ -263,7 +263,7 @@ def __init__( environment=environment, astra_db_client=astra_db_client, async_astra_db_client=async_astra_db_client, - namespace=namespace, + keyspace=keyspace, ) self.collection_name = collection_name self.collection = self.database.get_collection( diff --git a/libs/astradb/langchain_astradb/vectorstores.py b/libs/astradb/langchain_astradb/vectorstores.py index 4a2ee8f..8b1576a 100644 --- a/libs/astradb/langchain_astradb/vectorstores.py +++ b/libs/astradb/langchain_astradb/vectorstores.py @@ -597,7 +597,7 @@ def __init__( environment=self.environment, astra_db_client=astra_db_client, async_astra_db_client=async_astra_db_client, - namespace=self.namespace, + keyspace=self.namespace, ) if c_descriptor is None: msg = f"Collection '{self.collection_name}' not found." @@ -653,7 +653,7 @@ def __init__( environment=self.environment, astra_db_client=astra_db_client, async_astra_db_client=async_astra_db_client, - namespace=self.namespace, + keyspace=self.namespace, setup_mode=_setup_mode, pre_delete_collection=pre_delete_collection, embedding_dimension=_embedding_dimension, diff --git a/libs/astradb/tests/integration_tests/conftest.py b/libs/astradb/tests/integration_tests/conftest.py index 9b2010e..2799088 100644 --- a/libs/astradb/tests/integration_tests/conftest.py +++ b/libs/astradb/tests/integration_tests/conftest.py @@ -175,7 +175,7 @@ def database( db = client.get_database( astra_db_credentials["api_endpoint"], token=StaticTokenProvider(astra_db_credentials["token"]), - namespace=astra_db_credentials["namespace"], + keyspace=astra_db_credentials["namespace"], ) if not is_astra_db: if astra_db_credentials["namespace"] is None: diff --git a/libs/astradb/tests/integration_tests/test_document_loaders.py b/libs/astradb/tests/integration_tests/test_document_loaders.py index 99449c0..e83e2ad 100644 --- a/libs/astradb/tests/integration_tests/test_document_loaders.py +++ b/libs/astradb/tests/integration_tests/test_document_loaders.py @@ -96,7 +96,7 @@ def test_astradb_loader_base_sync( assert content["_id"] not in ids ids.add(content["_id"]) assert doc.metadata == { - "namespace": database.namespace, + "namespace": database.keyspace, "api_endpoint": astra_db_credentials["api_endpoint"], "collection": document_loader_collection.name, } @@ -189,7 +189,7 @@ async def test_astradb_loader_prefetched_async( assert content["_id"] not in ids ids.add(content["_id"]) assert doc.metadata == { - "namespace": database.namespace, + "namespace": database.keyspace, "api_endpoint": astra_db_credentials["api_endpoint"], "collection": async_document_loader_collection.name, } @@ -234,7 +234,7 @@ async def test_astradb_loader_base_async( assert content["_id"] not in ids ids.add(content["_id"]) assert doc.metadata == { - "namespace": database.namespace, + "namespace": database.keyspace, "api_endpoint": astra_db_credentials["api_endpoint"], "collection": async_document_loader_collection.name, } diff --git a/libs/astradb/tests/unit_tests/test_astra_db_environment.py b/libs/astradb/tests/unit_tests/test_astra_db_environment.py index 7d0137c..e146f2a 100644 --- a/libs/astradb/tests/unit_tests/test_astra_db_environment.py +++ b/libs/astradb/tests/unit_tests/test_astra_db_environment.py @@ -5,7 +5,7 @@ from langchain_astradb.utils.astradb import ( API_ENDPOINT_ENV_VAR, - NAMESPACE_ENV_VAR, + KEYSPACE_ENV_VAR, TOKEN_ENV_VAR, _AstraDBEnvironment, ) @@ -47,15 +47,15 @@ def test_initialization(self) -> None: API_ENDPOINT_ENV_VAR ] del os.environ[API_ENDPOINT_ENV_VAR] - if NAMESPACE_ENV_VAR in os.environ: - env_vars_to_restore[NAMESPACE_ENV_VAR] = os.environ[NAMESPACE_ENV_VAR] - del os.environ[NAMESPACE_ENV_VAR] + if KEYSPACE_ENV_VAR in os.environ: + env_vars_to_restore[KEYSPACE_ENV_VAR] = os.environ[KEYSPACE_ENV_VAR] + del os.environ[KEYSPACE_ENV_VAR] # token+endpoint env1 = _AstraDBEnvironment( token=FAKE_TOKEN, api_endpoint=a_e_string, - namespace="n", + keyspace="n", ) # through a core AstraDB instance @@ -126,7 +126,7 @@ def test_initialization(self) -> None: ) with pytest.raises( ValueError, - match="Conflicting namespaces found in the sync and async AstraDB " + match="Conflicting keyspaces found in the sync and async AstraDB " "constructor parameters.", ), pytest.warns(DeprecationWarning): _AstraDBEnvironment( @@ -174,7 +174,7 @@ def test_initialization(self) -> None: os.environ[TOKEN_ENV_VAR] = "t" env4 = _AstraDBEnvironment( api_endpoint=a_e_string, - namespace="n", + keyspace="n", ) del os.environ[TOKEN_ENV_VAR] assert env1.data_api_client == env4.data_api_client @@ -185,7 +185,7 @@ def test_initialization(self) -> None: os.environ[API_ENDPOINT_ENV_VAR] = a_e_string env5 = _AstraDBEnvironment( token=FAKE_TOKEN, - namespace="n", + keyspace="n", ) del os.environ[API_ENDPOINT_ENV_VAR] assert env1.data_api_client == env5.data_api_client @@ -195,19 +195,19 @@ def test_initialization(self) -> None: # both and also namespace via env vars os.environ[TOKEN_ENV_VAR] = FAKE_TOKEN os.environ[API_ENDPOINT_ENV_VAR] = a_e_string - os.environ[NAMESPACE_ENV_VAR] = "n" + os.environ[KEYSPACE_ENV_VAR] = "n" env6 = _AstraDBEnvironment() assert env1.data_api_client == env6.data_api_client assert env1.database == env6.database assert env1.async_database == env6.async_database del os.environ[TOKEN_ENV_VAR] del os.environ[API_ENDPOINT_ENV_VAR] - del os.environ[NAMESPACE_ENV_VAR] + del os.environ[KEYSPACE_ENV_VAR] # env vars do not interfere if client(s) passed os.environ[TOKEN_ENV_VAR] = "NO!" os.environ[API_ENDPOINT_ENV_VAR] = "NO!" - os.environ[NAMESPACE_ENV_VAR] = "NO!" + os.environ[KEYSPACE_ENV_VAR] = "NO!" with pytest.warns(DeprecationWarning): env7a = _AstraDBEnvironment( async_astra_db_client=mock_astra_db.to_async(), @@ -235,7 +235,7 @@ def test_initialization(self) -> None: env8 = _AstraDBEnvironment( token=FAKE_TOKEN, api_endpoint=a_e_string, - namespace="n", + keyspace="n", ) assert env1.data_api_client == env8.data_api_client assert env1.database == env8.database @@ -247,7 +247,7 @@ def test_initialization(self) -> None: del os.environ[TOKEN_ENV_VAR] if API_ENDPOINT_ENV_VAR in os.environ: del os.environ[API_ENDPOINT_ENV_VAR] - if NAMESPACE_ENV_VAR in os.environ: - del os.environ[NAMESPACE_ENV_VAR] + if KEYSPACE_ENV_VAR in os.environ: + del os.environ[KEYSPACE_ENV_VAR] for env_var_name, env_var_value in env_vars_to_restore.items(): os.environ[env_var_name] = env_var_value