diff --git a/src/e2ee/CryptoClient.ts b/src/e2ee/CryptoClient.ts index 0adcc4b0..b5d69860 100644 --- a/src/e2ee/CryptoClient.ts +++ b/src/e2ee/CryptoClient.ts @@ -123,6 +123,10 @@ export class CryptoClient { */ public async onRoomJoin(roomId: string) { await this.roomTracker.onRoomJoin(roomId); + if (await this.isRoomEncrypted(roomId)) { + const members = await this.client.getRoomMembers(roomId, null, ['join', 'invite']); + await this.engine.addTrackedUsers(members.map(e => e.membershipFor)); + } } /** diff --git a/test/encryption/CryptoClientTest.ts b/test/encryption/CryptoClientTest.ts index 02ed769e..5f9a3d2d 100644 --- a/test/encryption/CryptoClientTest.ts +++ b/test/encryption/CryptoClientTest.ts @@ -555,5 +555,45 @@ describe('CryptoClient', () => { // end up not running fast enough for our callCount checks. await Promise.all([prom1, prom2]); }); + + it('should update the tracked users when joining a new room', async () => { + // Stub the room tracker + (client.crypto as any).roomTracker.onRoomJoin = () => {}; + + const targetUserIds = ["@bob:example.org", "@charlie:example.org"]; + const prom1 = new Promise(extResolve => { + (client.crypto as any).engine.addTrackedUsers = simple.mock().callFn((uids) => { + expect(uids).toEqual(targetUserIds); + extResolve(); + return Promise.resolve(); + }); + }); + + const roomId = "!room:example.org"; + const prom2 = new Promise(extResolve => { + client.getRoomMembers = simple.mock().callFn((rid, token, memberships) => { + expect(rid).toEqual(roomId); + expect(token).toBeFalsy(); + expect(memberships).toEqual(["join", "invite"]); + extResolve(); + return Promise.resolve(targetUserIds.map(u => new MembershipEvent({ + type: "m.room.member", + state_key: u, + content: { membership: "join" }, + sender: u, + }))); + }); + }); + + client.crypto.isRoomEncrypted = async (rid) => { + expect(rid).toEqual(roomId); + return true; + }; + client.emit("room.join", roomId); + + // We do weird promise things because `emit()` is sync and we're using async code, so it can + // end up not running fast enough for our callCount checks. + await Promise.all([prom1, prom2]); + }); }); });