Skip to content

Commit 92ae7bb

Browse files
committed
tests: mo.ui.plotly tests
1 parent 23900ec commit 92ae7bb

File tree

1 file changed

+364
-0
lines changed

1 file changed

+364
-0
lines changed
Lines changed: 364 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,364 @@
1+
# Copyright 2024 Marimo. All rights reserved.
2+
from __future__ import annotations
3+
4+
from typing import Any
5+
6+
import pytest
7+
8+
pytest.importorskip("plotly.express")
9+
pytest.importorskip("plotly.graph_objects")
10+
11+
import plotly.express as px
12+
import plotly.graph_objects as go
13+
14+
from marimo._plugins.ui._impl.plotly import plotly
15+
16+
17+
class TestPlotly:
18+
@staticmethod
19+
def test_basic_scatter_plot() -> None:
20+
"""Test creating a basic scatter plot."""
21+
fig = go.Figure(
22+
data=go.Scatter(x=[1, 2, 3], y=[4, 5, 6], mode="markers")
23+
)
24+
plot = plotly(fig)
25+
26+
assert plot is not None
27+
assert plot.value == []
28+
assert plot.ranges == {}
29+
assert plot.points == []
30+
assert plot.indices == []
31+
32+
@staticmethod
33+
def test_plotly_express_scatter() -> None:
34+
"""Test creating a plot with plotly express."""
35+
import pandas as pd
36+
37+
df = pd.DataFrame(
38+
{"x": [1, 2, 3], "y": [4, 5, 6], "color": ["A", "B", "A"]}
39+
)
40+
fig = px.scatter(df, x="x", y="y", color="color")
41+
plot = plotly(fig)
42+
43+
assert plot is not None
44+
assert plot.value == []
45+
46+
@staticmethod
47+
def test_plotly_with_config() -> None:
48+
"""Test creating a plot with custom configuration."""
49+
fig = go.Figure(data=go.Scatter(x=[1, 2, 3], y=[4, 5, 6]))
50+
config = {"staticPlot": True, "displayModeBar": False}
51+
plot = plotly(fig, config=config)
52+
53+
assert plot is not None
54+
assert plot._component_args["config"] == config
55+
56+
@staticmethod
57+
def test_plotly_with_label() -> None:
58+
"""Test creating a plot with a label."""
59+
fig = go.Figure(data=go.Scatter(x=[1, 2, 3], y=[4, 5, 6]))
60+
plot = plotly(fig, label="My Plot")
61+
62+
assert plot is not None
63+
64+
@staticmethod
65+
def test_plotly_with_on_change() -> None:
66+
"""Test creating a plot with on_change callback."""
67+
fig = go.Figure(data=go.Scatter(x=[1, 2, 3], y=[4, 5, 6]))
68+
callback_called = []
69+
70+
def on_change(value: Any) -> None:
71+
callback_called.append(value)
72+
73+
plot = plotly(fig, on_change=on_change)
74+
assert plot is not None
75+
76+
@staticmethod
77+
def test_initial_selection() -> None:
78+
"""Test that initial selection is properly set."""
79+
fig = go.Figure(data=go.Scatter(x=[1, 2, 3, 4], y=[1, 2, 3, 4]))
80+
81+
# Add a selection to the figure
82+
fig.add_selection(x0=1, x1=3, y0=1, y1=3, xref="x", yref="y")
83+
84+
# Update layout to include axis titles
85+
fig.update_xaxes(title_text="X Axis")
86+
fig.update_yaxes(title_text="Y Axis")
87+
88+
plot = plotly(fig)
89+
90+
# Check that initial value contains the selection
91+
initial_value = plot._args.initial_value
92+
assert "range" in initial_value
93+
assert "x" in initial_value["range"]
94+
assert "y" in initial_value["range"]
95+
assert initial_value["range"]["x"] == [1, 3]
96+
assert initial_value["range"]["y"] == [1, 3]
97+
98+
# Check that points within the selection are included
99+
assert "points" in initial_value
100+
assert "indices" in initial_value
101+
# Points at (1,1), (2,2), and (3,3) should be selected (using <= comparisons)
102+
assert len(initial_value["indices"]) == 3
103+
assert initial_value["indices"] == [0, 1, 2]
104+
105+
@staticmethod
106+
def test_selection_with_axis_titles() -> None:
107+
"""Test that selection properly extracts axis titles."""
108+
fig = go.Figure(data=go.Scatter(x=[1, 2, 3], y=[4, 5, 6]))
109+
fig.update_xaxes(title_text="Time")
110+
fig.update_yaxes(title_text="Value")
111+
fig.add_selection(x0=1, x1=2, y0=4, y1=5, xref="x", yref="y")
112+
113+
plot = plotly(fig)
114+
115+
# Check that points have the correct axis labels
116+
initial_value = plot._args.initial_value
117+
if initial_value["points"]:
118+
point = initial_value["points"][0]
119+
assert "Time" in point or "Value" in point
120+
121+
@staticmethod
122+
def test_selection_without_axis_titles() -> None:
123+
"""Test selection when axes don't have titles."""
124+
fig = go.Figure(data=go.Scatter(x=[1, 2, 3], y=[4, 5, 6]))
125+
fig.add_selection(x0=1, x1=2, y0=4, y1=5, xref="x", yref="y")
126+
127+
plot = plotly(fig)
128+
129+
# Should still work, but points might be empty or have generic labels
130+
initial_value = plot._args.initial_value
131+
assert "points" in initial_value
132+
133+
@staticmethod
134+
def test_convert_value_with_selection() -> None:
135+
"""Test _convert_value method with selection data."""
136+
fig = go.Figure(data=go.Scatter(x=[1, 2, 3], y=[4, 5, 6]))
137+
plot = plotly(fig)
138+
139+
selection = {
140+
"points": [{"x": 1, "y": 4}, {"x": 2, "y": 5}],
141+
"range": {"x": [1, 2], "y": [4, 5]},
142+
"indices": [0, 1],
143+
}
144+
145+
result = plot._convert_value(selection)
146+
147+
# _convert_value should return the points
148+
assert result == selection["points"]
149+
assert plot.ranges == {"x": [1, 2], "y": [4, 5]}
150+
assert plot.indices == [0, 1]
151+
152+
@staticmethod
153+
def test_convert_value_empty_selection() -> None:
154+
"""Test _convert_value with empty selection."""
155+
fig = go.Figure(data=go.Scatter(x=[1, 2, 3], y=[4, 5, 6]))
156+
plot = plotly(fig)
157+
158+
result = plot._convert_value({})
159+
160+
assert result == []
161+
assert plot.ranges == {}
162+
assert plot.points == []
163+
assert plot.indices == []
164+
165+
@staticmethod
166+
def test_ranges_property() -> None:
167+
"""Test the ranges property."""
168+
fig = go.Figure(data=go.Scatter(x=[1, 2, 3], y=[4, 5, 6]))
169+
plot = plotly(fig)
170+
171+
# Initially empty
172+
assert plot.ranges == {}
173+
174+
# Set selection data
175+
plot._convert_value({"range": {"x": [1, 2], "y": [4, 5]}})
176+
assert plot.ranges == {"x": [1, 2], "y": [4, 5]}
177+
178+
@staticmethod
179+
def test_points_property() -> None:
180+
"""Test the points property."""
181+
fig = go.Figure(data=go.Scatter(x=[1, 2, 3], y=[4, 5, 6]))
182+
plot = plotly(fig)
183+
184+
# Initially empty
185+
assert plot.points == []
186+
187+
# Set selection data
188+
plot._convert_value({"points": [{"x": 1, "y": 4}]})
189+
assert plot.points == [{"x": 1, "y": 4}]
190+
191+
@staticmethod
192+
def test_indices_property() -> None:
193+
"""Test the indices property."""
194+
fig = go.Figure(data=go.Scatter(x=[1, 2, 3], y=[4, 5, 6]))
195+
plot = plotly(fig)
196+
197+
# Initially empty
198+
assert plot.indices == []
199+
200+
# Set selection data
201+
plot._convert_value({"indices": [0, 2]})
202+
assert plot.indices == [0, 2]
203+
204+
@staticmethod
205+
def test_treemap() -> None:
206+
"""Test that treemaps can be created (supported chart type)."""
207+
fig = go.Figure(
208+
go.Treemap(
209+
labels=["A", "B", "C"],
210+
parents=["", "A", "A"],
211+
values=[10, 5, 5],
212+
)
213+
)
214+
plot = plotly(fig)
215+
216+
assert plot is not None
217+
218+
@staticmethod
219+
def test_sunburst() -> None:
220+
"""Test that sunburst charts can be created (supported chart type)."""
221+
fig = go.Figure(
222+
go.Sunburst(
223+
labels=["A", "B", "C"],
224+
parents=["", "A", "A"],
225+
values=[10, 5, 5],
226+
)
227+
)
228+
plot = plotly(fig)
229+
230+
assert plot is not None
231+
232+
@staticmethod
233+
def test_multiple_traces() -> None:
234+
"""Test plot with multiple traces."""
235+
fig = go.Figure()
236+
fig.add_trace(go.Scatter(x=[1, 2, 3], y=[4, 5, 6], name="Trace 1"))
237+
fig.add_trace(go.Scatter(x=[1, 2, 3], y=[6, 5, 4], name="Trace 2"))
238+
239+
plot = plotly(fig)
240+
assert plot is not None
241+
242+
@staticmethod
243+
def test_selection_across_multiple_traces() -> None:
244+
"""Test that selection works across multiple traces."""
245+
fig = go.Figure()
246+
fig.add_trace(go.Scatter(x=[1, 2], y=[1, 2], name="Trace 1"))
247+
fig.add_trace(go.Scatter(x=[2, 3], y=[2, 3], name="Trace 2"))
248+
fig.update_xaxes(title_text="X")
249+
fig.update_yaxes(title_text="Y")
250+
fig.add_selection(x0=1.5, x1=2.5, y0=1.5, y1=2.5, xref="x", yref="y")
251+
252+
plot = plotly(fig)
253+
254+
# Should select points from both traces
255+
initial_value = plot._args.initial_value
256+
assert len(initial_value["indices"]) >= 1
257+
258+
@staticmethod
259+
def test_selection_with_no_data() -> None:
260+
"""Test selection on a plot with no data."""
261+
fig = go.Figure()
262+
fig.add_selection(x0=1, x1=2, y0=1, y1=2, xref="x", yref="y")
263+
264+
plot = plotly(fig)
265+
266+
# Should not error, but should have empty selection
267+
initial_value = plot._args.initial_value
268+
assert initial_value["points"] == []
269+
assert initial_value["indices"] == []
270+
271+
@staticmethod
272+
def test_selection_partial_attributes() -> None:
273+
"""Test that selection without all required attributes is ignored."""
274+
fig = go.Figure(data=go.Scatter(x=[1, 2, 3], y=[4, 5, 6]))
275+
276+
class PartialSelection:
277+
x0 = 1
278+
x1 = 2
279+
# Missing y0 and y1
280+
281+
# We can't easily test this without modifying the figure's internal state
282+
# But we can verify the plot is created without error
283+
plot = plotly(fig)
284+
assert plot is not None
285+
286+
@staticmethod
287+
def test_figure_serialization() -> None:
288+
"""Test that the figure is properly serialized to JSON."""
289+
fig = go.Figure(data=go.Scatter(x=[1, 2, 3], y=[4, 5, 6]))
290+
plot = plotly(fig)
291+
292+
# Check that figure is in args as a dictionary
293+
assert "figure" in plot._component_args
294+
assert isinstance(plot._component_args["figure"], dict)
295+
assert "data" in plot._component_args["figure"]
296+
297+
@staticmethod
298+
def test_default_config_from_renderer() -> None:
299+
"""Test that default config is pulled from renderer when not provided."""
300+
import plotly.io as pio
301+
302+
# Save original renderer
303+
original_renderer = pio.renderers.default
304+
305+
try:
306+
# Set a renderer with custom config
307+
pio.renderers.default = "browser"
308+
309+
fig = go.Figure(data=go.Scatter(x=[1, 2, 3], y=[4, 5, 6]))
310+
plot = plotly(fig)
311+
312+
# Should have some config (exact config depends on renderer)
313+
assert "config" in plot._component_args
314+
315+
finally:
316+
# Restore original renderer
317+
pio.renderers.default = original_renderer
318+
319+
@staticmethod
320+
def test_explicit_config_overrides_renderer() -> None:
321+
"""Test that explicit config takes precedence over renderer config."""
322+
fig = go.Figure(data=go.Scatter(x=[1, 2, 3], y=[4, 5, 6]))
323+
custom_config = {"displaylogo": False}
324+
plot = plotly(fig, config=custom_config)
325+
326+
assert plot._component_args["config"] == custom_config
327+
328+
@staticmethod
329+
def test_value_returns_points() -> None:
330+
"""Test that .value returns the points list."""
331+
fig = go.Figure(data=go.Scatter(x=[1, 2, 3], y=[4, 5, 6]))
332+
plot = plotly(fig)
333+
334+
selection = {
335+
"points": [{"x": 1, "y": 4}],
336+
"range": {"x": [1, 2], "y": [4, 5]},
337+
"indices": [0],
338+
}
339+
340+
# _convert_value returns the points
341+
result = plot._convert_value(selection)
342+
assert result == [{"x": 1, "y": 4}]
343+
344+
@staticmethod
345+
def test_plotly_name() -> None:
346+
"""Test that the component name is correct."""
347+
assert plotly.name == "marimo-plotly"
348+
349+
@staticmethod
350+
def test_selection_boundary_conditions() -> None:
351+
"""Test selection at exact boundaries."""
352+
fig = go.Figure(data=go.Scatter(x=[1, 2, 3], y=[1, 2, 3]))
353+
fig.update_xaxes(title_text="X")
354+
fig.update_yaxes(title_text="Y")
355+
356+
# Selection that exactly matches point (2, 2)
357+
fig.add_selection(x0=2, x1=2, y0=2, y1=2, xref="x", yref="y")
358+
359+
plot = plotly(fig)
360+
361+
# Point at exactly (2, 2) should be selected (using <= comparisons)
362+
initial_value = plot._args.initial_value
363+
assert len(initial_value["indices"]) == 1
364+
assert 1 in initial_value["indices"]

0 commit comments

Comments
 (0)