Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 3 additions & 7 deletions google/cloud/firestore_v1/async_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@

from google.cloud.firestore_v1 import (
async_aggregation,
async_document,
async_query,
async_vector_query,
transaction,
Expand All @@ -31,11 +30,10 @@
BaseCollectionReference,
_item_to_document_ref,
)
from google.cloud.firestore_v1.document import DocumentReference

if TYPE_CHECKING: # pragma: NO COVER
import datetime

from google.cloud.firestore_v1.async_document import AsyncDocumentReference
from google.cloud.firestore_v1.async_stream_generator import AsyncStreamGenerator
from google.cloud.firestore_v1.base_document import DocumentSnapshot
from google.cloud.firestore_v1.query_profile import ExplainOptions
Expand Down Expand Up @@ -142,9 +140,7 @@ async def add(
write_result = await document_ref.create(document_data, **kwargs)
return write_result.update_time, document_ref

def document(
self, document_id: str | None = None
) -> async_document.AsyncDocumentReference:
def document(self, document_id: str | None = None) -> AsyncDocumentReference:
"""Create a sub-document underneath the current collection.

Args:
Expand All @@ -166,7 +162,7 @@ async def list_documents(
timeout: float | None = None,
*,
read_time: datetime.datetime | None = None,
) -> AsyncGenerator[DocumentReference, None]:
) -> AsyncGenerator[AsyncDocumentReference, None]:
"""List all subdocuments of the current collection.

Args:
Expand Down
15 changes: 12 additions & 3 deletions google/cloud/firestore_v1/async_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,16 @@
"""
from __future__ import annotations

from typing import TYPE_CHECKING, Any, AsyncGenerator, List, Optional, Type
from typing import (
TYPE_CHECKING,
Any,
AsyncGenerator,
List,
Optional,
Type,
Union,
Sequence,
)

from google.api_core import gapic_v1
from google.api_core import retry_async as retries
Expand Down Expand Up @@ -256,7 +265,7 @@ async def get(
def find_nearest(
self,
vector_field: str,
query_vector: Vector,
query_vector: Union[Vector, Sequence[float]],
limit: int,
distance_measure: DistanceMeasure,
*,
Expand All @@ -269,7 +278,7 @@ def find_nearest(
Args:
vector_field (str): An indexed vector field to search upon. Only documents which contain
vectors whose dimensionality match the query_vector can be returned.
query_vector (Vector): The query vector that we are searching on. Must be a vector of no more
query_vector (Vector | Sequence[float]): The query vector that we are searching on. Must be a vector of no more
than 2048 dimensions.
limit (int): The number of nearest neighbors to return. Must be a positive integer of no more than 1000.
distance_measure (:class:`DistanceMeasure`): The Distance Measure to use.
Expand Down
27 changes: 16 additions & 11 deletions google/cloud/firestore_v1/base_aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,40 +83,45 @@ def __init__(self, alias: str | None = None):
def _to_protobuf(self):
"""Convert this instance to the protobuf representation"""
aggregation_pb = StructuredAggregationQuery.Aggregation()
aggregation_pb.alias = self.alias
if self.alias:
aggregation_pb.alias = self.alias
aggregation_pb.count = StructuredAggregationQuery.Aggregation.Count()
return aggregation_pb


class SumAggregation(BaseAggregation):
def __init__(self, field_ref: str | FieldPath, alias: str | None = None):
if isinstance(field_ref, FieldPath):
# convert field path to string
field_ref = field_ref.to_api_repr()
self.field_ref = field_ref
# convert field path to string if needed
field_str = (
field_ref.to_api_repr() if isinstance(field_ref, FieldPath) else field_ref
)
self.field_ref: str = field_str
super(SumAggregation, self).__init__(alias=alias)

def _to_protobuf(self):
"""Convert this instance to the protobuf representation"""
aggregation_pb = StructuredAggregationQuery.Aggregation()
aggregation_pb.alias = self.alias
if self.alias:
aggregation_pb.alias = self.alias
aggregation_pb.sum = StructuredAggregationQuery.Aggregation.Sum()
aggregation_pb.sum.field.field_path = self.field_ref
return aggregation_pb


class AvgAggregation(BaseAggregation):
def __init__(self, field_ref: str | FieldPath, alias: str | None = None):
if isinstance(field_ref, FieldPath):
# convert field path to string
field_ref = field_ref.to_api_repr()
self.field_ref = field_ref
# convert field path to string if needed
field_str = (
field_ref.to_api_repr() if isinstance(field_ref, FieldPath) else field_ref
)
self.field_ref: str = field_str
super(AvgAggregation, self).__init__(alias=alias)

def _to_protobuf(self):
"""Convert this instance to the protobuf representation"""
aggregation_pb = StructuredAggregationQuery.Aggregation()
aggregation_pb.alias = self.alias
if self.alias:
aggregation_pb.alias = self.alias
aggregation_pb.avg = StructuredAggregationQuery.Aggregation.Avg()
aggregation_pb.avg.field.field_path = self.field_ref
return aggregation_pb
Expand Down
2 changes: 1 addition & 1 deletion google/cloud/firestore_v1/base_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,7 @@ def _prep_collections(
read_time: datetime.datetime | None = None,
) -> Tuple[dict, dict]:
"""Shared setup for async/sync :meth:`collections`."""
request = {
request: dict[str, Any] = {
"parent": "{}/documents".format(self._database_string),
}
if read_time is not None:
Expand Down
16 changes: 11 additions & 5 deletions google/cloud/firestore_v1/base_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
BaseVectorQuery,
DistanceMeasure,
)
from google.cloud.firestore_v1.async_document import AsyncDocumentReference
from google.cloud.firestore_v1.document import DocumentReference
from google.cloud.firestore_v1.field_path import FieldPath
from google.cloud.firestore_v1.query_profile import ExplainOptions
Expand Down Expand Up @@ -132,7 +133,7 @@ def _aggregation_query(self) -> BaseAggregationQuery:
def _vector_query(self) -> BaseVectorQuery:
raise NotImplementedError

def document(self, document_id: Optional[str] = None) -> DocumentReference:
def document(self, document_id: Optional[str] = None):
"""Create a sub-document underneath the current collection.

Args:
Expand All @@ -142,7 +143,7 @@ def document(self, document_id: Optional[str] = None) -> DocumentReference:
uppercase and lowercase and letters.

Returns:
:class:`~google.cloud.firestore_v1.document.DocumentReference`:
:class:`~google.cloud.firestore_v1.base_document.BaseDocumentReference`:
The child document.
"""
if document_id is None:
Expand Down Expand Up @@ -182,7 +183,7 @@ def _prep_add(
document_id: Optional[str] = None,
retry: retries.Retry | retries.AsyncRetry | object | None = None,
timeout: Optional[float] = None,
) -> Tuple[DocumentReference, dict]:
):
"""Shared setup for async / sync :method:`add`"""
if document_id is None:
document_id = _auto_id()
Expand Down Expand Up @@ -234,7 +235,8 @@ def list_documents(
*,
read_time: Optional[datetime.datetime] = None,
) -> Union[
Generator[DocumentReference, Any, Any], AsyncGenerator[DocumentReference, Any]
Generator[DocumentReference, Any, Any],
AsyncGenerator[AsyncDocumentReference, Any],
]:
raise NotImplementedError

Expand Down Expand Up @@ -612,13 +614,17 @@ def _auto_id() -> str:
return "".join(random.choice(_AUTO_ID_CHARS) for _ in range(20))


def _item_to_document_ref(collection_reference, item) -> DocumentReference:
def _item_to_document_ref(collection_reference, item):
"""Convert Document resource to document ref.

Args:
collection_reference (google.api_core.page_iterator.GRPCIterator):
iterator response
item (dict): document resource

Returns:
:class:`~google.cloud.firestore_v1.base_document.BaseDocumentReference`:
The child document
"""
document_id = item.name.split(_helpers.DOCUMENT_PATH_DELIMITER)[-1]
return collection_reference.document(document_id)
10 changes: 5 additions & 5 deletions google/cloud/firestore_v1/base_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def _validate_opation(op_string, value):
class FieldFilter(BaseFilter):
"""Class representation of a Field Filter."""

def __init__(self, field_path, op_string, value=None):
def __init__(self, field_path: str, op_string: str, value: Any | None = None):
self.field_path = field_path
self.value = value
self.op_string = _validate_opation(op_string, value)
Expand All @@ -208,8 +208,8 @@ class BaseCompositeFilter(BaseFilter):

def __init__(
self,
operator=StructuredQuery.CompositeFilter.Operator.OPERATOR_UNSPECIFIED,
filters=None,
operator: int = StructuredQuery.CompositeFilter.Operator.OPERATOR_UNSPECIFIED,
filters: list[BaseFilter] | None = None,
):
self.operator = operator
if filters is None:
Expand Down Expand Up @@ -241,7 +241,7 @@ def _to_pb(self):
class Or(BaseCompositeFilter):
"""Class representation of an OR Filter."""

def __init__(self, filters):
def __init__(self, filters: list[BaseFilter]):
super().__init__(
operator=StructuredQuery.CompositeFilter.Operator.OR, filters=filters
)
Expand All @@ -250,7 +250,7 @@ def __init__(self, filters):
class And(BaseCompositeFilter):
"""Class representation of an AND Filter."""

def __init__(self, filters):
def __init__(self, filters: list[BaseFilter]):
super().__init__(
operator=StructuredQuery.CompositeFilter.Operator.AND, filters=filters
)
Expand Down
3 changes: 2 additions & 1 deletion google/cloud/firestore_v1/bulk_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def wrapper(self, *args, **kwargs):
# For code parity, even `SendMode.serial` scenarios should return
# a future here. Anything else would badly complicate calling code.
result = fn(self, *args, **kwargs)
future = concurrent.futures.Future()
future: concurrent.futures.Future = concurrent.futures.Future()
future.set_result(result)
return future

Expand Down Expand Up @@ -319,6 +319,7 @@ def __init__(
self._total_batches_sent: int = 0
self._total_write_operations: int = 0

self._executor: concurrent.futures.ThreadPoolExecutor
self._ensure_executor()

@staticmethod
Expand Down
14 changes: 7 additions & 7 deletions google/cloud/firestore_v1/field_path.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,15 +263,15 @@ class FieldPath(object):
Indicating path of the key to be used.
"""

def __init__(self, *parts):
def __init__(self, *parts: str):
for part in parts:
if not isinstance(part, str) or not part:
error = "One or more components is not a string or is empty."
raise ValueError(error)
self.parts = tuple(parts)

@classmethod
def from_api_repr(cls, api_repr: str):
def from_api_repr(cls, api_repr: str) -> "FieldPath":
"""Factory: create a FieldPath from the string formatted per the API.

Args:
Expand All @@ -288,7 +288,7 @@ def from_api_repr(cls, api_repr: str):
return cls(*parse_field_path(api_repr))

@classmethod
def from_string(cls, path_string: str):
def from_string(cls, path_string: str) -> "FieldPath":
"""Factory: create a FieldPath from a unicode string representation.

This method splits on the character `.` and disallows the
Expand Down Expand Up @@ -351,7 +351,7 @@ def __add__(self, other):
else:
return NotImplemented

def to_api_repr(self):
def to_api_repr(self) -> str:
"""Render a quoted string representation of the FieldPath

Returns:
Expand All @@ -360,7 +360,7 @@ def to_api_repr(self):
"""
return render_field_path(self.parts)

def eq_or_parent(self, other):
def eq_or_parent(self, other) -> bool:
"""Check whether ``other`` is an ancestor.

Returns:
Expand All @@ -369,7 +369,7 @@ def eq_or_parent(self, other):
"""
return self.parts[: len(other.parts)] == other.parts[: len(self.parts)]

def lineage(self):
def lineage(self) -> set["FieldPath"]:
"""Return field paths for all parents.

Returns: Set[:class:`FieldPath`]
Expand All @@ -378,7 +378,7 @@ def lineage(self):
return {FieldPath(*self.parts[:index]) for index in indexes}

@staticmethod
def document_id():
def document_id() -> str:
"""A special FieldPath value to refer to the ID of a document. It can be used
in queries to sort or filter by the document ID.

Expand Down
37 changes: 21 additions & 16 deletions google/cloud/firestore_v1/watch.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

import collections
import functools
Expand Down Expand Up @@ -232,7 +233,7 @@ def __init__(
def _init_stream(self):
rpc_request = self._get_rpc_request

self._rpc = ResumableBidiRpc(
self._rpc: ResumableBidiRpc | None = ResumableBidiRpc(
start_rpc=self._api._transport.listen,
should_recover=_should_recover,
should_terminate=_should_terminate,
Expand All @@ -243,7 +244,9 @@ def _init_stream(self):
self._rpc.add_done_callback(self._on_rpc_done)

# The server assigns and updates the resume token.
self._consumer = BackgroundConsumer(self._rpc, self.on_snapshot)
self._consumer: BackgroundConsumer | None = BackgroundConsumer(
self._rpc, self.on_snapshot
)
self._consumer.start()

@classmethod
Expand Down Expand Up @@ -330,16 +333,18 @@ def close(self, reason=None):
return

# Stop consuming messages.
if self.is_active:
_LOGGER.debug("Stopping consumer.")
self._consumer.stop()
self._consumer._on_response = None
if self._consumer:
if self.is_active:
_LOGGER.debug("Stopping consumer.")
self._consumer.stop()
self._consumer._on_response = None
self._consumer = None

self._snapshot_callback = None
self._rpc.close()
self._rpc._initial_request = None
self._rpc._callbacks = []
if self._rpc:
self._rpc.close()
self._rpc._initial_request = None
self._rpc._callbacks = []
self._rpc = None
self._closed = True
_LOGGER.debug("Finished stopping manager.")
Expand Down Expand Up @@ -460,13 +465,13 @@ def on_snapshot(self, proto):
message = f"Unknown target change type: {target_change_type}"
_LOGGER.info(f"on_snapshot: {message}")
self.close(reason=ValueError(message))

try:
# Use 'proto' vs 'pb' for datetime handling
meth(self, proto.target_change)
except Exception as exc2:
_LOGGER.debug(f"meth(proto) exc: {exc2}")
raise
else:
try:
# Use 'proto' vs 'pb' for datetime handling
meth(self, proto.target_change)
except Exception as exc2:
_LOGGER.debug(f"meth(proto) exc: {exc2}")
raise

# NOTE:
# in other implementations, such as node, the backoff is reset here
Expand Down
Loading
Loading