diff --git a/docs/reflection-v2-protocol.md b/docs/reflection-v2-protocol.md new file mode 100644 index 0000000000..ac974c8657 --- /dev/null +++ b/docs/reflection-v2-protocol.md @@ -0,0 +1,234 @@ +# Genkit Reflection Protocol V2 (WebSocket) + +This document outlines the design for the V2 Reflection API, which uses WebSockets for bidirectional communication between the Genkit CLI (Runtime Manager) and Genkit Runtimes (User Applications). + +## Overview + +In V2, the connection direction is reversed compared to V1: +- **Server**: The Genkit CLI (`RuntimeManagerV2`) starts a WebSocket server. +- **Client**: The Genkit Runtime connects to the CLI's WebSocket server. + +This architecture allows the CLI to easily manage multiple runtimes (e.g., for multi-service projects) and eliminates the need for runtimes to manage their own HTTP servers and ports for reflection. + +## Transport + +| Feature | Specification | +| :--- | :--- | +| **Protocol** | WebSocket | +| **Data Format** | JSON | +| **Message Structure** | JSON-RPC 2.0 (modified for streaming) | + +## Message Format + +All messages follow the JSON-RPC 2.0 specification. + +### Request +```json +{ + "jsonrpc": "2.0", + "method": "methodName", + "params": { ... }, + "id": 1 +} +``` +*Note: The `id` is generated by the sender (Manager). It can be a number (auto-incrementing) or a string (UUID). It must be unique for the pending request within the WebSocket session.* + +### Response (Success) +```json +{ + "jsonrpc": "2.0", + "result": { ... }, + "id": 1 +} +``` + +### Response (Error) +```json +{ + "jsonrpc": "2.0", + "error": { + "code": -32000, + "message": "Error message", + "data": { + "code": 13, + "message": "Error message", + "details": { + "traceId": "...", + "stack": "..." + } + } + }, + "id": 1 +} +``` + +The `data` field contains a `Status` object (matching V1 API) with: +- **`code`**: Genkit canonical status code (e.g., 13 for INTERNAL, 3 for INVALID_ARGUMENT). +- **`message`**: The error message. +- **`details`**: Additional context, including `traceId` and `stack` trace. + +### Notification +A request without an `id`. +```json +{ + "jsonrpc": "2.0", + "method": "methodName", + "params": { ... } +} +``` + +## Streaming Extension + +JSON-RPC 2.0 does not natively support streaming. We extend it by using Notifications from the Runtime to the Manager associated with a specific Request ID. + +| Message Type | Method | Direction | Description | +| :--- | :--- | :--- | :--- | +| **Stream Chunk** | `streamChunk` | Runtime -> Manager | Sent by the Runtime during a streaming `runAction` request. | +| **State Update** | `runActionState` | Runtime -> Manager | Sent by the Runtime to provide status updates (e.g., trace ID) before the result. | + +### Stream Chunk Notification +```json +{ + "jsonrpc": "2.0", + "method": "streamChunk", + "params": { + "requestId": 1, + "chunk": { ... } + } +} +``` + +### Run Action State Notification +```json +{ + "jsonrpc": "2.0", + "method": "runActionState", + "params": { + "requestId": 1, + "state": { "traceId": "..." } + } +} +``` + +## Protocol Methods Summary + +| Method | Direction | Type | Description | +| :--- | :--- | :--- | :--- | +| **`register`** | Runtime -> Manager | Request | Registers the runtime with the Manager. | +| **`configure`** | Manager -> Runtime | Notification | Pushes configuration updates to the Runtime. | +| **`listActions`** | Manager -> Runtime | Request | Retrieves the list of available actions. | +| **`listValues`** | Manager -> Runtime | Request | Retrieves the list of values (prompts, schemas, etc.). | +| **`runAction`** | Manager -> Runtime | Request | Executes an action. | +| **`cancelAction`** | Manager -> Runtime | Request | Cancels a running action. | + +## Detailed API + +### 1. Registration +**Direction:** Runtime -> Manager +**Type:** Request + +**Parameters:** +| Field | Type | Description | +| :--- | :--- | :--- | +| `id` | `string` | Unique Runtime ID. | +| `pid` | `number` | Process ID. | +| `name` | `string` | App name (optional). | +| `genkitVersion` | `string` | e.g., "0.9.0". | +| `reflectionApiSpecVersion` | `number` | Protocol version. | +| `envs` | `string[]` | Configured environments (optional). | + +**Result:** `void` + +### 2. Configuration +**Direction:** Manager -> Runtime +**Type:** Notification + +**Parameters:** +| Field | Type | Description | +| :--- | :--- | :--- | +| `telemetryServerUrl` | `string` | URL of the telemetry server (optional). | + +### 3. List Actions +**Direction:** Manager -> Runtime +**Type:** Request + +**Parameters:** `void` + +**Result:** +| Type | Description | +| :--- | :--- | +| `Record` | Map of action keys to Action definitions. (Same schema as V1 `/api/actions`) | + +### 4. List Values +**Direction:** Manager -> Runtime +**Type:** Request + +**Parameters:** +| Field | Type | Description | +| :--- | :--- | :--- | +| `type` | `string` | The type of value to list (e.g., "model", "prompt", "schema"). | + +**Result:** +| Type | Description | +| :--- | :--- | +| `Record` | Map of value keys to value definitions. | + +### 5. Run Action +**Direction:** Manager -> Runtime +**Type:** Request + +**Parameters:** +| Field | Type | Description | +| :--- | :--- | :--- | +| `key` | `string` | Action key (e.g., "/flow/myFlow"). | +| `input` | `any` | Input payload. | +| `context` | `any` | Context data (optional). | +| `telemetryLabels` | `Record` | Telemetry labels (optional). | +| `stream` | `boolean` | Whether to stream results. | +| `streamInput` | `boolean` | Whether to stream input (for bidi actions). | + +**Result (Non-Streaming):** +| Field | Type | Description | +| :--- | :--- | :--- | +| `result` | `any` | The return value. | +| `telemetry` | `object` | Telemetry metadata (e.g., `{ traceId: string }`). | + +**Streaming Flow:** +1. Runtime sends optional `runActionState` notifications. +2. Runtime sends `streamChunk` notifications. +3. Runtime sends final response with `result` (same structure as non-streaming). + +**Bidirectional Streaming Flow (if `streamInput: true`):** +1. Manager sends `streamInputChunk` notifications. +2. Manager sends `endStreamInput` notification. +3. Runtime behaves as per Streaming Flow. + +### 6. Cancel Action +**Direction:** Manager -> Runtime +**Type:** Request + +**Parameters:** +| Field | Type | Description | +| :--- | :--- | :--- | +| `traceId` | `string` | The trace ID of the action to cancel. | + +**Result:** +| Field | Type | Description | +| :--- | :--- | :--- | +| `message` | `string` | Confirmation message. | + +## Health Checks + +| Check Type | Description | +| :--- | :--- | +| **Connection State** | The WebSocket connection state itself serves as a basic health check. | +| **Heartbeats** | Standard WebSocket Ping/Pong frames should be used to maintain the connection and detect timeouts. | + +## Compatibility + +| Version | Architecture | +| :--- | :--- | +| **V1** | HTTP Server on Runtime, Polling/Request from CLI. | +| **V2** | WebSocket Server on CLI, Persistent Connection from Runtime. | + +The CLI will determine which mode to use based on configuration (e.g., `--experimental-reflection-v2`). diff --git a/genkit-tools/cli/src/commands/dev-test-model.ts b/genkit-tools/cli/src/commands/dev-test-model.ts index b03d01c429..527238c8d6 100644 --- a/genkit-tools/cli/src/commands/dev-test-model.ts +++ b/genkit-tools/cli/src/commands/dev-test-model.ts @@ -21,8 +21,8 @@ import { Part, } from '@genkit-ai/tools-common'; import { + BaseRuntimeManager, GenkitToolsError, - RuntimeManager, } from '@genkit-ai/tools-common/manager'; import { findProjectRoot, logger } from '@genkit-ai/tools-common/utils'; import { Command } from 'commander'; @@ -345,7 +345,7 @@ const TEST_CASES: Record = { }, }; -async function waitForRuntime(manager: RuntimeManager) { +async function waitForRuntime(manager: BaseRuntimeManager) { // Poll for runtimes for (let i = 0; i < 20; i++) { if (manager.listRuntimes().length > 0) return; @@ -355,7 +355,7 @@ async function waitForRuntime(manager: RuntimeManager) { } async function runTest( - manager: RuntimeManager, + manager: BaseRuntimeManager, model: string, testCase: TestCase ): Promise { @@ -397,7 +397,7 @@ async function runTest( } async function runTestSuite( - manager: RuntimeManager, + manager: BaseRuntimeManager, suite: TestSuite, defaultSupports: string[] ): Promise<{ passed: number; failed: number }> { @@ -470,7 +470,7 @@ export const devTestModel = new Command('dev:test-model') if (args) cmd = args; } - let manager: RuntimeManager; + let manager: BaseRuntimeManager; if (cmd.length > 0) { const result = await startDevProcessManager( diff --git a/genkit-tools/cli/src/commands/start.ts b/genkit-tools/cli/src/commands/start.ts index 151f468f93..4a239aec60 100644 --- a/genkit-tools/cli/src/commands/start.ts +++ b/genkit-tools/cli/src/commands/start.ts @@ -14,7 +14,7 @@ * limitations under the License. */ -import type { RuntimeManager } from '@genkit-ai/tools-common/manager'; +import type { BaseRuntimeManager } from '@genkit-ai/tools-common/manager'; import { startServer } from '@genkit-ai/tools-common/server'; import { findProjectRoot, logger } from '@genkit-ai/tools-common/utils'; import { Command } from 'commander'; @@ -27,6 +27,8 @@ interface RunOptions { port?: string; open?: boolean; disableRealtimeTelemetry?: boolean; + experimentalReflectionV2?: boolean; + allowedTelemetryCorsHostnames?: string; } /** Command to run code in dev mode and/or the Dev UI. */ @@ -39,6 +41,14 @@ export const start = new Command('start') '--disable-realtime-telemetry', 'Disable real-time telemetry streaming' ) + .option( + '--experimental-reflection-v2', + 'start the experimental reflection server (WebSocket)' + ) + .option( + '--allowed-telemetry-cors-hostnames ', + 'comma separated list of allowed telemetry CORS hostnames' + ) .action(async (options: RunOptions) => { const projectRoot = await findProjectRoot(); if (projectRoot.includes('/.Trash/')) { @@ -48,19 +58,32 @@ export const start = new Command('start') ); } // Always start the manager. - let manager: RuntimeManager; + let manager: BaseRuntimeManager; let processPromise: Promise | undefined; + const allowedTelemetryCorsHostnames = options.allowedTelemetryCorsHostnames + ? options.allowedTelemetryCorsHostnames.split(',') + : undefined; + if (start.args.length > 0) { const result = await startDevProcessManager( projectRoot, start.args[0], start.args.slice(1), - { disableRealtimeTelemetry: options.disableRealtimeTelemetry } + { + disableRealtimeTelemetry: options.disableRealtimeTelemetry, + experimentalReflectionV2: options.experimentalReflectionV2, + allowedTelemetryCorsHostnames, + } ); manager = result.manager; processPromise = result.processPromise; } else { - manager = await startManager(projectRoot, true); + manager = await startManager( + projectRoot, + true, + options.experimentalReflectionV2, + allowedTelemetryCorsHostnames + ); processPromise = new Promise(() => {}); } if (!options.noui) { diff --git a/genkit-tools/cli/src/mcp/utils.ts b/genkit-tools/cli/src/mcp/utils.ts index c0c5a20192..154c46de18 100644 --- a/genkit-tools/cli/src/mcp/utils.ts +++ b/genkit-tools/cli/src/mcp/utils.ts @@ -14,7 +14,7 @@ * limitations under the License. */ -import { RuntimeManager } from '@genkit-ai/tools-common/manager'; +import { BaseRuntimeManager } from '@genkit-ai/tools-common/manager'; import { z } from 'zod'; import { startDevProcessManager, startManager } from '../utils/manager-utils'; @@ -62,7 +62,7 @@ export function resolveProjectRoot( /** Genkit Runtime manager specifically for the MCP server. Allows lazy * initialization and dev process manangement. */ export class McpRuntimeManager { - private manager: RuntimeManager | undefined; + private manager: BaseRuntimeManager | undefined; private currentProjectRoot: string | undefined; async getManager(projectRoot: string) { @@ -83,7 +83,7 @@ export class McpRuntimeManager { args: string[]; explicitProjectRoot: boolean; timeout?: number; - }): Promise { + }): Promise { const { projectRoot, command, args, timeout, explicitProjectRoot } = params; if (this.manager) { await this.manager.stop(); diff --git a/genkit-tools/cli/src/utils/manager-utils.ts b/genkit-tools/cli/src/utils/manager-utils.ts index a961f6835c..1648866643 100644 --- a/genkit-tools/cli/src/utils/manager-utils.ts +++ b/genkit-tools/cli/src/utils/manager-utils.ts @@ -20,6 +20,7 @@ import { } from '@genkit-ai/telemetry-server'; import type { Status } from '@genkit-ai/tools-common'; import { + BaseRuntimeManager, ProcessManager, RuntimeEvent, RuntimeManager, @@ -34,7 +35,8 @@ import getPort, { makeRange } from 'get-port'; * This function is not idempotent. Typically you want to make sure it's called only once per cli instance. */ export async function resolveTelemetryServer( - projectRoot: string + projectRoot: string, + allowedTelemetryCorsHostnames?: string[] ): Promise { let telemetryServerUrl = process.env.GENKIT_TELEMETRY_SERVER; if (!telemetryServerUrl) { @@ -46,6 +48,7 @@ export async function resolveTelemetryServer( storeRoot: projectRoot, indexRoot: projectRoot, }), + allowedCorsHostnames: allowedTelemetryCorsHostnames, }); } return telemetryServerUrl; @@ -56,13 +59,19 @@ export async function resolveTelemetryServer( */ export async function startManager( projectRoot: string, - manageHealth?: boolean -): Promise { - const telemetryServerUrl = await resolveTelemetryServer(projectRoot); - const manager = RuntimeManager.create({ + manageHealth?: boolean, + experimentalReflectionV2?: boolean, + allowedTelemetryCorsHostnames?: string[] +): Promise { + const telemetryServerUrl = await resolveTelemetryServer( + projectRoot, + allowedTelemetryCorsHostnames + ); + const manager = await RuntimeManager.create({ telemetryServerUrl, manageHealth, projectRoot, + experimentalReflectionV2, }); return manager; } @@ -73,6 +82,8 @@ export interface DevProcessManagerOptions { healthCheck?: boolean; timeout?: number; cwd?: string; + experimentalReflectionV2?: boolean; + allowedTelemetryCorsHostnames?: string[]; } export async function startDevProcessManager( @@ -80,13 +91,25 @@ export async function startDevProcessManager( command: string, args: string[], options?: DevProcessManagerOptions -): Promise<{ manager: RuntimeManager; processPromise: Promise }> { - const telemetryServerUrl = await resolveTelemetryServer(projectRoot); +): Promise<{ manager: BaseRuntimeManager; processPromise: Promise }> { + const telemetryServerUrl = await resolveTelemetryServer( + projectRoot, + options?.allowedTelemetryCorsHostnames + ); const disableRealtimeTelemetry = options?.disableRealtimeTelemetry ?? false; + const experimentalReflectionV2 = options?.experimentalReflectionV2 ?? false; + + let reflectionV2Port: number | undefined; const envVars: Record = { GENKIT_TELEMETRY_SERVER: telemetryServerUrl, GENKIT_ENV: 'dev', }; + + if (experimentalReflectionV2) { + reflectionV2Port = await getPort({ port: makeRange(3200, 3400) }); + envVars.GENKIT_REFLECTION_V2_SERVER = `ws://localhost:${reflectionV2Port}`; + } + if (!disableRealtimeTelemetry) { envVars.GENKIT_ENABLE_REALTIME_TELEMETRY = 'true'; } @@ -97,6 +120,8 @@ export async function startDevProcessManager( projectRoot, processManager, disableRealtimeTelemetry, + experimentalReflectionV2, + reflectionV2Port, }); const processPromise = processManager.start({ ...options }); @@ -112,7 +137,7 @@ export async function startDevProcessManager( * Rejects if the process exits or if the timeout is reached. */ export async function waitForRuntime( - manager: RuntimeManager, + manager: BaseRuntimeManager, processPromise: Promise, timeoutMs: number = 30000 ): Promise { @@ -165,9 +190,9 @@ export async function waitForRuntime( */ export async function runWithManager( projectRoot: string, - fn: (manager: RuntimeManager) => Promise + fn: (manager: BaseRuntimeManager) => Promise ) { - let manager: RuntimeManager; + let manager: BaseRuntimeManager; try { manager = await startManager(projectRoot, false); // Don't manage health in this case. } catch (e) { diff --git a/genkit-tools/cli/tests/commands/start_test.ts b/genkit-tools/cli/tests/commands/start_test.ts index 76c90e0833..b6e640def0 100644 --- a/genkit-tools/cli/tests/commands/start_test.ts +++ b/genkit-tools/cli/tests/commands/start_test.ts @@ -92,7 +92,12 @@ describe('start command', () => { await serverStartedPromise; - expect(startManagerSpy).toHaveBeenCalledWith('/mock/root', true); + expect(startManagerSpy).toHaveBeenCalledWith( + '/mock/root', + true, + undefined, + undefined + ); expect(startDevProcessManagerSpy).not.toHaveBeenCalled(); expect(startServerSpy).toHaveBeenCalled(); }); diff --git a/genkit-tools/common/package.json b/genkit-tools/common/package.json index acb152ec1e..fd3b67ad5c 100644 --- a/genkit-tools/common/package.json +++ b/genkit-tools/common/package.json @@ -22,6 +22,7 @@ "commander": "^11.1.0", "configstore": "^5.0.1", "cors": "^2.8.5", + "events": "^3.3.0", "express": "^4.21.0", "get-port": "5.1.1", "glob": "^10.3.12", @@ -34,7 +35,8 @@ "winston": "^3.11.0", "yaml": "^2.4.1", "zod": "^3.22.4", - "zod-to-json-schema": "^3.22.4" + "zod-to-json-schema": "^3.22.4", + "ws": "^8.18.3" }, "devDependencies": { "@jest/globals": "^29.7.0", @@ -43,6 +45,7 @@ "@types/cli-color": "^2.0.6", "@types/configstore": "^6.0.2", "@types/cors": "^2.8.19", + "@types/events": "^3.0.3", "@types/express": "^4.17.21", "@types/inquirer": "^8.1.3", "@types/jest": "^29.5.12", @@ -50,6 +53,7 @@ "@types/json-schema": "^7.0.15", "@types/node": "^20.11.19", "@types/uuid": "^9.0.8", + "@types/ws": "^8.18.1", "bun-types": "^1.2.16", "genversion": "^3.2.0", "jest": "^29.7.0", diff --git a/genkit-tools/common/src/eval/evaluate.ts b/genkit-tools/common/src/eval/evaluate.ts index 36ee970d35..985c128d75 100644 --- a/genkit-tools/common/src/eval/evaluate.ts +++ b/genkit-tools/common/src/eval/evaluate.ts @@ -17,7 +17,7 @@ import { randomUUID } from 'crypto'; import { z } from 'zod'; import { getDatasetStore, getEvalStore } from '.'; -import type { RuntimeManager } from '../manager/manager'; +import type { BaseRuntimeManager } from '../manager/manager'; import { DatasetSchema, GenerateActionOptions, @@ -72,7 +72,7 @@ const GENERATE_ACTION_UTIL = '/util/generate'; * Starts a new evaluation run. Intended to be used via the reflection API. */ export async function runNewEvaluation( - manager: RuntimeManager, + manager: BaseRuntimeManager, request: RunNewEvaluationRequest ): Promise { const { dataSource, actionRef, evaluators } = request; @@ -141,7 +141,7 @@ export async function runNewEvaluation( /** Handles the Inference part of Inference-Evaluation cycle */ export async function runInference(params: { - manager: RuntimeManager; + manager: BaseRuntimeManager; actionRef: string; inferenceDataset: Dataset; context?: string; @@ -165,7 +165,7 @@ export async function runInference(params: { /** Handles the Evaluation part of Inference-Evaluation cycle */ export async function runEvaluation(params: { - manager: RuntimeManager; + manager: BaseRuntimeManager; evaluatorActions: Action[]; evalDataset: EvalInput[]; augments?: EvalKeyAugments; @@ -221,7 +221,7 @@ export async function runEvaluation(params: { } export async function getAllEvaluatorActions( - manager: RuntimeManager + manager: BaseRuntimeManager ): Promise { const allActions = await manager.listActions(); const allEvaluatorActions = []; @@ -234,7 +234,7 @@ export async function getAllEvaluatorActions( } export async function getMatchingEvaluatorActions( - manager: RuntimeManager, + manager: BaseRuntimeManager, evaluators?: string[] ): Promise { if (!evaluators) { @@ -253,7 +253,7 @@ export async function getMatchingEvaluatorActions( } async function bulkRunAction(params: { - manager: RuntimeManager; + manager: BaseRuntimeManager; actionRef: string; inferenceDataset: Dataset; context?: string; @@ -315,7 +315,7 @@ async function bulkRunAction(params: { } async function runFlowAction(params: { - manager: RuntimeManager; + manager: BaseRuntimeManager; actionRef: string; sample: FullInferenceSample; context?: any; @@ -347,7 +347,7 @@ async function runFlowAction(params: { } async function runModelAction(params: { - manager: RuntimeManager; + manager: BaseRuntimeManager; actionRef: string; sample: FullInferenceSample; modelConfig?: any; @@ -379,7 +379,7 @@ async function runModelAction(params: { } async function runPromptAction(params: { - manager: RuntimeManager; + manager: BaseRuntimeManager; actionRef: string; sample: FullInferenceSample; context?: any; @@ -466,7 +466,7 @@ async function runPromptAction(params: { } async function gatherEvalInput(params: { - manager: RuntimeManager; + manager: BaseRuntimeManager; actionRef: string; state: InferenceRunState; }): Promise { diff --git a/genkit-tools/common/src/eval/validate.ts b/genkit-tools/common/src/eval/validate.ts index 7453f8c507..3f204da3a8 100644 --- a/genkit-tools/common/src/eval/validate.ts +++ b/genkit-tools/common/src/eval/validate.ts @@ -17,7 +17,7 @@ import Ajv, { type ErrorObject, type JSONSchemaType } from 'ajv'; import addFormats from 'ajv-formats'; import { getDatasetStore } from '.'; -import type { RuntimeManager } from '../manager'; +import type { BaseRuntimeManager } from '../manager'; import { InferenceDatasetSchema, type Action, @@ -35,7 +35,7 @@ type JSONSchema = JSONSchemaType | any; * reflection API. */ export async function validateSchema( - manager: RuntimeManager, + manager: BaseRuntimeManager, request: ValidateDataRequest ): Promise { const { dataSource, actionRef } = request; @@ -125,7 +125,7 @@ function toErrorDetail(error: ErrorObject): ErrorDetail { } async function getAction( - manager: RuntimeManager, + manager: BaseRuntimeManager, actionRef: string ): Promise { const actions = await manager.listActions(); diff --git a/genkit-tools/common/src/manager/index.ts b/genkit-tools/common/src/manager/index.ts index d28014cc0b..b72c27125b 100644 --- a/genkit-tools/common/src/manager/index.ts +++ b/genkit-tools/common/src/manager/index.ts @@ -14,7 +14,7 @@ * limitations under the License. */ -export { RuntimeManager } from './manager'; +export { BaseRuntimeManager, RuntimeManager } from './manager'; export { AppProcessStatus, ProcessManager, diff --git a/genkit-tools/common/src/manager/manager-v2.ts b/genkit-tools/common/src/manager/manager-v2.ts new file mode 100644 index 0000000000..679f11438e --- /dev/null +++ b/genkit-tools/common/src/manager/manager-v2.ts @@ -0,0 +1,469 @@ +/** + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import EventEmitter from 'events'; +import getPort, { makeRange } from 'get-port'; +import { WebSocket, WebSocketServer } from 'ws'; +import { + Action, + RunActionResponse, + RunActionResponseSchema, +} from '../types/action'; +import * as apis from '../types/apis'; +import { logger } from '../utils/logger'; +import { DevToolsInfo } from '../utils/utils'; +import { BaseRuntimeManager, RuntimeManagerOptions } from './manager'; +import { ProcessManager } from './process-manager'; +import { + GenkitToolsError, + RuntimeEvent, + RuntimeInfo, + StreamingCallback, +} from './types'; + +interface JsonRpcRequest { + jsonrpc: '2.0'; + method: string; + params?: any; + id?: number | string; +} + +interface JsonRpcResponse { + jsonrpc: '2.0'; + result?: any; + error?: { + code: number; + message: string; + data?: any; + }; + id: number | string; +} + +type JsonRpcMessage = JsonRpcRequest | JsonRpcResponse; + +interface ConnectedRuntime { + ws: WebSocket; + info: RuntimeInfo; +} + +export class RuntimeManagerV2 extends BaseRuntimeManager { + private _port?: number; + private wss?: WebSocketServer; + private runtimes: Map = new Map(); + + get port(): number | undefined { + return this._port; + } + private pendingRequests: Map< + number | string, + { resolve: (value: any) => void; reject: (reason?: any) => void } + > = new Map(); + private streamCallbacks: Map> = + new Map(); + private traceIdCallbacks: Map void> = + new Map(); + private eventEmitter = new EventEmitter(); + private requestIdCounter = 0; + + constructor( + telemetryServerUrl: string | undefined, + readonly manageHealth: boolean, + projectRoot: string, + processManager?: ProcessManager, + disableRealtimeTelemetry: boolean = false + ) { + super( + telemetryServerUrl, + projectRoot, + processManager, + disableRealtimeTelemetry + ); + } + + static async create( + options: RuntimeManagerOptions + ): Promise { + const manager = new RuntimeManagerV2( + options.telemetryServerUrl, + options.manageHealth ?? true, + options.projectRoot, + options.processManager, + options.disableRealtimeTelemetry + ); + await manager.startWebSocketServer(options.reflectionV2Port); + return manager; + } + + /** + * Starts a WebSocket server. + */ + private async startWebSocketServer(port?: number): Promise<{ port: number }> { + if (!port) { + port = await getPort({ port: makeRange(3200, 3400) }); + } + this.wss = new WebSocketServer({ port }); + + this._port = port; + logger.info(`Starting reflection server: ws://localhost:${port}`); + + this.wss.on('connection', (ws) => { + ws.on('error', (err) => logger.error(`WebSocket error: ${err}`)); + + ws.on('message', (data) => { + try { + const message = JSON.parse(data.toString()) as JsonRpcMessage; + this.handleMessage(ws, message); + } catch (error) { + logger.error('Failed to parse WebSocket message:', error); + } + }); + + ws.on('close', () => { + this.handleDisconnect(ws); + }); + }); + return { port }; + } + + private handleMessage(ws: WebSocket, message: JsonRpcMessage) { + if ('method' in message) { + this.handleRequest(ws, message as JsonRpcRequest); + } else { + this.handleResponse(message as JsonRpcResponse); + } + } + + private handleRequest(ws: WebSocket, request: JsonRpcRequest) { + switch (request.method) { + case 'register': + this.handleRegister(ws, request); + break; + case 'streamChunk': + this.handleStreamChunk(request); + break; + case 'runActionState': + this.handleRunActionState(request); + break; + default: + logger.warn(`Unknown method: ${request.method}`); + } + } + + private handleRegister(ws: WebSocket, request: JsonRpcRequest) { + const params = request.params; + const runtimeInfo: RuntimeInfo = { + id: params.id, + pid: params.pid, + name: params.name, + genkitVersion: params.genkitVersion, + reflectionApiSpecVersion: params.reflectionApiSpecVersion, + reflectionServerUrl: `ws://localhost:${this.port}`, // Virtual URL for compatibility + timestamp: new Date().toISOString(), + projectName: params.name || 'Unknown', // Or derive from other means if needed + }; + + this.runtimes.set(runtimeInfo.id, { ws, info: runtimeInfo }); + this.eventEmitter.emit(RuntimeEvent.ADD, runtimeInfo); + + // Send success response + if (request.id) { + ws.send( + JSON.stringify({ + jsonrpc: '2.0', + result: null, + id: request.id, + }) + ); + } + + // Configure the runtime immediately + this.notifyRuntime(runtimeInfo.id); + } + + private handleStreamChunk(notification: JsonRpcRequest) { + const { requestId, chunk } = notification.params; + const callback = this.streamCallbacks.get(requestId); + if (callback) { + callback(chunk); + } + } + + private handleRunActionState(notification: JsonRpcRequest) { + const { requestId, state } = notification.params; + const callback = this.traceIdCallbacks.get(requestId); + if (callback && state?.traceId) { + callback(state.traceId); + } + } + + private handleResponse(response: JsonRpcResponse) { + const pending = this.pendingRequests.get(response.id); + if (pending) { + if (response.error) { + const errorData = response.error.data || {}; + const massagedData = { + ...errorData, + stack: errorData.details?.stack, + data: { + genkitErrorMessage: errorData.message, + genkitErrorDetails: errorData.details, + }, + }; + const error = new GenkitToolsError(response.error.message); + error.data = massagedData; + pending.reject(error); + } else { + pending.resolve(response.result); + } + this.pendingRequests.delete(response.id); + } else { + logger.warn(`Received response for unknown request ID ${response.id}`); + } + } + + private handleDisconnect(ws: WebSocket) { + for (const [id, runtime] of this.runtimes.entries()) { + if (runtime.ws === ws) { + this.runtimes.delete(id); + this.eventEmitter.emit(RuntimeEvent.REMOVE, runtime.info); + break; + } + } + } + + private async sendRequest( + runtimeId: string, + method: string, + params?: any + ): Promise { + const runtime = this.runtimes.get(runtimeId); + if (!runtime) { + throw new Error(`Runtime ${runtimeId} not found`); + } + + const id = ++this.requestIdCounter; + const message: JsonRpcRequest = { + jsonrpc: '2.0', + method, + params, + id, + }; + + return new Promise((resolve, reject) => { + const timeoutId = setTimeout(() => { + if (this.pendingRequests.has(id)) { + this.pendingRequests.delete(id); + reject(new Error(`Request ${id} timed out`)); + } + }, 30000); + + this.pendingRequests.set(id, { + resolve: (value) => { + clearTimeout(timeoutId); + resolve(value); + }, + reject: (reason) => { + clearTimeout(timeoutId); + reject(reason); + }, + }); + + runtime.ws.send(JSON.stringify(message)); + }); + } + + private sendNotification(runtimeId: string, method: string, params?: any) { + const runtime = this.runtimes.get(runtimeId); + if (!runtime) { + logger.warn(`Runtime ${runtimeId} not found, cannot send notification`); + return; + } + const message: JsonRpcRequest = { + jsonrpc: '2.0', + method, + params, + }; + runtime.ws.send(JSON.stringify(message)); + } + + private notifyRuntime(runtimeId: string) { + this.sendNotification(runtimeId, 'configure', { + telemetryServerUrl: this.telemetryServerUrl, + }); + } + + listRuntimes(): RuntimeInfo[] { + return Array.from(this.runtimes.values()).map((r) => r.info); + } + + getRuntimeById(id: string): RuntimeInfo | undefined { + return this.runtimes.get(id)?.info; + } + + getMostRecentRuntime(): RuntimeInfo | undefined { + const runtimes = this.listRuntimes(); + if (runtimes.length === 0) return undefined; + return runtimes[runtimes.length - 1]; + } + + getMostRecentDevUI(): DevToolsInfo | undefined { + // Not applicable for V2 yet + return undefined; + } + + onRuntimeEvent( + listener: (eventType: RuntimeEvent, runtime: RuntimeInfo) => void + ) { + const listeners: Array<{ event: string; fn: (rt: RuntimeInfo) => void }> = + []; + Object.values(RuntimeEvent).forEach((event) => { + const fn = (rt: RuntimeInfo) => listener(event, rt); + this.eventEmitter.on(event, fn); + listeners.push({ event, fn }); + }); + return () => { + listeners.forEach(({ event, fn }) => { + this.eventEmitter.off(event, fn); + }); + }; + } + + async listActions( + input?: apis.ListActionsRequest + ): Promise> { + const runtimeId = input?.runtimeId || this.getMostRecentRuntime()?.id; + if (!runtimeId) { + throw new Error( + input?.runtimeId + ? `No runtime found with ID ${input.runtimeId}.` + : 'No runtimes found. Make sure your app is running using the `start_runtime` MCP tool or the CLI: `genkit start -- ...`. See getting started documentation.' + ); + } + return this.sendRequest(runtimeId, 'listActions'); + } + + async listValues( + input: apis.ListValuesRequest + ): Promise> { + const runtimeId = input?.runtimeId || this.getMostRecentRuntime()?.id; + if (!runtimeId) { + throw new Error( + input?.runtimeId + ? `No runtime found with ID ${input.runtimeId}.` + : 'No runtimes found. Make sure your app is running using `genkit start -- ...`. See getting started documentation.' + ); + } + return this.sendRequest(runtimeId, 'listValues', { type: input.type }); + } + + async stop() { + if (this.wss) { + this.wss.close(); + } + if (this.processManager) { + await this.processManager.kill(); + } + } + + async runAction( + input: apis.RunActionRequest, + streamingCallback?: StreamingCallback, + onTraceId?: (traceId: string) => void, + inputStream?: AsyncIterable + ): Promise { + const runtimeId = input.runtimeId || this.getMostRecentRuntime()?.id; + if (!runtimeId) { + throw new Error( + 'No runtimes found. Make sure your app is running using the `start_runtime` MCP tool or the CLI: `genkit start -- ...`. See getting started documentation.' + ); + } + + const runtime = this.runtimes.get(runtimeId); + if (!runtime) { + throw new Error(`Runtime ${runtimeId} not found`); + } + + const id = ++this.requestIdCounter; + + if (streamingCallback) { + this.streamCallbacks.set(id, streamingCallback); + } + if (onTraceId) { + this.traceIdCallbacks.set(id, onTraceId!); + } + + const message: JsonRpcRequest = { + jsonrpc: '2.0', + method: 'runAction', + params: { + ...input, + stream: !!streamingCallback, + streamInput: !!inputStream, + }, + id, + }; + + const promise = new Promise((resolve, reject) => { + this.pendingRequests.set(id, { resolve, reject }); + runtime!.ws.send(JSON.stringify(message)); + }) + .then((result) => { + return RunActionResponseSchema.parse(result); + }) + .finally(() => { + if (streamingCallback) { + this.streamCallbacks.delete(id); + } + if (onTraceId) { + this.traceIdCallbacks.delete(id); + } + }); + + if (inputStream) { + (async () => { + try { + for await (const chunk of inputStream!) { + this.sendNotification(runtimeId!, 'streamInputChunk', { + requestId: id, + chunk, + }); + } + this.sendNotification(runtimeId!, 'endStreamInput', { + requestId: id, + }); + } catch (e) { + logger.error(`Error streaming input: ${e}`); + } + })(); + } + + return promise as Promise; + } + + async cancelAction(input: { + traceId: string; + runtimeId?: string; + }): Promise<{ message: string }> { + const runtimeId = input.runtimeId || this.getMostRecentRuntime()?.id; + if (!runtimeId) { + throw new Error('No runtime found'); + } + // Assuming cancelAction is a request that returns a message + return this.sendRequest(runtimeId, 'cancelAction', { + traceId: input.traceId, + }); + } +} diff --git a/genkit-tools/common/src/manager/manager.ts b/genkit-tools/common/src/manager/manager.ts index 7a245e02e8..80ee90eb4e 100644 --- a/genkit-tools/common/src/manager/manager.ts +++ b/genkit-tools/common/src/manager/manager.ts @@ -51,7 +51,7 @@ const STREAM_DELIMITER = '\n'; const HEALTH_CHECK_INTERVAL = 5000; export const GENKIT_REFLECTION_API_SPEC_VERSION = 1; -interface RuntimeManagerOptions { +export interface RuntimeManagerOptions { /** URL of the telemetry server. */ telemetryServerUrl?: string; /** Whether to clean up unhealthy runtimes. */ @@ -62,11 +62,210 @@ interface RuntimeManagerOptions { processManager?: ProcessManager; /** Whether to disable realtime telemetry streaming. Defaults to false. */ disableRealtimeTelemetry?: boolean; + /** Experimental Reflection V2 flag */ + experimentalReflectionV2?: boolean; + /** Reflection V2 Port */ + reflectionV2Port?: number; } -export class RuntimeManager { - readonly processManager?: ProcessManager; - readonly disableRealtimeTelemetry: boolean; +export abstract class BaseRuntimeManager { + constructor( + readonly telemetryServerUrl: string | undefined, + readonly projectRoot: string, + readonly processManager?: ProcessManager, + readonly disableRealtimeTelemetry: boolean = false + ) {} + + abstract listRuntimes(): RuntimeInfo[]; + abstract getRuntimeById(id: string): RuntimeInfo | undefined; + abstract getMostRecentRuntime(): RuntimeInfo | undefined; + abstract getMostRecentDevUI(): DevToolsInfo | undefined; + abstract onRuntimeEvent( + listener: (eventType: RuntimeEvent, runtime: RuntimeInfo) => void + ): () => void; + abstract listActions( + input?: apis.ListActionsRequest + ): Promise>; + abstract runAction( + input: apis.RunActionRequest, + streamingCallback?: StreamingCallback, + onTraceId?: (traceId: string) => void, + inputStream?: AsyncIterable + ): Promise; + abstract cancelAction(input: { + traceId: string; + runtimeId?: string; + }): Promise<{ message: string }>; + abstract listValues( + input: apis.ListValuesRequest + ): Promise>; + + abstract stop(): Promise; + + /** + * Retrieves all traces + */ + async listTraces( + input: apis.ListTracesRequest + ): Promise { + const { limit, continuationToken, filter } = input; + let query = ''; + if (limit) { + query += `limit=${limit}`; + } + if (continuationToken) { + if (query !== '') { + query += '&'; + } + query += `continuationToken=${continuationToken}`; + } + if (filter) { + if (query !== '') { + query += '&'; + } + query += `filter=${encodeURI(JSON.stringify(filter))}`; + } + + const response = await axios + .get(`${this.telemetryServerUrl}/api/traces?${query}`) + .catch((err) => + this.httpErrorHandler(err, `Error listing traces for query='${query}'.`) + ); + + return apis.ListTracesResponseSchema.parse(response.data); + } + + /** + * Retrieves a trace for a given ID. + */ + async getTrace(input: apis.GetTraceRequest): Promise { + const { traceId } = input; + const response = await axios + .get(`${this.telemetryServerUrl}/api/traces/${traceId}`) + .catch((err) => + this.httpErrorHandler( + err, + `Error getting trace for traceId='${traceId}'` + ) + ); + + return response.data as TraceData; + } + + /** + * Streams trace updates in real-time from the telemetry server. + * Connects to the telemetry server's SSE endpoint and forwards updates via callback. + */ + async streamTrace( + input: apis.StreamTraceRequest, + streamingCallback: StreamingCallback + ): Promise { + const { traceId } = input; + + if (!this.telemetryServerUrl) { + throw new Error( + 'Telemetry server URL not configured. Cannot stream trace updates.' + ); + } + + const response = await axios + .get(`${this.telemetryServerUrl}/api/traces/${traceId}/stream`, { + headers: { + Accept: 'text/event-stream', + }, + responseType: 'stream', + }) + .catch((err) => + this.httpErrorHandler( + err, + `Error streaming trace for traceId='${traceId}'` + ) + ); + + const stream = response.data; + let buffer = ''; + + // Return a promise that resolves when the stream ends + return new Promise((resolve, reject) => { + stream.on('data', (chunk: Buffer) => { + buffer += chunk.toString(); + + // Process complete messages (ending with \n\n) + while (buffer.includes('\n\n')) { + const messageEnd = buffer.indexOf('\n\n'); + const message = buffer.substring(0, messageEnd).trim(); + buffer = buffer.substring(messageEnd + 2); + + // Skip empty messages + if (!message) { + continue; + } + // Parse SSE data line - strip "data: " prefix + try { + const jsonData = message.startsWith('data: ') + ? message.slice(6) + : message; + const parsed = JSON.parse(jsonData); + streamingCallback(parsed); + } catch (err) { + logger.error(`Error parsing stream data: ${err}`); + } + } + }); + + stream.on('end', () => { + resolve(); + }); + + stream.on('error', (err: Error) => { + logger.error(`Stream error for traceId='${traceId}': ${err}`); + reject(err); + }); + }); + } + + /** + * Adds a trace to the trace store + */ + async addTrace(input: TraceData): Promise { + await axios + .post(`${this.telemetryServerUrl}/api/traces/`, input) + .catch((err) => + this.httpErrorHandler(err, 'Error writing trace to store.') + ); + } + + /** + * Handles an HTTP error. + */ + protected httpErrorHandler(error: AxiosError, message?: string): never { + const newError = new GenkitToolsError(message || 'Internal Error'); + + if (error.response) { + // we got a non-200 response; copy the payload and rethrow + newError.data = error.response.data as GenkitError; + newError.stack = (error.response?.data as any).message; + if ((error.response?.data as any).message) { + newError.data.data = { + ...newError.data.data, + genkitErrorMessage: message, + genkitErrorDetails: { + stack: (error.response?.data as any).message, + traceId: (error.response?.data as any).traceId, + }, + }; + } + throw newError; + } + + // We actually have an exception; wrap it and re-throw. + throw new GenkitToolsError(message || 'Internal Error', { + cause: error.cause, + }); + } +} + +export class RuntimeManager extends BaseRuntimeManager { private filenameToRuntimeMap: Record = {}; private filenameToDevUiMap: Record = {}; private idToFileMap: Record = {}; @@ -75,20 +274,31 @@ export class RuntimeManager { private healthCheckInterval?: NodeJS.Timeout; private constructor( - readonly telemetryServerUrl: string | undefined, + telemetryServerUrl: string | undefined, private manageHealth: boolean, - readonly projectRoot: string, + projectRoot: string, processManager?: ProcessManager, disableRealtimeTelemetry?: boolean ) { - this.processManager = processManager; - this.disableRealtimeTelemetry = disableRealtimeTelemetry ?? false; + super( + telemetryServerUrl, + projectRoot, + processManager, + disableRealtimeTelemetry + ); } /** * Creates a new runtime manager. */ - static async create(options: RuntimeManagerOptions) { + static async create( + options: RuntimeManagerOptions + ): Promise { + if (options.experimentalReflectionV2) { + const { RuntimeManagerV2 } = await import('./manager-v2'); + return RuntimeManagerV2.create(options); + } + const manager = new RuntimeManager( options.telemetryServerUrl, options.manageHealth ?? true, @@ -209,7 +419,7 @@ export class RuntimeManager { } /** - * Retrieves all valuess. + * Retrieves all values. */ async listValues( input: apis.ListValuesRequest @@ -252,7 +462,8 @@ export class RuntimeManager { async runAction( input: apis.RunActionRequest, streamingCallback?: StreamingCallback, - onTraceId?: (traceId: string) => void + onTraceId?: (traceId: string) => void, + inputStream?: AsyncIterable ): Promise { const runtime = input.runtimeId ? this.getRuntimeById(input.runtimeId) @@ -461,139 +672,6 @@ export class RuntimeManager { } } - /** - * Retrieves all traces - */ - async listTraces( - input: apis.ListTracesRequest - ): Promise { - const { limit, continuationToken, filter } = input; - let query = ''; - if (limit) { - query += `limit=${limit}`; - } - if (continuationToken) { - if (query !== '') { - query += '&'; - } - query += `continuationToken=${continuationToken}`; - } - if (filter) { - if (query !== '') { - query += '&'; - } - query += `filter=${encodeURI(JSON.stringify(filter))}`; - } - - const response = await axios - .get(`${this.telemetryServerUrl}/api/traces?${query}`) - .catch((err) => - this.httpErrorHandler(err, `Error listing traces for query='${query}'.`) - ); - - return apis.ListTracesResponseSchema.parse(response.data); - } - - /** - * Retrieves a trace for a given ID. - */ - async getTrace(input: apis.GetTraceRequest): Promise { - const { traceId } = input; - const response = await axios - .get(`${this.telemetryServerUrl}/api/traces/${traceId}`) - .catch((err) => - this.httpErrorHandler( - err, - `Error getting trace for traceId='${traceId}'` - ) - ); - - return response.data as TraceData; - } - - /** - * Streams trace updates in real-time from the telemetry server. - * Connects to the telemetry server's SSE endpoint and forwards updates via callback. - */ - async streamTrace( - input: apis.StreamTraceRequest, - streamingCallback: StreamingCallback - ): Promise { - const { traceId } = input; - - if (!this.telemetryServerUrl) { - throw new Error( - 'Telemetry server URL not configured. Cannot stream trace updates.' - ); - } - - const response = await axios - .get(`${this.telemetryServerUrl}/api/traces/${traceId}/stream`, { - headers: { - Accept: 'text/event-stream', - }, - responseType: 'stream', - }) - .catch((err) => - this.httpErrorHandler( - err, - `Error streaming trace for traceId='${traceId}'` - ) - ); - - const stream = response.data; - let buffer = ''; - - // Return a promise that resolves when the stream ends - return new Promise((resolve, reject) => { - stream.on('data', (chunk: Buffer) => { - buffer += chunk.toString(); - - // Process complete messages (ending with \n\n) - while (buffer.includes('\n\n')) { - const messageEnd = buffer.indexOf('\n\n'); - const message = buffer.substring(0, messageEnd).trim(); - buffer = buffer.substring(messageEnd + 2); - - // Skip empty messages - if (!message) { - continue; - } - // Parse SSE data line - strip "data: " prefix - try { - const jsonData = message.startsWith('data: ') - ? message.slice(6) - : message; - const parsed = JSON.parse(jsonData); - streamingCallback(parsed); - } catch (err) { - logger.error(`Error parsing stream data: ${err}`); - } - } - }); - - stream.on('end', () => { - resolve(); - }); - - stream.on('error', (err: Error) => { - logger.error(`Stream error for traceId='${traceId}': ${err}`); - reject(err); - }); - }); - } - - /** - * Adds a trace to the trace store - */ - async addTrace(input: TraceData): Promise { - await axios - .post(`${this.telemetryServerUrl}/api/traces/`, input) - .catch((err) => - this.httpErrorHandler(err, 'Error writing trace to store.') - ); - } - /** * Notifies the runtime of dependencies it may need (e.g. telemetry server URL). */ @@ -789,35 +867,6 @@ export class RuntimeManager { } } - /** - * Handles an HTTP error. - */ - private httpErrorHandler(error: AxiosError, message?: string): never { - const newError = new GenkitToolsError(message || 'Internal Error'); - - if (error.response) { - // we got a non-200 response; copy the payload and rethrow - newError.data = error.response.data as GenkitError; - newError.stack = (error.response?.data as any).message; - if ((error.response?.data as any).message) { - newError.data.data = { - ...newError.data.data, - genkitErrorMessage: message, - genkitErrorDetails: { - stack: (error.response?.data as any).message, - traceId: (error.response?.data as any).traceId, - }, - }; - } - throw newError; - } - - // We actually have an exception; wrap it and re-throw. - throw new GenkitToolsError(message || 'Internal Error', { - cause: error.cause, - }); - } - /** * Handles a stream error by reading the stream and then calling httpErrorHandler. */ diff --git a/genkit-tools/common/src/server/router.ts b/genkit-tools/common/src/server/router.ts index e85a374c07..e137ac53b1 100644 --- a/genkit-tools/common/src/server/router.ts +++ b/genkit-tools/common/src/server/router.ts @@ -21,7 +21,7 @@ import { runNewEvaluation, validateSchema, } from '../eval'; -import type { RuntimeManager } from '../manager/manager'; +import type { BaseRuntimeManager } from '../manager/manager'; import { AppProcessStatus } from '../manager/process-manager'; import { GenkitToolsError, type RuntimeInfo } from '../manager/types'; import { TraceDataSchema } from '../types'; @@ -125,7 +125,7 @@ const loggedProcedure = t.procedure.use(async (opts) => { }); // eslint-disable-next-line @typescript-eslint/explicit-function-return-type -export const TOOLS_SERVER_ROUTER = (manager: RuntimeManager) => +export const TOOLS_SERVER_ROUTER = (manager: BaseRuntimeManager) => t.router({ /** Retrieves all runnable actions. */ listActions: loggedProcedure diff --git a/genkit-tools/common/src/server/server.ts b/genkit-tools/common/src/server/server.ts index 5ef7caf3ff..de20bd41e8 100644 --- a/genkit-tools/common/src/server/server.ts +++ b/genkit-tools/common/src/server/server.ts @@ -23,7 +23,7 @@ import type { Server } from 'http'; import os from 'os'; import path from 'path'; import type { GenkitToolsError } from '../manager'; -import type { RuntimeManager } from '../manager/manager'; +import type { BaseRuntimeManager } from '../manager/manager'; import { writeToolsInfoFile } from '../utils'; import { logger } from '../utils/logger'; import { toolsPackage } from '../utils/package'; @@ -43,10 +43,50 @@ const UI_ASSETS_ROOT = path.resolve( const UI_ASSETS_SERVE_PATH = path.resolve(UI_ASSETS_ROOT, 'ui', 'browser'); const API_BASE_PATH = '/api'; +class PushableAsyncIterable implements AsyncIterable { + private queue: T[] = []; + private resolvers: ((value: IteratorResult) => void)[] = []; + private closed = false; + + [Symbol.asyncIterator]() { + return { + next: () => this.next(), + }; + } + + next(): Promise> { + if (this.queue.length > 0) { + return Promise.resolve({ value: this.queue.shift()!, done: false }); + } + if (this.closed) { + return Promise.resolve({ value: undefined as any, done: true }); + } + return new Promise((resolve) => this.resolvers.push(resolve)); + } + + push(value: T) { + if (this.closed) return; + if (this.resolvers.length > 0) { + this.resolvers.shift()!({ value, done: false }); + } else { + this.queue.push(value); + } + } + + close() { + this.closed = true; + while (this.resolvers.length > 0) { + this.resolvers.shift()!({ value: undefined as any, done: true }); + } + } +} + +const activeInputStreams = new Map>(); + /** * Starts up the Genkit Tools server which includes static files for the UI and the Tools API. */ -export function startServer(manager: RuntimeManager, port: number) { +export function startServer(manager: BaseRuntimeManager, port: number) { let server: Server; const app = express(); @@ -131,30 +171,77 @@ export function startServer(manager: RuntimeManager, port: number) { res.flushHeaders(); } + let inputStream: PushableAsyncIterable | undefined; + const bidi = req.query.bidi === 'true'; + if (bidi) { + inputStream = new PushableAsyncIterable(); + } + + let capturedTraceId: string | undefined; + try { - const onTraceIdCallback = !manager.disableRealtimeTelemetry - ? (traceId: string) => { - // Set trace ID header and flush - this fires before first chunk - res.setHeader('X-Genkit-Trace-Id', traceId); - res.flushHeaders(); - } - : undefined; + const onTraceIdCallback = (traceId: string) => { + capturedTraceId = traceId; + // Set trace ID header and flush - this fires before first chunk + if (!manager.disableRealtimeTelemetry) { + res.setHeader('X-Genkit-Trace-Id', traceId); + res.flushHeaders(); + } + if (bidi && inputStream) { + activeInputStreams.set(traceId, inputStream); + } + }; const result = await manager.runAction( req.body, (chunk) => { res.write(JSON.stringify(chunk) + '\n'); }, - onTraceIdCallback + onTraceIdCallback, + inputStream ); res.write(JSON.stringify(result)); } catch (err) { res.write(JSON.stringify({ error: (err as GenkitToolsError).data })); + } finally { + if (capturedTraceId) { + activeInputStreams.delete(capturedTraceId); + } } res.end(); } ); + app.post( + '/api/sendBidiInput', + bodyParser.json({ limit: MAX_PAYLOAD_SIZE }), + (req, res) => { + const { traceId, chunk } = req.body; + const stream = activeInputStreams.get(traceId); + if (stream) { + stream.push(chunk); + res.status(200).send('OK'); + } else { + res.status(404).send('Stream not found'); + } + } + ); + + app.post('/api/endBidiInput', bodyParser.json(), (req, res) => { + const { traceId } = req.body; + const stream = activeInputStreams.get(traceId); + if (stream) { + stream.close(); + // Don't delete here, wait for action to complete (finally block) + // or delete if we want to ensure no more writes. + // If we delete here, subsequent writes will fail, which is correct. + // But finally block handles cleanup anyway. + res.status(200).send('OK'); + } else { + res.status(404).send('Stream not found'); + } + }); + app.post( '/api/streamTrace', bodyParser.json({ limit: MAX_PAYLOAD_SIZE }), diff --git a/genkit-tools/common/src/utils/eval.ts b/genkit-tools/common/src/utils/eval.ts index 8f0eefe5e6..3992debf94 100644 --- a/genkit-tools/common/src/utils/eval.ts +++ b/genkit-tools/common/src/utils/eval.ts @@ -19,7 +19,7 @@ import { randomUUID } from 'crypto'; import { createReadStream } from 'fs'; import { readFile } from 'fs/promises'; import { createInterface } from 'readline'; -import type { RuntimeManager } from '../manager'; +import type { BaseRuntimeManager } from '../manager'; import { findToolsConfig, isEvalField, @@ -323,7 +323,7 @@ async function readLines(fileName: string): Promise { } export async function hasAction(params: { - manager: RuntimeManager; + manager: BaseRuntimeManager; actionRef: string; }): Promise { const { manager, actionRef } = { ...params }; @@ -333,7 +333,7 @@ export async function hasAction(params: { } export async function getAction(params: { - manager: RuntimeManager; + manager: BaseRuntimeManager; actionRef: string; }): Promise { const { manager, actionRef } = { ...params }; diff --git a/genkit-tools/common/tests/manager-v2_test.ts b/genkit-tools/common/tests/manager-v2_test.ts new file mode 100644 index 0000000000..6c2caa3041 --- /dev/null +++ b/genkit-tools/common/tests/manager-v2_test.ts @@ -0,0 +1,436 @@ +/** + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { afterEach, beforeEach, describe, expect, it } from '@jest/globals'; +import WebSocket from 'ws'; +import { RuntimeManagerV2 } from '../src/manager/manager-v2'; +import { RuntimeEvent } from '../src/manager/types'; + +describe('RuntimeManagerV2', () => { + let manager: RuntimeManagerV2; + let wsClient: WebSocket; + let port: number; + + beforeEach(async () => { + manager = await RuntimeManagerV2.create({ + projectRoot: './', + }); + port = manager.port!; + }); + + afterEach(async () => { + if (wsClient) { + wsClient.close(); + } + // Clean up server + await manager.stop(); + }); + + it('should accept connections and handle registration', (done) => { + wsClient = new WebSocket(`ws://localhost:${port}`); + + const unsubscribe = manager.onRuntimeEvent((event, runtime) => { + if (event === RuntimeEvent.ADD) { + expect(runtime.id).toBe('test-runtime-1'); + expect(runtime.pid).toBe(1234); + expect(manager.listRuntimes().length).toBe(1); + unsubscribe(); + done(); + } + }); + + wsClient.on('open', () => { + const registerMessage = { + jsonrpc: '2.0', + method: 'register', + params: { + id: 'test-runtime-1', + pid: 1234, + name: 'Test Runtime', + genkitVersion: '0.0.1', + reflectionApiSpecVersion: 1, + }, + id: 1, + }; + wsClient.send(JSON.stringify(registerMessage)); + }); + }); + + it('should allow unsubscribing from runtime events', async () => { + wsClient = new WebSocket(`ws://localhost:${port}`); + const listener = jest.fn(); + + const unsubscribe = manager.onRuntimeEvent(listener); + + await new Promise((resolve) => { + wsClient.on('open', () => { + wsClient.send( + JSON.stringify({ + jsonrpc: '2.0', + method: 'register', + params: { id: 'test-runtime-unsubscribe', pid: 1234 }, + id: 1, + }) + ); + setTimeout(resolve, 100); + }); + }); + + // Wait for event + await new Promise((resolve) => setTimeout(resolve, 100)); + expect(listener).toHaveBeenCalled(); + + unsubscribe(); + listener.mockClear(); + + // Trigger another event (e.g. disconnect) + wsClient.close(); + await new Promise((resolve) => setTimeout(resolve, 100)); + + expect(listener).not.toHaveBeenCalled(); + }); + + it('should send requests and handle responses', async () => { + wsClient = new WebSocket(`ws://localhost:${port}`); + + await new Promise((resolve) => { + wsClient.on('open', () => { + wsClient.send( + JSON.stringify({ + jsonrpc: '2.0', + method: 'register', + params: { id: 'test-runtime-2', pid: 1234 }, + id: 1, + }) + ); + // Wait for server to acknowledge or just wait a bit + setTimeout(resolve, 100); + }); + }); + + // Mock runtime response to runAction + wsClient.on('message', (data) => { + const message = JSON.parse(data.toString()); + if (message.method === 'runAction') { + const response = { + jsonrpc: '2.0', + result: { + result: 'Hello World', + telemetry: { + traceId: '1234', + }, + }, + id: message.id, + }; + wsClient.send(JSON.stringify(response)); + } + }); + + const response = await manager.runAction({ + key: 'testAction', + input: {}, + }); + + expect(response.result).toBe('Hello World'); + expect(response.telemetry).toStrictEqual({ + traceId: '1234', + }); + }); + + it('should handle listValues', async () => { + wsClient = new WebSocket(`ws://localhost:${port}`); + + await new Promise((resolve) => { + wsClient.on('open', () => { + wsClient.send( + JSON.stringify({ + jsonrpc: '2.0', + method: 'register', + params: { id: 'test-runtime-values', pid: 1234 }, + id: 1, + }) + ); + setTimeout(resolve, 100); + }); + }); + + wsClient.on('message', (data) => { + const message = JSON.parse(data.toString()); + if (message.method === 'listValues') { + const response = { + jsonrpc: '2.0', + result: { + 'my-prompt': { template: 'foo' }, + }, + id: message.id, + }; + wsClient.send(JSON.stringify(response)); + } + }); + + const values = await manager.listValues({ + type: 'prompt', + }); + + expect(values['my-prompt']).toBeDefined(); + expect(values['my-prompt']).toEqual({ template: 'foo' }); + }); + + it('should handle streaming', async () => { + wsClient = new WebSocket(`ws://localhost:${port}`); + + await new Promise((resolve) => { + wsClient.on('open', () => { + wsClient.send( + JSON.stringify({ + jsonrpc: '2.0', + method: 'register', + params: { id: 'test-runtime-3', pid: 1234 }, + id: 1, + }) + ); + setTimeout(resolve, 100); + }); + }); + + wsClient.on('message', (data) => { + const message = JSON.parse(data.toString()); + if (message.method === 'runAction' && message.params.stream) { + // Send chunk 1 + wsClient.send( + JSON.stringify({ + jsonrpc: '2.0', + method: 'streamChunk', + params: { requestId: message.id, chunk: { content: 'Hello' } }, + }) + ); + // Send chunk 2 + wsClient.send( + JSON.stringify({ + jsonrpc: '2.0', + method: 'streamChunk', + params: { requestId: message.id, chunk: { content: ' World' } }, + }) + ); + // Send final result + wsClient.send( + JSON.stringify({ + jsonrpc: '2.0', + result: { result: 'Hello World', telemetry: {} }, + id: message.id, + }) + ); + } + }); + + const chunks: any[] = []; + const response = await manager.runAction( + { + key: 'testAction', + input: {}, + }, + (chunk) => { + chunks.push(chunk); + } + ); + + expect(chunks).toHaveLength(2); + expect(chunks[0]).toEqual({ content: 'Hello' }); + expect(chunks[1]).toEqual({ content: ' World' }); + expect(response.result).toBe('Hello World'); + expect(response.telemetry).toBeDefined(); + }); + + it('should handle streaming errors and massage the error object', async () => { + wsClient = new WebSocket(`ws://localhost:${port}`); + + await new Promise((resolve) => { + wsClient.on('open', () => { + wsClient.send( + JSON.stringify({ + jsonrpc: '2.0', + method: 'register', + params: { id: 'test-runtime-error', pid: 1234 }, + id: 1, + }) + ); + setTimeout(resolve, 100); + }); + }); + + wsClient.on('message', (data) => { + const message = JSON.parse(data.toString()); + if (message.method === 'runAction' && message.params.stream) { + // Send chunk 1 + wsClient.send( + JSON.stringify({ + jsonrpc: '2.0', + method: 'streamChunk', + params: { requestId: message.id, chunk: { content: 'Hello' } }, + }) + ); + // Send error + const errorResponse = { + code: -32000, + message: 'Test Error', + data: { + code: 13, + message: 'Test Error', + details: { + stack: 'Error stack...', + traceId: 'trace-123', + }, + }, + }; + wsClient.send( + JSON.stringify({ + jsonrpc: '2.0', + error: errorResponse, + id: message.id, + }) + ); + } + }); + + const chunks: any[] = []; + try { + await manager.runAction( + { + key: 'testAction', + input: {}, + }, + (chunk) => { + chunks.push(chunk); + } + ); + throw new Error('Should have thrown'); + } catch (err: any) { + expect(chunks).toHaveLength(1); + expect(chunks[0]).toEqual({ content: 'Hello' }); + expect(err.message).toBe('Test Error'); + expect(err.data).toBeDefined(); + expect(err.data.data.genkitErrorMessage).toBe('Test Error'); + expect(err.data.stack).toBe('Error stack...'); + expect(err.data.data.genkitErrorDetails).toEqual({ + stack: 'Error stack...', + traceId: 'trace-123', + }); + } + }); + + it('should send cancelAction request', async () => { + wsClient = new WebSocket(`ws://localhost:${port}`); + + await new Promise((resolve) => { + wsClient.on('open', () => { + wsClient.send( + JSON.stringify({ + jsonrpc: '2.0', + method: 'register', + params: { id: 'test-runtime-cancel', pid: 1234 }, + id: 1, + }) + ); + setTimeout(resolve, 100); + }); + }); + + wsClient.on('message', (data) => { + const message = JSON.parse(data.toString()); + if (message.method === 'cancelAction') { + const response = { + jsonrpc: '2.0', + result: { + message: 'Action cancelled', + }, + id: message.id, + }; + wsClient.send(JSON.stringify(response)); + } + }); + + const response = await manager.cancelAction({ + traceId: '1234', + }); + + expect(response.message).toBe('Action cancelled'); + }); + + it('should handle runActionState for early trace info', async () => { + wsClient = new WebSocket(`ws://localhost:${port}`); + + await new Promise((resolve) => { + wsClient.on('open', () => { + wsClient.send( + JSON.stringify({ + jsonrpc: '2.0', + method: 'register', + params: { id: 'test-runtime-trace', pid: 1234 }, + id: 1, + }) + ); + setTimeout(resolve, 100); + }); + }); + + wsClient.on('message', (data) => { + const message = JSON.parse(data.toString()); + if (message.method === 'runAction') { + // Send runActionState with traceId + wsClient.send( + JSON.stringify({ + jsonrpc: '2.0', + method: 'runActionState', + params: { + requestId: message.id, + state: { + traceId: 'early-trace-id', + }, + }, + }) + ); + + // Send final result + const response = { + jsonrpc: '2.0', + result: { + result: 'Hello World', + telemetry: { + traceId: 'early-trace-id', + }, + }, + id: message.id, + }; + wsClient.send(JSON.stringify(response)); + } + }); + + let capturedTraceId: string | undefined; + const response = await manager.runAction( + { + key: 'testAction', + input: {}, + }, + undefined, + (traceId) => { + capturedTraceId = traceId; + } + ); + + expect(capturedTraceId).toBe('early-trace-id'); + expect(response.result).toBe('Hello World'); + }); +}); diff --git a/genkit-tools/common/tests/server_test.ts b/genkit-tools/common/tests/server_test.ts new file mode 100644 index 0000000000..876657041b --- /dev/null +++ b/genkit-tools/common/tests/server_test.ts @@ -0,0 +1,165 @@ +/** + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { + afterEach, + beforeEach, + describe, + expect, + it, + jest, +} from '@jest/globals'; +import axios from 'axios'; +import getPort from 'get-port'; +import { BaseRuntimeManager } from '../src/manager/manager'; +import { startServer } from '../src/server/server'; + +describe('Tools Server', () => { + let port: number; + let serverPromise: Promise; + let mockManager: any; + + beforeEach(async () => { + port = await getPort(); + mockManager = { + projectRoot: './', + disableRealtimeTelemetry: false, + runAction: jest.fn(), + streamTrace: jest.fn(), + listActions: jest.fn(), + listTraces: jest.fn(), + getTrace: jest.fn(), + getMostRecentRuntime: jest.fn(), + listRuntimes: jest.fn(), + onRuntimeEvent: jest.fn(), + cancelAction: jest.fn(), + }; + serverPromise = startServer(mockManager as BaseRuntimeManager, port); + }); + + afterEach(async () => { + const exitSpy = jest + .spyOn(process, 'exit') + .mockImplementation((code?: any) => { + return undefined as never; + }); + try { + await axios.post(`http://localhost:${port}/api/__quitquitquit`); + } catch (e) { + // Ignore + } + await serverPromise; + exitSpy.mockRestore(); + }); + + it('should handle runAction', async () => { + mockManager.runAction.mockResolvedValue({ result: 'bar' }); + + let response; + try { + response = await axios.post(`http://localhost:${port}/api/runAction`, { + key: 'foo', + input: 'bar', + }); + } catch (e: any) { + throw new Error(`runAction failed: ${e.message}`); + } + + expect(response.data.result).toBe('bar'); + expect(mockManager.runAction).toHaveBeenCalledWith( + expect.objectContaining({ key: 'foo' }), + undefined, + expect.any(Function) + ); + }); + + it('should handle bidi streaming', async () => { + let inputStream: AsyncIterable | undefined; + let finishAction: (() => void) | undefined; + + mockManager.runAction.mockImplementation( + async (input: any, cb: any, trace: any, stream: any) => { + inputStream = stream; + await new Promise((resolve) => { + finishAction = resolve; + }); + return { result: 'done' }; + } + ); + + const responsePromise = axios + .post( + `http://localhost:${port}/api/streamAction?bidi=true`, + { key: 'bidi' }, + { responseType: 'stream' } + ) + .catch((e) => { + throw new Error(`Stream action failed: ${e.message}`); + }); + + // Wait for runAction to be called + while (!inputStream) { + await new Promise((r) => setTimeout(r, 10)); + } + + const traceId = 'test-trace-id'; + // Get the onTraceId callback from the mock call args + const [inputArg, cb, onTraceIdCallback] = + mockManager.runAction.mock.calls[0]; + onTraceIdCallback(traceId); + + // Collect input chunks in background + const chunks: any[] = []; + const collectPromise = (async () => { + for await (const chunk of inputStream!) { + chunks.push(chunk); + } + })(); + + // Now send input + try { + await axios.post(`http://localhost:${port}/api/sendBidiInput`, { + traceId, + chunk: 'input1', + }); + + await axios.post(`http://localhost:${port}/api/endBidiInput`, { + traceId, + }); + } catch (e: any) { + throw new Error(`send/end input failed: ${e.message}`); + } + + await collectPromise; + expect(chunks).toEqual(['input1']); + + // Emit output chunk + if (cb) cb({ result: 'chunk1' }); + + // Finish action + finishAction!(); + + const response = await responsePromise; + const stream = response.data; + const outputChunks: string[] = []; + for await (const chunk of stream) { + outputChunks.push(chunk.toString()); + } + const output = outputChunks.join(''); + expect(output).toContain('chunk1'); + expect(output).toContain('done'); + }); +}); diff --git a/genkit-tools/pnpm-lock.yaml b/genkit-tools/pnpm-lock.yaml index ea1b78beac..fef3cebbcd 100644 --- a/genkit-tools/pnpm-lock.yaml +++ b/genkit-tools/pnpm-lock.yaml @@ -144,6 +144,9 @@ importers: cors: specifier: ^2.8.5 version: 2.8.5 + events: + specifier: ^3.3.0 + version: 3.3.0 express: specifier: ^4.21.0 version: 4.21.2 @@ -174,6 +177,9 @@ importers: winston: specifier: ^3.11.0 version: 3.17.0 + ws: + specifier: ^8.18.3 + version: 8.18.3 yaml: specifier: ^2.4.1 version: 2.8.0 @@ -202,6 +208,9 @@ importers: '@types/cors': specifier: ^2.8.19 version: 2.8.19 + '@types/events': + specifier: ^3.0.3 + version: 3.0.3 '@types/express': specifier: ^4.17.21 version: 4.17.23 @@ -223,6 +232,9 @@ importers: '@types/uuid': specifier: ^9.0.8 version: 9.0.8 + '@types/ws': + specifier: ^8.18.1 + version: 8.18.1 bun-types: specifier: ^1.2.16 version: 1.2.16 @@ -280,6 +292,9 @@ importers: async-mutex: specifier: ^0.5.0 version: 0.5.0 + cors: + specifier: ^2.8.5 + version: 2.8.5 express: specifier: ^4.21.0 version: 4.21.2 @@ -290,6 +305,9 @@ importers: specifier: ^3.22.4 version: 3.25.67 devDependencies: + '@types/cors': + specifier: ^2.8.19 + version: 2.8.19 '@types/express': specifier: ~4.17.21 version: 4.17.23 @@ -1120,6 +1138,9 @@ packages: '@types/cors@2.8.19': resolution: {integrity: sha512-mFNylyeyqN93lfe/9CSxOGREz8cpzAhH+E93xJ4xWQf62V8sQ/24reV2nyzUWM6H6Xji+GGHpkbLe7pVoUEskg==} + '@types/events@3.0.3': + resolution: {integrity: sha512-trOc4AAUThEz9hapPtSd7wf5tiQKvTtu5b371UxXdTuqzIh0ArcRspRP0i0Viu+LXstIQ1z96t1nsPxT9ol01g==} + '@types/express-serve-static-core@4.19.0': resolution: {integrity: sha512-bGyep3JqPCRry1wq+O5n7oiBgGWmeIJXPjXXCo8EK0u8duZGSYar7cGqd3ML2JUsLGeB7fmc06KYo9fLGWqPvQ==} @@ -1201,6 +1222,9 @@ packages: '@types/uuid@9.0.8': resolution: {integrity: sha512-jg+97EGIcY9AGHJJRaaPVgetKDsrTgbRjQ5Msgjh/DQKEFl0DtyRr/VCOyD1T2R1MNeWPK/u7JoGhlDZnKBAfA==} + '@types/ws@8.18.1': + resolution: {integrity: sha512-ThVF6DCVhA8kUGy+aazFQ4kXQ7E1Ty7A3ypFOe0IcJV8O/M511G99AW24irKrW56Wt44yG9+ij8FaqoBGkuBXg==} + '@types/yargs-parser@21.0.3': resolution: {integrity: sha512-I4q9QU9MQv4oEOz4tAHJtNz1cwuLxn2F3xcc2iV5WdqLPpUnj30aUuxt1mAxYTG+oe8CZMV/+6rU4S4gRDzqtQ==} @@ -1809,6 +1833,10 @@ packages: resolution: {integrity: sha512-i/2XbnSz/uxRCU6+NdVJgKWDTM427+MqYbkQzD321DuCQJUqOuJKIA0IM2+W2xtYHdKOmZ4dR6fExsd4SXL+WQ==} engines: {node: '>=6'} + events@3.3.0: + resolution: {integrity: sha512-mQw+2fkQbALzQ7V0MY0IqdnXNOeTtP4r0lN9z7AAawCXgqea7bDii20AYrIBrFd/Hx0M2Ocz6S111CaFkUcb0Q==} + engines: {node: '>=0.8.x'} + eventsource-parser@3.0.3: resolution: {integrity: sha512-nVpZkTMM9rF6AQ9gPJpFsNAMt48wIzB5TQgiTLdHiuO8XEDhUgZEhqKlZWXbIzo9VmJ/HvysHqEaVeD5v9TPvA==} engines: {node: '>=20.0.0'} @@ -3473,6 +3501,18 @@ packages: resolution: {integrity: sha512-7KxauUdBmSdWnmpaGFg+ppNjKF8uNLry8LyzjauQDOVONfFLNKrKvQOxZ/VuTIcS/gge/YNahf5RIIQWTSarlg==} engines: {node: ^12.13.0 || ^14.15.0 || >=16.0.0} + ws@8.18.3: + resolution: {integrity: sha512-PEIGCY5tSlUt50cqyMXfCzX+oOPqN0vuGqWzbcJ2xvnkzkq46oOpz7dQaTDBdfICb4N14+GARUDw2XV2N4tvzg==} + engines: {node: '>=10.0.0'} + peerDependencies: + bufferutil: ^4.0.1 + utf-8-validate: '>=5.0.2' + peerDependenciesMeta: + bufferutil: + optional: true + utf-8-validate: + optional: true + xdg-basedir@4.0.0: resolution: {integrity: sha512-PSNhEJDejZYV7h50BohL09Er9VaIefr2LMAf3OEmpCkjOi34eYyQYAXUTjEQtZJTKcF0E2UKTh+osDLsgNim9Q==} engines: {node: '>=8'} @@ -4433,6 +4473,8 @@ snapshots: dependencies: '@types/node': 20.19.1 + '@types/events@3.0.3': {} + '@types/express-serve-static-core@4.19.0': dependencies: '@types/node': 20.19.1 @@ -4525,6 +4567,10 @@ snapshots: '@types/uuid@9.0.8': {} + '@types/ws@8.18.1': + dependencies: + '@types/node': 20.19.1 + '@types/yargs-parser@21.0.3': {} '@types/yargs@17.0.32': @@ -5224,6 +5270,8 @@ snapshots: event-target-shim@5.0.1: {} + events@3.3.0: {} + eventsource-parser@3.0.3: {} eventsource@3.0.7: @@ -7251,6 +7299,8 @@ snapshots: imurmurhash: 0.1.4 signal-exit: 3.0.7 + ws@8.18.3: {} + xdg-basedir@4.0.0: {} y18n@5.0.8: {} diff --git a/genkit-tools/telemetry-server/package.json b/genkit-tools/telemetry-server/package.json index df17335908..f6cebf5cfb 100644 --- a/genkit-tools/telemetry-server/package.json +++ b/genkit-tools/telemetry-server/package.json @@ -37,9 +37,11 @@ "async-mutex": "^0.5.0", "express": "^4.21.0", "lockfile": "^1.0.4", - "zod": "^3.22.4" + "zod": "^3.22.4", + "cors": "^2.8.5" }, "devDependencies": { + "@types/cors": "^2.8.19", "@types/express": "~4.17.21", "@types/lockfile": "^1.0.4", "@types/node": "^20.11.30", diff --git a/genkit-tools/telemetry-server/src/index.ts b/genkit-tools/telemetry-server/src/index.ts index c24536f794..f9228a74fc 100644 --- a/genkit-tools/telemetry-server/src/index.ts +++ b/genkit-tools/telemetry-server/src/index.ts @@ -20,6 +20,7 @@ import { type SpanData, } from '@genkit-ai/tools-common'; import { logger } from '@genkit-ai/tools-common/utils'; +import cors from 'cors'; import express from 'express'; import type * as http from 'http'; import { BroadcastManager } from './broadcast-manager.js'; @@ -46,10 +47,31 @@ export async function startTelemetryServer(params: { * Defaults to '5mb'. */ maxRequestBodySize?: string | number; + allowedCorsHostnames?: string[]; }) { await params.traceStore.init(); const api = express(); + api.use( + cors({ + origin: (origin, callback) => { + // Allow requests with no origin (like mobile apps or curl requests) + if (!origin) return callback(null, true); + + const hostname = new URL(origin).hostname; + if ( + hostname === 'localhost' || + hostname === '127.0.0.1' || + params.allowedCorsHostnames?.includes(hostname) + ) { + return callback(null, true); + } + + return callback(new Error('Not allowed by CORS'), false); + }, + }) + ); + api.use(express.json({ limit: params.maxRequestBodySize ?? '100mb' })); api.get('/api/__health', async (_, response) => { diff --git a/js/core/package.json b/js/core/package.json index 0b1282ebef..5fb10daa79 100644 --- a/js/core/package.json +++ b/js/core/package.json @@ -43,9 +43,11 @@ "get-port": "^5.1.0", "json-schema": "^0.4.0", "zod": "^3.23.8", - "zod-to-json-schema": "^3.22.4" + "zod-to-json-schema": "^3.22.4", + "ws": "^8.18.0" }, "devDependencies": { + "@types/ws": "^8.5.10", "@types/express": "^4.17.21", "@types/node": "^20.11.30", "genversion": "^3.2.0", diff --git a/js/core/src/reflection-v2.ts b/js/core/src/reflection-v2.ts new file mode 100644 index 0000000000..0179aa0fec --- /dev/null +++ b/js/core/src/reflection-v2.ts @@ -0,0 +1,408 @@ +/** + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import WebSocket from 'ws'; +import { StatusCodes, type Status } from './action.js'; +import { Channel } from './async.js'; +import { GENKIT_REFLECTION_API_SPEC_VERSION, GENKIT_VERSION } from './index.js'; +import { logger } from './logging.js'; +import type { Registry } from './registry.js'; +import { toJsonSchema } from './schema.js'; +import { flushTracing, setTelemetryServerUrl } from './tracing.js'; + +let apiIndex = 0; + +interface JsonRpcRequest { + jsonrpc: '2.0'; + method: string; + params?: any; + id?: number | string; +} + +interface JsonRpcResponse { + jsonrpc: '2.0'; + result?: any; + error?: { + code: number; + message: string; + data?: any; + }; + id: number | string; +} + +type JsonRpcMessage = JsonRpcRequest | JsonRpcResponse; + +export interface ReflectionServerV2Options { + configuredEnvs?: string[]; + name?: string; + url: string; +} + +export class ReflectionServerV2 { + private registry: Registry; + private options: ReflectionServerV2Options; + private ws: WebSocket | null = null; + private url: string; + private index = apiIndex++; + private activeActions = new Map< + string, + { + abortController: AbortController; + startTime: Date; + } + >(); + private activeRequests = new Map>(); + + constructor(registry: Registry, options: ReflectionServerV2Options) { + this.registry = registry; + this.options = { + configuredEnvs: ['dev'], + ...options, + }; + // The URL should be provided via environment variable by the CLI manager + this.url = this.options.url; + } + + async start() { + logger.debug(`Connecting to Reflection V2 server at ${this.url}`); + this.ws = new WebSocket(this.url); + + this.ws.on('open', () => { + logger.debug('Connected to Reflection V2 server.'); + this.register(); + }); + + this.ws.on('message', async (data) => { + try { + const message = JSON.parse(data.toString()) as JsonRpcMessage; + if ('method' in message) { + await this.handleRequest(message as JsonRpcRequest); + } + } catch (error) { + logger.error(`Failed to parse message: ${error}`); + } + }); + + this.ws.on('error', (error) => { + logger.error(`Reflection V2 WebSocket error: ${error}`); + }); + + this.ws.on('close', () => { + logger.debug('Reflection V2 WebSocket closed.'); + }); + } + + async stop() { + if (this.ws) { + this.ws.close(); + this.ws = null; + } + } + + private send(message: JsonRpcMessage) { + if (this.ws && this.ws.readyState === WebSocket.OPEN) { + this.ws.send(JSON.stringify(message)); + } + } + + private sendResponse(id: number | string, result: any) { + this.send({ + jsonrpc: '2.0', + result, + id, + }); + } + + private sendError( + id: number | string, + code: number, + message: string, + data?: any + ) { + this.send({ + jsonrpc: '2.0', + error: { code, message, data }, + id, + }); + } + + private sendNotification(method: string, params: any) { + this.send({ + jsonrpc: '2.0', + method, + params, + }); + } + + private register() { + const params = { + id: process.env.GENKIT_RUNTIME_ID || this.runtimeId, + pid: process.pid, + name: this.options.name || this.runtimeId, + genkitVersion: GENKIT_VERSION, + reflectionApiSpecVersion: GENKIT_REFLECTION_API_SPEC_VERSION, + envs: this.options.configuredEnvs, + }; + this.sendNotification('register', params); + } + + get runtimeId() { + return `${process.pid}${this.index ? `-${this.index}` : ''}`; + } + + private async handleRequest(request: JsonRpcRequest) { + try { + switch (request.method) { + case 'listActions': + await this.handleListActions(request); + break; + case 'listValues': + await this.handleListValues(request); + break; + case 'runAction': + await this.handleRunAction(request); + break; + case 'configure': + this.handleConfigure(request); + break; + case 'cancelAction': + await this.handleCancelAction(request); + break; + case 'streamInputChunk': + this.handleStreamInputChunk(request); + break; + case 'endStreamInput': + this.handleEndStreamInput(request); + break; + default: + if (request.id) { + this.sendError( + request.id, + -32601, + `Method not found: ${request.method}` + ); + } + } + } catch (error: any) { + if (request.id) { + this.sendError(request.id, -32000, error.message, { + stack: error.stack, + }); + } + } + } + + private async handleListActions(request: JsonRpcRequest) { + if (!request.id) return; // Should be a request + const actions = await this.registry.listResolvableActions(); + const convertedActions: Record = {}; + + Object.keys(actions).forEach((key) => { + const action = actions[key]; + convertedActions[key] = { + key, + name: action.name, + description: action.description, + metadata: action.metadata, + }; + if (action.inputSchema || action.inputJsonSchema) { + convertedActions[key].inputSchema = toJsonSchema({ + schema: action.inputSchema, + jsonSchema: action.inputJsonSchema, + }); + } + if (action.outputSchema || action.outputJsonSchema) { + convertedActions[key].outputSchema = toJsonSchema({ + schema: action.outputSchema, + jsonSchema: action.outputJsonSchema, + }); + } + }); + + this.sendResponse(request.id, convertedActions); + } + + private async handleListValues(request: JsonRpcRequest) { + if (!request.id) return; + const { type } = request.params; + const values = await this.registry.listValues(type); + this.sendResponse(request.id, values); + } + + private async handleRunAction(request: JsonRpcRequest) { + if (!request.id) return; + + const { key, input, context, telemetryLabels, stream, streamInput } = + request.params; + const action = await this.registry.lookupAction(key); + + if (!action) { + this.sendError(request.id, 404, `action ${key} not found`); + return; + } + + const abortController = new AbortController(); + let traceId: string | undefined; + let inputStream: Channel | undefined; + + // Set up input stream for bidi actions + if (action.__action.metadata?.bidi) { + inputStream = new Channel(); + this.activeRequests.set(request.id, inputStream); + + // If initial input is provided, send it + if (input !== undefined) { + inputStream.send(input); + } + + // If input streaming is not requested, close the stream immediately + // effectively treating initial input as the only input. + if (!streamInput) { + inputStream.close(); + } + } + + try { + const onTraceStartCallback = ({ traceId: tid }: { traceId: string }) => { + traceId = tid; + this.activeActions.set(tid, { + abortController, + startTime: new Date(), + }); + // Send early trace ID notification + this.sendNotification('runActionState', { + requestId: request.id, + state: { traceId: tid }, + }); + }; + + if (stream) { + const callback = (chunk: any) => { + this.sendNotification('streamChunk', { + requestId: request.id, + chunk, + }); + }; + + const result = await action.run(input, { + context, + onChunk: callback, + telemetryLabels, + onTraceStart: onTraceStartCallback, + abortSignal: abortController.signal, + inputStream, + }); + + await flushTracing(); + + // Send final result + this.sendResponse(request.id, { + result: result.result, + telemetry: { + traceId: result.telemetry.traceId, + }, + }); + } else { + const result = await action.run(input, { + context, + telemetryLabels, + onTraceStart: onTraceStartCallback, + abortSignal: abortController.signal, + inputStream, + }); + await flushTracing(); + + this.sendResponse(request.id, { + result: result.result, + telemetry: { + traceId: result.telemetry.traceId, + }, + }); + } + } catch (err: any) { + const isAbort = + err?.name === 'AbortError' || + (typeof DOMException !== 'undefined' && + err instanceof DOMException && + err.name === 'AbortError'); + + const errorResponse: Status = { + code: isAbort ? StatusCodes.CANCELLED : StatusCodes.INTERNAL, + message: isAbort ? 'Action was cancelled' : err.message, + details: { + stack: err.stack, + }, + }; + if (err.traceId || traceId) { + errorResponse.details.traceId = err.traceId || traceId; + } + + this.sendError(request.id, -32000, errorResponse.message, errorResponse); + } finally { + if (traceId) { + this.activeActions.delete(traceId); + } + if (request.id) { + this.activeRequests.delete(request.id); + } + } + } + + private handleConfigure(request: JsonRpcRequest) { + const { telemetryServerUrl } = request.params; + if (telemetryServerUrl && !process.env.GENKIT_TELEMETRY_SERVER) { + setTelemetryServerUrl(telemetryServerUrl); + logger.debug(`Connected to telemetry server on ${telemetryServerUrl}`); + } + } + + private async handleCancelAction(request: JsonRpcRequest) { + if (!request.id) return; + const { traceId } = request.params; + if (!traceId || typeof traceId !== 'string') { + this.sendError(request.id, 400, 'traceId is required'); + return; + } + const activeAction = this.activeActions.get(traceId); + if (activeAction) { + activeAction.abortController.abort(); + this.activeActions.delete(traceId); + this.sendResponse(request.id, { message: 'Action cancelled' }); + } else { + this.sendError(request.id, 404, 'Action not found or already completed'); + } + } + + private handleStreamInputChunk(request: JsonRpcRequest) { + const { requestId, chunk } = request.params; + const channel = this.activeRequests.get(requestId); + if (channel) { + channel.send(chunk); + } else { + logger.warn(`Received input chunk for unknown request ${requestId}`); + } + } + + private handleEndStreamInput(request: JsonRpcRequest) { + const { requestId } = request.params; + const channel = this.activeRequests.get(requestId); + if (channel) { + channel.close(); + } else { + logger.warn(`Received end stream input for unknown request ${requestId}`); + } + } +} diff --git a/js/core/src/reflection.ts b/js/core/src/reflection.ts index a63959bf13..ec2ca7f726 100644 --- a/js/core/src/reflection.ts +++ b/js/core/src/reflection.ts @@ -92,6 +92,7 @@ export class ReflectionServer { startTime: Date; } >(); + private v2Server: any | null = null; constructor(registry: Registry, options?: ReflectionServerOptions) { this.registry = registry; @@ -135,6 +136,17 @@ export class ReflectionServer { ); return; } + if (process.env.GENKIT_REFLECTION_V2_SERVER) { + const { ReflectionServerV2 } = await import('./reflection-v2.js'); + this.v2Server = new ReflectionServerV2(this.registry, { + configuredEnvs: this.options.configuredEnvs, + name: this.options.name, + url: process.env.GENKIT_REFLECTION_V2_SERVER, + }); + await this.v2Server.start(); + ReflectionServer.RUNNING_SERVERS.push(this); + return; + } const server = express(); @@ -439,6 +451,15 @@ export class ReflectionServer { * Stops the server and removes it from the list of running servers to clean up on exit. */ async stop(): Promise { + if (this.v2Server) { + await this.v2Server.stop(); + const index = ReflectionServer.RUNNING_SERVERS.indexOf(this); + if (index > -1) { + ReflectionServer.RUNNING_SERVERS.splice(index, 1); + } + return; + } + if (!this.server) { return; } diff --git a/js/core/tests/reflection-v2_test.ts b/js/core/tests/reflection-v2_test.ts new file mode 100644 index 0000000000..28de731c03 --- /dev/null +++ b/js/core/tests/reflection-v2_test.ts @@ -0,0 +1,470 @@ +/** + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import * as assert from 'assert'; +import { afterEach, beforeEach, describe, it } from 'node:test'; +import { WebSocketServer } from 'ws'; +import { z } from 'zod'; +import { action, defineBidiAction } from '../src/action.js'; +import { initNodeFeatures } from '../src/node.js'; +import { ReflectionServerV2 } from '../src/reflection-v2.js'; +import { Registry } from '../src/registry.js'; + +initNodeFeatures(); + +describe('ReflectionServerV2', () => { + let wss: WebSocketServer; + let server: ReflectionServerV2; + let registry: Registry; + let port: number; + let serverWs: any; + + beforeEach(() => { + return new Promise((resolve) => { + wss = new WebSocketServer({ port: 0 }); + wss.on('listening', () => { + port = (wss.address() as any).port; + resolve(); + }); + wss.on('connection', (ws) => { + serverWs = ws; + }); + registry = new Registry(); + }); + }); + + afterEach(async () => { + if (server) { + await server.stop(); + } + if (serverWs) { + serverWs.terminate(); + } + await new Promise((resolve) => { + wss.close(() => resolve()); + }); + }); + + it('should connect to the server and register', async () => { + const connected = new Promise((resolve) => { + wss.on('connection', (ws) => { + ws.on('message', (data) => { + const msg = JSON.parse(data.toString()); + if (msg.method === 'register') { + assert.strictEqual(msg.params.name, 'test-app'); + resolve(); + } + }); + }); + }); + + server = new ReflectionServerV2(registry, { + url: `ws://localhost:${port}`, + name: 'test-app', + }); + await server.start(); + await connected; + }); + + it('should handle listActions', async () => { + // Register a dummy action + const testAction = action( + { + name: 'testAction', + description: 'A test action', + inputSchema: z.object({ foo: z.string() }), + outputSchema: z.object({ bar: z.string() }), + actionType: 'custom', + }, + async (input) => ({ bar: input.foo }) + ); + registry.registerAction('custom', testAction); + + const gotListActions = new Promise((resolve) => { + wss.on('connection', (ws) => { + ws.on('message', (data) => { + const msg = JSON.parse(data.toString()); + if (msg.method === 'register') { + // After registration, request listActions + ws.send( + JSON.stringify({ + jsonrpc: '2.0', + method: 'listActions', + id: '123', + }) + ); + } else if (msg.id === '123') { + assert.ok(msg.result['/custom/testAction']); + assert.strictEqual( + msg.result['/custom/testAction'].name, + 'testAction' + ); + resolve(); + } + }); + }); + }); + + server = new ReflectionServerV2(registry, { + url: `ws://localhost:${port}`, + }); + await server.start(); + await gotListActions; + }); + + it('should handle listValues', async () => { + registry.registerValue('prompt', 'my-prompt', { template: 'foo' }); + + const gotListValues = new Promise((resolve) => { + wss.on('connection', (ws) => { + ws.on('message', (data) => { + const msg = JSON.parse(data.toString()); + if (msg.method === 'register') { + ws.send( + JSON.stringify({ + jsonrpc: '2.0', + method: 'listValues', + params: { type: 'prompt' }, + id: '124', + }) + ); + } else if (msg.id === '124') { + assert.ok(msg.result['my-prompt']); + assert.strictEqual(msg.result['my-prompt'].template, 'foo'); + resolve(); + } + }); + }); + }); + + server = new ReflectionServerV2(registry, { + url: `ws://localhost:${port}`, + }); + await server.start(); + await gotListValues; + }); + + it('should handle runAction', async () => { + const testAction = action( + { + name: 'testAction', + inputSchema: z.object({ foo: z.string() }), + outputSchema: z.object({ bar: z.string() }), + actionType: 'custom', + }, + async (input) => ({ bar: input.foo }) + ); + registry.registerAction('custom', testAction); + + const actionRun = new Promise((resolve, reject) => { + const timeout = setTimeout( + () => reject(new Error('runAction timeout')), + 2000 + ); + wss.on('connection', (ws) => { + ws.on('message', (data) => { + try { + const msg = JSON.parse(data.toString()); + if (msg.method === 'register') { + ws.send( + JSON.stringify({ + jsonrpc: '2.0', + method: 'runAction', + params: { + key: '/custom/testAction', + input: { foo: 'baz' }, + }, + id: '456', + }) + ); + } else if (msg.id === '456') { + if (msg.error) { + reject( + new Error(`runAction error: ${JSON.stringify(msg.error)}`) + ); + return; + } + assert.strictEqual(msg.result.result.bar, 'baz'); + clearTimeout(timeout); + resolve(); + } + } catch (e) { + clearTimeout(timeout); + reject(e); + } + }); + }); + }); + + server = new ReflectionServerV2(registry, { + url: `ws://localhost:${port}`, + }); + await server.start(); + await actionRun; + }); + + it('should handle streaming runAction', async () => { + const streamAction = action( + { + name: 'streamAction', + inputSchema: z.object({ foo: z.string() }), + outputSchema: z.string(), + actionType: 'custom', + }, + async (input, { sendChunk }) => { + sendChunk('chunk1'); + sendChunk('chunk2'); + return 'done'; + } + ); + registry.registerAction('custom', streamAction); + + const chunks: any[] = []; + const actionRun = new Promise((resolve, reject) => { + const timeout = setTimeout( + () => reject(new Error('streamAction timeout')), + 2000 + ); + wss.on('connection', (ws) => { + ws.on('message', (data) => { + try { + const msg = JSON.parse(data.toString()); + if (msg.method === 'register') { + ws.send( + JSON.stringify({ + jsonrpc: '2.0', + method: 'runAction', + params: { + key: '/custom/streamAction', + input: { foo: 'baz' }, + stream: true, + }, + id: '789', + }) + ); + } else if (msg.method === 'streamChunk') { + chunks.push(msg.params.chunk); + } else if (msg.id === '789') { + if (msg.error) { + reject( + new Error(`streamAction error: ${JSON.stringify(msg.error)}`) + ); + return; + } + assert.strictEqual(msg.result.result, 'done'); + assert.deepStrictEqual(chunks, ['chunk1', 'chunk2']); + clearTimeout(timeout); + resolve(); + } + } catch (e) { + clearTimeout(timeout); + reject(e); + } + }); + }); + }); + + server = new ReflectionServerV2(registry, { + url: `ws://localhost:${port}`, + }); + await server.start(); + await actionRun; + }); + + it('should handle cancelAction', async () => { + let cancelSignal: AbortSignal | undefined; + const longAction = action( + { + name: 'longAction', + inputSchema: z.any(), + outputSchema: z.any(), + actionType: 'custom', + }, + async (_, { abortSignal }) => { + cancelSignal = abortSignal; + await new Promise((resolve, reject) => { + const timer = setTimeout(resolve, 5000); + if (abortSignal.aborted) { + clearTimeout(timer); + reject(new Error('Action cancelled')); + return; + } + abortSignal.addEventListener('abort', () => { + clearTimeout(timer); + reject(new Error('Action cancelled')); + }); + }); + } + ); + registry.registerAction('custom', longAction); + + const actionCancelled = new Promise((resolve, reject) => { + const timeout = setTimeout( + () => reject(new Error('cancelAction timeout')), + 2000 + ); + wss.on('connection', (ws) => { + ws.on('message', (data) => { + try { + const msg = JSON.parse(data.toString()); + if (msg.method === 'register') { + // Start action + ws.send( + JSON.stringify({ + jsonrpc: '2.0', + method: 'runAction', + params: { + key: '/custom/longAction', + input: {}, + }, + id: '999', + }) + ); + } else if (msg.method === 'runActionState') { + // Got traceId, send cancel + const traceId = msg.params.state.traceId; + ws.send( + JSON.stringify({ + jsonrpc: '2.0', + method: 'cancelAction', + params: { traceId }, + id: '1000', + }) + ); + } else if (msg.id === '1000') { + // Cancel response + assert.strictEqual(msg.result.message, 'Action cancelled'); + } else if (msg.id === '999') { + // Run action response (should be error) + if (msg.error) { + // Ensure code indicates cancellation if possible, or just error + // In implementation we send code -32000 and message 'Action was cancelled' + assert.match(msg.error.message, /cancelled/); + assert.ok(cancelSignal?.aborted); + clearTimeout(timeout); + resolve(); + } else { + reject(new Error('Action should have failed')); + } + } + } catch (e) { + clearTimeout(timeout); + reject(e); + } + }); + }); + }); + + server = new ReflectionServerV2(registry, { + url: `ws://localhost:${port}`, + }); + await server.start(); + await actionCancelled; + }); + + it('should handle bidi streaming runAction', async () => { + defineBidiAction( + registry, + { + name: 'bidiAction', + inputSchema: z.string(), + outputSchema: z.string(), + actionType: 'custom', + }, + async function* ({ inputStream }) { + for await (const chunk of inputStream) { + yield `echo ${chunk}`; + } + return 'done'; + } + ); + + const chunks: any[] = []; + const actionRun = new Promise((resolve, reject) => { + const timeout = setTimeout( + () => reject(new Error('bidiAction timeout')), + 2000 + ); + wss.on('connection', (ws) => { + ws.on('message', (data) => { + try { + const msg = JSON.parse(data.toString()); + if (msg.method === 'register') { + ws.send( + JSON.stringify({ + jsonrpc: '2.0', + method: 'runAction', + params: { + key: '/custom/bidiAction', + stream: true, + streamInput: true, + }, + id: '111', + }) + ); + // Send input chunks shortly after + setTimeout(() => { + ws.send( + JSON.stringify({ + jsonrpc: '2.0', + method: 'streamInputChunk', + params: { requestId: '111', chunk: 'foo' }, + }) + ); + ws.send( + JSON.stringify({ + jsonrpc: '2.0', + method: 'streamInputChunk', + params: { requestId: '111', chunk: 'bar' }, + }) + ); + ws.send( + JSON.stringify({ + jsonrpc: '2.0', + method: 'endStreamInput', + params: { requestId: '111' }, + }) + ); + }, 10); + } else if (msg.method === 'streamChunk') { + chunks.push(msg.params.chunk); + } else if (msg.id === '111') { + if (msg.error) { + reject( + new Error(`bidiAction error: ${JSON.stringify(msg.error)}`) + ); + return; + } + assert.strictEqual(msg.result.result, 'done'); + assert.deepStrictEqual(chunks, ['echo foo', 'echo bar']); + clearTimeout(timeout); + resolve(); + } + } catch (e) { + clearTimeout(timeout); + reject(e); + } + }); + }); + }); + + server = new ReflectionServerV2(registry, { + url: `ws://localhost:${port}`, + }); + await server.start(); + await actionRun; + }); +}); diff --git a/js/pnpm-lock.yaml b/js/pnpm-lock.yaml index b365e14875..0bb794172b 100644 --- a/js/pnpm-lock.yaml +++ b/js/pnpm-lock.yaml @@ -144,6 +144,9 @@ importers: json-schema: specifier: ^0.4.0 version: 0.4.0 + ws: + specifier: ^8.18.0 + version: 8.18.3 zod: specifier: ^3.23.8 version: 3.25.67 @@ -157,6 +160,9 @@ importers: '@types/node': specifier: ^20.11.30 version: 20.19.1 + '@types/ws': + specifier: ^8.5.10 + version: 8.18.1 genversion: specifier: ^3.2.0 version: 3.2.0 @@ -2787,11 +2793,27 @@ packages: firebase: optional: true + '@genkit-ai/firebase@1.25.0': + resolution: {integrity: sha512-Z0FbnJHQs8qS0yxG++Dn3CZ7gv+YNaihGaWXoDKy02mNOkeRzHA6UPaWxSTaWkWHYdB0MyOnMGlyqxnWyqVdmg==} + peerDependencies: + '@google-cloud/firestore': ^7.11.0 + firebase: '>=11.5.0' + firebase-admin: '>=12.2' + genkit: ^1.25.0 + peerDependenciesMeta: + firebase: + optional: true + '@genkit-ai/google-cloud@1.16.1': resolution: {integrity: sha512-uujjdGr/sra7iKHApufwkt5jGo7CQcRCJNWPgnSg4g179CjtvtZBGjxmFRVBtKzuF61ktkY6E9JoLz83nWEyAA==} peerDependencies: genkit: ^1.16.1 + '@genkit-ai/google-cloud@1.25.0': + resolution: {integrity: sha512-wHCa8JSTv7MtwzXjUQ9AT5v0kCTJrz0In+ffgAYw1yt8ComAz5o7Ir+xks+sX1vJfN8ptvW0GUa6rsUaXCB3kA==} + peerDependencies: + genkit: ^1.25.0 + '@gerrit0/mini-shiki@1.27.2': resolution: {integrity: sha512-GeWyHz8ao2gBiUW4OJnQDxXQnFgZQwwQk05t/CVVgNBN7/rK8XZ7xY6YhLVv9tH3VppWWmr9DCl3MwemB/i+Og==} @@ -4430,6 +4452,9 @@ packages: '@types/node@20.19.1': resolution: {integrity: sha512-jJD50LtlD2dodAEO653i3YF04NWak6jN3ky+Ri3Em3mGR39/glWiboM/IePaRbgwSfqM1TpGXfAg8ohn/4dTgA==} + '@types/node@20.19.26': + resolution: {integrity: sha512-0l6cjgF0XnihUpndDhk+nyD3exio3iKaYROSgvh/qSevPXax3L8p5DBRFjbvalnwatGgHEQn2R88y2fA3g4irg==} + '@types/node@22.15.32': resolution: {integrity: sha512-3jigKqgSjsH6gYZv2nEsqdXfZqIFGAV36XYYjf9KGZ3PSG+IhLecqPnI310RvjutyMwifE2hhhNEklOUrvx/wA==} @@ -4508,6 +4533,9 @@ packages: '@types/whatwg-url@11.0.5': resolution: {integrity: sha512-coYR071JRaHa+xoEvvYqvnIHaVqaYrLPbsufM9BF63HkwI5Lgmy2QR8Q5K/lYDYo5AK82wOvSOS0UsLTpTG7uQ==} + '@types/ws@8.18.1': + resolution: {integrity: sha512-ThVF6DCVhA8kUGy+aazFQ4kXQ7E1Ty7A3ypFOe0IcJV8O/M511G99AW24irKrW56Wt44yG9+ij8FaqoBGkuBXg==} + '@types/yargs-parser@21.0.3': resolution: {integrity: sha512-I4q9QU9MQv4oEOz4tAHJtNz1cwuLxn2F3xcc2iV5WdqLPpUnj30aUuxt1mAxYTG+oe8CZMV/+6rU4S4gRDzqtQ==} @@ -9777,7 +9805,7 @@ snapshots: dependencies: '@genkit-ai/core': 1.28.0(@google-cloud/firestore@7.11.1(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.6.0(encoding@0.1.13))(firebase@11.9.1)(genkit@1.28.0(@google-cloud/firestore@7.11.1(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.6.0(encoding@0.1.13))(firebase@11.9.1)) '@opentelemetry/api': 1.9.0 - '@types/node': 20.19.1 + '@types/node': 20.19.26 colorette: 2.0.20 dotprompt: 1.1.1 json5: 2.2.3 @@ -9798,7 +9826,7 @@ snapshots: dependencies: '@genkit-ai/core': 1.28.0(@google-cloud/firestore@7.11.1(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.6.0(encoding@0.1.13))(firebase@11.9.1)(genkit@genkit) '@opentelemetry/api': 1.9.0 - '@types/node': 20.19.1 + '@types/node': 20.19.26 colorette: 2.0.20 dotprompt: 1.1.1 json5: 2.2.3 @@ -9836,7 +9864,7 @@ snapshots: zod-to-json-schema: 3.24.5(zod@3.25.67) optionalDependencies: '@cfworker/json-schema': 4.1.1 - '@genkit-ai/firebase': 1.16.1(@google-cloud/firestore@7.11.1(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.6.0(encoding@0.1.13))(firebase@11.9.1)(genkit@1.28.0(@google-cloud/firestore@7.11.1(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.6.0(encoding@0.1.13))(firebase@11.9.1)) + '@genkit-ai/firebase': 1.25.0(@google-cloud/firestore@7.11.1(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.6.0(encoding@0.1.13))(firebase@11.9.1)(genkit@1.28.0(@google-cloud/firestore@7.11.1(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.6.0(encoding@0.1.13))(firebase@11.9.1)) transitivePeerDependencies: - '@google-cloud/firestore' - encoding @@ -9868,7 +9896,7 @@ snapshots: zod-to-json-schema: 3.24.5(zod@3.25.67) optionalDependencies: '@cfworker/json-schema': 4.1.1 - '@genkit-ai/firebase': 1.16.1(@google-cloud/firestore@7.11.1(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.6.0(encoding@0.1.13))(firebase@11.9.1)(genkit@genkit) + '@genkit-ai/firebase': 1.25.0(@google-cloud/firestore@7.11.1(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.6.0(encoding@0.1.13))(firebase@11.9.1)(genkit@genkit) transitivePeerDependencies: - '@google-cloud/firestore' - encoding @@ -9900,9 +9928,22 @@ snapshots: - supports-color optional: true - '@genkit-ai/firebase@1.16.1(@google-cloud/firestore@7.11.1(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.6.0(encoding@0.1.13))(firebase@11.9.1)(genkit@genkit)': + '@genkit-ai/firebase@1.25.0(@google-cloud/firestore@7.11.1(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.6.0(encoding@0.1.13))(firebase@11.9.1)(genkit@1.28.0(@google-cloud/firestore@7.11.1(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.6.0(encoding@0.1.13))(firebase@11.9.1))': dependencies: - '@genkit-ai/google-cloud': 1.16.1(encoding@0.1.13)(genkit@genkit) + '@genkit-ai/google-cloud': 1.25.0(encoding@0.1.13)(genkit@1.28.0(@google-cloud/firestore@7.11.1(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.6.0(encoding@0.1.13))(firebase@11.9.1)) + '@google-cloud/firestore': 7.11.1(encoding@0.1.13) + firebase-admin: 13.6.0(encoding@0.1.13) + genkit: 1.28.0(@google-cloud/firestore@7.11.1(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.6.0(encoding@0.1.13))(firebase@11.9.1) + optionalDependencies: + firebase: 11.9.1 + transitivePeerDependencies: + - encoding + - supports-color + optional: true + + '@genkit-ai/firebase@1.25.0(@google-cloud/firestore@7.11.1(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.6.0(encoding@0.1.13))(firebase@11.9.1)(genkit@genkit)': + dependencies: + '@genkit-ai/google-cloud': 1.25.0(encoding@0.1.13)(genkit@genkit) '@google-cloud/firestore': 7.11.1(encoding@0.1.13) firebase-admin: 13.6.0(encoding@0.1.13) genkit: link:genkit @@ -9938,7 +9979,32 @@ snapshots: - supports-color optional: true - '@genkit-ai/google-cloud@1.16.1(encoding@0.1.13)(genkit@genkit)': + '@genkit-ai/google-cloud@1.25.0(encoding@0.1.13)(genkit@1.28.0(@google-cloud/firestore@7.11.1(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.6.0(encoding@0.1.13))(firebase@11.9.1))': + dependencies: + '@google-cloud/logging-winston': 6.0.1(encoding@0.1.13)(winston@3.17.0) + '@google-cloud/opentelemetry-cloud-monitoring-exporter': 0.19.0(@opentelemetry/api@1.9.0)(@opentelemetry/core@1.25.1(@opentelemetry/api@1.9.0))(@opentelemetry/resources@1.25.1(@opentelemetry/api@1.9.0))(@opentelemetry/sdk-metrics@1.25.1(@opentelemetry/api@1.9.0))(encoding@0.1.13) + '@google-cloud/opentelemetry-cloud-trace-exporter': 2.4.1(@opentelemetry/api@1.9.0)(@opentelemetry/core@1.25.1(@opentelemetry/api@1.9.0))(@opentelemetry/resources@1.25.1(@opentelemetry/api@1.9.0))(@opentelemetry/sdk-trace-base@1.25.1(@opentelemetry/api@1.9.0))(encoding@0.1.13) + '@google-cloud/opentelemetry-resource-util': 2.4.0(@opentelemetry/resources@1.25.1(@opentelemetry/api@1.9.0))(encoding@0.1.13) + '@opentelemetry/api': 1.9.0 + '@opentelemetry/auto-instrumentations-node': 0.49.2(@opentelemetry/api@1.9.0)(encoding@0.1.13) + '@opentelemetry/core': 1.25.1(@opentelemetry/api@1.9.0) + '@opentelemetry/instrumentation': 0.52.1(@opentelemetry/api@1.9.0) + '@opentelemetry/instrumentation-pino': 0.41.0(@opentelemetry/api@1.9.0) + '@opentelemetry/instrumentation-winston': 0.39.0(@opentelemetry/api@1.9.0) + '@opentelemetry/resources': 1.25.1(@opentelemetry/api@1.9.0) + '@opentelemetry/sdk-metrics': 1.25.1(@opentelemetry/api@1.9.0) + '@opentelemetry/sdk-node': 0.52.1(@opentelemetry/api@1.9.0) + '@opentelemetry/sdk-trace-base': 1.25.1(@opentelemetry/api@1.9.0) + genkit: 1.28.0(@google-cloud/firestore@7.11.1(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.6.0(encoding@0.1.13))(firebase@11.9.1) + google-auth-library: 9.15.1(encoding@0.1.13) + node-fetch: 3.3.2 + winston: 3.17.0 + transitivePeerDependencies: + - encoding + - supports-color + optional: true + + '@genkit-ai/google-cloud@1.25.0(encoding@0.1.13)(genkit@genkit)': dependencies: '@google-cloud/logging-winston': 6.0.1(encoding@0.1.13)(winston@3.17.0) '@google-cloud/opentelemetry-cloud-monitoring-exporter': 0.19.0(@opentelemetry/api@1.9.0)(@opentelemetry/core@1.25.1(@opentelemetry/api@1.9.0))(@opentelemetry/resources@1.25.1(@opentelemetry/api@1.9.0))(@opentelemetry/sdk-metrics@1.25.1(@opentelemetry/api@1.9.0))(encoding@0.1.13) @@ -11683,7 +11749,7 @@ snapshots: '@types/graceful-fs@4.1.9': dependencies: - '@types/node': 20.19.1 + '@types/node': 20.19.26 '@types/handlebars@4.1.0': dependencies: @@ -11757,6 +11823,10 @@ snapshots: dependencies: undici-types: 6.21.0 + '@types/node@20.19.26': + dependencies: + undici-types: 6.21.0 + '@types/node@22.15.32': dependencies: undici-types: 6.21.0 @@ -11841,6 +11911,10 @@ snapshots: dependencies: '@types/webidl-conversions': 7.0.3 + '@types/ws@8.18.1': + dependencies: + '@types/node': 20.19.26 + '@types/yargs-parser@21.0.3': {} '@types/yargs@17.0.33': @@ -14335,7 +14409,7 @@ snapshots: '@jest/expect': 29.7.0 '@jest/test-result': 29.7.0 '@jest/types': 29.6.3 - '@types/node': 20.19.1 + '@types/node': 20.19.26 chalk: 4.1.2 co: 4.6.0 dedent: 1.5.3 @@ -14479,7 +14553,7 @@ snapshots: '@jest/environment': 29.7.0 '@jest/fake-timers': 29.7.0 '@jest/types': 29.6.3 - '@types/node': 20.19.1 + '@types/node': 20.19.26 jest-mock: 29.7.0 jest-util: 29.7.0 @@ -14665,7 +14739,7 @@ snapshots: jest-worker@29.7.0: dependencies: - '@types/node': 20.19.1 + '@types/node': 20.19.26 jest-util: 29.7.0 merge-stream: 2.0.0 supports-color: 8.1.1 @@ -15903,7 +15977,7 @@ snapshots: '@protobufjs/path': 1.1.2 '@protobufjs/pool': 1.1.0 '@protobufjs/utf8': 1.1.0 - '@types/node': 20.19.1 + '@types/node': 20.19.26 long: 5.3.2 proxy-addr@2.0.7: