Skip to content
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
ded134c
Fix SystemLanguageModel to pass schema for structured generation
eastriverlee Dec 20, 2025
0fa246d
Implement logit-constrained structured generation for LlamaLanguageModel
eastriverlee Dec 20, 2025
be79024
Implement logit-constrained structured generation for MLXLanguageModel
eastriverlee Dec 20, 2025
887964a
Fix duplicate type crash in schema generation
eastriverlee Dec 20, 2025
ecae1bd
Enforce count + numeric range guides
eastriverlee Dec 21, 2025
9fb930b
Respect temperature for structured generation
eastriverlee Dec 21, 2025
77950e3
Refactor Llama and MLX structured generation to shared constrained ge…
eastriverlee Dec 21, 2025
c8fdb9d
swift format -i -r .
mattt Jan 20, 2026
200886b
Restore SystemLanguageModel.swift from HEAD of main
mattt Jan 20, 2026
51294c6
Align structured generation prompts and defaults, and enrich schema p…
mattt Jan 20, 2026
2488201
Respect schema prompt flag and enhance structured prompts with JSONSc…
mattt Jan 20, 2026
04650e1
Add documentation comments to helper methods
mattt Jan 20, 2026
a5ccc1e
Refactor tests and fixtures
mattt Feb 3, 2026
996db5a
Incorporate feedback from review
mattt Feb 3, 2026
3659992
Conform internal GenerationSchema.Node to Equatable
mattt Feb 3, 2026
e76f048
Refactor and reorganize GenerationSchema
mattt Feb 3, 2026
87ef375
Rework StructuredGeneration adding docs and tests
mattt Feb 3, 2026
2564d59
Incorporate feedback from review
mattt Feb 3, 2026
5e5ed58
Update GenerableMacro to build guides using .minimum, .maximum, .rang…
mattt Feb 3, 2026
42e5934
Lower APIs from package to internal access
mattt Feb 3, 2026
4768a98
Change var to let
mattt Feb 3, 2026
dfb2f5d
Adjust constrained generation to never allow end tokens as valid term…
mattt Feb 3, 2026
c0cc2a3
Tighten free-string budget so we don’t exhaust the token budget on mu…
mattt Feb 3, 2026
8e46927
Fix eos token logic
mattt Feb 3, 2026
27f39b6
Incorporate feedback from review
mattt Feb 4, 2026
5d10c2a
Fix parsing of guide constraints without a description string
mattt Feb 4, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 33 additions & 18 deletions Sources/AnyLanguageModel/GenerationGuide.swift
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,24 @@ import struct Foundation.Decimal
import class Foundation.NSDecimalNumber

/// Guides that control how values are generated.
public struct GenerationGuide<Value> {}
public struct GenerationGuide<Value>: Sendable {
package var minimumCount: Int?
package var maximumCount: Int?
package var minimum: Double?
package var maximum: Double?

public init() {}

package init(minimumCount: Int?, maximumCount: Int?) {
self.minimumCount = minimumCount
self.maximumCount = maximumCount
}

package init(minimum: Double?, maximum: Double?) {
self.minimum = minimum
self.maximum = maximum
}
}

// MARK: - String Guides

Expand Down Expand Up @@ -45,7 +62,7 @@ extension GenerationGuide where Value == Int {
/// }
/// ```
public static func minimum(_ value: Int) -> GenerationGuide<Int> {
GenerationGuide<Int>()
GenerationGuide<Int>(minimum: Double(value), maximum: nil)
}

/// Enforces a maximum value.
Expand All @@ -65,7 +82,7 @@ extension GenerationGuide where Value == Int {
/// }
/// ```
public static func maximum(_ value: Int) -> GenerationGuide<Int> {
GenerationGuide<Int>()
GenerationGuide<Int>(minimum: nil, maximum: Double(value))
}

/// Enforces values fall within a range.
Expand All @@ -85,7 +102,7 @@ extension GenerationGuide where Value == Int {
/// }
/// ```
public static func range(_ range: ClosedRange<Int>) -> GenerationGuide<Int> {
GenerationGuide<Int>()
GenerationGuide<Int>(minimum: Double(range.lowerBound), maximum: Double(range.upperBound))
}
}

Expand Down Expand Up @@ -144,18 +161,18 @@ extension GenerationGuide where Value == Double {
/// Enforces a minimum value.
/// The bounds are inclusive.
public static func minimum(_ value: Double) -> GenerationGuide<Double> {
GenerationGuide<Double>()
GenerationGuide<Double>(minimum: value, maximum: nil)
}

/// Enforces a maximum value.
/// The bounds are inclusive.
public static func maximum(_ value: Double) -> GenerationGuide<Double> {
GenerationGuide<Double>()
GenerationGuide<Double>(minimum: nil, maximum: value)
}

/// Enforces values fall within a range.
public static func range(_ range: ClosedRange<Double>) -> GenerationGuide<Double> {
GenerationGuide<Double>()
GenerationGuide<Double>(minimum: range.lowerBound, maximum: range.upperBound)
}
}

Expand All @@ -168,33 +185,31 @@ extension GenerationGuide {
/// The bounds are inclusive.
public static func minimumCount<Element>(_ count: Int) -> GenerationGuide<[Element]>
where Value == [Element] {
GenerationGuide<[Element]>()
GenerationGuide<[Element]>(minimumCount: count, maximumCount: nil)
}

/// Enforces a maximum number of elements in the array.
///
/// The bounds are inclusive.
public static func maximumCount<Element>(_ count: Int) -> GenerationGuide<[Element]>
where Value == [Element] {
GenerationGuide<[Element]>()
GenerationGuide<[Element]>(minimumCount: nil, maximumCount: count)
}

/// Enforces that the number of elements in the array fall within a closed range.
public static func count<Element>(_ range: ClosedRange<Int>) -> GenerationGuide<[Element]>
where Value == [Element] {
GenerationGuide<[Element]>()
GenerationGuide<[Element]>(minimumCount: range.lowerBound, maximumCount: range.upperBound)
}

/// Enforces that the array has exactly a certain number elements.
public static func count<Element>(_ count: Int) -> GenerationGuide<[Element]>
where Value == [Element] {
GenerationGuide<[Element]>()
GenerationGuide<[Element]>(minimumCount: count, maximumCount: count)
}

/// Enforces a guide on the elements within the array.
public static func element<Element>(_ guide: GenerationGuide<Element>) -> GenerationGuide<
[Element]
>
public static func element<Element>(_ guide: GenerationGuide<Element>) -> GenerationGuide<[Element]>
where Value == [Element] {
GenerationGuide<[Element]>()
}
Expand All @@ -210,7 +225,7 @@ extension GenerationGuide where Value == [Never] {
///
/// - Warning: This overload is only used for macro expansion. Don't call `GenerationGuide<[Never]>.minimumCount(_:)` on your own.
public static func minimumCount(_ count: Int) -> GenerationGuide<Value> {
GenerationGuide<Value>()
GenerationGuide<Value>(minimumCount: count, maximumCount: nil)
}

/// Enforces a maximum number of elements in the array.
Expand All @@ -219,20 +234,20 @@ extension GenerationGuide where Value == [Never] {
///
/// - Warning: This overload is only used for macro expansion. Don't call `GenerationGuide<[Never]>.maximumCount(_:)` on your own.
public static func maximumCount(_ count: Int) -> GenerationGuide<Value> {
GenerationGuide<Value>()
GenerationGuide<Value>(minimumCount: nil, maximumCount: count)
}

/// Enforces that the number of elements in the array fall within a closed range.
///
/// - Warning: This overload is only used for macro expansion. Don't call `GenerationGuide<[Never]>.count(_:)` on your own.
public static func count(_ range: ClosedRange<Int>) -> GenerationGuide<Value> {
GenerationGuide<Value>()
GenerationGuide<Value>(minimumCount: range.lowerBound, maximumCount: range.upperBound)
}

/// Enforces that the array has exactly a certain number elements.
///
/// - Warning: This overload is only used for macro expansion. Don't call `GenerationGuide<[Never]>.count(_:)` on your own.
public static func count(_ count: Int) -> GenerationGuide<Value> {
GenerationGuide<Value>()
GenerationGuide<Value>(minimumCount: count, maximumCount: count)
}
}
100 changes: 97 additions & 3 deletions Sources/AnyLanguageModel/GenerationSchema.swift
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,16 @@ public struct GenerationSchema: Sendable, Codable, CustomDebugStringConvertible
)
}
}

var nodeDescription: String? {
switch self {
case .object(let node): node.description
case .array(let node): node.description
case .string(let node): node.description
case .number(let node): node.description
case .boolean, .anyOf, .ref: nil
}
}
}

struct ObjectNode: Sendable, Codable {
Expand Down Expand Up @@ -504,12 +514,32 @@ public struct GenerationSchema: Sendable, Codable, CustomDebugStringConvertible
}

private static func nodesEqual(_ a: Node, _ b: Node) -> Bool {
// Simple structural equality - could be enhanced
switch (a, b) {
case (.boolean, .boolean):
return true
case (.ref(let aName), .ref(let bName)):
return aName == bName
case (.string(let aString), .string(let bString)):
return aString.pattern == bString.pattern
&& aString.enumChoices == bString.enumChoices
case (.number(let aNumber), .number(let bNumber)):
return aNumber.integerOnly == bNumber.integerOnly
&& aNumber.minimum == bNumber.minimum
&& aNumber.maximum == bNumber.maximum
case (.array(let aArray), .array(let bArray)):
return aArray.minItems == bArray.minItems
&& aArray.maxItems == bArray.maxItems
&& nodesEqual(aArray.items, bArray.items)
case (.object(let aObject), .object(let bObject)):
return aObject.required == bObject.required
&& aObject.properties.keys == bObject.properties.keys
&& aObject.properties.allSatisfy { key, aNode in
guard let bNode = bObject.properties[key] else { return false }
return nodesEqual(aNode, bNode)
}
case (.anyOf(let aNodes), .anyOf(let bNodes)):
return aNodes.count == bNodes.count
&& zip(aNodes, bNodes).allSatisfy(nodesEqual)
default:
return false
}
Expand Down Expand Up @@ -693,16 +723,43 @@ extension GenerationSchema {
} else if type == String.self {
return (.string(StringNode(description: description, pattern: nil, enumChoices: nil)), [:])
} else if type == Int.self {
var minimum: Double?
var maximum: Double?
for guide in guides {
if let min = guide.minimum { minimum = min }
if let max = guide.maximum { maximum = max }
}
return (
.number(NumberNode(description: description, minimum: nil, maximum: nil, integerOnly: true)), [:]
.number(
NumberNode(description: description, minimum: minimum, maximum: maximum, integerOnly: true)
), [:]
)
} else if type == Float.self || type == Double.self || type == Decimal.self {
var minimum: Double?
var maximum: Double?
for guide in guides {
if let min = guide.minimum { minimum = min }
if let max = guide.maximum { maximum = max }
}
return (
.number(NumberNode(description: description, minimum: nil, maximum: nil, integerOnly: false)), [:]
.number(
NumberNode(description: description, minimum: minimum, maximum: maximum, integerOnly: false)
), [:]
)
} else {
// Complex type - use its schema
let schema = Value.generationSchema

// Arrays should be inlined, not referenced
if case .array(var arrayNode) = schema.root {
arrayNode.description = description
for guide in guides {
if let min = guide.minimumCount { arrayNode.minItems = min }
if let max = guide.maximumCount { arrayNode.maxItems = max }
}
return (.array(arrayNode), schema.defs)
}

let typeName = String(reflecting: Value.self)

var deps = schema.defs
Expand Down Expand Up @@ -800,4 +857,41 @@ extension GenerationSchema {
/// let data = try encoder.encode(schema)
/// ```
static let omitAdditionalPropertiesKey = CodingUserInfoKey(rawValue: "GenerationSchema.omitAdditionalProperties")!

package func schemaPrompt() -> String {
let encoder = JSONEncoder()
encoder.outputFormatting = [.prettyPrinted, .sortedKeys]
guard let data = try? encoder.encode(self),
let schemaJSON = String(data: data, encoding: .utf8)
else {
return "Respond with valid JSON only."
}
return "Respond with valid JSON matching this schema:\n\(schemaJSON)"
}
}

extension Character {
package static let jsonQuoteScalars: Set<UInt32> = [0x22, 0x201C, 0x201D, 0x2018, 0x2019]
package static let jsonAllowedWhitespaceCharacters: Set<Character> = [" ", "\t", "\n"]

package var containsEmojiScalar: Bool {
unicodeScalars.contains { scalar in
scalar.properties.isEmojiPresentation || scalar.properties.isEmoji
}
}

package var isValidJSONStringCharacter: Bool {
guard self != "\\" else { return false }
guard let scalar = unicodeScalars.first, scalar.value >= 0x20 else { return false }
guard !Self.jsonQuoteScalars.contains(scalar.value) else { return false }

if let ascii = asciiValue {
let char = Character(UnicodeScalar(ascii))
if Self.jsonAllowedWhitespaceCharacters.contains(char) { return true }
return isLetter || isNumber || (isASCII && (isPunctuation || isSymbol))
}

// Allow non-ASCII letters/numbers and emoji, but disallow non-ASCII punctuation (e.g. "】")
return isLetter || isNumber || containsEmojiScalar
}
}
10 changes: 10 additions & 0 deletions Sources/AnyLanguageModel/LanguageModelSession.swift
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,16 @@ public final class LanguageModelSession: @unchecked Sendable {
public let content: Content
public let rawContent: GeneratedContent
public let transcriptEntries: ArraySlice<Transcript.Entry>

init(
content: Content,
rawContent: GeneratedContent,
transcriptEntries: ArraySlice<Transcript.Entry>
) {
self.content = content
self.rawContent = rawContent
self.transcriptEntries = transcriptEntries
}
}

@discardableResult
Expand Down
Loading