Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion core/src/browser/extensions/engines/AIEngine.ts
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ export abstract class AIEngine extends BaseExtension {
/**
* Loads a model into memory
*/
abstract load(modelId: string): Promise<SessionInfo>
abstract load(modelId: string, settings?: any): Promise<SessionInfo>

/**
* Unloads a model from memory
Expand Down
10 changes: 7 additions & 3 deletions web-app/src/containers/ModelSetting.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -23,63 +23,63 @@
}

export function ModelSetting({
model,
provider,
smallIcon,
}: ModelSettingProps) {
const { updateProvider } = useModelProvider()
const { t } = useTranslation()

Check warning on line 31 in web-app/src/containers/ModelSetting.tsx

View workflow job for this annotation

GitHub Actions / coverage-check

26-31 lines are not covered with tests

// Create a debounced version of stopModel that waits 500ms after the last call
const debouncedStopModel = debounce((modelId: string) => {
stopModel(modelId)
}, 500)

Check warning on line 36 in web-app/src/containers/ModelSetting.tsx

View workflow job for this annotation

GitHub Actions / coverage-check

34-36 lines are not covered with tests

const handleSettingChange = (
key: string,
value: string | boolean | number
) => {
if (!provider) return

Check warning on line 42 in web-app/src/containers/ModelSetting.tsx

View workflow job for this annotation

GitHub Actions / coverage-check

38-42 lines are not covered with tests

// Create a copy of the model with updated settings
const updatedModel = {
...model,
settings: {
...model.settings,
[key]: {
...(model.settings?.[key] != null ? model.settings?.[key] : {}),
controller_props: {
...(model.settings?.[key]?.controller_props ?? {}),
value: value,
},
},
},
}

Check warning on line 57 in web-app/src/containers/ModelSetting.tsx

View workflow job for this annotation

GitHub Actions / coverage-check

45-57 lines are not covered with tests

// Find the model index in the provider's models array
const modelIndex = provider.models.findIndex((m) => m.id === model.id)

Check warning on line 60 in web-app/src/containers/ModelSetting.tsx

View workflow job for this annotation

GitHub Actions / coverage-check

60 line is not covered with tests

if (modelIndex !== -1) {

Check warning on line 62 in web-app/src/containers/ModelSetting.tsx

View workflow job for this annotation

GitHub Actions / coverage-check

62 line is not covered with tests
// Create a copy of the provider's models array
const updatedModels = [...provider.models]

Check warning on line 64 in web-app/src/containers/ModelSetting.tsx

View workflow job for this annotation

GitHub Actions / coverage-check

64 line is not covered with tests

// Update the specific model in the array
updatedModels[modelIndex] = updatedModel as Model

Check warning on line 67 in web-app/src/containers/ModelSetting.tsx

View workflow job for this annotation

GitHub Actions / coverage-check

67 line is not covered with tests

// Update the provider with the new models array
updateProvider(provider.provider, {
models: updatedModels,
})

Check warning on line 72 in web-app/src/containers/ModelSetting.tsx

View workflow job for this annotation

GitHub Actions / coverage-check

70-72 lines are not covered with tests

const params = Object.entries(updatedModel.settings).reduce(
(acc, [key, value]) => {
const rawVal = value.controller_props?.value
const num = parseFloat(rawVal as string)
acc[key] = !isNaN(num) ? num : rawVal
return acc
},
{} as Record<string, unknown>
) as ModelSettingParams

Check warning on line 82 in web-app/src/containers/ModelSetting.tsx

View workflow job for this annotation

GitHub Actions / coverage-check

74-82 lines are not covered with tests

updateModel({
id: model.id,
Expand All @@ -87,8 +87,10 @@
...(params as unknown as object),
})

// Call debounced stopModel after updating the model
debouncedStopModel(model.id)
// Call debounced stopModel only when updating ctx_len or ngl
if (key === 'ctx_len' || key === 'ngl') {
debouncedStopModel(model.id)
}
}
}

Expand All @@ -106,7 +108,9 @@
</SheetTrigger>
<SheetContent className="h-[calc(100%-8px)] top-1 right-1 rounded-e-md overflow-y-auto">
<SheetHeader>
<SheetTitle>{t('common:modelSettings.title', { modelId: model.id })}</SheetTitle>
<SheetTitle>
{t('common:modelSettings.title', { modelId: model.id })}
</SheetTitle>
<SheetDescription>
{t('common:modelSettings.description')}
</SheetDescription>
Expand Down
24 changes: 23 additions & 1 deletion web-app/src/hooks/useChat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -261,14 +261,36 @@ export const useChat = () => {
!abortController.signal.aborted &&
activeProvider
) {
const modelConfig = activeProvider.models.find(
(m) => m.id === selectedModel?.id
)

const modelSettings = modelConfig?.settings
? Object.fromEntries(
Object.entries(modelConfig.settings)
.filter(
([key, value]) =>
key !== 'ctx_len' &&
key !== 'ngl' &&
value.controller_props?.value !== undefined &&
value.controller_props?.value !== null &&
value.controller_props?.value !== ''
)
.map(([key, value]) => [key, value.controller_props?.value])
)
: undefined

const completion = await sendCompletion(
activeThread,
activeProvider,
builder.getMessages(),
abortController,
availableTools,
currentAssistant.parameters?.stream === false ? false : true,
currentAssistant.parameters as unknown as Record<string, object>
{
...modelSettings,
...currentAssistant.parameters,
} as unknown as Record<string, object>
)

if (!completion) throw new Error('No completion received')
Expand Down
38 changes: 31 additions & 7 deletions web-app/src/services/__tests__/models.test.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import { describe, it, expect, vi, beforeEach } from 'vitest'

import {
fetchModels,
fetchModelCatalog,
Expand All @@ -10,9 +11,8 @@ import {
stopModel,
stopAllModels,
startModel,
configurePullOptions,
} from '../models'
import { EngineManager } from '@janhq/core'
import { EngineManager, Model } from '@janhq/core'

// Mock EngineManager
vi.mock('@janhq/core', () => ({
Expand Down Expand Up @@ -118,7 +118,7 @@ describe('models service', () => {
settings: [{ key: 'temperature', value: 0.7 }],
}

await updateModel(model)
await updateModel(model as any)

expect(mockEngine.updateSettings).toHaveBeenCalledWith(model.settings)
})
Expand Down Expand Up @@ -209,7 +209,14 @@ describe('models service', () => {

describe('startModel', () => {
it('should start model successfully', async () => {
const provider = { provider: 'openai', models: [] } as ProviderObject
const mockSettings = {
ctx_len: { controller_props: { value: 4096 } },
ngl: { controller_props: { value: 32 } },
}
const provider = {
provider: 'openai',
models: [{ id: 'model1', settings: mockSettings }],
} as any
const model = 'model1'
const mockSession = { id: 'session1' }

Expand All @@ -221,11 +228,21 @@ describe('models service', () => {
const result = await startModel(provider, model)

expect(result).toEqual(mockSession)
expect(mockEngine.load).toHaveBeenCalledWith(model)
expect(mockEngine.load).toHaveBeenCalledWith(model, {
ctx_size: 4096,
n_gpu_layers: 32,
})
})

it('should handle start model error', async () => {
const provider = { provider: 'openai', models: [] } as ProviderObject
const mockSettings = {
ctx_len: { controller_props: { value: 4096 } },
ngl: { controller_props: { value: 32 } },
}
const provider = {
provider: 'openai',
models: [{ id: 'model1', settings: mockSettings }],
} as any
const model = 'model1'
const error = new Error('Failed to start model')

Expand All @@ -237,7 +254,14 @@ describe('models service', () => {
await expect(startModel(provider, model)).rejects.toThrow(error)
})
it('should not load model again', async () => {
const provider = { provider: 'openai', models: [] } as ProviderObject
const mockSettings = {
ctx_len: { controller_props: { value: 4096 } },
ngl: { controller_props: { value: 32 } },
}
const provider = {
provider: 'openai',
models: [{ id: 'model1', settings: mockSettings }],
} as any
const model = 'model1'

mockEngine.getLoadedModels.mockResolvedValue({
Expand Down
24 changes: 23 additions & 1 deletion web-app/src/services/models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,29 @@ export const startModel = async (
if (!engine) return undefined

if ((await engine.getLoadedModels()).includes(model)) return undefined
return engine.load(model).catch((error) => {

// Find the model configuration to get settings
const modelConfig = provider.models.find((m) => m.id === model)

// Key mapping function to transform setting keys
const mapSettingKey = (key: string): string => {
const keyMappings: Record<string, string> = {
ctx_len: 'ctx_size',
ngl: 'n_gpu_layers',
}
return keyMappings[key] || key
}

const settings = modelConfig?.settings
? Object.fromEntries(
Object.entries(modelConfig.settings).map(([key, value]) => [
mapSettingKey(key),
value.controller_props?.value,
])
)
: undefined

return engine.load(model, settings).catch((error) => {
console.error(
`Failed to start model ${model} for provider ${provider.provider}:`,
error
Expand Down
Loading