Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion packages/core/src/config/models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -120,14 +120,29 @@
);
}

/**
* Checks if the model is a Gemini 3 model.
*
* @param model The model name to check.
* @returns True if the model is a Gemini 3 model.
*/
export function isGemini3Model(model: string): boolean {
return (
/^gemini-3(\.|-|$)/.test(model) ||
model === PREVIEW_GEMINI_MODEL_AUTO ||
model === GEMINI_MODEL_ALIAS_PRO ||
model === GEMINI_MODEL_ALIAS_AUTO
);
}

/**
* Checks if the model is a Gemini 2.x model.
*
* @param model The model name to check.
* @returns True if the model is a Gemini-2.x model.
*/
export function isGemini2Model(model: string): boolean {
return /^gemini-2(\.|$)/.test(model);
return /^gemini-2(\.|-|$)/.test(model) || model === DEFAULT_GEMINI_MODEL_AUTO;

Check warning on line 145 in packages/core/src/config/models.ts

View workflow job for this annotation

GitHub Actions / Lint

Found sensitive keyword "gemini-2". Please make sure this change is appropriate to submit.
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@ import type { RoutingContext } from '../routingStrategy.js';
import type { Config } from '../../config/config.js';
import type { BaseLlmClient } from '../../core/baseLlmClient.js';
import {
DEFAULT_GEMINI_FLASH_MODEL,
DEFAULT_GEMINI_MODEL,
PREVIEW_GEMINI_FLASH_MODEL,
PREVIEW_GEMINI_MODEL,
PREVIEW_GEMINI_MODEL_AUTO,
DEFAULT_GEMINI_MODEL_AUTO,
} from '../../config/models.js';
import { promptIdContext } from '../../utils/promptIdContext.js';
Expand Down Expand Up @@ -46,7 +47,7 @@ describe('NumericalClassifierStrategy', () => {
modelConfigService: {
getResolvedConfig: vi.fn().mockReturnValue(mockResolvedConfig),
},
getModel: () => DEFAULT_GEMINI_MODEL_AUTO,
getModel: vi.fn().mockReturnValue(PREVIEW_GEMINI_MODEL_AUTO),
getSessionId: vi.fn().mockReturnValue('control-group-id'), // Default to Control Group (Hash 71 >= 50)
getNumericalRoutingEnabled: vi.fn().mockResolvedValue(true),
getClassifierThreshold: vi.fn().mockResolvedValue(undefined),
Expand Down Expand Up @@ -75,6 +76,19 @@ describe('NumericalClassifierStrategy', () => {
expect(mockBaseLlmClient.generateJson).not.toHaveBeenCalled();
});

it('should return null if the model is not a Gemini 3 model', async () => {
vi.mocked(mockConfig.getModel).mockReturnValue(DEFAULT_GEMINI_MODEL_AUTO);

const decision = await strategy.route(
mockContext,
mockConfig,
mockBaseLlmClient,
);

expect(decision).toBeNull();
expect(mockBaseLlmClient.generateJson).not.toHaveBeenCalled();
});

it('should call generateJson with the correct parameters and wrapped user content', async () => {
const mockApiResponse = {
complexity_reasoning: 'Simple task',
Expand Down Expand Up @@ -119,7 +133,7 @@ describe('NumericalClassifierStrategy', () => {
);

expect(decision).toEqual({
model: DEFAULT_GEMINI_FLASH_MODEL,
model: PREVIEW_GEMINI_FLASH_MODEL,
metadata: {
source: 'NumericalClassifier (Control)',
latencyMs: expect.any(Number),
Expand All @@ -145,7 +159,7 @@ describe('NumericalClassifierStrategy', () => {
);

expect(decision).toEqual({
model: DEFAULT_GEMINI_MODEL,
model: PREVIEW_GEMINI_MODEL,
metadata: {
source: 'NumericalClassifier (Control)',
latencyMs: expect.any(Number),
Expand All @@ -171,7 +185,7 @@ describe('NumericalClassifierStrategy', () => {
);

expect(decision).toEqual({
model: DEFAULT_GEMINI_FLASH_MODEL, // Routed to Flash because 60 < 80
model: PREVIEW_GEMINI_FLASH_MODEL, // Routed to Flash because 60 < 80
metadata: {
source: 'NumericalClassifier (Strict)',
latencyMs: expect.any(Number),
Expand All @@ -197,7 +211,7 @@ describe('NumericalClassifierStrategy', () => {
);

expect(decision).toEqual({
model: DEFAULT_GEMINI_MODEL,
model: PREVIEW_GEMINI_MODEL,
metadata: {
source: 'NumericalClassifier (Strict)',
latencyMs: expect.any(Number),
Expand Down Expand Up @@ -225,7 +239,7 @@ describe('NumericalClassifierStrategy', () => {
);

expect(decision).toEqual({
model: DEFAULT_GEMINI_FLASH_MODEL, // Score 60 < Threshold 70
model: PREVIEW_GEMINI_FLASH_MODEL, // Score 60 < Threshold 70
metadata: {
source: 'NumericalClassifier (Remote)',
latencyMs: expect.any(Number),
Expand All @@ -251,7 +265,7 @@ describe('NumericalClassifierStrategy', () => {
);

expect(decision).toEqual({
model: DEFAULT_GEMINI_FLASH_MODEL, // Score 40 < Threshold 45.5
model: PREVIEW_GEMINI_FLASH_MODEL, // Score 40 < Threshold 45.5
metadata: {
source: 'NumericalClassifier (Remote)',
latencyMs: expect.any(Number),
Expand All @@ -277,7 +291,7 @@ describe('NumericalClassifierStrategy', () => {
);

expect(decision).toEqual({
model: DEFAULT_GEMINI_MODEL, // Score 35 >= Threshold 30
model: PREVIEW_GEMINI_MODEL, // Score 35 >= Threshold 30
metadata: {
source: 'NumericalClassifier (Remote)',
latencyMs: expect.any(Number),
Expand Down Expand Up @@ -305,7 +319,7 @@ describe('NumericalClassifierStrategy', () => {
);

expect(decision).toEqual({
model: DEFAULT_GEMINI_FLASH_MODEL, // Score 40 < Default A/B Threshold 50
model: PREVIEW_GEMINI_FLASH_MODEL, // Score 40 < Default A/B Threshold 50
metadata: {
source: 'NumericalClassifier (Control)',
latencyMs: expect.any(Number),
Expand All @@ -332,7 +346,7 @@ describe('NumericalClassifierStrategy', () => {
);

expect(decision).toEqual({
model: DEFAULT_GEMINI_FLASH_MODEL,
model: PREVIEW_GEMINI_FLASH_MODEL,
metadata: {
source: 'NumericalClassifier (Control)',
latencyMs: expect.any(Number),
Expand All @@ -359,7 +373,7 @@ describe('NumericalClassifierStrategy', () => {
);

expect(decision).toEqual({
model: DEFAULT_GEMINI_MODEL,
model: PREVIEW_GEMINI_MODEL,
metadata: {
source: 'NumericalClassifier (Control)',
latencyMs: expect.any(Number),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import type {
RoutingDecision,
RoutingStrategy,
} from '../routingStrategy.js';
import { resolveClassifierModel } from '../../config/models.js';
import { resolveClassifierModel, isGemini3Model } from '../../config/models.js';
import { createUserContent, Type } from '@google/genai';
import type { Config } from '../../config/config.js';
import { debugLogger } from '../../utils/debugLogger.js';
Expand Down Expand Up @@ -138,6 +138,10 @@ export class NumericalClassifierStrategy implements RoutingStrategy {
return null;
}

if (!isGemini3Model(config.getModel())) {
return null;
}

const promptId = getPromptIdWithFallback('classifier-router');

const finalHistory = context.history.slice(-HISTORY_TURNS_FOR_CONTEXT);
Expand Down
Loading