3232 SlasMechanism ,
3333)
3434from .exceptions import StreamAlreadySubscribed
35+ from .recovery import BackOffRecoveryStrategy , IReliableEntity , RecoveryStrategy
3536from .schema import OffsetSpecification
3637from .utils import FilterConfiguration , OnClosedErrorInfo
3738
3839MT = TypeVar ("MT" )
3940CB = Annotated [Callable [[MT , Any ], Union [None , Awaitable [None ]]], "Message callback type" ]
40- CB_CONN = Annotated [Callable [[MT ], Union [None , Awaitable [None ]]], "Message callback type" ]
41+ CB_CONN = Annotated [Callable [[MT ], Union [None , Awaitable [Any ]]], "Message callback type" ]
4142logger = logging .getLogger (__name__ )
4243
4344
@@ -63,26 +64,26 @@ class EventContext:
6364@dataclass
6465class _Subscriber :
6566 stream : str
67+ client : Client
6668 subscription_id : int
6769 reference : Optional [str ]
68- client : Client
6970 callback : Callable [[AMQPMessage , MessageContext ], Union [None , Awaitable [None ]]]
7071 decoder : Callable [[bytes ], Any ]
7172 offset_type : OffsetType
7273 offset : int
7374 filter_input : Optional [FilterConfiguration ]
7475
7576
76- class Consumer :
77+ class Consumer ( IReliableEntity ) :
7778 def __init__ (
7879 self ,
7980 host : str ,
8081 port : int = 5552 ,
8182 * ,
82- ssl_context : Optional [ssl .SSLContext ] = None ,
83- vhost : str = "/" ,
8483 username : str ,
8584 password : str ,
85+ ssl_context : Optional [ssl .SSLContext ] = None ,
86+ vhost : str = "/" ,
8687 frame_max : int = 1 * 1024 * 1024 ,
8788 heartbeat : int = 60 ,
8889 load_balancer_mode : bool = False ,
@@ -91,7 +92,9 @@ def __init__(
9192 on_close_handler : Optional [CB_CONN [OnClosedErrorInfo ]] = None ,
9293 connection_name : str = "" ,
9394 sasl_configuration_mechanism : SlasMechanism = SlasMechanism .MechanismPlain ,
95+ recovery_strategy : RecoveryStrategy = BackOffRecoveryStrategy (),
9496 ):
97+ super ().__init__ ()
9598 self ._pool = ClientPool (
9699 host ,
97100 port ,
@@ -113,6 +116,7 @@ def __init__(
113116 if max_subscribers_by_connection > MAX_ITEM_ALLOWED :
114117 raise ValueError (f"max_subscribers_by_connection must be less than { MAX_ITEM_ALLOWED } " )
115118
119+ self ._recovery_strategy = recovery_strategy
116120 self ._default_client : Optional [Client ] = None
117121 self ._clients : dict [str , Client ] = {}
118122 self ._subscribers : dict [int , _Subscriber ] = {}
@@ -128,7 +132,7 @@ def __init__(
128132
129133 @property
130134 async def default_client (self ) -> Client :
131- if self ._default_client is None :
135+ if self ._default_client is None or not self . _default_client . is_connection_alive () :
132136 self ._default_client = await self ._create_locator_connection ()
133137 return self ._default_client
134138
@@ -141,7 +145,7 @@ async def __aexit__(self, *_: Any) -> None:
141145
142146 async def start (self ) -> None :
143147 self ._default_client = await self ._pool .get (
144- connection_closed_handler = self ._on_close_handler ,
148+ connection_closed_handler = self ._on_close_connection ,
145149 connection_name = self ._connection_name ,
146150 sasl_configuration_mechanism = self ._sasl_configuration_mechanism ,
147151 max_clients_by_connections = self ._max_subscribers_by_connection ,
@@ -175,7 +179,7 @@ async def _get_or_create_client(self, stream: str) -> Client:
175179 if self ._default_client is None :
176180 logger .debug ("_get_or_create_client(): Creating locator connection" )
177181 self ._default_client = await self ._pool .get (
178- connection_closed_handler = self ._on_close_handler ,
182+ connection_closed_handler = self ._on_close_connection ,
179183 connection_name = self ._connection_name ,
180184 max_clients_by_connections = self ._max_subscribers_by_connection ,
181185 )
@@ -185,7 +189,7 @@ async def _get_or_create_client(self, stream: str) -> Client:
185189 logger .debug ("_get_or_create_client(): Getting/Creating connection" )
186190 self ._clients [stream ] = await self ._pool .get (
187191 addr = Addr (broker .host , broker .port ),
188- connection_closed_handler = self ._on_close_handler ,
192+ connection_closed_handler = self ._on_close_connection ,
189193 connection_name = self ._connection_name ,
190194 stream = stream ,
191195 max_clients_by_connections = self ._max_subscribers_by_connection ,
@@ -254,7 +258,7 @@ async def subscribe(
254258 offset_specification = ConsumerOffsetSpecification (OffsetType .FIRST , None )
255259
256260 async with self ._lock :
257- logger .debug ("subscribe(): Create subscriber" )
261+ logger .debug ("[ subscribe] Create subscriber for stream: {}" . format ( stream ) )
258262 subscriber = await self ._create_subscriber (
259263 stream = stream ,
260264 subscriber_name = subscriber_name ,
@@ -270,7 +274,7 @@ async def subscribe(
270274 handler = partial (self ._on_deliver , subscriber = subscriber , filter_value = filter_input ),
271275 )
272276
273- logger .debug ("subscribe(): Adding handlers" )
277+ logger .debug ("[ subscribe] Adding handlers for stream: {}" . format ( stream ) )
274278 subscriber .client .add_handler (
275279 schema .Deliver ,
276280 partial (self ._on_deliver , subscriber = subscriber , filter_value = filter_input ),
@@ -299,7 +303,7 @@ async def subscribe(
299303 )
300304
301305 if filter_input is not None :
302- logger .debug ("subscribe(): Filtering scenario enabled" )
306+ logger .debug ("[ subscribe] Filtering scenario enabled for stream: {}" . format ( stream ) )
303307 await self ._check_if_filtering_is_supported ()
304308 values_to_filter = filter_input .values ()
305309 if len (values_to_filter ) <= 0 :
@@ -315,7 +319,7 @@ async def subscribe(
315319 else :
316320 properties [SUBSCRIPTION_PROPERTY_MATCH_UNFILTERED ] = "false"
317321
318- logger .debug ("subscribe(): Subscribing" )
322+ logger .debug ("[ subscribe] subscribing to stream: {}" . format ( stream ) )
319323 await subscriber .client .subscribe (
320324 stream = stream ,
321325 subscription_id = subscriber .subscription_id ,
@@ -325,7 +329,7 @@ async def subscribe(
325329 initial_credit = initial_credit ,
326330 properties = properties ,
327331 )
328-
332+ logger . debug ( "[subscribe] created for: {}, id: {}" . format ( stream , subscriber . subscription_id ))
329333 return subscriber .subscription_id
330334
331335 async def get_available_id (self ) -> int :
@@ -369,9 +373,7 @@ async def unsubscribe(self, subscriber_id: int) -> None:
369373 stream = subscriber .stream
370374
371375 if stream in self ._clients :
372- await self ._clients [stream ].remove_stream (stream )
373- await self ._clients [stream ].free_available_id ()
374- del self ._clients [stream ]
376+ await self ._remove_stream_from_client (stream )
375377
376378 del self ._subscribers [subscriber_id ]
377379
@@ -456,6 +458,48 @@ async def _on_metadata_update(self, frame: schema.MetadataUpdate) -> None:
456458 if result is not None and inspect .isawaitable (result ):
457459 await result
458460
461+ async def _on_close_connection (self , on_closed_info : OnClosedErrorInfo ) -> None :
462+ # clone on_closed_info to avoid modification during iteration
463+ new_on_closed_info = OnClosedErrorInfo (
464+ reason = on_closed_info .reason ,
465+ streams = list (on_closed_info .streams ) if on_closed_info .streams else [],
466+ )
467+
468+ if self ._on_close_handler is not None :
469+ result = self ._on_close_handler (new_on_closed_info )
470+ if result is not None and inspect .isawaitable (result ):
471+ await result
472+
473+ for stream in on_closed_info .streams .copy ():
474+ current_subscriber = await self ._get_subscriber_by_stream (stream )
475+ if current_subscriber is not None :
476+ del self ._subscribers [current_subscriber .subscription_id ]
477+ await self ._remove_stream_from_client (stream )
478+ result = self ._recovery_strategy .recover (
479+ self ,
480+ current_subscriber .stream ,
481+ error = Exception (on_closed_info .reason ),
482+ attempt = 1 ,
483+ # fmt: off
484+ recovery_fun = lambda stream_s = current_subscriber .stream ,
485+ reference = current_subscriber .reference ,
486+ callback = current_subscriber .callback ,
487+ decoder = current_subscriber .decoder ,
488+ offset = current_subscriber .offset ,
489+ filter_input = current_subscriber .filter_input :
490+ self .subscribe (
491+ stream = stream_s ,
492+ subscriber_name = reference ,
493+ callback = callback ,
494+ decoder = decoder ,
495+ offset_specification = ConsumerOffsetSpecification (OffsetType .OFFSET , offset ),
496+ filter_input = filter_input ,
497+ ),
498+ # fmt: on
499+ )
500+ if result is not None and inspect .isawaitable (result ):
501+ await result
502+
459503 async def _on_consumer_update_query_response (
460504 self ,
461505 frame : schema .ConsumerUpdateResponse ,
@@ -527,46 +571,6 @@ async def stream_exists(self, stream: str, on_close_event: bool = False) -> bool
527571
528572 return stream_exists
529573
530- async def reconnect_stream (self , stream : str , offset : Optional [int ] = None ) -> None :
531- logging .debug ("reconnect_stream" )
532- curr_subscriber = None
533- curr_subscriber_id = None
534- for subscriber_id in self ._subscribers :
535- if stream == self ._subscribers [subscriber_id ].stream :
536- curr_subscriber = self ._subscribers [subscriber_id ]
537- curr_subscriber_id = subscriber_id
538- if curr_subscriber_id is not None :
539- del self ._subscribers [curr_subscriber_id ]
540-
541- if stream in self ._clients :
542- if curr_subscriber is not None :
543- await self ._clients [stream ].free_available_id ()
544- await self ._clients [stream ].close ()
545- del self ._clients [stream ]
546-
547- if self ._default_client is not None :
548- if not self ._default_client .is_connection_alive ():
549- await self ._default_client .close ()
550- self ._default_client = None
551-
552- if offset is None :
553- if curr_subscriber is not None :
554- offset = curr_subscriber .offset
555-
556- logging .debug ("reconnect_stream(): Subscribing again" )
557- offset_specification = ConsumerOffsetSpecification (OffsetType .OFFSET , offset )
558- if curr_subscriber is not None :
559- asyncio .create_task (
560- self .subscribe (
561- stream = curr_subscriber .stream ,
562- # subscriber_name=curr_subscriber.reference,
563- callback = curr_subscriber .callback ,
564- decoder = curr_subscriber .decoder ,
565- offset_specification = offset_specification ,
566- filter_input = curr_subscriber .filter_input ,
567- )
568- )
569-
570574 async def _check_if_filtering_is_supported (self ) -> None :
571575 command_version_input = schema .FrameHandlerInfo (Key .Publish .value , min_version = 1 , max_version = 2 )
572576 server_command_version : schema .FrameHandlerInfo = await (
@@ -592,17 +596,27 @@ async def _close_locator_connection(self):
592596 await (await self .default_client ).close ()
593597 self ._default_client = None
594598
595- async def _maybe_clean_up_during_lost_connection (self , stream : str ):
596- curr_subscriber = None
597-
598- for subscriber_id in self ._subscribers :
599- if stream == self ._subscribers [subscriber_id ].stream :
600- curr_subscriber = self ._subscribers [subscriber_id ]
601-
599+ async def _remove_stream_from_client (self , stream : str ) -> None :
602600 if stream in self ._clients :
603601 await self ._clients [stream ].remove_stream (stream )
604- if curr_subscriber is not None :
605- await self ._clients [stream ].free_available_id ()
602+ await self ._clients [stream ].free_available_id ()
606603 if await self ._clients [stream ].get_stream_count () == 0 :
607604 await self ._clients [stream ].close ()
608605 del self ._clients [stream ]
606+
607+ async def _get_subscriber_by_stream (self , stream : str ) -> Optional [_Subscriber ]:
608+ for subscriber in self ._subscribers .values ():
609+ if stream == subscriber .stream :
610+ return subscriber
611+ return None
612+
613+ async def _maybe_clean_up_during_lost_connection (self , stream : str ) -> Optional [int ]:
614+ offset = None
615+ curr_subscriber = await self ._get_subscriber_by_stream (stream )
616+ if curr_subscriber is not None :
617+ offset = curr_subscriber .offset
618+
619+ if stream in self ._clients :
620+ await self ._remove_stream_from_client (stream )
621+
622+ return offset
0 commit comments