diff --git a/src/lib/AbstractChatCompletionRunner.ts b/src/lib/AbstractChatCompletionRunner.ts index 5764b85b2..590013aa6 100644 --- a/src/lib/AbstractChatCompletionRunner.ts +++ b/src/lib/AbstractChatCompletionRunner.ts @@ -8,7 +8,7 @@ import { type ChatCompletionCreateParams, type ChatCompletionTool, } from 'openai/resources/chat/completions'; -import { APIUserAbortError, OpenAIError } from 'openai/error'; +import { OpenAIError } from 'openai/error'; import { type RunnableFunction, isRunnableFunctionWithParse, @@ -20,6 +20,7 @@ import { ChatCompletionStreamingToolRunnerParams, } from './ChatCompletionStreamingRunner'; import { isAssistantMessage, isFunctionMessage, isToolMessage } from './chatCompletionUtils'; +import { BaseEvents, EventStream } from './EventStream'; const DEFAULT_MAX_CHAT_COMPLETIONS = 10; export interface RunnerOptions extends Core.RequestOptions { @@ -27,60 +28,16 @@ export interface RunnerOptions extends Core.RequestOptions { maxChatCompletions?: number; } -export abstract class AbstractChatCompletionRunner< - Events extends CustomEvents = AbstractChatCompletionRunnerEvents, -> { - controller: AbortController = new AbortController(); - - #connectedPromise: Promise; - #resolveConnectedPromise: () => void = () => {}; - #rejectConnectedPromise: (error: OpenAIError) => void = () => {}; - - #endPromise: Promise; - #resolveEndPromise: () => void = () => {}; - #rejectEndPromise: (error: OpenAIError) => void = () => {}; - - #listeners: { [Event in keyof Events]?: ListenersForEvent } = {}; - +export class AbstractChatCompletionRunner< + EventTypes extends AbstractChatCompletionRunnerEvents, +> extends EventStream { protected _chatCompletions: ChatCompletion[] = []; messages: ChatCompletionMessageParam[] = []; - #ended = false; - #errored = false; - #aborted = false; - #catchingPromiseCreated = false; - - constructor() { - this.#connectedPromise = new Promise((resolve, reject) => { - this.#resolveConnectedPromise = resolve; - this.#rejectConnectedPromise = reject; - }); - - this.#endPromise = new Promise((resolve, reject) => { - this.#resolveEndPromise = resolve; - this.#rejectEndPromise = reject; - }); - - // Don't let these promises cause unhandled rejection errors. - // we will manually cause an unhandled rejection error later - // if the user hasn't registered any error listener or called - // any promise-returning method. - this.#connectedPromise.catch(() => {}); - this.#endPromise.catch(() => {}); - } - - protected _run(executor: () => Promise) { - // Unfortunately if we call `executor()` immediately we get runtime errors about - // references to `this` before the `super()` constructor call returns. - setTimeout(() => { - executor().then(() => { - this._emitFinal(); - this._emit('end'); - }, this.#handleError); - }, 0); - } - - protected _addChatCompletion(chatCompletion: ChatCompletion): ChatCompletion { + protected _addChatCompletion( + this: AbstractChatCompletionRunner, + chatCompletion: ChatCompletion, + ): ChatCompletion { this._chatCompletions.push(chatCompletion); this._emit('chatCompletion', chatCompletion); const message = chatCompletion.choices[0]?.message; @@ -88,7 +45,11 @@ export abstract class AbstractChatCompletionRunner< return chatCompletion; } - protected _addMessage(message: ChatCompletionMessageParam, emit = true) { + protected _addMessage( + this: AbstractChatCompletionRunner, + message: ChatCompletionMessageParam, + emit = true, + ) { if (!('content' in message)) message.content = null; this.messages.push(message); @@ -110,99 +71,6 @@ export abstract class AbstractChatCompletionRunner< } } - protected _connected() { - if (this.ended) return; - this.#resolveConnectedPromise(); - this._emit('connect'); - } - - get ended(): boolean { - return this.#ended; - } - - get errored(): boolean { - return this.#errored; - } - - get aborted(): boolean { - return this.#aborted; - } - - abort() { - this.controller.abort(); - } - - /** - * Adds the listener function to the end of the listeners array for the event. - * No checks are made to see if the listener has already been added. Multiple calls passing - * the same combination of event and listener will result in the listener being added, and - * called, multiple times. - * @returns this ChatCompletionStream, so that calls can be chained - */ - on(event: Event, listener: ListenerForEvent): this { - const listeners: ListenersForEvent = - this.#listeners[event] || (this.#listeners[event] = []); - listeners.push({ listener }); - return this; - } - - /** - * Removes the specified listener from the listener array for the event. - * off() will remove, at most, one instance of a listener from the listener array. If any single - * listener has been added multiple times to the listener array for the specified event, then - * off() must be called multiple times to remove each instance. - * @returns this ChatCompletionStream, so that calls can be chained - */ - off(event: Event, listener: ListenerForEvent): this { - const listeners = this.#listeners[event]; - if (!listeners) return this; - const index = listeners.findIndex((l) => l.listener === listener); - if (index >= 0) listeners.splice(index, 1); - return this; - } - - /** - * Adds a one-time listener function for the event. The next time the event is triggered, - * this listener is removed and then invoked. - * @returns this ChatCompletionStream, so that calls can be chained - */ - once(event: Event, listener: ListenerForEvent): this { - const listeners: ListenersForEvent = - this.#listeners[event] || (this.#listeners[event] = []); - listeners.push({ listener, once: true }); - return this; - } - - /** - * This is similar to `.once()`, but returns a Promise that resolves the next time - * the event is triggered, instead of calling a listener callback. - * @returns a Promise that resolves the next time given event is triggered, - * or rejects if an error is emitted. (If you request the 'error' event, - * returns a promise that resolves with the error). - * - * Example: - * - * const message = await stream.emitted('message') // rejects if the stream errors - */ - emitted( - event: Event, - ): Promise< - EventParameters extends [infer Param] ? Param - : EventParameters extends [] ? void - : EventParameters - > { - return new Promise((resolve, reject) => { - this.#catchingPromiseCreated = true; - if (event !== 'error') this.once('error', reject); - this.once(event, resolve as any); - }); - } - - async done(): Promise { - this.#catchingPromiseCreated = true; - await this.#endPromise; - } - /** * @returns a promise that resolves with the final ChatCompletion, or rejects * if an error occurred or the stream ended prematurely without producing a ChatCompletion. @@ -327,75 +195,7 @@ export abstract class AbstractChatCompletionRunner< return [...this._chatCompletions]; } - #handleError = (error: unknown) => { - this.#errored = true; - if (error instanceof Error && error.name === 'AbortError') { - error = new APIUserAbortError(); - } - if (error instanceof APIUserAbortError) { - this.#aborted = true; - return this._emit('abort', error); - } - if (error instanceof OpenAIError) { - return this._emit('error', error); - } - if (error instanceof Error) { - const openAIError: OpenAIError = new OpenAIError(error.message); - // @ts-ignore - openAIError.cause = error; - return this._emit('error', openAIError); - } - return this._emit('error', new OpenAIError(String(error))); - }; - - protected _emit(event: Event, ...args: EventParameters) { - // make sure we don't emit any events after end - if (this.#ended) { - return; - } - - if (event === 'end') { - this.#ended = true; - this.#resolveEndPromise(); - } - - const listeners: ListenersForEvent | undefined = this.#listeners[event]; - if (listeners) { - this.#listeners[event] = listeners.filter((l) => !l.once) as any; - listeners.forEach(({ listener }: any) => listener(...args)); - } - - if (event === 'abort') { - const error = args[0] as APIUserAbortError; - if (!this.#catchingPromiseCreated && !listeners?.length) { - Promise.reject(error); - } - this.#rejectConnectedPromise(error); - this.#rejectEndPromise(error); - this._emit('end'); - return; - } - - if (event === 'error') { - // NOTE: _emit('error', error) should only be called from #handleError(). - - const error = args[0] as OpenAIError; - if (!this.#catchingPromiseCreated && !listeners?.length) { - // Trigger an unhandled rejection if the user hasn't registered any error handlers. - // If you are seeing stack traces here, make sure to handle errors via either: - // - runner.on('error', () => ...) - // - await runner.done() - // - await runner.finalChatCompletion() - // - etc. - Promise.reject(error); - } - this.#rejectConnectedPromise(error); - this.#rejectEndPromise(error); - this._emit('end'); - } - } - - protected _emitFinal() { + protected override _emitFinal(this: AbstractChatCompletionRunner) { const completion = this._chatCompletions[this._chatCompletions.length - 1]; if (completion) this._emit('finalChatCompletion', completion); const finalMessage = this.#getFinalMessage(); @@ -650,27 +450,7 @@ export abstract class AbstractChatCompletionRunner< } } -type CustomEvents = { - [k in Event]: k extends keyof AbstractChatCompletionRunnerEvents ? AbstractChatCompletionRunnerEvents[k] - : (...args: any[]) => void; -}; - -type ListenerForEvent, Event extends keyof Events> = Event extends ( - keyof AbstractChatCompletionRunnerEvents -) ? - AbstractChatCompletionRunnerEvents[Event] -: Events[Event]; - -type ListenersForEvent, Event extends keyof Events> = Array<{ - listener: ListenerForEvent; - once?: boolean; -}>; -type EventParameters, Event extends keyof Events> = Parameters< - ListenerForEvent ->; - -export interface AbstractChatCompletionRunnerEvents { - connect: () => void; +export interface AbstractChatCompletionRunnerEvents extends BaseEvents { functionCall: (functionCall: ChatCompletionMessage.FunctionCall) => void; message: (message: ChatCompletionMessageParam) => void; chatCompletion: (completion: ChatCompletion) => void; @@ -680,8 +460,5 @@ export interface AbstractChatCompletionRunnerEvents { finalFunctionCall: (functionCall: ChatCompletionMessage.FunctionCall) => void; functionCallResult: (content: string) => void; finalFunctionCallResult: (content: string) => void; - error: (error: OpenAIError) => void; - abort: (error: APIUserAbortError) => void; - end: () => void; totalUsage: (usage: CompletionUsage) => void; } diff --git a/src/lib/AssistantStream.ts b/src/lib/AssistantStream.ts index de7511b5d..0f88530b3 100644 --- a/src/lib/AssistantStream.ts +++ b/src/lib/AssistantStream.ts @@ -19,10 +19,6 @@ import { RunSubmitToolOutputsParamsBase, RunSubmitToolOutputsParamsStreaming, } from 'openai/resources/beta/threads/runs/runs'; -import { - AbstractAssistantRunnerEvents, - AbstractAssistantStreamRunner, -} from './AbstractAssistantStreamRunner'; import { type ReadableStream } from 'openai/_shims/index'; import { Stream } from 'openai/streaming'; import { APIUserAbortError, OpenAIError } from 'openai/error'; @@ -34,9 +30,12 @@ import { } from 'openai/resources/beta/assistants'; import { RunStep, RunStepDelta, ToolCall, ToolCallDelta } from 'openai/resources/beta/threads/runs/steps'; import { ThreadCreateAndRunParamsBase, Threads } from 'openai/resources/beta/threads/threads'; +import { BaseEvents, EventStream } from './EventStream'; import MessageDelta = Messages.MessageDelta; -export interface AssistantStreamEvents extends AbstractAssistantRunnerEvents { +export interface AssistantStreamEvents extends BaseEvents { + run: (run: Run) => void; + //New event structure messageCreated: (message: Message) => void; messageDelta: (message: MessageDelta, snapshot: Message) => void; @@ -57,8 +56,6 @@ export interface AssistantStreamEvents extends AbstractAssistantRunnerEvents { //No created or delta as this is not streamed imageFileDone: (content: ImageFile, snapshot: Message) => void; - end: () => void; - event: (event: AssistantStreamEvent) => void; } @@ -75,7 +72,7 @@ export type RunSubmitToolOutputsParamsStream = Omit + extends EventStream implements AsyncIterable { //Track all events in a single list for reference @@ -207,7 +204,7 @@ export class AssistantStream return runner; } - protected override async _createToolAssistantStream( + protected async _createToolAssistantStream( run: Runs, threadId: string, runId: string, @@ -304,7 +301,7 @@ export class AssistantStream return this.#finalRun; } - protected override async _createThreadAssistantStream( + protected async _createThreadAssistantStream( thread: Threads, params: ThreadCreateAndRunParamsBase, options?: Core.RequestOptions, @@ -330,7 +327,7 @@ export class AssistantStream return this._addRun(this.#endRequest()); } - protected override async _createAssistantStream( + protected async _createAssistantStream( run: Runs, threadId: string, params: RunCreateParamsBase, @@ -417,7 +414,7 @@ export class AssistantStream return this.#finalRun; } - #handleMessage(event: MessageStreamEvent) { + #handleMessage(this: AssistantStream, event: MessageStreamEvent) { const [accumulatedMessage, newContent] = this.#accumulateMessage(event, this.#messageSnapshot); this.#messageSnapshot = accumulatedMessage; this.#messageSnapshots[accumulatedMessage.id] = accumulatedMessage; @@ -500,7 +497,7 @@ export class AssistantStream } } - #handleRunStep(event: RunStepStreamEvent) { + #handleRunStep(this: AssistantStream, event: RunStepStreamEvent) { const accumulatedRunStep = this.#accumulateRunStep(event); this.#currentRunStepSnapshot = accumulatedRunStep; @@ -556,7 +553,7 @@ export class AssistantStream } } - #handleEvent(event: AssistantStreamEvent) { + #handleEvent(this: AssistantStream, event: AssistantStreamEvent) { this.#events.push(event); this._emit('event', event); } @@ -696,7 +693,7 @@ export class AssistantStream return acc; } - #handleRun(event: RunStreamEvent) { + #handleRun(this: AssistantStream, event: RunStreamEvent) { this.#currentRunSnapshot = event.data; switch (event.event) { case 'thread.run.created': @@ -720,4 +717,35 @@ export class AssistantStream break; } } + + protected _addRun(run: Run): Run { + return run; + } + + protected async _threadAssistantStream( + body: ThreadCreateAndRunParamsBase, + thread: Threads, + options?: Core.RequestOptions, + ): Promise { + return await this._createThreadAssistantStream(thread, body, options); + } + + protected async _runAssistantStream( + threadId: string, + runs: Runs, + params: RunCreateParamsBase, + options?: Core.RequestOptions, + ): Promise { + return await this._createAssistantStream(runs, threadId, params, options); + } + + protected async _runToolAssistantStream( + threadId: string, + runId: string, + runs: Runs, + params: RunSubmitToolOutputsParamsStream, + options?: Core.RequestOptions, + ): Promise { + return await this._createToolAssistantStream(runs, threadId, runId, params, options); + } } diff --git a/src/lib/ChatCompletionRunner.ts b/src/lib/ChatCompletionRunner.ts index a110f0192..c756919b0 100644 --- a/src/lib/ChatCompletionRunner.ts +++ b/src/lib/ChatCompletionRunner.ts @@ -59,7 +59,7 @@ export class ChatCompletionRunner extends AbstractChatCompletionRunner { + [Symbol.asyncIterator](this: ChatCompletionStream): AsyncIterator { const pushQueue: ChatCompletionChunk[] = []; const readQueue: { resolve: (chunk: ChatCompletionChunk | undefined) => void; diff --git a/src/lib/AbstractAssistantStreamRunner.ts b/src/lib/EventStream.ts similarity index 55% rename from src/lib/AbstractAssistantStreamRunner.ts rename to src/lib/EventStream.ts index b600f0df3..a18c771dd 100644 --- a/src/lib/AbstractAssistantStreamRunner.ts +++ b/src/lib/EventStream.ts @@ -1,12 +1,6 @@ -import * as Core from 'openai/core'; import { APIUserAbortError, OpenAIError } from 'openai/error'; -import { Run, RunSubmitToolOutputsParamsBase } from 'openai/resources/beta/threads/runs/runs'; -import { RunCreateParamsBase, Runs } from 'openai/resources/beta/threads/runs/runs'; -import { ThreadCreateAndRunParamsBase, Threads } from 'openai/resources/beta/threads/threads'; -export abstract class AbstractAssistantStreamRunner< - Events extends CustomEvents = AbstractAssistantRunnerEvents, -> { +export class EventStream { controller: AbortController = new AbortController(); #connectedPromise: Promise; @@ -17,7 +11,9 @@ export abstract class AbstractAssistantStreamRunner< #resolveEndPromise: () => void = () => {}; #rejectEndPromise: (error: OpenAIError) => void = () => {}; - #listeners: { [Event in keyof Events]?: ListenersForEvent } = {}; + #listeners: { + [Event in keyof EventTypes]?: EventListeners; + } = {}; #ended = false; #errored = false; @@ -43,22 +39,18 @@ export abstract class AbstractAssistantStreamRunner< this.#endPromise.catch(() => {}); } - protected _run(executor: () => Promise) { + protected _run(this: EventStream, executor: () => Promise) { // Unfortunately if we call `executor()` immediately we get runtime errors about // references to `this` before the `super()` constructor call returns. setTimeout(() => { executor().then(() => { - // this._emitFinal(); + this._emitFinal(); this._emit('end'); - }, this.#handleError); + }, this.#handleError.bind(this)); }, 0); } - protected _addRun(run: Run): Run { - return run; - } - - protected _connected() { + protected _connected(this: EventStream) { if (this.ended) return; this.#resolveConnectedPromise(); this._emit('connect'); @@ -87,8 +79,8 @@ export abstract class AbstractAssistantStreamRunner< * called, multiple times. * @returns this ChatCompletionStream, so that calls can be chained */ - on(event: Event, listener: ListenerForEvent): this { - const listeners: ListenersForEvent = + on(event: Event, listener: EventListener): this { + const listeners: EventListeners = this.#listeners[event] || (this.#listeners[event] = []); listeners.push({ listener }); return this; @@ -101,7 +93,7 @@ export abstract class AbstractAssistantStreamRunner< * off() must be called multiple times to remove each instance. * @returns this ChatCompletionStream, so that calls can be chained */ - off(event: Event, listener: ListenerForEvent): this { + off(event: Event, listener: EventListener): this { const listeners = this.#listeners[event]; if (!listeners) return this; const index = listeners.findIndex((l) => l.listener === listener); @@ -114,8 +106,8 @@ export abstract class AbstractAssistantStreamRunner< * this listener is removed and then invoked. * @returns this ChatCompletionStream, so that calls can be chained */ - once(event: Event, listener: ListenerForEvent): this { - const listeners: ListenersForEvent = + once(event: Event, listener: EventListener): this { + const listeners: EventListeners = this.#listeners[event] || (this.#listeners[event] = []); listeners.push({ listener, once: true }); return this; @@ -132,12 +124,12 @@ export abstract class AbstractAssistantStreamRunner< * * const message = await stream.emitted('message') // rejects if the stream errors */ - emitted( + emitted( event: Event, ): Promise< - EventParameters extends [infer Param] ? Param - : EventParameters extends [] ? void - : EventParameters + EventParameters extends [infer Param] ? Param + : EventParameters extends [] ? void + : EventParameters > { return new Promise((resolve, reject) => { this.#catchingPromiseCreated = true; @@ -151,7 +143,7 @@ export abstract class AbstractAssistantStreamRunner< await this.#endPromise; } - #handleError = (error: unknown) => { + #handleError(this: EventStream, error: unknown) { this.#errored = true; if (error instanceof Error && error.name === 'AbortError') { error = new APIUserAbortError(); @@ -170,9 +162,15 @@ export abstract class AbstractAssistantStreamRunner< return this._emit('error', openAIError); } return this._emit('error', new OpenAIError(String(error))); - }; + } - protected _emit(event: Event, ...args: EventParameters) { + _emit(event: Event, ...args: EventParameters): void; + _emit(event: Event, ...args: EventParameters): void; + _emit( + this: EventStream, + event: Event, + ...args: EventParameters + ) { // make sure we don't emit any events after end if (this.#ended) { return; @@ -183,10 +181,10 @@ export abstract class AbstractAssistantStreamRunner< this.#resolveEndPromise(); } - const listeners: ListenersForEvent | undefined = this.#listeners[event]; + const listeners: EventListeners | undefined = this.#listeners[event]; if (listeners) { this.#listeners[event] = listeners.filter((l) => !l.once) as any; - listeners.forEach(({ listener }: any) => listener(...args)); + listeners.forEach(({ listener }: any) => listener(...(args as any))); } if (event === 'abort') { @@ -219,121 +217,22 @@ export abstract class AbstractAssistantStreamRunner< } } - protected async _threadAssistantStream( - body: ThreadCreateAndRunParamsBase, - thread: Threads, - options?: Core.RequestOptions, - ): Promise { - return await this._createThreadAssistantStream(thread, body, options); - } - - protected async _runAssistantStream( - threadId: string, - runs: Runs, - params: RunCreateParamsBase, - options?: Core.RequestOptions, - ): Promise { - return await this._createAssistantStream(runs, threadId, params, options); - } - - protected async _runToolAssistantStream( - threadId: string, - runId: string, - runs: Runs, - params: RunSubmitToolOutputsParamsBase, - options?: Core.RequestOptions, - ): Promise { - return await this._createToolAssistantStream(runs, threadId, runId, params, options); - } - - protected async _createThreadAssistantStream( - thread: Threads, - body: ThreadCreateAndRunParamsBase, - options?: Core.RequestOptions, - ): Promise { - const signal = options?.signal; - if (signal) { - if (signal.aborted) this.controller.abort(); - signal.addEventListener('abort', () => this.controller.abort()); - } - // this.#validateParams(params); - - const runResult = await thread.createAndRun( - { ...body, stream: false }, - { ...options, signal: this.controller.signal }, - ); - this._connected(); - return this._addRun(runResult as Run); - } - - protected async _createToolAssistantStream( - run: Runs, - threadId: string, - runId: string, - params: RunSubmitToolOutputsParamsBase, - options?: Core.RequestOptions, - ): Promise { - const signal = options?.signal; - if (signal) { - if (signal.aborted) this.controller.abort(); - signal.addEventListener('abort', () => this.controller.abort()); - } - - const runResult = await run.submitToolOutputs( - threadId, - runId, - { ...params, stream: false }, - { ...options, signal: this.controller.signal }, - ); - this._connected(); - return this._addRun(runResult as Run); - } - - protected async _createAssistantStream( - run: Runs, - threadId: string, - params: RunCreateParamsBase, - options?: Core.RequestOptions, - ): Promise { - const signal = options?.signal; - if (signal) { - if (signal.aborted) this.controller.abort(); - signal.addEventListener('abort', () => this.controller.abort()); - } - // this.#validateParams(params); - - const runResult = await run.create( - threadId, - { ...params, stream: false }, - { ...options, signal: this.controller.signal }, - ); - this._connected(); - return this._addRun(runResult as Run); - } + protected _emitFinal(): void {} } -type CustomEvents = { - [k in Event]: k extends keyof AbstractAssistantRunnerEvents ? AbstractAssistantRunnerEvents[k] - : (...args: any[]) => void; -}; +type EventListener = Events[EventType]; -type ListenerForEvent, Event extends keyof Events> = Event extends ( - keyof AbstractAssistantRunnerEvents -) ? - AbstractAssistantRunnerEvents[Event] -: Events[Event]; - -type ListenersForEvent, Event extends keyof Events> = Array<{ - listener: ListenerForEvent; +type EventListeners = Array<{ + listener: EventListener; once?: boolean; }>; -type EventParameters, Event extends keyof Events> = Parameters< - ListenerForEvent ->; -export interface AbstractAssistantRunnerEvents { +export type EventParameters = { + [Event in EventType]: EventListener extends (...args: infer P) => any ? P : never; +}[EventType]; + +export interface BaseEvents { connect: () => void; - run: (run: Run) => void; error: (error: OpenAIError) => void; abort: (error: APIUserAbortError) => void; end: () => void;