Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 58 additions & 7 deletions packages/core/src/core/client.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1892,11 +1892,16 @@ ${JSON.stringify(
);
});

it('should recursively call sendMessageStream with "Please continue." when InvalidStream event is received', async () => {
it('should recursively call sendMessageStream with "Please continue." when InvalidStream event is received for Gemini 2 models', async () => {
vi.spyOn(client['config'], 'getContinueOnFailedApiCall').mockReturnValue(
true,
);
// Arrange
// Arrange - router must return a Gemini 2 model for retry to trigger
mockRouterService.route.mockResolvedValue({
model: 'gemini-2.0-flash',
reason: 'test',
});

const mockStream1 = (async function* () {
yield { type: GeminiEventType.InvalidStream };
})();
Expand Down Expand Up @@ -1926,7 +1931,7 @@ ${JSON.stringify(

// Assert
expect(events).toEqual([
{ type: GeminiEventType.ModelInfo, value: 'default-routed-model' },
{ type: GeminiEventType.ModelInfo, value: 'gemini-2.0-flash' },
{ type: GeminiEventType.InvalidStream },
{ type: GeminiEventType.Content, value: 'Continued content' },
]);
Expand All @@ -1937,7 +1942,7 @@ ${JSON.stringify(
// First call with original request
expect(mockTurnRunFn).toHaveBeenNthCalledWith(
1,
{ model: 'default-routed-model', isChatModel: true },
{ model: 'gemini-2.0-flash', isChatModel: true },
initialRequest,
expect.any(AbortSignal),
undefined,
Expand All @@ -1946,7 +1951,7 @@ ${JSON.stringify(
// Second call with "Please continue."
expect(mockTurnRunFn).toHaveBeenNthCalledWith(
2,
{ model: 'default-routed-model', isChatModel: true },
{ model: 'gemini-2.0-flash', isChatModel: true },
[{ text: 'System: Please continue.' }],
expect.any(AbortSignal),
undefined,
Expand Down Expand Up @@ -1990,11 +1995,57 @@ ${JSON.stringify(
expect(mockTurnRunFn).toHaveBeenCalledTimes(1);
});

it('should not retry with "Please continue." when InvalidStream event is received for non-Gemini-2 models', async () => {
vi.spyOn(client['config'], 'getContinueOnFailedApiCall').mockReturnValue(
true,
);
// Arrange - router returns a non-Gemini-2 model
mockRouterService.route.mockResolvedValue({
model: 'gemini-3.0-pro',
reason: 'test',
});

const mockStream1 = (async function* () {
yield { type: GeminiEventType.InvalidStream };
})();

mockTurnRunFn.mockReturnValueOnce(mockStream1);

const mockChat: Partial<GeminiChat> = {
addHistory: vi.fn(),
setTools: vi.fn(),
getHistory: vi.fn().mockReturnValue([]),
getLastPromptTokenCount: vi.fn(),
};
client['chat'] = mockChat as GeminiChat;

const initialRequest = [{ text: 'Hi' }];
const promptId = 'prompt-id-invalid-stream-non-g2';
const signal = new AbortController().signal;

// Act
const stream = client.sendMessageStream(initialRequest, signal, promptId);
const events = await fromAsync(stream);

// Assert
expect(events).toEqual([
{ type: GeminiEventType.ModelInfo, value: 'gemini-3.0-pro' },
{ type: GeminiEventType.InvalidStream },
]);

// Verify that turn.run was called only once (no retry)
expect(mockTurnRunFn).toHaveBeenCalledTimes(1);
});

it('should stop recursing after one retry when InvalidStream events are repeatedly received', async () => {
vi.spyOn(client['config'], 'getContinueOnFailedApiCall').mockReturnValue(
true,
);
// Arrange
// Arrange - router must return a Gemini 2 model for retry to trigger
mockRouterService.route.mockResolvedValue({
model: 'gemini-2.0-flash',
reason: 'test',
});
// Always return a new invalid stream
mockTurnRunFn.mockImplementation(() =>
(async function* () {
Expand Down Expand Up @@ -2025,7 +2076,7 @@ ${JSON.stringify(
events
.filter((e) => e.type === GeminiEventType.ModelInfo)
.map((e) => e.value),
).toEqual(['default-routed-model']);
).toEqual(['gemini-2.0-flash']);

// Verify that turn.run was called twice
expect(mockTurnRunFn).toHaveBeenCalledTimes(2);
Expand Down
7 changes: 5 additions & 2 deletions packages/core/src/core/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ import {
applyModelSelection,
createAvailabilityContextProvider,
} from '../availability/policyHelpers.js';
import { resolveModel } from '../config/models.js';
import { resolveModel, isGemini2Model } from '../config/models.js';
import type { RetryAvailabilityContext } from '../utils/retry.js';
import { partToString } from '../utils/partUtils.js';
import { coreEvents, CoreEvent } from '../utils/events.js';
Expand Down Expand Up @@ -725,7 +725,10 @@ export class GeminiClient {
}

if (isInvalidStream) {
if (this.config.getContinueOnFailedApiCall()) {
if (
this.config.getContinueOnFailedApiCall() &&
isGemini2Model(modelToUse)
) {
if (isInvalidStreamRetry) {
logContentRetryFailure(
this.config,
Expand Down
Loading