Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 44 additions & 7 deletions Sources/AnyLanguageModel/Models/MLXLanguageModel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,12 @@ import Foundation
public var kvGroupSize: Int
/// Sets the token offset where quantized KV storage starts.
public var quantizedKVStart: Int
/// Additional key-value pairs injected into the chat template rendering context.
public var additionalContext: [String: MLXLMCommon.JSONValue]?

var additionalContextForUserInput: [String: any Sendable]? {
additionalContext?.mapValues { $0.toSendable() }
}

/// Creates MLX-specific generation options.
///
Expand All @@ -218,16 +224,20 @@ import Foundation
/// Pass `nil` to disable KV quantization.
/// - kvGroupSize: The token group size used for KV quantization.
/// - quantizedKVStart: The token index where quantized KV storage begins.
/// - additionalContext: Additional key-value pairs injected into the chat
/// template rendering context.
public init(
maxKVSize: Int? = nil,
kvBits: Int? = nil,
kvGroupSize: Int = 64,
quantizedKVStart: Int = 0
quantizedKVStart: Int = 0,
additionalContext: [String: MLXLMCommon.JSONValue]? = nil
) {
self.maxKVSize = maxKVSize
self.kvBits = kvBits
self.kvGroupSize = kvGroupSize
self.quantizedKVStart = quantizedKVStart
self.additionalContext = additionalContext
}
}

Expand Down Expand Up @@ -813,6 +823,9 @@ import Foundation
// Map AnyLanguageModel GenerationOptions to MLX GenerateParameters
let generateParameters = toGenerateParameters(options)

// Extract additional context from custom options
let additionalContext = options[custom: MLXLanguageModel.self]?.additionalContextForUserInput

// Build chat history from full transcript
var chat = convertTranscriptToMLXChat(session: session, fallbackPrompt: prompt.description)

Expand All @@ -828,7 +841,8 @@ import Foundation
let userInput = MLXLMCommon.UserInput(
chat: chat,
processing: .init(resize: .init(width: 512, height: 512)),
tools: toolSpecs
tools: toolSpecs,
additionalContext: additionalContext,
)
let lmInput = try await context.processor.prepare(input: userInput)
let resolved = resolveCache(
Expand Down Expand Up @@ -991,10 +1005,15 @@ import Foundation

// Build chat inside task to avoid Sendable issues
let generateParameters = toGenerateParameters(options)
let userInput = makeUserInput(
session: session,
fallbackPrompt: prompt.description,
tools: nil
let chat = convertTranscriptToMLXChat(session: session, fallbackPrompt: prompt.description)

let additionalContext = options[custom: MLXLanguageModel.self]?.additionalContextForUserInput

let userInput = MLXLMCommon.UserInput(
chat: chat,
processing: .init(resize: .init(width: 512, height: 512)),
tools: nil,
additionalContext: additionalContext
)
let lmInput = try await context.processor.prepare(input: userInput)
let resolved = resolveCache(
Expand Down Expand Up @@ -1529,10 +1548,14 @@ import Foundation
let baseChat = convertTranscriptToMLXChat(session: session, fallbackPrompt: prompt.description)
let schemaPrompt = includeSchemaInPrompt ? schemaPrompt(for: schema) : nil
let chat = normalizeChatForStructuredGeneration(baseChat, schemaPrompt: schemaPrompt)

let additionalContext = options[custom: MLXLanguageModel.self]?.additionalContextForUserInput

let userInput = MLXLMCommon.UserInput(
chat: chat,
processing: .init(resize: .init(width: 512, height: 512)),
tools: nil
tools: nil,
additionalContext: additionalContext,
)
let lmInput = try await context.processor.prepare(input: userInput)

Expand Down Expand Up @@ -1773,4 +1796,18 @@ import Foundation
return sampledToken.item(Int.self)
}
}
extension MLXLMCommon.JSONValue {
/// Recursively converts a `JSONValue` to its primitive Swift equivalent.
func toSendable() -> any Sendable {
switch self {
case .string(let s): return s
case .int(let i): return i
case .double(let d): return d
case .bool(let b): return b
case .null: return NSNull()
case .array(let arr): return arr.map { $0.toSendable() }
case .object(let obj): return obj.mapValues { $0.toSendable() }
}
}
}
#endif // MLX
22 changes: 22 additions & 0 deletions Tests/AnyLanguageModelTests/MLXLanguageModelTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,28 @@ import Testing
#expect([Priority.low, Priority.medium, Priority.high].contains(response.content))
}

@Test func withAdditionalContext() async throws {
let session = LanguageModelSession(model: model)

var options = GenerationOptions(
temperature: 0.7,
maximumResponseTokens: 32
)
options[custom: MLXLanguageModel.self] = .init(
additionalContext: [
"user_name": .string("Alice"),
"turn_count": .int(3),
"verbose": .bool(true),
]
)

let response = try await session.respond(
to: "Say hello",
options: options
)
#expect(!response.content.isEmpty)
}

@Test func unavailableForNonexistentModel() async {
let model = MLXLanguageModel(modelId: "mlx-community/does-not-exist-anylanguagemodel-test")
await model.removeFromCache()
Expand Down
Loading