Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 49 additions & 1 deletion marimo/_plugins/ui/_impl/plotly.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,9 +135,57 @@ def __init__(
"Using an empty configuration."
)

initial_value: dict[str, Any] = {}

def add_selection(selection: go.layout.Selection) -> None:
if not all(
hasattr(selection, k) for k in ["x0", "x1", "y0", "y1"]
):
return

initial_value["range"] = {
"x": [selection.x0, selection.x1],
"y": [selection.y0, selection.y1],
}

# Find points within the selection range
selected_points = []
selected_indices = []

x_axes: list[go.layout.XAxis] = []
figure.for_each_xaxis(x_axes.append)
[x_axis] = x_axes if len(x_axes) == 1 else [None]
y_axes: list[go.layout.YAxis] = []
figure.for_each_yaxis(y_axes.append)
[y_axis] = y_axes if len(y_axes) == 1 else [None]

for trace in figure.data:
x_data = getattr(trace, "x", None)
y_data = getattr(trace, "y", None)
if x_data is None or y_data is None:
continue
for point_idx, (x, y) in enumerate(zip(x_data, y_data)):
if (
selection.x0 <= x <= selection.x1
and selection.y0 <= y <= selection.y1
):
selected_points.append(
{
axis.title.text: val
for axis, val in [(x_axis, x), (y_axis, y)]
if axis and axis.title.text
}
)
selected_indices.append(point_idx)

initial_value["points"] = selected_points
initial_value["indices"] = selected_indices

figure.for_each_selection(add_selection)

super().__init__(
component_name=plotly.name,
initial_value={},
initial_value=initial_value,
label=label,
args={
"figure": json.loads(json_str),
Expand Down
Loading