Skip to content

Commit c14e1ea

Browse files
Vanalitelouis-jan
authored andcommitted
Merge pull request #6715 from menloresearch/fix/get-model-capabilities-correctly
fix: Extract model capabilities correctly for various providers on various platforms
1 parent f537429 commit c14e1ea

File tree

5 files changed

+123
-49
lines changed

5 files changed

+123
-49
lines changed

web-app/src/containers/dialogs/AddModel.tsx

Lines changed: 2 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,7 @@ import { IconPlus } from '@tabler/icons-react'
1515
import { useState } from 'react'
1616
import { getProviderTitle } from '@/lib/utils'
1717
import { useTranslation } from '@/i18n/react-i18next-compat'
18-
import { ModelCapabilities } from '@/types/models'
19-
import { models as providerModels } from 'token.js'
18+
import { getModelCapabilities } from '@/lib/models'
2019
import { toast } from 'sonner'
2120

2221
type DialogAddModelProps = {
@@ -52,23 +51,7 @@ export const DialogAddModel = ({ provider, trigger }: DialogAddModelProps) => {
5251
id: modelId,
5352
model: modelId,
5453
name: modelId,
55-
capabilities: [
56-
ModelCapabilities.COMPLETION,
57-
(
58-
providerModels[
59-
provider.provider as unknown as keyof typeof providerModels
60-
]?.supportsToolCalls as unknown as string[]
61-
)?.includes(modelId)
62-
? ModelCapabilities.TOOLS
63-
: undefined,
64-
(
65-
providerModels[
66-
provider.provider as unknown as keyof typeof providerModels
67-
]?.supportsImages as unknown as string[]
68-
)?.includes(modelId)
69-
? ModelCapabilities.VISION
70-
: undefined,
71-
].filter(Boolean) as string[],
54+
capabilities: getModelCapabilities(provider.provider, modelId),
7255
version: '1.0',
7356
}
7457

web-app/src/lib/__tests__/models.test.ts

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,30 @@ import {
55
removeYamlFrontMatter,
66
extractModelName,
77
extractModelRepo,
8+
getModelCapabilities,
89
} from '../models'
10+
import { ModelCapabilities } from '@/types/models'
911

1012
// Mock the token.js module
1113
vi.mock('token.js', () => ({
1214
models: {
1315
openai: {
1416
models: ['gpt-3.5-turbo', 'gpt-4'],
17+
supportsToolCalls: ['gpt-3.5-turbo', 'gpt-4'],
18+
supportsImages: ['gpt-4-vision-preview'],
1519
},
1620
anthropic: {
1721
models: ['claude-3-sonnet', 'claude-3-haiku'],
22+
supportsToolCalls: ['claude-3-sonnet'],
23+
supportsImages: ['claude-3-sonnet', 'claude-3-haiku'],
1824
},
1925
mistral: {
2026
models: ['mistral-7b', 'mistral-8x7b'],
27+
supportsToolCalls: ['mistral-8x7b'],
28+
},
29+
// Provider with no capability arrays
30+
cohere: {
31+
models: ['command', 'command-light'],
2132
},
2233
},
2334
}))
@@ -223,3 +234,74 @@ describe('extractModelRepo', () => {
223234
)
224235
})
225236
})
237+
238+
describe('getModelCapabilities', () => {
239+
it('returns completion capability for all models', () => {
240+
const capabilities = getModelCapabilities('openai', 'gpt-3.5-turbo')
241+
expect(capabilities).toContain(ModelCapabilities.COMPLETION)
242+
})
243+
244+
it('includes tools capability when model supports it', () => {
245+
const capabilities = getModelCapabilities('openai', 'gpt-3.5-turbo')
246+
expect(capabilities).toContain(ModelCapabilities.TOOLS)
247+
expect(capabilities).toContain(ModelCapabilities.COMPLETION)
248+
})
249+
250+
it('excludes tools capability when model does not support it', () => {
251+
const capabilities = getModelCapabilities('mistral', 'mistral-7b')
252+
expect(capabilities).not.toContain(ModelCapabilities.TOOLS)
253+
expect(capabilities).toContain(ModelCapabilities.COMPLETION)
254+
})
255+
256+
it('includes vision capability when model supports it', () => {
257+
const capabilities = getModelCapabilities('openai', 'gpt-4-vision-preview')
258+
expect(capabilities).toContain(ModelCapabilities.VISION)
259+
expect(capabilities).toContain(ModelCapabilities.COMPLETION)
260+
})
261+
262+
it('excludes vision capability when model does not support it', () => {
263+
const capabilities = getModelCapabilities('openai', 'gpt-3.5-turbo')
264+
expect(capabilities).not.toContain(ModelCapabilities.VISION)
265+
})
266+
267+
it('includes both tools and vision when model supports both', () => {
268+
const capabilities = getModelCapabilities('anthropic', 'claude-3-sonnet')
269+
expect(capabilities).toContain(ModelCapabilities.COMPLETION)
270+
expect(capabilities).toContain(ModelCapabilities.TOOLS)
271+
expect(capabilities).toContain(ModelCapabilities.VISION)
272+
})
273+
274+
it('handles provider with no capability arrays gracefully', () => {
275+
const capabilities = getModelCapabilities('cohere', 'command')
276+
expect(capabilities).toEqual([ModelCapabilities.COMPLETION])
277+
expect(capabilities).not.toContain(ModelCapabilities.TOOLS)
278+
expect(capabilities).not.toContain(ModelCapabilities.VISION)
279+
})
280+
281+
it('handles unknown provider gracefully', () => {
282+
const capabilities = getModelCapabilities('openrouter', 'some-model')
283+
expect(capabilities).toEqual([ModelCapabilities.COMPLETION])
284+
expect(capabilities).not.toContain(ModelCapabilities.TOOLS)
285+
expect(capabilities).not.toContain(ModelCapabilities.VISION)
286+
})
287+
288+
it('handles model not in capability list', () => {
289+
const capabilities = getModelCapabilities('anthropic', 'claude-3-haiku')
290+
expect(capabilities).toContain(ModelCapabilities.COMPLETION)
291+
expect(capabilities).toContain(ModelCapabilities.VISION)
292+
expect(capabilities).not.toContain(ModelCapabilities.TOOLS)
293+
})
294+
295+
it('returns only completion for provider with partial capability data', () => {
296+
// Mistral has supportsToolCalls but no supportsImages
297+
const capabilities = getModelCapabilities('mistral', 'mistral-7b')
298+
expect(capabilities).toEqual([ModelCapabilities.COMPLETION])
299+
})
300+
301+
it('handles model that supports tools but not vision', () => {
302+
const capabilities = getModelCapabilities('mistral', 'mistral-8x7b')
303+
expect(capabilities).toContain(ModelCapabilities.COMPLETION)
304+
expect(capabilities).toContain(ModelCapabilities.TOOLS)
305+
expect(capabilities).not.toContain(ModelCapabilities.VISION)
306+
})
307+
})

web-app/src/lib/models.ts

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import { models } from 'token.js'
2+
import { ModelCapabilities } from '@/types/models'
23

34
export const defaultModel = (provider?: string) => {
45
if (!provider || !Object.keys(models).includes(provider)) {
@@ -10,6 +11,38 @@ export const defaultModel = (provider?: string) => {
1011
)[0]
1112
}
1213

14+
/**
15+
* Determines model capabilities based on provider configuration from token.js
16+
* @param providerName - The provider name (e.g., 'openai', 'anthropic', 'openrouter')
17+
* @param modelId - The model ID to check capabilities for
18+
* @returns Array of model capabilities
19+
*/
20+
export const getModelCapabilities = (
21+
providerName: string,
22+
modelId: string
23+
): string[] => {
24+
const providerConfig =
25+
models[providerName as unknown as keyof typeof models]
26+
27+
const supportsToolCalls = Array.isArray(
28+
providerConfig?.supportsToolCalls as unknown
29+
)
30+
? (providerConfig.supportsToolCalls as unknown as string[])
31+
: []
32+
33+
const supportsImages = Array.isArray(
34+
providerConfig?.supportsImages as unknown
35+
)
36+
? (providerConfig.supportsImages as unknown as string[])
37+
: []
38+
39+
return [
40+
ModelCapabilities.COMPLETION,
41+
supportsToolCalls.includes(modelId) ? ModelCapabilities.TOOLS : undefined,
42+
supportsImages.includes(modelId) ? ModelCapabilities.VISION : undefined,
43+
].filter(Boolean) as string[]
44+
}
45+
1346
/**
1447
* This utility is to extract cortexso model description from README.md file
1548
* @returns

web-app/src/services/providers/tauri.ts

Lines changed: 4 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import { modelSettings } from '@/lib/predefined'
1010
import { ExtensionManager } from '@/lib/extension'
1111
import { fetch as fetchTauri } from '@tauri-apps/plugin-http'
1212
import { DefaultProvidersService } from './default'
13+
import { getModelCapabilities } from '@/lib/models'
1314

1415
export class TauriProvidersService extends DefaultProvidersService {
1516
fetch(): typeof fetch {
@@ -26,32 +27,16 @@ export class TauriProvidersService extends DefaultProvidersService {
2627
provider.provider as unknown as keyof typeof providerModels
2728
].models as unknown as string[]
2829

29-
if (Array.isArray(builtInModels))
30+
if (Array.isArray(builtInModels)) {
3031
models = builtInModels.map((model) => {
3132
const modelManifest = models.find((e) => e.id === model)
3233
// TODO: Check chat_template for tool call support
33-
const capabilities = [
34-
ModelCapabilities.COMPLETION,
35-
(
36-
providerModels[
37-
provider.provider as unknown as keyof typeof providerModels
38-
]?.supportsToolCalls as unknown as string[]
39-
)?.includes(model)
40-
? ModelCapabilities.TOOLS
41-
: undefined,
42-
(
43-
providerModels[
44-
provider.provider as unknown as keyof typeof providerModels
45-
]?.supportsImages as unknown as string[]
46-
)?.includes(model)
47-
? ModelCapabilities.VISION
48-
: undefined,
49-
].filter(Boolean) as string[]
5034
return {
5135
...(modelManifest ?? { id: model, name: model }),
52-
capabilities,
36+
capabilities: getModelCapabilities(provider.provider, model),
5337
} as Model
5438
})
39+
}
5540
}
5641

5742
return {

web-app/src/services/providers/web.ts

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import { ExtensionManager } from '@/lib/extension'
1111
import type { ProvidersService } from './types'
1212
import { PlatformFeatures } from '@/lib/platform/const'
1313
import { PlatformFeature } from '@/lib/platform/types'
14+
import { getModelCapabilities } from '@/lib/models'
1415

1516
export class WebProvidersService implements ProvidersService {
1617
async getProviders(): Promise<ModelProvider[]> {
@@ -88,19 +89,9 @@ export class WebProvidersService implements ProvidersService {
8889
models = builtInModels.map((model) => {
8990
const modelManifest = models.find((e) => e.id === model)
9091
// TODO: Check chat_template for tool call support
91-
const capabilities = [
92-
ModelCapabilities.COMPLETION,
93-
(
94-
providerModels[
95-
provider.provider as unknown as keyof typeof providerModels
96-
]?.supportsToolCalls as unknown as string[]
97-
)?.includes(model)
98-
? ModelCapabilities.TOOLS
99-
: undefined,
100-
].filter(Boolean) as string[]
10192
return {
10293
...(modelManifest ?? { id: model, name: model }),
103-
capabilities,
94+
capabilities: getModelCapabilities(provider.provider, model),
10495
} as Model
10596
})
10697
}

0 commit comments

Comments
 (0)