Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions google/cloud/firestore_v1/async_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@
from google.cloud.firestore_v1 import async_document
from google.cloud.firestore_v1.async_aggregation import AsyncAggregationQuery
from google.cloud.firestore_v1.base_document import DocumentSnapshot
from google.cloud.firestore_v1.base_vector_query import DistanceMeasure
from google.cloud.firestore_v1.async_vector_query import AsyncVectorQuery
from google.cloud.firestore_v1.vector import Vector
from typing import AsyncGenerator, List, Optional, Type, TYPE_CHECKING

if TYPE_CHECKING: # pragma: NO COVER
Expand Down Expand Up @@ -217,6 +220,34 @@ async def get(

return result

def find_nearest(
self,
vector_field: str,
query_vector: Vector,
limit: int,
distance_measure: DistanceMeasure,
) -> AsyncVectorQuery:
"""
Finds the closest vector embeddings to the given query vector.

Args:
vector_field(str): An indexed vector field to search upon. Only documents which contain
vectors whose dimensionality match the query_vector can be returned.
query_vector(Vector): The query vector that we are searching on. Must be a vector of no more
than 2048 dimensions.
limit (int): The number of nearest neighbors to return. Must be a positive integer of no more than 1000.
distance_measure(:class:`DistanceMeasure`): The Distance Measure to use.

Returns:
:class`~firestore_v1.vector_query.VectorQuery`: the vector query.
"""
return AsyncVectorQuery(self).find_nearest(
vector_field=vector_field,
query_vector=query_vector,
limit=limit,
distance_measure=distance_measure,
)

def count(
self, alias: str | None = None
) -> Type["firestore_v1.async_aggregation.AsyncAggregationQuery"]:
Expand Down
127 changes: 127 additions & 0 deletions google/cloud/firestore_v1/async_vector_query.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
# Copyright 2024 Google LLC All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from google.api_core import gapic_v1
from google.api_core import retry_async as retries
from google.cloud.firestore_v1 import async_document
from google.cloud.firestore_v1.base_document import DocumentSnapshot
from google.cloud.firestore_v1.base_query import (
BaseQuery,
_query_response_to_snapshot,
_collection_group_query_response_to_snapshot,
)
from google.cloud.firestore_v1.base_vector_query import BaseVectorQuery
from typing import AsyncGenerator, List, Union, Optional, TypeVar

TAsyncVectorQuery = TypeVar("TAsyncVectorQuery", bound="AsyncVectorQuery")


class AsyncVectorQuery(BaseVectorQuery):
"""Represents an async vector query to the Firestore API."""

def __init__(
self,
nested_query: Union[BaseQuery, TAsyncVectorQuery],
) -> None:
"""Presents the vector query.
Args:
nested_query (BaseQuery | VectorQuery): the base query to apply as the prefilter.
"""
super(AsyncVectorQuery, self).__init__(nested_query)

async def get(
self,
transaction=None,
retry: retries.AsyncRetry = gapic_v1.method.DEFAULT,
timeout: Optional[float] = None,
) -> List[DocumentSnapshot]:
"""Runs the vector query.
This sends a ``RunQuery`` RPC and returns a list of document messages.
Args:
transaction
(Optional[:class:`~google.cloud.firestore_v1.transaction.Transaction`]):
An existing transaction that this query will run in.
If a ``transaction`` is used and it already has write operations
added, this method cannot be used (i.e. read-after-write is not
allowed).
retry (google.api_core.retry.Retry): Designation of what errors, if any,
should be retried. Defaults to a system-specified policy.
timeout (float): The timeout for this request. Defaults to a
system-specified value.
Returns:
list: The vector query results.
"""
stream_result = self.stream(
transaction=transaction, retry=retry, timeout=timeout
)
result = [snapshot async for snapshot in stream_result]
return result # type: ignore

async def stream(
self,
transaction=None,
retry: retries.AsyncRetry = gapic_v1.method.DEFAULT,
timeout: Optional[float] = None,
) -> AsyncGenerator[async_document.DocumentSnapshot, None]:
"""Reads the documents in the collection that match this query.
This sends a ``RunQuery`` RPC and then returns an iterator which
consumes each document returned in the stream of ``RunQueryResponse``
messages.
If a ``transaction`` is used and it already has write operations
added, this method cannot be used (i.e. read-after-write is not
allowed).
Args:
transaction
(Optional[:class:`~google.cloud.firestore_v1.transaction.Transaction`]):
An existing transaction that this query will run in.
retry (google.api_core.retry.Retry): Designation of what errors, if any,
should be retried. Defaults to a system-specified policy.
timeout (float): The timeout for this request. Defaults to a
system-specified value.
Yields:
:class:`~google.cloud.firestore_v1.document.DocumentSnapshot`:
The next document that fulfills the query.
"""
request, expected_prefix, kwargs = self._prep_stream(
transaction,
retry,
timeout,
)

response_iterator = await self._client._firestore_api.run_query(
request=request,
metadata=self._client._rpc_metadata,
**kwargs,
)

async for response in response_iterator:
if self._nested_query._all_descendants:
snapshot = _collection_group_query_response_to_snapshot(
response, self._nested_query._parent
)
else:
snapshot = _query_response_to_snapshot(
response, self._nested_query._parent, expected_prefix
)
if snapshot is not None:
yield snapshot
2 changes: 1 addition & 1 deletion google/cloud/firestore_v1/base_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -978,7 +978,7 @@ def _to_protobuf(self) -> StructuredQuery:
def find_nearest(
self,
vector_field: str,
queryVector: Vector,
query_vector: Vector,
limit: int,
distance_measure: DistanceMeasure,
) -> BaseVectorQuery:
Expand Down
10 changes: 9 additions & 1 deletion google/cloud/firestore_v1/base_vector_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from google.cloud.firestore_v1.base_document import DocumentSnapshot
from google.cloud.firestore_v1.types import query
from google.cloud.firestore_v1.vector import Vector
from google.cloud.firestore_v1 import _helpers
from google.cloud.firestore_v1 import document, _helpers


class DistanceMeasure(Enum):
Expand Down Expand Up @@ -117,3 +117,11 @@ def find_nearest(
self._limit = limit
self._distance_measure = distance_measure
return self

def stream(
self,
transaction=None,
retry: retries.Retry = gapic_v1.method.DEFAULT,
timeout: float = None,
) -> Iterable[document.DocumentSnapshot]:
"""Reads the documents in the collection that match this query."""
43 changes: 43 additions & 0 deletions tests/system/test_system_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@
from google.cloud._helpers import _datetime_to_pb_timestamp
from google.cloud import firestore_v1 as firestore
from google.cloud.firestore_v1.base_query import FieldFilter, And, Or
from google.cloud.firestore_v1.base_vector_query import DistanceMeasure
from google.cloud.firestore_v1.vector import Vector

from tests.system.test__helpers import (
FIRESTORE_CREDS,
Expand Down Expand Up @@ -339,6 +341,47 @@ async def test_document_update_w_int_field(client, cleanup, database):
assert snapshot1.to_dict() == expected


@pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data")
@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True)
async def test_vector_search_collection(client, database):
collection_id = "vector_search"
collection = client.collection(collection_id)
vector_query = collection.where("color", "==", "red").find_nearest(
vector_field="embedding",
query_vector=Vector([1.0, 2.0, 3.0]),
limit=1,
distance_measure=DistanceMeasure.EUCLIDEAN,
)
returned = await vector_query.get()
assert isinstance(returned, list)
assert len(returned) == 1
assert returned[0].to_dict() == {
"embedding": Vector([1.0, 2.0, 3.0]),
"color": "red",
}


@pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data")
@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True)
async def test_vector_search_collection_group(client, database):
collection_id = "vector_search"
collection_group = client.collection_group(collection_id)

vector_query = collection_group.where("color", "==", "red").find_nearest(
vector_field="embedding",
query_vector=Vector([1.0, 2.0, 3.0]),
distance_measure=DistanceMeasure.EUCLIDEAN,
limit=1,
)
returned = await vector_query.get()
assert isinstance(returned, list)
assert len(returned) == 1
assert returned[0].to_dict() == {
"embedding": Vector([1.0, 2.0, 3.0]),
"color": "red",
}


@pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Internal Issue b/137867104")
@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True)
async def test_update_document(client, cleanup, database):
Expand Down
Loading