@@ -109,18 +109,14 @@ def test_normal_user_pair(self) -> None:
109109 tok = alice_token ,
110110 )
111111
112- users = self .get_success (self .user_dir_helper .get_users_in_user_directory ())
113- in_public = self .get_success (self .user_dir_helper .get_users_in_public_rooms ())
114- in_private = self .get_success (
115- self .user_dir_helper .get_users_who_share_private_rooms ()
112+ # The user directory should reflect the room memberships above.
113+ users , in_public , in_private = self .get_success (
114+ self .user_dir_helper .get_tables ()
116115 )
117-
118116 self .assertEqual (users , {alice , bob })
117+ self .assertEqual (in_public , {(alice , public ), (bob , public ), (alice , public2 )})
119118 self .assertEqual (
120- set (in_public ), {(alice , public ), (bob , public ), (alice , public2 )}
121- )
122- self .assertEqual (
123- self .user_dir_helper ._compress_shared (in_private ),
119+ in_private ,
124120 {(alice , bob , private ), (bob , alice , private )},
125121 )
126122
@@ -209,6 +205,88 @@ def test_user_not_in_users_table(self) -> None:
209205 in_public = self .get_success (self .user_dir_helper .get_users_in_public_rooms ())
210206 self .assertEqual (set (in_public ), {(user1 , room ), (user2 , room )})
211207
208+ def test_excludes_users_when_making_room_public (self ) -> None :
209+ # Create a regular user and a support user.
210+ alice = self .register_user ("alice" , "pass" )
211+ alice_token = self .login (alice , "pass" )
212+ support = "@support1:test"
213+ self .get_success (
214+ self .store .register_user (
215+ user_id = support , password_hash = None , user_type = UserTypes .SUPPORT
216+ )
217+ )
218+
219+ # Make a public and private room containing Alice and the support user
220+ public , initially_private = self ._create_rooms_and_inject_memberships (
221+ alice , alice_token , support
222+ )
223+ self ._check_only_one_user_in_directory (alice , public )
224+
225+ # Alice makes the private room public.
226+ self .helper .send_state (
227+ initially_private ,
228+ "m.room.join_rules" ,
229+ {"join_rule" : "public" },
230+ tok = alice_token ,
231+ )
232+
233+ users , in_public , in_private = self .get_success (
234+ self .user_dir_helper .get_tables ()
235+ )
236+ self .assertEqual (users , {alice })
237+ self .assertEqual (in_public , {(alice , public ), (alice , initially_private )})
238+ self .assertEqual (in_private , set ())
239+
240+ def test_switching_from_private_to_public_to_private (self ) -> None :
241+ """Check we update the room sharing tables when switching a room
242+ from private to public, then back again to private."""
243+ # Alice and Bob share a private room.
244+ alice = self .register_user ("alice" , "pass" )
245+ alice_token = self .login (alice , "pass" )
246+ bob = self .register_user ("bob" , "pass" )
247+ bob_token = self .login (bob , "pass" )
248+ room = self .helper .create_room_as (alice , is_public = False , tok = alice_token )
249+ self .helper .invite (room , alice , bob , tok = alice_token )
250+ self .helper .join (room , bob , tok = bob_token )
251+
252+ # The user directory should reflect this.
253+ def check_user_dir_for_private_room () -> None :
254+ users , in_public , in_private = self .get_success (
255+ self .user_dir_helper .get_tables ()
256+ )
257+ self .assertEqual (users , {alice , bob })
258+ self .assertEqual (in_public , set ())
259+ self .assertEqual (in_private , {(alice , bob , room ), (bob , alice , room )})
260+
261+ check_user_dir_for_private_room ()
262+
263+ # Alice makes the room public.
264+ self .helper .send_state (
265+ room ,
266+ "m.room.join_rules" ,
267+ {"join_rule" : "public" },
268+ tok = alice_token ,
269+ )
270+
271+ # The user directory should be updated accordingly
272+ users , in_public , in_private = self .get_success (
273+ self .user_dir_helper .get_tables ()
274+ )
275+ self .assertEqual (users , {alice , bob })
276+ self .assertEqual (in_public , {(alice , room ), (bob , room )})
277+ self .assertEqual (in_private , set ())
278+
279+ # Alice makes the room private.
280+ self .helper .send_state (
281+ room ,
282+ "m.room.join_rules" ,
283+ {"join_rule" : "invite" },
284+ tok = alice_token ,
285+ )
286+
287+ # The user directory should be updated accordingly
288+ check_user_dir_for_private_room ()
289+
212290 def _create_rooms_and_inject_memberships (
213291 self , creator : str , token : str , joiner : str
214292 ) -> Tuple [str , str ]:
@@ -232,15 +310,18 @@ def _create_rooms_and_inject_memberships(
232310 return public_room , private_room
233311
234312 def _check_only_one_user_in_directory (self , user : str , public : str ) -> None :
235- users = self .get_success (self .user_dir_helper .get_users_in_user_directory ())
236- in_public = self .get_success (self .user_dir_helper .get_users_in_public_rooms ())
237- in_private = self .get_success (
238- self .user_dir_helper .get_users_who_share_private_rooms ()
239- )
313+ """Check that the user directory DB tables show that:
240314
315+ - only one user is in the user directory
316+ - they belong to exactly one public room
317+ - they don't share a private room with anyone.
318+ """
319+ users , in_public , in_private = self .get_success (
320+ self .user_dir_helper .get_tables ()
321+ )
241322 self .assertEqual (users , {user })
242- self .assertEqual (set ( in_public ) , {(user , public )})
243- self .assertEqual (in_private , [] )
323+ self .assertEqual (in_public , {(user , public )})
324+ self .assertEqual (in_private , set () )
244325
245326 def test_handle_local_profile_change_with_support_user (self ) -> None :
246327 support_user_id = "@support:test"
@@ -581,11 +662,8 @@ def test_private_room(self) -> None:
581662 self .user_dir_helper .get_users_in_public_rooms ()
582663 )
583664
584- self .assertEqual (
585- self .user_dir_helper ._compress_shared (shares_private ),
586- {(u1 , u2 , room ), (u2 , u1 , room )},
587- )
588- self .assertEqual (public_users , [])
665+ self .assertEqual (shares_private , {(u1 , u2 , room ), (u2 , u1 , room )})
666+ self .assertEqual (public_users , set ())
589667
590668 # We get one search result when searching for user2 by user1.
591669 s = self .get_success (self .handler .search_users (u1 , "user2" , 10 ))
@@ -610,8 +688,8 @@ def test_private_room(self) -> None:
610688 self .user_dir_helper .get_users_in_public_rooms ()
611689 )
612690
613- self .assertEqual (self . user_dir_helper . _compress_shared ( shares_private ) , set ())
614- self .assertEqual (public_users , [] )
691+ self .assertEqual (shares_private , set ())
692+ self .assertEqual (public_users , set () )
615693
616694 # User1 now gets no search results for any of the other users.
617695 s = self .get_success (self .handler .search_users (u1 , "user2" , 10 ))
@@ -645,11 +723,8 @@ def test_spam_checker(self) -> None:
645723 self .user_dir_helper .get_users_in_public_rooms ()
646724 )
647725
648- self .assertEqual (
649- self .user_dir_helper ._compress_shared (shares_private ),
650- {(u1 , u2 , room ), (u2 , u1 , room )},
651- )
652- self .assertEqual (public_users , [])
726+ self .assertEqual (shares_private , {(u1 , u2 , room ), (u2 , u1 , room )})
727+ self .assertEqual (public_users , set ())
653728
654729 # We get one search result when searching for user2 by user1.
655730 s = self .get_success (self .handler .search_users (u1 , "user2" , 10 ))
@@ -704,11 +779,8 @@ def test_legacy_spam_checker(self) -> None:
704779 self .user_dir_helper .get_users_in_public_rooms ()
705780 )
706781
707- self .assertEqual (
708- self .user_dir_helper ._compress_shared (shares_private ),
709- {(u1 , u2 , room ), (u2 , u1 , room )},
710- )
711- self .assertEqual (public_users , [])
782+ self .assertEqual (shares_private , {(u1 , u2 , room ), (u2 , u1 , room )})
783+ self .assertEqual (public_users , set ())
712784
713785 # Configure a spam checker.
714786 spam_checker = self .hs .get_spam_checker ()
@@ -740,8 +812,8 @@ def test_initial_share_all_users(self) -> None:
740812 )
741813
742814 # No users share rooms
743- self .assertEqual (public_users , [] )
744- self .assertEqual (self . user_dir_helper . _compress_shared ( shares_private ) , set ())
815+ self .assertEqual (public_users , set () )
816+ self .assertEqual (shares_private , set ())
745817
746818 # Despite not sharing a room, search_all_users means we get a search
747819 # result.
0 commit comments