Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions Sources/AnyLanguageModel/LanguageModel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ public protocol LanguageModel: Sendable {

var availability: Availability<UnavailableReason> { get }

var isResponding: Bool { get }

func prewarm(
for session: LanguageModelSession,
promptPrefix: Prompt?
Expand Down Expand Up @@ -45,6 +47,10 @@ extension LanguageModel {
}
}

Comment thread
mattt marked this conversation as resolved.
public var isResponding: Bool {
return false
}

public func prewarm(
for session: LanguageModelSession,
promptPrefix: Prompt? = nil
Expand Down
180 changes: 127 additions & 53 deletions Sources/AnyLanguageModel/LanguageModelSession.swift
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,16 @@ import Foundation
import Observation

@Observable
public final class LanguageModelSession {
public final class LanguageModelSession: @unchecked Sendable {
Comment thread
mattt marked this conversation as resolved.
public private(set) var isResponding: Bool = false
public private(set) var transcript: Transcript

private let model: any LanguageModel
public let tools: [any Tool]
public let instructions: Instructions?

@ObservationIgnored private let respondingState = RespondingState()

public convenience init(
model: any LanguageModel,
tools: [any Tool] = [],
Expand Down Expand Up @@ -58,60 +60,112 @@ public final class LanguageModelSession {
model.prewarm(for: self, promptPrefix: promptPrefix)
}

public struct Response<Content> where Content: Generable {
nonisolated private func beginResponding() {
Task {
let count = await respondingState.increment()
let active = count > 0
await MainActor.run {
self.isResponding = active
}
}
}

nonisolated private func endResponding() {
Task {
let count = await respondingState.decrement()
let active = count > 0
await MainActor.run {
self.isResponding = active
}
}
Comment thread
mattt marked this conversation as resolved.
Outdated
}

nonisolated private func wrapRespond<T>(_ operation: () async throws -> T) async throws -> T {
beginResponding()
defer { endResponding() }
Comment thread
mattt marked this conversation as resolved.
Outdated
return try await operation()
}

nonisolated private consuming func wrapStream<Content>(
_ upstream: sending ResponseStream<Content>
) -> ResponseStream<Content> where Content: Generable, Content.PartiallyGenerated: Sendable {
beginResponding()
let session = self
let relay = AsyncThrowingStream<ResponseStream<Content>.Snapshot, any Error> { continuation in
let stream = upstream
Comment thread
mattt marked this conversation as resolved.
Outdated
Task {
do {
for try await snapshot in stream {
continuation.yield(snapshot)
}
continuation.finish()
} catch {
continuation.finish(throwing: error)
}
session.endResponding()
}
}
return ResponseStream(stream: relay)
}

public struct Response<Content>: Sendable where Content: Generable, Content: Sendable {
public let content: Content
public let rawContent: GeneratedContent
public let transcriptEntries: ArraySlice<Transcript.Entry>
}

@discardableResult
nonisolated public func respond(
public func respond(
Comment thread
mattt marked this conversation as resolved.
Outdated
to prompt: Prompt,
options: GenerationOptions = GenerationOptions()
) async throws -> Response<String> {
try await model.respond(
within: self,
to: prompt,
generating: String.self,
includeSchemaInPrompt: true,
options: options
)
try await wrapRespond {
try await model.respond(
within: self,
to: prompt,
generating: String.self,
includeSchemaInPrompt: true,
options: options
)
}
}

@discardableResult
nonisolated public func respond(
public func respond(
to prompt: String,
options: GenerationOptions = GenerationOptions()
) async throws -> Response<String> {
try await respond(to: Prompt(prompt), options: options)
}

@discardableResult
nonisolated public func respond(
public func respond(
options: GenerationOptions = GenerationOptions(),
@PromptBuilder prompt: () throws -> Prompt
) async throws -> Response<String> {
try await respond(to: try prompt(), options: options)
}

@discardableResult
nonisolated public func respond(
public func respond(
to prompt: Prompt,
schema: GenerationSchema,
includeSchemaInPrompt: Bool = true,
options: GenerationOptions = GenerationOptions()
) async throws -> Response<GeneratedContent> {
try await model.respond(
within: self,
to: prompt,
generating: GeneratedContent.self,
includeSchemaInPrompt: includeSchemaInPrompt,
options: options
)
try await wrapRespond {
try await model.respond(
within: self,
to: prompt,
generating: GeneratedContent.self,
includeSchemaInPrompt: includeSchemaInPrompt,
options: options
)
}
}

@discardableResult
nonisolated public func respond(
public func respond(
to prompt: String,
schema: GenerationSchema,
includeSchemaInPrompt: Bool = true,
Expand All @@ -126,7 +180,7 @@ public final class LanguageModelSession {
}

@discardableResult
nonisolated public func respond(
public func respond(
schema: GenerationSchema,
includeSchemaInPrompt: Bool = true,
options: GenerationOptions = GenerationOptions(),
Expand All @@ -141,23 +195,25 @@ public final class LanguageModelSession {
}

@discardableResult
nonisolated public func respond<Content>(
public func respond<Content>(
to prompt: Prompt,
generating type: Content.Type = Content.self,
includeSchemaInPrompt: Bool = true,
options: GenerationOptions = GenerationOptions()
) async throws -> Response<Content> where Content: Generable {
try await model.respond(
within: self,
to: prompt,
generating: type,
includeSchemaInPrompt: includeSchemaInPrompt,
options: options
)
try await wrapRespond {
try await model.respond(
within: self,
to: prompt,
generating: type,
includeSchemaInPrompt: includeSchemaInPrompt,
options: options
)
}
}

@discardableResult
nonisolated public func respond<Content>(
public func respond<Content>(
to prompt: String,
generating type: Content.Type = Content.self,
includeSchemaInPrompt: Bool = true,
Expand All @@ -172,7 +228,7 @@ public final class LanguageModelSession {
}

@discardableResult
nonisolated public func respond<Content>(
public func respond<Content>(
generating type: Content.Type = Content.self,
includeSchemaInPrompt: Bool = true,
options: GenerationOptions = GenerationOptions(),
Expand All @@ -192,12 +248,14 @@ public final class LanguageModelSession {
includeSchemaInPrompt: Bool = true,
options: GenerationOptions = GenerationOptions()
) -> sending ResponseStream<GeneratedContent> {
model.streamResponse(
within: self,
to: prompt,
generating: GeneratedContent.self,
includeSchemaInPrompt: includeSchemaInPrompt,
options: options
wrapStream(
model.streamResponse(
within: self,
to: prompt,
generating: GeneratedContent.self,
includeSchemaInPrompt: includeSchemaInPrompt,
options: options
)
)
}

Expand Down Expand Up @@ -230,12 +288,14 @@ public final class LanguageModelSession {
includeSchemaInPrompt: Bool = true,
options: GenerationOptions = GenerationOptions()
) -> sending ResponseStream<Content> where Content: Generable {
model.streamResponse(
within: self,
to: prompt,
generating: type,
includeSchemaInPrompt: includeSchemaInPrompt,
options: options
wrapStream(
model.streamResponse(
within: self,
to: prompt,
generating: type,
includeSchemaInPrompt: includeSchemaInPrompt,
options: options
)
)
}

Expand Down Expand Up @@ -271,12 +331,14 @@ public final class LanguageModelSession {
to prompt: Prompt,
options: GenerationOptions = GenerationOptions()
) -> sending ResponseStream<String> {
model.streamResponse(
within: self,
to: prompt,
generating: String.self,
includeSchemaInPrompt: true,
options: options
wrapStream(
model.streamResponse(
within: self,
to: prompt,
generating: String.self,
includeSchemaInPrompt: true,
options: options
)
)
}

Expand Down Expand Up @@ -309,7 +371,19 @@ public final class LanguageModelSession {
}
}

extension LanguageModelSession: @unchecked Sendable, Observable {}
private actor RespondingState {
private var count = 0

func increment() -> Int {
count += 1
return count
}

func decrement() -> Int {
count = max(0, count - 1)
return count
}
}

extension LanguageModelSession {
public enum GenerationError: Error, LocalizedError {
Expand Down Expand Up @@ -401,7 +475,7 @@ extension LanguageModelSession {
}

extension LanguageModelSession {
public struct ResponseStream<Content> where Content: Generable {
public struct ResponseStream<Content>: Sendable where Content: Generable, Content.PartiallyGenerated: Sendable {
private let content: Content
private let rawContent: GeneratedContent
private let streaming: AsyncThrowingStream<Snapshot, any Error>?
Expand All @@ -420,7 +494,7 @@ extension LanguageModelSession {
self.streaming = stream
}

public struct Snapshot {
public struct Snapshot: Sendable where Content.PartiallyGenerated: Sendable {
public var content: Content.PartiallyGenerated
public var rawContent: GeneratedContent
}
Expand Down
10 changes: 5 additions & 5 deletions Sources/AnyLanguageModel/Models/SystemLanguageModel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
/// let model = SystemLanguageModel()
/// ```
@available(macOS 26.0, iOS 26.0, watchOS 26.0, tvOS 26.0, visionOS 26.0, *)
public struct SystemLanguageModel: LanguageModel {
public actor SystemLanguageModel: LanguageModel {
/// The reason the model is unavailable.
public typealias UnavailableReason = FoundationModels.SystemLanguageModel.Availability.UnavailableReason

Expand Down Expand Up @@ -54,7 +54,7 @@
}

/// The availability status for the system language model.
public var availability: Availability<UnavailableReason> {
nonisolated public var availability: Availability<UnavailableReason> {
switch systemModel.availability {
case .available:
.available
Expand All @@ -63,7 +63,7 @@
}
}

public func respond<Content>(
nonisolated public func respond<Content>(
within session: LanguageModelSession,
to prompt: Prompt,
generating type: Content.Type,
Expand Down Expand Up @@ -100,7 +100,7 @@
}
}

public func streamResponse<Content>(
nonisolated public func streamResponse<Content>(
within session: LanguageModelSession,
to prompt: Prompt,
generating type: Content.Type,
Expand Down Expand Up @@ -180,7 +180,7 @@
return LanguageModelSession.ResponseStream(stream: stream)
}

public func logFeedbackAttachment(
nonisolated public func logFeedbackAttachment(
within session: LanguageModelSession,
sentiment: LanguageModelFeedback.Sentiment?,
issues: [LanguageModelFeedback.Issue],
Expand Down