diff --git a/frontend/src/components/data-table/__tests__/__snapshots__/chart-spec-model.test.ts.snap b/frontend/src/components/data-table/__tests__/__snapshots__/chart-spec-model.test.ts.snap index 6d2f56d9454..f397e4f6530 100644 --- a/frontend/src/components/data-table/__tests__/__snapshots__/chart-spec-model.test.ts.snap +++ b/frontend/src/components/data-table/__tests__/__snapshots__/chart-spec-model.test.ts.snap @@ -123,7 +123,7 @@ exports[`ColumnChartSpecModel > file URL handling > should handle arrow data 1`] } `; -exports[`ColumnChartSpecModel > should expect bin values to be used for number and integer columns when feat flag is true 1`] = ` +exports[`ColumnChartSpecModel > should expect bin values to be used for number and integer columns 1`] = ` { "background": "transparent", "config": { @@ -153,13 +153,14 @@ exports[`ColumnChartSpecModel > should expect bin values to be used for number a "layer": [ { "encoding": { - "strokeWidth": { - "condition": { - "empty": false, - "param": "hover", - "value": 0.5, - }, - "value": 0, + "opacity": { + "condition": [ + { + "param": "hover", + "value": 1, + }, + ], + "value": 0.6, }, "x": { "bin": { @@ -181,20 +182,8 @@ exports[`ColumnChartSpecModel > should expect bin values to be used for number a }, "mark": { "color": "#027864", - "stroke": "#027864", - "strokeWidth": 0, "type": "bar", }, - "params": [ - { - "name": "hover", - "select": { - "clear": "mouseout", - "on": "mouseover", - "type": "point", - }, - }, - ], }, { "encoding": { @@ -242,6 +231,17 @@ exports[`ColumnChartSpecModel > should expect bin values to be used for number a "opacity": 0, "type": "bar", }, + "params": [ + { + "name": "hover", + "select": { + "clear": "mouseout", + "nearest": true, + "on": "mouseover", + "type": "point", + }, + }, + ], "transform": [ { "as": "bin_range", @@ -254,7 +254,7 @@ exports[`ColumnChartSpecModel > should expect bin values to be used for number a } `; -exports[`ColumnChartSpecModel > should handle boolean stats when feat flag is true 1`] = ` +exports[`ColumnChartSpecModel > should handle boolean stats 1`] = ` { "background": "transparent", "config": { @@ -378,16 +378,13 @@ exports[`ColumnChartSpecModel > should handle boolean stats when feat flag is tr } `; -exports[`ColumnChartSpecModel > should handle datetime bin values when feat flag is true 1`] = ` +exports[`ColumnChartSpecModel > should handle datetime bin values 1`] = ` { "background": "transparent", "config": { "axis": { "domain": false, }, - "concat": { - "spacing": 0, - }, "view": { "stroke": "transparent", }, @@ -412,177 +409,98 @@ exports[`ColumnChartSpecModel > should handle datetime bin values when feat flag }, ], }, - "hconcat": [ + "height": 100, + "layer": [ { - "height": 30, - "layer": [ - { - "encoding": { - "x": { - "axis": null, - "field": "bin_start", - "type": "nominal", - }, - "y": { - "axis": null, - "field": "count", - "type": "quantitative", - }, - }, - "mark": { - "color": "#cc4e00", - "type": "bar", + "encoding": { + "color": { + "condition": { + "test": "datum['bin_start'] === null && datum['bin_end'] === null", + "value": "#cc4e00", }, + "value": "#027864", }, - { - "encoding": { - "tooltip": [ - { - "field": "count", - "format": ",d", - "title": "nulls", - "type": "quantitative", - }, - ], - "x": { - "axis": null, - "field": "bin_start", - "type": "nominal", - }, - "y": { - "aggregate": "max", - "axis": null, - "type": "quantitative", + "opacity": { + "condition": [ + { + "param": "hover", + "value": 1, }, - }, - "mark": { - "opacity": 0, - "type": "bar", - }, + ], + "value": 0.6, }, - ], - "transform": [ + "x": { + "axis": null, + "field": "bin_start", + "type": "ordinal", + }, + "y": { + "axis": null, + "field": "count", + "type": "quantitative", + }, + }, + "mark": { + "color": "#027864", + "type": "bar", + }, + "params": [ { - "filter": "datum['bin_start'] === null && datum['bin_end'] === null", + "name": "hover", + "select": { + "clear": "mouseout", + "on": "mouseover", + "type": "point", + }, }, ], - "width": 5, }, { - "height": 30, - "layer": [ - { - "encoding": { - "strokeWidth": { - "condition": { - "empty": false, - "param": "hover", - "value": 0.5, - }, - "value": 0, - }, - "x": { - "axis": null, - "bin": { - "binned": true, - "step": 86400000, - }, - "field": "bin_start", - "type": "temporal", - }, - "x2": { - "axis": null, - "field": "bin_end", - "type": "temporal", - }, - "y": { - "axis": null, - "field": "count", - "type": "quantitative", - }, - }, - "mark": { - "color": "#027864", - "stroke": "#027864", - "strokeWidth": 0, - "type": "bar", + "encoding": { + "tooltip": [ + { + "field": "bin_start", + "timeUnit": "yearmonthdate", + "title": "datetime (start)", + "type": "temporal", }, - "params": [ - { - "name": "hover", - "select": { - "clear": "mouseout", - "on": "mouseover", - "type": "point", - }, - }, - ], - }, - { - "encoding": { - "tooltip": [ - { - "field": "bin_start", - "timeUnit": "yearmonthdate", - "title": "datetime (start)", - "type": "temporal", - }, - { - "field": "bin_end", - "timeUnit": "yearmonthdate", - "title": "datetime (end)", - "type": "temporal", - }, - { - "field": "count", - "format": ",d", - "title": "Count", - "type": "quantitative", - }, - ], - "x": { - "axis": null, - "bin": { - "binned": true, - "step": 86400000, - }, - "field": "bin_start", - "type": "temporal", - }, - "x2": { - "axis": null, - "bin": { - "binned": true, - "step": 86400000, - }, - "field": "bin_end", - "type": "temporal", - }, - "y": { - "aggregate": "max", - "axis": null, - "type": "quantitative", - }, + { + "field": "bin_end", + "timeUnit": "yearmonthdate", + "title": "datetime (end)", + "type": "temporal", }, - "mark": { - "opacity": 0, - "type": "bar", + { + "field": "count", + "format": ",d", + "title": "Count", + "type": "quantitative", }, + ], + "x": { + "axis": null, + "field": "bin_start", + "type": "ordinal", }, - ], - "width": 70, + "y": { + "aggregate": "max", + "axis": null, + "type": "quantitative", + }, + }, + "mark": { + "opacity": 0, + "type": "bar", + "width": { + "band": 1.2, + }, + }, }, ], - "height": 100, - "resolve": { - "scale": { - "y": "shared", - }, - }, } `; -exports[`ColumnChartSpecModel > should handle string value counts when feat flag is true 1`] = ` +exports[`ColumnChartSpecModel > should handle string value counts 1`] = ` { "background": "transparent", "config": { @@ -619,18 +537,20 @@ exports[`ColumnChartSpecModel > should handle string value counts when feat flag { "encoding": { "color": { + "condition": { + "test": "datum.value == "None" || datum.value == "null"", + "value": "#cc4e00", + }, + "value": "#027864", + }, + "opacity": { "condition": [ { "param": "hover_bar", - "value": "#027864", - }, - { - "test": "datum.value == "None" || datum.value == "null"", - "value": "#cc4e00", + "value": 1, }, ], - "legend": null, - "value": "#4cbba5", + "value": 0.6, }, "tooltip": [ { @@ -1077,3 +997,372 @@ exports[`ColumnChartSpecModel > snapshot > url data 1`] = ` ], } `; + +exports[`ColumnChartSpecModel > snapshot with legacy data spec > array 1`] = ` +{ + "background": "transparent", + "config": { + "axis": { + "domain": false, + }, + "view": { + "stroke": "transparent", + }, + }, + "data": { + "values": [ + "a", + "b", + "c", + ], + }, + "height": 100, + "layer": [ + { + "encoding": { + "x": { + "bin": true, + "field": "a", + "type": "quantitative", + }, + "y": { + "aggregate": "count", + "axis": null, + "type": "quantitative", + }, + }, + "mark": { + "color": "#027864", + "type": "bar", + }, + }, + { + "encoding": { + "tooltip": [ + { + "bin": true, + "field": "a", + "format": ".2f", + "title": "a", + "type": "quantitative", + }, + { + "aggregate": "count", + "format": ",d", + "title": "Count", + "type": "quantitative", + }, + ], + "x": { + "axis": { + "labelExpr": "(datum.value >= 10000 || datum.value <= -10000) ? format(datum.value, '.2e') : format(datum.value, '.2~f')", + "labelFontSize": 8.5, + "labelOpacity": 0.5, + "title": null, + }, + "bin": true, + "field": "a", + "type": "quantitative", + }, + "y": { + "aggregate": "max", + "axis": null, + "type": "quantitative", + }, + }, + "mark": { + "opacity": 0, + "type": "bar", + }, + }, + ], +} +`; + +exports[`ColumnChartSpecModel > snapshot with legacy data spec > csv data 1`] = ` +{ + "background": "transparent", + "config": { + "axis": { + "domain": false, + }, + "view": { + "stroke": "transparent", + }, + }, + "data": { + "values": [ + { + "a": 1, + "b": 2, + "c": 3, + }, + { + "a": 4, + "b": 5, + "c": 6, + }, + ], + }, + "height": 100, + "layer": [ + { + "encoding": { + "x": { + "bin": true, + "field": "a", + "type": "quantitative", + }, + "y": { + "aggregate": "count", + "axis": null, + "type": "quantitative", + }, + }, + "mark": { + "color": "#027864", + "type": "bar", + }, + }, + { + "encoding": { + "tooltip": [ + { + "bin": true, + "field": "a", + "format": ".2f", + "title": "a", + "type": "quantitative", + }, + { + "aggregate": "count", + "format": ",d", + "title": "Count", + "type": "quantitative", + }, + ], + "x": { + "axis": { + "labelExpr": "(datum.value >= 10000 || datum.value <= -10000) ? format(datum.value, '.2e') : format(datum.value, '.2~f')", + "labelFontSize": 8.5, + "labelOpacity": 0.5, + "title": null, + }, + "bin": true, + "field": "a", + "type": "quantitative", + }, + "y": { + "aggregate": "max", + "axis": null, + "type": "quantitative", + }, + }, + "mark": { + "opacity": 0, + "type": "bar", + }, + }, + ], +} +`; + +exports[`ColumnChartSpecModel > snapshot with legacy data spec > csv string 1`] = ` +{ + "background": "transparent", + "config": { + "axis": { + "domain": false, + }, + "view": { + "stroke": "transparent", + }, + }, + "data": { + "values": [ + { + "a": 1, + "b": 2, + "c": 3, + }, + { + "a": 4, + "b": 5, + "c": 6, + }, + ], + }, + "height": 100, + "layer": [ + { + "encoding": { + "x": { + "bin": true, + "field": "a", + "type": "quantitative", + }, + "y": { + "aggregate": "count", + "axis": null, + "type": "quantitative", + }, + }, + "mark": { + "color": "#027864", + "type": "bar", + }, + }, + { + "encoding": { + "tooltip": [ + { + "bin": true, + "field": "a", + "format": ".2f", + "title": "a", + "type": "quantitative", + }, + { + "aggregate": "count", + "format": ",d", + "title": "Count", + "type": "quantitative", + }, + ], + "x": { + "axis": { + "labelExpr": "(datum.value >= 10000 || datum.value <= -10000) ? format(datum.value, '.2e') : format(datum.value, '.2~f')", + "labelFontSize": 8.5, + "labelOpacity": 0.5, + "title": null, + }, + "bin": true, + "field": "a", + "type": "quantitative", + }, + "y": { + "aggregate": "max", + "axis": null, + "type": "quantitative", + }, + }, + "mark": { + "opacity": 0, + "type": "bar", + }, + }, + ], +} +`; + +exports[`ColumnChartSpecModel > snapshot with legacy data spec > url data 1`] = ` +{ + "background": "transparent", + "config": { + "axis": { + "domain": false, + }, + "view": { + "stroke": "transparent", + }, + }, + "data": { + "values": [], + }, + "height": 100, + "layer": [ + { + "encoding": { + "color": { + "condition": { + "test": "datum["bin_maxbins_10_date_range"] === "null"", + "value": "#cc4e00", + }, + "value": "#027864", + }, + "x": { + "axis": null, + "bin": true, + "field": "date", + "scale": { + "align": 0, + "paddingInner": 0, + "paddingOuter": { + "expr": "length(data('data_0')) == 2 ? 1 : length(data('data_0')) == 3 ? 0.5 : length(data('data_0')) == 4 ? 0 : 0", + }, + }, + "type": "temporal", + }, + "y": { + "aggregate": "count", + "axis": null, + "type": "quantitative", + }, + }, + "mark": { + "color": "#027864", + "type": "bar", + }, + }, + { + "encoding": { + "color": { + "condition": { + "test": "datum["bin_maxbins_10_date_range"] === "null"", + "value": "#cc4e00", + }, + "value": "#027864", + }, + "tooltip": [ + { + "bin": { + "binned": true, + }, + "field": "bin_maxbins_10_date", + "format": "%Y-%m-%d", + "title": "date (start)", + "type": "temporal", + }, + { + "bin": { + "binned": true, + }, + "field": "bin_maxbins_10_date_end", + "format": "%Y-%m-%d", + "title": "date (end)", + "type": "temporal", + }, + { + "aggregate": "count", + "format": ",d", + "title": "Count", + "type": "quantitative", + }, + ], + "x": { + "axis": null, + "bin": true, + "field": "date", + "scale": { + "align": 0, + "paddingInner": 0, + "paddingOuter": { + "expr": "length(data('data_0')) == 2 ? 1 : length(data('data_0')) == 3 ? 0.5 : length(data('data_0')) == 4 ? 0 : 0", + }, + }, + "type": "temporal", + }, + "y": { + "aggregate": "max", + "axis": null, + "type": "quantitative", + }, + }, + "mark": { + "opacity": 0, + "type": "bar", + }, + }, + ], +} +`; diff --git a/frontend/src/components/data-table/__tests__/chart-spec-model.test.ts b/frontend/src/components/data-table/__tests__/chart-spec-model.test.ts index f7c920650f7..1b581980c8a 100644 --- a/frontend/src/components/data-table/__tests__/chart-spec-model.test.ts +++ b/frontend/src/components/data-table/__tests__/chart-spec-model.test.ts @@ -61,7 +61,7 @@ describe("ColumnChartSpecModel", () => { it("should create an instance", () => { const model = new ColumnChartSpecModel( - mockData, + [], mockFieldTypes, mockStats, mockBinValues, @@ -78,7 +78,7 @@ describe("ColumnChartSpecModel", () => { it("should return header summary with spec when includeCharts is true", () => { const model = new ColumnChartSpecModel( - mockData, + [], mockFieldTypes, mockStats, mockBinValues, @@ -93,7 +93,7 @@ describe("ColumnChartSpecModel", () => { it("should return header summary without spec when includeCharts is false", () => { const model = new ColumnChartSpecModel( - mockData, + [], mockFieldTypes, mockStats, mockBinValues, @@ -106,17 +106,17 @@ describe("ColumnChartSpecModel", () => { expect(numberSummary.spec).toBeUndefined(); }); - it("should return null spec for string and unknown types", () => { + it("should return null spec for unknown types", () => { const model = new ColumnChartSpecModel( - mockData, + [], mockFieldTypes, mockStats, mockBinValues, mockValueCounts, { includeCharts: true }, ); - const stringSummary = model.getHeaderSummary("string"); - expect(stringSummary.spec).toBeNull(); + const unknownSummary = model.getHeaderSummary("unknown"); + expect(unknownSummary.spec).toBeNull(); }); it("should handle special characters in column names", () => { @@ -126,11 +126,18 @@ describe("ColumnChartSpecModel", () => { const specialStats: Record> = { "column.with[special:chars]": { min: "2023-01-01", max: "2023-12-31" }, }; + const specialBinValues: Record = { + ...mockBinValues, + "column.with[special:chars]": [ + { bin_start: "2023-01-01", bin_end: "2023-06-01", count: 10 }, + { bin_start: "2023-06-01", bin_end: "2023-12-31", count: 20 }, + ], + }; const model = new ColumnChartSpecModel( - mockData, + [], specialFieldTypes, specialStats, - mockBinValues, + specialBinValues, mockValueCounts, { includeCharts: true }, ); @@ -139,17 +146,17 @@ describe("ColumnChartSpecModel", () => { expect( // @ts-expect-error layer should be available (summary.spec?.layer[0].encoding?.x as { field: string })?.field, - ).toBe("column\\.with\\[special\\:chars\\]"); + ).toBe("bin_start"); }); - it("should expect bin values to be used for number and integer columns when feat flag is true", () => { + it("should expect bin values to be used for number and integer columns", () => { const model = new ColumnChartSpecModel( - mockData, + [], mockFieldTypes, mockStats, mockBinValues, mockValueCounts, - { includeCharts: true, usePreComputedValues: true }, + { includeCharts: true }, ); const summary = model.getHeaderSummary("number"); expect(summary.spec).toBeDefined(); @@ -171,14 +178,14 @@ describe("ColumnChartSpecModel", () => { expect(summary2.spec?.data?.values).toEqual(mockBinValues.integer); }); - it("should handle datetime bin values when feat flag is true", () => { + it("should handle datetime bin values", () => { const model = new ColumnChartSpecModel( - mockData, + [], mockFieldTypes, mockStats, mockBinValues, mockValueCounts, - { includeCharts: true, usePreComputedValues: true }, + { includeCharts: true }, ); const summary = model.getHeaderSummary("datetime"); @@ -186,9 +193,6 @@ describe("ColumnChartSpecModel", () => { // @ts-expect-error data.values should be available expect(summary.spec?.data?.values).toEqual(mockBinValues.datetime); - // Expect hconcat since there are nulls - // @ts-expect-error hconcat should be available - expect(summary.spec?.hconcat).toBeDefined(); expect(summary.spec).toMatchSnapshot(); // Test again without the nulls @@ -197,12 +201,12 @@ describe("ColumnChartSpecModel", () => { datetime: { min: "2023-01-01", max: "2023-12-31" }, }; const model2 = new ColumnChartSpecModel( - mockData, + [], mockFieldTypes, mockStats2, mockBinValues, mockValueCounts, - { includeCharts: true, usePreComputedValues: true }, + { includeCharts: true }, ); const summary2 = model2.getHeaderSummary("datetime"); expect(summary2.spec).toBeDefined(); @@ -212,14 +216,14 @@ describe("ColumnChartSpecModel", () => { expect(summary2.spec?.hconcat).toBeUndefined(); }); - it("should handle boolean stats when feat flag is true", () => { + it("should handle boolean stats", () => { const model = new ColumnChartSpecModel( - mockData, + [], mockFieldTypes, mockStats, mockBinValues, mockValueCounts, - { includeCharts: true, usePreComputedValues: true }, + { includeCharts: true }, ); const summary = model.getHeaderSummary("boolean"); expect(summary.spec).toBeDefined(); @@ -234,14 +238,14 @@ describe("ColumnChartSpecModel", () => { expect(summary.spec).toMatchSnapshot(); }); - it("should handle string value counts when feat flag is true", () => { + it("should handle string value counts", () => { const model = new ColumnChartSpecModel( - mockData, + [], mockFieldTypes, mockStats, mockBinValues, mockValueCounts, - { includeCharts: true, usePreComputedValues: true }, + { includeCharts: true }, ); const summary = model.getHeaderSummary("string"); expect(summary.spec).toBeDefined(); @@ -262,7 +266,7 @@ describe("ColumnChartSpecModel", () => { expect(summary.spec).toMatchSnapshot(); }); - describe("snapshot", () => { + describe("snapshot with legacy data spec", () => { const fieldTypes: FieldTypes = { ...mockFieldTypes, a: "number", diff --git a/frontend/src/components/data-table/column-summary/chart-spec-model.tsx b/frontend/src/components/data-table/column-summary/chart-spec-model.tsx index d4b4ebbebc6..98d86a633d4 100644 --- a/frontend/src/components/data-table/column-summary/chart-spec-model.tsx +++ b/frontend/src/components/data-table/column-summary/chart-spec-model.tsx @@ -3,19 +3,8 @@ import { mint, orange, slate } from "@radix-ui/colors"; import type { TopLevelSpec } from "vega-lite"; import type { StringFieldDef } from "vega-lite/build/src/channeldef"; -// @ts-expect-error vega-typings does not include formats -import { formats } from "vega-loader"; -import { asRemoteURL } from "@/core/runtime/config"; import type { TopLevelFacetedUnitSpec } from "@/plugins/impl/data-explorer/queries/types"; -import { arrow } from "@/plugins/impl/vega/formats"; -import { parseCsvData } from "@/plugins/impl/vega/loader"; import { logNever } from "@/utils/assertNever"; -import { - byteStringToBinary, - extractBase64FromDataURL, - isDataURLString, - typedAtob, -} from "@/utils/json/base64"; import type { BinValues, ColumnHeaderStats, @@ -24,9 +13,11 @@ import type { ValueCounts, } from "../types"; import { + getDataSpecAndSourceName, getLegacyBooleanSpec, getLegacyNumericSpec, getLegacyTemporalSpec, + getScale, } from "./legacy-chart-spec"; import { calculateBinStep, getPartialTimeTooltip } from "./utils"; @@ -34,15 +25,13 @@ import { calculateBinStep, getPartialTimeTooltip } from "./utils"; const MAX_BAR_HEIGHT = 20; // px // If we are concatenating charts, we need to specify each chart's height and width. -const CHART_HEIGHT = 30; -const CHART_WIDTH = 70; -const NULL_BAR_WIDTH = 5; - -// Arrow formats have a magic number at the beginning of the file. -const ARROW_MAGIC_NUMBER = "ARROW1"; +const CONCAT_CHART_HEIGHT = 30; +const CONCAT_CHART_WIDTH = 70; +const CONCAT_NULL_BAR_WIDTH = 5; -// register arrow reader under type 'arrow' -formats("arrow", arrow); +const BAR_COLOR = mint.mint11; +const UNHOVERED_BAR_OPACITY = 0.6; +const NULL_BAR_COLOR = orange.orange11; export class ColumnChartSpecModel { private columnStats = new Map>(); @@ -57,21 +46,21 @@ export class ColumnChartSpecModel { {}, { includeCharts: false, - usePreComputedValues: false, }, ); private dataSpec: TopLevelSpec["data"]; - private sourceName: "data_0" | "source_0"; - private readonly data: T[] | string; + // Legacy data spec for fallback + private legacyDataSpec: TopLevelSpec["data"]; + private legacySourceName: "data_0" | "source_0"; + private readonly fieldTypes: FieldTypes; readonly stats: Record>; readonly binValues: Record; readonly valueCounts: Record; private readonly opts: { includeCharts: boolean; - usePreComputedValues?: boolean; }; constructor( @@ -82,56 +71,22 @@ export class ColumnChartSpecModel { valueCounts: Record, opts: { includeCharts: boolean; - usePreComputedValues?: boolean; }, ) { - this.data = data; this.fieldTypes = fieldTypes; this.stats = stats; this.binValues = binValues; this.valueCounts = valueCounts; this.opts = opts; - // Data may come in from a few different sources: - // - A URL - // - A CSV data URI (e.g. "data:text/csv;base64,...") - // - A CSV string (e.g. "a,b,c\n1,2,3\n4,5,6") - // - An array of objects - // For each case, we need to set up the data spec and source name appropriately. - // If its a file, the source name will be "source_0", otherwise it will be "data_0". - // We have a few snapshot tests to ensure that the spec is correct for each case. - if (typeof this.data === "string") { - if (this.data.startsWith("./@file") || this.data.startsWith("/@file")) { - this.dataSpec = { url: asRemoteURL(this.data).href }; - this.sourceName = "source_0"; - } else if (isDataURLString(this.data)) { - this.sourceName = "data_0"; - const base64 = extractBase64FromDataURL(this.data); - const decoded = typedAtob(base64); - - if (decoded.startsWith(ARROW_MAGIC_NUMBER)) { - // @ts-expect-error vega-typings does not include arrow format - this.dataSpec = { - values: byteStringToBinary(decoded), - format: { type: "arrow" }, - }; - } else { - // Assume it's a CSV string - this.parseCsv(decoded); - } - } else { - // Assume it's a CSV string - this.parseCsv(this.data); - this.sourceName = "data_0"; - } - } else { - this.dataSpec = { values: this.data }; - this.sourceName = "source_0"; - } - this.columnBinValues = new Map(Object.entries(binValues)); this.columnValueCounts = new Map(Object.entries(valueCounts)); this.columnStats = new Map(Object.entries(stats)); + + const { dataSpec, sourceName } = getDataSpecAndSourceName(data); + this.dataSpec = dataSpec; + this.legacyDataSpec = dataSpec; + this.legacySourceName = sourceName; } public getColumnStats(column: string) { @@ -146,42 +101,8 @@ export class ColumnChartSpecModel { }; } - private parseCsv(data: string) { - this.dataSpec = { - values: parseCsvData(data) as T[], - }; - } - - private getVegaSpec(column: string): TopLevelFacetedUnitSpec | null { - if (!this.data) { - return null; - } - - const usePreComputedValues = this.opts.usePreComputedValues; - const binValues = this.columnBinValues.get(column); - const valueCounts = this.columnValueCounts.get(column); - const hasValueCounts = valueCounts && valueCounts.length > 0; - - let data = this.dataSpec as TopLevelFacetedUnitSpec["data"]; - const stats = this.columnStats.get(column); - - if (usePreComputedValues) { - if (hasValueCounts) { - data = { values: valueCounts, name: "value_counts" }; - } else { - // Bin values can be empty if all values are nulls - if (stats?.nulls) { - binValues?.push({ - bin_start: null, - bin_end: null, - count: stats.nulls as number, - }); - } - data = { values: binValues, name: "bin_values" }; - } - } - - const base: TopLevelFacetedUnitSpec = { + private createBase(data: TopLevelSpec["data"]): TopLevelFacetedUnitSpec { + return { background: "transparent", data, config: { @@ -194,7 +115,31 @@ export class ColumnChartSpecModel { }, height: 100, } as TopLevelFacetedUnitSpec; + } + + private getVegaSpec(column: string): TopLevelFacetedUnitSpec | null { + const binValues = this.columnBinValues.get(column); + const valueCounts = this.columnValueCounts.get(column); + const hasValueCounts = valueCounts && valueCounts.length > 0; + + let data = this.dataSpec as TopLevelFacetedUnitSpec["data"]; + const stats = this.columnStats.get(column); + if (hasValueCounts) { + data = { values: valueCounts, name: "value_counts" }; + } else { + // Bin values can be empty if all values are nulls + if (stats?.nulls) { + binValues?.push({ + bin_start: null, + bin_end: null, + count: stats.nulls as number, + }); + } + data = { values: binValues, name: "bin_values" }; + } + + const base = this.createBase(data); const type = this.fieldTypes[column]; // https://github.com/vega/altair/blob/32990a597af7c09586904f40b3f5e6787f752fa5/doc/user_guide/encodings/index.rst#escaping-special-characters-in-column-names @@ -205,18 +150,16 @@ export class ColumnChartSpecModel { // escape colons in column names column = column.replaceAll(":", "\\:"); - const scale = this.getScale(); - switch (type) { case "date": case "datetime": case "time": { - if (!usePreComputedValues) { - return getLegacyTemporalSpec(column, type, base, scale); + if (!binValues) { + const legacyBase = this.createBase(this.legacyDataSpec); + const scale = getScale(this.legacySourceName); + return getLegacyTemporalSpec(column, type, legacyBase, scale); } - const binStep = calculateBinStep(binValues || []); - const tooltip = getPartialTimeTooltip(binValues || []); const singleValue = binValues?.length === 1; @@ -224,7 +167,7 @@ export class ColumnChartSpecModel { if (singleValue) { return { ...base, - mark: { type: "bar", color: mint.mint11 }, + mark: { type: "bar", color: BAR_COLOR }, encoding: { x: { field: "bin_start", @@ -255,16 +198,12 @@ export class ColumnChartSpecModel { } const histogram: TopLevelFacetedUnitSpec = { - height: CHART_HEIGHT, - width: CHART_WIDTH, // @ts-expect-error 'layer' property not in TopLevelFacetedUnitSpec layer: [ { mark: { type: "bar", - color: mint.mint11, - stroke: mint.mint11, - strokeWidth: 0, + color: BAR_COLOR, }, params: [ { @@ -279,13 +218,7 @@ export class ColumnChartSpecModel { encoding: { x: { field: "bin_start", - type: "temporal", - bin: { binned: true, step: binStep }, - axis: null, - }, - x2: { - field: "bin_end", - type: "temporal", + type: "ordinal", axis: null, }, y: { @@ -293,13 +226,21 @@ export class ColumnChartSpecModel { type: "quantitative", axis: null, }, - strokeWidth: { + color: { condition: { - param: "hover", - empty: false, - value: 0.5, + test: "datum['bin_start'] === null && datum['bin_end'] === null", + value: NULL_BAR_COLOR, }, - value: 0, + value: BAR_COLOR, + }, + opacity: { + condition: [ + { + param: "hover", + value: 1, + }, + ], + value: UNHOVERED_BAR_OPACITY, }, }, }, @@ -309,18 +250,13 @@ export class ColumnChartSpecModel { mark: { type: "bar", opacity: 0, + // Wider bars to cover gaps between bars, prevents flickering when hovering over bars + width: { band: 1.2 }, }, encoding: { x: { field: "bin_start", - type: "temporal", - bin: { binned: true, step: binStep }, - axis: null, - }, - x2: { - field: "bin_end", - type: "temporal", - bin: { binned: true, step: binStep }, + type: "ordinal", axis: null, }, y: { @@ -353,130 +289,33 @@ export class ColumnChartSpecModel { ], }; - const nullBar: TopLevelFacetedUnitSpec = { - height: CHART_HEIGHT, - width: NULL_BAR_WIDTH, - // @ts-expect-error 'layer' property not in TopLevelFacetedUnitSpec - layer: [ - { - mark: { - type: "bar", - color: orange.orange11, - }, - encoding: { - x: { - field: "bin_start", - type: "nominal", - axis: null, - }, - y: { - field: "count", - type: "quantitative", - axis: null, - }, - }, - }, - - // Invisible tooltip layer with max-height - { - mark: { - type: "bar", - opacity: 0, - }, - encoding: { - x: { - field: "bin_start", - type: "nominal", - axis: null, - }, - y: { - aggregate: "max", - type: "quantitative", - axis: null, - }, - tooltip: [ - { - field: "count", - type: "quantitative", - title: "nulls", - format: ",d", - }, - ], - }, - }, - ], - transform: [ - { - filter: - "datum['bin_start'] === null && datum['bin_end'] === null", - }, - ], - }; - - let chart: TopLevelFacetedUnitSpec = histogram; - let timeBase = base; - - if (stats?.nulls) { - timeBase = { - ...base, - config: { - ...base.config, - concat: { - spacing: 0, - }, - }, - resolve: { - scale: { - y: "shared", - }, - }, - }; - chart = { - // Temporal axis will not show nulls, so we concat 2 charts - // @ts-expect-error 'hconcat' property not in TopLevelFacetedUnitSpec - hconcat: [nullBar, histogram], - }; - } - return { - ...timeBase, - ...chart, + ...base, + ...histogram, }; } case "integer": case "number": { // Create a histogram spec that properly handles null values const format = type === "integer" ? ",d" : ".2f"; - const binStep = calculateBinStep(binValues || []); - if (!usePreComputedValues) { - return getLegacyNumericSpec(column, format, base); + if (!binValues) { + const legacyBase = this.createBase(this.legacyDataSpec); + return getLegacyNumericSpec(column, format, legacyBase); } - const stats = this.columnStats.get(column); + const binStep = calculateBinStep(binValues || []); const histogram: TopLevelFacetedUnitSpec = { - height: CHART_HEIGHT, - width: CHART_WIDTH, + height: CONCAT_CHART_HEIGHT, + width: CONCAT_CHART_WIDTH, // @ts-expect-error 'layer' property not in TopLevelFacetedUnitSpec layer: [ { mark: { type: "bar", - color: mint.mint11, - stroke: mint.mint11, - strokeWidth: 0, + color: BAR_COLOR, }, - params: [ - { - name: "hover", - select: { - type: "point", - on: "mouseover", - clear: "mouseout", - }, - }, - ], encoding: { x: { field: "bin_start", @@ -492,13 +331,14 @@ export class ColumnChartSpecModel { type: "quantitative", axis: null, }, - strokeWidth: { - condition: { - param: "hover", - empty: false, - value: 0.5, - }, - value: 0, + opacity: { + condition: [ + { + param: "hover", + value: 1, + }, + ], + value: UNHOVERED_BAR_OPACITY, }, }, }, @@ -509,6 +349,17 @@ export class ColumnChartSpecModel { type: "bar", opacity: 0, }, + params: [ + { + name: "hover", + select: { + type: "point", + on: "mouseover", + clear: "mouseout", + nearest: true, // Nearest avoids flickering when hovering over bars, but it's not perfect + }, + }, + ], encoding: { x: { field: "bin_start", @@ -564,14 +415,14 @@ export class ColumnChartSpecModel { }; const nullBar: TopLevelFacetedUnitSpec = { - height: CHART_HEIGHT, - width: NULL_BAR_WIDTH, + height: CONCAT_CHART_HEIGHT, + width: CONCAT_NULL_BAR_WIDTH, // @ts-expect-error 'layer' property not in TopLevelFacetedUnitSpec layer: [ { mark: { type: "bar", - color: orange.orange11, + color: NULL_BAR_COLOR, }, encoding: { x: { @@ -652,7 +503,7 @@ export class ColumnChartSpecModel { }; } case "boolean": { - if (!usePreComputedValues || !stats?.true || !stats?.false) { + if (!stats?.true || !stats?.false) { return getLegacyBooleanSpec(column, base, MAX_BAR_HEIGHT); } @@ -699,7 +550,7 @@ export class ColumnChartSpecModel { }, mark: { type: "bar", - color: mint.mint11, + color: BAR_COLOR, }, encoding: { y: { @@ -726,7 +577,7 @@ export class ColumnChartSpecModel { type: "nominal", scale: { domain: ["true", "false", "null"], - range: [mint.mint11, mint.mint11, orange.orange11], + range: [BAR_COLOR, BAR_COLOR, NULL_BAR_COLOR], }, legend: null, }, @@ -740,7 +591,7 @@ export class ColumnChartSpecModel { { mark: { type: "bar", - color: mint.mint11, + color: BAR_COLOR, height: BAR_HEIGHT, }, }, @@ -764,7 +615,7 @@ export class ColumnChartSpecModel { } as TopLevelFacetedUnitSpec; // "layer" not in TopLevelFacetedUnitSpec } case "string": { - if (!usePreComputedValues || !hasValueCounts) { + if (!hasValueCounts) { return null; } @@ -840,18 +691,20 @@ export class ColumnChartSpecModel { type: "quantitative", }, color: { + condition: { + test: `datum.${yField} == "None" || datum.${yField} == "null"`, + value: NULL_BAR_COLOR, + }, + value: BAR_COLOR, + }, + opacity: { condition: [ { param: "hover_bar", - value: mint.mint11, - }, - { - test: `datum.${yField} == "None" || datum.${yField} == "null"`, - value: orange.orange11, + value: 1, }, ], - value: mint.mint8, - legend: null, + value: UNHOVERED_BAR_OPACITY, }, tooltip: [ { @@ -920,14 +773,4 @@ export class ColumnChartSpecModel { return null; } } - - private getScale() { - return { - align: 0, - paddingInner: 0, - paddingOuter: { - expr: `length(data('${this.sourceName}')) == 2 ? 1 : length(data('${this.sourceName}')) == 3 ? 0.5 : length(data('${this.sourceName}')) == 4 ? 0 : 0`, - }, - }; - } } diff --git a/frontend/src/components/data-table/column-summary/column-summary.tsx b/frontend/src/components/data-table/column-summary/column-summary.tsx index d8a5abbd79d..737c772a92b 100644 --- a/frontend/src/components/data-table/column-summary/column-summary.tsx +++ b/frontend/src/components/data-table/column-summary/column-summary.tsx @@ -48,7 +48,7 @@ export const TableColumnSummary = ({ (data: string | T[]): { + dataSpec: TopLevelSpec["data"]; + sourceName: "data_0" | "source_0"; +} { + let dataSpec: TopLevelSpec["data"]; + let sourceName: "data_0" | "source_0"; + + // Data may come in from a few different sources: + // - A URL + // - A CSV data URI (e.g. "data:text/csv;base64,...") + // - A CSV string (e.g. "a,b,c\n1,2,3\n4,5,6") + // - An array of objects + // For each case, we need to set up the data spec and source name appropriately. + // If its a file, the source name will be "source_0", otherwise it will be "data_0". + // We have a few snapshot tests to ensure that the spec is correct for each case. + if (typeof data === "string") { + if (data.startsWith("./@file") || data.startsWith("/@file")) { + dataSpec = { url: asRemoteURL(data).href }; + sourceName = "source_0"; + } else if (isDataURLString(data)) { + sourceName = "data_0"; + const base64 = extractBase64FromDataURL(data); + const decoded = typedAtob(base64); + + // eslint-disable-next-line unicorn/prefer-ternary + if (decoded.startsWith(ARROW_MAGIC_NUMBER)) { + // @ts-expect-error vega-typings does not include arrow format + dataSpec = { + values: byteStringToBinary(decoded), + format: { type: "arrow" }, + }; + } else { + // Assume it's a CSV string + dataSpec = { values: parseCsvData(decoded) }; + } + } else { + // Assume it's a CSV string + dataSpec = { values: parseCsvData(data) }; + sourceName = "data_0"; + } + } else { + dataSpec = { values: data }; + sourceName = "source_0"; + } + + return { dataSpec, sourceName }; +} + +export function getScale(sourceName: string): Scale { + return { + align: 0, + paddingInner: 0, + paddingOuter: { + expr: `length(data('${sourceName}')) == 2 ? 1 : length(data('${sourceName}')) == 3 ? 0.5 : length(data('${sourceName}')) == 4 ? 0 : 0`, + }, + }; +} diff --git a/frontend/src/plugins/impl/DataTablePlugin.tsx b/frontend/src/plugins/impl/DataTablePlugin.tsx index 607008665f6..0922558c92d 100644 --- a/frontend/src/plugins/impl/DataTablePlugin.tsx +++ b/frontend/src/plugins/impl/DataTablePlugin.tsx @@ -57,7 +57,6 @@ import { Alert, AlertTitle } from "@/components/ui/alert"; import { Button } from "@/components/ui/button"; import { DelayMount } from "@/components/utils/delay-mount"; import { type CellId, findCellId } from "@/core/cells/ids"; -import { getFeatureFlag } from "@/core/config/feature-flag"; import { slotsController } from "@/core/slots/slots"; import { store } from "@/core/state/jotai"; import { isStaticNotebook } from "@/core/static/static-state"; @@ -85,15 +84,12 @@ import { type CsvURL = string; export type TableData = T[] | CsvURL; -interface ColumnSummariesArgs { - precompute: boolean; -} - interface ColumnSummaries { data: TableData | null | undefined; stats: Record; bin_values: Record; value_counts: Record; + show_charts: boolean; is_disabled?: boolean; } @@ -198,9 +194,7 @@ interface Data { // eslint-disable-next-line @typescript-eslint/consistent-type-definitions type DataTableFunctions = { download_as: DownloadAsArgs; - get_column_summaries: ( - opts: ColumnSummariesArgs, - ) => Promise>; + get_column_summaries: (opts: {}) => Promise>; search: (req: { sort?: { by: string; @@ -281,19 +275,16 @@ export const DataTablePlugin = createPlugin("marimo-table") ) .withFunctions({ download_as: DownloadAsSchema, - get_column_summaries: rpc - .input(z.object({ precompute: z.boolean() })) - .output( - z.object({ - data: z - .union([z.string(), z.array(z.object({}).passthrough())]) - .nullable(), - stats: z.record(z.string(), columnStats), - bin_values: z.record(z.string(), binValues), - value_counts: z.record(z.string(), valueCounts), - is_disabled: z.boolean().optional(), - }), - ), + get_column_summaries: rpc.input(z.looseObject({})).output( + z.object({ + data: z.union([z.string(), z.array(z.looseObject({}))]).nullable(), + stats: z.record(z.string(), columnStats), + bin_values: z.record(z.string(), binValues), + value_counts: z.record(z.string(), valueCounts), + show_charts: z.boolean(), + is_disabled: z.boolean().optional(), + }), + ), search: rpc .input( z.object({ @@ -598,8 +589,6 @@ export const LoadingDataTableComponent = memo( ); }, [data?.totalRows]); - const precompute = getFeatureFlag("performant_table_charts"); - // Column summaries const { data: columnSummaries, error: columnSummariesError } = useAsyncData< ColumnSummaries @@ -607,9 +596,15 @@ export const LoadingDataTableComponent = memo( // TODO: props.get_column_summaries is always true, // so we are unable to detect if the function is registered if (props.totalRows === 0 || !props.showColumnSummaries) { - return { data: null, stats: {}, bin_values: {}, value_counts: {} }; + return { + data: null, + stats: {}, + bin_values: {}, + value_counts: {}, + show_charts: false, + }; } - return props.get_column_summaries({ precompute }); + return props.get_column_summaries({}); }, [ props.get_column_summaries, props.showColumnSummaries, @@ -778,8 +773,7 @@ const DataTableComponent = ({ columnSummaries.bin_values, columnSummaries.value_counts, { - includeCharts: Boolean(columnSummaries.data), - usePreComputedValues: getFeatureFlag("performant_table_charts"), + includeCharts: columnSummaries.show_charts, }, ); }, [fieldTypes, columnSummaries]); diff --git a/frontend/src/plugins/impl/data-frames/DataFramePlugin.tsx b/frontend/src/plugins/impl/data-frames/DataFramePlugin.tsx index e0a2f1d43b6..b1767f3e7c4 100644 --- a/frontend/src/plugins/impl/data-frames/DataFramePlugin.tsx +++ b/frontend/src/plugins/impl/data-frames/DataFramePlugin.tsx @@ -312,9 +312,10 @@ DataFrameComponent.displayName = "DataFrameComponent"; function getColumnSummaries() { return Promise.resolve({ - stats: {}, data: null, + stats: {}, bin_values: {}, value_counts: {}, + show_charts: false, }); } diff --git a/marimo/_plugins/ui/_impl/table.py b/marimo/_plugins/ui/_impl/table.py index fc153c565c0..eefeb4ecc43 100644 --- a/marimo/_plugins/ui/_impl/table.py +++ b/marimo/_plugins/ui/_impl/table.py @@ -89,24 +89,24 @@ class DownloadAsArgs: @dataclass -class ColumnSummariesArgs: - """If enabled, we will precompute chart values.""" - - precompute: bool +class ColumnSummariesArgs: ... @dataclass class ColumnSummaries: + # If precomputed aggregations fail, we fallback to chart data data: Union[JSONType, str] stats: dict[ColumnName, ColumnStats] bin_values: dict[ColumnName, list[BinValue]] value_counts: dict[ColumnName, list[ValueCount]] + show_charts: bool # Disabled because of too many columns/rows # This will show a banner in the frontend is_disabled: Optional[bool] = None ShowColumnSummaries = Union[bool, Literal["stats", "chart"]] +CHART_MAX_ROWS_STRING_VALUE_COUNTS = 20_000 DEFAULT_MAX_COLUMNS = 50 @@ -888,6 +888,8 @@ def _get_column_summaries( If summaries are disabled or row limit is exceeded, returns empty summaries with is_disabled flag set appropriately. """ + del args + show_column_summaries = self._show_column_summaries if not show_column_summaries: @@ -899,6 +901,7 @@ def _get_column_summaries( # This is not 'disabled' because of too many rows # so we don't want to display the banner is_disabled=False, + show_charts=False, ) total_rows = self._searched_manager.get_num_rows(force=True) or 0 @@ -912,12 +915,13 @@ def _get_column_summaries( bin_values={}, value_counts={}, is_disabled=True, + show_charts=False, ) # If we are above the limit to show charts, # or if we are in stats-only mode, - # we don't return the chart data - should_get_chart_data = ( + # we don't show charts + show_charts = ( self._show_column_summaries != "stats" and total_rows <= self._column_charts_row_limit ) @@ -934,6 +938,9 @@ def _get_column_summaries( DEFAULT_BIN_SIZE = 9 DEFAULT_VALUE_COUNTS_SIZE = 15 + bin_aggregation_failed = False + cols_to_drop = [] + for column in self._manager.get_column_names(): statistic = None if should_get_stats: @@ -945,25 +952,25 @@ def _get_column_summaries( # BaseExceptions, which shouldn't crash the kernel LOGGER.warning("Failed to get stats for column %s", column) - if should_get_chart_data and args.precompute: + if show_charts: if not should_get_stats: LOGGER.warning( "Unable to compute stats for column, may not be computed correctly" ) - # For boolean columns, we can drop the column since we use stats (column_type, external_type) = self._manager.get_field_type( column ) - if column_type == "boolean": - data = data.drop_columns([column]) + # For boolean columns, we can drop the column since we use stats + if column_type == "boolean" or column_type == "unknown": + cols_to_drop.append(column) # Handle columns with all nulls first # These get empty bins regardless of type if statistic and statistic.nulls == total_rows: try: bin_values[column] = [] - data = data.drop_columns([column]) + cols_to_drop.append(column) continue except BaseException as e: LOGGER.warning( @@ -971,10 +978,15 @@ def _get_column_summaries( ) continue - # For perf, we only compute value counts for categorical columns + # For now, we only compute value counts for categorical columns and small tables external_type = external_type.lower() - if column_type == "string" and ( - "cat" in external_type or "enum" in external_type + if ( + column_type == "string" + and ("cat" in external_type or "enum" in external_type) + or ( + column_type == "string" + and total_rows <= CHART_MAX_ROWS_STRING_VALUE_COUNTS + ) ): try: val_counts = self._get_value_counts( @@ -982,7 +994,7 @@ def _get_column_summaries( ) if len(val_counts) > 0: value_counts[column] = val_counts - data = data.drop_columns([column]) + cols_to_drop.append(column) continue except BaseException as e: LOGGER.warning( @@ -1002,18 +1014,23 @@ def _get_column_summaries( continue try: + # get_bin_values is marked unstable + # https://narwhals-dev.github.io/narwhals/api-reference/series/#narwhals.series.Series.hist bins = data.get_bin_values(column, DEFAULT_BIN_SIZE) bin_values[column] = bins - # Only drop column if we got bins to visualize if len(bins) > 0: - data = data.drop_columns([column]) + cols_to_drop.append(column) continue except BaseException as e: + bin_aggregation_failed = True LOGGER.warning( "Failed to get bin values for column %s: %s", column, e ) - if should_get_chart_data: + should_fallback = show_charts and bin_aggregation_failed + if should_fallback: + LOGGER.debug("Bin aggregation failed, falling back to chart data") + data = data.drop_columns(cols_to_drop) chart_data, _ = self._to_chart_data_url(data) return ColumnSummaries( @@ -1021,6 +1038,7 @@ def _get_column_summaries( stats=stats, bin_values=bin_values, value_counts=value_counts, + show_charts=show_charts, is_disabled=False, ) diff --git a/marimo/_smoke_tests/tables/column-header-chart.py b/marimo/_smoke_tests/tables/column-header-chart.py new file mode 100644 index 00000000000..d18b8d14a3d --- /dev/null +++ b/marimo/_smoke_tests/tables/column-header-chart.py @@ -0,0 +1,72 @@ +# /// script +# requires-python = ">=3.13" +# dependencies = [ +# "polars==1.34.0", +# "requests==2.32.5", +# ] +# /// + +import marimo + +__generated_with = "0.17.0" +app = marimo.App(width="medium") + + +@app.cell +def _(): + import marimo as mo + import polars as pl + return (pl,) + + +@app.cell +def _(pl): + # Enums and categorical data types + bears_enum = pl.Enum(["Polar", "Panda", "Brown"]) + bears = pl.Series( + ["Polar", "Panda", "Brown", "Brown", "Polar"] * 30, dtype=bears_enum + ) + + enums_cats = pl.DataFrame( + { + "bears": ["Polar", "Panda", "Brown", "Brown", "Polar"] * 30, + "bears_cat": ["Polar", "Panda", "Brown", "Brown", "Polar"] * 30, + }, + schema={ + "bears": bears_enum, + "bears_cat": pl.Categorical, + }, + ) + enums_cats + return + + +@app.cell +def _(pl): + pokemon_url = "https://gist.githubusercontent.com/armgilles/194bcff35001e7eb53a2a8b441e8b2c6/raw/92200bc0a673d5ce2110aaad4544ed6c4010f687/pokemon.csv" + pl.read_csv(pokemon_url) + return + + +@app.cell +def _(pl): + import io + import requests + import zipfile + + train_parquet_link = "https://www.kaggle.com/api/v1/datasets/download/shahmirvarqha/train-stations-amsterdam" + + response = requests.get(train_parquet_link) + zip_data = io.BytesIO(response.content) + with zipfile.ZipFile(zip_data) as z: + # List files to find parquet file, assuming one parquet file is in the archive + parquet_file_name = [f for f in z.namelist() if f.endswith(".parquet")][0] + with z.open(parquet_file_name) as parquet_file: + trains_df = pl.read_parquet(parquet_file) + + trains_df[:20000] + return + + +if __name__ == "__main__": + app.run() diff --git a/marimo/_smoke_tests/tables/complex_types.py b/marimo/_smoke_tests/tables/complex_types.py index 6ae6c7a6c8e..83975e52d01 100644 --- a/marimo/_smoke_tests/tables/complex_types.py +++ b/marimo/_smoke_tests/tables/complex_types.py @@ -1,6 +1,6 @@ import marimo -__generated_with = "0.16.5" +__generated_with = "0.17.2" app = marimo.App(width="medium") @@ -97,6 +97,9 @@ def _(): ), } ) + + # df = pl.concat([df]*10, rechunk=True).with_row_count(name="idx") + mo.ui.table(df) return df, mo, pl @@ -193,5 +196,6 @@ def _(): return + if __name__ == "__main__": app.run() diff --git a/pixi.lock b/pixi.lock index 74ea5e1fa91..83b5dadd435 100644 --- a/pixi.lock +++ b/pixi.lock @@ -1101,8 +1101,8 @@ packages: timestamp: 1727801725384 - pypi: ./ name: marimo - version: 0.17.0 - sha256: 5e2b68e9ec635e489b7e57183e46a5199d745621494165518f589a5498327212 + version: 0.17.2 + sha256: d81b66fd7b9230f46e084209d14ea657198fa4a86285c07a2e1956d761d985cd requires_dist: - click>=8.0,<9 - jedi>=0.18.0 diff --git a/pyproject.toml b/pyproject.toml index 8d080e21189..33d9bd1d881 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -549,7 +549,6 @@ format_on_save = true [tool.marimo.experimental] multi_column = true -performant_table_charts = true chat_modes = true cache_panel = true diff --git a/tests/_plugins/ui/_impl/test_table.py b/tests/_plugins/ui/_impl/test_table.py index 439e964238b..3190578def3 100644 --- a/tests/_plugins/ui/_impl/test_table.py +++ b/tests/_plugins/ui/_impl/test_table.py @@ -13,6 +13,7 @@ from marimo._plugins import ui from marimo._plugins.ui._impl.dataframes.transforms.types import Condition from marimo._plugins.ui._impl.table import ( + CHART_MAX_ROWS_STRING_VALUE_COUNTS, DEFAULT_MAX_COLUMNS, MAX_COLUMNS_NOT_PROVIDED, CalculateTopKRowsArgs, @@ -849,9 +850,7 @@ def test_table_with_too_many_rows_column_summaries_disabled() -> None: data = {"a": list(range(20))} table = ui.table(data, _internal_summary_row_limit=10) - summaries_disabled = table._get_column_summaries( - ColumnSummariesArgs(precompute=False) - ) + summaries_disabled = table._get_column_summaries(ColumnSummariesArgs()) assert summaries_disabled.is_disabled is True # search results are 2 and 12 @@ -862,9 +861,7 @@ def test_table_with_too_many_rows_column_summaries_disabled() -> None: page_number=0, ) ) - summaries_enabled = table._get_column_summaries( - ColumnSummariesArgs(precompute=False) - ) + summaries_enabled = table._get_column_summaries(ColumnSummariesArgs()) assert summaries_enabled.is_disabled is False @@ -872,11 +869,9 @@ def test_with_too_many_rows_column_charts_disabled() -> None: data = {"a": list(range(20))} table = ui.table(data, _internal_column_charts_row_limit=10) - charts_disabled = table._get_column_summaries( - ColumnSummariesArgs(precompute=False) - ) + charts_disabled = table._get_column_summaries(ColumnSummariesArgs()) + assert charts_disabled.show_charts is False assert charts_disabled.is_disabled is False - assert charts_disabled.data is None # search results are 2 and 12 table._search( @@ -886,9 +881,9 @@ def test_with_too_many_rows_column_charts_disabled() -> None: page_number=0, ) ) - charts_enabled = table._get_column_summaries( - ColumnSummariesArgs(precompute=False) - ) + charts_enabled = table._get_column_summaries(ColumnSummariesArgs()) + assert charts_enabled.show_charts is True + assert charts_enabled.data is None assert charts_enabled.is_disabled is False @@ -905,13 +900,9 @@ def test_get_column_summaries_after_search() -> None: page_number=0, ) ) - summaries = table._get_column_summaries( - ColumnSummariesArgs(precompute=False) - ) + summaries = table._get_column_summaries(ColumnSummariesArgs()) + assert summaries.show_charts is True assert summaries.is_disabled is False - summaries_data = from_data_uri(summaries.data)[1].decode("utf-8") - # Result is csv or json - assert summaries_data in ["a\n2\n12\n", '[{"a":2},{"a":12}]'] # We don't have column summaries for non-dataframe data assert summaries.stats["a"].min is None assert summaries.stats["a"].max is None @@ -923,11 +914,9 @@ def test_get_column_summaries_after_search() -> None: ) def test_get_column_summaries_after_search_df(df: Any) -> None: table = ui.table(df) - summaries = table._get_column_summaries( - ColumnSummariesArgs(precompute=False) - ) + summaries = table._get_column_summaries(ColumnSummariesArgs()) + assert summaries.show_charts is True assert summaries.is_disabled is False - assert isinstance(summaries.data, str) # Different dataframe types return different formats FORMATS = [ "data:text/plain;base64,", # arrow format for polars @@ -935,7 +924,6 @@ def test_get_column_summaries_after_search_df(df: Any) -> None: "data:text/csv;base64,", ] - assert any(summaries.data.startswith(fmt) for fmt in FORMATS) assert summaries.stats["a"].min == 0 assert summaries.stats["a"].max == 19 @@ -947,12 +935,9 @@ def test_get_column_summaries_after_search_df(df: Any) -> None: page_number=0, ) ) - summaries = table._get_column_summaries( - ColumnSummariesArgs(precompute=False) - ) + summaries = table._get_column_summaries(ColumnSummariesArgs()) + assert summaries.show_charts is True assert summaries.is_disabled is False - assert isinstance(summaries.data, str) - assert any(summaries.data.startswith(fmt) for fmt in FORMATS) # We don't have column summaries for non-dataframe data assert summaries.stats["a"].min == 2 assert summaries.stats["a"].max == 12 @@ -964,40 +949,34 @@ def test_show_column_summaries_modes(): # Test stats-only mode table_stats = ui.table(data, show_column_summaries="stats") - summaries_stats = table_stats._get_column_summaries( - ColumnSummariesArgs(precompute=False) - ) + summaries_stats = table_stats._get_column_summaries(ColumnSummariesArgs()) + assert summaries_stats.show_charts is False assert summaries_stats.is_disabled is False - assert summaries_stats.data is None assert summaries_stats.bin_values == {} assert summaries_stats.value_counts == {} assert len(summaries_stats.stats) > 0 # Test chart-only mode table_chart = ui.table(data, show_column_summaries="chart") - summaries_chart = table_chart._get_column_summaries( - ColumnSummariesArgs(precompute=False) - ) + summaries_chart = table_chart._get_column_summaries(ColumnSummariesArgs()) + assert summaries_chart.show_charts is True assert summaries_chart.is_disabled is False - assert summaries_chart.data is not None assert len(summaries_chart.stats) == 0 # Test default mode (both stats and chart) table_both = ui.table(data, show_column_summaries=True) - summaries_both = table_both._get_column_summaries( - ColumnSummariesArgs(precompute=False) - ) + summaries_both = table_both._get_column_summaries(ColumnSummariesArgs()) + assert summaries_both.show_charts is True assert summaries_both.is_disabled is False - assert summaries_both.data is not None assert len(summaries_both.stats) > 0 # Test disabled mode table_disabled = ui.table(data, show_column_summaries=False) summaries_disabled = table_disabled._get_column_summaries( - ColumnSummariesArgs(precompute=False) + ColumnSummariesArgs() ) assert summaries_disabled.is_disabled is False - assert summaries_disabled.data is None + assert summaries_disabled.show_charts is False assert summaries_disabled.bin_values == {} assert summaries_disabled.value_counts == {} assert len(summaries_disabled.stats) == 0 @@ -1005,10 +984,10 @@ def test_show_column_summaries_modes(): # Test Default behavior table_default = ui.table(data) summaries_default = table_default._get_column_summaries( - ColumnSummariesArgs(precompute=False) + ColumnSummariesArgs() ) + assert summaries_default.show_charts is True assert summaries_default.is_disabled is False - assert summaries_default.data is not None assert len(summaries_default.stats) > 0 assert table_default._component_args["show-column-summaries"] is True @@ -1020,9 +999,7 @@ class TestTableBinValues: ) def test_bin_values_all_nulls(self, df: Any) -> None: table = ui.table(df) - summaries = table._get_column_summaries( - ColumnSummariesArgs(precompute=True) - ) + summaries = table._get_column_summaries(ColumnSummariesArgs()) # Returns empty list assert summaries.bin_values == {"a": []} @@ -1067,6 +1044,41 @@ def test_with_nulls(self, table: ui.table) -> None: ValueCount(value="4", count=1), ] + @pytest.mark.skipif( + not DependencyManager.pandas.has(), reason="Pandas not installed" + ) + def test_rows_string_value_counts_limit(self) -> None: + import pandas as pd + + data = pd.DataFrame( + {"a": [str(i) for i in range(CHART_MAX_ROWS_STRING_VALUE_COUNTS)]} + ) + table = ui.table(data) + summaries = table._get_column_summaries(ColumnSummariesArgs()) + assert summaries.value_counts == { + "a": [ + ValueCount( + value="unique values", + count=CHART_MAX_ROWS_STRING_VALUE_COUNTS, + ) + ] + } + assert summaries.data is None + + # If >20k rows, we should not get value_counts + data = pd.DataFrame( + { + "a": [ + str(i) + for i in range(CHART_MAX_ROWS_STRING_VALUE_COUNTS + 1) + ] + } + ) + table = ui.table(data) + summaries = table._get_column_summaries(ColumnSummariesArgs()) + assert summaries.data is None + assert summaries.value_counts == {} + def test_with_smaller_limit(self, table: ui.table) -> None: value_counts = table._get_value_counts( column="repeat", size=2, total_rows=self.total_rows @@ -1177,10 +1189,40 @@ def test_show_column_summaries_disabled(): summaries = table._get_column_summaries(EmptyArgs()) assert summaries.is_disabled is False - assert summaries.data is None assert len(summaries.stats) == 0 +@pytest.mark.skipif( + not DependencyManager.polars.has(), reason="Polars not installed" +) +def test_column_summaries_fallback(monkeypatch): + import polars as pl + + data = pl.DataFrame( + { + "a": [1, 2, 3] * 200, + "b": [4, 5, 6] * 200, + "c": [7, 8, 9] * 200, + } + ) + table = ui.table(data) + + def always_fail_get_bin_values(*_args: Any, **_kwargs: Any) -> None: + raise RuntimeError("Intentional bin failure") + + monkeypatch.setattr( + table._manager, "get_bin_values", always_fail_get_bin_values + ) + + summaries = table._get_column_summaries(ColumnSummariesArgs()) + assert summaries.is_disabled is False + assert summaries.bin_values == {} + assert summaries.value_counts == {} + assert summaries.show_charts is True + # Should have chart data + assert summaries.data is not None + + @pytest.mark.parametrize( "df", create_dataframes(