1+ import os
2+ import asyncio
3+ import uuid
4+ from datetime import datetime
5+ import logging
6+
7+ from aiomqtt import Client
8+ from pynumaflow .shared .asynciter import NonBlockingIterator
9+ from pynumaflow .sourcer import (
10+ ReadRequest ,
11+ Message ,
12+ AckRequest ,
13+ PendingResponse ,
14+ Offset ,
15+ PartitionsResponse ,
16+ get_default_partitions ,
17+ Sourcer ,
18+ SourceAsyncServer ,
19+ NackRequest ,
20+ )
21+
22+ logging .basicConfig (
23+ level = logging .INFO ,
24+ format = "%(asctime)s %(levelname)-8s %(message)s" ,
25+ datefmt = "%Y-%m-%d %H:%M:%S" ,
26+ )
27+ logger = logging .getLogger (__name__ )
28+
29+
30+ class MQTTAsyncSource (Sourcer ):
31+ """
32+ User-defined source for MQTT messages.
33+ """
34+
35+ def __init__ (self , broker , port , topic ):
36+ # The offset idx till where the messages have been read
37+ self .read_idx : int = 0
38+ # Set to maintain a track of the offsets yet to be acknowledged
39+ self .to_ack_set : set [int ] = set ()
40+ # Set to maintain a track of the offsets that have been negatively acknowledged
41+ self .nacked : set [int ] = set ()
42+ # MQTT broker address to connect to.
43+ self .broker = broker
44+ # Port number of the MQTT broker.
45+ self .port = port
46+ # MQTT topic to subscribe to for receiving messages.
47+ self .topic = topic
48+ # Async queue to store incoming messages
49+ self .messages = asyncio .Queue ()
50+ # Asyncio task for MQTT client loop
51+ self ._mqtt_task = None
52+ # Flag indicating if source has started
53+ self ._started = False
54+
55+ async def start_mqtt_consumer (self ):
56+ """Start the MQTT consumer"""
57+ if self ._started :
58+ return
59+ self ._started = True
60+
61+ logger .info (f"Starting MQTT consumer for broker={ self .broker } , port={ self .port } , topic={ self .topic } " )
62+
63+ async def mqtt_loop ():
64+ while True :
65+ try :
66+ async with Client (self .broker , self .port ) as client :
67+ await client .subscribe (self .topic )
68+ logger .info (f"Successfully subscribed to MQTT topic: { self .topic } " )
69+ async for msg in client .messages :
70+ payload = msg .payload .decode ()
71+ logger .info (f"Received MQTT message: { payload } " )
72+ await self .messages .put (payload )
73+ except Exception as e :
74+ logger .error (f"MQTT consumer error: { e } . Retrying in 5 seconds..." )
75+ await asyncio .sleep (5 )
76+
77+ self ._mqtt_task = asyncio .create_task (mqtt_loop ())
78+
79+ async def read_handler (self , datum : ReadRequest , output : NonBlockingIterator ):
80+ """
81+ read_handler is used to read the data from the source and send the data forward
82+ for each read request we process num_records and increment the read_idx to indicate that
83+ the message has been read and the same is added to the ack set
84+ """
85+
86+ if not self ._started :
87+ await self .start_mqtt_consumer ()
88+
89+ if len (self .to_ack_set ) >= 500 :
90+ return
91+
92+ for _ in range (datum .num_records ):
93+ if self .nacked :
94+ idx = self .nacked .pop ()
95+ else :
96+ idx = self .read_idx
97+ self .read_idx += 1
98+
99+ try :
100+ payload = self .messages .get_nowait ()
101+ logger .info (f"Sending MQTT message: { payload } " )
102+ except asyncio .QueueEmpty :
103+ payload = f"dummy-{ idx } "
104+
105+ headers = {"x-txn-id" : str (uuid .uuid4 ())}
106+ await output .put (
107+ Message (
108+ payload = str (payload ).encode (),
109+ offset = Offset .offset_with_default_partition_id (str (idx ).encode ()),
110+ event_time = datetime .now (),
111+ headers = headers ,
112+ )
113+ )
114+ self .to_ack_set .add (idx )
115+
116+ async def ack_handler (self , ack_request : AckRequest ):
117+ """
118+ Handle message acknowledgments.
119+ """
120+ for req in ack_request .offsets :
121+ offset = int (req .offset )
122+ self .to_ack_set .remove (offset )
123+
124+ async def nack_handler (self , ack_request : NackRequest ):
125+ """
126+ Add the offsets that have been negatively acknowledged to the nacked set
127+ """
128+
129+ for req in ack_request .offsets :
130+ offset = int (req .offset )
131+ self .to_ack_set .remove (offset )
132+ self .nacked .add (offset )
133+ logger .info ("Negatively acknowledged offsets: %s" , self .nacked )
134+
135+ async def pending_handler (self ) -> PendingResponse :
136+ """
137+ Return the number of pending messages in the queue
138+ """
139+ return PendingResponse (count = self .messages .qsize ())
140+
141+ async def partitions_handler (self ) -> PartitionsResponse :
142+ """
143+ Return default partitions.
144+ """
145+ return PartitionsResponse (partitions = get_default_partitions ())
146+
147+
148+ if __name__ == "__main__" :
149+ broker = os .getenv ("MQTT_BROKER" , "localhost" )
150+ port = int (os .getenv ("MQTT_PORT" , 1883 ))
151+ topic = os .getenv ("MQTT_TOPIC" , "test" )
152+
153+ logger .info (f"Configuring MQTT Source: broker={ broker } , port={ port } , topic={ topic } " )
154+
155+ ud_source = MQTTAsyncSource (broker , port , topic )
156+ grpc_server = SourceAsyncServer (ud_source , sock_path = "/var/run/numaflow/source.sock" )
157+
158+ logger .info ("Starting MQTT UDS gRPC server" )
159+ grpc_server .start ()
0 commit comments