Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions packages/core/src/config/models.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import { describe, it, expect } from 'vitest';
import {
resolveModel,
resolveClassifierModel,
isGemini3Model,
isGemini2Model,
isAutoModel,
getDisplayString,
Expand All @@ -25,6 +26,29 @@ import {
DEFAULT_GEMINI_MODEL_AUTO,
} from './models.js';

describe('isGemini3Model', () => {
it('should return true for gemini-3 models', () => {
expect(isGemini3Model('gemini-3-pro-preview')).toBe(true);
expect(isGemini3Model('gemini-3-flash-preview')).toBe(true);
});

it('should return true for aliases that resolve to Gemini 3', () => {
expect(isGemini3Model(GEMINI_MODEL_ALIAS_AUTO)).toBe(true);
expect(isGemini3Model(GEMINI_MODEL_ALIAS_PRO)).toBe(true);
expect(isGemini3Model(PREVIEW_GEMINI_MODEL_AUTO)).toBe(true);
});

it('should return false for Gemini 2 models', () => {
expect(isGemini3Model('gemini-2.5-pro')).toBe(false);
expect(isGemini3Model('gemini-2.5-flash')).toBe(false);
expect(isGemini3Model(DEFAULT_GEMINI_MODEL_AUTO)).toBe(false);
});

it('should return false for arbitrary strings', () => {
expect(isGemini3Model('gpt-4')).toBe(false);
});
});

describe('getDisplayString', () => {
it('should return Auto (Gemini 3) for preview auto model', () => {
expect(getDisplayString(PREVIEW_GEMINI_MODEL_AUTO)).toBe('Auto (Gemini 3)');
Expand Down
11 changes: 11 additions & 0 deletions packages/core/src/config/models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,17 @@ export function isPreviewModel(model: string): boolean {
);
}

/**
* Checks if the model is a Gemini 3 model.
*
* @param model The model name to check.
* @returns True if the model is a Gemini 3 model.
*/
export function isGemini3Model(model: string): boolean {
const resolved = resolveModel(model);
return /^gemini-3(\.|-|$)/.test(resolved);
}

/**
* Checks if the model is a Gemini 2.x model.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import {
DEFAULT_GEMINI_FLASH_MODEL,
DEFAULT_GEMINI_MODEL,
DEFAULT_GEMINI_MODEL_AUTO,
PREVIEW_GEMINI_MODEL_AUTO,
} from '../../config/models.js';
import { promptIdContext } from '../../utils/promptIdContext.js';
import type { Content } from '@google/genai';
Expand Down Expand Up @@ -50,8 +51,12 @@ describe('ClassifierStrategy', () => {
modelConfigService: {
getResolvedConfig: vi.fn().mockReturnValue(mockResolvedConfig),
},
<<<<<<< HEAD
getModel: () => DEFAULT_GEMINI_MODEL_AUTO,
getPreviewFeatures: () => false,
=======
getModel: vi.fn().mockReturnValue(DEFAULT_GEMINI_MODEL_AUTO),
>>>>>>> 37f128a10 (feat(routing): restrict numerical routing to Gemini 3 family (#18478))
getNumericalRoutingEnabled: vi.fn().mockResolvedValue(false),
} as unknown as Config;
mockBaseLlmClient = {
Expand All @@ -61,8 +66,9 @@ describe('ClassifierStrategy', () => {
vi.spyOn(promptIdContext, 'getStore').mockReturnValue('test-prompt-id');
});

it('should return null if numerical routing is enabled', async () => {
it('should return null if numerical routing is enabled and model is Gemini 3', async () => {
vi.mocked(mockConfig.getNumericalRoutingEnabled).mockResolvedValue(true);
vi.mocked(mockConfig.getModel).mockReturnValue(PREVIEW_GEMINI_MODEL_AUTO);

const decision = await strategy.route(
mockContext,
Expand All @@ -74,6 +80,24 @@ describe('ClassifierStrategy', () => {
expect(mockBaseLlmClient.generateJson).not.toHaveBeenCalled();
});

it('should NOT return null if numerical routing is enabled but model is NOT Gemini 3', async () => {
vi.mocked(mockConfig.getNumericalRoutingEnabled).mockResolvedValue(true);
vi.mocked(mockConfig.getModel).mockReturnValue(DEFAULT_GEMINI_MODEL_AUTO);
vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue({
reasoning: 'test',
model_choice: 'flash',
});

const decision = await strategy.route(
mockContext,
mockConfig,
mockBaseLlmClient,
);

expect(decision).not.toBeNull();
expect(mockBaseLlmClient.generateJson).toHaveBeenCalled();
});

it('should call generateJson with the correct parameters', async () => {
const mockApiResponse = {
reasoning: 'Simple task',
Expand Down
10 changes: 7 additions & 3 deletions packages/core/src/routing/strategies/classifierStrategy.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import type {
RoutingDecision,
RoutingStrategy,
} from '../routingStrategy.js';
import { resolveClassifierModel } from '../../config/models.js';
import { resolveClassifierModel, isGemini3Model } from '../../config/models.js';
import { createUserContent, Type } from '@google/genai';
import type { Config } from '../../config/config.js';
import {
Expand Down Expand Up @@ -133,7 +133,11 @@ export class ClassifierStrategy implements RoutingStrategy {
): Promise<RoutingDecision | null> {
const startTime = Date.now();
try {
if (await config.getNumericalRoutingEnabled()) {
const model = context.requestedModel ?? config.getModel();
if (
(await config.getNumericalRoutingEnabled()) &&
isGemini3Model(model)
) {
return null;
}

Expand Down Expand Up @@ -164,7 +168,7 @@ export class ClassifierStrategy implements RoutingStrategy {
const reasoning = routerResponse.reasoning;
const latencyMs = Date.now() - startTime;
const selectedModel = resolveClassifierModel(
context.requestedModel ?? config.getModel(),
model,
routerResponse.model_choice,
config.getPreviewFeatures(),
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@ import type { RoutingContext } from '../routingStrategy.js';
import type { Config } from '../../config/config.js';
import type { BaseLlmClient } from '../../core/baseLlmClient.js';
import {
DEFAULT_GEMINI_FLASH_MODEL,
DEFAULT_GEMINI_MODEL,
PREVIEW_GEMINI_FLASH_MODEL,
PREVIEW_GEMINI_MODEL,
PREVIEW_GEMINI_MODEL_AUTO,
DEFAULT_GEMINI_MODEL_AUTO,
DEFAULT_GEMINI_MODEL,
} from '../../config/models.js';
import { promptIdContext } from '../../utils/promptIdContext.js';
import type { Content } from '@google/genai';
Expand Down Expand Up @@ -46,8 +48,12 @@ describe('NumericalClassifierStrategy', () => {
modelConfigService: {
getResolvedConfig: vi.fn().mockReturnValue(mockResolvedConfig),
},
<<<<<<< HEAD
getModel: () => DEFAULT_GEMINI_MODEL_AUTO,
getPreviewFeatures: () => false,
=======
getModel: vi.fn().mockReturnValue(PREVIEW_GEMINI_MODEL_AUTO),
>>>>>>> 37f128a10 (feat(routing): restrict numerical routing to Gemini 3 family (#18478))
getSessionId: vi.fn().mockReturnValue('control-group-id'), // Default to Control Group (Hash 71 >= 50)
getNumericalRoutingEnabled: vi.fn().mockResolvedValue(true),
getClassifierThreshold: vi.fn().mockResolvedValue(undefined),
Expand Down Expand Up @@ -76,6 +82,32 @@ describe('NumericalClassifierStrategy', () => {
expect(mockBaseLlmClient.generateJson).not.toHaveBeenCalled();
});

it('should return null if the model is not a Gemini 3 model', async () => {
vi.mocked(mockConfig.getModel).mockReturnValue(DEFAULT_GEMINI_MODEL_AUTO);

const decision = await strategy.route(
mockContext,
mockConfig,
mockBaseLlmClient,
);

expect(decision).toBeNull();
expect(mockBaseLlmClient.generateJson).not.toHaveBeenCalled();
});

it('should return null if the model is explicitly a Gemini 2 model', async () => {
vi.mocked(mockConfig.getModel).mockReturnValue(DEFAULT_GEMINI_MODEL);

const decision = await strategy.route(
mockContext,
mockConfig,
mockBaseLlmClient,
);

expect(decision).toBeNull();
expect(mockBaseLlmClient.generateJson).not.toHaveBeenCalled();
});

it('should call generateJson with the correct parameters and wrapped user content', async () => {
const mockApiResponse = {
complexity_reasoning: 'Simple task',
Expand Down Expand Up @@ -120,7 +152,7 @@ describe('NumericalClassifierStrategy', () => {
);

expect(decision).toEqual({
model: DEFAULT_GEMINI_FLASH_MODEL,
model: PREVIEW_GEMINI_FLASH_MODEL,
metadata: {
source: 'NumericalClassifier (Control)',
latencyMs: expect.any(Number),
Expand All @@ -146,7 +178,7 @@ describe('NumericalClassifierStrategy', () => {
);

expect(decision).toEqual({
model: DEFAULT_GEMINI_MODEL,
model: PREVIEW_GEMINI_MODEL,
metadata: {
source: 'NumericalClassifier (Control)',
latencyMs: expect.any(Number),
Expand All @@ -172,7 +204,7 @@ describe('NumericalClassifierStrategy', () => {
);

expect(decision).toEqual({
model: DEFAULT_GEMINI_FLASH_MODEL, // Routed to Flash because 60 < 80
model: PREVIEW_GEMINI_FLASH_MODEL, // Routed to Flash because 60 < 80
metadata: {
source: 'NumericalClassifier (Strict)',
latencyMs: expect.any(Number),
Expand All @@ -198,7 +230,7 @@ describe('NumericalClassifierStrategy', () => {
);

expect(decision).toEqual({
model: DEFAULT_GEMINI_MODEL,
model: PREVIEW_GEMINI_MODEL,
metadata: {
source: 'NumericalClassifier (Strict)',
latencyMs: expect.any(Number),
Expand Down Expand Up @@ -226,7 +258,7 @@ describe('NumericalClassifierStrategy', () => {
);

expect(decision).toEqual({
model: DEFAULT_GEMINI_FLASH_MODEL, // Score 60 < Threshold 70
model: PREVIEW_GEMINI_FLASH_MODEL, // Score 60 < Threshold 70
metadata: {
source: 'NumericalClassifier (Remote)',
latencyMs: expect.any(Number),
Expand All @@ -252,7 +284,7 @@ describe('NumericalClassifierStrategy', () => {
);

expect(decision).toEqual({
model: DEFAULT_GEMINI_FLASH_MODEL, // Score 40 < Threshold 45.5
model: PREVIEW_GEMINI_FLASH_MODEL, // Score 40 < Threshold 45.5
metadata: {
source: 'NumericalClassifier (Remote)',
latencyMs: expect.any(Number),
Expand All @@ -278,7 +310,7 @@ describe('NumericalClassifierStrategy', () => {
);

expect(decision).toEqual({
model: DEFAULT_GEMINI_MODEL, // Score 35 >= Threshold 30
model: PREVIEW_GEMINI_MODEL, // Score 35 >= Threshold 30
metadata: {
source: 'NumericalClassifier (Remote)',
latencyMs: expect.any(Number),
Expand Down Expand Up @@ -306,7 +338,7 @@ describe('NumericalClassifierStrategy', () => {
);

expect(decision).toEqual({
model: DEFAULT_GEMINI_FLASH_MODEL, // Score 40 < Default A/B Threshold 50
model: PREVIEW_GEMINI_FLASH_MODEL, // Score 40 < Default A/B Threshold 50
metadata: {
source: 'NumericalClassifier (Control)',
latencyMs: expect.any(Number),
Expand All @@ -333,7 +365,7 @@ describe('NumericalClassifierStrategy', () => {
);

expect(decision).toEqual({
model: DEFAULT_GEMINI_FLASH_MODEL,
model: PREVIEW_GEMINI_FLASH_MODEL,
metadata: {
source: 'NumericalClassifier (Control)',
latencyMs: expect.any(Number),
Expand All @@ -360,7 +392,7 @@ describe('NumericalClassifierStrategy', () => {
);

expect(decision).toEqual({
model: DEFAULT_GEMINI_MODEL,
model: PREVIEW_GEMINI_MODEL,
metadata: {
source: 'NumericalClassifier (Control)',
latencyMs: expect.any(Number),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import type {
RoutingDecision,
RoutingStrategy,
} from '../routingStrategy.js';
import { resolveClassifierModel } from '../../config/models.js';
import { resolveClassifierModel, isGemini3Model } from '../../config/models.js';
import { createUserContent, Type } from '@google/genai';
import type { Config } from '../../config/config.js';
import { debugLogger } from '../../utils/debugLogger.js';
Expand Down Expand Up @@ -134,10 +134,15 @@ export class NumericalClassifierStrategy implements RoutingStrategy {
): Promise<RoutingDecision | null> {
const startTime = Date.now();
try {
const model = context.requestedModel ?? config.getModel();
if (!(await config.getNumericalRoutingEnabled())) {
return null;
}

if (!isGemini3Model(model)) {
return null;
}

const promptId = getPromptIdWithFallback('classifier-router');

const finalHistory = context.history.slice(-HISTORY_TURNS_FOR_CONTEXT);
Expand Down Expand Up @@ -176,11 +181,15 @@ export class NumericalClassifierStrategy implements RoutingStrategy {
config.getSessionId() || 'unknown-session',
);

<<<<<<< HEAD
const selectedModel = resolveClassifierModel(
config.getModel(),
modelAlias,
config.getPreviewFeatures(),
);
=======
const selectedModel = resolveClassifierModel(model, modelAlias);
>>>>>>> 37f128a10 (feat(routing): restrict numerical routing to Gemini 3 family (#18478))

const latencyMs = Date.now() - startTime;

Expand Down
Loading