1212# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313# See the License for the specific language governing permissions and
1414# limitations under the License.
15+ import logging
16+ from typing import Optional
1517
1618from mock import Mock
1719
20+ import attr
21+
22+ from twisted .internet .interfaces import IConsumer , IPullProducer , IReactorTime
23+ from twisted .internet .task import LoopingCall
24+ from twisted .web .http import HTTPChannel
25+
26+ from synapse .app .generic_worker import GenericWorkerServer
27+ from synapse .http .site import SynapseRequest
28+ from synapse .replication .tcp .client import ReplicationDataHandler
1829from synapse .replication .tcp .handler import ReplicationCommandHandler
1930from synapse .replication .tcp .protocol import ClientReplicationStreamProtocol
2031from synapse .replication .tcp .resource import ReplicationStreamProtocolFactory
32+ from synapse .util import Clock
2133
2234from tests import unittest
2335from tests .server import FakeTransport
2436
37+ logger = logging .getLogger (__name__ )
38+
2539
2640class BaseStreamTestCase (unittest .HomeserverTestCase ):
2741 """Base class for tests of the replication streams"""
2842
29- def make_homeserver (self , reactor , clock ):
30- self .test_handler = Mock (wraps = TestReplicationDataHandler ())
31- return self .setup_test_homeserver (replication_data_handler = self .test_handler )
32-
3343 def prepare (self , reactor , clock , hs ):
3444 # build a replication server
3545 server_factory = ReplicationStreamProtocolFactory (hs )
3646 self .streamer = hs .get_replication_streamer ()
3747 self .server = server_factory .buildProtocol (None )
3848
39- repl_handler = ReplicationCommandHandler (hs )
40- repl_handler .handler = self .test_handler
49+ # Make a new HomeServer object for the worker
50+ config = self .default_config ()
51+ config ["worker_app" ] = "synapse.app.generic_worker"
52+ config ["worker_replication_host" ] = "testserv"
53+ config ["worker_replication_http_port" ] = "8765"
54+
55+ self .reactor .lookups ["testserv" ] = "1.2.3.4"
56+
57+ self .worker_hs = self .setup_test_homeserver (
58+ http_client = None ,
59+ homeserverToUse = GenericWorkerServer ,
60+ config = config ,
61+ reactor = self .reactor ,
62+ )
63+
64+ # Since we use sqlite in memory databases we need to make sure the
65+ # databases objects are the same.
66+ self .worker_hs .get_datastore ().db = hs .get_datastore ().db
67+
68+ self .test_handler = Mock (
69+ wraps = TestReplicationDataHandler (self .worker_hs .get_datastore ())
70+ )
71+ self .worker_hs .replication_data_handler = self .test_handler
72+
73+ repl_handler = ReplicationCommandHandler (self .worker_hs )
4174 self .client = ClientReplicationStreamProtocol (
42- hs , "client" , "test" , clock , repl_handler ,
75+ self . worker_hs , "client" , "test" , clock , repl_handler ,
4376 )
4477
4578 self ._client_transport = None
@@ -74,11 +107,75 @@ def replicate(self):
74107 self .streamer .on_notifier_poke ()
75108 self .pump (0.1 )
76109
110+ def handle_http_replication_attempt (self ) -> SynapseRequest :
111+ """Asserts that a connection attempt was made to the master HS on the
112+ HTTP replication port, then proxies it to the master HS object to be
113+ handled.
114+
115+ Returns:
116+ The request object received by master HS.
117+ """
118+
119+ # We should have an outbound connection attempt.
120+ clients = self .reactor .tcpClients
121+ self .assertEqual (len (clients ), 1 )
122+ (host , port , client_factory , _timeout , _bindAddress ) = clients .pop (0 )
123+ self .assertEqual (host , "1.2.3.4" )
124+ self .assertEqual (port , 8765 )
125+
126+ # Set up client side protocol
127+ client_protocol = client_factory .buildProtocol (None )
128+
129+ request_factory = OneShotRequestFactory ()
77130
78- class TestReplicationDataHandler :
131+ # Set up the server side protocol
132+ channel = _PushHTTPChannel (self .reactor )
133+ channel .requestFactory = request_factory
134+ channel .site = self .site
135+
136+ # Connect client to server and vice versa.
137+ client_to_server_transport = FakeTransport (
138+ channel , self .reactor , client_protocol
139+ )
140+ client_protocol .makeConnection (client_to_server_transport )
141+
142+ server_to_client_transport = FakeTransport (
143+ client_protocol , self .reactor , channel
144+ )
145+ channel .makeConnection (server_to_client_transport )
146+
147+ # The request will now be processed by `self.site` and the response
148+ # streamed back.
149+ self .reactor .advance (0 )
150+
151+ # We tear down the connection so it doesn't get reused without our
152+ # knowledge.
153+ server_to_client_transport .loseConnection ()
154+ client_to_server_transport .loseConnection ()
155+
156+ return request_factory .request
157+
158+ def assert_request_is_get_repl_stream_updates (
159+ self , request : SynapseRequest , stream_name : str
160+ ):
161+ """Asserts that the given request is a HTTP replication request for
162+ fetching updates for given stream.
163+ """
164+
165+ self .assertRegex (
166+ request .path ,
167+ br"^/_synapse/replication/get_repl_stream_updates/%s/[^/]+$"
168+ % (stream_name .encode ("ascii" ),),
169+ )
170+
171+ self .assertEqual (request .method , b"GET" )
172+
173+
174+ class TestReplicationDataHandler (ReplicationDataHandler ):
79175 """Drop-in for ReplicationDataHandler which just collects RDATA rows"""
80176
81- def __init__ (self ):
177+ def __init__ (self , hs ):
178+ super ().__init__ (hs )
82179 self .streams = set ()
83180 self ._received_rdata_rows = []
84181
@@ -90,8 +187,118 @@ def get_streams_to_replicate(self):
90187 return positions
91188
92189 async def on_rdata (self , stream_name , token , rows ):
190+ await super ().on_rdata (stream_name , token , rows )
93191 for r in rows :
94192 self ._received_rdata_rows .append ((stream_name , token , r ))
95193
96- async def on_position (self , stream_name , token ):
97- pass
194+
195+ @attr .s ()
196+ class OneShotRequestFactory :
197+ """A simple request factory that generates a single `SynapseRequest` and
198+ stores it for future use. Can only be used once.
199+ """
200+
201+ request = attr .ib (default = None )
202+
203+ def __call__ (self , * args , ** kwargs ):
204+ assert self .request is None
205+
206+ self .request = SynapseRequest (* args , ** kwargs )
207+ return self .request
208+
209+
210+ class _PushHTTPChannel (HTTPChannel ):
211+ """A HTTPChannel that wraps pull producers to push producers.
212+
213+ This is a hack to get around the fact that HTTPChannel transparently wraps a
214+ pull producer (which is what Synapse uses to reply to requests) with
215+ `_PullToPush` to convert it to a push producer. Unfortunately `_PullToPush`
216+ uses the standard reactor rather than letting us use our test reactor, which
217+ makes it very hard to test.
218+ """
219+
220+ def __init__ (self , reactor : IReactorTime ):
221+ super ().__init__ ()
222+ self .reactor = reactor
223+
224+ self ._pull_to_push_producer = None
225+
226+ def registerProducer (self , producer , streaming ):
227+ # Convert pull producers to push producer.
228+ if not streaming :
229+ self ._pull_to_push_producer = _PullToPushProducer (
230+ self .reactor , producer , self
231+ )
232+ producer = self ._pull_to_push_producer
233+
234+ super ().registerProducer (producer , True )
235+
236+ def unregisterProducer (self ):
237+ if self ._pull_to_push_producer :
238+ # We need to manually stop the _PullToPushProducer.
239+ self ._pull_to_push_producer .stop ()
240+
241+
242+ class _PullToPushProducer :
243+ """A push producer that wraps a pull producer.
244+ """
245+
246+ def __init__ (
247+ self , reactor : IReactorTime , producer : IPullProducer , consumer : IConsumer
248+ ):
249+ self ._clock = Clock (reactor )
250+ self ._producer = producer
251+ self ._consumer = consumer
252+
253+ # While running we use a looping call with a zero delay to call
254+ # resumeProducing on given producer.
255+ self ._looping_call = None # type: Optional[LoopingCall]
256+
257+ # We start writing next reactor tick.
258+ self ._start_loop ()
259+
260+ def _start_loop (self ):
261+ """Start the looping call to
262+ """
263+
264+ if not self ._looping_call :
265+ # Start a looping call which runs every tick.
266+ self ._looping_call = self ._clock .looping_call (self ._run_once , 0 )
267+
268+ def stop (self ):
269+ """Stops calling resumeProducing.
270+ """
271+ if self ._looping_call :
272+ self ._looping_call .stop ()
273+ self ._looping_call = None
274+
275+ def pauseProducing (self ):
276+ """Implements IPushProducer
277+ """
278+ self .stop ()
279+
280+ def resumeProducing (self ):
281+ """Implements IPushProducer
282+ """
283+ self ._start_loop ()
284+
285+ def stopProducing (self ):
286+ """Implements IPushProducer
287+ """
288+ self .stop ()
289+ self ._producer .stopProducing ()
290+
291+ def _run_once (self ):
292+ """Calls resumeProducing on producer once.
293+ """
294+
295+ try :
296+ self ._producer .resumeProducing ()
297+ except Exception :
298+ logger .exception ("Failed to call resumeProducing" )
299+ try :
300+ self ._consumer .unregisterProducer ()
301+ except Exception :
302+ pass
303+
304+ self .stopProducing ()
0 commit comments