From 891d57d1ebbc19570cb5eaf0f34fe9b55b097613 Mon Sep 17 00:00:00 2001 From: Sichen Liu Date: Wed, 3 Apr 2024 22:56:09 +0000 Subject: [PATCH 1/7] support sync vector query --- google/cloud/firestore_v1/async_query.py | 30 +++ .../cloud/firestore_v1/async_vector_query.py | 136 ++++++++++++ tests/unit/v1/_test_helpers.py | 6 + tests/unit/v1/test_async_vector_query.py | 205 ++++++++++++++++++ 4 files changed, 377 insertions(+) create mode 100644 google/cloud/firestore_v1/async_vector_query.py create mode 100644 tests/unit/v1/test_async_vector_query.py diff --git a/google/cloud/firestore_v1/async_query.py b/google/cloud/firestore_v1/async_query.py index 8ee4012904..20a4c9fb3c 100644 --- a/google/cloud/firestore_v1/async_query.py +++ b/google/cloud/firestore_v1/async_query.py @@ -36,6 +36,8 @@ from google.cloud.firestore_v1 import async_document from google.cloud.firestore_v1.async_aggregation import AsyncAggregationQuery from google.cloud.firestore_v1.base_document import DocumentSnapshot +from google.cloud.firestore_v1.async_vector_query import AsyncVectorQuery +from google.cloud.firestore_v1.vector import Vector from typing import AsyncGenerator, List, Optional, Type, TYPE_CHECKING if TYPE_CHECKING: # pragma: NO COVER @@ -217,6 +219,34 @@ async def get( return result + def find_nearest( + self, + vector_field: str, + query_vector: Vector, + limit: int, + distance_measure: DistanceMeasure, + ) -> AsyncVectorQuery: + """ + Finds the closest vector embeddings to the given query vector. + + Args: + vector_field(str): An indexed vector field to search upon. Only documents which contain + vectors whose dimensionality match the query_vector can be returned. + query_vector(Vector): The query vector that we are searching on. Must be a vector of no more + than 2048 dimensions. + limit (int): The number of nearest neighbors to return. Must be a positive integer of no more than 1000. + distance_measure(:class:`DistanceMeasure`): The Distance Measure to use. + + Returns: + :class`~firestore_v1.vector_query.VectorQuery`: the vector query. + """ + return AsyncVectorQuery(self).find_nearest( + vector_field=vector_field, + query_vector=query_vector, + limit=limit, + distance_measure=distance_measure, + ) + def count( self, alias: str | None = None ) -> Type["firestore_v1.async_aggregation.AsyncAggregationQuery"]: diff --git a/google/cloud/firestore_v1/async_vector_query.py b/google/cloud/firestore_v1/async_vector_query.py new file mode 100644 index 0000000000..d9c12e4c48 --- /dev/null +++ b/google/cloud/firestore_v1/async_vector_query.py @@ -0,0 +1,136 @@ +# Copyright 2024 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Classes for representing Async aggregation queries for the Google Cloud Firestore API. + +A :class:`~google.cloud.firestore_v1.async_aggregation.AsyncAggregationQuery` can be created directly from +a :class:`~google.cloud.firestore_v1.async_collection.AsyncCollection` and that can be +a more common way to create an aggregation query than direct usage of the constructor. +""" +from __future__ import annotations + +from google.cloud.firestore_v1.base_vector_query import BaseVectorQuery +from google.api_core import gapic_v1 +from google.api_core import retry_async as retries +from google.cloud.firestore_v1.base_document import DocumentSnapshot +from typing import AsyncGenerator, List, Union, Optional, TypeVar + + +from google.cloud.firestore_v1.base_vector_query import ( + BaseVectorQuery, +) +from google.cloud.firestore_v1.base_query import ( + _query_response_to_snapshot, + _collection_group_query_response_to_snapshot, +) + +TAsyncVectorQuery = TypeVar("TAsyncVectorQuery", bound="AsyncVectorQuery") + + +class AsyncVectorQuery(BaseVectorQuery): + """Represents an async vector query to the Firestore API.""" + + def __init__( + self, + nested_query: Union[BaseQuery, TAsyncVectorQuery], + ) -> None: + """Presents the vector query. + Args: + nested_query (BaseQuery | VectorQuery): the base query to apply as the prefilter. + """ + super(AsyncVectorQuery, self).__init__(nested_query) + + async def get( + self, + transaction=None, + retry: retries.AsyncRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + ) -> List[DocumentSnapshot]: + """Runs the vector query. + + This sends a ``RunQuery`` RPC and returns a list of document messages. + + Args: + transaction + (Optional[:class:`~google.cloud.firestore_v1.transaction.Transaction`]): + An existing transaction that this query will run in. + If a ``transaction`` is used and it already has write operations + added, this method cannot be used (i.e. read-after-write is not + allowed). + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. Defaults to a system-specified policy. + timeout (float): The timeout for this request. Defaults to a + system-specified value. + + Returns: + list: The vector query results. + """ + stream_result = self.stream( + transaction=transaction, retry=retry, timeout=timeout + ) + result = [aggregation async for aggregation in stream_result] + return result # type: ignore + + async def stream( + self, + transaction=None, + retry: retries.AsyncRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + ) -> AsyncGenerator[async_document.DocumentSnapshot, None]: + """Reads the documents in the collection that match this query. + + This sends a ``RunQuery`` RPC and then returns an iterator which + consumes each document returned in the stream of ``RunQueryResponse`` + messages. + + If a ``transaction`` is used and it already has write operations + added, this method cannot be used (i.e. read-after-write is not + allowed). + + Args: + transaction + (Optional[:class:`~google.cloud.firestore_v1.transaction.Transaction`]): + An existing transaction that this query will run in. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. Defaults to a system-specified policy. + timeout (float): The timeout for this request. Defaults to a + system-specified value. + + Yields: + :class:`~google.cloud.firestore_v1.document.DocumentSnapshot`: + The next document that fulfills the query. + """ + request, expected_prefix, kwargs = self._prep_stream( + transaction, + retry, + timeout, + ) + + response_iterator = await self._client._firestore_api.run_query( + request=request, + metadata=self._client._rpc_metadata, + **kwargs, + ) + + async for response in response_iterator: + if self._nested_query._all_descendants: + snapshot = _collection_group_query_response_to_snapshot( + response, self._nested_query._parent + ) + else: + snapshot = _query_response_to_snapshot( + response, self._nested_query._parent, expected_prefix + ) + if snapshot is not None: + yield snapshot diff --git a/tests/unit/v1/_test_helpers.py b/tests/unit/v1/_test_helpers.py index 2734d78751..febb77a418 100644 --- a/tests/unit/v1/_test_helpers.py +++ b/tests/unit/v1/_test_helpers.py @@ -109,6 +109,12 @@ def make_vector_query(*args, **kw): return VectorQuery(*args, **kw) +def make_async_vector_query(*args, **kw): + from google.cloud.firestore_v1.async_vector_query import AsyncVectorQuery + + return AsyncVectorQuery(*args, **kw) + + def build_test_timestamp( year: int = 2021, month: int = 1, diff --git a/tests/unit/v1/test_async_vector_query.py b/tests/unit/v1/test_async_vector_query.py new file mode 100644 index 0000000000..18053c5e55 --- /dev/null +++ b/tests/unit/v1/test_async_vector_query.py @@ -0,0 +1,205 @@ +# Copyright 2024 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import mock +import pytest +import types + +from google.cloud.firestore_v1.types.query import StructuredQuery +from google.cloud.firestore_v1.vector import Vector +from google.cloud.firestore_v1.base_vector_query import DistanceMeasure + +from tests.unit.v1.test__helpers import AsyncIter, AsyncMock +from tests.unit.v1._test_helpers import ( + make_async_query, + make_async_client, + make_query, +) +from tests.unit.v1.test_base_query import _make_query_response +from google.cloud.firestore_v1._helpers import encode_value, make_retry_timeout_kwargs + +_PROJECT = "PROJECT" +_TXN_ID = b"\x00\x00\x01-work-\xf2" + + +def _transaction(client): + transaction = client.transaction() + txn_id = _TXN_ID + transaction._id = txn_id + return transaction + + +def _expected_pb(parent, vector_field, vector, distance_type, limit): + query = make_query(parent) + expected_pb = query._to_protobuf() + expected_pb.find_nearest = StructuredQuery.FindNearest( + vector_field=StructuredQuery.FieldReference(field_path=vector_field), + query_vector=encode_value(vector.to_map_value()), + distance_measure=distance_type, + limit=limit, + ) + return expected_pb + + +@pytest.mark.parametrize( + "distance_measure, expected_distance", + [ + ( + DistanceMeasure.EUCLIDEAN, + StructuredQuery.FindNearest.DistanceMeasure.EUCLIDEAN, + ), + (DistanceMeasure.COSINE, StructuredQuery.FindNearest.DistanceMeasure.COSINE), + ( + DistanceMeasure.DOT_PRODUCT, + StructuredQuery.FindNearest.DistanceMeasure.DOT_PRODUCT, + ), + ], +) +@pytest.mark.asyncio +async def test_async_vector_query_with_filter(distance_measure, expected_distance): + # Create a minimal fake GAPIC. + firestore_api = AsyncMock(spec=["run_query"]) + client = make_async_client() + client._firestore_api_internal = firestore_api + + # Make a **real** collection reference as parent. + parent = client.collection("dee") + query = make_async_query(parent) + parent_path, expected_prefix = parent._parent_info() + + data = {"snooze": 10, "embedding": Vector([1.0, 2.0, 3.0])} + response_pb1 = _make_query_response( + name="{}/test_doc".format(expected_prefix), data=data + ) + response_pb2 = _make_query_response( + name="{}/test_doc".format(expected_prefix), data=data + ) + + kwargs = make_retry_timeout_kwargs(retry=None, timeout=None) + + # Execute the vector query and check the response. + firestore_api.run_query.return_value = AsyncIter([response_pb1, response_pb2]) + + vector_async__query = query.where("snooze", "==", 10).find_nearest( + vector_field="embedding", + query_vector=Vector([1.0, 2.0, 3.0]), + distance_measure=distance_measure, + limit=5, + ) + + returned = await vector_async__query.get(transaction=_transaction(client), **kwargs) + assert isinstance(returned, list) + assert len(returned) == 2 + assert returned[0].to_dict() == data + + expected_pb = _expected_pb( + parent=parent, + vector_field="embedding", + vector=Vector([1.0, 2.0, 3.0]), + distance_type=expected_distance, + limit=5, + ) + expected_pb.where = StructuredQuery.Filter( + field_filter=StructuredQuery.FieldFilter( + field=StructuredQuery.FieldReference(field_path="snooze"), + op=StructuredQuery.FieldFilter.Operator.EQUAL, + value=encode_value(10), + ) + ) + + firestore_api.run_query.assert_called_once_with( + request={ + "parent": parent_path, + "structured_query": expected_pb, + "transaction": _TXN_ID, + }, + metadata=client._rpc_metadata, + **kwargs, + ) + + +@pytest.mark.parametrize( + "distance_measure, expected_distance", + [ + ( + DistanceMeasure.EUCLIDEAN, + StructuredQuery.FindNearest.DistanceMeasure.EUCLIDEAN, + ), + (DistanceMeasure.COSINE, StructuredQuery.FindNearest.DistanceMeasure.COSINE), + ( + DistanceMeasure.DOT_PRODUCT, + StructuredQuery.FindNearest.DistanceMeasure.DOT_PRODUCT, + ), + ], +) +@pytest.mark.asyncio +async def test_vector_query_collection_group(distance_measure, expected_distance): + # Create a minimal fake GAPIC. + firestore_api = AsyncMock(spec=["run_query"]) + client = make_async_client() + client._firestore_api_internal = firestore_api + + # Make a **real** collection group reference as parent. + collection_group_ref = client.collection_group("dee") + + data = {"snooze": 10, "embedding": Vector([1.0, 2.0, 3.0])} + response_pb = _make_query_response(name="xxx/test_doc", data=data) + + kwargs = make_retry_timeout_kwargs(retry=None, timeout=None) + + # Execute the vector query and check the response. + firestore_api.run_query.return_value = AsyncIter([response_pb]) + + vector_query = collection_group_ref.where("snooze", "==", 10).find_nearest( + vector_field="embedding", + query_vector=Vector([1.0, 2.0, 3.0]), + distance_measure=distance_measure, + limit=5, + ) + + returned = await vector_query.get(transaction=_transaction(client), **kwargs) + assert isinstance(returned, list) + assert len(returned) == 1 + assert returned[0].to_dict() == data + + parent = client.collection("dee") + parent_path, expected_prefix = parent._parent_info() + + expected_pb = _expected_pb( + parent=parent, + vector_field="embedding", + vector=Vector([1.0, 2.0, 3.0]), + distance_type=expected_distance, + limit=5, + ) + expected_pb.where = StructuredQuery.Filter( + field_filter=StructuredQuery.FieldFilter( + field=StructuredQuery.FieldReference(field_path="snooze"), + op=StructuredQuery.FieldFilter.Operator.EQUAL, + value=encode_value(10), + ) + ) + expected_pb.from_ = [ + StructuredQuery.CollectionSelector(collection_id="dee", all_descendants=True) + ] + + firestore_api.run_query.assert_called_once_with( + request={ + "parent": parent_path, + "structured_query": expected_pb, + "transaction": _TXN_ID, + }, + metadata=client._rpc_metadata, + **kwargs, + ) From 5ed48be990b5e721cb04bb3f0f4273c6417d77bf Mon Sep 17 00:00:00 2001 From: Sichen Liu Date: Wed, 3 Apr 2024 23:08:35 +0000 Subject: [PATCH 2/7] lint --- google/cloud/firestore_v1/async_query.py | 1 + google/cloud/firestore_v1/async_vector_query.py | 11 ++++------- tests/unit/v1/test_async_vector_query.py | 2 -- 3 files changed, 5 insertions(+), 9 deletions(-) diff --git a/google/cloud/firestore_v1/async_query.py b/google/cloud/firestore_v1/async_query.py index 20a4c9fb3c..5980b269cb 100644 --- a/google/cloud/firestore_v1/async_query.py +++ b/google/cloud/firestore_v1/async_query.py @@ -36,6 +36,7 @@ from google.cloud.firestore_v1 import async_document from google.cloud.firestore_v1.async_aggregation import AsyncAggregationQuery from google.cloud.firestore_v1.base_document import DocumentSnapshot +from google.cloud.firestore_v1.base_vector_query import DistanceMeasure from google.cloud.firestore_v1.async_vector_query import AsyncVectorQuery from google.cloud.firestore_v1.vector import Vector from typing import AsyncGenerator, List, Optional, Type, TYPE_CHECKING diff --git a/google/cloud/firestore_v1/async_vector_query.py b/google/cloud/firestore_v1/async_vector_query.py index d9c12e4c48..4687ba27ec 100644 --- a/google/cloud/firestore_v1/async_vector_query.py +++ b/google/cloud/firestore_v1/async_vector_query.py @@ -20,20 +20,17 @@ """ from __future__ import annotations -from google.cloud.firestore_v1.base_vector_query import BaseVectorQuery from google.api_core import gapic_v1 from google.api_core import retry_async as retries +from google.cloud.firestore_v1 import async_document from google.cloud.firestore_v1.base_document import DocumentSnapshot -from typing import AsyncGenerator, List, Union, Optional, TypeVar - - -from google.cloud.firestore_v1.base_vector_query import ( - BaseVectorQuery, -) from google.cloud.firestore_v1.base_query import ( + BaseQuery, _query_response_to_snapshot, _collection_group_query_response_to_snapshot, ) +from google.cloud.firestore_v1.base_vector_query import BaseVectorQuery +from typing import AsyncGenerator, List, Union, Optional, TypeVar TAsyncVectorQuery = TypeVar("TAsyncVectorQuery", bound="AsyncVectorQuery") diff --git a/tests/unit/v1/test_async_vector_query.py b/tests/unit/v1/test_async_vector_query.py index 18053c5e55..b42a4cd45a 100644 --- a/tests/unit/v1/test_async_vector_query.py +++ b/tests/unit/v1/test_async_vector_query.py @@ -12,9 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import mock import pytest -import types from google.cloud.firestore_v1.types.query import StructuredQuery from google.cloud.firestore_v1.vector import Vector From 9c3faf01a842c73870b2fa7223653dfd7e3f42c6 Mon Sep 17 00:00:00 2001 From: Sichen Liu Date: Wed, 3 Apr 2024 23:20:53 +0000 Subject: [PATCH 3/7] coverage --- .../cloud/firestore_v1/async_vector_query.py | 2 +- tests/unit/v1/_test_helpers.py | 6 --- tests/unit/v1/test_async_vector_query.py | 38 +++++++++++++++++++ 3 files changed, 39 insertions(+), 7 deletions(-) diff --git a/google/cloud/firestore_v1/async_vector_query.py b/google/cloud/firestore_v1/async_vector_query.py index 4687ba27ec..0ea5fbcfbe 100644 --- a/google/cloud/firestore_v1/async_vector_query.py +++ b/google/cloud/firestore_v1/async_vector_query.py @@ -76,7 +76,7 @@ async def get( stream_result = self.stream( transaction=transaction, retry=retry, timeout=timeout ) - result = [aggregation async for aggregation in stream_result] + result = [snapshot async for snapshot in stream_result] return result # type: ignore async def stream( diff --git a/tests/unit/v1/_test_helpers.py b/tests/unit/v1/_test_helpers.py index febb77a418..2734d78751 100644 --- a/tests/unit/v1/_test_helpers.py +++ b/tests/unit/v1/_test_helpers.py @@ -109,12 +109,6 @@ def make_vector_query(*args, **kw): return VectorQuery(*args, **kw) -def make_async_vector_query(*args, **kw): - from google.cloud.firestore_v1.async_vector_query import AsyncVectorQuery - - return AsyncVectorQuery(*args, **kw) - - def build_test_timestamp( year: int = 2021, month: int = 1, diff --git a/tests/unit/v1/test_async_vector_query.py b/tests/unit/v1/test_async_vector_query.py index b42a4cd45a..eae018de30 100644 --- a/tests/unit/v1/test_async_vector_query.py +++ b/tests/unit/v1/test_async_vector_query.py @@ -201,3 +201,41 @@ async def test_vector_query_collection_group(distance_measure, expected_distance metadata=client._rpc_metadata, **kwargs, ) + + +@pytest.mark.asyncio +async def test_async_query_stream_multiple_empty_response_in_stream(): + # Create a minimal fake GAPIC with a dummy response. + firestore_api = AsyncMock(spec=["run_query"]) + empty_response1 = _make_query_response() + empty_response2 = _make_query_response() + run_query_response = AsyncIter([empty_response1, empty_response2]) + firestore_api.run_query.return_value = run_query_response + + # Attach the fake GAPIC to a real client. + client = make_async_client() + client._firestore_api_internal = firestore_api + + # Make a **real** collection reference as parent. + parent = client.collection("dah", "dah", "dum") + async_vector_query = parent.where("snooze", "==", 10).find_nearest( + vector_field="embedding", + query_vector=Vector([1.0, 2.0, 3.0]), + distance_measure=DistanceMeasure.EUCLIDEAN, + limit=5, + ) + + result = [snapshot async for snapshot in async_vector_query.stream()] + + assert list(result) == [] + + # Verify the mock call. + parent_path, _ = parent._parent_info() + firestore_api.run_query.assert_called_once_with( + request={ + "parent": parent_path, + "structured_query": async_vector_query._to_protobuf(), + "transaction": None, + }, + metadata=client._rpc_metadata, + ) From 38bfbd96db06901f906eb03e3180b4865eead9c6 Mon Sep 17 00:00:00 2001 From: Sichen Liu Date: Wed, 3 Apr 2024 23:26:37 +0000 Subject: [PATCH 4/7] docs --- google/cloud/firestore_v1/async_vector_query.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/google/cloud/firestore_v1/async_vector_query.py b/google/cloud/firestore_v1/async_vector_query.py index 0ea5fbcfbe..27de5251ca 100644 --- a/google/cloud/firestore_v1/async_vector_query.py +++ b/google/cloud/firestore_v1/async_vector_query.py @@ -12,12 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Classes for representing Async aggregation queries for the Google Cloud Firestore API. - -A :class:`~google.cloud.firestore_v1.async_aggregation.AsyncAggregationQuery` can be created directly from -a :class:`~google.cloud.firestore_v1.async_collection.AsyncCollection` and that can be -a more common way to create an aggregation query than direct usage of the constructor. -""" from __future__ import annotations from google.api_core import gapic_v1 From b9e5237bbccbe59380d5c0d87b860912cca7e461 Mon Sep 17 00:00:00 2001 From: Sichen Liu Date: Thu, 4 Apr 2024 22:21:17 +0000 Subject: [PATCH 5/7] Resolve comments --- google/cloud/firestore_v1/base_query.py | 2 +- .../cloud/firestore_v1/base_vector_query.py | 10 ++++- tests/system/test_system_async.py | 42 +++++++++++++++++++ 3 files changed, 52 insertions(+), 2 deletions(-) diff --git a/google/cloud/firestore_v1/base_query.py b/google/cloud/firestore_v1/base_query.py index c8c2f3ceb2..9e75514a56 100644 --- a/google/cloud/firestore_v1/base_query.py +++ b/google/cloud/firestore_v1/base_query.py @@ -978,7 +978,7 @@ def _to_protobuf(self) -> StructuredQuery: def find_nearest( self, vector_field: str, - queryVector: Vector, + query_vector: Vector, limit: int, distance_measure: DistanceMeasure, ) -> BaseVectorQuery: diff --git a/google/cloud/firestore_v1/base_vector_query.py b/google/cloud/firestore_v1/base_vector_query.py index e41717d2b5..cb9c00b3af 100644 --- a/google/cloud/firestore_v1/base_vector_query.py +++ b/google/cloud/firestore_v1/base_vector_query.py @@ -25,7 +25,7 @@ from google.cloud.firestore_v1.base_document import DocumentSnapshot from google.cloud.firestore_v1.types import query from google.cloud.firestore_v1.vector import Vector -from google.cloud.firestore_v1 import _helpers +from google.cloud.firestore_v1 import document, _helpers class DistanceMeasure(Enum): @@ -117,3 +117,11 @@ def find_nearest( self._limit = limit self._distance_measure = distance_measure return self + + def stream( + self, + transaction=None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + ) -> Iterable[document.DocumentSnapshot]: + """Reads the documents in the collection that match this query.""" diff --git a/tests/system/test_system_async.py b/tests/system/test_system_async.py index 5b681e7b33..b4a4479756 100644 --- a/tests/system/test_system_async.py +++ b/tests/system/test_system_async.py @@ -35,6 +35,8 @@ from google.cloud._helpers import _datetime_to_pb_timestamp from google.cloud import firestore_v1 as firestore from google.cloud.firestore_v1.base_query import FieldFilter, And, Or +from google.cloud.firestore_v1.base_vector_query import DistanceMeasure +from google.cloud.firestore_v1.vector import Vector from tests.system.test__helpers import ( FIRESTORE_CREDS, @@ -339,6 +341,46 @@ async def test_document_update_w_int_field(client, cleanup, database): assert snapshot1.to_dict() == expected +@pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data") +@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +async def test_vector_search_collection(client, database): + collection_id = "vector_search" + collection = client.collection(collection_id) + vector_query = collection.where("color", "==", "red").find_nearest( + vector_field="embedding", + query_vector=Vector([1.0, 2.0, 3.0]), + limit=1, + distance_measure=DistanceMeasure.EUCLIDEAN, + ) + returned = await vector_query.get() + assert isinstance(returned, list) + assert len(returned) == 1 + assert returned[0].to_dict() == { + "embedding": Vector([1.0, 2.0, 3.0]), + "color": "red", + } + + +@pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data") +@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +async def test_vector_search_collection_group(client, database): + collection_id = "vector_search" + collection_group = client.collection_group(collection_id) + + vector_query = collection_group.where("color", "==", "red").find_nearest( + vector_field="embedding", + query_vector=Vector([1.0, 2.0, 3.0]), + distance_measure=DistanceMeasure.EUCLIDEAN, + limit=1, + ) + returned = await vector_query.get() + assert isinstance(returned, list) + assert len(returned) == 1 + assert returned[0].to_dict() == { + "embedding": Vector([1.0, 2.0, 3.0]), + "color": "red", + } + @pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Internal Issue b/137867104") @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) async def test_update_document(client, cleanup, database): From 255d9ed7274c46a950775c53408c2de0688a4045 Mon Sep 17 00:00:00 2001 From: Sichen Liu Date: Thu, 4 Apr 2024 22:23:11 +0000 Subject: [PATCH 6/7] lint --- tests/system/test_system_async.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/system/test_system_async.py b/tests/system/test_system_async.py index b4a4479756..4418323534 100644 --- a/tests/system/test_system_async.py +++ b/tests/system/test_system_async.py @@ -381,6 +381,7 @@ async def test_vector_search_collection_group(client, database): "color": "red", } + @pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Internal Issue b/137867104") @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) async def test_update_document(client, cleanup, database): From 3abb96beec538b8207186bf4a7ea3125f4c6f6c9 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Wed, 10 Jul 2024 16:42:03 -0700 Subject: [PATCH 7/7] move imports into TYPE_CHECKING --- google/cloud/firestore_v1/async_query.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/google/cloud/firestore_v1/async_query.py b/google/cloud/firestore_v1/async_query.py index 7b98a43760..7a17eee47a 100644 --- a/google/cloud/firestore_v1/async_query.py +++ b/google/cloud/firestore_v1/async_query.py @@ -35,10 +35,7 @@ from google.cloud.firestore_v1 import async_document from google.cloud.firestore_v1.async_aggregation import AsyncAggregationQuery -from google.cloud.firestore_v1.base_document import DocumentSnapshot -from google.cloud.firestore_v1.base_vector_query import DistanceMeasure from google.cloud.firestore_v1.async_vector_query import AsyncVectorQuery -from google.cloud.firestore_v1.vector import Vector from google.cloud.firestore_v1.async_stream_generator import AsyncStreamGenerator from google.cloud.firestore_v1 import transaction from typing import AsyncGenerator, List, Optional, Type, TYPE_CHECKING @@ -46,7 +43,9 @@ if TYPE_CHECKING: # pragma: NO COVER # Types needed only for Type Hints from google.cloud.firestore_v1.base_document import DocumentSnapshot + from google.cloud.firestore_v1.base_vector_query import DistanceMeasure from google.cloud.firestore_v1.field_path import FieldPath + from google.cloud.firestore_v1.vector import Vector class AsyncQuery(BaseQuery):