Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
255 changes: 16 additions & 239 deletions src/lib/AbstractChatCompletionRunner.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import {
type ChatCompletionCreateParams,
type ChatCompletionTool,
} from 'openai/resources/chat/completions';
import { APIUserAbortError, OpenAIError } from 'openai/error';
import { OpenAIError } from 'openai/error';
import {
type RunnableFunction,
isRunnableFunctionWithParse,
Expand All @@ -20,75 +20,36 @@ import {
ChatCompletionStreamingToolRunnerParams,
} from './ChatCompletionStreamingRunner';
import { isAssistantMessage, isFunctionMessage, isToolMessage } from './chatCompletionUtils';
import { BaseEvents, EventStream } from './EventStream';

const DEFAULT_MAX_CHAT_COMPLETIONS = 10;
export interface RunnerOptions extends Core.RequestOptions {
/** How many requests to make before canceling. Default 10. */
maxChatCompletions?: number;
}

export abstract class AbstractChatCompletionRunner<
Events extends CustomEvents<any> = AbstractChatCompletionRunnerEvents,
> {
controller: AbortController = new AbortController();

#connectedPromise: Promise<void>;
#resolveConnectedPromise: () => void = () => {};
#rejectConnectedPromise: (error: OpenAIError) => void = () => {};

#endPromise: Promise<void>;
#resolveEndPromise: () => void = () => {};
#rejectEndPromise: (error: OpenAIError) => void = () => {};

#listeners: { [Event in keyof Events]?: ListenersForEvent<Events, Event> } = {};

export class AbstractChatCompletionRunner<
EventTypes extends AbstractChatCompletionRunnerEvents,
> extends EventStream<EventTypes> {
protected _chatCompletions: ChatCompletion[] = [];
messages: ChatCompletionMessageParam[] = [];

#ended = false;
#errored = false;
#aborted = false;
#catchingPromiseCreated = false;

constructor() {
this.#connectedPromise = new Promise<void>((resolve, reject) => {
this.#resolveConnectedPromise = resolve;
this.#rejectConnectedPromise = reject;
});

this.#endPromise = new Promise<void>((resolve, reject) => {
this.#resolveEndPromise = resolve;
this.#rejectEndPromise = reject;
});

// Don't let these promises cause unhandled rejection errors.
// we will manually cause an unhandled rejection error later
// if the user hasn't registered any error listener or called
// any promise-returning method.
this.#connectedPromise.catch(() => {});
this.#endPromise.catch(() => {});
}

protected _run(executor: () => Promise<any>) {
// Unfortunately if we call `executor()` immediately we get runtime errors about
// references to `this` before the `super()` constructor call returns.
setTimeout(() => {
executor().then(() => {
this._emitFinal();
this._emit('end');
}, this.#handleError);
}, 0);
}

protected _addChatCompletion(chatCompletion: ChatCompletion): ChatCompletion {
protected _addChatCompletion(
this: AbstractChatCompletionRunner<AbstractChatCompletionRunnerEvents>,
chatCompletion: ChatCompletion,
): ChatCompletion {
this._chatCompletions.push(chatCompletion);
this._emit('chatCompletion', chatCompletion);
const message = chatCompletion.choices[0]?.message;
if (message) this._addMessage(message as ChatCompletionMessageParam);
return chatCompletion;
}

protected _addMessage(message: ChatCompletionMessageParam, emit = true) {
protected _addMessage(
this: AbstractChatCompletionRunner<AbstractChatCompletionRunnerEvents>,
message: ChatCompletionMessageParam,
emit = true,
) {
if (!('content' in message)) message.content = null;

this.messages.push(message);
Expand All @@ -110,99 +71,6 @@ export abstract class AbstractChatCompletionRunner<
}
}

protected _connected() {
if (this.ended) return;
this.#resolveConnectedPromise();
this._emit('connect');
}

get ended(): boolean {
return this.#ended;
}

get errored(): boolean {
return this.#errored;
}

get aborted(): boolean {
return this.#aborted;
}

abort() {
this.controller.abort();
}

/**
* Adds the listener function to the end of the listeners array for the event.
* No checks are made to see if the listener has already been added. Multiple calls passing
* the same combination of event and listener will result in the listener being added, and
* called, multiple times.
* @returns this ChatCompletionStream, so that calls can be chained
*/
on<Event extends keyof Events>(event: Event, listener: ListenerForEvent<Events, Event>): this {
const listeners: ListenersForEvent<Events, Event> =
this.#listeners[event] || (this.#listeners[event] = []);
listeners.push({ listener });
return this;
}

/**
* Removes the specified listener from the listener array for the event.
* off() will remove, at most, one instance of a listener from the listener array. If any single
* listener has been added multiple times to the listener array for the specified event, then
* off() must be called multiple times to remove each instance.
* @returns this ChatCompletionStream, so that calls can be chained
*/
off<Event extends keyof Events>(event: Event, listener: ListenerForEvent<Events, Event>): this {
const listeners = this.#listeners[event];
if (!listeners) return this;
const index = listeners.findIndex((l) => l.listener === listener);
if (index >= 0) listeners.splice(index, 1);
return this;
}

/**
* Adds a one-time listener function for the event. The next time the event is triggered,
* this listener is removed and then invoked.
* @returns this ChatCompletionStream, so that calls can be chained
*/
once<Event extends keyof Events>(event: Event, listener: ListenerForEvent<Events, Event>): this {
const listeners: ListenersForEvent<Events, Event> =
this.#listeners[event] || (this.#listeners[event] = []);
listeners.push({ listener, once: true });
return this;
}

/**
* This is similar to `.once()`, but returns a Promise that resolves the next time
* the event is triggered, instead of calling a listener callback.
* @returns a Promise that resolves the next time given event is triggered,
* or rejects if an error is emitted. (If you request the 'error' event,
* returns a promise that resolves with the error).
*
* Example:
*
* const message = await stream.emitted('message') // rejects if the stream errors
*/
emitted<Event extends keyof Events>(
event: Event,
): Promise<
EventParameters<Events, Event> extends [infer Param] ? Param
: EventParameters<Events, Event> extends [] ? void
: EventParameters<Events, Event>
> {
return new Promise((resolve, reject) => {
this.#catchingPromiseCreated = true;
if (event !== 'error') this.once('error', reject);
this.once(event, resolve as any);
});
}

async done(): Promise<void> {
this.#catchingPromiseCreated = true;
await this.#endPromise;
}

/**
* @returns a promise that resolves with the final ChatCompletion, or rejects
* if an error occurred or the stream ended prematurely without producing a ChatCompletion.
Expand Down Expand Up @@ -327,75 +195,7 @@ export abstract class AbstractChatCompletionRunner<
return [...this._chatCompletions];
}

#handleError = (error: unknown) => {
this.#errored = true;
if (error instanceof Error && error.name === 'AbortError') {
error = new APIUserAbortError();
}
if (error instanceof APIUserAbortError) {
this.#aborted = true;
return this._emit('abort', error);
}
if (error instanceof OpenAIError) {
return this._emit('error', error);
}
if (error instanceof Error) {
const openAIError: OpenAIError = new OpenAIError(error.message);
// @ts-ignore
openAIError.cause = error;
return this._emit('error', openAIError);
}
return this._emit('error', new OpenAIError(String(error)));
};

protected _emit<Event extends keyof Events>(event: Event, ...args: EventParameters<Events, Event>) {
// make sure we don't emit any events after end
if (this.#ended) {
return;
}

if (event === 'end') {
this.#ended = true;
this.#resolveEndPromise();
}

const listeners: ListenersForEvent<Events, Event> | undefined = this.#listeners[event];
if (listeners) {
this.#listeners[event] = listeners.filter((l) => !l.once) as any;
listeners.forEach(({ listener }: any) => listener(...args));
}

if (event === 'abort') {
const error = args[0] as APIUserAbortError;
if (!this.#catchingPromiseCreated && !listeners?.length) {
Promise.reject(error);
}
this.#rejectConnectedPromise(error);
this.#rejectEndPromise(error);
this._emit('end');
return;
}

if (event === 'error') {
// NOTE: _emit('error', error) should only be called from #handleError().

const error = args[0] as OpenAIError;
if (!this.#catchingPromiseCreated && !listeners?.length) {
// Trigger an unhandled rejection if the user hasn't registered any error handlers.
// If you are seeing stack traces here, make sure to handle errors via either:
// - runner.on('error', () => ...)
// - await runner.done()
// - await runner.finalChatCompletion()
// - etc.
Promise.reject(error);
}
this.#rejectConnectedPromise(error);
this.#rejectEndPromise(error);
this._emit('end');
}
}

protected _emitFinal() {
protected override _emitFinal(this: AbstractChatCompletionRunner<AbstractChatCompletionRunnerEvents>) {
const completion = this._chatCompletions[this._chatCompletions.length - 1];
if (completion) this._emit('finalChatCompletion', completion);
const finalMessage = this.#getFinalMessage();
Expand Down Expand Up @@ -650,27 +450,7 @@ export abstract class AbstractChatCompletionRunner<
}
}

type CustomEvents<Event extends string> = {
[k in Event]: k extends keyof AbstractChatCompletionRunnerEvents ? AbstractChatCompletionRunnerEvents[k]
: (...args: any[]) => void;
};

type ListenerForEvent<Events extends CustomEvents<any>, Event extends keyof Events> = Event extends (
keyof AbstractChatCompletionRunnerEvents
) ?
AbstractChatCompletionRunnerEvents[Event]
: Events[Event];

type ListenersForEvent<Events extends CustomEvents<any>, Event extends keyof Events> = Array<{
listener: ListenerForEvent<Events, Event>;
once?: boolean;
}>;
type EventParameters<Events extends CustomEvents<any>, Event extends keyof Events> = Parameters<
ListenerForEvent<Events, Event>
>;

export interface AbstractChatCompletionRunnerEvents {
connect: () => void;
export interface AbstractChatCompletionRunnerEvents extends BaseEvents {
functionCall: (functionCall: ChatCompletionMessage.FunctionCall) => void;
message: (message: ChatCompletionMessageParam) => void;
chatCompletion: (completion: ChatCompletion) => void;
Expand All @@ -680,8 +460,5 @@ export interface AbstractChatCompletionRunnerEvents {
finalFunctionCall: (functionCall: ChatCompletionMessage.FunctionCall) => void;
functionCallResult: (content: string) => void;
finalFunctionCallResult: (content: string) => void;
error: (error: OpenAIError) => void;
abort: (error: APIUserAbortError) => void;
end: () => void;
totalUsage: (usage: CompletionUsage) => void;
}
Loading