Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
214 changes: 213 additions & 1 deletion js/plugins/compat-oai/src/audio.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ import type {
SpeechCreateParams,
Transcription,
TranscriptionCreateParams,
TranslationCreateParams,
TranslationCreateResponse,
} from 'openai/resources/audio/index.mjs';
import { PluginOptions } from './index.js';
import { maybeCreateRequestScopedOpenAIClient, toModelName } from './utils.js';
Expand All @@ -40,8 +42,12 @@ export type TranscriptionRequestBuilder = (
req: GenerateRequest,
params: TranscriptionCreateParams
) => void;
export type TranslationRequestBuilder = (
req: GenerateRequest,
params: TranslationCreateParams
) => void;

export const TRANSCRIPTION_MODEL_INFO = {
export const TRANSCRIPTION_MODEL_INFO: ModelInfo = {
supports: {
media: true,
output: ['text', 'json'],
Expand All @@ -61,6 +67,16 @@ export const SPEECH_MODEL_INFO: ModelInfo = {
},
};

export const WHISPER_MODER_INFO: ModelInfo = {
supports: {
media: true,
output: ['text', 'json'],
multiturn: false,
systemRole: false,
tools: false,
},
};

const ChunkingStrategySchema = z.object({
type: z.string(),
prefix_padding_ms: z.number().int().optional(),
Expand Down Expand Up @@ -92,6 +108,19 @@ export const SpeechConfigSchema = z.object({
.optional(),
});

export const WhisperConfigSchema = GenerationCommonConfigSchema.pick({
temperature: true,
}).extend({
/** When true, uses Translation API. Default: false **/
translate: z.boolean().optional().default(false),
response_format: z
.enum(['json', 'text', 'srt', 'verbose_json', 'vtt'])
.optional(),
// transcription-only fields (ignored when translate=true)
language: z.string().optional(),
timestamp_granularities: z.array(z.enum(['word', 'segment'])).optional(),
});

/**
* Supported media formats for Audio generation
*/
Expand Down Expand Up @@ -420,3 +449,186 @@ export function compatOaiTranscriptionModelRef<
namespace,
});
}

function toTranslationRequest(
modelName: string,
request: GenerateRequest,
requestBuilder?: TranslationRequestBuilder
): TranslationCreateParams {
const message = new Message(request.messages[0]);
const media = message.media;
if (!media?.url) {
throw new Error('No media found in the request');
}
const mediaBuffer = Buffer.from(
media.url.slice(media.url.indexOf(',') + 1),
'base64'
);
const mediaFile = new File([mediaBuffer], 'input', {
type:
media.contentType ??
media.url.slice('data:'.length, media.url.indexOf(';')),
});
const {
temperature,
version: modelVersion,
maxOutputTokens,
stopSequences,
topK,
topP,
...restOfConfig
} = request.config ?? {};

let options: TranslationCreateParams = {
model: modelVersion ?? modelName,
file: mediaFile,
prompt: message.text,
temperature,
};
if (requestBuilder) {
requestBuilder(request, options);
} else {
options = {
...options,
...restOfConfig, // passthrough rest of the config
};
}
const outputFormat = request.output?.format as 'json' | 'text' | 'media';
const customFormat = request.config?.response_format;
if (outputFormat && customFormat) {
if (
outputFormat === 'json' &&
customFormat !== 'json' &&
customFormat !== 'verbose_json'
) {
throw new Error(
`Custom response format ${customFormat} is not compatible with output format ${outputFormat}`
);
}
}
if (outputFormat === 'media') {
throw new Error(`Output format ${outputFormat} is not supported.`);
}
options.response_format = customFormat || outputFormat || 'text';
for (const k in options) {
if (options[k] === undefined) {
delete options[k];
}
}
return options;
}

function translationToGenerateResponse(
result: TranslationCreateResponse | string
): GenerateResponseData {
return {
message: {
role: 'model',
content: [
{
text: typeof result === 'string' ? result : result.text,
},
],
},
finishReason: 'stop',
raw: result,
};
}

/**
* Method to define a new Genkit Model that is compatible with Open AI
* Whisper API.
*
* These models are to be used to transcribe or translate audio to text.
*
* @param params An object containing parameters for defining the OpenAI
* whisper model.
* @param params.ai The Genkit AI instance.
* @param params.name The name of the model.
* @param params.client The OpenAI client instance.
* @param params.modelRef Optional reference to the model's configuration and
* custom options.
*
* @returns the created {@link ModelAction}
*/
export function defineCompatOpenAIWhisperModel(params: {
name: string;
client: OpenAI;
pluginOptions?: PluginOptions;
modelRef?: ModelReference<any>;
requestBuilder?: TranscriptionRequestBuilder | TranslationRequestBuilder;
}) {
const {
name,
pluginOptions,
client: defaultClient,
modelRef,
requestBuilder,
} = params;
const modelName = toModelName(name, pluginOptions?.name);
const actionName =
modelRef?.name ?? `${pluginOptions?.name ?? 'compat-oai'}/${modelName}`;

return model(
{
name: actionName,
...modelRef?.info,
configSchema: modelRef?.configSchema,
},
async (request, { abortSignal }) => {
const isTranslate = request.config?.translate === true;
const client = maybeCreateRequestScopedOpenAIClient(
pluginOptions,
request,
defaultClient
);

if (isTranslate) {
// Translation API
const params = toTranslationRequest(modelName, request, requestBuilder);
const result = await client.audio.translations.create(params, {
signal: abortSignal,
});
return translationToGenerateResponse(result);
} else {
// Transcription API
const params = toSttRequest(modelName, request, requestBuilder);
const result = await client.audio.transcriptions.create(
{
...params,
stream: false,
},
{ signal: abortSignal }
);
return transcriptionToGenerateResponse(result);
}
}
);
}

/** Whisper ModelRef helper, with reasonable defaults for
* OpenAI-compatible providers */
export function compatOaiWhisperModelRef<
CustomOptions extends z.ZodTypeAny = z.ZodTypeAny,
>(params: {
name: string;
info?: ModelInfo;
configSchema?: CustomOptions;
config?: any;
namespace?: string;
}) {
const {
name,
info = WHISPER_MODER_INFO,
configSchema,
config = undefined,
namespace,
} = params;
return modelRef({
name,
configSchema: configSchema || (WhisperConfigSchema as any),
info,
config,
namespace,
});
}
4 changes: 4 additions & 0 deletions js/plugins/compat-oai/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,16 @@ import { toModelName } from './utils.js';
export {
SpeechConfigSchema,
TranscriptionConfigSchema,
WhisperConfigSchema,
compatOaiSpeechModelRef,
compatOaiTranscriptionModelRef,
compatOaiWhisperModelRef,
defineCompatOpenAISpeechModel,
defineCompatOpenAITranscriptionModel,
defineCompatOpenAIWhisperModel,
type SpeechRequestBuilder,
type TranscriptionRequestBuilder,
type TranslationRequestBuilder,
} from './audio.js';
export { defineCompatOpenAIEmbedder } from './embedder.js';
export {
Expand Down
57 changes: 44 additions & 13 deletions js/plugins/compat-oai/src/openai/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,10 @@ import OpenAI from 'openai';
import {
defineCompatOpenAISpeechModel,
defineCompatOpenAITranscriptionModel,
defineCompatOpenAIWhisperModel,
SpeechConfigSchema,
TranscriptionConfigSchema,
WhisperConfigSchema,
} from '../audio.js';
import { defineCompatOpenAIEmbedder } from '../embedder.js';
import {
Expand All @@ -56,6 +58,7 @@ import {
} from './gpt.js';
import { openAITranscriptionModelRef, SUPPORTED_STT_MODELS } from './stt.js';
import { openAISpeechModelRef, SUPPORTED_TTS_MODELS } from './tts.js';
import { openAIWhisperModelRef, SUPPORTED_WHISPER_MODELS } from './whisper.js';

export type OpenAIPluginOptions = Omit<PluginOptions, 'name' | 'baseURL'>;

Expand Down Expand Up @@ -88,10 +91,15 @@ function createResolver(pluginOptions: PluginOptions) {
pluginOptions,
modelRef,
});
} else if (
actionName.includes('whisper') ||
actionName.includes('transcribe')
) {
} else if (actionName.includes('whisper')) {
const modelRef = openAIWhisperModelRef({ name: actionName });
return defineCompatOpenAIWhisperModel({
name: modelRef.name,
client,
pluginOptions,
modelRef,
});
} else if (actionName.includes('transcribe')) {
const modelRef = openAITranscriptionModelRef({
name: actionName,
});
Expand Down Expand Up @@ -147,10 +155,16 @@ const listActions = async (client: OpenAI): Promise<ActionMetadata[]> => {
info: modelRef.info,
configSchema: modelRef.configSchema,
});
} else if (
model.id.includes('whisper') ||
model.id.includes('transcribe')
) {
} else if (model.id.includes('whisper')) {
const modelRef =
SUPPORTED_WHISPER_MODELS[model.id] ??
openAIWhisperModelRef({ name: model.id });
return modelActionMetadata({
name: modelRef.name,
info: modelRef.info,
configSchema: modelRef.configSchema,
});
} else if (model.id.includes('transcribe')) {
const modelRef =
SUPPORTED_STT_MODELS[model.id] ??
openAITranscriptionModelRef({ name: model.id });
Expand Down Expand Up @@ -209,6 +223,16 @@ export function openAIPlugin(options?: OpenAIPluginOptions): GenkitPluginV2 {
})
)
);
models.push(
...Object.values(SUPPORTED_WHISPER_MODELS).map((modelRef) =>
defineCompatOpenAIWhisperModel({
name: modelRef.name,
client,
pluginOptions,
modelRef,
})
)
);
models.push(
...Object.values(SUPPORTED_STT_MODELS).map((modelRef) =>
defineCompatOpenAITranscriptionModel({
Expand Down Expand Up @@ -256,10 +280,11 @@ export type OpenAIPlugin = {
config?: z.infer<typeof SpeechConfigSchema>
): ModelReference<typeof SpeechConfigSchema>;
model(
name:
| keyof typeof SUPPORTED_STT_MODELS
| (`whisper-${string}` & {})
| (`${string}-transcribe` & {}),
name: keyof typeof SUPPORTED_WHISPER_MODELS | (`whisper-${string}` & {}),
config?: z.infer<typeof WhisperConfigSchema>
): ModelReference<typeof WhisperConfigSchema>;
model(
name: keyof typeof SUPPORTED_STT_MODELS | (`${string}-transcribe` & {}),
config?: z.infer<typeof TranscriptionConfigSchema>
): ModelReference<typeof TranscriptionConfigSchema>;
model(
Expand Down Expand Up @@ -292,7 +317,13 @@ const model = ((name: string, config?: any): ModelReference<z.ZodTypeAny> => {
config,
});
}
if (name.includes('whisper') || name.includes('transcribe')) {
if (name.includes('whisper')) {
return openAIWhisperModelRef({
name,
config,
});
}
if (name.includes('transcribe')) {
return openAITranscriptionModelRef({
name,
config,
Expand Down
3 changes: 0 additions & 3 deletions js/plugins/compat-oai/src/openai/stt.ts
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,4 @@ export const SUPPORTED_STT_MODELS = {
'gpt-4o-mini-transcribe': openAITranscriptionModelRef({
name: 'gpt-4o-mini-transcribe',
}),
'whisper-1': openAITranscriptionModelRef({
name: 'whisper-1',
}),
};
Loading