Skip to content

Commit 84445a4

Browse files
Implement Auto recovery connection for consumer (#250)
* This PR implements auto recovery functionality for consumers when connections are broken. The implementation includes a new recovery strategy system with backoff capabilities and comprehensive test coverage for connection recovery scenarios. - Adds a configurable recovery strategy system to automatically reconnect consumers when connections fail - Implements BackOffRecoveryStrategy with exponential backoff and jitter for robust reconnection - Updates test utilities to use standardized HTTP API function names and improve test reliability Breaking changes - removed the API consumer.reconnect_stream) - Auto-recconnection is handled by RecoveryStrategy configuration --------- Signed-off-by: Gabriele Santomaggio <[email protected]> Co-authored-by: Copilot <[email protected]>
1 parent 5f6b581 commit 84445a4

16 files changed

Lines changed: 530 additions & 324 deletions

.github/workflows/test.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,9 @@ jobs:
4545
- name: black check
4646
run: poetry run black --check .
4747
- name: flake8
48-
run: poetry run flake8 --exclude=venv,local_tests,docs/examples --max-line-length=120 --ignore=E203,W503,E701,E704
48+
run: poetry run flake8 --exclude=venv,local_tests,docs/examples --max-line-length=120 --ignore=E203,W503,E701,E704,E131
4949
- name: mypy
5050
run: |
5151
poetry run mypy .
5252
- name: poetry run pytest
53-
run: poetry run pytest
53+
run: poetry run pytest -v

Makefile

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ format:
1010
poetry --version
1111
poetry run isort .
1212
poetry run black --exclude=venv .
13-
poetry run flake8 --exclude=venv,local_tests,docs/examples --max-line-length=120 --ignore=E203,W503,E701,E704
13+
poetry run flake8 --exclude=venv,local_tests,docs/examples --max-line-length=120 --ignore=E203,W503,E701,E704,E131
1414
poetry run mypy .
1515

1616
rabbitmq-ha-proxy:
@@ -23,6 +23,6 @@ rabbitmq-ha-proxy:
2323
cd compose/ha_tls; docker compose up
2424

2525
test: format
26-
poetry run pytest .
26+
poetry run pytest . -s -v
2727
help:
2828
cat Makefile

docs/examples/reliable_client/BestPracticesClient.py

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -102,26 +102,6 @@ async def on_close_connection(on_closed_info: OnClosedErrorInfo) -> None:
102102
+ str(on_closed_info.reason)
103103
)
104104

105-
await asyncio.sleep(2)
106-
# reconnect just if the partition exists
107-
for stream in on_closed_info.streams:
108-
backoff = 1
109-
while True:
110-
try:
111-
print("reconnecting stream: {}".format(stream))
112-
if consumer is not None:
113-
await consumer.reconnect_stream(stream)
114-
break
115-
except Exception as ex:
116-
if backoff > 32:
117-
# failed to found the leader
118-
print("reconnection failed")
119-
break
120-
backoff = backoff * 2
121-
print("exception reconnecting waiting 120s: " + str(ex))
122-
await asyncio.sleep(30)
123-
continue
124-
125105

126106
# Make consumers
127107
async def make_consumer(rabbitmq_data: dict) -> Consumer | SuperStreamConsumer: # type: ignore

pyproject.toml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ typing_extensions ="^4.15.0"
3333

3434
[tool.black]
3535
line-length = 110
36+
skip-string-normalization = true
3637

3738
[tool.mypy]
3839
ignore_missing_imports = true
@@ -45,3 +46,8 @@ ignore_errors = true
4546
[tool.isort]
4647
profile = "black"
4748

49+
50+
[tool.flake8]
51+
max-line-length = 120
52+
extend-ignore = ["E203", "W503"]
53+

rstream/client.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
Union,
2424
)
2525

26+
from rstream.schema import Broker
27+
2628
from . import (
2729
__license__,
2830
__version__,
@@ -509,7 +511,7 @@ async def stream_exists(self, stream: str) -> bool:
509511
async def query_leader_and_replicas(
510512
self,
511513
stream: str,
512-
) -> tuple[schema.Broker, list[schema.Broker]]:
514+
) -> tuple[Broker, list[Broker]]:
513515
while True:
514516
metadata_resp = await self.sync_request(
515517
schema.Metadata(

rstream/consumer.py

Lines changed: 80 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,13 @@
3232
SlasMechanism,
3333
)
3434
from .exceptions import StreamAlreadySubscribed
35+
from .recovery import BackOffRecoveryStrategy, IReliableEntity, RecoveryStrategy
3536
from .schema import OffsetSpecification
3637
from .utils import FilterConfiguration, OnClosedErrorInfo
3738

3839
MT = TypeVar("MT")
3940
CB = Annotated[Callable[[MT, Any], Union[None, Awaitable[None]]], "Message callback type"]
40-
CB_CONN = Annotated[Callable[[MT], Union[None, Awaitable[None]]], "Message callback type"]
41+
CB_CONN = Annotated[Callable[[MT], Union[None, Awaitable[Any]]], "Message callback type"]
4142
logger = logging.getLogger(__name__)
4243

4344

@@ -63,26 +64,26 @@ class EventContext:
6364
@dataclass
6465
class _Subscriber:
6566
stream: str
67+
client: Client
6668
subscription_id: int
6769
reference: Optional[str]
68-
client: Client
6970
callback: Callable[[AMQPMessage, MessageContext], Union[None, Awaitable[None]]]
7071
decoder: Callable[[bytes], Any]
7172
offset_type: OffsetType
7273
offset: int
7374
filter_input: Optional[FilterConfiguration]
7475

7576

76-
class Consumer:
77+
class Consumer(IReliableEntity):
7778
def __init__(
7879
self,
7980
host: str,
8081
port: int = 5552,
8182
*,
82-
ssl_context: Optional[ssl.SSLContext] = None,
83-
vhost: str = "/",
8483
username: str,
8584
password: str,
85+
ssl_context: Optional[ssl.SSLContext] = None,
86+
vhost: str = "/",
8687
frame_max: int = 1 * 1024 * 1024,
8788
heartbeat: int = 60,
8889
load_balancer_mode: bool = False,
@@ -91,7 +92,9 @@ def __init__(
9192
on_close_handler: Optional[CB_CONN[OnClosedErrorInfo]] = None,
9293
connection_name: str = "",
9394
sasl_configuration_mechanism: SlasMechanism = SlasMechanism.MechanismPlain,
95+
recovery_strategy: RecoveryStrategy = BackOffRecoveryStrategy(),
9496
):
97+
super().__init__()
9598
self._pool = ClientPool(
9699
host,
97100
port,
@@ -113,6 +116,7 @@ def __init__(
113116
if max_subscribers_by_connection > MAX_ITEM_ALLOWED:
114117
raise ValueError(f"max_subscribers_by_connection must be less than {MAX_ITEM_ALLOWED}")
115118

119+
self._recovery_strategy = recovery_strategy
116120
self._default_client: Optional[Client] = None
117121
self._clients: dict[str, Client] = {}
118122
self._subscribers: dict[int, _Subscriber] = {}
@@ -128,7 +132,7 @@ def __init__(
128132

129133
@property
130134
async def default_client(self) -> Client:
131-
if self._default_client is None:
135+
if self._default_client is None or not self._default_client.is_connection_alive():
132136
self._default_client = await self._create_locator_connection()
133137
return self._default_client
134138

@@ -141,7 +145,7 @@ async def __aexit__(self, *_: Any) -> None:
141145

142146
async def start(self) -> None:
143147
self._default_client = await self._pool.get(
144-
connection_closed_handler=self._on_close_handler,
148+
connection_closed_handler=self._on_close_connection,
145149
connection_name=self._connection_name,
146150
sasl_configuration_mechanism=self._sasl_configuration_mechanism,
147151
max_clients_by_connections=self._max_subscribers_by_connection,
@@ -175,7 +179,7 @@ async def _get_or_create_client(self, stream: str) -> Client:
175179
if self._default_client is None:
176180
logger.debug("_get_or_create_client(): Creating locator connection")
177181
self._default_client = await self._pool.get(
178-
connection_closed_handler=self._on_close_handler,
182+
connection_closed_handler=self._on_close_connection,
179183
connection_name=self._connection_name,
180184
max_clients_by_connections=self._max_subscribers_by_connection,
181185
)
@@ -185,7 +189,7 @@ async def _get_or_create_client(self, stream: str) -> Client:
185189
logger.debug("_get_or_create_client(): Getting/Creating connection")
186190
self._clients[stream] = await self._pool.get(
187191
addr=Addr(broker.host, broker.port),
188-
connection_closed_handler=self._on_close_handler,
192+
connection_closed_handler=self._on_close_connection,
189193
connection_name=self._connection_name,
190194
stream=stream,
191195
max_clients_by_connections=self._max_subscribers_by_connection,
@@ -254,7 +258,7 @@ async def subscribe(
254258
offset_specification = ConsumerOffsetSpecification(OffsetType.FIRST, None)
255259

256260
async with self._lock:
257-
logger.debug("subscribe(): Create subscriber")
261+
logger.debug("[subscribe] Create subscriber for stream: {}".format(stream))
258262
subscriber = await self._create_subscriber(
259263
stream=stream,
260264
subscriber_name=subscriber_name,
@@ -270,7 +274,7 @@ async def subscribe(
270274
handler=partial(self._on_deliver, subscriber=subscriber, filter_value=filter_input),
271275
)
272276

273-
logger.debug("subscribe(): Adding handlers")
277+
logger.debug("[subscribe] Adding handlers for stream: {}".format(stream))
274278
subscriber.client.add_handler(
275279
schema.Deliver,
276280
partial(self._on_deliver, subscriber=subscriber, filter_value=filter_input),
@@ -299,7 +303,7 @@ async def subscribe(
299303
)
300304

301305
if filter_input is not None:
302-
logger.debug("subscribe(): Filtering scenario enabled")
306+
logger.debug("[subscribe] Filtering scenario enabled for stream: {}".format(stream))
303307
await self._check_if_filtering_is_supported()
304308
values_to_filter = filter_input.values()
305309
if len(values_to_filter) <= 0:
@@ -315,7 +319,7 @@ async def subscribe(
315319
else:
316320
properties[SUBSCRIPTION_PROPERTY_MATCH_UNFILTERED] = "false"
317321

318-
logger.debug("subscribe(): Subscribing")
322+
logger.debug("[subscribe] subscribing to stream: {}".format(stream))
319323
await subscriber.client.subscribe(
320324
stream=stream,
321325
subscription_id=subscriber.subscription_id,
@@ -325,7 +329,7 @@ async def subscribe(
325329
initial_credit=initial_credit,
326330
properties=properties,
327331
)
328-
332+
logger.debug("[subscribe] created for: {}, id: {}".format(stream, subscriber.subscription_id))
329333
return subscriber.subscription_id
330334

331335
async def get_available_id(self) -> int:
@@ -369,9 +373,7 @@ async def unsubscribe(self, subscriber_id: int) -> None:
369373
stream = subscriber.stream
370374

371375
if stream in self._clients:
372-
await self._clients[stream].remove_stream(stream)
373-
await self._clients[stream].free_available_id()
374-
del self._clients[stream]
376+
await self._remove_stream_from_client(stream)
375377

376378
del self._subscribers[subscriber_id]
377379

@@ -456,6 +458,48 @@ async def _on_metadata_update(self, frame: schema.MetadataUpdate) -> None:
456458
if result is not None and inspect.isawaitable(result):
457459
await result
458460

461+
async def _on_close_connection(self, on_closed_info: OnClosedErrorInfo) -> None:
462+
# clone on_closed_info to avoid modification during iteration
463+
new_on_closed_info = OnClosedErrorInfo(
464+
reason=on_closed_info.reason,
465+
streams=list(on_closed_info.streams) if on_closed_info.streams else [],
466+
)
467+
468+
if self._on_close_handler is not None:
469+
result = self._on_close_handler(new_on_closed_info)
470+
if result is not None and inspect.isawaitable(result):
471+
await result
472+
473+
for stream in on_closed_info.streams.copy():
474+
current_subscriber = await self._get_subscriber_by_stream(stream)
475+
if current_subscriber is not None:
476+
del self._subscribers[current_subscriber.subscription_id]
477+
await self._remove_stream_from_client(stream)
478+
result = self._recovery_strategy.recover(
479+
self,
480+
current_subscriber.stream,
481+
error=Exception(on_closed_info.reason),
482+
attempt=1,
483+
# fmt: off
484+
recovery_fun=lambda stream_s=current_subscriber.stream,
485+
reference=current_subscriber.reference,
486+
callback=current_subscriber.callback,
487+
decoder=current_subscriber.decoder,
488+
offset=current_subscriber.offset,
489+
filter_input=current_subscriber.filter_input:
490+
self.subscribe(
491+
stream=stream_s,
492+
subscriber_name=reference,
493+
callback=callback,
494+
decoder=decoder,
495+
offset_specification=ConsumerOffsetSpecification(OffsetType.OFFSET, offset),
496+
filter_input=filter_input,
497+
),
498+
# fmt: on
499+
)
500+
if result is not None and inspect.isawaitable(result):
501+
await result
502+
459503
async def _on_consumer_update_query_response(
460504
self,
461505
frame: schema.ConsumerUpdateResponse,
@@ -527,46 +571,6 @@ async def stream_exists(self, stream: str, on_close_event: bool = False) -> bool
527571

528572
return stream_exists
529573

530-
async def reconnect_stream(self, stream: str, offset: Optional[int] = None) -> None:
531-
logging.debug("reconnect_stream")
532-
curr_subscriber = None
533-
curr_subscriber_id = None
534-
for subscriber_id in self._subscribers:
535-
if stream == self._subscribers[subscriber_id].stream:
536-
curr_subscriber = self._subscribers[subscriber_id]
537-
curr_subscriber_id = subscriber_id
538-
if curr_subscriber_id is not None:
539-
del self._subscribers[curr_subscriber_id]
540-
541-
if stream in self._clients:
542-
if curr_subscriber is not None:
543-
await self._clients[stream].free_available_id()
544-
await self._clients[stream].close()
545-
del self._clients[stream]
546-
547-
if self._default_client is not None:
548-
if not self._default_client.is_connection_alive():
549-
await self._default_client.close()
550-
self._default_client = None
551-
552-
if offset is None:
553-
if curr_subscriber is not None:
554-
offset = curr_subscriber.offset
555-
556-
logging.debug("reconnect_stream(): Subscribing again")
557-
offset_specification = ConsumerOffsetSpecification(OffsetType.OFFSET, offset)
558-
if curr_subscriber is not None:
559-
asyncio.create_task(
560-
self.subscribe(
561-
stream=curr_subscriber.stream,
562-
# subscriber_name=curr_subscriber.reference,
563-
callback=curr_subscriber.callback,
564-
decoder=curr_subscriber.decoder,
565-
offset_specification=offset_specification,
566-
filter_input=curr_subscriber.filter_input,
567-
)
568-
)
569-
570574
async def _check_if_filtering_is_supported(self) -> None:
571575
command_version_input = schema.FrameHandlerInfo(Key.Publish.value, min_version=1, max_version=2)
572576
server_command_version: schema.FrameHandlerInfo = await (
@@ -592,17 +596,27 @@ async def _close_locator_connection(self):
592596
await (await self.default_client).close()
593597
self._default_client = None
594598

595-
async def _maybe_clean_up_during_lost_connection(self, stream: str):
596-
curr_subscriber = None
597-
598-
for subscriber_id in self._subscribers:
599-
if stream == self._subscribers[subscriber_id].stream:
600-
curr_subscriber = self._subscribers[subscriber_id]
601-
599+
async def _remove_stream_from_client(self, stream: str) -> None:
602600
if stream in self._clients:
603601
await self._clients[stream].remove_stream(stream)
604-
if curr_subscriber is not None:
605-
await self._clients[stream].free_available_id()
602+
await self._clients[stream].free_available_id()
606603
if await self._clients[stream].get_stream_count() == 0:
607604
await self._clients[stream].close()
608605
del self._clients[stream]
606+
607+
async def _get_subscriber_by_stream(self, stream: str) -> Optional[_Subscriber]:
608+
for subscriber in self._subscribers.values():
609+
if stream == subscriber.stream:
610+
return subscriber
611+
return None
612+
613+
async def _maybe_clean_up_during_lost_connection(self, stream: str) -> Optional[int]:
614+
offset = None
615+
curr_subscriber = await self._get_subscriber_by_stream(stream)
616+
if curr_subscriber is not None:
617+
offset = curr_subscriber.offset
618+
619+
if stream in self._clients:
620+
await self._remove_stream_from_client(stream)
621+
622+
return offset

0 commit comments

Comments
 (0)