5555
5656from six import iteritems
5757
58+ import txredisapi as redis
5859from prometheus_client import Counter
5960
6061from twisted .protocols .basic import LineOnlyReceiver
6162from twisted .python .failure import Failure
6263
64+ from synapse .logging .context import PreserveLoggingContext
6365from synapse .metrics import LaterGauge
6466from synapse .metrics .background_process_metrics import run_as_background_process
6567from synapse .replication .tcp .commands import (
@@ -420,6 +422,8 @@ class CommandHandler:
420422 def __init__ (self , hs , handler ):
421423 self .handler = handler
422424
425+ self .is_master = hs .config .worker .worker_app is None
426+
423427 self .clock = hs .get_clock ()
424428
425429 self .streams = {
@@ -458,11 +462,22 @@ def lost_connection(self, connection):
458462 self .handler .lost_connection (connection )
459463
460464 async def on_USER_SYNC (self , cmd : UserSyncCommand ):
465+ if not self .connection :
466+ raise Exception ("Not connected" )
467+
461468 await self .handler .on_user_sync (
462469 self .connection .conn_id , cmd .user_id , cmd .is_syncing , cmd .last_sync_ms
463470 )
464471
465472 async def on_REPLICATE (self , cmd : ReplicateCommand ):
473+ # We only want to announce positions by the writer of the streams.
474+ # Currently this is just the master process.
475+ if not self .is_master :
476+ return
477+
478+ if not self .connection :
479+ raise Exception ("Not connected" )
480+
466481 for stream_name , stream in self .streams .items ():
467482 current_token = stream .current_token ()
468483 self .connection .send_command (PositionCommand (stream_name , current_token ))
@@ -483,15 +498,14 @@ async def on_SYNC(self, cmd: SyncCommand):
483498 self .handler .on_sync (cmd .data )
484499
485500 async def on_RDATA (self , cmd : RdataCommand ):
501+
486502 stream_name = cmd .stream_name
487503 inbound_rdata_count .labels (stream_name ).inc ()
488504
489505 try :
490506 row = STREAMS_MAP [stream_name ].parse_row (cmd .row )
491507 except Exception :
492- logger .exception (
493- "[%s] Failed to parse RDATA: %r %r" , self .id (), stream_name , cmd .row
494- )
508+ logger .exception ("[%s] Failed to parse RDATA: %r" , stream_name , cmd .row )
495509 raise
496510
497511 if cmd .token is None or stream_name in self .streams_connecting :
@@ -519,7 +533,7 @@ async def on_POSITION(self, cmd: PositionCommand):
519533 return
520534
521535 # Fetch all updates between then and now.
522- limited = True
536+ limited = cmd . token != current_token
523537 while limited :
524538 updates , current_token , limited = await stream .get_updates_since (
525539 current_token , cmd .token
@@ -582,7 +596,7 @@ def lost_connection(self, connection):
582596 raise NotImplementedError ()
583597
584598 @abc .abstractmethod
585- def on_user_sync (
599+ async def on_user_sync (
586600 self , conn_id : str , user_id : str , is_syncing : bool , last_sync_ms : int
587601 ):
588602 """A client has started/stopped syncing on a worker.
@@ -794,3 +808,112 @@ def transport_kernel_read_buffer_size(protocol, read=True):
794808inbound_rdata_count = Counter (
795809 "synapse_replication_tcp_protocol_inbound_rdata_count" , "" , ["stream_name" ]
796810)
811+
812+
813+ class RedisSubscriber (redis .SubscriberProtocol ):
814+ def connectionMade (self ):
815+ logger .info ("MADE CONNECTION" )
816+ self .subscribe (self .stream_name )
817+ self .send_command (ReplicateCommand ("ALL" ))
818+
819+ self .handler .new_connection (self )
820+
821+ def messageReceived (self , pattern , channel , message ):
822+ if message .strip () == "" :
823+ # Ignore blank lines
824+ return
825+
826+ line = message
827+ cmd_name , rest_of_line = line .split (" " , 1 )
828+
829+ cmd_cls = COMMAND_MAP [cmd_name ]
830+ try :
831+ cmd = cmd_cls .from_line (rest_of_line )
832+ except Exception as e :
833+ logger .exception (
834+ "[%s] failed to parse line %r: %r" , self .id (), cmd_name , rest_of_line
835+ )
836+ self .send_error (
837+ "failed to parse line for %r: %r (%r):" % (cmd_name , e , rest_of_line )
838+ )
839+ return
840+
841+ # Now lets try and call on_<CMD_NAME> function
842+ run_as_background_process (
843+ "replication-" + cmd .get_logcontext_id (), self .handle_command , cmd
844+ )
845+
846+ async def handle_command (self , cmd : Command ):
847+ """Handle a command we have received over the replication stream.
848+
849+ By default delegates to on_<COMMAND>, which should return an awaitable.
850+
851+ Args:
852+ cmd: received command
853+ """
854+ # First call any command handlers on this instance. These are for TCP
855+ # specific handling.
856+ cmd_func = getattr (self , "on_%s" % (cmd .NAME ,), None )
857+ if cmd_func :
858+ await cmd_func (cmd )
859+
860+ # Then call out to the handler.
861+ cmd_func = getattr (self .handler , "on_%s" % (cmd .NAME ,), None )
862+ if cmd_func :
863+ await cmd_func (cmd )
864+
865+ def connectionLost (self , reason ):
866+ logger .info ("LOST CONNECTION" )
867+ self .handler .lost_connection (self )
868+
869+ def send_command (self , cmd ):
870+ """Send a command if connection has been established.
871+
872+ Args:
873+ cmd (Command)
874+ """
875+ string = "%s %s" % (cmd .NAME , cmd .to_line ())
876+ if "\n " in string :
877+ raise Exception ("Unexpected newline in command: %r" , string )
878+
879+ encoded_string = string .encode ("utf-8" )
880+
881+ async def _send ():
882+ with PreserveLoggingContext ():
883+ await self .redis_connection .publish (self .stream_name , encoded_string )
884+
885+ run_as_background_process ("send-cmd" , _send )
886+
887+ def stream_update (self , stream_name , token , data ):
888+ """Called when a new update is available to stream to clients.
889+
890+ We need to check if the client is interested in the stream or not
891+ """
892+ self .send_command (RdataCommand (stream_name , token , data ))
893+
894+ def send_sync (self , data ):
895+ self .send_command (SyncCommand (data ))
896+
897+ def send_remote_server_up (self , server : str ):
898+ self .send_command (RemoteServerUpCommand (server ))
899+
900+
901+ class RedisFactory (redis .SubscriberFactory ):
902+
903+ maxDelay = 5
904+ continueTrying = True
905+ protocol = RedisSubscriber
906+
907+ def __init__ (self , hs , handler ):
908+ super (RedisFactory , self ).__init__ ()
909+
910+ self .handler = CommandHandler (hs , handler )
911+ self .stream_name = hs .hostname
912+
913+ def buildProtocol (self , addr ):
914+ p = super (RedisFactory , self ).buildProtocol (addr )
915+ p .handler = self .handler
916+ p .redis_connection = redis .lazyConnection ("redis" )
917+ p .conn_id = random_string (5 ) # TODO: FIXME
918+ p .stream_name = self .stream_name
919+ return p
0 commit comments