Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
182 changes: 129 additions & 53 deletions Sources/AnyLanguageModel/LanguageModelSession.swift
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,16 @@ import Foundation
import Observation

@Observable
public final class LanguageModelSession {
public final class LanguageModelSession: @unchecked Sendable {
Comment thread
mattt marked this conversation as resolved.
public private(set) var isResponding: Bool = false
public private(set) var transcript: Transcript

private let model: any LanguageModel
public let tools: [any Tool]
public let instructions: Instructions?

@ObservationIgnored private let respondingState = RespondingState()

public convenience init(
model: any LanguageModel,
tools: [any Tool] = [],
Expand Down Expand Up @@ -58,60 +60,114 @@ public final class LanguageModelSession {
model.prewarm(for: self, promptPrefix: promptPrefix)
}

public struct Response<Content> where Content: Generable {
nonisolated private func beginResponding() async {
let count = await respondingState.increment()
let active = count > 0
await MainActor.run {
self.isResponding = active
}
}

nonisolated private func endResponding() async {
let count = await respondingState.decrement()
let active = count > 0
await MainActor.run {
self.isResponding = active
}
}

nonisolated private func wrapRespond<T>(_ operation: () async throws -> T) async throws -> T {
await beginResponding()
do {
let result = try await operation()
await endResponding()
return result
} catch {
await endResponding()
throw error
}
}

nonisolated private func wrapStream<Content>(
_ upstream: sending ResponseStream<Content>
) -> ResponseStream<Content> where Content: Generable, Content.PartiallyGenerated: Sendable {
let session = self
let relay = AsyncThrowingStream<ResponseStream<Content>.Snapshot, any Error> { continuation in
let stream = upstream
Task {
await session.beginResponding()
do {
for try await snapshot in stream {
continuation.yield(snapshot)
}
continuation.finish()
} catch {
continuation.finish(throwing: error)
}
await session.endResponding()
}
}
return ResponseStream(stream: relay)
}

public struct Response<Content>: Sendable where Content: Generable, Content: Sendable {
public let content: Content
public let rawContent: GeneratedContent
public let transcriptEntries: ArraySlice<Transcript.Entry>
}

@discardableResult
nonisolated public func respond(
public func respond(
Comment thread
mattt marked this conversation as resolved.
Outdated
to prompt: Prompt,
options: GenerationOptions = GenerationOptions()
) async throws -> Response<String> {
try await model.respond(
within: self,
to: prompt,
generating: String.self,
includeSchemaInPrompt: true,
options: options
)
try await wrapRespond {
try await model.respond(
within: self,
to: prompt,
generating: String.self,
includeSchemaInPrompt: true,
options: options
)
}
}

@discardableResult
nonisolated public func respond(
public func respond(
to prompt: String,
options: GenerationOptions = GenerationOptions()
) async throws -> Response<String> {
try await respond(to: Prompt(prompt), options: options)
}

@discardableResult
nonisolated public func respond(
public func respond(
options: GenerationOptions = GenerationOptions(),
@PromptBuilder prompt: () throws -> Prompt
) async throws -> Response<String> {
try await respond(to: try prompt(), options: options)
}

@discardableResult
nonisolated public func respond(
public func respond(
to prompt: Prompt,
schema: GenerationSchema,
includeSchemaInPrompt: Bool = true,
options: GenerationOptions = GenerationOptions()
) async throws -> Response<GeneratedContent> {
try await model.respond(
within: self,
to: prompt,
generating: GeneratedContent.self,
includeSchemaInPrompt: includeSchemaInPrompt,
options: options
)
try await wrapRespond {
try await model.respond(
within: self,
to: prompt,
generating: GeneratedContent.self,
includeSchemaInPrompt: includeSchemaInPrompt,
options: options
)
}
}

@discardableResult
nonisolated public func respond(
public func respond(
to prompt: String,
schema: GenerationSchema,
includeSchemaInPrompt: Bool = true,
Expand All @@ -126,7 +182,7 @@ public final class LanguageModelSession {
}

@discardableResult
nonisolated public func respond(
public func respond(
schema: GenerationSchema,
includeSchemaInPrompt: Bool = true,
options: GenerationOptions = GenerationOptions(),
Expand All @@ -141,23 +197,25 @@ public final class LanguageModelSession {
}

@discardableResult
nonisolated public func respond<Content>(
public func respond<Content>(
to prompt: Prompt,
generating type: Content.Type = Content.self,
includeSchemaInPrompt: Bool = true,
options: GenerationOptions = GenerationOptions()
) async throws -> Response<Content> where Content: Generable {
try await model.respond(
within: self,
to: prompt,
generating: type,
includeSchemaInPrompt: includeSchemaInPrompt,
options: options
)
try await wrapRespond {
try await model.respond(
within: self,
to: prompt,
generating: type,
includeSchemaInPrompt: includeSchemaInPrompt,
options: options
)
}
}

@discardableResult
nonisolated public func respond<Content>(
public func respond<Content>(
to prompt: String,
generating type: Content.Type = Content.self,
includeSchemaInPrompt: Bool = true,
Expand All @@ -172,7 +230,7 @@ public final class LanguageModelSession {
}

@discardableResult
nonisolated public func respond<Content>(
public func respond<Content>(
generating type: Content.Type = Content.self,
includeSchemaInPrompt: Bool = true,
options: GenerationOptions = GenerationOptions(),
Expand All @@ -192,12 +250,14 @@ public final class LanguageModelSession {
includeSchemaInPrompt: Bool = true,
options: GenerationOptions = GenerationOptions()
) -> sending ResponseStream<GeneratedContent> {
model.streamResponse(
within: self,
to: prompt,
generating: GeneratedContent.self,
includeSchemaInPrompt: includeSchemaInPrompt,
options: options
wrapStream(
model.streamResponse(
within: self,
to: prompt,
generating: GeneratedContent.self,
includeSchemaInPrompt: includeSchemaInPrompt,
options: options
)
)
}

Expand Down Expand Up @@ -230,12 +290,14 @@ public final class LanguageModelSession {
includeSchemaInPrompt: Bool = true,
options: GenerationOptions = GenerationOptions()
) -> sending ResponseStream<Content> where Content: Generable {
model.streamResponse(
within: self,
to: prompt,
generating: type,
includeSchemaInPrompt: includeSchemaInPrompt,
options: options
wrapStream(
model.streamResponse(
within: self,
to: prompt,
generating: type,
includeSchemaInPrompt: includeSchemaInPrompt,
options: options
)
)
}

Expand Down Expand Up @@ -271,12 +333,14 @@ public final class LanguageModelSession {
to prompt: Prompt,
options: GenerationOptions = GenerationOptions()
) -> sending ResponseStream<String> {
model.streamResponse(
within: self,
to: prompt,
generating: String.self,
includeSchemaInPrompt: true,
options: options
wrapStream(
model.streamResponse(
within: self,
to: prompt,
generating: String.self,
includeSchemaInPrompt: true,
options: options
)
)
}

Expand Down Expand Up @@ -309,7 +373,19 @@ public final class LanguageModelSession {
}
}

extension LanguageModelSession: @unchecked Sendable, Observable {}
private actor RespondingState {
private var count = 0

func increment() -> Int {
count += 1
return count
}

func decrement() -> Int {
count = max(0, count - 1)
return count
}
}

extension LanguageModelSession {
public enum GenerationError: Error, LocalizedError {
Expand Down Expand Up @@ -401,7 +477,7 @@ extension LanguageModelSession {
}

extension LanguageModelSession {
public struct ResponseStream<Content> where Content: Generable {
public struct ResponseStream<Content>: Sendable where Content: Generable, Content.PartiallyGenerated: Sendable {
private let content: Content
private let rawContent: GeneratedContent
private let streaming: AsyncThrowingStream<Snapshot, any Error>?
Expand All @@ -420,7 +496,7 @@ extension LanguageModelSession {
self.streaming = stream
}

public struct Snapshot {
public struct Snapshot: Sendable where Content.PartiallyGenerated: Sendable {
public var content: Content.PartiallyGenerated
public var rawContent: GeneratedContent
}
Expand Down
10 changes: 5 additions & 5 deletions Sources/AnyLanguageModel/Models/SystemLanguageModel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
/// let model = SystemLanguageModel()
/// ```
@available(macOS 26.0, iOS 26.0, watchOS 26.0, tvOS 26.0, visionOS 26.0, *)
public struct SystemLanguageModel: LanguageModel {
public actor SystemLanguageModel: LanguageModel {
/// The reason the model is unavailable.
public typealias UnavailableReason = FoundationModels.SystemLanguageModel.Availability.UnavailableReason

Expand Down Expand Up @@ -54,7 +54,7 @@
}

/// The availability status for the system language model.
public var availability: Availability<UnavailableReason> {
nonisolated public var availability: Availability<UnavailableReason> {
switch systemModel.availability {
case .available:
.available
Expand All @@ -63,7 +63,7 @@
}
}

public func respond<Content>(
nonisolated public func respond<Content>(
within session: LanguageModelSession,
to prompt: Prompt,
generating type: Content.Type,
Expand Down Expand Up @@ -100,7 +100,7 @@
}
}

public func streamResponse<Content>(
nonisolated public func streamResponse<Content>(
within session: LanguageModelSession,
to prompt: Prompt,
generating type: Content.Type,
Expand Down Expand Up @@ -180,7 +180,7 @@
return LanguageModelSession.ResponseStream(stream: stream)
}

public func logFeedbackAttachment(
nonisolated public func logFeedbackAttachment(
within session: LanguageModelSession,
sentiment: LanguageModelFeedback.Sentiment?,
issues: [LanguageModelFeedback.Issue],
Expand Down
Loading