diff --git a/core/src/browser/extensions/engines/AIEngine.ts b/core/src/browser/extensions/engines/AIEngine.ts index 0e8a75fcaa..855f6e4dca 100644 --- a/core/src/browser/extensions/engines/AIEngine.ts +++ b/core/src/browser/extensions/engines/AIEngine.ts @@ -240,6 +240,12 @@ export abstract class AIEngine extends BaseExtension { EngineManager.instance().register(this) } + /** + * Gets model info + * @param modelId + */ + abstract get(modelId: string): Promise + /** * Lists available models */ diff --git a/extensions-web/src/jan-provider-web/provider.ts b/extensions-web/src/jan-provider-web/provider.ts index 216da66c94..dfdfe01b42 100644 --- a/extensions-web/src/jan-provider-web/provider.ts +++ b/extensions-web/src/jan-provider-web/provider.ts @@ -22,7 +22,7 @@ export default class JanProviderWeb extends AIEngine { override async onLoad() { console.log('Loading Jan Provider Extension...') - + try { // Initialize authentication and fetch models await janApiClient.initialize() @@ -37,20 +37,43 @@ export default class JanProviderWeb extends AIEngine { override async onUnload() { console.log('Unloading Jan Provider Extension...') - + // Clear all sessions for (const sessionId of this.activeSessions.keys()) { await this.unload(sessionId) } - + janProviderStore.reset() console.log('Jan Provider Extension unloaded') } + async get(modelId: string): Promise { + return janApiClient + .getModels() + .then((list) => list.find((e) => e.id === modelId)) + .then((model) => + model + ? { + id: model.id, + name: model.id, // Use ID as name for now + quant_type: undefined, + providerId: this.provider, + port: 443, // HTTPS port for API + sizeBytes: 0, // Size not provided by Jan API + tags: [], + path: undefined, // Remote model, no local path + owned_by: model.owned_by, + object: model.object, + capabilities: ['tools'], // Jan models support both tools via MCP + } + : undefined + ) + } + async list(): Promise { try { const janModels = await janApiClient.getModels() - + return janModels.map((model) => ({ id: model.id, name: model.id, // Use ID as name for now @@ -75,7 +98,7 @@ export default class JanProviderWeb extends AIEngine { // For Jan API, we don't actually "load" models in the traditional sense // We just create a session reference for tracking const sessionId = `jan-${modelId}-${Date.now()}` - + const sessionInfo: SessionInfo = { pid: Date.now(), // Use timestamp as pseudo-PID port: 443, // HTTPS port @@ -85,8 +108,10 @@ export default class JanProviderWeb extends AIEngine { } this.activeSessions.set(sessionId, sessionInfo) - - console.log(`Jan model session created: ${sessionId} for model ${modelId}`) + + console.log( + `Jan model session created: ${sessionId} for model ${modelId}` + ) return sessionInfo } catch (error) { console.error(`Failed to load Jan model ${modelId}:`, error) @@ -97,23 +122,23 @@ export default class JanProviderWeb extends AIEngine { async unload(sessionId: string): Promise { try { const session = this.activeSessions.get(sessionId) - + if (!session) { return { success: false, - error: `Session ${sessionId} not found` + error: `Session ${sessionId} not found`, } } this.activeSessions.delete(sessionId) console.log(`Jan model session unloaded: ${sessionId}`) - + return { success: true } } catch (error) { console.error(`Failed to unload Jan session ${sessionId}:`, error) return { success: false, - error: error instanceof Error ? error.message : 'Unknown error' + error: error instanceof Error ? error.message : 'Unknown error', } } } @@ -136,9 +161,12 @@ export default class JanProviderWeb extends AIEngine { } // Convert core chat completion request to Jan API format - const janMessages: JanChatMessage[] = opts.messages.map(msg => ({ + const janMessages: JanChatMessage[] = opts.messages.map((msg) => ({ role: msg.role as 'system' | 'user' | 'assistant', - content: typeof msg.content === 'string' ? msg.content : JSON.stringify(msg.content) + content: + typeof msg.content === 'string' + ? msg.content + : JSON.stringify(msg.content), })) const janRequest = { @@ -162,18 +190,18 @@ export default class JanProviderWeb extends AIEngine { } else { // Return single response const response = await janApiClient.createChatCompletion(janRequest) - + // Check if aborted after completion if (abortController?.signal?.aborted) { throw new Error('Request was aborted') } - + return { id: response.id, object: 'chat.completion' as const, created: response.created, model: response.model, - choices: response.choices.map(choice => ({ + choices: response.choices.map((choice) => ({ index: choice.index, message: { role: choice.message.role, @@ -182,7 +210,12 @@ export default class JanProviderWeb extends AIEngine { reasoning_content: choice.message.reasoning_content, tool_calls: choice.message.tool_calls, }, - finish_reason: (choice.finish_reason || 'stop') as 'stop' | 'length' | 'tool_calls' | 'content_filter' | 'function_call', + finish_reason: (choice.finish_reason || 'stop') as + | 'stop' + | 'length' + | 'tool_calls' + | 'content_filter' + | 'function_call', })), usage: response.usage, } @@ -193,7 +226,10 @@ export default class JanProviderWeb extends AIEngine { } } - private async *createStreamingGenerator(janRequest: any, abortController?: AbortController) { + private async *createStreamingGenerator( + janRequest: any, + abortController?: AbortController + ) { let resolve: () => void let reject: (error: Error) => void const chunks: any[] = [] @@ -231,7 +267,7 @@ export default class JanProviderWeb extends AIEngine { object: chunk.object, created: chunk.created, model: chunk.model, - choices: chunk.choices.map(choice => ({ + choices: chunk.choices.map((choice) => ({ index: choice.index, delta: { role: choice.delta.role, @@ -261,14 +297,14 @@ export default class JanProviderWeb extends AIEngine { if (abortController?.signal?.aborted) { throw new Error('Request was aborted') } - + while (yieldedIndex < chunks.length) { yield chunks[yieldedIndex] yieldedIndex++ } - + // Wait a bit before checking again - await new Promise(resolve => setTimeout(resolve, 10)) + await new Promise((resolve) => setTimeout(resolve, 10)) } // Yield any remaining chunks @@ -291,24 +327,32 @@ export default class JanProviderWeb extends AIEngine { } async delete(modelId: string): Promise { - throw new Error(`Delete operation not supported for remote Jan API model: ${modelId}`) + throw new Error( + `Delete operation not supported for remote Jan API model: ${modelId}` + ) } async import(modelId: string, _opts: ImportOptions): Promise { - throw new Error(`Import operation not supported for remote Jan API model: ${modelId}`) + throw new Error( + `Import operation not supported for remote Jan API model: ${modelId}` + ) } async abortImport(modelId: string): Promise { - throw new Error(`Abort import operation not supported for remote Jan API model: ${modelId}`) + throw new Error( + `Abort import operation not supported for remote Jan API model: ${modelId}` + ) } async getLoadedModels(): Promise { - return Array.from(this.activeSessions.values()).map(session => session.model_id) + return Array.from(this.activeSessions.values()).map( + (session) => session.model_id + ) } async isToolSupported(modelId: string): Promise { // Jan models support tool calls via MCP - console.log(`Checking tool support for Jan model ${modelId}: supported`); - return true; + console.log(`Checking tool support for Jan model ${modelId}: supported`) + return true } -} \ No newline at end of file +} diff --git a/extensions/llamacpp-extension/src/index.ts b/extensions/llamacpp-extension/src/index.ts index 77b0aafcdd..8fad4fd87a 100644 --- a/extensions/llamacpp-extension/src/index.ts +++ b/extensions/llamacpp-extension/src/index.ts @@ -922,6 +922,30 @@ export default class llamacpp_extension extends AIEngine { return hash } + override async get(modelId: string): Promise { + const modelPath = await joinPath([ + await this.getProviderPath(), + 'models', + modelId, + ]) + const path = await joinPath([modelPath, 'model.yml']) + + if (!(await fs.existsSync(path))) return undefined + + const modelConfig = await invoke('read_yaml', { + path, + }) + + return { + id: modelId, + name: modelConfig.name ?? modelId, + quant_type: undefined, // TODO: parse quantization type from model.yml or model.gguf + providerId: this.provider, + port: 0, // port is not known until the model is loaded + sizeBytes: modelConfig.size_bytes ?? 0, + } as modelInfo + } + // Implement the required LocalProvider interface methods override async list(): Promise { const modelsDir = await joinPath([await this.getProviderPath(), 'models']) @@ -1085,7 +1109,10 @@ export default class llamacpp_extension extends AIEngine { const archiveName = await basename(path) logger.info(`Installing backend from path: ${path}`) - if (!(await fs.existsSync(path)) || (!path.endsWith('tar.gz') && !path.endsWith('zip'))) { + if ( + !(await fs.existsSync(path)) || + (!path.endsWith('tar.gz') && !path.endsWith('zip')) + ) { logger.error(`Invalid path or file ${path}`) throw new Error(`Invalid path or file ${path}`) } @@ -2601,7 +2628,8 @@ export default class llamacpp_extension extends AIEngine { metadata: Record ): Promise { // Extract vision parameters from metadata - const projectionDim = Math.floor(Number(metadata['clip.vision.projection_dim']) / 10) || 256 + const projectionDim = + Math.floor(Number(metadata['clip.vision.projection_dim']) / 10) || 256 // Count images in messages let imageCount = 0 diff --git a/web-app/src/containers/DownloadButton.tsx b/web-app/src/containers/DownloadButton.tsx new file mode 100644 index 0000000000..7d4db703b2 --- /dev/null +++ b/web-app/src/containers/DownloadButton.tsx @@ -0,0 +1,142 @@ +import { Button } from '@/components/ui/button' +import { Progress } from '@/components/ui/progress' +import { useDownloadStore } from '@/hooks/useDownloadStore' +import { useGeneralSetting } from '@/hooks/useGeneralSetting' +import { useModelProvider } from '@/hooks/useModelProvider' +import { useServiceHub } from '@/hooks/useServiceHub' +import { useTranslation } from '@/i18n' +import { extractModelName } from '@/lib/models' +import { cn, sanitizeModelId } from '@/lib/utils' +import { CatalogModel } from '@/services/models/types' +import { useCallback, useMemo } from 'react' +import { useShallow } from 'zustand/shallow' + +type ModelProps = { + model: CatalogModel + handleUseModel: (modelId: string) => void +} +const defaultModelQuantizations = ['iq4_xs', 'q4_k_m'] + +export function DownloadButtonPlaceholder({ + model, + handleUseModel, +}: ModelProps) { + const { downloads, localDownloadingModels, addLocalDownloadingModel } = + useDownloadStore( + useShallow((state) => ({ + downloads: state.downloads, + localDownloadingModels: state.localDownloadingModels, + addLocalDownloadingModel: state.addLocalDownloadingModel, + })) + ) + const { t } = useTranslation() + const getProviderByName = useModelProvider((state) => state.getProviderByName) + const llamaProvider = getProviderByName('llamacpp') + + const serviceHub = useServiceHub() + const huggingfaceToken = useGeneralSetting((state) => state.huggingfaceToken) + + const quant = + model.quants.find((e) => + defaultModelQuantizations.some((m) => + e.model_id.toLowerCase().includes(m) + ) + ) ?? model.quants[0] + + const modelId = quant?.model_id || model.model_name + + const downloadProcesses = useMemo( + () => + Object.values(downloads).map((download) => ({ + id: download.name, + name: download.name, + progress: download.progress, + current: download.current, + total: download.total, + })), + [downloads] + ) + + const isRecommendedModel = useCallback((modelId: string) => { + return (extractModelName(modelId)?.toLowerCase() === + 'jan-nano-gguf') as boolean + }, []) + + if (model.quants.length === 0) { + return ( +
+ +
+ ) + } + + const modelUrl = quant?.path || modelId + const isDownloading = + localDownloadingModels.has(modelId) || + downloadProcesses.some((e) => e.id === modelId) + + const downloadProgress = + downloadProcesses.find((e) => e.id === modelId)?.progress || 0 + const isDownloaded = llamaProvider?.models.some( + (m: { id: string }) => + m.id === modelId || + m.id === `${model.developer}/${sanitizeModelId(modelId)}` + ) + const isRecommended = isRecommendedModel(model.model_name) + + const handleDownload = () => { + // Immediately set local downloading state + addLocalDownloadingModel(modelId) + const mmprojPath = ( + model.mmproj_models?.find( + (e) => e.model_id.toLowerCase() === 'mmproj-f16' + ) || model.mmproj_models?.[0] + )?.path + serviceHub + .models() + .pullModelWithMetadata(modelId, modelUrl, mmprojPath, huggingfaceToken) + } + + return ( +
+ {isDownloading && !isDownloaded && ( +
+ + + {Math.round(downloadProgress * 100)}% + +
+ )} + {isDownloaded ? ( + + ) : ( + + )} +
+ ) +} diff --git a/web-app/src/containers/RenderMarkdown.tsx b/web-app/src/containers/RenderMarkdown.tsx index da702eff6d..31d08cf105 100644 --- a/web-app/src/containers/RenderMarkdown.tsx +++ b/web-app/src/containers/RenderMarkdown.tsx @@ -89,6 +89,7 @@ const CodeComponent = memo( onCopy, copiedId, ...props + // eslint-disable-next-line @typescript-eslint/no-explicit-any }: any) => { const { t } = useTranslation() const match = /language-(\w+)/.exec(className || '') diff --git a/web-app/src/routes/hub/$modelId.tsx b/web-app/src/routes/hub/$modelId.tsx index 75ccc58bf6..102b5cecea 100644 --- a/web-app/src/routes/hub/$modelId.tsx +++ b/web-app/src/routes/hub/$modelId.tsx @@ -21,10 +21,7 @@ import { useEffect, useMemo, useCallback, useState } from 'react' import { useModelProvider } from '@/hooks/useModelProvider' import { useDownloadStore } from '@/hooks/useDownloadStore' import { useServiceHub } from '@/hooks/useServiceHub' -import type { - CatalogModel, - ModelQuant, -} from '@/services/models/types' +import type { CatalogModel, ModelQuant } from '@/services/models/types' import { Progress } from '@/components/ui/progress' import { Button } from '@/components/ui/button' import { cn } from '@/lib/utils' @@ -80,12 +77,13 @@ function HubModelDetailContent() { }, [fetchSources]) const fetchRepo = useCallback(async () => { - const repoInfo = await serviceHub.models().fetchHuggingFaceRepo( - search.repo || modelId, - huggingfaceToken - ) + const repoInfo = await serviceHub + .models() + .fetchHuggingFaceRepo(search.repo || modelId, huggingfaceToken) if (repoInfo) { - const repoDetail = serviceHub.models().convertHfRepoToCatalogModel(repoInfo) + const repoDetail = serviceHub + .models() + .convertHfRepoToCatalogModel(repoInfo) setRepoData(repoDetail || undefined) } }, [serviceHub, modelId, search, huggingfaceToken]) @@ -168,7 +166,9 @@ function HubModelDetailContent() { try { // Use the HuggingFace path for the model const modelPath = variant.path - const supported = await serviceHub.models().isModelSupported(modelPath, 8192) + const supported = await serviceHub + .models() + .isModelSupported(modelPath, 8192) setModelSupportStatus((prev) => ({ ...prev, [modelKey]: supported, @@ -473,12 +473,20 @@ function HubModelDetailContent() { addLocalDownloadingModel( variant.model_id ) - serviceHub.models().pullModelWithMetadata( - variant.model_id, - variant.path, - modelData.mmproj_models?.[0]?.path, - huggingfaceToken - ) + serviceHub + .models() + .pullModelWithMetadata( + variant.model_id, + variant.path, + ( + modelData.mmproj_models?.find( + (e) => + e.model_id.toLowerCase() === + 'mmproj-f16' + ) || modelData.mmproj_models?.[0] + )?.path, + huggingfaceToken + ) }} className={cn(isDownloading && 'hidden')} > diff --git a/web-app/src/routes/hub/index.tsx b/web-app/src/routes/hub/index.tsx index 2a53a848fa..be63c49b69 100644 --- a/web-app/src/routes/hub/index.tsx +++ b/web-app/src/routes/hub/index.tsx @@ -1,6 +1,6 @@ /* eslint-disable @typescript-eslint/no-explicit-any */ import { useVirtualizer } from '@tanstack/react-virtual' -import { createFileRoute, useNavigate, useSearch } from '@tanstack/react-router' +import { createFileRoute, useNavigate } from '@tanstack/react-router' import { route } from '@/constants/routes' import { useModelSources } from '@/hooks/useModelSources' import { cn } from '@/lib/utils' @@ -34,8 +34,6 @@ import { TooltipTrigger, } from '@/components/ui/tooltip' import { ModelInfoHoverCard } from '@/containers/ModelInfoHoverCard' -import Joyride, { CallBackProps, STATUS } from 'react-joyride' -import { CustomTooltipJoyRide } from '@/containers/CustomeTooltipJoyRide' import { DropdownMenu, DropdownMenuContent, @@ -51,10 +49,9 @@ import { Loader } from 'lucide-react' import { useTranslation } from '@/i18n/react-i18next-compat' import Fuse from 'fuse.js' import { useGeneralSetting } from '@/hooks/useGeneralSetting' +import { DownloadButtonPlaceholder } from '@/containers/DownloadButton' +import { useShallow } from 'zustand/shallow' -type ModelProps = { - model: CatalogModel -} type SearchParams = { repo: string } @@ -77,7 +74,7 @@ function Hub() { function HubContent() { const parentRef = useRef(null) - const { huggingfaceToken } = useGeneralSetting() + const huggingfaceToken = useGeneralSetting((state) => state.huggingfaceToken) const serviceHub = useServiceHub() const { t } = useTranslation() @@ -93,7 +90,13 @@ function HubContent() { } }, []) - const { sources, fetchSources, loading } = useModelSources() + const { sources, fetchSources, loading } = useModelSources( + useShallow((state) => ({ + sources: state.sources, + fetchSources: state.fetchSources, + loading: state.loading, + })) + ) const [searchValue, setSearchValue] = useState('') const [sortSelected, setSortSelected] = useState('newest') @@ -108,16 +111,9 @@ function HubContent() { const [modelSupportStatus, setModelSupportStatus] = useState< Record >({}) - const [joyrideReady, setJoyrideReady] = useState(false) - const [currentStepIndex, setCurrentStepIndex] = useState(0) const addModelSourceTimeoutRef = useRef | null>( null ) - const downloadButtonRef = useRef(null) - const hasTriggeredDownload = useRef(false) - - const { getProviderByName } = useModelProvider() - const llamaProvider = getProviderByName('llamacpp') const toggleModelExpansion = (modelId: string) => { setExpandedModels((prev) => ({ @@ -168,9 +164,10 @@ function HubContent() { ?.map((model) => ({ ...model, quants: model.quants.filter((variant) => - llamaProvider?.models.some( - (m: { id: string }) => m.id === variant.model_id - ) + useModelProvider + .getState() + .getProviderByName('llamacpp') + ?.models.some((m: { id: string }) => m.id === variant.model_id) ), })) .filter((model) => model.quants.length > 0) @@ -186,7 +183,6 @@ function HubContent() { showOnlyDownloaded, huggingFaceRepo, searchOptions, - llamaProvider?.models, ]) // The virtualizer @@ -215,9 +211,13 @@ function HubContent() { addModelSourceTimeoutRef.current = setTimeout(async () => { try { - const repoInfo = await serviceHub.models().fetchHuggingFaceRepo(searchValue, huggingfaceToken) + const repoInfo = await serviceHub + .models() + .fetchHuggingFaceRepo(searchValue, huggingfaceToken) if (repoInfo) { - const catalogModel = serviceHub.models().convertHfRepoToCatalogModel(repoInfo) + const catalogModel = serviceHub + .models() + .convertHfRepoToCatalogModel(repoInfo) if ( !sources.some( (s) => @@ -303,7 +303,9 @@ function HubContent() { try { // Use the HuggingFace path for the model const modelPath = variant.path - const supportStatus = await serviceHub.models().isModelSupported(modelPath, 8192) + const supportStatus = await serviceHub + .models() + .isModelSupported(modelPath, 8192) setModelSupportStatus((prev) => ({ ...prev, @@ -320,178 +322,7 @@ function HubContent() { [modelSupportStatus, serviceHub] ) - const DownloadButtonPlaceholder = useMemo(() => { - return ({ model }: ModelProps) => { - // Check if this is a HuggingFace repository (no quants) - if (model.quants.length === 0) { - return ( -
- -
- ) - } - - const quant = - model.quants.find((e) => - defaultModelQuantizations.some((m) => - e.model_id.toLowerCase().includes(m) - ) - ) ?? model.quants[0] - const modelId = quant?.model_id || model.model_name - const modelUrl = quant?.path || modelId - const isDownloading = - localDownloadingModels.has(modelId) || - downloadProcesses.some((e) => e.id === modelId) - const downloadProgress = - downloadProcesses.find((e) => e.id === modelId)?.progress || 0 - const isDownloaded = llamaProvider?.models.some( - (m: { id: string }) => m.id === modelId - ) - const isRecommended = isRecommendedModel(model.model_name) - - const handleDownload = () => { - // Immediately set local downloading state - addLocalDownloadingModel(modelId) - const mmprojPath = model.mmproj_models?.[0]?.path - serviceHub.models().pullModelWithMetadata( - modelId, - modelUrl, - mmprojPath, - huggingfaceToken - ) - } - - return ( -
- {isDownloading && !isDownloaded && ( -
- - - {Math.round(downloadProgress * 100)}% - -
- )} - {isDownloaded ? ( - - ) : ( - - )} -
- ) - } - }, [ - localDownloadingModels, - downloadProcesses, - llamaProvider?.models, - isRecommendedModel, - t, - addLocalDownloadingModel, - huggingfaceToken, - handleUseModel, - serviceHub, - ]) - - const { step } = useSearch({ from: Route.id }) - const isSetup = step === 'setup_local_provider' - - // Wait for DOM to be ready before starting Joyride - useEffect(() => { - if (!loading && filteredModels.length > 0 && isSetup) { - const timer = setTimeout(() => { - setJoyrideReady(true) - }, 100) - return () => clearTimeout(timer) - } else { - setJoyrideReady(false) - } - }, [loading, filteredModels.length, isSetup]) - - const handleJoyrideCallback = (data: CallBackProps) => { - const { status, index } = data - - if ( - status === STATUS.FINISHED && - !isDownloading && - isLastStep && - !hasTriggeredDownload.current - ) { - const recommendedModel = filteredModels.find((model) => - isRecommendedModel(model.model_name) - ) - if (recommendedModel && recommendedModel.quants[0]?.model_id) { - if (downloadButtonRef.current) { - hasTriggeredDownload.current = true - downloadButtonRef.current.click() - } - return - } - } - - if (status === STATUS.FINISHED) { - navigate({ - to: route.hub.index, - }) - } - - // Track current step index - setCurrentStepIndex(index) - } - - // Check if any model is currently downloading - const isDownloading = - localDownloadingModels.size > 0 || downloadProcesses.length > 0 - - const steps = [ - { - target: '.hub-model-card-step', - title: t('hub:joyride.recommendedModelTitle'), - disableBeacon: true, - content: t('hub:joyride.recommendedModelContent'), - }, - { - target: '.hub-download-button-step', - title: isDownloading - ? t('hub:joyride.downloadInProgressTitle') - : t('hub:joyride.downloadModelTitle'), - disableBeacon: true, - content: isDownloading - ? t('hub:joyride.downloadInProgressContent') - : t('hub:joyride.downloadModelContent'), - }, - ] - // Check if we're on the last step - const isLastStep = currentStepIndex === steps.length - 1 - const renderFilter = () => { return ( <> @@ -544,31 +375,6 @@ function HubContent() { return ( <> -
@@ -698,6 +504,7 @@ function HubContent() { />
@@ -908,10 +715,13 @@ function HubContent() { (e) => e.id === variant.model_id )?.progress || 0 const isDownloaded = - llamaProvider?.models.some( - (m: { id: string }) => - m.id === variant.model_id - ) + useModelProvider + .getState() + .getProviderByName('llamacpp') + ?.models.some( + (m: { id: string }) => + m.id === variant.model_id + ) if (isDownloading) { return ( @@ -962,14 +772,26 @@ function HubContent() { addLocalDownloadingModel( variant.model_id ) - serviceHub.models().pullModelWithMetadata( - variant.model_id, - variant.path, - filteredModels[ - virtualItem.index - ].mmproj_models?.[0]?.path, - huggingfaceToken - ) + serviceHub + .models() + .pullModelWithMetadata( + variant.model_id, + variant.path, + + ( + filteredModels[ + virtualItem.index + ].mmproj_models?.find( + (e) => + e.model_id.toLowerCase() === + 'mmproj-f16' + ) || + filteredModels[ + virtualItem.index + ].mmproj_models?.[0] + )?.path, + huggingfaceToken + ) }} > { let modelsService: DefaultModelsService - + const mockEngine = { list: vi.fn(), updateSettings: vi.fn(), @@ -246,7 +246,9 @@ describe('DefaultModelsService', () => { }) mockEngine.load.mockRejectedValue(error) - await expect(modelsService.startModel(provider, model)).rejects.toThrow(error) + await expect(modelsService.startModel(provider, model)).rejects.toThrow( + error + ) }) it('should not load model again', async () => { const mockSettings = { @@ -263,7 +265,9 @@ describe('DefaultModelsService', () => { includes: () => true, }) expect(mockEngine.load).toBeCalledTimes(0) - await expect(modelsService.startModel(provider, model)).resolves.toBe(undefined) + await expect(modelsService.startModel(provider, model)).resolves.toBe( + undefined + ) }) }) @@ -312,7 +316,9 @@ describe('DefaultModelsService', () => { json: vi.fn().mockResolvedValue(mockRepoData), }) - const result = await modelsService.fetchHuggingFaceRepo('microsoft/DialoGPT-medium') + const result = await modelsService.fetchHuggingFaceRepo( + 'microsoft/DialoGPT-medium' + ) expect(result).toEqual(mockRepoData) expect(fetch).toHaveBeenCalledWith( @@ -342,7 +348,9 @@ describe('DefaultModelsService', () => { ) // Test with domain prefix - await modelsService.fetchHuggingFaceRepo('huggingface.co/microsoft/DialoGPT-medium') + await modelsService.fetchHuggingFaceRepo( + 'huggingface.co/microsoft/DialoGPT-medium' + ) expect(fetch).toHaveBeenCalledWith( 'https://huggingface.co/api/models/microsoft/DialoGPT-medium?blobs=true&files_metadata=true', { @@ -365,7 +373,9 @@ describe('DefaultModelsService', () => { expect(await modelsService.fetchHuggingFaceRepo('')).toBeNull() // Test string without slash - expect(await modelsService.fetchHuggingFaceRepo('invalid-repo')).toBeNull() + expect( + await modelsService.fetchHuggingFaceRepo('invalid-repo') + ).toBeNull() // Test whitespace only expect(await modelsService.fetchHuggingFaceRepo(' ')).toBeNull() @@ -378,7 +388,8 @@ describe('DefaultModelsService', () => { statusText: 'Not Found', }) - const result = await modelsService.fetchHuggingFaceRepo('nonexistent/model') + const result = + await modelsService.fetchHuggingFaceRepo('nonexistent/model') expect(result).toBeNull() expect(fetch).toHaveBeenCalledWith( @@ -398,7 +409,9 @@ describe('DefaultModelsService', () => { statusText: 'Internal Server Error', }) - const result = await modelsService.fetchHuggingFaceRepo('microsoft/DialoGPT-medium') + const result = await modelsService.fetchHuggingFaceRepo( + 'microsoft/DialoGPT-medium' + ) expect(result).toBeNull() expect(consoleSpy).toHaveBeenCalledWith( @@ -414,7 +427,9 @@ describe('DefaultModelsService', () => { ;(fetch as any).mockRejectedValue(new Error('Network error')) - const result = await modelsService.fetchHuggingFaceRepo('microsoft/DialoGPT-medium') + const result = await modelsService.fetchHuggingFaceRepo( + 'microsoft/DialoGPT-medium' + ) expect(result).toBeNull() expect(consoleSpy).toHaveBeenCalledWith( @@ -448,7 +463,9 @@ describe('DefaultModelsService', () => { json: vi.fn().mockResolvedValue(mockRepoData), }) - const result = await modelsService.fetchHuggingFaceRepo('microsoft/DialoGPT-medium') + const result = await modelsService.fetchHuggingFaceRepo( + 'microsoft/DialoGPT-medium' + ) expect(result).toEqual(mockRepoData) }) @@ -487,7 +504,9 @@ describe('DefaultModelsService', () => { json: vi.fn().mockResolvedValue(mockRepoData), }) - const result = await modelsService.fetchHuggingFaceRepo('microsoft/DialoGPT-medium') + const result = await modelsService.fetchHuggingFaceRepo( + 'microsoft/DialoGPT-medium' + ) expect(result).toEqual(mockRepoData) }) @@ -531,7 +550,9 @@ describe('DefaultModelsService', () => { json: vi.fn().mockResolvedValue(mockRepoData), }) - const result = await modelsService.fetchHuggingFaceRepo('microsoft/DialoGPT-medium') + const result = await modelsService.fetchHuggingFaceRepo( + 'microsoft/DialoGPT-medium' + ) expect(result).toEqual(mockRepoData) // Verify the GGUF file is present in siblings @@ -576,7 +597,8 @@ describe('DefaultModelsService', () => { } it('should convert HuggingFace repo to catalog model format', () => { - const result = modelsService.convertHfRepoToCatalogModel(mockHuggingFaceRepo) + const result = + modelsService.convertHfRepoToCatalogModel(mockHuggingFaceRepo) const expected: CatalogModel = { model_name: 'microsoft/DialoGPT-medium', @@ -586,12 +608,12 @@ describe('DefaultModelsService', () => { num_quants: 2, quants: [ { - model_id: 'model-q4_0', + model_id: 'microsoft/model-q4_0', path: 'https://huggingface.co/microsoft/DialoGPT-medium/resolve/main/model-q4_0.gguf', file_size: '2.0 GB', }, { - model_id: 'model-q8_0', + model_id: 'microsoft/model-q8_0', path: 'https://huggingface.co/microsoft/DialoGPT-medium/resolve/main/model-q8_0.GGUF', file_size: '4.0 GB', }, @@ -635,7 +657,8 @@ describe('DefaultModelsService', () => { siblings: undefined, } - const result = modelsService.convertHfRepoToCatalogModel(repoWithoutSiblings) + const result = + modelsService.convertHfRepoToCatalogModel(repoWithoutSiblings) expect(result.num_quants).toBe(0) expect(result.quants).toEqual([]) @@ -663,7 +686,9 @@ describe('DefaultModelsService', () => { ], } - const result = modelsService.convertHfRepoToCatalogModel(repoWithVariousFileSizes) + const result = modelsService.convertHfRepoToCatalogModel( + repoWithVariousFileSizes + ) expect(result.quants[0].file_size).toBe('500.0 MB') expect(result.quants[1].file_size).toBe('3.5 GB') @@ -676,7 +701,8 @@ describe('DefaultModelsService', () => { tags: [], } - const result = modelsService.convertHfRepoToCatalogModel(repoWithEmptyTags) + const result = + modelsService.convertHfRepoToCatalogModel(repoWithEmptyTags) expect(result.description).toBe('**Tags**: ') }) @@ -687,7 +713,8 @@ describe('DefaultModelsService', () => { downloads: undefined as any, } - const result = modelsService.convertHfRepoToCatalogModel(repoWithoutDownloads) + const result = + modelsService.convertHfRepoToCatalogModel(repoWithoutDownloads) expect(result.downloads).toBe(0) }) @@ -714,15 +741,17 @@ describe('DefaultModelsService', () => { ], } - const result = modelsService.convertHfRepoToCatalogModel(repoWithVariousGGUF) + const result = + modelsService.convertHfRepoToCatalogModel(repoWithVariousGGUF) - expect(result.quants[0].model_id).toBe('model') - expect(result.quants[1].model_id).toBe('MODEL') - expect(result.quants[2].model_id).toBe('complex-model-name') + expect(result.quants[0].model_id).toBe('microsoft/model') + expect(result.quants[1].model_id).toBe('microsoft/MODEL') + expect(result.quants[2].model_id).toBe('microsoft/complex-model-name') }) it('should generate correct download paths', () => { - const result = modelsService.convertHfRepoToCatalogModel(mockHuggingFaceRepo) + const result = + modelsService.convertHfRepoToCatalogModel(mockHuggingFaceRepo) expect(result.quants[0].path).toBe( 'https://huggingface.co/microsoft/DialoGPT-medium/resolve/main/model-q4_0.gguf' @@ -733,7 +762,8 @@ describe('DefaultModelsService', () => { }) it('should generate correct readme URL', () => { - const result = modelsService.convertHfRepoToCatalogModel(mockHuggingFaceRepo) + const result = + modelsService.convertHfRepoToCatalogModel(mockHuggingFaceRepo) expect(result.readme).toBe( 'https://huggingface.co/microsoft/DialoGPT-medium/resolve/main/README.md' @@ -767,13 +797,14 @@ describe('DefaultModelsService', () => { ], } - const result = modelsService.convertHfRepoToCatalogModel(repoWithMixedCase) + const result = + modelsService.convertHfRepoToCatalogModel(repoWithMixedCase) expect(result.num_quants).toBe(3) expect(result.quants).toHaveLength(3) - expect(result.quants[0].model_id).toBe('model-1') - expect(result.quants[1].model_id).toBe('model-2') - expect(result.quants[2].model_id).toBe('model-3') + expect(result.quants[0].model_id).toBe('microsoft/model-1') + expect(result.quants[1].model_id).toBe('microsoft/model-2') + expect(result.quants[2].model_id).toBe('microsoft/model-3') }) it('should handle edge cases with file size formatting', () => { @@ -798,7 +829,8 @@ describe('DefaultModelsService', () => { ], } - const result = modelsService.convertHfRepoToCatalogModel(repoWithEdgeCases) + const result = + modelsService.convertHfRepoToCatalogModel(repoWithEdgeCases) expect(result.quants[0].file_size).toBe('0.0 MB') expect(result.quants[1].file_size).toBe('1.0 GB') @@ -850,7 +882,10 @@ describe('DefaultModelsService', () => { mockEngineManager.get.mockReturnValue(mockEngineWithSupport) - const result = await modelsService.isModelSupported('/path/to/model.gguf', 4096) + const result = await modelsService.isModelSupported( + '/path/to/model.gguf', + 4096 + ) expect(result).toBe('GREEN') expect(mockEngineWithSupport.isModelSupported).toHaveBeenCalledWith( @@ -867,7 +902,10 @@ describe('DefaultModelsService', () => { mockEngineManager.get.mockReturnValue(mockEngineWithSupport) - const result = await modelsService.isModelSupported('/path/to/model.gguf', 8192) + const result = await modelsService.isModelSupported( + '/path/to/model.gguf', + 8192 + ) expect(result).toBe('YELLOW') expect(mockEngineWithSupport.isModelSupported).toHaveBeenCalledWith( @@ -884,7 +922,9 @@ describe('DefaultModelsService', () => { mockEngineManager.get.mockReturnValue(mockEngineWithSupport) - const result = await modelsService.isModelSupported('/path/to/large-model.gguf') + const result = await modelsService.isModelSupported( + '/path/to/large-model.gguf' + ) expect(result).toBe('RED') expect(mockEngineWithSupport.isModelSupported).toHaveBeenCalledWith( diff --git a/web-app/src/services/models/default.ts b/web-app/src/services/models/default.ts index 5a31f39936..1867063346 100644 --- a/web-app/src/services/models/default.ts +++ b/web-app/src/services/models/default.ts @@ -30,6 +30,10 @@ export class DefaultModelsService implements ModelsService { return EngineManager.instance().get(provider) as AIEngine | undefined } + async getModel(modelId: string): Promise { + return this.getEngine()?.get(modelId) + } + async fetchModels(): Promise { return this.getEngine()?.list() ?? [] } @@ -127,7 +131,7 @@ export class DefaultModelsService implements ModelsService { const modelId = file.rfilename.replace(/\.gguf$/i, '') return { - model_id: sanitizeModelId(modelId), + model_id: `${repo.author}/${sanitizeModelId(modelId)}`, path: `https://huggingface.co/${repo.modelId}/resolve/main/${file.rfilename}`, file_size: formatFileSize(file.size), } diff --git a/web-app/src/services/models/types.ts b/web-app/src/services/models/types.ts index 5bf66b8bfb..d92dae38a8 100644 --- a/web-app/src/services/models/types.ts +++ b/web-app/src/services/models/types.ts @@ -90,6 +90,7 @@ export interface ModelPlan { } export interface ModelsService { + getModel(modelId: string): Promise fetchModels(): Promise fetchModelCatalog(): Promise fetchHuggingFaceRepo(