diff --git a/google/cloud/firestore_v1/async_collection.py b/google/cloud/firestore_v1/async_collection.py index 1b71372dd..cc99aa460 100644 --- a/google/cloud/firestore_v1/async_collection.py +++ b/google/cloud/firestore_v1/async_collection.py @@ -22,7 +22,6 @@ from google.cloud.firestore_v1 import ( async_aggregation, - async_document, async_query, async_vector_query, transaction, @@ -31,11 +30,10 @@ BaseCollectionReference, _item_to_document_ref, ) -from google.cloud.firestore_v1.document import DocumentReference if TYPE_CHECKING: # pragma: NO COVER import datetime - + from google.cloud.firestore_v1.async_document import AsyncDocumentReference from google.cloud.firestore_v1.async_stream_generator import AsyncStreamGenerator from google.cloud.firestore_v1.base_document import DocumentSnapshot from google.cloud.firestore_v1.query_profile import ExplainOptions @@ -142,9 +140,7 @@ async def add( write_result = await document_ref.create(document_data, **kwargs) return write_result.update_time, document_ref - def document( - self, document_id: str | None = None - ) -> async_document.AsyncDocumentReference: + def document(self, document_id: str | None = None) -> AsyncDocumentReference: """Create a sub-document underneath the current collection. Args: @@ -166,7 +162,7 @@ async def list_documents( timeout: float | None = None, *, read_time: datetime.datetime | None = None, - ) -> AsyncGenerator[DocumentReference, None]: + ) -> AsyncGenerator[AsyncDocumentReference, None]: """List all subdocuments of the current collection. Args: diff --git a/google/cloud/firestore_v1/async_query.py b/google/cloud/firestore_v1/async_query.py index 98de75bd6..de6c3c1cf 100644 --- a/google/cloud/firestore_v1/async_query.py +++ b/google/cloud/firestore_v1/async_query.py @@ -20,7 +20,16 @@ """ from __future__ import annotations -from typing import TYPE_CHECKING, Any, AsyncGenerator, List, Optional, Type +from typing import ( + TYPE_CHECKING, + Any, + AsyncGenerator, + List, + Optional, + Type, + Union, + Sequence, +) from google.api_core import gapic_v1 from google.api_core import retry_async as retries @@ -256,7 +265,7 @@ async def get( def find_nearest( self, vector_field: str, - query_vector: Vector, + query_vector: Union[Vector, Sequence[float]], limit: int, distance_measure: DistanceMeasure, *, @@ -269,7 +278,7 @@ def find_nearest( Args: vector_field (str): An indexed vector field to search upon. Only documents which contain vectors whose dimensionality match the query_vector can be returned. - query_vector (Vector): The query vector that we are searching on. Must be a vector of no more + query_vector (Vector | Sequence[float]): The query vector that we are searching on. Must be a vector of no more than 2048 dimensions. limit (int): The number of nearest neighbors to return. Must be a positive integer of no more than 1000. distance_measure (:class:`DistanceMeasure`): The Distance Measure to use. diff --git a/google/cloud/firestore_v1/base_aggregation.py b/google/cloud/firestore_v1/base_aggregation.py index da1af1ec1..c5e6a7b7f 100644 --- a/google/cloud/firestore_v1/base_aggregation.py +++ b/google/cloud/firestore_v1/base_aggregation.py @@ -83,23 +83,26 @@ def __init__(self, alias: str | None = None): def _to_protobuf(self): """Convert this instance to the protobuf representation""" aggregation_pb = StructuredAggregationQuery.Aggregation() - aggregation_pb.alias = self.alias + if self.alias: + aggregation_pb.alias = self.alias aggregation_pb.count = StructuredAggregationQuery.Aggregation.Count() return aggregation_pb class SumAggregation(BaseAggregation): def __init__(self, field_ref: str | FieldPath, alias: str | None = None): - if isinstance(field_ref, FieldPath): - # convert field path to string - field_ref = field_ref.to_api_repr() - self.field_ref = field_ref + # convert field path to string if needed + field_str = ( + field_ref.to_api_repr() if isinstance(field_ref, FieldPath) else field_ref + ) + self.field_ref: str = field_str super(SumAggregation, self).__init__(alias=alias) def _to_protobuf(self): """Convert this instance to the protobuf representation""" aggregation_pb = StructuredAggregationQuery.Aggregation() - aggregation_pb.alias = self.alias + if self.alias: + aggregation_pb.alias = self.alias aggregation_pb.sum = StructuredAggregationQuery.Aggregation.Sum() aggregation_pb.sum.field.field_path = self.field_ref return aggregation_pb @@ -107,16 +110,18 @@ def _to_protobuf(self): class AvgAggregation(BaseAggregation): def __init__(self, field_ref: str | FieldPath, alias: str | None = None): - if isinstance(field_ref, FieldPath): - # convert field path to string - field_ref = field_ref.to_api_repr() - self.field_ref = field_ref + # convert field path to string if needed + field_str = ( + field_ref.to_api_repr() if isinstance(field_ref, FieldPath) else field_ref + ) + self.field_ref: str = field_str super(AvgAggregation, self).__init__(alias=alias) def _to_protobuf(self): """Convert this instance to the protobuf representation""" aggregation_pb = StructuredAggregationQuery.Aggregation() - aggregation_pb.alias = self.alias + if self.alias: + aggregation_pb.alias = self.alias aggregation_pb.avg = StructuredAggregationQuery.Aggregation.Avg() aggregation_pb.avg.field.field_path = self.field_ref return aggregation_pb diff --git a/google/cloud/firestore_v1/base_client.py b/google/cloud/firestore_v1/base_client.py index acbd148fb..4a0e3f6b8 100644 --- a/google/cloud/firestore_v1/base_client.py +++ b/google/cloud/firestore_v1/base_client.py @@ -476,7 +476,7 @@ def _prep_collections( read_time: datetime.datetime | None = None, ) -> Tuple[dict, dict]: """Shared setup for async/sync :meth:`collections`.""" - request = { + request: dict[str, Any] = { "parent": "{}/documents".format(self._database_string), } if read_time is not None: diff --git a/google/cloud/firestore_v1/base_collection.py b/google/cloud/firestore_v1/base_collection.py index ada23529d..1b1ef0411 100644 --- a/google/cloud/firestore_v1/base_collection.py +++ b/google/cloud/firestore_v1/base_collection.py @@ -45,6 +45,7 @@ BaseVectorQuery, DistanceMeasure, ) + from google.cloud.firestore_v1.async_document import AsyncDocumentReference from google.cloud.firestore_v1.document import DocumentReference from google.cloud.firestore_v1.field_path import FieldPath from google.cloud.firestore_v1.query_profile import ExplainOptions @@ -132,7 +133,7 @@ def _aggregation_query(self) -> BaseAggregationQuery: def _vector_query(self) -> BaseVectorQuery: raise NotImplementedError - def document(self, document_id: Optional[str] = None) -> DocumentReference: + def document(self, document_id: Optional[str] = None): """Create a sub-document underneath the current collection. Args: @@ -142,7 +143,7 @@ def document(self, document_id: Optional[str] = None) -> DocumentReference: uppercase and lowercase and letters. Returns: - :class:`~google.cloud.firestore_v1.document.DocumentReference`: + :class:`~google.cloud.firestore_v1.base_document.BaseDocumentReference`: The child document. """ if document_id is None: @@ -182,7 +183,7 @@ def _prep_add( document_id: Optional[str] = None, retry: retries.Retry | retries.AsyncRetry | object | None = None, timeout: Optional[float] = None, - ) -> Tuple[DocumentReference, dict]: + ): """Shared setup for async / sync :method:`add`""" if document_id is None: document_id = _auto_id() @@ -234,7 +235,8 @@ def list_documents( *, read_time: Optional[datetime.datetime] = None, ) -> Union[ - Generator[DocumentReference, Any, Any], AsyncGenerator[DocumentReference, Any] + Generator[DocumentReference, Any, Any], + AsyncGenerator[AsyncDocumentReference, Any], ]: raise NotImplementedError @@ -612,13 +614,17 @@ def _auto_id() -> str: return "".join(random.choice(_AUTO_ID_CHARS) for _ in range(20)) -def _item_to_document_ref(collection_reference, item) -> DocumentReference: +def _item_to_document_ref(collection_reference, item): """Convert Document resource to document ref. Args: collection_reference (google.api_core.page_iterator.GRPCIterator): iterator response item (dict): document resource + + Returns: + :class:`~google.cloud.firestore_v1.base_document.BaseDocumentReference`: + The child document """ document_id = item.name.split(_helpers.DOCUMENT_PATH_DELIMITER)[-1] return collection_reference.document(document_id) diff --git a/google/cloud/firestore_v1/base_query.py b/google/cloud/firestore_v1/base_query.py index 7f0ca15d2..14df886bc 100644 --- a/google/cloud/firestore_v1/base_query.py +++ b/google/cloud/firestore_v1/base_query.py @@ -182,7 +182,7 @@ def _validate_opation(op_string, value): class FieldFilter(BaseFilter): """Class representation of a Field Filter.""" - def __init__(self, field_path, op_string, value=None): + def __init__(self, field_path: str, op_string: str, value: Any | None = None): self.field_path = field_path self.value = value self.op_string = _validate_opation(op_string, value) @@ -208,8 +208,8 @@ class BaseCompositeFilter(BaseFilter): def __init__( self, - operator=StructuredQuery.CompositeFilter.Operator.OPERATOR_UNSPECIFIED, - filters=None, + operator: int = StructuredQuery.CompositeFilter.Operator.OPERATOR_UNSPECIFIED, + filters: list[BaseFilter] | None = None, ): self.operator = operator if filters is None: @@ -241,7 +241,7 @@ def _to_pb(self): class Or(BaseCompositeFilter): """Class representation of an OR Filter.""" - def __init__(self, filters): + def __init__(self, filters: list[BaseFilter]): super().__init__( operator=StructuredQuery.CompositeFilter.Operator.OR, filters=filters ) @@ -250,7 +250,7 @@ def __init__(self, filters): class And(BaseCompositeFilter): """Class representation of an AND Filter.""" - def __init__(self, filters): + def __init__(self, filters: list[BaseFilter]): super().__init__( operator=StructuredQuery.CompositeFilter.Operator.AND, filters=filters ) diff --git a/google/cloud/firestore_v1/bulk_writer.py b/google/cloud/firestore_v1/bulk_writer.py index eff936300..6747bc234 100644 --- a/google/cloud/firestore_v1/bulk_writer.py +++ b/google/cloud/firestore_v1/bulk_writer.py @@ -110,7 +110,7 @@ def wrapper(self, *args, **kwargs): # For code parity, even `SendMode.serial` scenarios should return # a future here. Anything else would badly complicate calling code. result = fn(self, *args, **kwargs) - future = concurrent.futures.Future() + future: concurrent.futures.Future = concurrent.futures.Future() future.set_result(result) return future @@ -319,6 +319,7 @@ def __init__( self._total_batches_sent: int = 0 self._total_write_operations: int = 0 + self._executor: concurrent.futures.ThreadPoolExecutor self._ensure_executor() @staticmethod diff --git a/google/cloud/firestore_v1/field_path.py b/google/cloud/firestore_v1/field_path.py index 048eb64d0..27ac6cc45 100644 --- a/google/cloud/firestore_v1/field_path.py +++ b/google/cloud/firestore_v1/field_path.py @@ -263,7 +263,7 @@ class FieldPath(object): Indicating path of the key to be used. """ - def __init__(self, *parts): + def __init__(self, *parts: str): for part in parts: if not isinstance(part, str) or not part: error = "One or more components is not a string or is empty." @@ -271,7 +271,7 @@ def __init__(self, *parts): self.parts = tuple(parts) @classmethod - def from_api_repr(cls, api_repr: str): + def from_api_repr(cls, api_repr: str) -> "FieldPath": """Factory: create a FieldPath from the string formatted per the API. Args: @@ -288,7 +288,7 @@ def from_api_repr(cls, api_repr: str): return cls(*parse_field_path(api_repr)) @classmethod - def from_string(cls, path_string: str): + def from_string(cls, path_string: str) -> "FieldPath": """Factory: create a FieldPath from a unicode string representation. This method splits on the character `.` and disallows the @@ -351,7 +351,7 @@ def __add__(self, other): else: return NotImplemented - def to_api_repr(self): + def to_api_repr(self) -> str: """Render a quoted string representation of the FieldPath Returns: @@ -360,7 +360,7 @@ def to_api_repr(self): """ return render_field_path(self.parts) - def eq_or_parent(self, other): + def eq_or_parent(self, other) -> bool: """Check whether ``other`` is an ancestor. Returns: @@ -369,7 +369,7 @@ def eq_or_parent(self, other): """ return self.parts[: len(other.parts)] == other.parts[: len(self.parts)] - def lineage(self): + def lineage(self) -> set["FieldPath"]: """Return field paths for all parents. Returns: Set[:class:`FieldPath`] @@ -378,7 +378,7 @@ def lineage(self): return {FieldPath(*self.parts[:index]) for index in indexes} @staticmethod - def document_id(): + def document_id() -> str: """A special FieldPath value to refer to the ID of a document. It can be used in queries to sort or filter by the document ID. diff --git a/google/cloud/firestore_v1/watch.py b/google/cloud/firestore_v1/watch.py index 79933aeca..971485655 100644 --- a/google/cloud/firestore_v1/watch.py +++ b/google/cloud/firestore_v1/watch.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations import collections import functools @@ -232,7 +233,7 @@ def __init__( def _init_stream(self): rpc_request = self._get_rpc_request - self._rpc = ResumableBidiRpc( + self._rpc: ResumableBidiRpc | None = ResumableBidiRpc( start_rpc=self._api._transport.listen, should_recover=_should_recover, should_terminate=_should_terminate, @@ -243,7 +244,9 @@ def _init_stream(self): self._rpc.add_done_callback(self._on_rpc_done) # The server assigns and updates the resume token. - self._consumer = BackgroundConsumer(self._rpc, self.on_snapshot) + self._consumer: BackgroundConsumer | None = BackgroundConsumer( + self._rpc, self.on_snapshot + ) self._consumer.start() @classmethod @@ -330,16 +333,18 @@ def close(self, reason=None): return # Stop consuming messages. - if self.is_active: - _LOGGER.debug("Stopping consumer.") - self._consumer.stop() - self._consumer._on_response = None + if self._consumer: + if self.is_active: + _LOGGER.debug("Stopping consumer.") + self._consumer.stop() + self._consumer._on_response = None self._consumer = None self._snapshot_callback = None - self._rpc.close() - self._rpc._initial_request = None - self._rpc._callbacks = [] + if self._rpc: + self._rpc.close() + self._rpc._initial_request = None + self._rpc._callbacks = [] self._rpc = None self._closed = True _LOGGER.debug("Finished stopping manager.") @@ -460,13 +465,13 @@ def on_snapshot(self, proto): message = f"Unknown target change type: {target_change_type}" _LOGGER.info(f"on_snapshot: {message}") self.close(reason=ValueError(message)) - - try: - # Use 'proto' vs 'pb' for datetime handling - meth(self, proto.target_change) - except Exception as exc2: - _LOGGER.debug(f"meth(proto) exc: {exc2}") - raise + else: + try: + # Use 'proto' vs 'pb' for datetime handling + meth(self, proto.target_change) + except Exception as exc2: + _LOGGER.debug(f"meth(proto) exc: {exc2}") + raise # NOTE: # in other implementations, such as node, the backoff is reset here diff --git a/noxfile.py b/noxfile.py index 7ef3ed5b8..9e81d7179 100644 --- a/noxfile.py +++ b/noxfile.py @@ -155,9 +155,16 @@ def pytype(session): def mypy(session): """Verify type hints are mypy compatible.""" session.install("-e", ".") - session.install("mypy", "types-setuptools") - # TODO: also verify types on tests, all of google package - session.run("mypy", "-p", "google.cloud.firestore", "--no-incremental") + session.install("mypy", "types-setuptools", "types-protobuf") + session.run( + "mypy", + "-p", + "google.cloud.firestore_v1", + "--no-incremental", + "--check-untyped-defs", + "--exclude", + "services", + ) @nox.session(python=DEFAULT_PYTHON_VERSION) diff --git a/tests/unit/v1/test_aggregation.py b/tests/unit/v1/test_aggregation.py index 767089e98..c8a2af9ef 100644 --- a/tests/unit/v1/test_aggregation.py +++ b/tests/unit/v1/test_aggregation.py @@ -51,6 +51,12 @@ def test_count_aggregation_to_pb(): assert count_aggregation._to_protobuf() == expected_aggregation_query_pb +def test_count_aggregation_no_alias_to_pb(): + count_aggregation = CountAggregation(alias=None) + got_pb = count_aggregation._to_protobuf() + assert got_pb.alias == "" + + def test_sum_aggregation_w_field_path(): """ SumAggregation should convert FieldPath inputs into strings @@ -88,6 +94,12 @@ def test_sum_aggregation_to_pb(): assert sum_aggregation._to_protobuf() == expected_aggregation_query_pb +def test_sum_aggregation_no_alias_to_pb(): + sum_aggregation = SumAggregation("someref", alias=None) + got_pb = sum_aggregation._to_protobuf() + assert got_pb.alias == "" + + def test_avg_aggregation_to_pb(): from google.cloud.firestore_v1.types import query as query_pb2 @@ -103,6 +115,12 @@ def test_avg_aggregation_to_pb(): assert avg_aggregation._to_protobuf() == expected_aggregation_query_pb +def test_avg_aggregation_no_alias_to_pb(): + avg_aggregation = AvgAggregation("someref", alias=None) + got_pb = avg_aggregation._to_protobuf() + assert got_pb.alias == "" + + def test_aggregation_query_constructor(): client = make_client() parent = client.collection("dee") diff --git a/tests/unit/v1/test_watch.py b/tests/unit/v1/test_watch.py index 6d8c12abc..63e2233a4 100644 --- a/tests/unit/v1/test_watch.py +++ b/tests/unit/v1/test_watch.py @@ -322,6 +322,15 @@ def test_watch_close(): assert inst._closed +def test_watch_close_w_empty_attrs(): + inst = _make_watch() + inst._consumer = None + inst._rpc = None + inst.close() + assert inst._consumer is None + assert inst._rpc is None + + def test_watch__get_rpc_request_wo_resume_token(): inst = _make_watch()