Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 13 additions & 39 deletions web-app/src/hooks/useModelProvider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,123 +30,123 @@
selectedModel: null,
deletedModels: [],
getModelBy: (modelId: string) => {
const provider = get().providers.find(
(provider) => provider.provider === get().selectedProvider
)
if (!provider) return undefined
return provider.models.find((model) => model.id === modelId)
},

Check warning on line 38 in web-app/src/hooks/useModelProvider.ts

View workflow job for this annotation

GitHub Actions / coverage-check

33-38 lines are not covered with tests
setProviders: (providers) =>
set((state) => {
const existingProviders = state.providers

Check warning on line 41 in web-app/src/hooks/useModelProvider.ts

View workflow job for this annotation

GitHub Actions / coverage-check

40-41 lines are not covered with tests
// Filter out legacy llama.cpp provider for migration
// Can remove after a couple of releases
.filter((e) => e.provider !== 'llama.cpp')
.map((provider) => {
return {
...provider,
models: provider.models.filter(
(e) =>
('id' in e || 'model' in e) &&
typeof (e.id ?? e.model) === 'string'
),
}
})

Check warning on line 54 in web-app/src/hooks/useModelProvider.ts

View workflow job for this annotation

GitHub Actions / coverage-check

44-54 lines are not covered with tests

let legacyModels: Model[] | undefined = []

Check warning on line 56 in web-app/src/hooks/useModelProvider.ts

View workflow job for this annotation

GitHub Actions / coverage-check

56 line is not covered with tests
/// Cortex Migration
if (
localStorage.getItem('cortex_model_settings_migrated') !== 'true'
) {
legacyModels = state.providers.find(
(e) => e.provider === 'llama.cpp'
)?.models
localStorage.setItem('cortex_model_settings_migrated', 'true')
}

Check warning on line 65 in web-app/src/hooks/useModelProvider.ts

View workflow job for this annotation

GitHub Actions / coverage-check

58-65 lines are not covered with tests
// Ensure deletedModels is always an array
const currentDeletedModels = Array.isArray(state.deletedModels)
? state.deletedModels
: []

Check warning on line 69 in web-app/src/hooks/useModelProvider.ts

View workflow job for this annotation

GitHub Actions / coverage-check

67-69 lines are not covered with tests

const updatedProviders = providers.map((provider) => {
const existingProvider = existingProviders.find(
(x) => x.provider === provider.provider
)
const models = (existingProvider?.models || []).filter(
(e) =>
('id' in e || 'model' in e) &&
typeof (e.id ?? e.model) === 'string'
)
const mergedModels = [
...(provider?.models ?? []).filter(
(e) =>
('id' in e || 'model' in e) &&
typeof (e.id ?? e.model) === 'string' &&
!models.some((m) => m.id === e.id) &&
!currentDeletedModels.includes(e.id)
),
...models,
]
const updatedModels = provider.models?.map((model) => {
const settings =
(legacyModels && legacyModels?.length > 0
? legacyModels
: models
).find(
(m) => m.id.split(':').slice(0, 2).join(sep()) === model.id
)?.settings || model.settings
const existingModel = models.find((m) => m.id === model.id)
return {
...model,
settings: settings,
capabilities: existingModel?.capabilities || model.capabilities,
}
})

Check warning on line 104 in web-app/src/hooks/useModelProvider.ts

View workflow job for this annotation

GitHub Actions / coverage-check

71-104 lines are not covered with tests

return {
...provider,
models: provider.persist ? updatedModels : mergedModels,
settings: provider.settings.map((setting) => {
const existingSetting = provider.persist
? undefined
: existingProvider?.settings?.find(
(x) => x.key === setting.key
)
return {
...setting,
controller_props: {
...setting.controller_props,
...(existingSetting?.controller_props || {}),
},
}
}),
api_key: existingProvider?.api_key || provider.api_key,
base_url: existingProvider?.base_url || provider.base_url,
active: existingProvider ? existingProvider?.active : true,
}
})
return {
providers: [
...updatedProviders,
...existingProviders.filter(
(e) => !updatedProviders.some((p) => p.provider === e.provider)
),
],
}
}),

Check warning on line 136 in web-app/src/hooks/useModelProvider.ts

View workflow job for this annotation

GitHub Actions / coverage-check

106-136 lines are not covered with tests
updateProvider: (providerName, data) => {
set((state) => ({
providers: state.providers.map((provider) => {
if (provider.provider === providerName) {
return {
...provider,
...data,
}
}
return provider
}),
}))
},

Check warning on line 149 in web-app/src/hooks/useModelProvider.ts

View workflow job for this annotation

GitHub Actions / coverage-check

138-149 lines are not covered with tests
getProviderByName: (providerName: string) => {
const provider = get().providers.find(
(provider) => provider.provider === providerName
Expand All @@ -156,9 +156,9 @@
},
selectModelProvider: (providerName: string, modelName: string) => {
// Find the model object
const provider = get().providers.find(
(provider) => provider.provider === providerName
)

Check warning on line 161 in web-app/src/hooks/useModelProvider.ts

View workflow job for this annotation

GitHub Actions / coverage-check

159-161 lines are not covered with tests

let modelObject: Model | undefined = undefined

Expand Down Expand Up @@ -227,42 +227,31 @@
>
}

// Migration for cont_batching description update (version 0 -> 1)
if (version === 0 && state?.providers) {
state.providers = state.providers.map((provider) => {
state.providers.forEach((provider) => {
// Update cont_batching description for llamacpp provider
if (provider.provider === 'llamacpp' && provider.settings) {
provider.settings = provider.settings.map((setting) => {
if (setting.key === 'cont_batching') {
return {
...setting,
description:
'Enable continuous batching (a.k.a dynamic batching) for concurrent requests.',
}
}
return setting
})
const contBatchingSetting = provider.settings.find(
(s) => s.key === 'cont_batching'
)
if (contBatchingSetting) {
contBatchingSetting.description =
'Enable continuous batching (a.k.a dynamic batching) for concurrent requests.'
}
}
return provider
})
}

// Migration for chatTemplate key to chat_template (version 1 -> 2)
if (version === 1 && state?.providers) {
state.providers.forEach((provider) => {
// Migrate model settings
if (provider.models) {
provider.models.forEach((model) => {
// Initialize settings if it doesn't exist
if (!model.settings) {
model.settings = {}
}
if (!model.settings) model.settings = {}

// Migrate chatTemplate key to chat_template
if (model.settings.chatTemplate) {
model.settings.chat_template = model.settings.chatTemplate
delete model.settings.chatTemplate
}

// Add missing chat_template setting if it doesn't exist
// Add missing settings with defaults
if (!model.settings.chat_template) {
model.settings.chat_template = {
...modelSettings.chatTemplate,
Expand All @@ -271,22 +260,7 @@
},
}
}
})
}
})
}

// Migration for override_tensor_buffer_type key (version 2 -> 3)
if (version === 2 && state?.providers) {
state.providers.forEach((provider) => {
if (provider.models) {
provider.models.forEach((model) => {
// Initialize settings if it doesn't exist
if (!model.settings) {
model.settings = {}
}

// Add missing override_tensor_buffer_type setting if it doesn't exist
if (!model.settings.override_tensor_buffer_t) {
model.settings.override_tensor_buffer_t = {
...modelSettings.override_tensor_buffer_t,
Expand All @@ -303,7 +277,7 @@

return state
},
version: 3,
version: 1,
}
)
)
Loading