Skip to content

Commit b1228b4

Browse files
authored
tests: mo.ui.plotly tests (#6757)
Follow tests for #6747 cc @yairchu
1 parent 3321239 commit b1228b4

File tree

1 file changed

+353
-0
lines changed

1 file changed

+353
-0
lines changed
Lines changed: 353 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,353 @@
1+
# Copyright 2024 Marimo. All rights reserved.
2+
from __future__ import annotations
3+
4+
from typing import Any
5+
6+
import pytest
7+
8+
pytest.importorskip("plotly.express")
9+
pytest.importorskip("plotly.graph_objects")
10+
11+
import plotly.express as px
12+
import plotly.graph_objects as go
13+
14+
from marimo._plugins.ui._impl.plotly import plotly
15+
16+
17+
def test_basic_scatter_plot() -> None:
18+
"""Test creating a basic scatter plot."""
19+
fig = go.Figure(data=go.Scatter(x=[1, 2, 3], y=[4, 5, 6], mode="markers"))
20+
plot = plotly(fig)
21+
22+
assert plot is not None
23+
assert plot.value == []
24+
assert plot.ranges == {}
25+
assert plot.points == []
26+
assert plot.indices == []
27+
28+
29+
def test_plotly_express_scatter() -> None:
30+
"""Test creating a plot with plotly express."""
31+
import pandas as pd
32+
33+
df = pd.DataFrame(
34+
{"x": [1, 2, 3], "y": [4, 5, 6], "color": ["A", "B", "A"]}
35+
)
36+
fig = px.scatter(df, x="x", y="y", color="color")
37+
plot = plotly(fig)
38+
39+
assert plot is not None
40+
assert plot.value == []
41+
42+
43+
def test_plotly_with_config() -> None:
44+
"""Test creating a plot with custom configuration."""
45+
fig = go.Figure(data=go.Scatter(x=[1, 2, 3], y=[4, 5, 6]))
46+
config = {"staticPlot": True, "displayModeBar": False}
47+
plot = plotly(fig, config=config)
48+
49+
assert plot is not None
50+
assert plot._component_args["config"] == config
51+
52+
53+
def test_plotly_with_label() -> None:
54+
"""Test creating a plot with a label."""
55+
fig = go.Figure(data=go.Scatter(x=[1, 2, 3], y=[4, 5, 6]))
56+
plot = plotly(fig, label="My Plot")
57+
58+
assert plot is not None
59+
60+
61+
def test_plotly_with_on_change() -> None:
62+
"""Test creating a plot with on_change callback."""
63+
fig = go.Figure(data=go.Scatter(x=[1, 2, 3], y=[4, 5, 6]))
64+
callback_called = []
65+
66+
def on_change(value: Any) -> None:
67+
callback_called.append(value)
68+
69+
plot = plotly(fig, on_change=on_change)
70+
assert plot is not None
71+
72+
73+
def test_initial_selection() -> None:
74+
"""Test that initial selection is properly set."""
75+
fig = go.Figure(data=go.Scatter(x=[1, 2, 3, 4], y=[1, 2, 3, 4]))
76+
77+
# Add a selection to the figure
78+
fig.add_selection(x0=1, x1=3, y0=1, y1=3, xref="x", yref="y")
79+
80+
# Update layout to include axis titles
81+
fig.update_xaxes(title_text="X Axis")
82+
fig.update_yaxes(title_text="Y Axis")
83+
84+
plot = plotly(fig)
85+
86+
# Check that initial value contains the selection
87+
initial_value = plot._args.initial_value
88+
assert "range" in initial_value
89+
assert "x" in initial_value["range"]
90+
assert "y" in initial_value["range"]
91+
assert initial_value["range"]["x"] == [1, 3]
92+
assert initial_value["range"]["y"] == [1, 3]
93+
94+
# Check that points within the selection are included
95+
assert "points" in initial_value
96+
assert "indices" in initial_value
97+
# Points at (1,1), (2,2), and (3,3) should be selected (using <= comparisons)
98+
assert len(initial_value["indices"]) == 3
99+
assert initial_value["indices"] == [0, 1, 2]
100+
101+
102+
def test_selection_with_axis_titles() -> None:
103+
"""Test that selection properly extracts axis titles."""
104+
fig = go.Figure(data=go.Scatter(x=[1, 2, 3], y=[4, 5, 6]))
105+
fig.update_xaxes(title_text="Time")
106+
fig.update_yaxes(title_text="Value")
107+
fig.add_selection(x0=1, x1=2, y0=4, y1=5, xref="x", yref="y")
108+
109+
plot = plotly(fig)
110+
111+
# Check that points have the correct axis labels
112+
initial_value = plot._args.initial_value
113+
if initial_value["points"]:
114+
point = initial_value["points"][0]
115+
assert "Time" in point or "Value" in point
116+
117+
118+
def test_selection_without_axis_titles() -> None:
119+
"""Test selection when axes don't have titles."""
120+
fig = go.Figure(data=go.Scatter(x=[1, 2, 3], y=[4, 5, 6]))
121+
fig.add_selection(x0=1, x1=2, y0=4, y1=5, xref="x", yref="y")
122+
123+
plot = plotly(fig)
124+
125+
# Should still work, but points might be empty or have generic labels
126+
initial_value = plot._args.initial_value
127+
assert "points" in initial_value
128+
129+
130+
def test_convert_value_with_selection() -> None:
131+
"""Test _convert_value method with selection data."""
132+
fig = go.Figure(data=go.Scatter(x=[1, 2, 3], y=[4, 5, 6]))
133+
plot = plotly(fig)
134+
135+
selection = {
136+
"points": [{"x": 1, "y": 4}, {"x": 2, "y": 5}],
137+
"range": {"x": [1, 2], "y": [4, 5]},
138+
"indices": [0, 1],
139+
}
140+
141+
result = plot._convert_value(selection)
142+
143+
# _convert_value should return the points
144+
assert result == selection["points"]
145+
assert plot.ranges == {"x": [1, 2], "y": [4, 5]}
146+
assert plot.indices == [0, 1]
147+
148+
149+
def test_convert_value_empty_selection() -> None:
150+
"""Test _convert_value with empty selection."""
151+
fig = go.Figure(data=go.Scatter(x=[1, 2, 3], y=[4, 5, 6]))
152+
plot = plotly(fig)
153+
154+
result = plot._convert_value({})
155+
156+
assert result == []
157+
assert plot.ranges == {}
158+
assert plot.points == []
159+
assert plot.indices == []
160+
161+
162+
def test_ranges_property() -> None:
163+
"""Test the ranges property."""
164+
fig = go.Figure(data=go.Scatter(x=[1, 2, 3], y=[4, 5, 6]))
165+
plot = plotly(fig)
166+
167+
# Initially empty
168+
assert plot.ranges == {}
169+
170+
# Set selection data
171+
plot._convert_value({"range": {"x": [1, 2], "y": [4, 5]}})
172+
assert plot.ranges == {"x": [1, 2], "y": [4, 5]}
173+
174+
175+
def test_points_property() -> None:
176+
"""Test the points property."""
177+
fig = go.Figure(data=go.Scatter(x=[1, 2, 3], y=[4, 5, 6]))
178+
plot = plotly(fig)
179+
180+
# Initially empty
181+
assert plot.points == []
182+
183+
# Set selection data
184+
plot._convert_value({"points": [{"x": 1, "y": 4}]})
185+
assert plot.points == [{"x": 1, "y": 4}]
186+
187+
188+
def test_indices_property() -> None:
189+
"""Test the indices property."""
190+
fig = go.Figure(data=go.Scatter(x=[1, 2, 3], y=[4, 5, 6]))
191+
plot = plotly(fig)
192+
193+
# Initially empty
194+
assert plot.indices == []
195+
196+
# Set selection data
197+
plot._convert_value({"indices": [0, 2]})
198+
assert plot.indices == [0, 2]
199+
200+
201+
def test_treemap() -> None:
202+
"""Test that treemaps can be created (supported chart type)."""
203+
fig = go.Figure(
204+
go.Treemap(
205+
labels=["A", "B", "C"],
206+
parents=["", "A", "A"],
207+
values=[10, 5, 5],
208+
)
209+
)
210+
plot = plotly(fig)
211+
212+
assert plot is not None
213+
214+
215+
def test_sunburst() -> None:
216+
"""Test that sunburst charts can be created (supported chart type)."""
217+
fig = go.Figure(
218+
go.Sunburst(
219+
labels=["A", "B", "C"],
220+
parents=["", "A", "A"],
221+
values=[10, 5, 5],
222+
)
223+
)
224+
plot = plotly(fig)
225+
226+
assert plot is not None
227+
228+
229+
def test_multiple_traces() -> None:
230+
"""Test plot with multiple traces."""
231+
fig = go.Figure()
232+
fig.add_trace(go.Scatter(x=[1, 2, 3], y=[4, 5, 6], name="Trace 1"))
233+
fig.add_trace(go.Scatter(x=[1, 2, 3], y=[6, 5, 4], name="Trace 2"))
234+
235+
plot = plotly(fig)
236+
assert plot is not None
237+
238+
239+
def test_selection_across_multiple_traces() -> None:
240+
"""Test that selection works across multiple traces."""
241+
fig = go.Figure()
242+
fig.add_trace(go.Scatter(x=[1, 2], y=[1, 2], name="Trace 1"))
243+
fig.add_trace(go.Scatter(x=[2, 3], y=[2, 3], name="Trace 2"))
244+
fig.update_xaxes(title_text="X")
245+
fig.update_yaxes(title_text="Y")
246+
fig.add_selection(x0=1.5, x1=2.5, y0=1.5, y1=2.5, xref="x", yref="y")
247+
248+
plot = plotly(fig)
249+
250+
# Should select points from both traces
251+
initial_value = plot._args.initial_value
252+
assert len(initial_value["indices"]) >= 1
253+
254+
255+
def test_selection_with_no_data() -> None:
256+
"""Test selection on a plot with no data."""
257+
fig = go.Figure()
258+
fig.add_selection(x0=1, x1=2, y0=1, y1=2, xref="x", yref="y")
259+
260+
plot = plotly(fig)
261+
262+
# Should not error, but should have empty selection
263+
initial_value = plot._args.initial_value
264+
assert initial_value["points"] == []
265+
assert initial_value["indices"] == []
266+
267+
268+
def test_selection_partial_attributes() -> None:
269+
"""Test that selection without all required attributes is ignored."""
270+
fig = go.Figure(data=go.Scatter(x=[1, 2, 3], y=[4, 5, 6]))
271+
272+
plot = plotly(fig)
273+
assert plot is not None
274+
275+
276+
def test_figure_serialization() -> None:
277+
"""Test that the figure is properly serialized to JSON."""
278+
fig = go.Figure(data=go.Scatter(x=[1, 2, 3], y=[4, 5, 6]))
279+
plot = plotly(fig)
280+
281+
# Check that figure is in args as a dictionary
282+
assert "figure" in plot._component_args
283+
assert isinstance(plot._component_args["figure"], dict)
284+
assert "data" in plot._component_args["figure"]
285+
286+
287+
def test_default_config_from_renderer() -> None:
288+
"""Test that default config is pulled from renderer when not provided."""
289+
import plotly.io as pio
290+
291+
# Save original renderer
292+
original_renderer = pio.renderers.default
293+
294+
try:
295+
# Set a renderer with custom config
296+
pio.renderers.default = "browser"
297+
298+
fig = go.Figure(data=go.Scatter(x=[1, 2, 3], y=[4, 5, 6]))
299+
plot = plotly(fig)
300+
301+
# Should have some config (exact config depends on renderer)
302+
assert "config" in plot._component_args
303+
304+
finally:
305+
# Restore original renderer
306+
pio.renderers.default = original_renderer
307+
308+
309+
def test_explicit_config_overrides_renderer() -> None:
310+
"""Test that explicit config takes precedence over renderer config."""
311+
fig = go.Figure(data=go.Scatter(x=[1, 2, 3], y=[4, 5, 6]))
312+
custom_config = {"displaylogo": False}
313+
plot = plotly(fig, config=custom_config)
314+
315+
assert plot._component_args["config"] == custom_config
316+
317+
318+
def test_value_returns_points() -> None:
319+
"""Test that .value returns the points list."""
320+
fig = go.Figure(data=go.Scatter(x=[1, 2, 3], y=[4, 5, 6]))
321+
plot = plotly(fig)
322+
323+
selection = {
324+
"points": [{"x": 1, "y": 4}],
325+
"range": {"x": [1, 2], "y": [4, 5]},
326+
"indices": [0],
327+
}
328+
329+
# _convert_value returns the points
330+
result = plot._convert_value(selection)
331+
assert result == [{"x": 1, "y": 4}]
332+
333+
334+
def test_plotly_name() -> None:
335+
"""Test that the component name is correct."""
336+
assert plotly.name == "marimo-plotly"
337+
338+
339+
def test_selection_boundary_conditions() -> None:
340+
"""Test selection at exact boundaries."""
341+
fig = go.Figure(data=go.Scatter(x=[1, 2, 3], y=[1, 2, 3]))
342+
fig.update_xaxes(title_text="X")
343+
fig.update_yaxes(title_text="Y")
344+
345+
# Selection that exactly matches point (2, 2)
346+
fig.add_selection(x0=2, x1=2, y0=2, y1=2, xref="x", yref="y")
347+
348+
plot = plotly(fig)
349+
350+
# Point at exactly (2, 2) should be selected (using <= comparisons)
351+
initial_value = plot._args.initial_value
352+
assert len(initial_value["indices"]) == 1
353+
assert 1 in initial_value["indices"]

0 commit comments

Comments
 (0)