Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion frontend/src/core/dom/events.ts
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ export const MarimoIncomingMessageEvent = defineCustomEvent(
)<{
objectId: UIElementId;
message: unknown;
buffers: DataView[] | undefined;
buffers: readonly DataView[];
}>();
export type MarimoIncomingMessageEventType = ReturnType<
typeof MarimoIncomingMessageEvent.create
Expand Down
10 changes: 2 additions & 8 deletions frontend/src/core/dom/uiregistry.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
/* Copyright 2024 Marimo. All rights reserved. */

import { byteStringToDataView } from "@/utils/data-views";
import type { Base64String } from "@/utils/json/base64";
import { typedAtob } from "@/utils/json/base64";
import { Logger } from "@/utils/Logger";
import type { CellId, UIElementId } from "../cells/ids";
import {
Expand Down Expand Up @@ -136,15 +133,12 @@ export class UIElementRegistry {
broadcastMessage(
objectId: UIElementId,
message: unknown,
buffers: Base64String[] | undefined | null,
buffers: readonly DataView[],
): void {
const entry = this.entries.get(objectId);
if (entry === undefined) {
Logger.warn("UIElementRegistry missing entry", objectId);
} else {
const toDataView = (base64: Base64String) => {
return byteStringToDataView(typedAtob(base64));
};
entry.elements.forEach((element) => {
element.dispatchEvent(
MarimoIncomingMessageEvent.create({
Expand All @@ -153,7 +147,7 @@ export class UIElementRegistry {
detail: {
objectId: objectId,
message: message,
buffers: buffers ? buffers.map(toDataView) : undefined,
buffers: buffers,
},
}),
);
Expand Down
4 changes: 2 additions & 2 deletions frontend/src/core/islands/main.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import { renderHTML } from "@/plugins/core/RenderHTML";
import { initializePlugins } from "@/plugins/plugins";
import { logNever } from "@/utils/assertNever";
import { Functions } from "@/utils/functions";
import type { Base64String } from "@/utils/json/base64";
import { safeExtractSetUIElementMessageBuffers } from "@/utils/json/base64";
import { jsonParseWithSpecialChar } from "@/utils/json/json-parser";
import { Logger } from "@/utils/Logger";
import {
Expand Down Expand Up @@ -145,7 +145,7 @@ export async function initialize() {
UI_ELEMENT_REGISTRY.broadcastMessage(
msg.data.ui_element as UIElementId,
msg.data.message,
msg.data.buffers as Base64String[],
safeExtractSetUIElementMessageBuffers(msg.data),
);
return;

Expand Down
7 changes: 5 additions & 2 deletions frontend/src/core/websocket/useMarimoWebSocket.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@ import {
} from "@/plugins/impl/anywidget/model";
import { logNever } from "@/utils/assertNever";
import { prettyError } from "@/utils/errors";
import type { Base64String, JsonString } from "@/utils/json/base64";
import {
type JsonString,
safeExtractSetUIElementMessageBuffers,
} from "@/utils/json/base64";
import { jsonParseWithSpecialChar } from "@/utils/json/json-parser";
import { Logger } from "@/utils/Logger";
import { reloadSafe } from "@/utils/reload-safe";
Expand Down Expand Up @@ -112,7 +115,7 @@ export function useMarimoWebSocket(opts: {
const modelId = msg.data.model_id;
const uiElement = msg.data.ui_element;
const message = msg.data.message;
const buffers = (msg.data.buffers ?? []) as Base64String[];
const buffers = safeExtractSetUIElementMessageBuffers(msg.data);

if (modelId && isMessageWidgetState(message)) {
handleWidgetMessage({
Expand Down
32 changes: 25 additions & 7 deletions frontend/src/plugins/impl/anywidget/AnyWidgetPlugin.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
/* eslint-disable @typescript-eslint/no-explicit-any */

import type { AnyWidget, Experimental } from "@anywidget/types";
import { isEqual } from "lodash-es";
import { get, isEqual, set } from "lodash-es";
import { useEffect, useMemo, useRef } from "react";
import { z } from "zod";
import { MarimoIncomingMessageEvent } from "@/core/dom/events";
Expand All @@ -16,7 +16,11 @@ import {
import { createPlugin } from "@/plugins/core/builder";
import { rpc } from "@/plugins/core/rpc";
import type { IPluginProps } from "@/plugins/types";
import { updateBufferPaths } from "@/utils/data-views";
import {
type Base64String,
byteStringToBinary,
typedAtob,
} from "@/utils/json/base64";
import { Logger } from "@/utils/Logger";
import { ErrorBanner } from "../common/error-banner";
import { MODEL_MANAGER, Model } from "./model";
Expand Down Expand Up @@ -59,6 +63,11 @@ type Props = IPluginProps<T, Data, PluginFunctions>;

const AnyWidgetSlot = (props: Props) => {
const { css, jsUrl, jsHash, bufferPaths } = props.data;

const valueWithBuffers = useMemo(() => {
return resolveInitialValue(props.value, bufferPaths ?? []);
}, [props.value, bufferPaths]);

// JS is an ESM file with a render function on it
// export function render({ model, el }) {
// ...
Expand All @@ -85,10 +94,6 @@ const AnyWidgetSlot = (props: Props) => {
}
}, [hasError, jsUrl]);

const valueWithBuffer = useMemo(() => {
return updateBufferPaths(props.value, bufferPaths);
}, [props.value, bufferPaths]);

// Mount the CSS
useEffect(() => {
const shadowRoot = props.host.shadowRoot;
Expand Down Expand Up @@ -157,7 +162,7 @@ const AnyWidgetSlot = (props: Props) => {
key={key}
{...props}
widget={module.default}
value={valueWithBuffer}
value={valueWithBuffers}
/>
);
};
Expand Down Expand Up @@ -284,3 +289,16 @@ export const visibleForTesting = {
isAnyWidgetModule,
getDirtyFields,
};

export function resolveInitialValue(
raw: Record<string, any>,
bufferPaths: ReadonlyArray<ReadonlyArray<string | number>>,
) {
const out = structuredClone(raw);
for (const bufferPath of bufferPaths) {
const base64String: Base64String = get(raw, bufferPath);
const bytes = byteStringToBinary(typedAtob(base64String));
set(out, bufferPath, new DataView(bytes.buffer));
}
return out;
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@ import { beforeEach, describe, expect, it, vi } from "vitest";
import { TestUtils } from "@/__tests__/test-helpers";
import type { UIElementId } from "@/core/cells/ids";
import { MarimoIncomingMessageEvent } from "@/core/dom/events";
import { getDirtyFields, visibleForTesting } from "../AnyWidgetPlugin";
import {
getDirtyFields,
resolveInitialValue,
visibleForTesting,
} from "../AnyWidgetPlugin";
import { Model } from "../model";

const { LoadedSlot } = visibleForTesting;
Expand Down Expand Up @@ -129,7 +133,7 @@ describe("LoadedSlot", () => {
method: "update",
state: { count: 10 },
},
buffers: undefined,
buffers: [],
},
bubbles: false,
composed: true,
Expand Down Expand Up @@ -179,3 +183,55 @@ describe("LoadedSlot", () => {
});
});
});

describe("resolveInitialValue", () => {
it("should convert base64 strings to DataView at specified paths", () => {
const result = resolveInitialValue(
{
a: 10,
b: "aGVsbG8=", // "hello" in base64
c: [1, "d29ybGQ="], // "world" in base64
d: {
foo: "bWFyaW1vCg==", // "marimo" in base64
baz: 20,
},
},
[["b"], ["c", 1], ["d", "foo"]],
);

expect(result).toMatchInlineSnapshot(`
{
"a": 10,
"b": DataView [
104,
101,
108,
108,
111,
],
"c": [
1,
DataView [
119,
111,
114,
108,
100,
],
],
"d": {
"baz": 20,
"foo": DataView [
109,
97,
114,
105,
109,
111,
10,
],
},
}
`);
});
});
7 changes: 3 additions & 4 deletions frontend/src/plugins/impl/anywidget/__tests__/model.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import {
vi,
} from "vitest";
import { TestUtils } from "@/__tests__/test-helpers";
import type { Base64String } from "@/utils/json/base64";
import {
type AnyWidgetMessage,
handleWidgetMessage,
Expand Down Expand Up @@ -246,7 +245,7 @@ describe("Model", () => {
content,
});

expect(callback).toHaveBeenCalledWith(content, undefined);
expect(callback).toHaveBeenCalledWith(content, []);
});

it("should handle custom messages with buffers", () => {
Expand Down Expand Up @@ -286,7 +285,7 @@ describe("ModelManager", () => {
}: {
modelId: string;
message: AnyWidgetMessage;
buffers: Base64String[];
buffers: readonly DataView[];
}) => {
return handleWidgetMessage({
modelId,
Expand Down Expand Up @@ -353,7 +352,7 @@ describe("ModelManager", () => {
message: { method: "custom", content: { count: 1 } },
buffers: [],
});
expect(callback).toHaveBeenCalledWith({ count: 1 }, undefined);
expect(callback).toHaveBeenCalledWith({ count: 1 }, []);
});

it("should handle close messages", async () => {
Expand Down
5 changes: 2 additions & 3 deletions frontend/src/plugins/impl/anywidget/model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import { assertNever } from "@/utils/assertNever";
import { Deferred } from "@/utils/Deferred";
import { updateBufferPaths } from "@/utils/data-views";
import { throwNotImplemented } from "@/utils/functions";
import type { Base64String } from "@/utils/json/base64";
import { Logger } from "@/utils/Logger";

export type EventHandler = (...args: any[]) => void;
Expand Down Expand Up @@ -171,7 +170,7 @@ export class Model<T extends Record<string, any>> implements AnyModel<T> {
* When receiving a message from the backend.
* We want to notify all listeners with `msg:custom`
*/
receiveCustomMessage(message: any, buffers?: DataView[]): void {
receiveCustomMessage(message: any, buffers: readonly DataView[] = []): void {
const response = AnyWidgetMessageSchema.safeParse(message);
if (response.success) {
const data = response.data;
Expand Down Expand Up @@ -260,7 +259,7 @@ export async function handleWidgetMessage({
}: {
modelId: string;
msg: AnyWidgetMessage;
buffers: Base64String[];
buffers: readonly DataView[];
modelManager: ModelManager;
}): Promise<void> {
if (msg.method === "echo_update") {
Expand Down
Loading
Loading