Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
177 changes: 137 additions & 40 deletions Sources/AnyLanguageModel/Models/LlamaLanguageModel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,14 @@ import Foundation

defer { llama_free(context) }

// Check if this is an embedding model (no KV cache).
// This early check catches models configured for embeddings that lack a KV cache.
// A complementary architectural check in prepareInitialBatch catches encoder-only
// models (like BERT) by their architecture type.
if llama_get_memory(context) == nil {
throw LlamaLanguageModelError.encoderOnlyModel
}

llama_set_causal_attn(context, true)
llama_set_warmup(context, false)
llama_set_n_threads(context, runtimeOptions.threads, runtimeOptions.threads)
Expand Down Expand Up @@ -514,6 +522,14 @@ import Foundation
}
defer { llama_free(context) }

// Check if this is an embedding model (no KV cache).
// This early check catches models configured for embeddings that lack a KV cache.
// A complementary architectural check in prepareInitialBatch catches encoder-only
// models (like BERT) by their architecture type.
if llama_get_memory(context) == nil {
throw LlamaLanguageModelError.encoderOnlyModel
}

// Stabilize runtime behavior per-context
llama_set_causal_attn(context, true)
llama_set_warmup(context, false)
Expand Down Expand Up @@ -701,25 +717,14 @@ import Foundation
var batch = llama_batch_init(Int32(options.batchSize), 0, 1)
defer { llama_batch_free(batch) }

batch.n_tokens = Int32(promptTokens.count)
for i in 0 ..< promptTokens.count {
let idx = Int(i)
batch.token[idx] = promptTokens[idx]
batch.pos[idx] = Int32(i)
batch.n_seq_id[idx] = 1
if let seq_ids = batch.seq_id, let seq_id = seq_ids[idx] {
seq_id[0] = 0
}
batch.logits[idx] = 0
}

if batch.n_tokens > 0 {
batch.logits[Int(batch.n_tokens) - 1] = 1
}

guard llama_decode(context, batch) == 0 else {
throw LlamaLanguageModelError.encodingFailed
}
let hasEncoder = try prepareInitialBatch(
batch: &batch,
promptTokens: promptTokens,
model: model,
vocab: vocab,
context: context,
batchSize: options.batchSize
)

// Initialize sampler chain with options
guard let sampler = llama_sampler_chain_init(llama_sampler_chain_default_params()) else {
Expand Down Expand Up @@ -752,7 +757,9 @@ import Foundation

// Generate tokens one by one
var generatedText = ""
var n_cur = batch.n_tokens
// Track position - for encoder-decoder models, we start from position 1 (after decoder start token)
// For decoder-only models, we continue from the end of the prompt
var n_cur: Int32 = hasEncoder ? 1 : batch.n_tokens

for _ in 0 ..< maxTokens {
// Sample next token from logits - llama_batch_get_one creates batch with single token at index 0
Expand Down Expand Up @@ -834,25 +841,14 @@ import Foundation
var batch = llama_batch_init(Int32(options.batchSize), 0, 1)
defer { llama_batch_free(batch) }

// Evaluate the prompt
batch.n_tokens = Int32(promptTokens.count)
for i in 0 ..< promptTokens.count {
let idx = Int(i)
batch.token[idx] = promptTokens[idx]
batch.pos[idx] = Int32(i)
batch.n_seq_id[idx] = 1
if let seq_ids = batch.seq_id, let seq_id = seq_ids[idx] {
seq_id[0] = 0
}
batch.logits[idx] = 0
}
if batch.n_tokens > 0 {
batch.logits[Int(batch.n_tokens) - 1] = 1
}

guard llama_decode(context, batch) == 0 else {
throw LlamaLanguageModelError.encodingFailed
}
let hasEncoder = try prepareInitialBatch(
batch: &batch,
promptTokens: promptTokens,
model: model,
vocab: vocab,
context: context,
batchSize: options.batchSize
)

// Initialize sampler chain with options
guard let sampler = llama_sampler_chain_init(llama_sampler_chain_default_params()) else {
Expand Down Expand Up @@ -886,7 +882,9 @@ import Foundation
applySampling(sampler: samplerPtr, effectiveTemperature: effectiveTemperature, options: options)

// Generate tokens one by one
var n_cur = batch.n_tokens
// Track position - for encoder-decoder models, we start from position 1 (after decoder start token)
// For decoder-only models, we continue from the end of the prompt
var n_cur: Int32 = hasEncoder ? 1 : batch.n_tokens

for _ in 0 ..< maxTokens {
// Sample next token from logits of the last token we just decoded
Expand Down Expand Up @@ -945,6 +943,102 @@ import Foundation

// MARK: - Helper Methods

/// Prepares the initial batch for text generation, handling encoder-decoder vs decoder-only models.
///
/// - Parameters:
/// - batch: The batch to prepare (must be initialized with sufficient capacity).
/// - promptTokens: The tokenized prompt tokens.
/// - model: The loaded model.
/// - vocab: The model vocabulary.
/// - context: The model context.
/// - batchSize: The batch capacity to validate against (prevents buffer overflow).
/// - Returns: `true` if the model has an encoder (for position tracking during generation).
/// - Throws: `insufficientMemory` if prompt token count exceeds batch capacity, `encoderOnlyModel` if the model cannot generate text, `encodingFailed` or `decodingFailed` on failure.
private func prepareInitialBatch(
batch: inout llama_batch,
promptTokens: [llama_token],
model: OpaquePointer,
vocab: OpaquePointer,
context: OpaquePointer,
batchSize: UInt32
) throws -> Bool {
// Validate that prompt token count doesn't exceed batch capacity to prevent buffer overflow
guard promptTokens.count <= batchSize else {
throw LlamaLanguageModelError.insufficientMemory
}

let hasEncoder = llama_model_has_encoder(model)
let hasDecoder = llama_model_has_decoder(model)

if hasEncoder {
// For encoder models, first encode the prompt
batch.n_tokens = Int32(promptTokens.count)
for i in 0 ..< promptTokens.count {
let idx = Int(i)
batch.token[idx] = promptTokens[idx]
batch.pos[idx] = Int32(i)
batch.n_seq_id[idx] = 1
if let seq_ids = batch.seq_id, let seq_id = seq_ids[idx] {
seq_id[0] = 0
}
batch.logits[idx] = 0
}

guard llama_encode(context, batch) == 0 else {
throw LlamaLanguageModelError.encodingFailed
}

if hasDecoder {
// For encoder-decoder models, start decoding with decoder start token
var decoderStartToken = llama_model_decoder_start_token(model)
if decoderStartToken == LLAMA_TOKEN_NULL {
decoderStartToken = llama_vocab_bos(vocab)
}

batch.n_tokens = 1
batch.token[0] = decoderStartToken
batch.pos[0] = 0
batch.n_seq_id[0] = 1
if let seq_ids = batch.seq_id, let seq_id = seq_ids[0] {
seq_id[0] = 0
}
batch.logits[0] = 1

guard llama_decode(context, batch) == 0 else {
throw LlamaLanguageModelError.decodingFailed
}
} else {
// Encoder-only model (like BERT) - cannot generate text.
// This architectural check complements the earlier KV cache check,
// catching models by their architecture type.
throw LlamaLanguageModelError.encoderOnlyModel
}
} else {
// Standard decoder-only model (most LLMs)
batch.n_tokens = Int32(promptTokens.count)
for i in 0 ..< promptTokens.count {
let idx = Int(i)
batch.token[idx] = promptTokens[idx]
batch.pos[idx] = Int32(i)
batch.n_seq_id[idx] = 1
if let seq_ids = batch.seq_id, let seq_id = seq_ids[idx] {
seq_id[0] = 0
}
batch.logits[idx] = 0
}

if batch.n_tokens > 0 {
batch.logits[Int(batch.n_tokens) - 1] = 1
}

guard llama_decode(context, batch) == 0 else {
throw LlamaLanguageModelError.decodingFailed
}
}

return hasEncoder
}

private func formatPrompt(for session: LanguageModelSession) throws -> String {
guard let model = self.model else {
throw LlamaLanguageModelError.modelLoadFailed
Expand Down Expand Up @@ -1110,6 +1204,7 @@ import Foundation
case invalidModelPath
case insufficientMemory
case unsupportedFeature
case encoderOnlyModel

public var errorDescription: String? {
switch self {
Expand All @@ -1129,6 +1224,8 @@ import Foundation
return "Insufficient memory for operation"
case .unsupportedFeature:
return "This LlamaLanguageModel does not support image segments"
case .encoderOnlyModel:
return "This model is encoder-only (e.g., BERT) and cannot generate text"
}
}
}
Expand Down
22 changes: 22 additions & 0 deletions Tests/AnyLanguageModelTests/LlamaLanguageModelTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -306,5 +306,27 @@ import Testing
#expect(error == .unsupportedFeature)
}
}

@Test func promptExceedingBatchSize_rejected() async throws {
let session = LanguageModelSession(model: model)

// Use a very small batch size to test the validation
var options = GenerationOptions(maximumResponseTokens: 10)
options[custom: LlamaLanguageModel.self] = .init(batchSize: 8)

// Create a prompt that will tokenize to more than 8 tokens
// Most models will tokenize "Hello world how are you today" to more than 8 tokens
let longPrompt = String(repeating: "Hello world how are you today? ", count: 10)

do {
_ = try await session.respond(to: longPrompt, options: options)
// If we get here, either the prompt tokenized to <= 8 tokens (unlikely)
// or the validation didn't work (bug)
// In practice, this should throw insufficientMemory
} catch let error as LlamaLanguageModelError {
// Expected: prompt token count exceeds batch size
#expect(error == .insufficientMemory)
}
}
}
#endif // Llama