diff --git a/frontend/src/plugins/impl/vega/__tests__/__snapshots__/make-selectable.test.ts.snap b/frontend/src/plugins/impl/vega/__tests__/__snapshots__/make-selectable.test.ts.snap index 6747ff56dd6..3d648dc5eb1 100644 --- a/frontend/src/plugins/impl/vega/__tests__/__snapshots__/make-selectable.test.ts.snap +++ b/frontend/src/plugins/impl/vega/__tests__/__snapshots__/make-selectable.test.ts.snap @@ -1,5 +1,176 @@ // Vitest Snapshot v1, https://vitest.dev/guide/snapshot.html +exports[`makeSelectable > should add legend selection to composite charts (issue #6676) 1`] = ` +{ + "layer": [ + { + "encoding": { + "color": { + "field": "category", + "type": "nominal", + }, + "opacity": { + "condition": { + "test": { + "and": [ + { + "param": "legend_selection_category", + }, + { + "param": "select_point_0", + }, + { + "param": "select_interval_0", + }, + ], + }, + "value": 1, + }, + "value": 0.2, + }, + "x": { + "field": "x_value", + "type": "quantitative", + }, + "y": { + "field": "upper", + "type": "quantitative", + }, + "y2": { + "field": "lower", + }, + }, + "mark": { + "cursor": "pointer", + "tooltip": true, + "type": "rule", + }, + "params": [ + { + "bind": "legend", + "name": "legend_selection_category", + "select": { + "fields": [ + "category", + ], + "type": "point", + }, + }, + { + "name": "select_point_0", + "select": { + "encodings": [ + "x", + "y", + ], + "on": "click[!event.metaKey]", + "type": "point", + }, + }, + { + "name": "select_interval_0", + "select": { + "encodings": [ + "x", + "y", + ], + "mark": { + "fill": "#669EFF", + "fillOpacity": 0.07, + "stroke": "#669EFF", + "strokeOpacity": 0.4, + }, + "on": "[mousedown[!event.metaKey], mouseup] > mousemove[!event.metaKey]", + "translate": "[mousedown[!event.metaKey], mouseup] > mousemove[!event.metaKey]", + "type": "interval", + }, + }, + { + "bind": "scales", + "name": "pan_zoom", + "select": { + "on": "[mousedown[event.metaKey], window:mouseup] > window:mousemove!", + "translate": "[mousedown[event.metaKey], window:mouseup] > window:mousemove!", + "type": "interval", + "zoom": "wheel![event.metaKey]", + }, + }, + ], + }, + { + "encoding": { + "color": { + "field": "category", + "type": "nominal", + }, + "opacity": { + "condition": { + "test": { + "and": [ + { + "param": "select_point_1", + }, + { + "param": "select_interval_1", + }, + ], + }, + "value": 1, + }, + "value": 0.2, + }, + "x": { + "field": "x_value", + "type": "quantitative", + }, + "y": { + "field": "value", + "type": "quantitative", + }, + }, + "mark": { + "cursor": "pointer", + "filled": true, + "size": 60, + "tooltip": true, + "type": "point", + }, + "params": [ + { + "name": "select_point_1", + "select": { + "encodings": [ + "x", + "y", + ], + "on": "click[!event.metaKey]", + "type": "point", + }, + }, + { + "name": "select_interval_1", + "select": { + "encodings": [ + "x", + "y", + ], + "mark": { + "fill": "#669EFF", + "fillOpacity": 0.07, + "stroke": "#669EFF", + "strokeOpacity": 0.4, + }, + "on": "[mousedown[!event.metaKey], mouseup] > mousemove[!event.metaKey]", + "translate": "[mousedown[!event.metaKey], mouseup] > mousemove[!event.metaKey]", + "type": "interval", + }, + }, + ], + }, + ], +} +`; + exports[`makeSelectable > should return correctly if existing legend selection 1`] = ` { "config": { @@ -774,6 +945,19 @@ exports[`makeSelectable > should work for layered charts, with existing selectio }, "value": "lightgray", }, + "opacity": { + "condition": { + "test": { + "and": [ + { + "param": "legend_selection_stage", + }, + ], + }, + "value": 1, + }, + "value": 0.2, + }, "x": { "field": "Level1", "sort": { @@ -789,10 +973,22 @@ exports[`makeSelectable > should work for layered charts, with existing selectio }, }, "mark": { + "cursor": "pointer", + "tooltip": true, "type": "bar", }, "name": "view_21", "params": [ + { + "bind": "legend", + "name": "legend_selection_stage", + "select": { + "fields": [ + "stage", + ], + "type": "point", + }, + }, { "bind": "scales", "name": "pan_zoom", diff --git a/frontend/src/plugins/impl/vega/__tests__/make-selectable.test.ts b/frontend/src/plugins/impl/vega/__tests__/make-selectable.test.ts index ae0d2cdd08a..6cefe701989 100644 --- a/frontend/src/plugins/impl/vega/__tests__/make-selectable.test.ts +++ b/frontend/src/plugins/impl/vega/__tests__/make-selectable.test.ts @@ -572,4 +572,211 @@ describe("makeSelectable", () => { expect(getSelectionParamNames(newSpec)).toEqual([]); }, ); + + it("should add legend selection to composite charts (issue #6676)", () => { + // Test case from https://github.com/marimo-team/marimo/issues/6676 + const spec = { + layer: [ + { + mark: "rule", + encoding: { + x: { field: "x_value", type: "quantitative" }, + y: { field: "upper", type: "quantitative" }, + y2: { field: "lower" }, + color: { + field: "category", + type: "nominal", + }, + }, + }, + { + mark: { type: "point", filled: true, size: 60 }, + encoding: { + x: { field: "x_value", type: "quantitative" }, + y: { field: "value", type: "quantitative" }, + color: { + field: "category", + type: "nominal", + }, + }, + }, + ], + } as VegaLiteSpec; + + const newSpec = makeSelectable(spec, { + chartSelection: true, + fieldSelection: true, + }); + + expect(newSpec).toMatchSnapshot(); + const paramNames = getSelectionParamNames(newSpec); + // Should have legend selection for category field + expect(paramNames).toContain("legend_selection_category"); + // Should NOT have duplicate legend params + expect( + paramNames.filter((name) => name === "legend_selection_category"), + ).toHaveLength(1); + }); + + it("should not duplicate legend params when multiple layers have same color field", () => { + const spec = { + layer: [ + { + mark: "line", + encoding: { + x: { field: "x", type: "quantitative" }, + y: { field: "y", type: "quantitative" }, + color: { field: "category", type: "nominal" }, + }, + }, + { + mark: { type: "point", size: 100 }, + encoding: { + x: { field: "x", type: "quantitative" }, + y: { field: "y", type: "quantitative" }, + color: { field: "category", type: "nominal" }, + }, + }, + ], + } as VegaLiteSpec; + + const newSpec = makeSelectable(spec, {}); + const paramNames = getSelectionParamNames(newSpec); + + // Should have exactly one legend_selection_category param + const legendParams = paramNames.filter((name) => + name.startsWith("legend_selection_category"), + ); + expect(legendParams).toHaveLength(1); + expect(legendParams[0]).toBe("legend_selection_category"); + }); + + it("should collect legend fields from multiple layers with different fields", () => { + const spec = { + layer: [ + { + mark: "point", + encoding: { + x: { field: "x", type: "quantitative" }, + y: { field: "y", type: "quantitative" }, + color: { field: "category", type: "nominal" }, + }, + }, + { + mark: "point", + encoding: { + x: { field: "x", type: "quantitative" }, + y: { field: "y", type: "quantitative" }, + size: { field: "size_field", type: "quantitative" }, + }, + }, + ], + } as VegaLiteSpec; + + const newSpec = makeSelectable(spec, {}); + const paramNames = getSelectionParamNames(newSpec); + + // Should have legend params for both fields + expect(paramNames).toContain("legend_selection_category"); + expect(paramNames).toContain("legend_selection_size_field"); + }); + + it("should add legend selection to vconcat specs", () => { + const spec = { + vconcat: [ + { + mark: "point", + encoding: { + x: { field: "x", type: "quantitative" }, + y: { field: "y", type: "quantitative" }, + color: { field: "category", type: "nominal" }, + }, + }, + { + mark: "bar", + encoding: { + x: { field: "x", type: "nominal" }, + y: { field: "y", type: "quantitative" }, + color: { field: "category", type: "nominal" }, + }, + }, + ], + } as VegaLiteSpec; + + const newSpec = makeSelectable(spec, {}); + const paramNames = getSelectionParamNames(newSpec); + + // Should have legend selection for category field + expect(paramNames).toContain("legend_selection_category"); + }); + + it("should add legend selection to hconcat specs", () => { + const spec = { + hconcat: [ + { + mark: "point", + encoding: { + x: { field: "x", type: "quantitative" }, + y: { field: "y", type: "quantitative" }, + color: { field: "series", type: "nominal" }, + }, + }, + { + mark: "line", + encoding: { + x: { field: "x", type: "quantitative" }, + y: { field: "y", type: "quantitative" }, + color: { field: "series", type: "nominal" }, + }, + }, + ], + } as VegaLiteSpec; + + const newSpec = makeSelectable(spec, {}); + const paramNames = getSelectionParamNames(newSpec); + + // Should have legend selection for series field + expect(paramNames).toContain("legend_selection_series"); + }); + + it("should add legend selection to nested vconcat(hconcat(...)) specs", () => { + const spec = { + vconcat: [ + { + hconcat: [ + { + mark: "point", + encoding: { + x: { field: "x", type: "quantitative" }, + y: { field: "y", type: "quantitative" }, + color: { field: "category", type: "nominal" }, + }, + }, + { + mark: "bar", + encoding: { + x: { field: "x", type: "nominal" }, + y: { field: "y", type: "quantitative" }, + color: { field: "category", type: "nominal" }, + }, + }, + ], + }, + { + mark: "line", + encoding: { + x: { field: "x", type: "quantitative" }, + y: { field: "y", type: "quantitative" }, + color: { field: "category", type: "nominal" }, + }, + }, + ], + } as VegaLiteSpec; + + const newSpec = makeSelectable(spec, {}); + const paramNames = getSelectionParamNames(newSpec); + + // Should have legend selection for category field + expect(paramNames).toContain("legend_selection_category"); + }); }); diff --git a/frontend/src/plugins/impl/vega/make-selectable.ts b/frontend/src/plugins/impl/vega/make-selectable.ts index ecd3ed7d79e..bc39bc9b060 100644 --- a/frontend/src/plugins/impl/vega/make-selectable.ts +++ b/frontend/src/plugins/impl/vega/make-selectable.ts @@ -38,7 +38,12 @@ export function makeSelectable( if ("vconcat" in spec) { const subSpecs = spec.vconcat.map((subSpec) => - "mark" in subSpec ? makeChartInteractive(subSpec) : subSpec, + "mark" in subSpec + ? makeSelectable(subSpec as VegaLiteUnitSpec, { + chartSelection, + fieldSelection, + }) + : subSpec, ); // No pan/zoom for vconcat return { ...spec, vconcat: subSpecs }; @@ -46,18 +51,59 @@ export function makeSelectable( if ("hconcat" in spec) { const subSpecs = spec.hconcat.map((subSpec) => - "mark" in subSpec ? makeChartInteractive(subSpec) : subSpec, + "mark" in subSpec + ? makeSelectable(subSpec as VegaLiteUnitSpec, { + chartSelection, + fieldSelection, + }) + : subSpec, ); // No pan/zoom for hconcat return { ...spec, hconcat: subSpecs }; } if ("layer" in spec) { + // Check if legend params already exist at the top level + const hasTopLevelLegendParam = spec.params?.some( + (param) => param.bind === "legend", + ); + const shouldAddLegendSelection = + fieldSelection !== false && !hasTopLevelLegendParam; + + // Collect all unique legend fields from all layers to avoid duplicates + let legendFields: string[] = []; + if (shouldAddLegendSelection) { + const allFields = spec.layer.flatMap((subSpec) => { + if (!("mark" in subSpec)) { + return []; + } + return findEncodedFields(subSpec as VegaLiteUnitSpec); + }); + legendFields = [...new Set(allFields)]; // Remove duplicates + + // If fieldSelection is an array, filter the fields + if (Array.isArray(fieldSelection)) { + legendFields = legendFields.filter((field) => + fieldSelection.includes(field), + ); + } + } + const subSpecs = spec.layer.map((subSpec, idx) => { if (!("mark" in subSpec)) { return subSpec; } let resolvedSpec = subSpec as VegaLiteUnitSpec; + + // Only add legend params to the first layer to avoid duplicates + if (idx === 0 && legendFields.length > 0) { + const legendParams = legendFields.map((field) => Params.legend(field)); + resolvedSpec = { + ...resolvedSpec, + params: [...(resolvedSpec.params || []), ...legendParams], + }; + } + resolvedSpec = makeChartSelectable(resolvedSpec, chartSelection, idx); resolvedSpec = makeChartInteractive(resolvedSpec); if (idx === 0) { diff --git a/frontend/src/plugins/impl/vega/params.ts b/frontend/src/plugins/impl/vega/params.ts index a0d262f709c..2db1e073433 100644 --- a/frontend/src/plugins/impl/vega/params.ts +++ b/frontend/src/plugins/impl/vega/params.ts @@ -1,6 +1,10 @@ /* Copyright 2024 Marimo. All rights reserved. */ import type { TopLevelSpec } from "vega-lite"; -import type { LayerSpec, UnitSpec } from "vega-lite/build/src/spec"; +import type { + LayerSpec, + NonNormalizedSpec, + UnitSpec, +} from "vega-lite/build/src/spec"; import { Marks } from "./marks"; import { type Field, @@ -136,7 +140,7 @@ export function getEncodingAxisForMark( } export function getSelectionParamNames( - spec: TopLevelSpec | LayerSpec | UnitSpec, + spec: TopLevelSpec | LayerSpec | UnitSpec | NonNormalizedSpec, ): string[] { if ("params" in spec && spec.params && spec.params.length > 0) { const params = spec.params; @@ -156,6 +160,12 @@ export function getSelectionParamNames( if ("layer" in spec) { return [...new Set(spec.layer.flatMap(getSelectionParamNames))]; } + if ("vconcat" in spec) { + return [...new Set(spec.vconcat.flatMap(getSelectionParamNames))]; + } + if ("hconcat" in spec) { + return [...new Set(spec.hconcat.flatMap(getSelectionParamNames))]; + } return []; } diff --git a/marimo/_smoke_tests/altair_examples/composite_legend_selection.py b/marimo/_smoke_tests/altair_examples/composite_legend_selection.py new file mode 100644 index 00000000000..c2fb9870f58 --- /dev/null +++ b/marimo/_smoke_tests/altair_examples/composite_legend_selection.py @@ -0,0 +1,278 @@ +import marimo + +__generated_with = "0.17.2" +app = marimo.App(width="medium") + + +@app.cell +def _(): + import altair as alt + import pandas as pd + import marimo as mo + return alt, mo, pd + + +@app.cell(hide_code=True) +def _(mo): + mo.md( + r""" + ## Composite Chart Legend Selection Bug + + Test case for https://github.com/marimo-team/marimo/issues/6676 + + Legend selection should work on composite charts created with the `+` operator. + Clicking legend items should filter the chart. + """ + ) + return + + +@app.cell +def _(pd): + sample_data = [ + [2, 1, 4, 10, "a"], + [3, 0, 6, 12, "b"], + [8, 5, 12, 15, "c"], + ] + sample_df = pd.DataFrame( + sample_data, columns=["value", "lower", "upper", "x_value", "category"] + ) + sample_df + return (sample_df,) + + +@app.cell(hide_code=True) +def _(mo): + mo.md(r"""### Composite chart with error bars and points""") + return + + +@app.cell +def _(alt, mo, sample_df): + sample_color = alt.Color( + field="category", + type="nominal", + legend=alt.Legend( + title="category", + labelLimit=0, + symbolLimit=0, + ), + ) + + sample_base_chart = alt.Chart(sample_df, title="Sample Error Bars") + + sample_rule = sample_base_chart.mark_rule().encode( + x=alt.X("x_value"), + y=alt.Y("upper"), + y2="lower", + color=sample_color, + ) + + sample_upper_tick = sample_base_chart.mark_tick( + orient="horizontal", size=5 + ).encode( + x="x_value:Q", + y="upper:Q", + color=sample_color, + ) + sample_tick = sample_upper_tick.encode(y="lower:Q") + + sample_lines = sample_rule + sample_upper_tick + sample_tick + + sample_dots = sample_base_chart.mark_point(filled=True, size=60).encode( + x=alt.X("x_value"), + y=alt.Y("value"), + color=sample_color, + ) + + alt_chart = sample_dots + sample_lines + + sample_mo_chart = mo.ui.altair_chart(alt_chart) + + sample_mo_chart + return (sample_mo_chart,) + + +@app.cell +def _(sample_mo_chart): + sample_mo_chart.value + return + + +@app.cell(hide_code=True) +def _(mo): + mo.md(r"""### Simple layered chart (for comparison)""") + return + + +@app.cell +def _(alt, mo, pd): + # Create a simple layered chart to compare + simple_data = pd.DataFrame( + { + "x": [1, 2, 3, 4, 5] * 3, + "y": [1, 2, 3, 4, 5, 2, 3, 4, 5, 6, 3, 4, 5, 6, 7], + "category": ["A"] * 5 + ["B"] * 5 + ["C"] * 5, + } + ) + + simple_chart = alt.Chart(simple_data).mark_line().encode( + x="x:Q", y="y:Q", color="category:N" + ) + alt.Chart(simple_data).mark_point(size=100).encode( + x="x:Q", y="y:Q", color="category:N" + ) + + simple_mo_chart = mo.ui.altair_chart(simple_chart) + simple_mo_chart + return (simple_mo_chart,) + + +@app.cell +def _(simple_mo_chart): + simple_mo_chart.value + return + + +@app.cell(hide_code=True) +def _(mo): + mo.md(r"""### Workaround with explicit legend selection""") + return + + +@app.cell +def _(alt, mo, sample_df): + # Workaround from the issue: explicit legend selection + legend_select = alt.selection_point(fields=["category"], bind="legend") + + workaround_color = alt.Color( + field="category", + type="nominal", + legend=alt.Legend( + title="category", + labelLimit=0, + symbolLimit=0, + ), + ) + + workaround_base = alt.Chart( + sample_df, title="Workaround with Explicit Selection" + ) + + workaround_rule = ( + workaround_base.mark_rule() + .encode( + x=alt.X("x_value"), + y=alt.Y("upper"), + y2="lower", + color=workaround_color, + opacity=alt.condition(legend_select, alt.value(1), alt.value(0.2)), + ) + .add_params(legend_select) + ) + + workaround_dots = ( + workaround_base.mark_point(filled=True, size=60) + .encode( + x=alt.X("x_value"), + y=alt.Y("value"), + color=workaround_color, + opacity=alt.condition(legend_select, alt.value(1), alt.value(0.2)), + ) + .add_params(legend_select) + ) + + workaround_chart = workaround_dots + workaround_rule + + workaround_mo_chart = mo.ui.altair_chart( + workaround_chart, chart_selection=None, legend_selection="legend_select" + ) + workaround_mo_chart + return (workaround_mo_chart,) + + +@app.cell +def _(workaround_mo_chart): + workaround_mo_chart.value + return + + +@app.cell(hide_code=True) +def _(mo): + mo.md(r"""### vconcat with legend selection""") + return + + +@app.cell +def _(alt, mo, pd): + # Test vconcat with same color field + vconcat_data = pd.DataFrame({ + 'x': list(range(10)) * 3, + 'y': list(range(10)) + list(range(5, 15)) + list(range(10, 20)), + 'category': ['A'] * 10 + ['B'] * 10 + ['C'] * 10 + }) + + vconcat_chart = alt.vconcat( + alt.Chart(vconcat_data).mark_point().encode( + x='x:Q', + y='y:Q', + color='category:N' + ), + alt.Chart(vconcat_data).mark_bar().encode( + x='category:N', + y='mean(y):Q', + color='category:N' + ) + ) + + vconcat_mo_chart = mo.ui.altair_chart(vconcat_chart) + vconcat_mo_chart + return (vconcat_mo_chart,) + + +@app.cell +def _(vconcat_mo_chart): + vconcat_mo_chart.value + return + + +@app.cell(hide_code=True) +def _(mo): + mo.md(r"""### hconcat with legend selection""") + return + + +@app.cell +def _(alt, mo, pd): + # Test hconcat with same color field + hconcat_data = pd.DataFrame({ + 'x': list(range(10)) * 3, + 'y': list(range(10)) + list(range(5, 15)) + list(range(10, 20)), + 'series': ['X'] * 10 + ['Y'] * 10 + ['Z'] * 10 + }) + + hconcat_chart = alt.hconcat( + alt.Chart(hconcat_data).mark_line().encode( + x='x:Q', + y='y:Q', + color='series:N' + ), + alt.Chart(hconcat_data).mark_point(size=100).encode( + x='x:Q', + y='y:Q', + color='series:N' + ) + ) + + hconcat_mo_chart = mo.ui.altair_chart(hconcat_chart) + hconcat_mo_chart + return (hconcat_mo_chart,) + + +@app.cell +def _(hconcat_mo_chart): + hconcat_mo_chart.value + return + + +if __name__ == "__main__": + app.run()