Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 2 additions & 19 deletions web-app/src/containers/dialogs/AddModel.tsx
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import {

Check warning on line 1 in web-app/src/containers/dialogs/AddModel.tsx

View workflow job for this annotation

GitHub Actions / coverage-check

1 line is not covered with tests
Dialog,
DialogContent,
DialogDescription,
Expand All @@ -7,90 +7,73 @@
DialogTrigger,
DialogFooter,
} from '@/components/ui/dialog'
import { Button } from '@/components/ui/button'
import { useModelProvider } from '@/hooks/useModelProvider'
import { useProviderModels } from '@/hooks/useProviderModels'
import { ModelCombobox } from '@/containers/ModelCombobox'
import { IconPlus } from '@tabler/icons-react'
import { useState } from 'react'
import { getProviderTitle } from '@/lib/utils'
import { useTranslation } from '@/i18n/react-i18next-compat'
import { ModelCapabilities } from '@/types/models'
import { models as providerModels } from 'token.js'
import { getModelCapabilities } from '@/lib/models'
import { toast } from 'sonner'

Check warning on line 19 in web-app/src/containers/dialogs/AddModel.tsx

View workflow job for this annotation

GitHub Actions / coverage-check

10-19 lines are not covered with tests

type DialogAddModelProps = {
provider: ModelProvider
trigger?: React.ReactNode
}

export const DialogAddModel = ({ provider, trigger }: DialogAddModelProps) => {
const { t } = useTranslation()
const { updateProvider } = useModelProvider()
const [modelId, setModelId] = useState<string>('')
const [open, setOpen] = useState(false)
const [isComboboxOpen, setIsComboboxOpen] = useState(false)

Check warning on line 31 in web-app/src/containers/dialogs/AddModel.tsx

View workflow job for this annotation

GitHub Actions / coverage-check

26-31 lines are not covered with tests

// Fetch models from provider API (API key is optional)
const { models, loading, error, refetch } = useProviderModels(
provider.base_url ? provider : undefined
)

Check warning on line 36 in web-app/src/containers/dialogs/AddModel.tsx

View workflow job for this annotation

GitHub Actions / coverage-check

34-36 lines are not covered with tests

// Handle form submission
const handleSubmit = () => {
if (!modelId.trim()) return // Don't submit if model ID is empty

Check warning on line 40 in web-app/src/containers/dialogs/AddModel.tsx

View workflow job for this annotation

GitHub Actions / coverage-check

39-40 lines are not covered with tests

if (provider.models.some((e) => e.id === modelId)) {
toast.error(t('providers:addModel.modelExists'), {
description: t('providers:addModel.modelExistsDesc'),
})
return // Don't submit if model ID already exists
}

Check warning on line 47 in web-app/src/containers/dialogs/AddModel.tsx

View workflow job for this annotation

GitHub Actions / coverage-check

42-47 lines are not covered with tests

// Create the new model
const newModel = {
id: modelId,
model: modelId,
name: modelId,
capabilities: [
ModelCapabilities.COMPLETION,
(
providerModels[
provider.provider as unknown as keyof typeof providerModels
]?.supportsToolCalls as unknown as string[]
)?.includes(modelId)
? ModelCapabilities.TOOLS
: undefined,
(
providerModels[
provider.provider as unknown as keyof typeof providerModels
]?.supportsImages as unknown as string[]
)?.includes(modelId)
? ModelCapabilities.VISION
: undefined,
].filter(Boolean) as string[],
capabilities: getModelCapabilities(provider.provider, modelId),
version: '1.0',
}

Check warning on line 56 in web-app/src/containers/dialogs/AddModel.tsx

View workflow job for this annotation

GitHub Actions / coverage-check

50-56 lines are not covered with tests

// Update the provider with the new model
const updatedModels = [...provider.models, newModel]
updateProvider(provider.provider, {
...provider,
models: updatedModels,
})

Check warning on line 63 in web-app/src/containers/dialogs/AddModel.tsx

View workflow job for this annotation

GitHub Actions / coverage-check

59-63 lines are not covered with tests

// Reset form and close dialog
setModelId('')
setOpen(false)
}

Check warning on line 68 in web-app/src/containers/dialogs/AddModel.tsx

View workflow job for this annotation

GitHub Actions / coverage-check

66-68 lines are not covered with tests

return (
<Dialog open={open} onOpenChange={setOpen}>
<DialogTrigger asChild>
{trigger || (
<div className="size-6 cursor-pointer flex items-center justify-center rounded hover:bg-main-view-fg/10 transition-all duration-200 ease-in-out">
<IconPlus size={18} className="text-main-view-fg/50" />
</div>

Check warning on line 76 in web-app/src/containers/dialogs/AddModel.tsx

View workflow job for this annotation

GitHub Actions / coverage-check

70-76 lines are not covered with tests
)}
</DialogTrigger>
<DialogContent
Expand Down
82 changes: 82 additions & 0 deletions web-app/src/lib/__tests__/models.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,30 @@ import {
removeYamlFrontMatter,
extractModelName,
extractModelRepo,
getModelCapabilities,
} from '../models'
import { ModelCapabilities } from '@/types/models'

// Mock the token.js module
vi.mock('token.js', () => ({
models: {
openai: {
models: ['gpt-3.5-turbo', 'gpt-4'],
supportsToolCalls: ['gpt-3.5-turbo', 'gpt-4'],
supportsImages: ['gpt-4-vision-preview'],
},
anthropic: {
models: ['claude-3-sonnet', 'claude-3-haiku'],
supportsToolCalls: ['claude-3-sonnet'],
supportsImages: ['claude-3-sonnet', 'claude-3-haiku'],
},
mistral: {
models: ['mistral-7b', 'mistral-8x7b'],
supportsToolCalls: ['mistral-8x7b'],
},
// Provider with no capability arrays
cohere: {
models: ['command', 'command-light'],
},
},
}))
Expand Down Expand Up @@ -223,3 +234,74 @@ describe('extractModelRepo', () => {
)
})
})

describe('getModelCapabilities', () => {
it('returns completion capability for all models', () => {
const capabilities = getModelCapabilities('openai', 'gpt-3.5-turbo')
expect(capabilities).toContain(ModelCapabilities.COMPLETION)
})

it('includes tools capability when model supports it', () => {
const capabilities = getModelCapabilities('openai', 'gpt-3.5-turbo')
expect(capabilities).toContain(ModelCapabilities.TOOLS)
expect(capabilities).toContain(ModelCapabilities.COMPLETION)
})

it('excludes tools capability when model does not support it', () => {
const capabilities = getModelCapabilities('mistral', 'mistral-7b')
expect(capabilities).not.toContain(ModelCapabilities.TOOLS)
expect(capabilities).toContain(ModelCapabilities.COMPLETION)
})

it('includes vision capability when model supports it', () => {
const capabilities = getModelCapabilities('openai', 'gpt-4-vision-preview')
expect(capabilities).toContain(ModelCapabilities.VISION)
expect(capabilities).toContain(ModelCapabilities.COMPLETION)
})

it('excludes vision capability when model does not support it', () => {
const capabilities = getModelCapabilities('openai', 'gpt-3.5-turbo')
expect(capabilities).not.toContain(ModelCapabilities.VISION)
})

it('includes both tools and vision when model supports both', () => {
const capabilities = getModelCapabilities('anthropic', 'claude-3-sonnet')
expect(capabilities).toContain(ModelCapabilities.COMPLETION)
expect(capabilities).toContain(ModelCapabilities.TOOLS)
expect(capabilities).toContain(ModelCapabilities.VISION)
})

it('handles provider with no capability arrays gracefully', () => {
const capabilities = getModelCapabilities('cohere', 'command')
expect(capabilities).toEqual([ModelCapabilities.COMPLETION])
expect(capabilities).not.toContain(ModelCapabilities.TOOLS)
expect(capabilities).not.toContain(ModelCapabilities.VISION)
})

it('handles unknown provider gracefully', () => {
const capabilities = getModelCapabilities('openrouter', 'some-model')
expect(capabilities).toEqual([ModelCapabilities.COMPLETION])
expect(capabilities).not.toContain(ModelCapabilities.TOOLS)
expect(capabilities).not.toContain(ModelCapabilities.VISION)
})

it('handles model not in capability list', () => {
const capabilities = getModelCapabilities('anthropic', 'claude-3-haiku')
expect(capabilities).toContain(ModelCapabilities.COMPLETION)
expect(capabilities).toContain(ModelCapabilities.VISION)
expect(capabilities).not.toContain(ModelCapabilities.TOOLS)
})

it('returns only completion for provider with partial capability data', () => {
// Mistral has supportsToolCalls but no supportsImages
const capabilities = getModelCapabilities('mistral', 'mistral-7b')
expect(capabilities).toEqual([ModelCapabilities.COMPLETION])
})

it('handles model that supports tools but not vision', () => {
const capabilities = getModelCapabilities('mistral', 'mistral-8x7b')
expect(capabilities).toContain(ModelCapabilities.COMPLETION)
expect(capabilities).toContain(ModelCapabilities.TOOLS)
expect(capabilities).not.toContain(ModelCapabilities.VISION)
})
})
33 changes: 33 additions & 0 deletions web-app/src/lib/models.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import { models } from 'token.js'
import { ModelCapabilities } from '@/types/models'

export const defaultModel = (provider?: string) => {
if (!provider || !Object.keys(models).includes(provider)) {
Expand All @@ -10,6 +11,38 @@ export const defaultModel = (provider?: string) => {
)[0]
}

/**
* Determines model capabilities based on provider configuration from token.js
* @param providerName - The provider name (e.g., 'openai', 'anthropic', 'openrouter')
* @param modelId - The model ID to check capabilities for
* @returns Array of model capabilities
*/
export const getModelCapabilities = (
providerName: string,
modelId: string
): string[] => {
const providerConfig =
models[providerName as unknown as keyof typeof models]

const supportsToolCalls = Array.isArray(
providerConfig?.supportsToolCalls as unknown
)
? (providerConfig.supportsToolCalls as unknown as string[])
: []

const supportsImages = Array.isArray(
providerConfig?.supportsImages as unknown
)
? (providerConfig.supportsImages as unknown as string[])
: []

return [
ModelCapabilities.COMPLETION,
supportsToolCalls.includes(modelId) ? ModelCapabilities.TOOLS : undefined,
supportsImages.includes(modelId) ? ModelCapabilities.VISION : undefined,
].filter(Boolean) as string[]
}

/**
* This utility is to extract cortexso model description from README.md file
* @returns
Expand Down
23 changes: 4 additions & 19 deletions web-app/src/services/providers/tauri.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import { modelSettings } from '@/lib/predefined'
import { ExtensionManager } from '@/lib/extension'
import { fetch as fetchTauri } from '@tauri-apps/plugin-http'
import { DefaultProvidersService } from './default'
import { getModelCapabilities } from '@/lib/models'

export class TauriProvidersService extends DefaultProvidersService {
fetch(): typeof fetch {
Expand All @@ -26,32 +27,16 @@ export class TauriProvidersService extends DefaultProvidersService {
provider.provider as unknown as keyof typeof providerModels
].models as unknown as string[]

if (Array.isArray(builtInModels))
if (Array.isArray(builtInModels)) {
models = builtInModels.map((model) => {
const modelManifest = models.find((e) => e.id === model)
// TODO: Check chat_template for tool call support
const capabilities = [
ModelCapabilities.COMPLETION,
(
providerModels[
provider.provider as unknown as keyof typeof providerModels
]?.supportsToolCalls as unknown as string[]
)?.includes(model)
? ModelCapabilities.TOOLS
: undefined,
(
providerModels[
provider.provider as unknown as keyof typeof providerModels
]?.supportsImages as unknown as string[]
)?.includes(model)
? ModelCapabilities.VISION
: undefined,
].filter(Boolean) as string[]
return {
...(modelManifest ?? { id: model, name: model }),
capabilities,
capabilities: getModelCapabilities(provider.provider, model),
} as Model
})
}
}

return {
Expand Down
13 changes: 2 additions & 11 deletions web-app/src/services/providers/web.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import { ExtensionManager } from '@/lib/extension'
import type { ProvidersService } from './types'
import { PlatformFeatures } from '@/lib/platform/const'
import { PlatformFeature } from '@/lib/platform/types'
import { getModelCapabilities } from '@/lib/models'

export class WebProvidersService implements ProvidersService {
async getProviders(): Promise<ModelProvider[]> {
Expand Down Expand Up @@ -88,19 +89,9 @@ export class WebProvidersService implements ProvidersService {
models = builtInModels.map((model) => {
const modelManifest = models.find((e) => e.id === model)
// TODO: Check chat_template for tool call support
const capabilities = [
ModelCapabilities.COMPLETION,
(
providerModels[
provider.provider as unknown as keyof typeof providerModels
]?.supportsToolCalls as unknown as string[]
)?.includes(model)
? ModelCapabilities.TOOLS
: undefined,
].filter(Boolean) as string[]
return {
...(modelManifest ?? { id: model, name: model }),
capabilities,
capabilities: getModelCapabilities(provider.provider, model),
} as Model
})
}
Expand Down
Loading