diff --git a/frontend/src/plugins/impl/vega/__tests__/__snapshots__/make-selectable.test.ts.snap b/frontend/src/plugins/impl/vega/__tests__/__snapshots__/make-selectable.test.ts.snap index 3d648dc5eb1..729fab95e8a 100644 --- a/frontend/src/plugins/impl/vega/__tests__/__snapshots__/make-selectable.test.ts.snap +++ b/frontend/src/plugins/impl/vega/__tests__/__snapshots__/make-selectable.test.ts.snap @@ -1,5 +1,268 @@ // Vitest Snapshot v1, https://vitest.dev/guide/snapshot.html +exports[`makeSelectable > should add bin_coloring param for 2D binned histogram 1`] = ` +{ + "encoding": { + "color": { + "aggregate": "count", + "type": "quantitative", + }, + "opacity": { + "condition": { + "test": { + "and": [ + { + "param": "bin_coloring", + }, + ], + }, + "value": 1, + }, + "value": 0.2, + }, + "x": { + "bin": true, + "field": "x", + "type": "quantitative", + }, + "y": { + "bin": true, + "field": "y", + "type": "quantitative", + }, + }, + "mark": { + "cursor": "pointer", + "tooltip": true, + "type": "rect", + }, + "params": [ + { + "name": "select_point", + "select": { + "encodings": [ + "x", + "y", + ], + "on": "click[!event.metaKey]", + "type": "point", + }, + }, + { + "name": "bin_coloring", + "select": { + "on": "click[!event.metaKey]", + "type": "point", + }, + }, + { + "bind": "scales", + "name": "pan_zoom", + "select": { + "on": "[mousedown[event.metaKey], window:mouseup] > window:mousemove!", + "translate": "[mousedown[event.metaKey], window:mouseup] > window:mousemove!", + "type": "interval", + "zoom": "wheel![event.metaKey]", + }, + }, + ], +} +`; + +exports[`makeSelectable > should add bin_coloring param for binned charts 1`] = ` +{ + "encoding": { + "opacity": { + "condition": { + "test": { + "and": [ + { + "param": "bin_coloring", + }, + ], + }, + "value": 1, + }, + "value": 0.2, + }, + "x": { + "bin": true, + "field": "x", + "type": "quantitative", + }, + "y": { + "aggregate": "count", + "type": "quantitative", + }, + }, + "mark": { + "cursor": "pointer", + "tooltip": true, + "type": "bar", + }, + "params": [ + { + "name": "select_point", + "select": { + "encodings": [ + "x", + ], + "on": "click[!event.metaKey]", + "type": "point", + }, + }, + { + "name": "bin_coloring", + "select": { + "on": "click[!event.metaKey]", + "type": "point", + }, + }, + { + "bind": "scales", + "name": "pan_zoom", + "select": { + "on": "[mousedown[event.metaKey], window:mouseup] > window:mousemove!", + "translate": "[mousedown[event.metaKey], window:mouseup] > window:mousemove!", + "type": "interval", + "zoom": "wheel![event.metaKey]", + }, + }, + ], +} +`; + +exports[`makeSelectable > should add bin_coloring param for layered binned charts 1`] = ` +{ + "layer": [ + { + "encoding": { + "opacity": { + "condition": { + "test": { + "and": [ + { + "param": "bin_coloring_0", + }, + ], + }, + "value": 1, + }, + "value": 0.2, + }, + "x": { + "bin": true, + "field": "x", + "type": "quantitative", + }, + "y": { + "aggregate": "count", + "type": "quantitative", + }, + }, + "mark": { + "cursor": "pointer", + "tooltip": true, + "type": "bar", + }, + "params": [ + { + "name": "select_point_0", + "select": { + "encodings": [ + "x", + ], + "on": "click[!event.metaKey]", + "type": "point", + }, + }, + { + "name": "bin_coloring_0", + "select": { + "on": "click[!event.metaKey]", + "type": "point", + }, + }, + { + "bind": "scales", + "name": "pan_zoom", + "select": { + "on": "[mousedown[event.metaKey], window:mouseup] > window:mousemove!", + "translate": "[mousedown[event.metaKey], window:mouseup] > window:mousemove!", + "type": "interval", + "zoom": "wheel![event.metaKey]", + }, + }, + ], + }, + { + "encoding": { + "color": { + "value": "red", + }, + "opacity": { + "condition": { + "test": { + "and": [ + { + "param": "select_point_1", + }, + { + "param": "select_interval_1", + }, + ], + }, + "value": 1, + }, + "value": 0.2, + }, + "x": { + "aggregate": "mean", + "field": "x", + "type": "quantitative", + }, + }, + "mark": { + "cursor": "pointer", + "tooltip": true, + "type": "rule", + }, + "params": [ + { + "name": "select_point_1", + "select": { + "encodings": [ + "x", + "y", + ], + "on": "click[!event.metaKey]", + "type": "point", + }, + }, + { + "name": "select_interval_1", + "select": { + "encodings": [ + "x", + "y", + ], + "mark": { + "fill": "#669EFF", + "fillOpacity": 0.07, + "stroke": "#669EFF", + "strokeOpacity": 0.4, + }, + "on": "[mousedown[!event.metaKey], mouseup] > mousemove[!event.metaKey]", + "translate": "[mousedown[!event.metaKey], mouseup] > mousemove[!event.metaKey]", + "type": "interval", + }, + }, + ], + }, + ], +} +`; + exports[`makeSelectable > should add legend selection to composite charts (issue #6676) 1`] = ` { "layer": [ diff --git a/frontend/src/plugins/impl/vega/__tests__/encodings.test.ts b/frontend/src/plugins/impl/vega/__tests__/encodings.test.ts index c13b4a52a2e..33f09a37ed9 100644 --- a/frontend/src/plugins/impl/vega/__tests__/encodings.test.ts +++ b/frontend/src/plugins/impl/vega/__tests__/encodings.test.ts @@ -169,4 +169,97 @@ describe("makeEncodingInteractive", () => { ), ).toEqual(expected); }); + + it("should use only bin_coloring param when present", () => { + const encodings: Encodings = { + color: { + field: "someField", + type: "quantitative", + }, + }; + + const expected = { + ...encodings, + opacity: { + condition: { + test: { + and: [{ param: "bin_coloring" }], + }, + value: 1, + }, + value: 0.2, + }, + }; + + expect( + makeEncodingInteractive( + "opacity", + encodings, + ["select_point", "bin_coloring"], + "point", + ), + ).toEqual(expected); + }); + + it("should use only bin_coloring params when multiple are present", () => { + const encodings: Encodings = { + color: { + field: "someField", + type: "quantitative", + }, + }; + + const expected = { + ...encodings, + opacity: { + condition: { + test: { + and: [{ param: "bin_coloring_0" }, { param: "bin_coloring_1" }], + }, + value: 1, + }, + value: 0.2, + }, + }; + + expect( + makeEncodingInteractive( + "opacity", + encodings, + ["select_point", "bin_coloring_0", "select_interval", "bin_coloring_1"], + "point", + ), + ).toEqual(expected); + }); + + it("should fall back to all params when no bin_coloring params present", () => { + const encodings: Encodings = { + color: { + field: "someField", + type: "quantitative", + }, + }; + + const expected = { + ...encodings, + opacity: { + condition: { + test: { + and: [{ param: "select_point" }, { param: "select_interval" }], + }, + value: 1, + }, + value: 0.2, + }, + }; + + expect( + makeEncodingInteractive( + "opacity", + encodings, + ["select_point", "select_interval"], + "point", + ), + ).toEqual(expected); + }); }); diff --git a/frontend/src/plugins/impl/vega/__tests__/make-selectable.test.ts b/frontend/src/plugins/impl/vega/__tests__/make-selectable.test.ts index 6cefe701989..3d439aa8530 100644 --- a/frontend/src/plugins/impl/vega/__tests__/make-selectable.test.ts +++ b/frontend/src/plugins/impl/vega/__tests__/make-selectable.test.ts @@ -651,6 +651,109 @@ describe("makeSelectable", () => { expect(legendParams[0]).toBe("legend_selection_category"); }); + it("should add bin_coloring param for binned charts", () => { + const spec = { + mark: "bar", + encoding: { + x: { field: "x", bin: true, type: "quantitative" }, + y: { aggregate: "count", type: "quantitative" }, + }, + } as VegaLiteSpec; + + const newSpec = makeSelectable(spec, { chartSelection: true }); + expect(newSpec).toMatchSnapshot(); + const paramNames = getSelectionParamNames(newSpec); + + // Should have point selection and bin_coloring param + expect(paramNames).toContain("select_point"); + expect(paramNames).toContain("bin_coloring"); + // Should NOT have interval selection for binned charts + expect(paramNames).not.toContain("select_interval"); + }); + + it("should add bin_coloring param for 2D binned histogram", () => { + const spec = { + mark: "rect", + encoding: { + x: { field: "x", bin: true, type: "quantitative" }, + y: { field: "y", bin: true, type: "quantitative" }, + color: { aggregate: "count", type: "quantitative" }, + }, + } as VegaLiteSpec; + + const newSpec = makeSelectable(spec, { chartSelection: true }); + expect(newSpec).toMatchSnapshot(); + const paramNames = getSelectionParamNames(newSpec); + + // Should have point selection and bin_coloring param + expect(paramNames).toContain("select_point"); + expect(paramNames).toContain("bin_coloring"); + }); + + it("should add bin_coloring param for layered binned charts", () => { + const spec = { + layer: [ + { + mark: "bar", + encoding: { + x: { field: "x", bin: true, type: "quantitative" }, + y: { aggregate: "count", type: "quantitative" }, + }, + }, + { + mark: "rule", + encoding: { + x: { aggregate: "mean", field: "x", type: "quantitative" }, + color: { value: "red" }, + }, + }, + ], + } as VegaLiteSpec; + + const newSpec = makeSelectable(spec, { chartSelection: true }); + expect(newSpec).toMatchSnapshot(); + const paramNames = getSelectionParamNames(newSpec); + + // First layer should have point selection and bin_coloring + expect(paramNames).toContain("select_point_0"); + expect(paramNames).toContain("bin_coloring_0"); + }); + + it("should prefer point selection for binned charts even when chartSelection is true", () => { + const spec = { + mark: "bar", + encoding: { + x: { field: "x", bin: { maxbins: 20 }, type: "quantitative" }, + y: { aggregate: "count", type: "quantitative" }, + }, + } as VegaLiteSpec; + + const newSpec = makeSelectable(spec, { chartSelection: true }); + const paramNames = getSelectionParamNames(newSpec); + + // Should only have point selection for binned charts (not interval) + expect(paramNames).toContain("select_point"); + expect(paramNames).not.toContain("select_interval"); + expect(paramNames).toContain("bin_coloring"); + }); + + it("should not add bin_coloring when chartSelection is false", () => { + const spec = { + mark: "bar", + encoding: { + x: { field: "x", bin: true, type: "quantitative" }, + y: { aggregate: "count", type: "quantitative" }, + }, + } as VegaLiteSpec; + + const newSpec = makeSelectable(spec, { chartSelection: false }); + const paramNames = getSelectionParamNames(newSpec); + + // Should not have any chart selection params + expect(paramNames).not.toContain("select_point"); + expect(paramNames).not.toContain("bin_coloring"); + }); + it("should collect legend fields from multiple layers with different fields", () => { const spec = { layer: [ diff --git a/frontend/src/plugins/impl/vega/__tests__/params.test.ts b/frontend/src/plugins/impl/vega/__tests__/params.test.ts index 54fba1a828c..592aff92e88 100644 --- a/frontend/src/plugins/impl/vega/__tests__/params.test.ts +++ b/frontend/src/plugins/impl/vega/__tests__/params.test.ts @@ -1,6 +1,6 @@ /* Copyright 2024 Marimo. All rights reserved. */ import { describe, expect, it } from "vitest"; -import { getDirectionOfBar } from "../params"; +import { getBinnedFields, getDirectionOfBar, ParamNames } from "../params"; import type { VegaLiteUnitSpec } from "../types"; describe("getDirectionOfBar", () => { @@ -50,3 +50,123 @@ describe("getDirectionOfBar", () => { expect(getDirectionOfBar(spec)).toBeUndefined(); }); }); + +describe("getBinnedFields", () => { + it("should return empty array if spec has no encoding", () => { + const spec = { mark: "point" } as VegaLiteUnitSpec; + expect(getBinnedFields(spec)).toEqual([]); + }); + + it("should return empty array if no fields are binned", () => { + const spec = { + mark: "point", + encoding: { + x: { field: "x", type: "quantitative" }, + y: { field: "y", type: "quantitative" }, + }, + } as VegaLiteUnitSpec; + expect(getBinnedFields(spec)).toEqual([]); + }); + + it("should return binned field name for x encoding", () => { + const spec = { + mark: "bar", + encoding: { + x: { field: "x", bin: true, type: "quantitative" }, + y: { aggregate: "count", type: "quantitative" }, + }, + } as VegaLiteUnitSpec; + expect(getBinnedFields(spec)).toEqual(["x"]); + }); + + it("should return binned field name for y encoding", () => { + const spec = { + mark: "bar", + encoding: { + x: { aggregate: "count", type: "quantitative" }, + y: { field: "y", bin: true, type: "quantitative" }, + }, + } as VegaLiteUnitSpec; + expect(getBinnedFields(spec)).toEqual(["y"]); + }); + + it("should return multiple binned fields", () => { + const spec = { + mark: "rect", + encoding: { + x: { field: "x", bin: true, type: "quantitative" }, + y: { field: "y", bin: true, type: "quantitative" }, + color: { aggregate: "count", type: "quantitative" }, + }, + } as VegaLiteUnitSpec; + expect(getBinnedFields(spec)).toEqual(["x", "y"]); + }); + + it("should handle bin with custom configuration", () => { + const spec = { + mark: "bar", + encoding: { + x: { + field: "temperature", + bin: { maxbins: 20 }, + type: "quantitative", + }, + y: { aggregate: "count", type: "quantitative" }, + }, + } as VegaLiteUnitSpec; + expect(getBinnedFields(spec)).toEqual(["temperature"]); + }); + + it("should ignore encodings without field property", () => { + const spec = { + mark: "bar", + encoding: { + x: { bin: true, type: "quantitative" }, + y: { field: "y", type: "quantitative" }, + }, + } as VegaLiteUnitSpec; + expect(getBinnedFields(spec)).toEqual([]); + }); +}); + +describe("ParamNames", () => { + describe("binColoring", () => { + it("should generate bin_coloring for undefined layer", () => { + expect(ParamNames.binColoring(undefined)).toBe("bin_coloring"); + }); + + it("should generate bin_coloring_0 for layer 0", () => { + expect(ParamNames.binColoring(0)).toBe("bin_coloring_0"); + }); + + it("should generate bin_coloring_1 for layer 1", () => { + expect(ParamNames.binColoring(1)).toBe("bin_coloring_1"); + }); + }); + + describe("isBinColoring", () => { + it("should return true for bin_coloring", () => { + expect(ParamNames.isBinColoring("bin_coloring")).toBe(true); + }); + + it("should return true for bin_coloring_0", () => { + expect(ParamNames.isBinColoring("bin_coloring_0")).toBe(true); + }); + + it("should return true for bin_coloring_123", () => { + expect(ParamNames.isBinColoring("bin_coloring_123")).toBe(true); + }); + + it("should return false for select_point", () => { + expect(ParamNames.isBinColoring("select_point")).toBe(false); + }); + + it("should return false for select_interval", () => { + expect(ParamNames.isBinColoring("select_interval")).toBe(false); + }); + + it("should return false for legend_selection_field", () => { + expect(ParamNames.isBinColoring("legend_selection_field")).toBe(false); + }); + }); +}); diff --git a/frontend/src/plugins/impl/vega/encodings.ts b/frontend/src/plugins/impl/vega/encodings.ts index a3b6cc4604f..419109d2f6f 100644 --- a/frontend/src/plugins/impl/vega/encodings.ts +++ b/frontend/src/plugins/impl/vega/encodings.ts @@ -1,5 +1,6 @@ /* Copyright 2024 Marimo. All rights reserved. */ import { Marks } from "./marks"; +import { ParamNames } from "./params"; import type { AnyMark, Encodings, @@ -84,8 +85,21 @@ export function makeEncodingInteractive( paramNames: string[], mark: AnyMark | undefined, ): SharedCompositeEncoding { + // For binned charts, we only use the bin_coloring param for opacity. + // The regular selection params (point/interval) are used for backend filtering only. + // This separation allows us to control which params trigger visual feedback vs data filtering. + // NOTE: Bin + interval selection does not change the opacity at all. + const opacityParams = paramNames.filter((paramName) => + ParamNames.isBinColoring(paramName), + ); + + // If we have bin_coloring params, use only those for opacity. + // Otherwise, use all params (non-binned chart behavior). + const relevantParams = opacityParams.length > 0 ? opacityParams : paramNames; + + // Use AND so all conditions must be met const test = { - and: paramNames.map((paramName) => ({ + and: relevantParams.map((paramName) => ({ param: paramName, })), }; diff --git a/frontend/src/plugins/impl/vega/make-selectable.ts b/frontend/src/plugins/impl/vega/make-selectable.ts index bc39bc9b060..5e2a5129fca 100644 --- a/frontend/src/plugins/impl/vega/make-selectable.ts +++ b/frontend/src/plugins/impl/vega/make-selectable.ts @@ -2,7 +2,7 @@ import { findEncodedFields, makeEncodingInteractive } from "./encodings"; import { Marks } from "./marks"; -import { Params } from "./params"; +import { getBinnedFields, Params } from "./params"; import type { GenericVegaSpec, Mark, @@ -193,10 +193,19 @@ function makeChartSelectable( return spec; } - const resolvedChartSelection = - chartSelection === true ? getBestSelectionForMark(mark) : [chartSelection]; + const binnedFields = getBinnedFields(spec); - if (!resolvedChartSelection) { + // If chartSelection is true, we use the best selection for based on the spec + // For binned charts, we use point selection + // Otherwise, we use the best selection for the mark + const resolvedChartSelection: SelectionType[] | undefined = + chartSelection === true + ? binnedFields.length > 0 + ? ["point"] + : getBestSelectionForMark(mark) + : [chartSelection]; + + if (!resolvedChartSelection || resolvedChartSelection.length === 0) { return spec; } @@ -208,6 +217,16 @@ function makeChartSelectable( const nextParams = [...(spec.params || []), ...params]; + // For binned charts, we need TWO params: + // 1. The regular selection param (point/interval) - sends signals to backend for filtering + // 2. The bin_coloring param - controls opacity/coloring, NO signal listener + // This separation allows us to filter on binned ranges while providing visual feedback + if (binnedFields.length > 0) { + if (resolvedChartSelection.includes("point")) { + nextParams.push(Params.binColoring(layerNum)); + } + } + return { ...spec, params: nextParams, diff --git a/frontend/src/plugins/impl/vega/params.ts b/frontend/src/plugins/impl/vega/params.ts index 80c4845eec5..0ea704b31a5 100644 --- a/frontend/src/plugins/impl/vega/params.ts +++ b/frontend/src/plugins/impl/vega/params.ts @@ -22,6 +22,14 @@ export const ParamNames = { legendSelection(field: string) { return `legend_selection_${field}`; }, + /** + * Special param for binned charts that controls opacity/coloring. + * This param is used for visual feedback only and does NOT send signals to the backend. + * The actual selection param (point/interval) handles backend filtering. + */ + binColoring(layerNum: number | undefined) { + return layerNum == null ? "bin_coloring" : `bin_coloring_${layerNum}`; + }, HIGHLIGHT: "highlight", PAN_ZOOM: "pan_zoom", hasPoint(names: string[]) { @@ -36,6 +44,9 @@ export const ParamNames = { hasPanZoom(names: string[]) { return names.some((name) => name.startsWith("pan_zoom")); }, + isBinColoring(name: string) { + return name.startsWith("bin_coloring"); + }, }; export const Params = { @@ -80,6 +91,20 @@ export const Params = { }, }; }, + /** + * Creates a param for binned charts that controls opacity/coloring only. + * This param does NOT send signals to the backend - it's purely for visual feedback. + * The regular selection param (point/interval) handles backend filtering. + */ + binColoring(layerNum: number | undefined): SelectionParameter<"point"> { + return { + name: ParamNames.binColoring(layerNum), + select: { + type: "point", + on: "click[!event.metaKey]", + }, + }; + }, legend(field: string): SelectionParameter<"point"> { return { name: ParamNames.legendSelection(field), @@ -198,3 +223,31 @@ export function getDirectionOfBar( return undefined; } + +/** + * Returns the binned field names from the spec. + * For binned charts, selections need to use fields instead of encodings. + */ +export function getBinnedFields(spec: VegaLiteUnitSpec): string[] { + if (!spec.encoding) { + return []; + } + + const fields: string[] = []; + + for (const channel of Object.values(spec.encoding)) { + if (channel && typeof channel === "object") { + // Check for binning + if ( + "bin" in channel && + channel.bin && + "field" in channel && + typeof channel.field === "string" + ) { + fields.push(channel.field); + } + } + } + + return fields; +} diff --git a/frontend/src/plugins/impl/vega/vega-component.tsx b/frontend/src/plugins/impl/vega/vega-component.tsx index fa986ff4492..4e887028e14 100644 --- a/frontend/src/plugins/impl/vega/vega-component.tsx +++ b/frontend/src/plugins/impl/vega/vega-component.tsx @@ -145,6 +145,13 @@ const LoadedVegaComponent = ({ return acc; } + // bin_coloring params are used ONLY for opacity/visual feedback. + // They should NOT send signals to the backend for filtering. + // The regular selection params (point/interval) handle backend filtering. + if (ParamNames.isBinColoring(name)) { + return acc; + } + acc.push({ signalName: name, handler: (signalName, signalValue) => diff --git a/marimo/_plugins/ui/_impl/altair_chart.py b/marimo/_plugins/ui/_impl/altair_chart.py index ab9a5e9ddf3..8d04182075b 100644 --- a/marimo/_plugins/ui/_impl/altair_chart.py +++ b/marimo/_plugins/ui/_impl/altair_chart.py @@ -74,6 +74,26 @@ def _has_binning(spec: VegaSpec) -> bool: return False +def _get_binned_fields(spec: VegaSpec) -> dict[str, Any]: + """Return a dictionary of field names that have binning enabled. + + Returns: + dict mapping field name to bin configuration + """ + binned_fields: dict[str, Any] = {} + if "encoding" not in spec: + return binned_fields + + for encoding in spec["encoding"].values(): + if "bin" in encoding and encoding["bin"]: + # Get the field name + field = encoding.get("field") + if field: + binned_fields[field] = encoding["bin"] + + return binned_fields + + def _has_geoshape(spec: altair.TopLevelMixin) -> bool: """Return True if the spec has geoshape.""" try: @@ -100,7 +120,10 @@ def _using_vegafusion() -> bool: def _filter_dataframe( - native_df: Union[IntoDataFrame, IntoLazyFrame], selection: ChartSelection + native_df: Union[IntoDataFrame, IntoLazyFrame], + *, + selection: ChartSelection, + binned_fields: Optional[dict[str, Any]] = None, ) -> Union[IntoDataFrame, IntoLazyFrame]: # Use lazy evaluation for efficient chained filtering base = nw.from_native(native_df) @@ -110,6 +133,9 @@ def _filter_dataframe( if not isinstance(selection, dict): raise TypeError("Input 'selection' must be a dictionary") + if binned_fields is None: + binned_fields = {} + for channel, fields in selection.items(): if not isinstance(channel, str) or not isinstance(fields, dict): raise ValueError( @@ -148,6 +174,9 @@ def _filter_dataframe( if field in ("vlPoint", "_vgsid_"): continue + # If the field is binned, we treat it as a range selection + is_binned = field in binned_fields + # Need to collect schema to check columns and dtypes schema = df.collect_schema() if field not in schema.names(): @@ -155,7 +184,8 @@ def _filter_dataframe( dtype = schema[field] resolved_values = _resolve_values(values, dtype) - if is_point_selection: + + if is_point_selection and not is_binned: df = df.filter(nw.col(field).is_in(resolved_values)) elif len(resolved_values) == 1: df = df.filter(nw.col(field) == resolved_values[0]) @@ -164,10 +194,41 @@ def _filter_dataframe( resolved_values[0] ): left_value, right_value = resolved_values - df = df.filter( - (nw.col(field) >= left_value) - & (nw.col(field) <= right_value) - ) + + # For binned fields, we need to check if this is the last bin + # by comparing the right boundary to the maximum value in the dataset. + # If they're equal (or right boundary >= max), use inclusive right boundary. + if is_binned: + # Get the maximum value in the dataset for this field + max_value_df = df.select(nw.col(field).max()) + max_value_collected = ( + max_value_df.collect() + if is_narwhals_lazyframe(max_value_df) + else max_value_df + ) + max_value = max_value_collected[field][0] + + # If right boundary >= max value, this is the last bin + is_last_bin = right_value >= max_value + + if is_last_bin: + # Last bin: use inclusive right boundary + df = df.filter( + (nw.col(field) >= left_value) + & (nw.col(field) <= right_value) + ) + else: + # Not last bin: use exclusive right boundary + df = df.filter( + (nw.col(field) >= left_value) + & (nw.col(field) < right_value) + ) + else: + # Non-binned fields: use inclusive right boundary + df = df.filter( + (nw.col(field) >= left_value) + & (nw.col(field) <= right_value) + ) # Multi-selection via range # This can happen when you use an interval selection # on categorical data @@ -394,21 +455,14 @@ def __init__( ) legend_selection = False - # Selection for binned charts is not yet implemented has_chart_selection = chart_selection is not False has_legend_selection = legend_selection is not False - if _has_binning(vega_spec) and ( - has_chart_selection or has_legend_selection - ): + if _has_binning(vega_spec) and chart_selection == "interval": sys.stderr.write( - "Binning + selection is not yet supported in " - "marimo.ui.chart.\n" - "If you'd like this feature, please file an issue: " - "https://github.com/marimo-team/marimo/issues\n" + "Binning + interval selection does not highlight the bins. " + "You can use point selection instead." ) - chart_selection = False - legend_selection = False - if _has_geoshape(chart) and (has_chart_selection): + if _has_geoshape(chart) and has_chart_selection: sys.stderr.write( "Geoshapes + chart selection is not yet supported in " "marimo.ui.chart.\n" @@ -433,6 +487,7 @@ def __init__( ) self._spec = vega_spec + self._binned_fields = _get_binned_fields(vega_spec) super().__init__( component_name="marimo-vega", @@ -509,7 +564,9 @@ def _convert_value(self, value: ChartSelection) -> ChartDataType: if _has_transforms(self._spec): try: df: Any = self._chart.transformed_data() - return _filter_dataframe(df, value) + return _filter_dataframe( + df, selection=value, binned_fields=self._binned_fields + ) except ImportError as e: sys.stderr.write( "Failed to filter dataframe that includes a transform. " @@ -517,9 +574,15 @@ def _convert_value(self, value: ChartSelection) -> ChartDataType: + e.msg ) # Fall back to the untransformed dataframe - return _filter_dataframe(self.dataframe, value) + return _filter_dataframe( + self.dataframe, + selection=value, + binned_fields=self._binned_fields, + ) - return _filter_dataframe(self.dataframe, value) + return _filter_dataframe( + self.dataframe, selection=value, binned_fields=self._binned_fields + ) def apply_selection(self, df: ChartDataType) -> ChartDataType: """Apply the selection to a DataFrame. @@ -559,7 +622,9 @@ def apply_selection(self, df: ChartDataType) -> ChartDataType: ``` """ assert assert_can_narwhalify(df) - return _filter_dataframe(df, self.selections) + return _filter_dataframe( + df, selection=self.selections, binned_fields=self._binned_fields + ) # Proxy all of altair's attributes def __getattr__(self, name: str) -> Any: diff --git a/marimo/_smoke_tests/altair_examples/binning.py b/marimo/_smoke_tests/altair_examples/binning.py new file mode 100644 index 00000000000..08a35ab90d6 --- /dev/null +++ b/marimo/_smoke_tests/altair_examples/binning.py @@ -0,0 +1,368 @@ +# /// script +# requires-python = ">=3.11" +# dependencies = [ +# "marimo", +# "pandas", +# "numpy", +# "vega-datasets", +# "altair", +# "pyarrow==22.0.0", +# "polars==1.34.0", +# ] +# /// + +import marimo + +__generated_with = "0.17.2" +app = marimo.App(width="medium") + + +@app.cell(hide_code=True) +def _(): + import marimo as mo + return (mo,) + + +@app.cell(hide_code=True) +def _(mo): + mo.md(r"""# Binning Examples""") + return + + +@app.cell +def _(): + import altair as alt + import numpy as np + import pandas as pd + import polars as pl + import pyarrow + from vega_datasets import data + return alt, data, pl + + +@app.cell +def _(data, pl): + # Load datasets + movies = pl.DataFrame(data.movies().drop(columns=["Title"])) + cars = pl.DataFrame(data.cars()) + return cars, movies + + +@app.cell(hide_code=True) +def _(mo): + mo.md(r"""## Basic Histogram with Auto-binning""") + return + + +@app.cell +def _(alt, mo, movies): + # Simple histogram with automatic binning + basic_histogram = mo.ui.altair_chart( + alt.Chart(movies.head(100)) + .mark_bar() + .encode( + x=alt.X("IMDB_Rating:Q", bin=True, title="IMDB Rating"), + y=alt.Y("count()", title="Count of Movies"), + ) + .properties( + width="container", height=300, title="Distribution of IMDB Ratings" + ), + ) + basic_histogram + return (basic_histogram,) + + +@app.cell +def _(basic_histogram, mo): + mo.vstack([basic_histogram.value]) + return + + +@app.cell +def _(): + # basic_histogram.selections + return + + +@app.cell(hide_code=True) +def _(mo): + mo.md(r"""## Histogram with Custom Bin Parameters""") + return + + +@app.cell +def _(alt, cars, mo): + # Histogram with custom maxbins parameter + custom_bins = mo.ui.altair_chart( + alt.Chart(cars) + .mark_bar() + .encode( + x=alt.X( + "Miles_per_Gallon:Q", + bin=alt.Bin(maxbins=20), + title="Miles per Gallon", + ), + y=alt.Y("count()", title="Count"), + ) + .properties( + width="container", height=300, title="MPG Distribution (20 bins max)" + ) + ) + custom_bins + return (custom_bins,) + + +@app.cell +def _(custom_bins, mo): + mo.vstack([custom_bins.value]) + return + + +@app.cell(hide_code=True) +def _(mo): + mo.md(r"""## Histogram with Fixed Bin Step""") + return + + +@app.cell +def _(alt, cars, mo): + # Histogram with fixed step size + step_bins = mo.ui.altair_chart( + alt.Chart(cars) + .mark_bar() + .encode( + x=alt.X( + "Horsepower:Q", + bin=alt.Bin(step=25), + title="Horsepower", + ), + y=alt.Y("count()", title="Count"), + ) + .properties(width="container", height=300, title="Horsepower (bins of 25)") + ) + step_bins + return (step_bins,) + + +@app.cell +def _(mo, step_bins): + mo.vstack([step_bins.value]) + return + + +@app.cell(hide_code=True) +def _(mo): + mo.md(r"""## 2D Binning (Heatmap)""") + return + + +@app.cell +def _(alt, mo, movies): + # 2D histogram using rect mark + heatmap = mo.ui.altair_chart( + alt.Chart(movies) + .mark_rect() + .encode( + x=alt.X("IMDB_Rating:Q", bin=alt.Bin(maxbins=10), title="IMDB Rating"), + y=alt.Y( + "Rotten_Tomatoes_Rating:Q", + bin=alt.Bin(maxbins=10), + title="Rotten Tomatoes Rating", + ), + color=alt.Color("count()", scale=alt.Scale(scheme="viridis")), + ) + .properties( + width="container", + height=300, + title="2D Distribution of Movie Ratings", + ) + ) + heatmap + return (heatmap,) + + +@app.cell +def _(heatmap, mo): + mo.vstack([heatmap.value]) + return + + +@app.cell(hide_code=True) +def _(mo): + mo.md(r"""## Binned Color Encoding""") + return + + +@app.cell +def _(alt, cars, mo): + # Scatter plot with binned color encoding + binned_color = mo.ui.altair_chart( + alt.Chart(cars) + .mark_circle(size=60) + .encode( + x=alt.X("Horsepower:Q", title="Horsepower"), + y=alt.Y("Miles_per_Gallon:Q", title="MPG"), + color=alt.Color( + "Acceleration:Q", + bin=alt.Bin(maxbins=5), + title="Acceleration (binned)", + ), + ) + .properties( + width="container", + height=400, + title="Car Performance with Binned Acceleration", + ) + ) + binned_color + return (binned_color,) + + +@app.cell +def _(binned_color, mo): + mo.vstack([binned_color.value]) + return + + +@app.cell(hide_code=True) +def _(mo): + mo.md(r"""## Binned Size Encoding""") + return + + +@app.cell +def _(alt, cars, mo): + # Scatter plot with binned size encoding + binned_size = mo.ui.altair_chart( + alt.Chart(cars) + .mark_point() + .encode( + x=alt.X("Horsepower:Q", title="Horsepower"), + y=alt.Y("Miles_per_Gallon:Q", title="MPG"), + size=alt.Size( + "Displacement:Q", + bin=alt.Bin(maxbins=4), + title="Displacement (binned)", + ), + color=alt.Color("Origin:N"), + ) + .properties( + width="container", + height=400, + title="Car Performance with Binned Displacement", + ) + ) + binned_size + return (binned_size,) + + +@app.cell +def _(binned_size, mo): + mo.vstack([binned_size.value]) + return + + +@app.cell(hide_code=True) +def _(mo): + mo.md(r"""## Stacked Histogram with Binning""") + return + + +@app.cell +def _(alt, cars, mo): + # Stacked histogram showing distribution by origin + stacked_histogram = mo.ui.altair_chart( + alt.Chart(cars) + .mark_bar() + .encode( + x=alt.X("Horsepower:Q", bin=alt.Bin(maxbins=15), title="Horsepower"), + y=alt.Y("count()", title="Count"), + color=alt.Color("Origin:N", title="Origin"), + ) + .properties( + width="container", + height=300, + title="Horsepower Distribution by Origin", + ) + ) + stacked_histogram + return (stacked_histogram,) + + +@app.cell +def _(mo, stacked_histogram): + mo.vstack([stacked_histogram.value]) + return + + +@app.cell(hide_code=True) +def _(mo): + mo.md(r"""## Binning with Selection""") + return + + +@app.cell +def _(alt, mo, movies): + # Interactive histogram with selection + binned_selection = mo.ui.altair_chart( + alt.Chart(movies) + .mark_bar() + .encode( + x=alt.X("IMDB_Rating:Q", bin=True, title="IMDB Rating"), + y=alt.Y("count()", title="Count"), + ) + .properties( + width="container", + height=300, + title="IMDB Rating Distribution (Interactive)", + ), + chart_selection="interval", + ) + binned_selection + return (binned_selection,) + + +@app.cell +def _(binned_selection, mo): + mo.vstack([binned_selection.value, binned_selection.selections]) + return + + +@app.cell(hide_code=True) +def _(mo): + mo.md(r"""## Custom Bin Extent""") + return + + +@app.cell +def _(alt, cars, mo): + # Histogram with custom extent (min and max range) + extent_bins = mo.ui.altair_chart( + alt.Chart(cars) + .mark_bar() + .encode( + x=alt.X( + "Miles_per_Gallon:Q", + bin=alt.Bin(extent=[10, 50], step=5), + title="Miles per Gallon", + ), + y=alt.Y("count()", title="Count"), + ) + .properties( + width="container", + height=300, + title="MPG Distribution (10-50 range, step of 5)", + ) + ) + extent_bins + return (extent_bins,) + + +@app.cell +def _(extent_bins, mo): + mo.vstack([extent_bins.value]) + return + + +if __name__ == "__main__": + app.run() diff --git a/tests/_plugins/ui/_impl/test_altair_chart.py b/tests/_plugins/ui/_impl/test_altair_chart.py index e8cf20242c7..ba817c32e06 100644 --- a/tests/_plugins/ui/_impl/test_altair_chart.py +++ b/tests/_plugins/ui/_impl/test_altair_chart.py @@ -17,6 +17,7 @@ ChartDataType, ChartSelection, _filter_dataframe, + _get_binned_fields, _has_binning, _has_geoshape, _has_legend_param, @@ -87,7 +88,7 @@ def test_filter_dataframe(df: ChartDataType) -> None: "signal_channel_1": {"vlPoint": [1], "field": ["value1", "value2"]} } # Filter the DataFrame with the point selection - assert get_len(_filter_dataframe(df, point_selection)) == 2 + assert get_len(_filter_dataframe(df, selection=point_selection)) == 2 # Point selected with a no fields point_selection = { @@ -97,7 +98,7 @@ def test_filter_dataframe(df: ChartDataType) -> None: }, } # Filter the DataFrame with the point selection - filtered_df = _filter_dataframe(df, point_selection) + filtered_df = _filter_dataframe(df, selection=point_selection) assert get_len(filtered_df) == 2 first, second = maybe_collect(filtered_df)["field"] assert str(first) == "value2" @@ -108,7 +109,7 @@ def test_filter_dataframe(df: ChartDataType) -> None: "signal_channel_2": {"field_2": [1, 3]} } # Filter the DataFrame with the interval selection - filtered_df = _filter_dataframe(df, interval_selection) + filtered_df = _filter_dataframe(df, selection=interval_selection) assert get_len(filtered_df) == 3 # Define an interval selection with multiple fields @@ -116,7 +117,7 @@ def test_filter_dataframe(df: ChartDataType) -> None: "signal_channel_1": {"field_2": [1, 3], "field_3": [30, 40]} } # Filter the DataFrame with the multi-field selection - filtered_df = _filter_dataframe(df, multi_field_selection) + filtered_df = _filter_dataframe(df, selection=multi_field_selection) assert get_len(filtered_df) == 1 # Define an interval selection with multiple fields @@ -125,7 +126,9 @@ def test_filter_dataframe(df: ChartDataType) -> None: "signal_channel_2": {"vlPoint": [1], "color": ["red"]}, } # Filter the DataFrame with the multi-field selection - filtered_df = _filter_dataframe(df, interval_and_point_selection) + filtered_df = _filter_dataframe( + df, selection=interval_and_point_selection + ) assert get_len(filtered_df) == 1 @staticmethod @@ -182,7 +185,7 @@ def test_filter_dataframe_with_dates( } } # Filter the DataFrame with the interval selection - filtered_df = _filter_dataframe(df, interval_selection) + filtered_df = _filter_dataframe(df, selection=interval_selection) assert get_len(filtered_df) == 2 first, second = maybe_collect(filtered_df)["field"] assert str(first) == "value1" @@ -197,7 +200,7 @@ def test_filter_dataframe_with_dates( ] } } - filtered_df = _filter_dataframe(df, interval_selection) + filtered_df = _filter_dataframe(df, selection=interval_selection) assert get_len(filtered_df) == 2 first, second = maybe_collect(filtered_df)["field"] assert str(first) == "value1" @@ -218,7 +221,7 @@ def test_filter_dataframe_with_dates( ] } } - filtered_df = _filter_dataframe(df, interval_selection) + filtered_df = _filter_dataframe(df, selection=interval_selection) assert get_len(filtered_df) == 2 first, second = maybe_collect(filtered_df)["field"] assert str(first) == "value1" @@ -235,7 +238,7 @@ def test_filter_dataframe_with_dates( } } # Filter the DataFrame with the interval selection - filtered_df = _filter_dataframe(df, interval_selection) + filtered_df = _filter_dataframe(df, selection=interval_selection) assert get_len(filtered_df) == 2 first, second = maybe_collect(filtered_df)["field"] assert str(first) == "value1" @@ -271,7 +274,7 @@ def test_filter_dataframe_with_datetimes_as_strings( get_len( _filter_dataframe( df, - { + selection={ "select_point": { "datetime_column_utc": [ datetime.datetime( @@ -295,7 +298,7 @@ def test_filter_dataframe_with_datetimes_as_strings( get_len( _filter_dataframe( df, - { + selection={ "select_interval": { "datetime_column_utc": [ datetime.datetime( @@ -327,7 +330,7 @@ def test_filter_dataframe_with_datetimes_as_strings( get_len( _filter_dataframe( df, - { + selection={ "select_interval": { "datetime_column_utc": [ datetime.datetime( @@ -361,7 +364,7 @@ def test_filter_dataframe_with_datetimes_as_strings( get_len( _filter_dataframe( df, - { + selection={ "select_interval": { "datetime_column": [ datetime.datetime(2019, 12, 29).isoformat(), @@ -379,7 +382,7 @@ def test_filter_dataframe_with_datetimes_as_strings( get_len( _filter_dataframe( df, - { + selection={ "select_interval": { "datetime_column": [ datetime.datetime( @@ -430,7 +433,7 @@ def to_milliseconds(seconds: int) -> int: get_len( _filter_dataframe( df, - { + selection={ "select_interval": { "datetime_column_utc": [ 0, @@ -446,7 +449,7 @@ def to_milliseconds(seconds: int) -> int: get_len( _filter_dataframe( df, - { + selection={ "select_interval": { "datetime_column_utc": [ milliseconds_since_epoch, @@ -463,7 +466,7 @@ def to_milliseconds(seconds: int) -> int: get_len( _filter_dataframe( df, - { + selection={ "select_interval": { "datetime_column_utc": [ milliseconds_since_epoch, @@ -482,7 +485,7 @@ def to_milliseconds(seconds: int) -> int: get_len( _filter_dataframe( df, - { + selection={ "select_interval": { "datetime_column": [ 0, @@ -502,7 +505,7 @@ def to_milliseconds(seconds: int) -> int: get_len( _filter_dataframe( df, - { + selection={ "select_interval": { "datetime_column": [ datetime.datetime( @@ -769,6 +772,119 @@ def test_parse_spec_geopandas() -> None: snapshot("parse_spec_geopandas.txt", json.dumps(spec, indent=2)) +@pytest.mark.skipif(not HAS_DEPS, reason="optional dependencies not installed") +def test_get_binned_fields() -> None: + """Test _get_binned_fields detection for various binning configurations.""" + import altair as alt + + # Case 1: No binning - should return empty dict + spec_no_binning = _parse_spec( + alt.Chart(pd.DataFrame({"x": [1, 2, 3], "y": [4, 5, 6]})) + .mark_point() + .encode(x="x", y="y") + ) + binned_fields = _get_binned_fields(spec_no_binning) + assert binned_fields == {} + + # Case 2: Single field with bin=True + spec_bin_true = _parse_spec( + alt.Chart(pd.DataFrame({"values": range(100)})) + .mark_bar() + .encode(x=alt.X("values", bin=True), y="count()") + ) + binned_fields = _get_binned_fields(spec_bin_true) + assert "values" in binned_fields + assert binned_fields["values"] is True + + # Case 3: Single field with bin configuration + spec_bin_config = _parse_spec( + alt.Chart(pd.DataFrame({"values": range(100)})) + .mark_bar() + .encode(x=alt.X("values", bin=alt.Bin(maxbins=20)), y="count()") + ) + binned_fields = _get_binned_fields(spec_bin_config) + assert "values" in binned_fields + assert isinstance(binned_fields["values"], dict) + assert binned_fields["values"]["maxbins"] == 20 + + # Case 4: Bin configuration with step + spec_bin_step = _parse_spec( + alt.Chart(pd.DataFrame({"values": range(100)})) + .mark_bar() + .encode(x=alt.X("values", bin=alt.Bin(step=10)), y="count()") + ) + binned_fields = _get_binned_fields(spec_bin_step) + assert "values" in binned_fields + assert isinstance(binned_fields["values"], dict) + assert binned_fields["values"]["step"] == 10 + + # Case 5: Multiple binned fields (2D histogram) + spec_multiple_bins = _parse_spec( + alt.Chart(pd.DataFrame({"x": range(100), "y": range(100)})) + .mark_rect() + .encode( + x=alt.X("x", bin=True), + y=alt.Y("y", bin=alt.Bin(maxbins=15)), + color="count()", + ) + ) + binned_fields = _get_binned_fields(spec_multiple_bins) + assert "x" in binned_fields + assert "y" in binned_fields + assert binned_fields["x"] is True + assert isinstance(binned_fields["y"], dict) + assert binned_fields["y"]["maxbins"] == 15 + + # Case 6: Mix of binned and non-binned fields + spec_mixed = _parse_spec( + alt.Chart( + pd.DataFrame( + { + "x": range(100), + "y": range(100), + "color": ["A"] * 50 + ["B"] * 50, + } + ) + ) + .mark_bar() + .encode( + x=alt.X("x", bin=True), + y="count()", + color="color:N", # Not binned + ) + ) + binned_fields = _get_binned_fields(spec_mixed) + assert "x" in binned_fields + assert "color" not in binned_fields + assert binned_fields["x"] is True + + # Case 7: Binned field on y-axis + spec_y_binned = _parse_spec( + alt.Chart(pd.DataFrame({"values": range(100)})) + .mark_bar() + .encode(x="count()", y=alt.Y("values", bin=True)) + ) + binned_fields = _get_binned_fields(spec_y_binned) + assert "values" in binned_fields + assert binned_fields["values"] is True + + # Case 8: Spec with no encoding (should not error) + spec_no_encoding = {"mark": "point"} + binned_fields = _get_binned_fields(spec_no_encoding) + assert binned_fields == {} + + # Case 9: Bin with extent + spec_bin_extent = _parse_spec( + alt.Chart(pd.DataFrame({"values": range(100)})) + .mark_bar() + .encode(x=alt.X("values", bin=alt.Bin(extent=[0, 50])), y="count()") + ) + binned_fields = _get_binned_fields(spec_bin_extent) + assert "values" in binned_fields + assert isinstance(binned_fields["values"], dict) + assert binned_fields["values"]["extent"] == [0, 50] + + @pytest.mark.skipif(not HAS_DEPS, reason="optional dependencies not installed") def test_has_geoshape() -> None: import altair as alt @@ -871,8 +987,8 @@ def test_chart_with_binning(df: IntoDataFrame): marimo_chart = altair_chart(chart) assert _has_binning(marimo_chart._spec) - # Test that selection is disabled for binned charts - assert marimo_chart._component_args["chart-selection"] is False + # Test that selection is now enabled for binned charts + assert marimo_chart._component_args["chart-selection"] is not False @pytest.mark.skipif(not HAS_DEPS, reason="optional dependencies not installed") @@ -900,6 +1016,250 @@ def test_apply_selection(df: IntoDataFrame): assert all(maybe_collect(filtered_data)["category"] == "A") +@pytest.mark.skipif(not HAS_DEPS, reason="optional dependencies not installed") +@pytest.mark.parametrize( + "df", + create_dataframes( + { + "values": [10, 15, 20, 25, 30, 35, 40, 45], + "category": ["A", "A", "B", "B", "C", "C", "D", "D"], + }, + ), +) +def test_filter_dataframe_with_binned_fields(df: ChartDataType) -> None: + """Test filtering with binned fields using interval selection.""" + # Define binned fields (simulating what would come from _get_binned_fields) + binned_fields = {"values": True} + + # Interval selection on a binned field - selecting bins from 20 to 30 + # This should include values where 20 <= values < 30 + interval_selection: ChartSelection = { + "signal_channel": {"values": [20, 30]} + } + filtered_df = _filter_dataframe( + df, selection=interval_selection, binned_fields=binned_fields + ) + assert get_len(filtered_df) == 2 + collected = maybe_collect(filtered_df) + assert all(collected["values"] >= 20) + assert all(collected["values"] < 30) + + # Test with wider range (not including max value) + wider_selection: ChartSelection = {"signal_channel": {"values": [10, 40]}} + filtered_df = _filter_dataframe( + df, selection=wider_selection, binned_fields=binned_fields + ) + assert get_len(filtered_df) == 6 + collected = maybe_collect(filtered_df) + assert all(collected["values"] >= 10) + assert all(collected["values"] < 40) + + # Test boundary values - right boundary is not inclusive for non-last bin + boundary_selection: ChartSelection = { + "signal_channel": {"values": [30, 40]} + } + filtered_df = _filter_dataframe( + df, selection=boundary_selection, binned_fields=binned_fields + ) + assert get_len(filtered_df) == 2 + collected = maybe_collect(filtered_df) + assert 30 in collected["values"] + assert 35 in collected["values"] + assert 40 not in collected["values"] + + # Test last bin - right boundary SHOULD be inclusive + # When selecting to the max value (45), it should be included + last_bin_selection: ChartSelection = { + "signal_channel": {"values": [40, 45]} + } + filtered_df = _filter_dataframe( + df, selection=last_bin_selection, binned_fields=binned_fields + ) + assert get_len(filtered_df) == 2 + collected = maybe_collect(filtered_df) + assert 40 in collected["values"] + assert 45 in collected["values"] # Last bin includes right boundary + + +@pytest.mark.skipif(not HAS_DEPS, reason="optional dependencies not installed") +@pytest.mark.parametrize( + "df", + create_dataframes( + { + "values": list(range(0, 100, 10)), + "id": list(range(10)), + }, + ), +) +def test_filter_dataframe_binned_with_multiple_selections( + df: ChartDataType, +) -> None: + """Test filtering with binned fields and multiple selection channels.""" + binned_fields = {"values": True} + + # Multiple selection channels + multi_selection: ChartSelection = { + "signal_channel_1": {"values": [20, 50]}, + "signal_channel_2": {"id": [2, 6]}, + } + filtered_df = _filter_dataframe( + df, selection=multi_selection, binned_fields=binned_fields + ) + # Should have values >= 20 and < 50 AND id >= 2 and < 6 + assert get_len(filtered_df) == 3 + collected = maybe_collect(filtered_df) + assert all(collected["values"] >= 20) + assert all(collected["values"] < 50) + assert all(collected["id"] >= 2) + assert all(collected["id"] < 6) + + +@pytest.mark.skipif(not HAS_DEPS, reason="optional dependencies not installed") +@pytest.mark.parametrize( + "df", + create_dataframes( + { + "timestamp": [ + datetime.datetime(2020, 1, 1), + datetime.datetime(2020, 2, 1), + datetime.datetime(2020, 3, 1), + datetime.datetime(2020, 4, 1), + datetime.datetime(2020, 5, 1), + ], + "value": [10, 20, 30, 40, 50], + }, + ), +) +def test_filter_dataframe_binned_dates(df: ChartDataType) -> None: + """Test filtering with binned date fields.""" + binned_fields = {"timestamp": True} + + # Interval selection on binned date field (not last bin) + # Vega sends milliseconds since epoch + start = int(datetime.datetime(2020, 2, 1).timestamp() * 1000) + end = int(datetime.datetime(2020, 4, 1).timestamp() * 1000) + + interval_selection: ChartSelection = { + "signal_channel": {"timestamp": [start, end]} + } + filtered_df = _filter_dataframe( + df, selection=interval_selection, binned_fields=binned_fields + ) + assert get_len(filtered_df) == 2 + collected = maybe_collect(filtered_df) + timestamps = collected["timestamp"] + # Should include Feb and Mar, but not Apr (right boundary non-inclusive for non-last bin) + assert datetime.datetime(2020, 2, 1) in timestamps + assert datetime.datetime(2020, 3, 1) in timestamps + assert datetime.datetime(2020, 4, 1) not in timestamps + + # Test last bin - should include the right boundary + start_last = int(datetime.datetime(2020, 4, 1).timestamp() * 1000) + end_last = int(datetime.datetime(2020, 5, 1).timestamp() * 1000) + + last_bin_selection: ChartSelection = { + "signal_channel": {"timestamp": [start_last, end_last]} + } + filtered_df = _filter_dataframe( + df, selection=last_bin_selection, binned_fields=binned_fields + ) + assert get_len(filtered_df) == 2 + collected = maybe_collect(filtered_df) + timestamps = collected["timestamp"] + # Last bin should include May (right boundary inclusive) + assert datetime.datetime(2020, 4, 1) in timestamps + assert datetime.datetime(2020, 5, 1) in timestamps + + +@pytest.mark.skipif(not HAS_DEPS, reason="optional dependencies not installed") +@pytest.mark.parametrize( + "df", + create_dataframes( + { + "values": [5, 10, 15, 20, 25, 30], + "category": ["A", "A", "B", "B", "C", "C"], + }, + ), +) +def test_filter_dataframe_binned_with_point_selection( + df: ChartDataType, +) -> None: + """Test that point selection works correctly with binned fields.""" + binned_fields = {"values": True} + + # Point selection should still work even with binned fields + # However, point selections on binned fields should be treated as intervals + point_selection: ChartSelection = { + "signal_channel": { + "vlPoint": [1], + "values": [10, 20], + } + } + filtered_df = _filter_dataframe( + df, selection=point_selection, binned_fields=binned_fields + ) + # With binning, should filter as a range + assert get_len(filtered_df) == 2 + collected = maybe_collect(filtered_df) + assert all(collected["values"] >= 10) + assert all(collected["values"] < 20) + + +@pytest.mark.skipif(not HAS_DEPS, reason="optional dependencies not installed") +@pytest.mark.parametrize( + "df", + create_dataframes({"values": range(100)}, exclude=["lazy-polars"]), +) +def test_chart_binning_end_to_end(df: IntoDataFrame): + """Test binning with selection end-to-end through altair_chart.""" + import altair as alt + + chart = ( + alt.Chart(df) + .mark_bar() + .encode(x=alt.X("values", bin=True), y="count()") + ) + + marimo_chart = altair_chart(chart) + + # Simulate a selection from the frontend (bin from 20 to 30, not last bin) + marimo_chart._chart_selection = {"select_interval": {"values": [20, 30]}} + + # Get filtered data + filtered = marimo_chart._convert_value(marimo_chart._chart_selection) + assert get_len(filtered) == 10 + collected = maybe_collect(filtered) + assert all(collected["values"] >= 20) + assert all(collected["values"] < 30) + + # Test last bin (should include right boundary) + marimo_chart._chart_selection = {"select_interval": {"values": [90, 99]}} + filtered = marimo_chart._convert_value(marimo_chart._chart_selection) + assert get_len(filtered) == 10 + collected = maybe_collect(filtered) + assert all(collected["values"] >= 90) + assert all(collected["values"] <= 99) + assert 99 in collected["values"] # Last bin includes max value + + +@pytest.mark.skipif(not HAS_DEPS, reason="optional dependencies not installed") +def test_filter_dataframe_without_binned_fields() -> None: + """Test that filtering works normally when binned_fields is None.""" + df = pd.DataFrame({"values": [10, 20, 30, 40, 50]}) + + # Without binned_fields (default behavior) + interval_selection: ChartSelection = { + "signal_channel": {"values": [20, 40]} + } + filtered_df = _filter_dataframe(df, selection=interval_selection) + # Without binning flag, should use inclusive right boundary + assert get_len(filtered_df) == 3 + collected = maybe_collect(filtered_df) + assert 20 in collected["values"] + assert 30 in collected["values"] + assert 40 in collected["values"] + + @pytest.mark.skipif(not HAS_DEPS, reason="optional dependencies not installed") def test_value_is_not_available() -> None: import altair as alt