Skip to content

Commit 77055db

Browse files
authored
Fix tests on postgresql (#3740)
1 parent 567363e commit 77055db

18 files changed

+356
-341
lines changed

.travis.yml

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,6 @@ matrix:
3535
- python: 3.6
3636
env: TOX_ENV=check-newsfragment
3737

38-
allow_failures:
39-
- python: 2.7
40-
env: TOX_ENV=py27-postgres TRIAL_FLAGS="-j 4"
41-
4238
install:
4339
- pip install tox
4440

changelog.d/3740.misc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
The test suite now passes on PostgreSQL.

tests/handlers/test_device.py

Lines changed: 75 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# -*- coding: utf-8 -*-
22
# Copyright 2016 OpenMarket Ltd
3+
# Copyright 2018 New Vector Ltd
34
#
45
# Licensed under the Apache License, Version 2.0 (the "License");
56
# you may not use this file except in compliance with the License.
@@ -13,79 +14,79 @@
1314
# See the License for the specific language governing permissions and
1415
# limitations under the License.
1516

16-
from twisted.internet import defer
17-
1817
import synapse.api.errors
1918
import synapse.handlers.device
2019
import synapse.storage
2120

22-
from tests import unittest, utils
21+
from tests import unittest
2322

2423
user1 = "@boris:aaa"
2524
user2 = "@theresa:bbb"
2625

2726

28-
class DeviceTestCase(unittest.TestCase):
29-
def __init__(self, *args, **kwargs):
30-
super(DeviceTestCase, self).__init__(*args, **kwargs)
31-
self.store = None # type: synapse.storage.DataStore
32-
self.handler = None # type: synapse.handlers.device.DeviceHandler
33-
self.clock = None # type: utils.MockClock
34-
35-
@defer.inlineCallbacks
36-
def setUp(self):
37-
hs = yield utils.setup_test_homeserver(self.addCleanup)
27+
class DeviceTestCase(unittest.HomeserverTestCase):
28+
def make_homeserver(self, reactor, clock):
29+
hs = self.setup_test_homeserver("server", http_client=None)
3830
self.handler = hs.get_device_handler()
3931
self.store = hs.get_datastore()
40-
self.clock = hs.get_clock()
32+
return hs
33+
34+
def prepare(self, reactor, clock, hs):
35+
# These tests assume that it starts 1000 seconds in.
36+
self.reactor.advance(1000)
4137

42-
@defer.inlineCallbacks
4338
def test_device_is_created_if_doesnt_exist(self):
44-
res = yield self.handler.check_device_registered(
45-
user_id="@boris:foo",
46-
device_id="fco",
47-
initial_device_display_name="display name",
39+
res = self.get_success(
40+
self.handler.check_device_registered(
41+
user_id="@boris:foo",
42+
device_id="fco",
43+
initial_device_display_name="display name",
44+
)
4845
)
4946
self.assertEqual(res, "fco")
5047

51-
dev = yield self.handler.store.get_device("@boris:foo", "fco")
48+
dev = self.get_success(self.handler.store.get_device("@boris:foo", "fco"))
5249
self.assertEqual(dev["display_name"], "display name")
5350

54-
@defer.inlineCallbacks
5551
def test_device_is_preserved_if_exists(self):
56-
res1 = yield self.handler.check_device_registered(
57-
user_id="@boris:foo",
58-
device_id="fco",
59-
initial_device_display_name="display name",
52+
res1 = self.get_success(
53+
self.handler.check_device_registered(
54+
user_id="@boris:foo",
55+
device_id="fco",
56+
initial_device_display_name="display name",
57+
)
6058
)
6159
self.assertEqual(res1, "fco")
6260

63-
res2 = yield self.handler.check_device_registered(
64-
user_id="@boris:foo",
65-
device_id="fco",
66-
initial_device_display_name="new display name",
61+
res2 = self.get_success(
62+
self.handler.check_device_registered(
63+
user_id="@boris:foo",
64+
device_id="fco",
65+
initial_device_display_name="new display name",
66+
)
6767
)
6868
self.assertEqual(res2, "fco")
6969

70-
dev = yield self.handler.store.get_device("@boris:foo", "fco")
70+
dev = self.get_success(self.handler.store.get_device("@boris:foo", "fco"))
7171
self.assertEqual(dev["display_name"], "display name")
7272

73-
@defer.inlineCallbacks
7473
def test_device_id_is_made_up_if_unspecified(self):
75-
device_id = yield self.handler.check_device_registered(
76-
user_id="@theresa:foo",
77-
device_id=None,
78-
initial_device_display_name="display",
74+
device_id = self.get_success(
75+
self.handler.check_device_registered(
76+
user_id="@theresa:foo",
77+
device_id=None,
78+
initial_device_display_name="display",
79+
)
7980
)
8081

81-
dev = yield self.handler.store.get_device("@theresa:foo", device_id)
82+
dev = self.get_success(self.handler.store.get_device("@theresa:foo", device_id))
8283
self.assertEqual(dev["display_name"], "display")
8384

84-
@defer.inlineCallbacks
8585
def test_get_devices_by_user(self):
86-
yield self._record_users()
86+
self._record_users()
87+
88+
res = self.get_success(self.handler.get_devices_by_user(user1))
8789

88-
res = yield self.handler.get_devices_by_user(user1)
8990
self.assertEqual(3, len(res))
9091
device_map = {d["device_id"]: d for d in res}
9192
self.assertDictContainsSubset(
@@ -119,11 +120,10 @@ def test_get_devices_by_user(self):
119120
device_map["abc"],
120121
)
121122

122-
@defer.inlineCallbacks
123123
def test_get_device(self):
124-
yield self._record_users()
124+
self._record_users()
125125

126-
res = yield self.handler.get_device(user1, "abc")
126+
res = self.get_success(self.handler.get_device(user1, "abc"))
127127
self.assertDictContainsSubset(
128128
{
129129
"user_id": user1,
@@ -135,59 +135,66 @@ def test_get_device(self):
135135
res,
136136
)
137137

138-
@defer.inlineCallbacks
139138
def test_delete_device(self):
140-
yield self._record_users()
139+
self._record_users()
141140

142141
# delete the device
143-
yield self.handler.delete_device(user1, "abc")
142+
self.get_success(self.handler.delete_device(user1, "abc"))
144143

145144
# check the device was deleted
146-
with self.assertRaises(synapse.api.errors.NotFoundError):
147-
yield self.handler.get_device(user1, "abc")
145+
res = self.handler.get_device(user1, "abc")
146+
self.pump()
147+
self.assertIsInstance(
148+
self.failureResultOf(res).value, synapse.api.errors.NotFoundError
149+
)
148150

149151
# we'd like to check the access token was invalidated, but that's a
150152
# bit of a PITA.
151153

152-
@defer.inlineCallbacks
153154
def test_update_device(self):
154-
yield self._record_users()
155+
self._record_users()
155156

156157
update = {"display_name": "new display"}
157-
yield self.handler.update_device(user1, "abc", update)
158+
self.get_success(self.handler.update_device(user1, "abc", update))
158159

159-
res = yield self.handler.get_device(user1, "abc")
160+
res = self.get_success(self.handler.get_device(user1, "abc"))
160161
self.assertEqual(res["display_name"], "new display")
161162

162-
@defer.inlineCallbacks
163163
def test_update_unknown_device(self):
164164
update = {"display_name": "new_display"}
165-
with self.assertRaises(synapse.api.errors.NotFoundError):
166-
yield self.handler.update_device("user_id", "unknown_device_id", update)
165+
res = self.handler.update_device("user_id", "unknown_device_id", update)
166+
self.pump()
167+
self.assertIsInstance(
168+
self.failureResultOf(res).value, synapse.api.errors.NotFoundError
169+
)
167170

168-
@defer.inlineCallbacks
169171
def _record_users(self):
170172
# check this works for both devices which have a recorded client_ip,
171173
# and those which don't.
172-
yield self._record_user(user1, "xyz", "display 0")
173-
yield self._record_user(user1, "fco", "display 1", "token1", "ip1")
174-
yield self._record_user(user1, "abc", "display 2", "token2", "ip2")
175-
yield self._record_user(user1, "abc", "display 2", "token3", "ip3")
174+
self._record_user(user1, "xyz", "display 0")
175+
self._record_user(user1, "fco", "display 1", "token1", "ip1")
176+
self._record_user(user1, "abc", "display 2", "token2", "ip2")
177+
self._record_user(user1, "abc", "display 2", "token3", "ip3")
178+
179+
self._record_user(user2, "def", "dispkay", "token4", "ip4")
176180

177-
yield self._record_user(user2, "def", "dispkay", "token4", "ip4")
181+
self.reactor.advance(10000)
178182

179-
@defer.inlineCallbacks
180183
def _record_user(
181184
self, user_id, device_id, display_name, access_token=None, ip=None
182185
):
183-
device_id = yield self.handler.check_device_registered(
184-
user_id=user_id,
185-
device_id=device_id,
186-
initial_device_display_name=display_name,
186+
device_id = self.get_success(
187+
self.handler.check_device_registered(
188+
user_id=user_id,
189+
device_id=device_id,
190+
initial_device_display_name=display_name,
191+
)
187192
)
188193

189194
if ip is not None:
190-
yield self.store.insert_client_ip(
191-
user_id, access_token, ip, "user_agent", device_id
195+
self.get_success(
196+
self.store.insert_client_ip(
197+
user_id, access_token, ip, "user_agent", device_id
198+
)
192199
)
193-
self.clock.advance_time(1000)
200+
self.reactor.advance(1000)

tests/replication/slave/storage/_base.py

Lines changed: 49 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# Copyright 2016 OpenMarket Ltd
2+
# Copyright 2018 New Vector Ltd
23
#
34
# Licensed under the Apache License, Version 2.0 (the "License");
45
# you may not use this file except in compliance with the License.
@@ -11,89 +12,91 @@
1112
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1213
# See the License for the specific language governing permissions and
1314
# limitations under the License.
14-
import tempfile
1515

1616
from mock import Mock, NonCallableMock
1717

18-
from twisted.internet import defer, reactor
19-
from twisted.internet.defer import Deferred
18+
import attr
2019

2120
from synapse.replication.tcp.client import (
2221
ReplicationClientFactory,
2322
ReplicationClientHandler,
2423
)
2524
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
26-
from synapse.util.logcontext import PreserveLoggingContext, make_deferred_yieldable
2725

2826
from tests import unittest
29-
from tests.utils import setup_test_homeserver
3027

3128

32-
class TestReplicationClientHandler(ReplicationClientHandler):
33-
"""Overrides on_rdata so that we can wait for it to happen"""
29+
class BaseSlavedStoreTestCase(unittest.HomeserverTestCase):
30+
def make_homeserver(self, reactor, clock):
3431

35-
def __init__(self, store):
36-
super(TestReplicationClientHandler, self).__init__(store)
37-
self._rdata_awaiters = []
38-
39-
def await_replication(self):
40-
d = Deferred()
41-
self._rdata_awaiters.append(d)
42-
return make_deferred_yieldable(d)
43-
44-
def on_rdata(self, stream_name, token, rows):
45-
awaiters = self._rdata_awaiters
46-
self._rdata_awaiters = []
47-
super(TestReplicationClientHandler, self).on_rdata(stream_name, token, rows)
48-
with PreserveLoggingContext():
49-
for a in awaiters:
50-
a.callback(None)
51-
52-
53-
class BaseSlavedStoreTestCase(unittest.TestCase):
54-
@defer.inlineCallbacks
55-
def setUp(self):
56-
self.hs = yield setup_test_homeserver(
57-
self.addCleanup,
32+
hs = self.setup_test_homeserver(
5833
"blue",
59-
http_client=None,
6034
federation_client=Mock(),
6135
ratelimiter=NonCallableMock(spec_set=["send_message"]),
6236
)
63-
self.hs.get_ratelimiter().send_message.return_value = (True, 0)
37+
38+
hs.get_ratelimiter().send_message.return_value = (True, 0)
39+
40+
return hs
41+
42+
def prepare(self, reactor, clock, hs):
6443

6544
self.master_store = self.hs.get_datastore()
6645
self.slaved_store = self.STORE_TYPE(self.hs.get_db_conn(), self.hs)
6746
self.event_id = 0
6847

6948
server_factory = ReplicationStreamProtocolFactory(self.hs)
70-
# XXX: mktemp is unsafe and should never be used. but we're just a test.
71-
path = tempfile.mktemp(prefix="base_slaved_store_test_case_socket")
72-
listener = reactor.listenUNIX(path, server_factory)
73-
self.addCleanup(listener.stopListening)
7449
self.streamer = server_factory.streamer
7550

76-
self.replication_handler = TestReplicationClientHandler(self.slaved_store)
51+
self.replication_handler = ReplicationClientHandler(self.slaved_store)
7752
client_factory = ReplicationClientFactory(
7853
self.hs, "client_name", self.replication_handler
7954
)
80-
client_connector = reactor.connectUNIX(path, client_factory)
81-
self.addCleanup(client_factory.stopTrying)
82-
self.addCleanup(client_connector.disconnect)
55+
56+
server = server_factory.buildProtocol(None)
57+
client = client_factory.buildProtocol(None)
58+
59+
@attr.s
60+
class FakeTransport(object):
61+
62+
other = attr.ib()
63+
disconnecting = False
64+
buffer = attr.ib(default=b'')
65+
66+
def registerProducer(self, producer, streaming):
67+
68+
self.producer = producer
69+
70+
def _produce():
71+
self.producer.resumeProducing()
72+
reactor.callLater(0.1, _produce)
73+
74+
reactor.callLater(0.0, _produce)
75+
76+
def write(self, byt):
77+
self.buffer = self.buffer + byt
78+
79+
if getattr(self.other, "transport") is not None:
80+
self.other.dataReceived(self.buffer)
81+
self.buffer = b""
82+
83+
def writeSequence(self, seq):
84+
for x in seq:
85+
self.write(x)
86+
87+
client.makeConnection(FakeTransport(server))
88+
server.makeConnection(FakeTransport(client))
8389

8490
def replicate(self):
8591
"""Tell the master side of replication that something has happened, and then
8692
wait for the replication to occur.
8793
"""
88-
# xxx: should we be more specific in what we wait for?
89-
d = self.replication_handler.await_replication()
9094
self.streamer.on_notifier_poke()
91-
return d
95+
self.pump(0.1)
9296

93-
@defer.inlineCallbacks
9497
def check(self, method, args, expected_result=None):
95-
master_result = yield getattr(self.master_store, method)(*args)
96-
slaved_result = yield getattr(self.slaved_store, method)(*args)
98+
master_result = self.get_success(getattr(self.master_store, method)(*args))
99+
slaved_result = self.get_success(getattr(self.slaved_store, method)(*args))
97100
if expected_result is not None:
98101
self.assertEqual(master_result, expected_result)
99102
self.assertEqual(slaved_result, expected_result)

0 commit comments

Comments
 (0)