Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion frontend/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
"@lezer/lr": "^1.4.2",
"@lezer/markdown": "^1.4.3",
"@lezer/python": "^1.1.18",
"@marimo-team/codemirror-ai": "^0.1.11",
"@marimo-team/codemirror-ai": "^0.2.1",
"@marimo-team/codemirror-languageserver": "^1.15.24",
"@marimo-team/marimo-api": "workspace:*",
"@marimo-team/react-slotz": "^0.2.0",
Expand Down
23 changes: 23 additions & 0 deletions frontend/src/components/app-config/user-config-form.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -1628,6 +1628,29 @@ export const UserConfigForm: React.FC = () => {
</div>
)}
/>
<FormField
control={form.control}
name="experimental.next_edit_prediction"
render={({ field }) => (
<div className="flex flex-col gap-y-1">
<FormItem className={formItemClasses}>
<FormLabel className="font-normal">
Next Edit Prediction
</FormLabel>
<FormControl>
<Checkbox
data-testid="inline-ai-checkbox"
checked={field.value === true}
onCheckedChange={field.onChange}
/>
</FormControl>
</FormItem>
<FormDescription>
Enable experimental next edit prediction.
</FormDescription>
</div>
)}
/>
<FormField
control={form.control}
name="experimental.rtc_v2"
Expand Down
85 changes: 52 additions & 33 deletions frontend/src/core/codemirror/copilot/extension.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ import type { EditorView } from "@codemirror/view";
import { keymap } from "@codemirror/view";
import {
inlineCompletion,
nextEditPrediction,
PredictionBackend,
rejectInlineCompletion,
} from "@marimo-team/codemirror-ai";
import {
Expand All @@ -24,8 +26,10 @@ import {
InlineCompletionTriggerKind,
} from "vscode-languageserver-protocol";
import type { CompletionConfig } from "@/core/config/config-schema";
import { getFeatureFlag } from "@/core/config/feature-flag";
import type { AiInlineCompletionRequest } from "@/core/kernel/messages";
import { API } from "@/core/network/api";
import { asRemoteURL } from "@/core/runtime/config";
import { Logger } from "@/utils/Logger";
import { languageAdapterState } from "../language/extension";
import { isInVimMode } from "../utils";
Expand Down Expand Up @@ -104,44 +108,59 @@ export const copilotBundle = (config: CompletionConfig): Extension => {
}

if (config.copilot === "custom") {
extensions.push(
inlineCompletion({
...commonInlineCompletionConfig,
fetchFn: async (state) => {
if (state.doc.length === 0) {
return "";
}
if (getFeatureFlag("next_edit_prediction")) {
extensions.push(
nextEditPrediction({
fetchFn: PredictionBackend.oxen({
model: "openai/oxen:ox-cold-olive-fox",
baseUrl: asRemoteURL("/api/ai/compat/chat").toString(),
headers: {
...API.headers(),
"x-marimo-ai-scope": "next-edit-prediction",
},
}),
}),
);
} else {
extensions.push(
inlineCompletion({
...commonInlineCompletionConfig,
fetchFn: async (state) => {
if (state.doc.length === 0) {
return "";
}

// If not focused, don't fetch
const prefix = state.doc.sliceString(0, state.selection.main.head);
const suffix = state.doc.sliceString(
state.selection.main.head,
state.doc.length,
);
// If not focused, don't fetch
const prefix = state.doc.sliceString(0, state.selection.main.head);
const suffix = state.doc.sliceString(
state.selection.main.head,
state.doc.length,
);

// If no prefix, don't fetch
if (prefix.length === 0) {
return "";
}
// If no prefix, don't fetch
if (prefix.length === 0) {
return "";
}

const language = state.field(languageAdapterState).type;
let res = await API.post<AiInlineCompletionRequest, string>(
"/ai/inline_completion",
{ prefix, suffix, language },
);
const language = state.field(languageAdapterState).type;
let res = await API.post<AiInlineCompletionRequest, string>(
"/ai/inline_completion",
{ prefix, suffix, language },
);

// Sometimes the prefix might get included in the response, so we need to trim it
if (prefix && res.startsWith(prefix)) {
res = res.slice(prefix.length);
}
if (suffix && res.endsWith(suffix)) {
res = res.slice(0, -suffix.length);
}
// Sometimes the prefix might get included in the response, so we need to trim it
if (prefix && res.startsWith(prefix)) {
res = res.slice(prefix.length);
}
if (suffix && res.endsWith(suffix)) {
res = res.slice(0, -suffix.length);
}

return res;
},
}),
);
return res;
},
}),
);
}
}

return [
Expand Down
2 changes: 2 additions & 0 deletions frontend/src/core/config/feature-flag.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import { getResolvedMarimoConfig } from "./config";
export interface ExperimentalFeatures {
markdown: boolean; // Used in playground (community cloud)
inline_ai_tooltip: boolean;
next_edit_prediction: boolean;
wasm_layouts: boolean; // Used in playground (community cloud)
rtc_v2: boolean;
performant_table_charts: boolean;
Expand All @@ -20,6 +21,7 @@ const defaultValues: ExperimentalFeatures = {
markdown: true,
inline_ai_tooltip: import.meta.env.DEV,
wasm_layouts: false,
next_edit_prediction: false,
rtc_v2: false,
performant_table_charts: false,
mcp_docs: false,
Expand Down
126 changes: 125 additions & 1 deletion marimo/_server/api/endpoints/ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from __future__ import annotations

from dataclasses import asdict
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any, NoReturn

from starlette.authentication import requires
from starlette.exceptions import HTTPException
Expand Down Expand Up @@ -66,6 +66,45 @@ def get_ai_config(config: MarimoConfig) -> AiConfig:
return ai_config


def parse_model_format(model: str) -> tuple[str, str]:
"""Parse model format <format>/<model> and return (format, model_name).

Args:
model: Model string in format 'format/model' (e.g., 'openai/gpt-4', 'anthropic/claude-3')

Returns:
Tuple of (format, model_name)

Raises:
HTTPException: If model format is invalid or format is not supported
"""

def raise_error() -> NoReturn:
raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST,
detail=f"Invalid model format '{model}'. Expected format: '<provider>/<model>' (e.g., 'openai/gpt-4', 'anthropic/claude-3', 'google/gemini-pro')",
)

if "/" not in model:
raise_error()

parts = model.split("/", 1)
if len(parts) != 2 or not parts[0] or not parts[1]:
raise_error()

format_name, model_name = parts

# Validate supported formats
supported_formats = {"openai", "anthropic", "google", "bedrock"}
if format_name not in supported_formats:
raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST,
detail=f"Unsupported provider '{format_name}'. Supported providers: {', '.join(sorted(supported_formats))}",
)

return format_name, model_name


@router.post("/completion")
@requires("edit")
async def ai_completion(
Expand Down Expand Up @@ -283,3 +322,88 @@ async def invoke_tool(
)

return JSONResponse(content=asdict(response))


@router.post("/compat/chat/completions")
@requires("edit")
async def compat_chat_completions(
*,
request: Request,
) -> Any:
"""
OpenAI-compatible chat completions endpoint.

requestBody:
description: The request body for chat completions
required: true
content:
application/json:
schema:
$ref: "#/components/schemas/ChatCompletionRequest"
responses:
200:
description: Chat completion response
content:
application/json:
schema:
type: object
properties:
choices:
type: array
items:
type: object
additionalProperties: true
"""
app_state = AppState(request)
app_state.require_current_session()
config = app_state.app_config_manager.get_config(hide_secrets=False)
body = await request.json()
ai_config = get_ai_config(config)

full_model = body.pop("model", None)
if full_model is None:
raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST,
detail="Model is required",
)
# Parse and validate model format
format_name, model_name = parse_model_format(full_model)

# Config for
scope = request.headers.get("x-marimo-ai-scope", None)
if scope == "next-edit-prediction":
provider_config = AnyProviderConfig.for_completion(
config["completion"],
)
else:
provider_config = AnyProviderConfig.for_model(model_name, ai_config)

supported_providers = {"openai"}

if format_name not in supported_providers:
raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST,
detail=f"Unsupported provider '{format_name}'. Supported providers: {', '.join(sorted(supported_providers))}",
)

# Create provider
from openai import AsyncOpenAI

client = AsyncOpenAI(
api_key=provider_config.api_key,
base_url=provider_config.base_url,
)

stream = body.pop("stream", False)
if stream:
return client.chat.completions.create(
**stream,
model=model_name,
stream=True,
)
response = await client.chat.completions.create(
**body,
model=model_name,
stream=False,
)
return JSONResponse(content=response.to_dict())
17 changes: 12 additions & 5 deletions pnpm-lock.yaml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading