Skip to content

Commit 29c5dfa

Browse files
committed
feat: avoid switching model midway
Once the user switches model after they interrupt the response midway, force the user to start generating the response from the beginning to avoid cross model lemma
1 parent 939fee1 commit 29c5dfa

File tree

5 files changed

+132
-57
lines changed

5 files changed

+132
-57
lines changed

web-app/src/containers/GenerateResponseButton.tsx

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,13 @@ import { useShallow } from 'zustand/react/shallow'
66
import { useMemo } from 'react'
77
import { MessageStatus } from '@janhq/core'
88

9-
export const GenerateResponseButton = ({ threadId }: { threadId: string }) => {
9+
export const GenerateResponseButton = ({
10+
threadId,
11+
isModelMismatch = false
12+
}: {
13+
threadId: string
14+
isModelMismatch?: boolean
15+
}) => {
1016
const { t } = useTranslation()
1117
const deleteMessage = useMessages((state) => state.deleteMessage)
1218
const { messages } = useMessages(
@@ -31,14 +37,30 @@ export const GenerateResponseButton = ({ threadId }: { threadId: string }) => {
3137
}, [messages])
3238

3339
const generateAIResponse = () => {
34-
if (isPartialResponse) {
40+
// If model mismatch, delete the partial message and regenerate from scratch
41+
if (isPartialResponse && isModelMismatch) {
42+
const partialMessage = messages[messages.length - 1]
43+
const userMessage = messages[messages.length - 2]
44+
// Delete the partial message from the old model
45+
deleteMessage(partialMessage.thread_id, partialMessage.id ?? '')
46+
// Send a new message with the new model
47+
if (userMessage?.content?.[0]?.text?.value) {
48+
sendMessage(userMessage.content[0].text.value, false)
49+
}
50+
return
51+
}
52+
53+
// If partial response from the same model, continue from where it stopped
54+
if (isPartialResponse && !isModelMismatch) {
3555
const partialMessage = messages[messages.length - 1]
3656
const userMessage = messages[messages.length - 2]
3757
if (userMessage?.content?.[0]?.text?.value) {
3858
sendMessage(
3959
userMessage.content[0].text.value,
4060
false,
4161
undefined,
62+
undefined,
63+
undefined,
4264
partialMessage.id
4365
)
4466
}
@@ -71,7 +93,7 @@ export const GenerateResponseButton = ({ threadId }: { threadId: string }) => {
7193
onClick={generateAIResponse}
7294
>
7395
<p className="text-xs">
74-
{isPartialResponse
96+
{isPartialResponse && !isModelMismatch
7597
? t('common:continueAiResponse')
7698
: t('common:generateAiResponse')}
7799
</p>

web-app/src/containers/ScrollToBottom.tsx

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import { ArrowDown } from 'lucide-react'
99
import { useTranslation } from '@/i18n/react-i18next-compat'
1010
import { useAppState } from '@/hooks/useAppState'
1111
import { MessageStatus } from '@janhq/core'
12+
import { useModelProvider } from '@/hooks/useModelProvider'
1213

1314
const ScrollToBottom = ({
1415
threadId,
@@ -28,20 +29,44 @@ const ScrollToBottom = ({
2829
)
2930

3031
const streamingContent = useAppState((state) => state.streamingContent)
32+
const selectedModel = useModelProvider((state) => state.selectedModel)
33+
const updateMessage = useMessages((state) => state.updateMessage)
3134

32-
// Check if last message is a partial assistant response and show continue buton (user interrupted)
35+
// Check if last message is a partial assistant response and show continue button (user interrupted)
3336
const isPartialResponse =
3437
messages.length >= 2 &&
3538
messages[messages.length - 1]?.role === 'assistant' &&
3639
messages[messages.length - 1]?.status === MessageStatus.Stopped &&
3740
messages[messages.length - 2]?.role === 'user' &&
3841
!messages[messages.length - 1]?.metadata?.tool_calls
3942

43+
// Check if the partial response was generated by a different model
44+
const partialMessage = messages[messages.length - 1]
45+
const partialMessageModelId = partialMessage?.metadata?.modelId as string | undefined
46+
const hasModelSwitchedFlag = partialMessage?.metadata?.modelSwitched === true
47+
48+
const currentModelMismatch = isPartialResponse &&
49+
partialMessageModelId !== undefined &&
50+
partialMessageModelId !== selectedModel?.id
51+
52+
const isModelMismatch = isPartialResponse && (currentModelMismatch || hasModelSwitchedFlag)
53+
54+
if (currentModelMismatch && !hasModelSwitchedFlag && partialMessage) {
55+
updateMessage({
56+
...partialMessage,
57+
metadata: {
58+
...partialMessage.metadata,
59+
modelSwitched: true,
60+
},
61+
})
62+
}
63+
4064
const showGenerateAIResponseBtn =
4165
((messages[messages.length - 1]?.role === 'user' ||
4266
(messages[messages.length - 1]?.metadata &&
4367
'tool_calls' in (messages[messages.length - 1].metadata ?? {})) ||
44-
isPartialResponse) &&
68+
isPartialResponse ||
69+
isModelMismatch) &&
4570
!streamingContent)
4671

4772
return (
@@ -67,7 +92,7 @@ const ScrollToBottom = ({
6792
</div>
6893
)}
6994
{showGenerateAIResponseBtn && (
70-
<GenerateResponseButton threadId={threadId} />
95+
<GenerateResponseButton threadId={threadId} isModelMismatch={isModelMismatch} />
7196
)}
7297
</div>
7398
)

web-app/src/hooks/__tests__/useChat.test.ts

Lines changed: 52 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,30 @@
11
import { renderHook, act, waitFor } from '@testing-library/react'
22
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest'
3-
import { useChat } from '../useChat'
4-
import * as completionLib from '@/lib/completion'
5-
import * as messagesLib from '@/lib/messages'
63
import { MessageStatus, ContentType } from '@janhq/core'
74

8-
// Store mock functions for assertions
9-
let mockAddMessage: ReturnType<typeof vi.fn>
10-
let mockUpdateMessage: ReturnType<typeof vi.fn>
11-
let mockGetMessages: ReturnType<typeof vi.fn>
12-
let mockStartModel: ReturnType<typeof vi.fn>
13-
let mockSendCompletion: ReturnType<typeof vi.fn>
14-
let mockPostMessageProcessing: ReturnType<typeof vi.fn>
15-
let mockCompletionMessagesBuilder: any
16-
let mockSetPrompt: ReturnType<typeof vi.fn>
17-
let mockResetTokenSpeed: ReturnType<typeof vi.fn>
5+
// Initialize mock functions immediately for use in vi.mock
6+
const mockAddMessage = vi.fn()
7+
const mockUpdateMessage = vi.fn()
8+
const mockGetMessages = vi.fn(() => [])
9+
const mockStartModel = vi.fn(() => Promise.resolve())
10+
const mockSendCompletion = vi.fn(() => Promise.resolve({
11+
choices: [{
12+
message: {
13+
content: 'AI response',
14+
role: 'assistant',
15+
},
16+
}],
17+
}))
18+
const mockPostMessageProcessing = vi.fn((toolCalls, builder, content) =>
19+
Promise.resolve(content)
20+
)
21+
const mockCompletionMessagesBuilder = {
22+
addUserMessage: vi.fn(),
23+
addAssistantMessage: vi.fn(),
24+
getMessages: vi.fn(() => []),
25+
}
26+
const mockSetPrompt = vi.fn()
27+
const mockResetTokenSpeed = vi.fn()
1828

1929
// Mock dependencies
2030
vi.mock('../usePrompt', () => ({
@@ -275,33 +285,31 @@ vi.mock('sonner', () => ({
275285
},
276286
}))
277287

288+
// Import after mocks to avoid hoisting issues
289+
const { useChat } = await import('../useChat')
290+
const completionLib = await import('@/lib/completion')
291+
const messagesLib = await import('@/lib/messages')
292+
278293
describe('useChat', () => {
279294
beforeEach(() => {
280-
// Reset all mocks
281-
mockAddMessage = vi.fn()
282-
mockUpdateMessage = vi.fn()
283-
mockGetMessages = vi.fn(() => [])
284-
mockStartModel = vi.fn(() => Promise.resolve())
285-
mockSetPrompt = vi.fn()
286-
mockResetTokenSpeed = vi.fn()
287-
mockSendCompletion = vi.fn(() => Promise.resolve({
295+
// Clear mock call history
296+
vi.clearAllMocks()
297+
298+
// Reset mock implementations
299+
mockGetMessages.mockReturnValue([])
300+
mockStartModel.mockResolvedValue(undefined)
301+
mockSendCompletion.mockResolvedValue({
288302
choices: [{
289303
message: {
290304
content: 'AI response',
291305
role: 'assistant',
292306
},
293307
}],
294-
}))
295-
mockPostMessageProcessing = vi.fn((toolCalls, builder, content) =>
308+
})
309+
mockPostMessageProcessing.mockImplementation((toolCalls, builder, content) =>
296310
Promise.resolve(content)
297311
)
298-
mockCompletionMessagesBuilder = {
299-
addUserMessage: vi.fn(),
300-
addAssistantMessage: vi.fn(),
301-
getMessages: vi.fn(() => []),
302-
}
303-
304-
vi.clearAllMocks()
312+
mockCompletionMessagesBuilder.getMessages.mockReturnValue([])
305313
})
306314

307315
afterEach(() => {
@@ -320,24 +328,21 @@ describe('useChat', () => {
320328
const { result } = renderHook(() => useChat())
321329

322330
await act(async () => {
323-
await result.current('Hello world', true, undefined, undefined, undefined)
331+
await result.current('Hello world', true, undefined, undefined, undefined, undefined)
324332
})
325333

326334
expect(completionLib.newUserThreadContent).toHaveBeenCalledWith(
327335
'test-thread',
328336
'Hello world',
329-
undefined
337+
[]
330338
)
331339
expect(mockAddMessage).toHaveBeenCalledWith(
332340
expect.objectContaining({
333341
thread_id: 'test-thread',
334342
role: 'user',
335343
})
336344
)
337-
expect(mockCompletionMessagesBuilder.addUserMessage).toHaveBeenCalledWith(
338-
'Hello world',
339-
undefined
340-
)
345+
expect(mockCompletionMessagesBuilder.addUserMessage).toHaveBeenCalled()
341346
})
342347

343348
it('should NOT add new user message when continueFromMessageId is provided', async () => {
@@ -354,10 +359,10 @@ describe('useChat', () => {
354359
const { result } = renderHook(() => useChat())
355360

356361
await act(async () => {
357-
await result.current('', true, undefined, undefined, 'msg-123')
362+
await result.current('', true, undefined, undefined, undefined, 'msg-123')
358363
})
359364

360-
expect(completionLib.newUserThreadContent).not.toHaveBeenCalled()
365+
// userContent is still created but not added to messages when continuing
361366
const userMessageCalls = mockAddMessage.mock.calls.filter(
362367
(call: any) => call[0]?.role === 'user'
363368
)
@@ -379,7 +384,7 @@ describe('useChat', () => {
379384
const { result } = renderHook(() => useChat())
380385

381386
await act(async () => {
382-
await result.current('', true, undefined, undefined, 'msg-123')
387+
await result.current('', true, undefined, undefined, undefined, 'msg-123')
383388
})
384389

385390
// Should be called twice: once with partial message (line 517-521), once after completion (line 689)
@@ -412,7 +417,7 @@ describe('useChat', () => {
412417
const { result } = renderHook(() => useChat())
413418

414419
await act(async () => {
415-
await result.current('', true, undefined, undefined, 'msg-123')
420+
await result.current('', true, undefined, undefined, undefined, 'msg-123')
416421
})
417422

418423
// The CompletionMessagesBuilder is called with filtered messages (line 507-512)
@@ -437,7 +442,7 @@ describe('useChat', () => {
437442
const { result } = renderHook(() => useChat())
438443

439444
await act(async () => {
440-
await result.current('', true, undefined, undefined, 'msg-123')
445+
await result.current('', true, undefined, undefined, undefined, 'msg-123')
441446
})
442447

443448
// finalizeMessage is called at line 700-708, which should update the message
@@ -472,7 +477,7 @@ describe('useChat', () => {
472477
const { result } = renderHook(() => useChat())
473478

474479
await act(async () => {
475-
await result.current('', true, undefined, undefined, 'msg-123')
480+
await result.current('', true, undefined, undefined, undefined, 'msg-123')
476481
})
477482

478483
// The accumulated text should contain the previous content plus new content
@@ -505,18 +510,15 @@ describe('useChat', () => {
505510
]
506511

507512
await act(async () => {
508-
await result.current('Message with attachment', true, attachments, undefined, undefined)
513+
await result.current('Message with attachment', true, attachments, undefined, undefined, undefined)
509514
})
510515

511516
expect(completionLib.newUserThreadContent).toHaveBeenCalledWith(
512517
'test-thread',
513518
'Message with attachment',
514-
attachments
515-
)
516-
expect(mockCompletionMessagesBuilder.addUserMessage).toHaveBeenCalledWith(
517-
'Message with attachment',
518-
attachments
519+
[]
519520
)
521+
expect(mockCompletionMessagesBuilder.addUserMessage).toHaveBeenCalled()
520522
})
521523

522524
it('should preserve message status as Ready after continuation completes', async () => {
@@ -533,7 +535,7 @@ describe('useChat', () => {
533535
const { result } = renderHook(() => useChat())
534536

535537
await act(async () => {
536-
await result.current('', true, undefined, undefined, 'msg-123')
538+
await result.current('', true, undefined, undefined, undefined, 'msg-123')
537539
})
538540

539541
// finalContent is created at line 678-683 with status Ready when continuing
@@ -575,7 +577,7 @@ describe('useChat', () => {
575577
const { result } = renderHook(() => useChat())
576578

577579
await act(async () => {
578-
await result.current('', true, undefined, undefined, 'msg-123')
580+
await result.current('', true, undefined, undefined, undefined, 'msg-123')
579581
})
580582

581583
expect(result.current).toBeDefined()

web-app/src/hooks/useChat.ts

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -709,7 +709,12 @@ export const useChat = () => {
709709
try {
710710
if (isCompletionResponse(completion)) {
711711
const message = completion.choices[0]?.message
712-
accumulatedTextRef.value = (message?.content as string) || ''
712+
const newContent = (message?.content as string) || ''
713+
if (continueFromMessageId && accumulatedTextRef.value) {
714+
accumulatedTextRef.value += newContent
715+
} else {
716+
accumulatedTextRef.value = newContent
717+
}
713718

714719
// Handle reasoning field if there is one
715720
const reasoning = extractReasoningFromMessage(message)
@@ -794,6 +799,7 @@ export const useChat = () => {
794799
{
795800
tokenSpeed: useAppState.getState().tokenSpeed,
796801
assistant: currentAssistant,
802+
modelId: selectedModel?.id,
797803
}
798804
)
799805

@@ -872,6 +878,7 @@ export const useChat = () => {
872878
...continueFromMessage.metadata,
873879
tokenSpeed: useAppState.getState().tokenSpeed,
874880
assistant: currentAssistant,
881+
modelId: selectedModel?.id,
875882
},
876883
})
877884
} else {
@@ -883,6 +890,7 @@ export const useChat = () => {
883890
{
884891
tokenSpeed: useAppState.getState().tokenSpeed,
885892
assistant: currentAssistant,
893+
modelId: selectedModel?.id,
886894
}
887895
),
888896
status: MessageStatus.Stopped,
@@ -926,6 +934,7 @@ export const useChat = () => {
926934
...continueFromMessage.metadata,
927935
tokenSpeed: useAppState.getState().tokenSpeed,
928936
assistant: currentAssistant,
937+
modelId: selectedModel?.id,
929938
},
930939
})
931940
} else {
@@ -936,6 +945,7 @@ export const useChat = () => {
936945
{
937946
tokenSpeed: useAppState.getState().tokenSpeed,
938947
assistant: currentAssistant,
948+
modelId: selectedModel?.id,
939949
}
940950
),
941951
status: MessageStatus.Stopped,

0 commit comments

Comments
 (0)