Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
298 changes: 120 additions & 178 deletions marimo/_plugins/stateless/mpl/_mpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,19 @@

import asyncio
import io
import json
import mimetypes
import socket
import threading
from pathlib import Path
from typing import Any, Optional
from typing import Any, Optional, Tuple

import tornado
import tornado.httpserver
import tornado.ioloop
import tornado.netutil
import tornado.web
import tornado.websocket
import uvicorn
from matplotlib.figure import Figure
from starlette.applications import Starlette
from starlette.requests import Request
from starlette.responses import HTMLResponse, Response
from starlette.routing import Mount, Route, WebSocketRoute
from starlette.staticfiles import StaticFiles
from starlette.websockets import WebSocket

from marimo._output.builder import h
from marimo._output.hypertext import Html
Expand All @@ -32,160 +32,106 @@
RuntimeContext,
get_context,
)
from marimo._server.utils import find_free_port


def create_application(
figure: Any,
host: str,
port: int,
) -> Starlette:
import matplotlib as mpl # type: ignore[import-not-found,import-untyped,unused-ignore] # noqa: E501
from matplotlib.backends.backend_webagg import (
FigureManagerWebAgg,
new_figure_manager_given_figure,
)


class MplApplication(tornado.web.Application):
# Figure Manager, Any type because matplotlib doesn't have typings
manager: Any

class MainPage(tornado.web.RequestHandler):
"""
Serves the main HTML page.
"""

application: MplApplication

def get(self) -> None:
manager = self.application.manager
ws_uri = f"ws://{self.request.host}/"
content = html_content % {
"ws_uri": ws_uri,
"fig_id": manager.num,
"custom_css": css_content,
}
self.write(content)

class MplJs(tornado.web.RequestHandler):
"""
Serves the generated matplotlib javascript file. The content
is dynamically generated based on which toolbar functions the
user has defined. Call `FigureManagerWebAgg` to get its
content.
"""

application: MplApplication

def get(self) -> None:
from matplotlib.backends.backend_webagg import ( # type: ignore
FigureManagerWebAgg,
)

self.set_header("Content-Type", "application/javascript")
js_content = FigureManagerWebAgg.get_javascript()
manager: Any = new_figure_manager_given_figure(id(figure), figure)

self.write(js_content)
async def main_page(request: Request):
ws_uri = f"ws://{host}:{port}/ws"

class Download(tornado.web.RequestHandler):
"""
Handles downloading of the figure in various file formats.
"""

application: MplApplication
content = html_content % {
"ws_uri": ws_uri,
"fig_id": manager.num,
"custom_css": css_content,
}
# return HTMLResponse(content="Hello World")
return HTMLResponse(content=content)

def get(self, fmt: str) -> None:
manager = self.application.manager
self.set_header(
"Content-Type", mimetypes.types_map.get(fmt, "binary")
)
buff = io.BytesIO()
manager.canvas.figure.savefig(buff, format=fmt)
self.write(buff.getvalue())

class WebSocket(tornado.websocket.WebSocketHandler):
"""
A websocket for interactive communication between the plot in
the browser and the server.

In addition to the methods required by tornado, it is required to
have two callback methods:

- ``send_json(json_content)`` is called by matplotlib when
it needs to send json to the browser. `json_content` is
a JSON tree (Python dictionary), and it is the responsibility
of this implementation to encode it as a string to send over
the socket.

- ``send_binary(blob)`` is called to send binary image data
to the browser.
"""

application: MplApplication
supports_binary = True

def open(self, *args: str, **kwargs: str) -> None:
del args
del kwargs
# Register the websocket with the FigureManager.
manager = self.application.manager
manager.add_web_socket(self)
if hasattr(self, "set_nodelay"):
self.set_nodelay(True)

def on_close(self) -> None:
# When the socket is closed, deregister the websocket with
# the FigureManager.
manager = self.application.manager
manager.remove_web_socket(self)

def on_message(self, message: Any) -> None:
# The 'supports_binary' message is relevant to the
# websocket itself. The other messages get passed along
# to matplotlib as-is.

# Every message has a "type" and a "figure_id".
message = json.loads(message)
if message["type"] == "supports_binary":
self.supports_binary = message["value"]
else:
manager = self.application.manager
manager.handle_json(message)

def send_json(self, content: str) -> None:
self.write_message(json.dumps(content))

def send_binary(self, blob: Any) -> None:
if self.supports_binary:
self.write_message(blob, binary=True)
else:
data_uri = "data:image/png;base64," + blob.encode(
"base64"
).replace("\n", "")
self.write_message(data_uri)

def __init__(self, figure: Any) -> None:
import matplotlib as mpl # type: ignore[import-not-found,import-untyped,unused-ignore] # noqa: E501
from matplotlib.backends.backend_webagg import (
FigureManagerWebAgg,
new_figure_manager_given_figure,
async def mpl_js(request: Request):
return Response(
content=FigureManagerWebAgg.get_javascript(),
media_type="application/javascript",
)

self.figure = figure
self.manager = new_figure_manager_given_figure(id(figure), figure)

super().__init__(
[
# Static files for the CSS and JS
(
r"/mpl/_static/(.*)",
tornado.web.StaticFileHandler,
{"path": FigureManagerWebAgg.get_static_file_path()},
async def download(request: Request):
fmt = request.path_params["fmt"]
mime_type = mimetypes.types_map.get(fmt, "binary")
buff = io.BytesIO()
manager.canvas.figure.savefig(buff, format=fmt)
return Response(content=buff.getvalue(), media_type=mime_type)

async def websocket_endpoint(websocket: WebSocket):
await websocket.accept()

queue = asyncio.Queue[Tuple[Any, str]]()

class SyncWebSocket:
def send_json(self, content: str) -> None:
queue.put_nowait((content, "json"))

def send_binary(self, blob: Any) -> None:
queue.put_nowait((blob, "binary"))

manager.add_web_socket(SyncWebSocket())

async def receive() -> None:
try:
while True:
data = await websocket.receive_json()
manager.handle_json(data)
except Exception:
pass
finally:
await websocket.close()

async def send() -> None:
try:
while True:
(data, mode) = await queue.get()
if mode == "json":
await websocket.send_json(data)
else:
await websocket.send_bytes(data)
except Exception:
pass
finally:
await websocket.close()

await asyncio.gather(receive(), send())

return Starlette(
routes=[
Route("/", main_page, methods=["GET"]),
Route("/mpl/mpl.js", mpl_js, methods=["GET"]),
Route("/download.{fmt}", download, methods=["GET"]),
WebSocketRoute("/ws", websocket_endpoint),
Mount(
"/mpl/_static",
StaticFiles(
directory=FigureManagerWebAgg.get_static_file_path()
),
# Static images for the toolbar
(
r"/_images/(.*)",
tornado.web.StaticFileHandler,
{"path": Path(mpl.get_data_path(), "images")},
),
# The page that contains all of the pieces
("/", self.MainPage),
("/mpl/mpl.js", self.MplJs),
# Sends images and events to the browser, and receives
# events from the browser
("/ws", self.WebSocket),
# Handles the downloading (i.e., saving) of static images
(r"/download.([a-z0-9.]+)", self.Download),
]
)
name="mpl_static",
),
Mount(
"/_images",
StaticFiles(directory=Path(mpl.get_data_path(), "images")),
name="images",
),
]
)


class CleanupHandle(CellLifecycleItem):
Expand Down Expand Up @@ -248,39 +194,29 @@ def interactive(figure: "Figure | Axes") -> Html: # type: ignore[name-defined]
"marimo.mpl.interactive can't be used when running as a script."
) from err

host = "localhost"
port = find_free_port(10_000)

# TODO(akshayka): Proxy this server through the marimo server to help with
# deployment.
application = MplApplication(figure)
application = create_application(figure, host, port)
cleanup_handle = CleanupHandle()
sockets = tornado.netutil.bind_sockets(0, "")

async def main() -> None:
# create the shutdown event in the coroutine for py3.8, 3.9 compat
cleanup_handle.shutdown_event = asyncio.Event()
http_server = tornado.httpserver.HTTPServer(application)
http_server.add_sockets(sockets)
await cleanup_handle.shutdown_event.wait()
http_server.stop()
await http_server.close_all_connections()

def start_server() -> None:
asyncio.run(main())

addr: Optional[str] = None
port: Optional[int] = None
for s in sockets:
addr, port = s.getsockname()[:2]
if s.family is socket.AF_INET6:
addr = f"[{addr}]"
if addr is None or port is None:
raise RuntimeError("Failed to create sockets for mpl.interactive.")
uvicorn.Server(
uvicorn.Config(
app=application,
port=port,
host=host,
)
).run()

assert ctx.kernel.execution_context is not None
ctx.cell_lifecycle_registry.add(cleanup_handle)
threading.Thread(target=start_server).start()
return Html(
h.iframe(
src=f"http://{addr}:{port}/",
src=f"http://{host}:{port}/",
width="100%",
height="550px",
)
Expand Down Expand Up @@ -316,7 +252,7 @@ def start_server() -> None:
ready(
function() {
var websocket_type = mpl.get_websocket_type();
var websocket = new websocket_type("%(ws_uri)sws");
var websocket = new websocket_type("%(ws_uri)s");

// mpl.figure creates a new figure on the webpage.
var fig = new mpl.figure(
Expand Down Expand Up @@ -347,6 +283,12 @@ def start_server() -> None:
css_content = """
body {
background-color: transparent;
height: 400px;
width: 100%;
}
#figure, mlp-canvas {
height: 400px;
width: 100%;
}
.ui-dialog-titlebar + div {
border-radius: 4px;
Expand Down
4 changes: 3 additions & 1 deletion marimo/_server2/api/endpoints/assets.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,9 @@ async def index(request: Request):
title = parse_title(app_state.filename)
user_config = get_configuration()
app_config = (
app_state.app_config.dict() if app_state.app_config is not None else {}
app_state.app_config.asdict()
if app_state.app_config is not None
else {}
)

index_html = os.path.join(root, "index.html")
Expand Down