Skip to content

Commit 2dc8e80

Browse files
committed
feat: efficient concurrent sessions of same thread support
1 parent 6c94508 commit 2dc8e80

10 files changed

Lines changed: 93 additions & 9 deletions

File tree

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
import { ShardedCounter } from '@convex-dev/sharded-counter'
2+
import { components } from './_generated/api'
3+
4+
export const messagesInThreadCounter = new ShardedCounter(components.shardedCounter, { defaultShards: 1 })

apps/backend-convex/convex/_generated/api.d.ts

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
* @module
99
*/
1010

11+
import type * as _counters from "../_counters.js";
1112
import type * as authInfo from "../authInfo.js";
1213
import type * as crons from "../crons.js";
1314
import type * as http_ai from "../http/ai.js";
@@ -31,6 +32,7 @@ import type {
3132
* ```
3233
*/
3334
declare const fullApi: ApiFromModules<{
35+
_counters: typeof _counters;
3436
authInfo: typeof authInfo;
3537
crons: typeof crons;
3638
"http/ai": typeof http_ai;
@@ -195,4 +197,28 @@ export declare const components: {
195197
>;
196198
};
197199
};
200+
shardedCounter: {
201+
public: {
202+
add: FunctionReference<
203+
"mutation",
204+
"internal",
205+
{ count: number; name: string; shard?: number; shards?: number },
206+
number
207+
>;
208+
count: FunctionReference<"query", "internal", { name: string }, number>;
209+
estimateCount: FunctionReference<
210+
"query",
211+
"internal",
212+
{ name: string; readFromShards?: number; shards?: number },
213+
any
214+
>;
215+
rebalance: FunctionReference<
216+
"mutation",
217+
"internal",
218+
{ name: string; shards?: number },
219+
any
220+
>;
221+
reset: FunctionReference<"mutation", "internal", { name: string }, any>;
222+
};
223+
};
198224
};
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import rateLimiter from '@convex-dev/rate-limiter/convex.config'
2+
import shardedCounter from '@convex-dev/sharded-counter/convex.config'
23
import { defineApp } from 'convex/server'
34

45
const app: ReturnType<typeof defineApp> = defineApp()
56
app.use(rateLimiter)
7+
app.use(shardedCounter)
68

79
export default app

apps/backend-convex/convex/http/ai.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ aiApp
163163
}
164164

165165
// Get conversation history
166-
const messages = await c.env.runQuery(api.messages.list, { threadId, lockerKey })
166+
const messages = await c.env.runQuery(api.messages.listByThread, { threadId, lockerKey })
167167

168168
// Prepare messages for AI API (exclude the streaming messages)
169169
const messagesContext = messages

apps/backend-convex/convex/messages.ts

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import { objectPick } from '@namesmt/utils'
22
import { ConvexError, v } from 'convex/values'
3+
import { messagesInThreadCounter } from './_counters'
34
import { internalMutation, internalQuery, mutation, query } from './_generated/server'
45
import { assertThreadAccess } from './threads'
56

6-
export const list = query({
7+
export const listByThread = query({
78
args: {
89
threadId: v.id('threads'),
910
lockerKey: v.optional(v.string()),
@@ -24,6 +25,22 @@ export const list = query({
2425
},
2526
})
2627

28+
export const countByThread = query({
29+
args: {
30+
threadId: v.id('threads'),
31+
lockerKey: v.optional(v.string()),
32+
},
33+
handler: async (ctx, args) => {
34+
const thread = await ctx.db.get(args.threadId)
35+
if (!thread)
36+
throw new ConvexError('Thread not found')
37+
38+
await assertThreadAccess(ctx, { thread, lockerKey: args.lockerKey })
39+
40+
return await messagesInThreadCounter.count(ctx, args.threadId)
41+
},
42+
})
43+
2744
export const get = query({
2845
args: {
2946
messageId: v.id('messages'),
@@ -62,6 +79,8 @@ export const add = mutation({
6279

6380
await assertThreadAccess(ctx, { thread, lockerKey: args.lockerKey })
6481

82+
await messagesInThreadCounter.inc(ctx, args.threadId)
83+
6584
return await ctx.db.insert('messages', {
6685
...objectPick(args, ['threadId', 'role', 'content', 'isStreaming', 'streamId', 'provider', 'model']),
6786
timestamp: Date.now(),

apps/backend-convex/convex/threads.ts

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import { clearUndefined } from '@namesmt/utils'
66
import { openrouter } from '@openrouter/ai-sdk-provider'
77
import { generateText } from 'ai'
88
import { ConvexError, v } from 'convex/values'
9+
import { messagesInThreadCounter } from './_counters'
910
import { api, internal } from './_generated/api'
1011
import { action, internalMutation, mutation, query } from './_generated/server'
1112

@@ -138,6 +139,8 @@ export const branchThreadFromMessage = mutation({
138139
})
139140
}))
140141

142+
await messagesInThreadCounter.inc(ctx, thread._id)
143+
141144
return newThreadId
142145
},
143146
})
@@ -188,7 +191,7 @@ export const generateThreadTitle = action({
188191
handler: async (ctx, args) => {
189192
const thread = await ctx.runQuery(api.threads.get, { threadId: args.threadId, lockerKey: args.lockerKey })
190193

191-
const messages = await ctx.runQuery(api.messages.list, { threadId: args.threadId, lockerKey: args.lockerKey })
194+
const messages = await ctx.runQuery(api.messages.listByThread, { threadId: args.threadId, lockerKey: args.lockerKey })
192195

193196
const { text } = await generateText({
194197
model: openrouter('qwen/qwen3-8b:free'),

apps/backend-convex/package.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
"dependencies": {
2020
"@ai-sdk/openai": "^1.3.22",
2121
"@convex-dev/rate-limiter": "^0.2.7",
22+
"@convex-dev/sharded-counter": "^0.1.8",
2223
"@hono/zod-validator": "^0.7.0",
2324
"@openrouter/ai-sdk-provider": "^0.7.2",
2425
"ai": "^4.3.16",

apps/frontend/app/components/chat/ChatInterface.vue

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ const nearTopBottom = computed(() => {
4242
4343
const threadIdRef = useThreadIdRef()
4444
const isThreadFrozen = computed(() => chatContext.activeThread.value?.frozen)
45+
const fetchKey = ref(0)
4546
4647
const cachedThreadsMessages: {
4748
[threadId: string]: Array<CustomMessage>
@@ -52,10 +53,10 @@ const streamingMessages = ref(0)
5253
const isFetching = ref(false)
5354
const chatInput = ref('')
5455
55-
// Fetch messages and resume streams
56+
// Fetch messages as needed and resume streams
5657
const { ignoreUpdates: ignorePathUpdate } = watchIgnorable(
57-
threadIdRef,
58-
async (threadId, oldThreadId) => {
58+
[threadIdRef, fetchKey],
59+
async ([threadId], [oldThreadId]) => {
5960
if (oldThreadId)
6061
cachedThreadsMessages[oldThreadId] = messages.value
6162
@@ -65,7 +66,7 @@ const { ignoreUpdates: ignorePathUpdate } = watchIgnorable(
6566
6667
if (threadId) {
6768
isFetching.value = true
68-
await convex.query(api.messages.list, { threadId: threadId as Doc<'threads'>['_id'], lockerKey: getLockerKey(threadId) })
69+
await convex.query(api.messages.listByThread, { threadId: threadId as Doc<'threads'>['_id'], lockerKey: getLockerKey(threadId) })
6970
.then((existingMessages) => {
7071
if (threadIdRef.value === threadId) {
7172
messages.value = existingMessages.map(customMessageTransform)
@@ -77,7 +78,6 @@ const { ignoreUpdates: ignorePathUpdate } = watchIgnorable(
7778
7879
// If the owner have deleted the thread, remove it locally
7980
// (or the demo crons cleaned it)
80-
console.log({ a: getConvexErrorMessage(e), e })
8181
if (getConvexErrorMessage(e) === 'Thread not found') {
8282
toast({ variant: 'destructive', description: t('chat.toast.threadRemovedExternal') })
8383
@@ -109,6 +109,23 @@ const { ignoreUpdates: ignorePathUpdate } = watchIgnorable(
109109
{ immediate: true },
110110
)
111111
112+
// Subscribe to a counter to check for messages from other concurrent sessions
113+
watchImmediate(threadIdRef, (threadId) => {
114+
console.log(`Subscribing to messages count of: ${threadId}`)
115+
const { unsubscribe } = convex.onUpdate(
116+
api.messages.countByThread,
117+
{ threadId: threadId as Id<'threads'>, lockerKey: getLockerKey(threadId) },
118+
(count) => {
119+
if (count > messages.value.length)
120+
++fetchKey.value
121+
},
122+
)
123+
watchOnce(threadIdRef, () => {
124+
unsubscribe()
125+
console.log(`Unsubscribed from ${threadId}`)
126+
})
127+
})
128+
112129
interface HandleSubmitArgs {
113130
input: string
114131
confirmMultiStream?: boolean

apps/frontend/app/pages/chat/[...all].vue

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ provideSidebarContext({
3737
</div> -->
3838

3939
<SidebarProvider>
40-
<ChatSidebar />
40+
<ChatSidebar class="z-5" />
4141
<ChatInterface class="h-full w-full" />
4242
<ChatFloatingMenu class="absolute left-2 top-2 z-10" />
4343
</SidebarProvider>

pnpm-lock.yaml

Lines changed: 12 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)