Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions packages/core/src/config/models.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import {
resolveModel,
resolveClassifierModel,
isGemini3Model,
isGemini2Model,
isAutoModel,
getDisplayString,
Expand All @@ -24,6 +25,29 @@
DEFAULT_GEMINI_MODEL_AUTO,
} from './models.js';

describe('isGemini3Model', () => {
it('should return true for gemini-3 models', () => {
expect(isGemini3Model('gemini-3-pro-preview')).toBe(true);
expect(isGemini3Model('gemini-3-flash-preview')).toBe(true);
});

it('should return true for aliases that resolve to Gemini 3', () => {
expect(isGemini3Model(GEMINI_MODEL_ALIAS_AUTO)).toBe(true);
expect(isGemini3Model(GEMINI_MODEL_ALIAS_PRO)).toBe(true);
expect(isGemini3Model(PREVIEW_GEMINI_MODEL_AUTO)).toBe(true);
});

it('should return false for Gemini 2 models', () => {
expect(isGemini3Model('gemini-2.5-pro')).toBe(false);
expect(isGemini3Model('gemini-2.5-flash')).toBe(false);
expect(isGemini3Model(DEFAULT_GEMINI_MODEL_AUTO)).toBe(false);
});

it('should return false for arbitrary strings', () => {
expect(isGemini3Model('gpt-4')).toBe(false);
});
});

describe('getDisplayString', () => {
it('should return Auto (Gemini 3) for preview auto model', () => {
expect(getDisplayString(PREVIEW_GEMINI_MODEL_AUTO)).toBe('Auto (Gemini 3)');
Expand Down Expand Up @@ -58,7 +82,7 @@
expect(supportsMultimodalFunctionResponse('gemini-3-pro')).toBe(true);
});

it('should return false for gemini-2 models', () => {

Check warning on line 85 in packages/core/src/config/models.test.ts

View workflow job for this annotation

GitHub Actions / Lint

Found sensitive keyword "gemini-2". Please make sure this change is appropriate to submit.
expect(supportsMultimodalFunctionResponse('gemini-2.5-pro')).toBe(false);
expect(supportsMultimodalFunctionResponse('gemini-2.5-flash')).toBe(false);
});
Expand Down
11 changes: 11 additions & 0 deletions packages/core/src/config/models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,17 @@
);
}

/**
* Checks if the model is a Gemini 3 model.
*
* @param model The model name to check.
* @returns True if the model is a Gemini 3 model.
*/
export function isGemini3Model(model: string): boolean {
const resolved = resolveModel(model);
return /^gemini-3(\.|-|$)/.test(resolved);
}

/**
* Checks if the model is a Gemini 2.x model.
*
Expand All @@ -127,7 +138,7 @@
* @returns True if the model is a Gemini-2.x model.
*/
export function isGemini2Model(model: string): boolean {
return /^gemini-2(\.|$)/.test(model);

Check warning on line 141 in packages/core/src/config/models.ts

View workflow job for this annotation

GitHub Actions / Lint

Found sensitive keyword "gemini-2". Please make sure this change is appropriate to submit.
}

/**
Expand Down
24 changes: 22 additions & 2 deletions packages/core/src/routing/strategies/classifierStrategy.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import {
DEFAULT_GEMINI_FLASH_MODEL,
DEFAULT_GEMINI_MODEL,
DEFAULT_GEMINI_MODEL_AUTO,
PREVIEW_GEMINI_MODEL_AUTO,
} from '../../config/models.js';
import { promptIdContext } from '../../utils/promptIdContext.js';
import type { Content } from '@google/genai';
Expand Down Expand Up @@ -50,7 +51,7 @@ describe('ClassifierStrategy', () => {
modelConfigService: {
getResolvedConfig: vi.fn().mockReturnValue(mockResolvedConfig),
},
getModel: () => DEFAULT_GEMINI_MODEL_AUTO,
getModel: vi.fn().mockReturnValue(DEFAULT_GEMINI_MODEL_AUTO),
getNumericalRoutingEnabled: vi.fn().mockResolvedValue(false),
} as unknown as Config;
mockBaseLlmClient = {
Expand All @@ -60,8 +61,9 @@ describe('ClassifierStrategy', () => {
vi.spyOn(promptIdContext, 'getStore').mockReturnValue('test-prompt-id');
});

it('should return null if numerical routing is enabled', async () => {
it('should return null if numerical routing is enabled and model is Gemini 3', async () => {
vi.mocked(mockConfig.getNumericalRoutingEnabled).mockResolvedValue(true);
vi.mocked(mockConfig.getModel).mockReturnValue(PREVIEW_GEMINI_MODEL_AUTO);

const decision = await strategy.route(
mockContext,
Expand All @@ -73,6 +75,24 @@ describe('ClassifierStrategy', () => {
expect(mockBaseLlmClient.generateJson).not.toHaveBeenCalled();
});

it('should NOT return null if numerical routing is enabled but model is NOT Gemini 3', async () => {
vi.mocked(mockConfig.getNumericalRoutingEnabled).mockResolvedValue(true);
vi.mocked(mockConfig.getModel).mockReturnValue(DEFAULT_GEMINI_MODEL_AUTO);
vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue({
reasoning: 'test',
model_choice: 'flash',
});

const decision = await strategy.route(
mockContext,
mockConfig,
mockBaseLlmClient,
);

expect(decision).not.toBeNull();
expect(mockBaseLlmClient.generateJson).toHaveBeenCalled();
});

it('should call generateJson with the correct parameters', async () => {
const mockApiResponse = {
reasoning: 'Simple task',
Expand Down
10 changes: 7 additions & 3 deletions packages/core/src/routing/strategies/classifierStrategy.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import type {
RoutingDecision,
RoutingStrategy,
} from '../routingStrategy.js';
import { resolveClassifierModel } from '../../config/models.js';
import { resolveClassifierModel, isGemini3Model } from '../../config/models.js';
import { createUserContent, Type } from '@google/genai';
import type { Config } from '../../config/config.js';
import {
Expand Down Expand Up @@ -133,7 +133,11 @@ export class ClassifierStrategy implements RoutingStrategy {
): Promise<RoutingDecision | null> {
const startTime = Date.now();
try {
if (await config.getNumericalRoutingEnabled()) {
const model = context.requestedModel ?? config.getModel();
if (
(await config.getNumericalRoutingEnabled()) &&
isGemini3Model(model)
) {
return null;
}

Expand Down Expand Up @@ -164,7 +168,7 @@ export class ClassifierStrategy implements RoutingStrategy {
const reasoning = routerResponse.reasoning;
const latencyMs = Date.now() - startTime;
const selectedModel = resolveClassifierModel(
context.requestedModel ?? config.getModel(),
model,
routerResponse.model_choice,
);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@ import type { RoutingContext } from '../routingStrategy.js';
import type { Config } from '../../config/config.js';
import type { BaseLlmClient } from '../../core/baseLlmClient.js';
import {
DEFAULT_GEMINI_FLASH_MODEL,
DEFAULT_GEMINI_MODEL,
PREVIEW_GEMINI_FLASH_MODEL,
PREVIEW_GEMINI_MODEL,
PREVIEW_GEMINI_MODEL_AUTO,
DEFAULT_GEMINI_MODEL_AUTO,
DEFAULT_GEMINI_MODEL,
} from '../../config/models.js';
import { promptIdContext } from '../../utils/promptIdContext.js';
import type { Content } from '@google/genai';
Expand Down Expand Up @@ -46,7 +48,7 @@ describe('NumericalClassifierStrategy', () => {
modelConfigService: {
getResolvedConfig: vi.fn().mockReturnValue(mockResolvedConfig),
},
getModel: () => DEFAULT_GEMINI_MODEL_AUTO,
getModel: vi.fn().mockReturnValue(PREVIEW_GEMINI_MODEL_AUTO),
getSessionId: vi.fn().mockReturnValue('control-group-id'), // Default to Control Group (Hash 71 >= 50)
getNumericalRoutingEnabled: vi.fn().mockResolvedValue(true),
getClassifierThreshold: vi.fn().mockResolvedValue(undefined),
Expand Down Expand Up @@ -75,6 +77,32 @@ describe('NumericalClassifierStrategy', () => {
expect(mockBaseLlmClient.generateJson).not.toHaveBeenCalled();
});

it('should return null if the model is not a Gemini 3 model', async () => {
vi.mocked(mockConfig.getModel).mockReturnValue(DEFAULT_GEMINI_MODEL_AUTO);

const decision = await strategy.route(
mockContext,
mockConfig,
mockBaseLlmClient,
);

expect(decision).toBeNull();
expect(mockBaseLlmClient.generateJson).not.toHaveBeenCalled();
});

it('should return null if the model is explicitly a Gemini 2 model', async () => {
vi.mocked(mockConfig.getModel).mockReturnValue(DEFAULT_GEMINI_MODEL);

const decision = await strategy.route(
mockContext,
mockConfig,
mockBaseLlmClient,
);

expect(decision).toBeNull();
expect(mockBaseLlmClient.generateJson).not.toHaveBeenCalled();
});

it('should call generateJson with the correct parameters and wrapped user content', async () => {
const mockApiResponse = {
complexity_reasoning: 'Simple task',
Expand Down Expand Up @@ -119,7 +147,7 @@ describe('NumericalClassifierStrategy', () => {
);

expect(decision).toEqual({
model: DEFAULT_GEMINI_FLASH_MODEL,
model: PREVIEW_GEMINI_FLASH_MODEL,
metadata: {
source: 'NumericalClassifier (Control)',
latencyMs: expect.any(Number),
Expand All @@ -145,7 +173,7 @@ describe('NumericalClassifierStrategy', () => {
);

expect(decision).toEqual({
model: DEFAULT_GEMINI_MODEL,
model: PREVIEW_GEMINI_MODEL,
metadata: {
source: 'NumericalClassifier (Control)',
latencyMs: expect.any(Number),
Expand All @@ -171,7 +199,7 @@ describe('NumericalClassifierStrategy', () => {
);

expect(decision).toEqual({
model: DEFAULT_GEMINI_FLASH_MODEL, // Routed to Flash because 60 < 80
model: PREVIEW_GEMINI_FLASH_MODEL, // Routed to Flash because 60 < 80
metadata: {
source: 'NumericalClassifier (Strict)',
latencyMs: expect.any(Number),
Expand All @@ -197,7 +225,7 @@ describe('NumericalClassifierStrategy', () => {
);

expect(decision).toEqual({
model: DEFAULT_GEMINI_MODEL,
model: PREVIEW_GEMINI_MODEL,
metadata: {
source: 'NumericalClassifier (Strict)',
latencyMs: expect.any(Number),
Expand Down Expand Up @@ -225,7 +253,7 @@ describe('NumericalClassifierStrategy', () => {
);

expect(decision).toEqual({
model: DEFAULT_GEMINI_FLASH_MODEL, // Score 60 < Threshold 70
model: PREVIEW_GEMINI_FLASH_MODEL, // Score 60 < Threshold 70
metadata: {
source: 'NumericalClassifier (Remote)',
latencyMs: expect.any(Number),
Expand All @@ -251,7 +279,7 @@ describe('NumericalClassifierStrategy', () => {
);

expect(decision).toEqual({
model: DEFAULT_GEMINI_FLASH_MODEL, // Score 40 < Threshold 45.5
model: PREVIEW_GEMINI_FLASH_MODEL, // Score 40 < Threshold 45.5
metadata: {
source: 'NumericalClassifier (Remote)',
latencyMs: expect.any(Number),
Expand All @@ -277,7 +305,7 @@ describe('NumericalClassifierStrategy', () => {
);

expect(decision).toEqual({
model: DEFAULT_GEMINI_MODEL, // Score 35 >= Threshold 30
model: PREVIEW_GEMINI_MODEL, // Score 35 >= Threshold 30
metadata: {
source: 'NumericalClassifier (Remote)',
latencyMs: expect.any(Number),
Expand Down Expand Up @@ -305,7 +333,7 @@ describe('NumericalClassifierStrategy', () => {
);

expect(decision).toEqual({
model: DEFAULT_GEMINI_FLASH_MODEL, // Score 40 < Default A/B Threshold 50
model: PREVIEW_GEMINI_FLASH_MODEL, // Score 40 < Default A/B Threshold 50
metadata: {
source: 'NumericalClassifier (Control)',
latencyMs: expect.any(Number),
Expand All @@ -332,7 +360,7 @@ describe('NumericalClassifierStrategy', () => {
);

expect(decision).toEqual({
model: DEFAULT_GEMINI_FLASH_MODEL,
model: PREVIEW_GEMINI_FLASH_MODEL,
metadata: {
source: 'NumericalClassifier (Control)',
latencyMs: expect.any(Number),
Expand All @@ -359,7 +387,7 @@ describe('NumericalClassifierStrategy', () => {
);

expect(decision).toEqual({
model: DEFAULT_GEMINI_MODEL,
model: PREVIEW_GEMINI_MODEL,
metadata: {
source: 'NumericalClassifier (Control)',
latencyMs: expect.any(Number),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import type {
RoutingDecision,
RoutingStrategy,
} from '../routingStrategy.js';
import { resolveClassifierModel } from '../../config/models.js';
import { resolveClassifierModel, isGemini3Model } from '../../config/models.js';
import { createUserContent, Type } from '@google/genai';
import type { Config } from '../../config/config.js';
import { debugLogger } from '../../utils/debugLogger.js';
Expand Down Expand Up @@ -134,10 +134,15 @@ export class NumericalClassifierStrategy implements RoutingStrategy {
): Promise<RoutingDecision | null> {
const startTime = Date.now();
try {
const model = context.requestedModel ?? config.getModel();
if (!(await config.getNumericalRoutingEnabled())) {
return null;
}

if (!isGemini3Model(model)) {
return null;
}

const promptId = getPromptIdWithFallback('classifier-router');

const finalHistory = context.history.slice(-HISTORY_TURNS_FOR_CONTEXT);
Expand Down Expand Up @@ -176,10 +181,7 @@ export class NumericalClassifierStrategy implements RoutingStrategy {
config.getSessionId() || 'unknown-session',
);

const selectedModel = resolveClassifierModel(
config.getModel(),
modelAlias,
);
const selectedModel = resolveClassifierModel(model, modelAlias);

const latencyMs = Date.now() - startTime;

Expand Down
Loading