diff --git a/src/channel_manager.ts b/src/channel_manager.ts index 2805e149e..1c9208400 100644 --- a/src/channel_manager.ts +++ b/src/channel_manager.ts @@ -133,6 +133,10 @@ export type ChannelManagerOptions = { lockChannelOrder?: boolean; }; +export type QueryChannelsRequestType = ( + ...params: Parameters +) => Promise; + export const DEFAULT_CHANNEL_MANAGER_OPTIONS = { abortInFlightQuery: false, allowNotLoadedChannelPromotionForEvent: { @@ -160,6 +164,7 @@ export class ChannelManager extends WithSubscriptions { private client: StreamChat; private eventHandlers: Map = new Map(); private eventHandlerOverrides: Map = new Map(); + private queryChannelsRequest: QueryChannelsRequestType; private options: ChannelManagerOptions = {}; private stateOptions: ChannelStateOptions = {}; private id: string; @@ -168,10 +173,12 @@ export class ChannelManager extends WithSubscriptions { client, eventHandlerOverrides = {}, options = {}, + queryChannelsOverride, }: { client: StreamChat; eventHandlerOverrides?: ChannelManagerEventHandlerOverrides; options?: ChannelManagerOptions; + queryChannelsOverride?: QueryChannelsRequestType; }) { super(); @@ -192,6 +199,8 @@ export class ChannelManager extends WithSubscriptions { }); this.setEventHandlerOverrides(eventHandlerOverrides); this.setOptions(options); + this.queryChannelsRequest = + queryChannelsOverride ?? ((...params) => this.client.queryChannels(...params)); this.eventHandlers = new Map( Object.entries({ channelDeletedHandler: this.channelDeletedHandler, @@ -252,6 +261,10 @@ export class ChannelManager extends WithSubscriptions { ); }; + public setQueryChannelsRequest = (queryChannelsRequest: QueryChannelsRequestType) => { + this.queryChannelsRequest = queryChannelsRequest; + }; + public setOptions = (options: ChannelManagerOptions = {}) => { this.options = { ...DEFAULT_CHANNEL_MANAGER_OPTIONS, ...options }; }; @@ -266,7 +279,7 @@ export class ChannelManager extends WithSubscriptions { ...options, }; try { - const channels = await this.client.queryChannels( + const channels = await this.queryChannelsRequest( filters, sort, options, @@ -407,7 +420,7 @@ export class ChannelManager extends WithSubscriptions { this.state.partialNext({ pagination: { ...pagination, isLoading: false, isLoadingNext: true }, }); - const nextChannels = await this.client.queryChannels( + const nextChannels = await this.queryChannelsRequest( filters, sort, options, diff --git a/src/client.ts b/src/client.ts index ffc030541..8ff633715 100644 --- a/src/client.ts +++ b/src/client.ts @@ -233,6 +233,7 @@ import { PollManager } from './poll_manager'; import type { ChannelManagerEventHandlerOverrides, ChannelManagerOptions, + QueryChannelsRequestType, } from './channel_manager'; import { ChannelManager } from './channel_manager'; import { NotificationManager } from './notifications'; @@ -720,10 +721,18 @@ export class StreamChat { createChannelManager = ({ eventHandlerOverrides = {}, options = {}, + queryChannelsOverride, }: { eventHandlerOverrides?: ChannelManagerEventHandlerOverrides; options?: ChannelManagerOptions; - }) => new ChannelManager({ client: this, eventHandlerOverrides, options }); + queryChannelsOverride?: QueryChannelsRequestType; + }) => + new ChannelManager({ + client: this, + eventHandlerOverrides, + options, + queryChannelsOverride, + }); /** * Creates a new WebSocket connection with the current user. Returns empty promise, if there is an active connection diff --git a/test/unit/channel_manager.test.ts b/test/unit/channel_manager.test.ts index 8c9445e50..c96747688 100644 --- a/test/unit/channel_manager.test.ts +++ b/test/unit/channel_manager.test.ts @@ -9,6 +9,7 @@ import { DEFAULT_CHANNEL_MANAGER_OPTIONS, channelManagerEventToHandlerMapping, DEFAULT_CHANNEL_MANAGER_PAGINATION_OPTIONS, + QueryChannelsRequestType, } from '../../src'; import { generateChannel } from './test-utils/generateChannel'; @@ -52,6 +53,11 @@ describe('ChannelManager', () => { channelManager = client.createChannelManager({}); }); + afterEach(() => { + sinon.restore(); + sinon.reset(); + }); + it('initializes properly', () => { const state = channelManager.state.getLatestValue(); expect(state.channels).to.be.empty; @@ -66,7 +72,7 @@ describe('ChannelManager', () => { expect(state.initialized).to.be.false; }); - it('should properly set eventHandlerOverrides and options if they are passed', () => { + it('should properly set eventHandlerOverrides, options and queryChannelsRequest if they are passed', async () => { const eventHandlerOverrides = { newMessageHandler: () => {} }; const options = { allowNotLoadedChannelPromotionForEvent: { @@ -76,9 +82,16 @@ describe('ChannelManager', () => { 'notification.message_new': false, }, }; + const queryChannelsOverride = async () => { + console.log('Called from override.'); + return new Promise((resolve) => { + resolve([]); + }); + }; const newChannelManager = client.createChannelManager({ eventHandlerOverrides, options, + queryChannelsOverride, }); expect( @@ -88,9 +101,13 @@ describe('ChannelManager', () => { ...DEFAULT_CHANNEL_MANAGER_OPTIONS, ...options, }); + + const consoleLogSpy = vi.spyOn(console, 'log'); + await (newChannelManager as any).queryChannelsRequest({}); + expect(consoleLogSpy).toHaveBeenCalledWith('Called from override.'); }); - it('should properly set the default event handlers', () => { + it('should properly set the default event handlers', async () => { const { eventHandlers, channelDeletedHandler, @@ -113,6 +130,12 @@ describe('ChannelManager', () => { notificationNewMessageHandler, notificationRemovedFromChannelHandler, }); + + const clientQueryChannelsSpy = vi + .spyOn(client, 'queryChannels') + .mockImplementation(async () => []); + await (channelManager as any).queryChannelsRequest({}); + expect(clientQueryChannelsSpy).toHaveBeenCalledOnce(); }); }); @@ -138,6 +161,21 @@ describe('ChannelManager', () => { ).to.deep.equal(eventHandlerOverrides); }); + it('should properly set queryChannelRequest', async () => { + const queryChannelsOverride = async () => { + console.log('Called from override.'); + return new Promise((resolve) => { + resolve([]); + }); + }; + + channelManager.setQueryChannelsRequest(queryChannelsOverride); + + const consoleLogSpy = vi.spyOn(console, 'log'); + await (channelManager as any).queryChannelsRequest({}); + expect(consoleLogSpy).toHaveBeenCalledWith('Called from override.'); + }); + it('should properly set options', () => { const options = { lockChannelOrder: true, @@ -407,6 +445,7 @@ describe('ChannelManager', () => { describe('querying and pagination', () => { let clientQueryChannelsStub: sinon.SinonStub; let mockChannelPages: Array>; + let mockChannelCidMap: Record; let channelManager: ChannelManager; beforeEach(() => { @@ -422,9 +461,20 @@ describe('ChannelManager', () => { client.channel(c.channel.type, c.channel.id), ); }); + mockChannelCidMap = Object.fromEntries( + mockChannelPages.flat().map((obj) => [obj.cid, obj]), + ); clientQueryChannelsStub = sinon .stub(client, 'queryChannels') - .callsFake((_filters, _sort, options) => { + .callsFake((filters, _sort, options) => { + if ( + typeof filters.cid === 'object' && + filters.cid !== null && + '$in' in filters.cid + ) { + const toReturn = (filters.cid['$in'] ?? []) as string[]; + return Promise.resolve(toReturn.map((cid) => mockChannelCidMap[cid])); + } const offset = options?.offset ?? 0; return Promise.resolve(mockChannelPages[Math.floor(offset / 10)]); }); @@ -877,6 +927,39 @@ describe('ChannelManager', () => { expect(offset).to.equal(5); expect(hasNext).to.be.false; }); + + it('should execute queryChannelsOverride if set', async () => { + const fetchedChannels = mockChannelPages[2].concat(mockChannelPages[1]); + const queryChannelsOverride = async ( + ...params: Parameters + ) => { + const [filters, ...restParams] = params; + filters.cid = { $in: fetchedChannels.map((c) => c.cid) }; + + return await client.queryChannels(filters, ...restParams); + }; + channelManager.setQueryChannelsRequest(queryChannelsOverride); + + await channelManager.queryChannels( + { filterA: true }, + { asc: 1 }, + { limit: 15, offset: 0 }, + ); + + const { + channels, + pagination: { + hasNext, + options: { offset }, + }, + } = channelManager.state.getLatestValue(); + + expect(clientQueryChannelsStub.calledOnce).to.be.true; + expect(channels.length).to.equal(15); + expect(channels).to.deep.equal(fetchedChannels); + expect(offset).to.equal(15); + expect(hasNext).to.be.true; + }); }); describe('loadNext', () => { @@ -1173,6 +1256,60 @@ describe('ChannelManager', () => { expect(hasNext).to.be.false; expect(offset).to.equal(25); }); + + it('should properly paginate with queryChannelsOverride if set', async () => { + const fetchedChannels = mockChannelPages[2].concat(mockChannelPages[1]); + const fetchedNextPageChannels = mockChannelPages[0]; + const queryChannelsOverride = async ( + ...params: Parameters + ) => { + const [filters, sort, options, ...restParams] = params; + const isInitialPage = options?.offset === 0; + filters.cid = { + $in: (isInitialPage ? fetchedChannels : fetchedNextPageChannels).map( + (c) => c.cid, + ), + }; + + return await client.queryChannels(filters, sort, options, ...restParams); + }; + + channelManager.setQueryChannelsRequest(queryChannelsOverride); + + await channelManager.queryChannels( + { filterA: true }, + { asc: 1 }, + { limit: 15, offset: 0 }, + ); + + const { + channels: prevChannels, + pagination: { + hasNext: prevHasNext, + options: { offset: prevOffset }, + }, + } = channelManager.state.getLatestValue(); + + expect(prevChannels.length).to.equal(15); + expect(prevChannels).to.deep.equal(fetchedChannels); + expect(prevOffset).to.equal(15); + expect(prevHasNext).to.be.true; + + await channelManager.loadNext(); + + const { + channels, + pagination: { + hasNext, + options: { offset }, + }, + } = channelManager.state.getLatestValue(); + + expect(channels.length).to.equal(25); + expect(channels).to.deep.equal(fetchedChannels.concat(fetchedNextPageChannels)); + expect(offset).to.equal(25); + expect(hasNext).to.be.false; + }); }); });