diff --git a/marimo/_plugins/ui/_impl/plotly.py b/marimo/_plugins/ui/_impl/plotly.py index 5db55d44e9c..c05bf7dff21 100644 --- a/marimo/_plugins/ui/_impl/plotly.py +++ b/marimo/_plugins/ui/_impl/plotly.py @@ -135,9 +135,57 @@ def __init__( "Using an empty configuration." ) + initial_value: dict[str, Any] = {} + + def add_selection(selection: go.layout.Selection) -> None: + if not all( + hasattr(selection, k) for k in ["x0", "x1", "y0", "y1"] + ): + return + + initial_value["range"] = { + "x": [selection.x0, selection.x1], + "y": [selection.y0, selection.y1], + } + + # Find points within the selection range + selected_points = [] + selected_indices = [] + + x_axes: list[go.layout.XAxis] = [] + figure.for_each_xaxis(x_axes.append) + [x_axis] = x_axes if len(x_axes) == 1 else [None] + y_axes: list[go.layout.YAxis] = [] + figure.for_each_yaxis(y_axes.append) + [y_axis] = y_axes if len(y_axes) == 1 else [None] + + for trace in figure.data: + x_data = getattr(trace, "x", None) + y_data = getattr(trace, "y", None) + if x_data is None or y_data is None: + continue + for point_idx, (x, y) in enumerate(zip(x_data, y_data)): + if ( + selection.x0 <= x <= selection.x1 + and selection.y0 <= y <= selection.y1 + ): + selected_points.append( + { + axis.title.text: val + for axis, val in [(x_axis, x), (y_axis, y)] + if axis and axis.title.text + } + ) + selected_indices.append(point_idx) + + initial_value["points"] = selected_points + initial_value["indices"] = selected_indices + + figure.for_each_selection(add_selection) + super().__init__( component_name=plotly.name, - initial_value={}, + initial_value=initial_value, label=label, args={ "figure": json.loads(json_str),