Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion js/plugins/compat-oai/src/audio.ts
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ export type TranscriptionRequestBuilder = (
params: TranscriptionCreateParams
) => void;

export const TRANSCRIPTION_MODEL_INFO = {
export const TRANSCRIPTION_MODEL_INFO: ModelInfo = {
supports: {
media: true,
output: ['text', 'json'],
Expand Down
6 changes: 6 additions & 0 deletions js/plugins/compat-oai/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,12 @@ export {
openAIModelRunner,
type ModelRequestBuilder,
} from './model.js';
export {
TranslationConfigSchema,
compatOaiTranslationModelRef,
defineCompatOpenAITranslationModel,
type TranslationRequestBuilder,
} from './translate.js';

export interface PluginOptions extends Partial<Omit<ClientOptions, 'apiKey'>> {
apiKey?: ClientOptions['apiKey'] | false;
Expand Down
58 changes: 58 additions & 0 deletions js/plugins/compat-oai/src/openai/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ import {
} from '../image.js';
import { openAICompatible, PluginOptions } from '../index.js';
import { defineCompatOpenAIModel } from '../model.js';
import {
defineCompatOpenAITranslationModel,
TranslationConfigSchema,
} from '../translate.js';
import {
gptImage1RequestBuilder,
openAIImageModelRef,
Expand All @@ -55,6 +59,10 @@ import {
SUPPORTED_GPT_MODELS,
} from './gpt.js';
import { openAITranscriptionModelRef, SUPPORTED_STT_MODELS } from './stt.js';
import {
openAITranslationModelRef,
SUPPORTED_TRANSLATION_MODELS,
} from './translation.js';
import { openAISpeechModelRef, SUPPORTED_TTS_MODELS } from './tts.js';

export type OpenAIPluginOptions = Omit<PluginOptions, 'name' | 'baseURL'>;
Expand Down Expand Up @@ -88,6 +96,19 @@ function createResolver(pluginOptions: PluginOptions) {
pluginOptions,
modelRef,
});
} else if (actionName.includes('translate')) {
const modelRef = openAITranslationModelRef({ name: actionName });
return defineCompatOpenAITranslationModel({
name: modelRef.name,
client,
pluginOptions,
modelRef,
requestBuilder: (req, params) => {
if (modelRef.name.endsWith('whisper-1-translate')) {
params.model = 'whisper-1';
}
},
});
} else if (
actionName.includes('whisper') ||
actionName.includes('transcribe')
Expand Down Expand Up @@ -147,6 +168,15 @@ const listActions = async (client: OpenAI): Promise<ActionMetadata[]> => {
info: modelRef.info,
configSchema: modelRef.configSchema,
});
} else if (model.id.includes('translate')) {
const modelRef =
SUPPORTED_TRANSLATION_MODELS[model.id] ??
openAITranslationModelRef({ name: model.id });
return modelActionMetadata({
name: modelRef.name,
info: modelRef.info,
configSchema: modelRef.configSchema,
});
} else if (
model.id.includes('whisper') ||
model.id.includes('transcribe')
Expand Down Expand Up @@ -209,6 +239,21 @@ export function openAIPlugin(options?: OpenAIPluginOptions): GenkitPluginV2 {
})
)
);
models.push(
...Object.values(SUPPORTED_TRANSLATION_MODELS).map((modelRef) =>
defineCompatOpenAITranslationModel({
name: modelRef.name,
client,
pluginOptions,
modelRef,
requestBuilder: (req, params) => {
if (modelRef.name.endsWith('whisper-1-translate')) {
params.model = 'whisper-1';
}
},
})
)
);
models.push(
...Object.values(SUPPORTED_STT_MODELS).map((modelRef) =>
defineCompatOpenAITranscriptionModel({
Expand Down Expand Up @@ -255,6 +300,13 @@ export type OpenAIPlugin = {
| (`${string}-tts` & {}),
config?: z.infer<typeof SpeechConfigSchema>
): ModelReference<typeof SpeechConfigSchema>;
model(
name:
| keyof typeof SUPPORTED_TRANSLATION_MODELS
| (`whisper-${string}-translate` & {})
| (`${string}-translate` & {}),
config?: z.infer<typeof TranslationConfigSchema>
): ModelReference<typeof TranslationConfigSchema>;
model(
name:
| keyof typeof SUPPORTED_STT_MODELS
Expand Down Expand Up @@ -292,6 +344,12 @@ const model = ((name: string, config?: any): ModelReference<z.ZodTypeAny> => {
config,
});
}
if (name.includes('translate')) {
return openAITranslationModelRef({
name,
config,
});
}
if (name.includes('whisper') || name.includes('transcribe')) {
return openAITranscriptionModelRef({
name,
Expand Down
45 changes: 45 additions & 0 deletions js/plugins/compat-oai/src/openai/translation.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
/**
* Copyright 2024 The Fire Company
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

import { z } from 'genkit';
import { ModelInfo } from 'genkit/model';
import { compatOaiTranslationModelRef } from '../translate';

/** OpenAI translation ModelRef helper, same as the OpenAI-compatible spec. */
export function openAITranslationModelRef<
CustomOptions extends z.ZodTypeAny = z.ZodTypeAny,
>(params: {
name: string;
info?: ModelInfo;
configSchema?: CustomOptions;
config?: any;
}) {
return compatOaiTranslationModelRef({ ...params, namespace: 'openai' });
}

export const SUPPORTED_TRANSLATION_MODELS = {
/**
* Whisper 1 translation model.
*
* The actual OpenAI model ID is 'whisper-1', but we use 'whisper-1-translate'
* to distinguish it from the 'whisper-1' transcription model. The model ID
* is overridden in index.ts to 'whisper-1' when calling the OpenAI API.
*/
'whisper-1-translate': openAITranslationModelRef({
name: 'whisper-1-translate',
}),
};
223 changes: 223 additions & 0 deletions js/plugins/compat-oai/src/translate.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,223 @@
/**
* Copyright 2024 The Fire Company
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

import type {
GenerateRequest,
GenerateResponseData,
ModelReference,
} from 'genkit';
import { GenerationCommonConfigSchema, Message, modelRef, z } from 'genkit';
import type { ModelAction, ModelInfo } from 'genkit/model';
import { model } from 'genkit/plugin';
import OpenAI from 'openai';
import type {
TranslationCreateParams,
TranslationCreateResponse,
} from 'openai/resources/audio/index.mjs';
import { PluginOptions } from './index.js';
import { maybeCreateRequestScopedOpenAIClient, toModelName } from './utils.js';

export type TranslationRequestBuilder = (
req: GenerateRequest,
params: TranslationCreateParams
) => void;

export const TRANSLATION_MODEL_INFO: ModelInfo = {
supports: {
media: true,
output: ['text', 'json'],
multiturn: false,
systemRole: false,
tools: false,
},
};

export const TranslationConfigSchema = GenerationCommonConfigSchema.pick({
temperature: true,
}).extend({
response_format: z
.enum(['json', 'text', 'srt', 'verbose_json', 'vtt'])
.optional(),
});

function toTranslationRequest(
modelName: string,
request: GenerateRequest,
requestBuilder?: TranslationRequestBuilder
): TranslationCreateParams {
const message = new Message(request.messages[0]);
const media = message.media;
if (!media?.url) {
throw new Error('No media found in the request');
}
const mediaBuffer = Buffer.from(
media.url.slice(media.url.indexOf(',') + 1),
'base64'
);
const mediaFile = new File([mediaBuffer], 'input', {
type:
media.contentType ??
media.url.slice('data:'.length, media.url.indexOf(';')),
});
const {
temperature,
version: modelVersion,
maxOutputTokens,
stopSequences,
topK,
topP,
...restOfConfig
} = request.config ?? {};

let options: TranslationCreateParams = {
model: modelVersion ?? modelName,
file: mediaFile,
prompt: message.text,
temperature,
};
if (requestBuilder) {
requestBuilder(request, options);
} else {
options = {
...options,
...restOfConfig, // passthrough rest of the config
};
}
const outputFormat = request.output?.format as 'json' | 'text' | 'media';
const customFormat = request.config?.response_format;
if (outputFormat && customFormat) {
if (
outputFormat === 'json' &&
customFormat !== 'json' &&
customFormat !== 'verbose_json'
) {
throw new Error(
`Custom response format ${customFormat} is not compatible with output format ${outputFormat}`
);
}
}
if (outputFormat === 'media') {
throw new Error(`Output format ${outputFormat} is not supported.`);
}
options.response_format = customFormat || outputFormat || 'text';
for (const k in options) {
if (options[k] === undefined) {
delete options[k];
}
}
return options;
}

function translationToGenerateResponse(
result: TranslationCreateResponse | string
): GenerateResponseData {
return {
message: {
role: 'model',
content: [
{
text: typeof result === 'string' ? result : result.text,
},
],
},
finishReason: 'stop',
raw: result,
};
}

/**
* Method to define a new Genkit Model that is compatible with Open AI
* Translation API.
*
* These models are to be used to translate audio to text.
*
* @param params An object containing parameters for defining the OpenAI
* translation model.
* @param params.ai The Genkit AI instance.
* @param params.name The name of the model.
* @param params.client The OpenAI client instance.
* @param params.modelRef Optional reference to the model's configuration and
* custom options.
*
* @returns the created {@link ModelAction}
*/
export function defineCompatOpenAITranslationModel<
CustomOptions extends z.ZodTypeAny = z.ZodTypeAny,
>(params: {
name: string;
client: OpenAI;
pluginOptions?: PluginOptions;
modelRef?: ModelReference<CustomOptions>;
requestBuilder?: TranslationRequestBuilder;
}) {
const {
name,
client: defaultClient,
pluginOptions,
modelRef,
requestBuilder,
} = params;
const modelName = toModelName(name, pluginOptions?.name);
const actionName = `${pluginOptions?.name ?? 'compat-oai'}/${modelName}`;

return model(
{
name: actionName,
...modelRef?.info,
configSchema: modelRef?.configSchema,
},
async (request, { abortSignal }) => {
const params = toTranslationRequest(modelName, request, requestBuilder);
const client = maybeCreateRequestScopedOpenAIClient(
pluginOptions,
request,
defaultClient
);
const result = await client.audio.translations.create(params, {
signal: abortSignal,
});
return translationToGenerateResponse(result);
}
);
}

/** Translation ModelRef helper, with reasonable defaults for
* OpenAI-compatible providers */
export function compatOaiTranslationModelRef<
CustomOptions extends z.ZodTypeAny = z.ZodTypeAny,
>(params: {
name: string;
info?: ModelInfo;
configSchema?: CustomOptions;
config?: any;
namespace?: string;
}) {
const {
name,
info = TRANSLATION_MODEL_INFO,
configSchema,
config = undefined,
namespace,
} = params;
return modelRef({
name,
configSchema: configSchema || (TranslationConfigSchema as any),
info,
config,
namespace,
});
}
Binary file added js/testapps/compat-oai/audio-korean.mp3
Binary file not shown.
Loading