Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 26 additions & 27 deletions google_nest_sdm/subscriber_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
_T = TypeVar("_T")

RPC_TIMEOUT_SECONDS = 30.0
STREAMING_PULL_TIMEOUT_SECONDS = 45.0
STREAMING_PULL_TIMEOUT_SECONDS = 55.0
STREAM_ACK_TIMEOUT_SECONDS = 180
STREAM_ACK_FREQUENCY_SECONDS = 90

Expand All @@ -53,19 +53,14 @@ def refresh_creds(creds: Credentials) -> Credentials:


def exception_handler[_T: Any](
func_name: str, timeout: float = RPC_TIMEOUT_SECONDS
func_name: str,
) -> Callable[..., Callable[..., Awaitable[_T]]]:
"""Wrap a function with exception handling."""

def wrapped(func: Callable[..., Awaitable[_T]]) -> Callable[..., Awaitable[_T]]:
async def wrapped_func(*args: Any, **kwargs: Any) -> _T:
try:
async with asyncio.timeout(timeout):
return await func(*args, **kwargs)
except asyncio.TimeoutError as err:
_LOGGER.debug("Timeout in %s: %s", func_name, err)
DIAGNOSTICS.increment(f"{func_name}.timeout")
raise SubscriberException(f"Timeout in {func_name}") from err
return await func(*args, **kwargs)
except NotFound as err:
_LOGGER.debug("NotFound error in %s: %s", func_name, err)
DIAGNOSTICS.increment(f"{func_name}.not_found_error")
Expand Down Expand Up @@ -174,28 +169,26 @@ async def _async_get_client(self) -> pubsub_v1.SubscriberAsyncClient:
self._client = pubsub_v1.SubscriberAsyncClient(credentials=self._creds)
return self._client

@exception_handler("streaming_pull")
async def streaming_pull(
self,
ack_ids_generator: Callable[[], list[str]],
) -> AsyncIterable[pubsub_v1.types.StreamingPullResponse]:
"""Start the streaming pull."""
stream: AsyncIterable[pubsub_v1.types.StreamingPullResponse] = (
await self._streaming_pull_req(ack_ids_generator)
)
return aiter_exception_handler(stream)

@exception_handler(
func_name="streaming_pull", timeout=STREAMING_PULL_TIMEOUT_SECONDS
)
async def _streaming_pull_req(
self,
ack_ids_generator: Callable[[], list[str]],
) -> AsyncIterable[pubsub_v1.types.StreamingPullResponse]:
client = await self._async_get_client()
req_gen = pull_request_generator(self._subscription_name, ack_ids_generator)
_LOGGER.debug("Sending streaming pull request for %s", self._subscription_name)
return await client.streaming_pull(
requests=pull_request_generator(self._subscription_name, ack_ids_generator)
)
try:
async with asyncio.timeout(STREAMING_PULL_TIMEOUT_SECONDS):
stream: AsyncIterable[pubsub_v1.types.StreamingPullResponse] = (
await client.streaming_pull(requests=req_gen)
)
except asyncio.TimeoutError as err:
_LOGGER.debug("Timeout in streaming_pull %s", err)
DIAGNOSTICS.increment("streaming_pull.timeout")
raise SubscriberException("Timeout in streaming_pull") from err
_LOGGER.debug("Streaming pull started")
return aiter_exception_handler(stream)

@exception_handler("acknowledge")
async def ack_messages(self, ack_ids: list[str]) -> None:
Expand All @@ -204,7 +197,13 @@ async def ack_messages(self, ack_ids: list[str]) -> None:
return
client = await self._async_get_client()
_LOGGER.debug("Acking %s messages", len(ack_ids))
await client.acknowledge(
subscription=self._subscription_name,
ack_ids=ack_ids,
)
try:
async with asyncio.timeout(RPC_TIMEOUT_SECONDS):
await client.acknowledge(
subscription=self._subscription_name,
ack_ids=ack_ids,
)
except asyncio.TimeoutError as err:
_LOGGER.debug("Timeout in acknowledge: %s", err)
DIAGNOSTICS.increment("acknowledge.timeout")
raise SubscriberException("Timeout in acknowledge") from err