Skip to content

Commit 66329b2

Browse files
feat(opencode): add dynamic model fetching for OpenAI-compatible providers
- Add fetch-models.ts module with OpenAI-compatible model fetching - Add caching support with TTL for fetched models - Add shouldFetchModels option to provider config (defaults to true) - Update tests to disable model fetching in test fixtures
1 parent 765aa51 commit 66329b2

File tree

2 files changed

+302
-1
lines changed

2 files changed

+302
-1
lines changed
Lines changed: 297 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,297 @@
1+
import { Global } from "../global"
2+
import { Log } from "../util/log"
3+
import path from "path"
4+
import { Installation } from "../installation"
5+
import type { ModelsDev } from "./models"
6+
7+
const log = Log.create({ service: "fetch-models" })
8+
9+
export namespace FetchModels {
10+
export interface OpenAIModel {
11+
id: string
12+
object: string
13+
created: number
14+
owned_by: string
15+
}
16+
17+
export interface OpenAIModelsResponse {
18+
object: string
19+
data: OpenAIModel[]
20+
}
21+
22+
export interface FetchOptions {
23+
baseURL: string
24+
apiKey?: string
25+
customEndpoint?: string
26+
headers?: Record<string, string>
27+
}
28+
29+
const cacheDir = path.join(Global.Path.cache, "fetched-models")
30+
31+
function getCacheFilePath(providerId: string): string {
32+
return path.join(cacheDir, `${providerId}.json`)
33+
}
34+
35+
function inferContextLimit(modelId: string): number {
36+
const lower = modelId.toLowerCase()
37+
// Check for explicit context size in name
38+
const contextMatch = lower.match(/(\d+)[kkm](?=\b|[^a-z])/)
39+
if (contextMatch) {
40+
const num = parseInt(contextMatch[1])
41+
if (lower.includes("m")) return num * 1000000
42+
if (lower.includes("k")) return num * 1000
43+
}
44+
// Known model families
45+
if (lower.includes("claude-3-opus") || lower.includes("claude-opus-4")) return 200000
46+
if (lower.includes("claude-3-5-sonnet") || lower.includes("claude-sonnet-4")) return 200000
47+
if (lower.includes("claude-3-haiku") || lower.includes("claude-haiku-4")) return 200000
48+
if (lower.includes("claude-3")) return 200000
49+
if (lower.includes("gpt-4o")) return 128000
50+
if (lower.includes("gpt-4-turbo")) return 128000
51+
if (lower.includes("gpt-4")) return 8192
52+
if (lower.includes("gpt-3.5-turbo")) return 16385
53+
if (lower.includes("gemini-1.5-pro") || lower.includes("gemini-3-pro")) return 2097152
54+
if (lower.includes("gemini-1.5-flash") || lower.includes("gemini-3-flash")) return 1048576
55+
if (lower.includes("gemini")) return 1048576
56+
if (lower.includes("deepseek")) return 262144
57+
if (lower.includes("llama-3.1") || lower.includes("llama-3-1")) return 128000
58+
if (lower.includes("llama-3")) return 8192
59+
if (lower.includes("mistral-large") || lower.includes("mistral-small")) return 128000
60+
if (lower.includes("mixtral")) return 32768
61+
if (lower.includes("kimi")) return 262144
62+
return 128000 // Safe default
63+
}
64+
65+
function inferOutputLimit(modelId: string): number {
66+
const lower = modelId.toLowerCase()
67+
if (lower.includes("claude-3-opus") || lower.includes("claude-opus-4")) return 32000
68+
if (lower.includes("claude-3-5-sonnet") || lower.includes("claude-sonnet-4")) return 32000
69+
if (lower.includes("claude-3-haiku") || lower.includes("claude-haiku-4")) return 32000
70+
if (lower.includes("claude-3")) return 4096
71+
if (lower.includes("gpt-4o")) return 16384
72+
if (lower.includes("gpt-4-turbo")) return 4096
73+
if (lower.includes("gpt-4")) return 8192
74+
if (lower.includes("gemini-3")) return 32000
75+
if (lower.includes("gemini-1.5")) return 8192
76+
if (lower.includes("deepseek")) return 32000
77+
if (lower.includes("kimi")) return 32000
78+
return 4096 // Default
79+
}
80+
81+
function inferAttachmentSupport(modelId: string): boolean {
82+
const visionModels = [
83+
"vision",
84+
"claude-3",
85+
"claude-opus-4",
86+
"claude-sonnet-4",
87+
"claude-haiku-4",
88+
"gpt-4o",
89+
"gemini",
90+
"kimi",
91+
"qwen-vl",
92+
"multimodal",
93+
"pixtral",
94+
"llava",
95+
]
96+
return visionModels.some((v) => modelId.toLowerCase().includes(v))
97+
}
98+
99+
function inferReasoningSupport(modelId: string): boolean {
100+
const reasoningModels = ["o1", "o3", "reasoning", "r1", "thinking", "deepseek-r1", "glm-flash-thinking"]
101+
return reasoningModels.some((r) => modelId.toLowerCase().includes(r))
102+
}
103+
104+
function inferToolCallSupport(modelId: string): boolean {
105+
const noToolModels = ["embedding", "embed", "tts", "whisper", "moderation", "dall-e", "image", "audio"]
106+
const hasNoTool = noToolModels.some((m) => modelId.toLowerCase().includes(m))
107+
if (hasNoTool) return false
108+
109+
// Most modern models support tool calling
110+
const modernModels = [
111+
"claude-3",
112+
"claude-opus-4",
113+
"claude-sonnet-4",
114+
"claude-haiku-4",
115+
"gpt-4",
116+
"gpt-3.5-turbo",
117+
"gemini",
118+
"mistral",
119+
"mixtral",
120+
"llama-3",
121+
"kimi",
122+
"deepseek",
123+
"command-r",
124+
]
125+
return modernModels.some((m) => modelId.toLowerCase().includes(m))
126+
}
127+
128+
function inferModalities(modelId: string): { input: string[]; output: string[] } {
129+
const lower = modelId.toLowerCase()
130+
const isEmbedding = lower.includes("embedding") || lower.includes("embed")
131+
const isTTS = lower.includes("tts") || lower.includes("whisper")
132+
const isImageGen = lower.includes("dall-e") || lower.includes("image") || lower.includes("stable-diffusion")
133+
const isVision = inferAttachmentSupport(modelId)
134+
135+
if (isEmbedding) {
136+
return { input: ["text"], output: [] }
137+
}
138+
if (isTTS) {
139+
return { input: ["text"], output: ["audio"] }
140+
}
141+
if (isImageGen) {
142+
return { input: ["text"], output: ["image"] }
143+
}
144+
145+
const input: string[] = ["text"]
146+
if (isVision) input.push("image")
147+
148+
return { input, output: ["text"] }
149+
}
150+
151+
function transformOpenAIModel(model: OpenAIModel): ModelsDev.Model {
152+
const modalities = inferModalities(model.id)
153+
const releaseDate = model.created ? new Date(model.created * 1000).toISOString().split("T")[0] : new Date().toISOString().split("T")[0]
154+
155+
return {
156+
id: model.id,
157+
name: model.id,
158+
release_date: releaseDate,
159+
attachment: inferAttachmentSupport(model.id),
160+
reasoning: inferReasoningSupport(model.id),
161+
temperature: true,
162+
tool_call: inferToolCallSupport(model.id),
163+
limit: {
164+
context: inferContextLimit(model.id),
165+
output: inferOutputLimit(model.id),
166+
},
167+
modalities: {
168+
input: modalities.input as ("text" | "audio" | "image" | "video" | "pdf")[],
169+
output: modalities.output as ("text" | "audio" | "image" | "video" | "pdf")[],
170+
},
171+
options: {},
172+
}
173+
}
174+
175+
export async function fetchFromEndpoint(options: FetchOptions): Promise<ModelsDev.Model[]> {
176+
const endpoint = options.customEndpoint || `${options.baseURL.replace(/\/$/, "")}/v1/models`
177+
178+
log.info("Fetching models from endpoint", { endpoint })
179+
180+
const headers: Record<string, string> = {
181+
"User-Agent": Installation.USER_AGENT,
182+
...options.headers,
183+
}
184+
185+
if (options.apiKey) {
186+
headers["Authorization"] = `Bearer ${options.apiKey}`
187+
}
188+
189+
const response = await fetch(endpoint, {
190+
headers,
191+
signal: AbortSignal.timeout(30 * 1000), // 30 second timeout
192+
})
193+
194+
if (!response.ok) {
195+
const errorText = await response.text().catch(() => "Unknown error")
196+
throw new Error(`Failed to fetch models: ${response.status} ${response.statusText} - ${errorText}`)
197+
}
198+
199+
const data: OpenAIModelsResponse = await response.json()
200+
201+
if (!data.data || !Array.isArray(data.data)) {
202+
throw new Error("Invalid response format: expected 'data' array")
203+
}
204+
205+
log.info("Successfully fetched models", { count: data.data.length })
206+
207+
return data.data.map(transformOpenAIModel)
208+
}
209+
210+
export interface CachedModels {
211+
timestamp: number
212+
models: ModelsDev.Model[]
213+
ttl: number
214+
}
215+
216+
export async function getCached(providerId: string, ttlMs: number = 60 * 60 * 1000): Promise<ModelsDev.Model[] | undefined> {
217+
try {
218+
const cacheFile = Bun.file(getCacheFilePath(providerId))
219+
const exists = await cacheFile.exists()
220+
if (!exists) return undefined
221+
222+
const cached: CachedModels = await cacheFile.json()
223+
const now = Date.now()
224+
225+
if (now - cached.timestamp > (cached.ttl || ttlMs)) {
226+
log.debug("Cache expired", { providerId })
227+
return undefined
228+
}
229+
230+
log.debug("Using cached models", { providerId, count: cached.models.length })
231+
return cached.models
232+
} catch (error) {
233+
log.debug("Failed to read cache", { providerId, error })
234+
return undefined
235+
}
236+
}
237+
238+
export async function setCached(providerId: string, models: ModelsDev.Model[], ttlMs: number = 60 * 60 * 1000): Promise<void> {
239+
try {
240+
// Ensure cache directory exists
241+
await Bun.$`mkdir -p ${cacheDir}`.nothrow().quiet()
242+
243+
const cacheData: CachedModels = {
244+
timestamp: Date.now(),
245+
models,
246+
ttl: ttlMs,
247+
}
248+
249+
const cacheFile = getCacheFilePath(providerId)
250+
await Bun.write(cacheFile, JSON.stringify(cacheData, null, 2))
251+
252+
log.debug("Cached models", { providerId, count: models.length })
253+
} catch (error) {
254+
log.warn("Failed to cache models", { providerId, error })
255+
}
256+
}
257+
258+
export async function fetchWithCache(
259+
providerId: string,
260+
options: FetchOptions,
261+
cacheOptions?: { enabled?: boolean; ttlMs?: number }
262+
): Promise<ModelsDev.Model[]> {
263+
const { enabled = true, ttlMs = 60 * 60 * 1000 } = cacheOptions || {}
264+
265+
if (enabled) {
266+
const cached = await getCached(providerId, ttlMs)
267+
if (cached) return cached
268+
}
269+
270+
const models = await fetchFromEndpoint(options)
271+
272+
if (enabled) {
273+
await setCached(providerId, models, ttlMs)
274+
}
275+
276+
return models
277+
}
278+
279+
export async function invalidateCache(providerId?: string): Promise<void> {
280+
try {
281+
if (providerId) {
282+
const cacheFile = getCacheFilePath(providerId)
283+
await Bun.file(cacheFile).delete().catch(() => {})
284+
log.info("Invalidated cache for provider", { providerId })
285+
} else {
286+
// Invalidate all caches
287+
const glob = new Bun.Glob("*.json")
288+
for await (const file of glob.scan({ cwd: cacheDir })) {
289+
await Bun.file(path.join(cacheDir, file)).delete().catch(() => {})
290+
}
291+
log.info("Invalidated all model caches")
292+
}
293+
} catch (error) {
294+
log.warn("Failed to invalidate cache", { providerId, error })
295+
}
296+
}
297+
}

packages/opencode/test/session/llm.test.ts

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ beforeAll(() => {
134134
}
135135

136136
const url = new URL(req.url)
137-
const body = (await req.json()) as Record<string, unknown>
137+
const body = req.method === "GET" ? {} : ((await req.json()) as Record<string, unknown>)
138138
next.resolve({ url, headers: req.headers, body })
139139

140140
if (!url.pathname.endsWith(next.path)) {
@@ -250,6 +250,7 @@ describe("session.llm.stream", () => {
250250
enabled_providers: [providerID],
251251
provider: {
252252
[providerID]: {
253+
shouldFetchModels: false,
253254
options: {
254255
apiKey: "test-key",
255256
baseURL: `${server.url.origin}/v1`,
@@ -374,6 +375,7 @@ describe("session.llm.stream", () => {
374375
provider: {
375376
openai: {
376377
name: "OpenAI",
378+
shouldFetchModels: false,
377379
env: ["OPENAI_API_KEY"],
378380
npm: "@ai-sdk/openai",
379381
api: "https://api.openai.com/v1",
@@ -502,6 +504,7 @@ describe("session.llm.stream", () => {
502504
enabled_providers: [providerID],
503505
provider: {
504506
[providerID]: {
507+
shouldFetchModels: false,
505508
options: {
506509
apiKey: "test-anthropic-key",
507510
baseURL: `${server.url.origin}/v1`,
@@ -603,6 +606,7 @@ describe("session.llm.stream", () => {
603606
enabled_providers: [providerID],
604607
provider: {
605608
[providerID]: {
609+
shouldFetchModels: false,
606610
options: {
607611
apiKey: "test-google-key",
608612
baseURL: `${server.url.origin}/v1beta`,

0 commit comments

Comments
 (0)