From dad74e93cbe223cc01adf654e5de260ccd429f0d Mon Sep 17 00:00:00 2001 From: Sichen Liu Date: Mon, 8 Apr 2024 20:52:41 +0000 Subject: [PATCH 1/7] Accept Sequence[float] as query_vector in FindNearest --- google/cloud/firestore_v1/base_collection.py | 3 +- .../cloud/firestore_v1/base_vector_query.py | 6 ++-- tests/unit/v1/test_vector_query.py | 36 +++++++++++++++++++ 3 files changed, 42 insertions(+), 3 deletions(-) diff --git a/google/cloud/firestore_v1/base_collection.py b/google/cloud/firestore_v1/base_collection.py index 98f690e6d9..3f187cb6e3 100644 --- a/google/cloud/firestore_v1/base_collection.py +++ b/google/cloud/firestore_v1/base_collection.py @@ -38,6 +38,7 @@ Iterator, Iterable, NoReturn, + Sequence, Tuple, Union, TYPE_CHECKING, @@ -549,7 +550,7 @@ def avg(self, field_ref: str | FieldPath, alias=None): def find_nearest( self, vector_field: str, - query_vector: Vector, + query_vector: Union[Vector, Sequence[float]], limit: int, distance_measure: DistanceMeasure, ) -> VectorQuery: diff --git a/google/cloud/firestore_v1/base_vector_query.py b/google/cloud/firestore_v1/base_vector_query.py index e41717d2b5..afb49b528f 100644 --- a/google/cloud/firestore_v1/base_vector_query.py +++ b/google/cloud/firestore_v1/base_vector_query.py @@ -19,7 +19,7 @@ from abc import ABC from enum import Enum -from typing import Iterable, Optional, Tuple, Union +from typing import Iterable, Optional, Sequence, Tuple, Union from google.api_core import gapic_v1 from google.api_core import retry as retries from google.cloud.firestore_v1.base_document import DocumentSnapshot @@ -107,11 +107,13 @@ def get( def find_nearest( self, vector_field: str, - query_vector: Vector, + query_vector: Union[Vector, Sequence[float]], limit: int, distance_measure: DistanceMeasure, ): """Finds the closest vector embeddings to the given query vector.""" + if not isinstance (query_vector, Vector): + self._query_vector = Vector(query_vector) self._vector_field = vector_field self._query_vector = query_vector self._limit = limit diff --git a/tests/unit/v1/test_vector_query.py b/tests/unit/v1/test_vector_query.py index 92dca45c4d..4b44cabcb0 100644 --- a/tests/unit/v1/test_vector_query.py +++ b/tests/unit/v1/test_vector_query.py @@ -324,6 +324,42 @@ def test_vector_query_collection_group(distance_measure, expected_distance): **kwargs, ) +def test_vector_query_list_as_query_vector(): + # Create a minimal fake GAPIC with a dummy response. + firestore_api = mock.Mock(spec=["run_query"]) + response_pb = _make_query_response(name="xxx/test_doc", data=data) + run_query_response = iter([response_pb]) + firestore_api.run_query.return_value = run_query_response + + # Attach the fake GAPIC to a real client. + client = make_client() + client._firestore_api_internal = firestore_api + + # Make a **real** collection reference as parent. + parent = client.collection("dah", "dah", "dum") + vector_query = parent.where("snooze", "==", 10).find_nearest( + vector_field="embedding", + query_vector=[1.0, 2.0, 3.0], + distance_measure=DistanceMeasure.EUCLIDEAN, + limit=5, + ) + + get_response = vector_query.stream() + assert isinstance(get_response, types.GeneratorType) + assert list(get_response) == [] + + # Verify the mock call. + parent_path, _ = parent._parent_info() + firestore_api.run_query.assert_called_once_with( + request={ + "parent": parent_path, + "structured_query": vector_query._to_protobuf(), + "transaction": None, + }, + metadata=client._rpc_metadata, + ) + + def test_query_stream_multiple_empty_response_in_stream(): # Create a minimal fake GAPIC with a dummy response. From d391e0714f2ce278100f62e8ccac0b54d05aa5ab Mon Sep 17 00:00:00 2001 From: Sichen Liu Date: Mon, 8 Apr 2024 21:00:30 +0000 Subject: [PATCH 2/7] lint --- google/cloud/firestore_v1/base_vector_query.py | 2 +- tests/unit/v1/test_vector_query.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/google/cloud/firestore_v1/base_vector_query.py b/google/cloud/firestore_v1/base_vector_query.py index afb49b528f..6f2c8351ec 100644 --- a/google/cloud/firestore_v1/base_vector_query.py +++ b/google/cloud/firestore_v1/base_vector_query.py @@ -112,7 +112,7 @@ def find_nearest( distance_measure: DistanceMeasure, ): """Finds the closest vector embeddings to the given query vector.""" - if not isinstance (query_vector, Vector): + if not isinstance(query_vector, Vector): self._query_vector = Vector(query_vector) self._vector_field = vector_field self._query_vector = query_vector diff --git a/tests/unit/v1/test_vector_query.py b/tests/unit/v1/test_vector_query.py index 4b44cabcb0..cbb07e18a4 100644 --- a/tests/unit/v1/test_vector_query.py +++ b/tests/unit/v1/test_vector_query.py @@ -324,6 +324,7 @@ def test_vector_query_collection_group(distance_measure, expected_distance): **kwargs, ) + def test_vector_query_list_as_query_vector(): # Create a minimal fake GAPIC with a dummy response. firestore_api = mock.Mock(spec=["run_query"]) @@ -360,7 +361,6 @@ def test_vector_query_list_as_query_vector(): ) - def test_query_stream_multiple_empty_response_in_stream(): # Create a minimal fake GAPIC with a dummy response. firestore_api = mock.Mock(spec=["run_query"]) From aa6347ef9b8d44d8dcff8b29208ff5d3732a4c68 Mon Sep 17 00:00:00 2001 From: Sichen Liu Date: Mon, 8 Apr 2024 21:44:20 +0000 Subject: [PATCH 3/7] lint & tests --- google/cloud/firestore_v1/base_query.py | 3 +- .../cloud/firestore_v1/base_vector_query.py | 3 +- google/cloud/firestore_v1/query.py | 6 +- tests/unit/v1/test_vector_query.py | 56 ++++++++++++++----- 4 files changed, 48 insertions(+), 20 deletions(-) diff --git a/google/cloud/firestore_v1/base_query.py b/google/cloud/firestore_v1/base_query.py index c8c2f3ceb2..15525d9901 100644 --- a/google/cloud/firestore_v1/base_query.py +++ b/google/cloud/firestore_v1/base_query.py @@ -46,6 +46,7 @@ Iterable, NoReturn, Optional, + Sequence, Tuple, Type, TypeVar, @@ -978,7 +979,7 @@ def _to_protobuf(self) -> StructuredQuery: def find_nearest( self, vector_field: str, - queryVector: Vector, + query_vector: Union[Vector, Sequence[float]], limit: int, distance_measure: DistanceMeasure, ) -> BaseVectorQuery: diff --git a/google/cloud/firestore_v1/base_vector_query.py b/google/cloud/firestore_v1/base_vector_query.py index 6f2c8351ec..7e5283b707 100644 --- a/google/cloud/firestore_v1/base_vector_query.py +++ b/google/cloud/firestore_v1/base_vector_query.py @@ -114,8 +114,9 @@ def find_nearest( """Finds the closest vector embeddings to the given query vector.""" if not isinstance(query_vector, Vector): self._query_vector = Vector(query_vector) + else: + self._query_vector = query_vector self._vector_field = vector_field - self._query_vector = query_vector self._limit = limit self._distance_measure = distance_measure return self diff --git a/google/cloud/firestore_v1/query.py b/google/cloud/firestore_v1/query.py index c46a06918a..f0e84b18ea 100644 --- a/google/cloud/firestore_v1/query.py +++ b/google/cloud/firestore_v1/query.py @@ -41,7 +41,7 @@ from google.cloud.firestore_v1 import document from google.cloud.firestore_v1.watch import Watch -from typing import Any, Callable, Generator, List, Optional, Type, TYPE_CHECKING +from typing import Any, Callable, Generator, List, Optional, Sequence, Type, TYPE_CHECKING if TYPE_CHECKING: # pragma: NO COVER from google.cloud.firestore_v1.field_path import FieldPath @@ -245,7 +245,7 @@ def _retry_query_after_exception(self, exc, retry, transaction): def find_nearest( self, vector_field: str, - query_vector: Vector, + query_vector: Union[Vector, Sequence[float]], limit: int, distance_measure: DistanceMeasure, ) -> Type["firestore_v1.vector_query.VectorQuery"]: @@ -255,7 +255,7 @@ def find_nearest( Args: vector_field(str): An indexed vector field to search upon. Only documents which contain vectors whose dimensionality match the query_vector can be returned. - query_vector(Vector): The query vector that we are searching on. Must be a vector of no more + query_vector(Vector | Sequence[float]): The query vector that we are searching on. Must be a vector of no more than 2048 dimensions. limit (int): The number of nearest neighbors to return. Must be a positive integer of no more than 1000. distance_measure(:class:`DistanceMeasure`): The Distance Measure to use. diff --git a/tests/unit/v1/test_vector_query.py b/tests/unit/v1/test_vector_query.py index cbb07e18a4..2a5c8c5a5a 100644 --- a/tests/unit/v1/test_vector_query.py +++ b/tests/unit/v1/test_vector_query.py @@ -326,38 +326,64 @@ def test_vector_query_collection_group(distance_measure, expected_distance): def test_vector_query_list_as_query_vector(): - # Create a minimal fake GAPIC with a dummy response. + # Create a minimal fake GAPIC. firestore_api = mock.Mock(spec=["run_query"]) - response_pb = _make_query_response(name="xxx/test_doc", data=data) - run_query_response = iter([response_pb]) - firestore_api.run_query.return_value = run_query_response - - # Attach the fake GAPIC to a real client. client = make_client() client._firestore_api_internal = firestore_api # Make a **real** collection reference as parent. - parent = client.collection("dah", "dah", "dum") - vector_query = parent.where("snooze", "==", 10).find_nearest( + parent = client.collection("dee") + query = make_query(parent) + parent_path, expected_prefix = parent._parent_info() + + data = {"snooze": 10, "embedding": Vector([1.0, 2.0, 3.0])} + response_pb1 = _make_query_response( + name="{}/test_doc".format(expected_prefix), data=data + ) + response_pb2 = _make_query_response( + name="{}/test_doc".format(expected_prefix), data=data + ) + + kwargs = make_retry_timeout_kwargs(retry=None, timeout=None) + + # Execute the vector query and check the response. + firestore_api.run_query.return_value = iter([response_pb1, response_pb2]) + + vector_query = query.where("snooze", "==", 10).find_nearest( vector_field="embedding", query_vector=[1.0, 2.0, 3.0], distance_measure=DistanceMeasure.EUCLIDEAN, limit=5, ) - get_response = vector_query.stream() - assert isinstance(get_response, types.GeneratorType) - assert list(get_response) == [] + returned = vector_query.get(transaction=_transaction(client), **kwargs) + assert isinstance(returned, list) + assert len(returned) == 2 + assert returned[0].to_dict() == data + + expected_pb = _expected_pb( + parent=parent, + vector_field="embedding", + vector=Vector([1.0, 2.0, 3.0]), + distance_type=StructuredQuery.FindNearest.DistanceMeasure.EUCLIDEAN, + limit=5, + ) + expected_pb.where = StructuredQuery.Filter( + field_filter=StructuredQuery.FieldFilter( + field=StructuredQuery.FieldReference(field_path="snooze"), + op=StructuredQuery.FieldFilter.Operator.EQUAL, + value=encode_value(10), + ) + ) - # Verify the mock call. - parent_path, _ = parent._parent_info() firestore_api.run_query.assert_called_once_with( request={ "parent": parent_path, - "structured_query": vector_query._to_protobuf(), - "transaction": None, + "structured_query": expected_pb, + "transaction": _TXN_ID, }, metadata=client._rpc_metadata, + **kwargs, ) From c8a74c0d189a582e14193eaa5a0a8bb83329142d Mon Sep 17 00:00:00 2001 From: Sichen Liu Date: Mon, 8 Apr 2024 21:45:59 +0000 Subject: [PATCH 4/7] lint --- google/cloud/firestore_v1/base_collection.py | 2 +- google/cloud/firestore_v1/query.py | 11 ++++++++++- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/google/cloud/firestore_v1/base_collection.py b/google/cloud/firestore_v1/base_collection.py index 3f187cb6e3..72f6211dd3 100644 --- a/google/cloud/firestore_v1/base_collection.py +++ b/google/cloud/firestore_v1/base_collection.py @@ -560,7 +560,7 @@ def find_nearest( Args: vector_field(str): An indexed vector field to search upon. Only documents which contain vectors whose dimensionality match the query_vector can be returned. - query_vector(Vector): The query vector that we are searching on. Must be a vector of no more + query_vector(Union[Vector, Sequence[float]]): The query vector that we are searching on. Must be a vector of no more than 2048 dimensions. limit (int): The number of nearest neighbors to return. Must be a positive integer of no more than 1000. distance_measure(:class:`DistanceMeasure`): The Distance Measure to use. diff --git a/google/cloud/firestore_v1/query.py b/google/cloud/firestore_v1/query.py index f0e84b18ea..97d13a0d0c 100644 --- a/google/cloud/firestore_v1/query.py +++ b/google/cloud/firestore_v1/query.py @@ -41,7 +41,16 @@ from google.cloud.firestore_v1 import document from google.cloud.firestore_v1.watch import Watch -from typing import Any, Callable, Generator, List, Optional, Sequence, Type, TYPE_CHECKING +from typing import ( + Any, + Callable, + Generator, + List, + Optional, + Sequence, + Type, + TYPE_CHECKING, +) if TYPE_CHECKING: # pragma: NO COVER from google.cloud.firestore_v1.field_path import FieldPath From b600941152500ced2d0117310970538287cfea71 Mon Sep 17 00:00:00 2001 From: Sichen Liu Date: Mon, 8 Apr 2024 21:47:40 +0000 Subject: [PATCH 5/7] lint --- google/cloud/firestore_v1/query.py | 1 + 1 file changed, 1 insertion(+) diff --git a/google/cloud/firestore_v1/query.py b/google/cloud/firestore_v1/query.py index 97d13a0d0c..8e71d976f6 100644 --- a/google/cloud/firestore_v1/query.py +++ b/google/cloud/firestore_v1/query.py @@ -50,6 +50,7 @@ Sequence, Type, TYPE_CHECKING, + Union, ) if TYPE_CHECKING: # pragma: NO COVER From 5cc2a2e9786bfd5b49ef29da0d694575c151a8a2 Mon Sep 17 00:00:00 2001 From: Sichen Liu Date: Tue, 11 Jun 2024 11:56:00 +0000 Subject: [PATCH 6/7] Fix typo in test_vector --- tests/unit/v1/test_vector.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/unit/v1/test_vector.py b/tests/unit/v1/test_vector.py index 6ca1ce4134..fb8ca62af0 100644 --- a/tests/unit/v1/test_vector.py +++ b/tests/unit/v1/test_vector.py @@ -24,7 +24,7 @@ from unittest import mock -def _make_commit_repsonse(): +def _make_commit_response(): response = mock.create_autospec(firestore.CommitResponse) response.write_results = [mock.sentinel.write_result] response.commit_time = mock.sentinel.commit_time @@ -34,7 +34,7 @@ def _make_commit_repsonse(): def _make_firestore_api(): firestore_api = mock.Mock() firestore_api.commit.mock_add_spec(spec=["commit"]) - firestore_api.commit.return_value = _make_commit_repsonse() + firestore_api.commit.return_value = _make_commit_response() return firestore_api From d0015ab2dafd1e5ab5f95a9a11271474b775a370 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Wed, 21 May 2025 14:44:18 -0700 Subject: [PATCH 7/7] fixed CI issues after merge --- google/cloud/firestore_v1/base_collection.py | 2 +- google/cloud/firestore_v1/base_vector_query.py | 4 ++-- google/cloud/firestore_v1/query.py | 12 +++++++++++- 3 files changed, 14 insertions(+), 4 deletions(-) diff --git a/google/cloud/firestore_v1/base_collection.py b/google/cloud/firestore_v1/base_collection.py index dc5b848fe2..0e5ae6ed1e 100644 --- a/google/cloud/firestore_v1/base_collection.py +++ b/google/cloud/firestore_v1/base_collection.py @@ -25,10 +25,10 @@ Generator, Generic, Iterable, - NoReturn, Sequence, Tuple, Union, + Optional, ) from google.api_core import retry as retries diff --git a/google/cloud/firestore_v1/base_vector_query.py b/google/cloud/firestore_v1/base_vector_query.py index bd6a3ca2c6..88e40635f9 100644 --- a/google/cloud/firestore_v1/base_vector_query.py +++ b/google/cloud/firestore_v1/base_vector_query.py @@ -19,13 +19,14 @@ import abc from abc import ABC from enum import Enum -from typing import TYPE_CHECKING, Optional, Sequence, Tuple, Union +from typing import TYPE_CHECKING, Any, Coroutine, Optional, Sequence, Tuple, Union from google.api_core import gapic_v1 from google.api_core import retry as retries from google.cloud.firestore_v1 import _helpers from google.cloud.firestore_v1.types import query +from google.cloud.firestore_v1.vector import Vector if TYPE_CHECKING: # pragma: NO COVER from google.cloud.firestore_v1.async_stream_generator import AsyncStreamGenerator @@ -33,7 +34,6 @@ from google.cloud.firestore_v1.query_profile import ExplainOptions from google.cloud.firestore_v1.query_results import QueryResultsList from google.cloud.firestore_v1.stream_generator import StreamGenerator - from google.cloud.firestore_v1.vector import Vector class DistanceMeasure(Enum): diff --git a/google/cloud/firestore_v1/query.py b/google/cloud/firestore_v1/query.py index c510246633..a8b821bdc4 100644 --- a/google/cloud/firestore_v1/query.py +++ b/google/cloud/firestore_v1/query.py @@ -20,7 +20,17 @@ """ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Callable, Generator, List, Optional, Sequence, Type +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Generator, + List, + Optional, + Sequence, + Type, + Union, +) from google.api_core import exceptions, gapic_v1 from google.api_core import retry as retries