diff --git a/frontend/src/components/databases/namespace-icons.ts b/frontend/src/components/databases/namespace-icons.ts index 319d925c3c4..29ebf08b008 100644 --- a/frontend/src/components/databases/namespace-icons.ts +++ b/frontend/src/components/databases/namespace-icons.ts @@ -2,6 +2,7 @@ export { BookMarkedIcon as IndexIcon, + Cog as DatasourceIcon, ColumnsIcon as ColumnIcon, DatabaseIcon, EyeIcon as ViewIcon, diff --git a/frontend/src/core/ai/context/__tests__/utils.test.ts b/frontend/src/core/ai/context/__tests__/utils.test.ts index d4f1e83a88f..f6815d9bc78 100644 --- a/frontend/src/core/ai/context/__tests__/utils.test.ts +++ b/frontend/src/core/ai/context/__tests__/utils.test.ts @@ -103,7 +103,7 @@ describe("contextToXml", () => { ); }); - it("should handle complex nested data", () => { + it("should handle json string data", () => { const context = { type: "complex", data: { @@ -119,6 +119,37 @@ describe("contextToXml", () => { ); }); + it("should handle objects", () => { + const context = { + type: "object", + data: { + name: "test", + config: { key: "value", nested: { prop: "test" } }, + }, + details: "Complex configuration data", + }; + + const result = contextToXml(context); + expect(result).toMatchInlineSnapshot( + `"Complex configuration data"`, + ); + }); + + it("should handle arrays", () => { + const context = { + type: "array", + data: { + name: "test", + array: [1, 2, 3], + }, + }; + + const result = contextToXml(context); + expect(result).toMatchInlineSnapshot( + `""`, + ); + }); + it("should handle boolean values", () => { const context = { type: "flags", diff --git a/frontend/src/core/ai/context/context.ts b/frontend/src/core/ai/context/context.ts index 5228918430c..86aca1f9e4d 100644 --- a/frontend/src/core/ai/context/context.ts +++ b/frontend/src/core/ai/context/context.ts @@ -1,10 +1,14 @@ /* Copyright 2024 Marimo. All rights reserved. */ -import { allTablesAtom } from "@/core/datasets/data-source-connections"; +import { + allTablesAtom, + dataSourceConnectionsAtom, +} from "@/core/datasets/data-source-connections"; import { getRequestClient } from "@/core/network/requests"; import type { JotaiStore } from "@/core/state/jotai"; import { variablesAtom } from "@/core/variables/state"; import { CellOutputContextProvider } from "./providers/cell-output"; +import { DatasourceContextProvider } from "./providers/datasource"; import { ErrorContextProvider } from "./providers/error"; import { FileContextProvider } from "./providers/file"; import { TableContextProvider } from "./providers/tables"; @@ -12,6 +16,7 @@ import { VariableContextProvider } from "./providers/variable"; import { AIContextRegistry } from "./registry"; export function getAIContextRegistry(store: JotaiStore) { + const datasource = store.get(dataSourceConnectionsAtom); const tablesMap = store.get(allTablesAtom); const variables = store.get(variablesAtom); @@ -19,7 +24,10 @@ export function getAIContextRegistry(store: JotaiStore) { .register(new TableContextProvider(tablesMap)) .register(new VariableContextProvider(variables, tablesMap)) .register(new ErrorContextProvider(store)) - .register(new CellOutputContextProvider(store)); + .register(new CellOutputContextProvider(store)) + .register( + new DatasourceContextProvider(datasource.connectionsMap, tablesMap), + ); } export function getFileContextProvider(): FileContextProvider { diff --git a/frontend/src/core/ai/context/providers/__tests__/__snapshots__/datasource.test.ts.snap b/frontend/src/core/ai/context/providers/__tests__/__snapshots__/datasource.test.ts.snap new file mode 100644 index 00000000000..85c12b4c0f1 --- /dev/null +++ b/frontend/src/core/ai/context/providers/__tests__/__snapshots__/datasource.test.ts.snap @@ -0,0 +1,7 @@ +// Vitest Snapshot v1, https://vitest.dev/guide/snapshot.html + +exports[`DatasourceContextProvider > formatContext > should format context for datasource with multiple databases > multi-db-context 1`] = `""`; + +exports[`DatasourceContextProvider > formatContext > should format context for internal duckdb with multiple tables > internal-datasource-context 1`] = `""`; + +exports[`DatasourceContextProvider > formatContext > should format context for postgres datasource > postgres-datasource-context 1`] = `""`; diff --git a/frontend/src/core/ai/context/providers/__tests__/datasource.test.ts b/frontend/src/core/ai/context/providers/__tests__/datasource.test.ts new file mode 100644 index 00000000000..40cfb0f9506 --- /dev/null +++ b/frontend/src/core/ai/context/providers/__tests__/datasource.test.ts @@ -0,0 +1,633 @@ +/* Copyright 2024 Marimo. All rights reserved. */ +/* eslint-disable @typescript-eslint/no-explicit-any */ + +import { beforeEach, describe, expect, it } from "vitest"; +import type { + ConnectionsMap, + DatasetTablesMap, +} from "@/core/datasets/data-source-connections"; +import { DUCKDB_ENGINE } from "@/core/datasets/engines"; +import type { DataSourceConnection, DataTable } from "@/core/kernel/messages"; +import { Boosts } from "../common"; +import { DatasourceContextProvider } from "../datasource"; + +// Mock data for testing +const createMockDataSourceConnection = ( + name: string, + options: Partial = {}, +): DataSourceConnection => ({ + name, + dialect: "duckdb", + source: "duckdb", + display_name: `Test ${name}`, + databases: [ + { + name: "main", + dialect: "duckdb", + schemas: [ + { + name: "public", + tables: [ + { + name: "users", + source_type: "connection", + source: name, + num_rows: 100, + num_columns: 3, + variable_name: null, + columns: [], + }, + { + name: "orders", + source_type: "connection", + source: name, + num_rows: 50, + num_columns: 4, + variable_name: null, + columns: [], + }, + ], + }, + { + name: "analytics", + tables: [ + { + name: "events", + source_type: "connection", + source: name, + num_rows: 200, + num_columns: 5, + variable_name: null, + columns: [], + }, + ], + }, + ], + }, + ], + ...options, +}); + +const createMockConnectionsMap = ( + connections: DataSourceConnection[], +): ConnectionsMap => { + const map = new Map(); + connections.forEach((conn) => { + map.set(conn.name, conn); + }); + return map; +}; + +const createMockDataTable = ( + name: string, + options: Partial = {}, +): DataTable => ({ + name, + source: "local", + source_type: "local", + num_rows: 100, + num_columns: 3, + variable_name: name, // This makes it a dataframe + columns: [ + { + name: "id", + type: "integer", + external_type: "INTEGER", + sample_values: [1, 2, 3], + }, + { + name: "name", + type: "string", + external_type: "VARCHAR", + sample_values: ["Alice", "Bob", "Charlie"], + }, + { + name: "age", + type: "integer", + external_type: "INTEGER", + sample_values: [25, 30, 35], + }, + ], + engine: null, + indexes: null, + primary_keys: null, + type: "table", + ...options, +}); + +const createMockTablesMap = (tables: DataTable[]): DatasetTablesMap => { + const map = new Map(); + tables.forEach((table) => { + map.set(table.name, table); + }); + return map; +}; + +describe("DatasourceContextProvider", () => { + let provider: DatasourceContextProvider; + let connectionsMap: ConnectionsMap; + let tablesMap: DatasetTablesMap; + + beforeEach(() => { + connectionsMap = createMockConnectionsMap([ + createMockDataSourceConnection(DUCKDB_ENGINE), + createMockDataSourceConnection("postgres", { + dialect: "postgresql", + source: "postgresql", + display_name: "PostgreSQL Database", + databases: [ + { + name: "production", + dialect: "postgresql", + schemas: [ + { + name: "public", + tables: [ + { + name: "customers", + source_type: "connection", + source: "postgres", + num_rows: 1000, + num_columns: 8, + variable_name: null, + columns: [], + }, + { + name: "products", + source_type: "connection", + source: "postgres", + num_rows: 500, + num_columns: 6, + variable_name: null, + columns: [], + }, + { + name: "sales", + source_type: "connection", + source: "postgres", + num_rows: 5000, + num_columns: 10, + variable_name: null, + columns: [], + }, + ], + }, + ], + }, + ], + }), + ]); + tablesMap = createMockTablesMap([ + createMockDataTable("users"), + createMockDataTable("orders"), + createMockDataTable("events"), + ]); + provider = new DatasourceContextProvider(connectionsMap, tablesMap); + }); + + describe("provider properties", () => { + it("should have correct provider properties", () => { + expect(provider.title).toBe("Datasource"); + expect(provider.mentionPrefix).toBe("@"); + expect(provider.contextType).toBe("datasource"); + }); + }); + + describe("getItems", () => { + it("should return empty array when no connections", () => { + const emptyMap = createMockConnectionsMap([]); + const emptyProvider = new DatasourceContextProvider(emptyMap, new Map()); + const items = emptyProvider.getItems(); + expect(items).toEqual([]); + }); + + it("should return datasource items when connections exist", () => { + const items = provider.getItems(); + + expect(items).toHaveLength(2); + + // Check first item (duckdb) + expect(items[0]).toMatchObject({ + name: DUCKDB_ENGINE, + type: "datasource", + data: { + connection: { + name: DUCKDB_ENGINE, + dialect: "duckdb", + source: "duckdb", + display_name: `Test ${DUCKDB_ENGINE}`, + }, + }, + }); + + // Check second item (postgres) + expect(items[1]).toMatchObject({ + name: "postgres", + type: "datasource", + data: { + connection: { + name: "postgres", + dialect: "postgresql", + source: "postgresql", + display_name: "PostgreSQL Database", + }, + }, + }); + + // Check URIs are properly formatted + expect(items[0].uri).toBe(`datasource://${DUCKDB_ENGINE}`); + expect(items[1].uri).toBe("datasource://postgres"); + }); + + it("should include dataframes for internal SQL engines", () => { + const items = provider.getItems(); + const duckdbItem = items.find((item) => item.name === DUCKDB_ENGINE)!; + + // DuckDB is an internal engine, so it should have tables (dataframes) + expect(duckdbItem.data.tables).toBeDefined(); + expect(duckdbItem.data.tables).toHaveLength(3); + expect(duckdbItem.data.tables?.map((t) => t.name)).toEqual([ + "users", + "orders", + "events", + ]); + + // PostgreSQL is external, so it should not have tables + const postgresItem = items.find((item) => item.name === "postgres")!; + expect(postgresItem.data.tables).toBeUndefined(); + }); + + it("should handle connections with no databases", () => { + const emptyDbConnection = createMockDataSourceConnection("empty", { + databases: [], + }); + const mapWithEmpty = createMockConnectionsMap([emptyDbConnection]); + const providerWithEmpty = new DatasourceContextProvider( + mapWithEmpty, + new Map(), + ); + + const items = providerWithEmpty.getItems(); + expect(items).toHaveLength(1); + expect(items[0].data.connection.databases).toEqual([]); + }); + + it("should handle connections with databases but no schemas", () => { + const emptySchemaConnection = createMockDataSourceConnection( + "empty-schema", + { + databases: [ + { + name: "empty_db", + dialect: "duckdb", + schemas: [], + }, + ], + }, + ); + const mapWithEmptySchema = createMockConnectionsMap([ + emptySchemaConnection, + ]); + const providerWithEmptySchema = new DatasourceContextProvider( + mapWithEmptySchema, + new Map(), + ); + + const items = providerWithEmptySchema.getItems(); + expect(items).toHaveLength(1); + expect(items[0].data.connection.databases[0].schemas).toEqual([]); + }); + }); + + describe("formatCompletion", () => { + it("should format completion for datasource with tables", () => { + const items = provider.getItems(); + const duckdbItem = items.find((item) => item.name === DUCKDB_ENGINE)!; + const completion = provider.formatCompletion(duckdbItem); + + expect(completion).toMatchObject({ + label: "@In-Memory", + displayLabel: "In-Memory", + detail: "DuckDB", + boost: Boosts.MEDIUM, + type: "datasource", + section: "Data Sources", + }); + + expect(completion.info).toBeDefined(); + }); + + it("should format completion for datasource with no tables", () => { + const emptyConnection = createMockDataSourceConnection("empty", { + databases: [ + { + name: "empty_db", + dialect: "duckdb", + schemas: [ + { + name: "empty_schema", + tables: [], + }, + ], + }, + ], + }); + const mapWithEmpty = createMockConnectionsMap([emptyConnection]); + const providerWithEmpty = new DatasourceContextProvider( + mapWithEmpty, + new Map(), + ); + + const items = providerWithEmpty.getItems(); + const completion = provider.formatCompletion(items[0]); + + expect(completion.detail).toBe("DuckDB"); + }); + + it("should format completion for postgres connection", () => { + const items = provider.getItems(); + const postgresItem = items.find((item) => item.name === "postgres")!; + const completion = provider.formatCompletion(postgresItem); + + expect(completion).toMatchObject({ + label: "@postgres", + displayLabel: "postgres", + detail: "PostgreSQL", + boost: Boosts.MEDIUM, + type: "datasource", + section: "Data Sources", + }); + }); + + it("should format completion for in-memory engine", () => { + const duckdbConnection = createMockDataSourceConnection(DUCKDB_ENGINE, { + dialect: "duckdb", + source: "duckdb", + display_name: "DuckDB (In-Memory)", + }); + const mapWithInMemory = createMockConnectionsMap([duckdbConnection]); + const providerWithInMemory = new DatasourceContextProvider( + mapWithInMemory, + new Map(), + ); + const items = providerWithInMemory.getItems(); + const inMemoryItem = items.find((item) => item.name === DUCKDB_ENGINE)!; + const completion = provider.formatCompletion(inMemoryItem); + + expect(completion).toMatchObject({ + label: "@In-Memory", + displayLabel: "In-Memory", + detail: "DuckDB", + boost: Boosts.MEDIUM, + type: "datasource", + section: "Data Sources", + }); + }); + }); + + describe("formatContext", () => { + it("should format context for internal duckdb with multiple tables", () => { + const items = provider.getItems(); + const duckdbItem = items.find((item) => item.name === DUCKDB_ENGINE)!; + const context = provider.formatContext(duckdbItem); + + expect(context).not.toContain('"engine_name":"__marimo_duckdb"'); + expect(context).toMatchSnapshot("internal-datasource-context"); + }); + + it("should include dataframes in context for internal engines", () => { + const items = provider.getItems(); + const duckdbItem = items.find((item) => item.name === DUCKDB_ENGINE)!; + const context = provider.formatContext(duckdbItem); + + // Should include the dataframes in the context + expect(context).toContain('"name":"users"'); + expect(context).toContain('"name":"orders"'); + expect(context).toContain('"name":"events"'); + expect(context).toContain('"variable_name":"users"'); + expect(context).toContain('"variable_name":"orders"'); + expect(context).toContain('"variable_name":"events"'); + }); + + it("should format context for postgres datasource", () => { + const items = provider.getItems(); + const postgresItem = items.find((item) => item.name === "postgres")!; + const context = provider.formatContext(postgresItem); + + expect(context).toContain('"engine_name":"postgres"'); + expect(context).toMatchSnapshot("postgres-datasource-context"); + }); + + it("should format context for datasource with no tables", () => { + const emptyConnection = createMockDataSourceConnection("empty", { + databases: [ + { + name: "empty_db", + dialect: "duckdb", + schemas: [ + { + name: "empty_schema", + tables: [], + }, + ], + }, + ], + }); + const mapWithEmpty = createMockConnectionsMap([emptyConnection]); + const providerWithEmpty = new DatasourceContextProvider( + mapWithEmpty, + new Map(), + ); + + const items = providerWithEmpty.getItems(); + const context = providerWithEmpty.formatContext(items[0]); + + expect(context).toContain('"dialect":"duckdb"'); + expect(context).not.toContain('"name":"Test empty"'); + expect(context).not.toContain('"source":"duckdb"'); + expect(context).not.toContain('"display_name":"Test empty"'); + }); + + it("should format context for datasource with multiple databases", () => { + const multiDbConnection = createMockDataSourceConnection("multi", { + default_database: "db1", + default_schema: "schema1", + databases: [ + { + name: "db1", + dialect: "duckdb", + schemas: [ + { + name: "schema1", + tables: [ + { + name: "table1", + source_type: "connection", + source: "multi", + num_rows: 10, + num_columns: 2, + variable_name: null, + columns: [], + }, + { + name: "table2", + source_type: "connection", + source: "multi", + num_rows: 20, + num_columns: 3, + variable_name: null, + columns: [], + }, + ], + }, + ], + }, + { + name: "db2", + dialect: "duckdb", + schemas: [ + { + name: "schema2", + tables: [ + { + name: "table3", + source_type: "connection", + source: "multi", + num_rows: 15, + num_columns: 4, + variable_name: null, + columns: [], + }, + ], + }, + { + name: "schema3", + tables: [], + }, + ], + }, + ], + }); + const mapWithMulti = createMockConnectionsMap([multiDbConnection]); + const providerWithMulti = new DatasourceContextProvider( + mapWithMulti, + new Map(), + ); + + const items = providerWithMulti.getItems(); + const context = provider.formatContext(items[0]); + + expect(context).toContain('"dialect":"duckdb"'); + expect(context).toContain('"default_database":"db1"'); + expect(context).toContain('"default_schema":"schema1"'); + expect(context).toMatchSnapshot("multi-db-context"); + }); + }); + + describe("edge cases", () => { + it("should highlight default database and schema", () => { + const connectionWithDefaults = createMockDataSourceConnection("test", { + default_database: "main", + default_schema: "public", + databases: [ + { + name: "main", + dialect: "duckdb", + schemas: [ + { + name: "public", + tables: [ + { + name: "users", + source_type: "connection", + source: "test", + num_rows: 100, + num_columns: 3, + variable_name: null, + columns: [], + }, + ], + }, + { + name: "analytics", + tables: [], + }, + ], + }, + ], + }); + const mapWithDefaults = createMockConnectionsMap([ + connectionWithDefaults, + ]); + const providerWithDefaults = new DatasourceContextProvider( + mapWithDefaults, + new Map(), + ); + + const items = providerWithDefaults.getItems(); + const completion = providerWithDefaults.formatCompletion(items[0]); + + expect(completion.info).toBeDefined(); + expect(typeof completion.info).toBe("function"); + }); + + it("should handle connections with special characters in names", () => { + const specialConnection = createMockDataSourceConnection("test-db_123", { + display_name: "Test DB (123)", + }); + const mapWithSpecial = createMockConnectionsMap([specialConnection]); + const providerWithSpecial = new DatasourceContextProvider( + mapWithSpecial, + new Map(), + ); + + const items = providerWithSpecial.getItems(); + expect(items).toHaveLength(1); + expect(items[0].name).toBe("test-db_123"); + expect(items[0].uri).toBe("datasource://test-db_123"); + + const context = providerWithSpecial.formatContext(items[0]); + expect(context).toContain('"dialect":"duckdb"'); + }); + + it("should handle very large numbers of tables", () => { + const largeConnection = createMockDataSourceConnection("large", { + databases: [ + { + name: "large_db", + dialect: "duckdb", + schemas: [ + { + name: "large_schema", + tables: Array.from({ length: 100 }, (_, i) => ({ + name: `table_${i}`, + source_type: "connection" as const, + source: "large", + num_rows: i + 1, + num_columns: 2, + variable_name: null, + columns: [], + })), + }, + ], + }, + ], + }); + const mapWithLarge = createMockConnectionsMap([largeConnection]); + const providerWithLarge = new DatasourceContextProvider( + mapWithLarge, + new Map(), + ); + + const items = providerWithLarge.getItems(); + const completion = providerWithLarge.formatCompletion(items[0]); + expect(completion.detail).toBe("DuckDB"); + + const context = providerWithLarge.formatContext(items[0]); + // Since we now return the entire data structure, check for basic properties + expect(context).not.toContain('"name":"Test large"'); + expect(context).toContain('"dialect":"duckdb"'); + expect(context).not.toContain('"source":"duckdb"'); + }); + }); +}); diff --git a/frontend/src/core/ai/context/providers/common.ts b/frontend/src/core/ai/context/providers/common.ts index 7e99f421332..d109a9b1075 100644 --- a/frontend/src/core/ai/context/providers/common.ts +++ b/frontend/src/core/ai/context/providers/common.ts @@ -1,5 +1,6 @@ /* Copyright 2024 Marimo. All rights reserved. */ +/** Number from -99 to 99. Higher numbers are prioritized when surfacing completions. */ export const Boosts = { LOCAL_TABLE: 5, REMOTE_TABLE: 4, diff --git a/frontend/src/core/ai/context/providers/datasource.ts b/frontend/src/core/ai/context/providers/datasource.ts new file mode 100644 index 00000000000..e3701f7ad89 --- /dev/null +++ b/frontend/src/core/ai/context/providers/datasource.ts @@ -0,0 +1,131 @@ +/* Copyright 2024 Marimo. All rights reserved. */ + +import type { Completion } from "@codemirror/autocomplete"; +import { createRoot } from "react-dom/client"; +import { dbDisplayName } from "@/components/databases/display"; +import { renderDatasourceInfo } from "@/core/codemirror/language/languages/sql/renderers"; +import { + type ConnectionsMap, + type DatasetTablesMap, + getTableType, +} from "@/core/datasets/data-source-connections"; +import { + type ConnectionName, + INTERNAL_SQL_ENGINES, +} from "@/core/datasets/engines"; +import type { DataSourceConnection, DataTable } from "@/core/kernel/messages"; +import type { AIContextItem } from "../registry"; +import { AIContextProvider } from "../registry"; +import { contextToXml } from "../utils"; +import { Boosts } from "./common"; + +type NamedDatasource = Omit< + DataSourceConnection, + "name" | "display_name" | "source" +> & { + // Easier for the AI to write mo.sql with engine name + engine_name?: ConnectionName; +}; + +export interface DatasourceContextItem extends AIContextItem { + type: "datasource"; + // For internal engine, it can have both connection and data tables + // For external engines, the data is just a DataSourceConnection + data: { + connection: DataSourceConnection; + tables?: DataTable[]; + }; +} + +export class DatasourceContextProvider extends AIContextProvider { + readonly title = "Datasource"; + readonly mentionPrefix = "@"; + readonly contextType = "datasource"; + private connectionsMap: ConnectionsMap; + private dataframes: DataTable[]; + + constructor(connectionsMap: ConnectionsMap, tablesMap: DatasetTablesMap) { + super(); + this.connectionsMap = connectionsMap; + this.dataframes = [...tablesMap.values()].filter( + (table: DataTable) => getTableType(table) === "dataframe", + ); + } + + getItems(): DatasourceContextItem[] { + return [...this.connectionsMap.values()].map((connection) => { + let description = "Database schema."; + const data: DatasourceContextItem["data"] = { + connection: connection, + }; + if (INTERNAL_SQL_ENGINES.has(connection.name)) { + data.tables = this.dataframes; + description = "Database schema and the dataframes that can be queried"; + } + + return { + uri: this.asURI(connection.name), + name: connection.name, + description: description, + type: this.contextType, + data: data, + }; + }); + } + + formatContext(item: DatasourceContextItem): string { + const data = item.data; + // Remove certain fields that are not needed in the context + const { name, display_name, source, ...filteredDatasource } = + data.connection; + + let datasource = filteredDatasource; + const isInternalEngine = INTERNAL_SQL_ENGINES.has(name as ConnectionName); + if (!isInternalEngine) { + const namedDatasource: NamedDatasource = { + ...filteredDatasource, + engine_name: name as ConnectionName, // Add the engine name for external engines + }; + datasource = namedDatasource; + } + + return contextToXml({ + type: this.contextType, + data: { + connection: datasource, + tables: data.tables, + }, + }); + } + + formatCompletion(item: DatasourceContextItem): Completion { + const datasource = item.data; + + const dataConnection = datasource.connection; + const dataframes = datasource.tables; + + let label = dataConnection.name; + if (INTERNAL_SQL_ENGINES.has(dataConnection.name as ConnectionName)) { + label = "In-Memory"; + } + + return { + label: `@${label}`, + displayLabel: label, + detail: dbDisplayName(dataConnection.dialect), + boost: Boosts.MEDIUM, + type: this.contextType, + section: "Data Sources", + info: () => { + const infoContainer = document.createElement("div"); + infoContainer.classList.add("mo-cm-tooltip", "docs-documentation"); + + // Use React to render the datasource info + const root = createRoot(infoContainer); + root.render(renderDatasourceInfo(dataConnection, dataframes)); + + return infoContainer; + }, + }; + } +} diff --git a/frontend/src/core/ai/context/providers/error.ts b/frontend/src/core/ai/context/providers/error.ts index fe96ee34554..66f35dc0c8e 100644 --- a/frontend/src/core/ai/context/providers/error.ts +++ b/frontend/src/core/ai/context/providers/error.ts @@ -5,6 +5,7 @@ import { cellErrorsAtom } from "@/core/cells/cells"; import type { CellId } from "@/core/cells/ids"; import type { MarimoError } from "@/core/kernel/messages"; import type { JotaiStore } from "@/core/state/jotai"; +import { logNever } from "@/utils/assertNever"; import { PluralWord } from "@/utils/pluralize"; import { type AIContextItem, AIContextProvider } from "../registry"; import { contextToXml } from "../utils"; @@ -56,6 +57,13 @@ function describeError(error: MarimoError): string { if (error.type === "unknown") { return error.msg; } + if (error.type === "sql-error") { + return error.msg; + } + if (error.type === "internal") { + return error.msg || "An internal error occurred"; + } + logNever(error); return "Unknown error"; } diff --git a/frontend/src/core/ai/context/providers/tables.ts b/frontend/src/core/ai/context/providers/tables.ts index c6b74f8c788..93783802842 100644 --- a/frontend/src/core/ai/context/providers/tables.ts +++ b/frontend/src/core/ai/context/providers/tables.ts @@ -1,7 +1,10 @@ /* Copyright 2024 Marimo. All rights reserved. */ import type { Completion } from "@codemirror/autocomplete"; -import type { DatasetTablesMap } from "@/core/datasets/data-source-connections"; +import { + type DatasetTablesMap, + getTableType, +} from "@/core/datasets/data-source-connections"; import type { DataTable } from "@/core/kernel/messages"; import type { AIContextItem } from "../registry"; import { AIContextProvider } from "../registry"; @@ -73,9 +76,9 @@ export class TableContextProvider extends AIContextProvider { table.source_type === "local" ? Boosts.LOCAL_TABLE : Boosts.REMOTE_TABLE, - type: table.variable_name ? "dataframe" : "table", + type: getTableType(table), apply: `@${tableName}`, - section: table.variable_name ? "Dataframe" : "Table", + section: getTableType(table) === "dataframe" ? "Dataframe" : "Table", info: () => this.createTableInfoElement(tableName, table), }; } diff --git a/frontend/src/core/ai/context/utils.ts b/frontend/src/core/ai/context/utils.ts index 09fddf54077..191e138479d 100644 --- a/frontend/src/core/ai/context/utils.ts +++ b/frontend/src/core/ai/context/utils.ts @@ -28,8 +28,13 @@ export function contextToXml(context: AiContextPayload): string { // Add data as attributes for (const [key, value] of Object.entries(data)) { - const escapedValue = escapeXml(String(value)); if (value !== undefined) { + // Serialize objects and arrays as JSON + const stringValue = + typeof value === "object" && value !== null + ? JSON.stringify(value) + : String(value); + const escapedValue = escapeXml(stringValue); xml += ` ${key}="${escapedValue}"`; } } diff --git a/frontend/src/core/codemirror/language/languages/sql/renderers.tsx b/frontend/src/core/codemirror/language/languages/sql/renderers.tsx index ceecc961bed..6f0c66d365e 100644 --- a/frontend/src/core/codemirror/language/languages/sql/renderers.tsx +++ b/frontend/src/core/codemirror/language/languages/sql/renderers.tsx @@ -2,9 +2,11 @@ import { HashIcon, InfoIcon } from "lucide-react"; import type React from "react"; +import { dbDisplayName } from "@/components/databases/display"; import { ColumnIcon, DatabaseIcon, + DatasourceIcon, IndexIcon, PrimaryKeyIcon, SchemaIcon, @@ -13,13 +15,19 @@ import { } from "@/components/databases/namespace-icons"; import { DATA_TYPE_ICON } from "@/components/datasets/icons"; import { Badge } from "@/components/ui/badge"; +import { + type ConnectionName, + INTERNAL_SQL_ENGINES, +} from "@/core/datasets/engines"; import type { Database, DatabaseSchema, + DataSourceConnection, DataTable, DataTableColumn, DataType, } from "@/core/kernel/messages"; +import { PluralWord } from "@/utils/pluralize"; // Configuration constants const PREVIEW_ITEM_LIMIT = 5; @@ -46,6 +54,12 @@ const SOURCE_TYPE_COLORS = { const CONTAINER_STYLES = "p-3 min-w-[250px] flex flex-col divide-y"; +const columnsText = new PluralWord("column", "columns"); +const rowsText = new PluralWord("row", "rows"); +const schemasText = new PluralWord("schema", "schemas"); +const tablesText = new PluralWord("table", "tables"); +const databasesText = new PluralWord("database", "databases"); + // Helper components and functions const SectionHeader: React.FC<{ icon: React.ReactNode; @@ -85,6 +99,10 @@ const PreviewList: React.FC<{ totalCount: number; limit?: number; }> = ({ title = "", items, totalCount, limit = PREVIEW_ITEM_LIMIT }) => { + if (items.length === 0) { + return null; + } + const visibleItems = items.slice(0, limit); const hasMore = totalCount > limit; @@ -203,13 +221,13 @@ export const renderTableInfo = (table: DataTable): React.ReactNode => { {table.num_columns != null && ( } - text={`${table.num_columns} columns`} + text={`${table.num_columns} ${columnsText.pluralize(table.num_columns)}`} /> )} {table.num_rows != null && ( } - text={`${table.num_rows} rows`} + text={`${table.num_rows} ${rowsText.pluralize(table.num_rows)}`} /> )} @@ -349,7 +367,7 @@ export const renderDatabaseInfo = (database: Database): React.ReactNode => { {schema.name} - {schema.tables.length} tables + {schema.tables.length} {tablesText.pluralize(schema.tables.length)} )); @@ -468,6 +486,156 @@ export const renderSchemaInfo = (schema: DatabaseSchema): React.ReactNode => { ); }; +const DefaultBadge = default; +const MAX_SCHEMAS_TO_DISPLAY = 8; +const MAX_TABLES_TO_DISPLAY = 3; + +export const renderDatasourceInfo = ( + connection: DataSourceConnection, + dataframes?: DataTable[], +): React.ReactNode => { + const databaseCount = connection.databases.length; + const schemasCount = connection.databases.reduce( + (count, db) => count + db.schemas.length, + 0, + ); + + const renderSchema = (schema: DatabaseSchema, isDefaultDb: boolean) => { + if (schema.tables.length === 0) { + return null; + } + + const isDefaultSchema = + schema.name === connection.default_schema && isDefaultDb; + + let tableItems: React.ReactNode[] = []; + // Don't display table items if there are many schemas + if (schemasCount < MAX_SCHEMAS_TO_DISPLAY) { + tableItems = schema.tables + .slice(0, MAX_TABLES_TO_DISPLAY + 1) + .map((table) => { + return ( +
+ + {table.name} +
+ ); + }); + } + + return ( +
+
+ + {schema.name} + {isDefaultSchema && DefaultBadge} + + {schema.tables.length} tables + +
+ +
+ ); + }; + + const databaseItems = connection.databases.map((db) => { + const isDefaultDb = + db.name === connection.default_database || + connection.databases.length === 1; + + const schemaItems = db.schemas.map((schema) => + renderSchema(schema, isDefaultDb), + ); + + return ( +
+
+ + {db.name} + {isDefaultDb && DefaultBadge} +
+ {schemaItems && ( + + )} +
+ ); + }); + + let title = connection.name; + if (INTERNAL_SQL_ENGINES.has(connection.name as ConnectionName)) { + title = "In-Memory"; + } + + const dataframeItems = dataframes?.map((table) => ( +
+ + {table.name} +
+ )); + + return ( +
+ } + title={title} + /> + + {/* Metadata */} +
+ + {dbDisplayName(connection.dialect)} + + } + /> + + {connection.source} + + } + /> +
+ + {/* Statistics */} +
+ } + text={`${databaseCount} ${databasesText.pluralize(databaseCount)}`} + /> + } + text={`${schemasCount} ${schemasText.pluralize(schemasCount)}`} + /> +
+ + {/* Database Preview */} + {databaseCount > 0 && ( + + )} + + {/* Tables Preview */} + {dataframeItems && dataframeItems.length > 0 && ( + + )} +
+ ); +}; + export const renderEmptyInfo = ( type: "column" | "table" | "schema" | "database", ) => { diff --git a/frontend/src/core/datasets/data-source-connections.ts b/frontend/src/core/datasets/data-source-connections.ts index e9ea8de93fc..0c3837b3d93 100644 --- a/frontend/src/core/datasets/data-source-connections.ts +++ b/frontend/src/core/datasets/data-source-connections.ts @@ -36,7 +36,7 @@ export interface DataSourceConnection name: ConnectionName; } -type ConnectionsMap = ReadonlyMap; +export type ConnectionsMap = ReadonlyMap; export interface DataSourceState { latestEngineSelected: ConnectionName; @@ -331,4 +331,12 @@ export const allTablesAtom = atom((get) => { return tableNames; }); +/** + * Dataframes are tables that are created from local Python dataframes + * In-memory engines can access dataframes + */ +export function getTableType(table: DataTable): "table" | "dataframe" { + return table.variable_name ? "dataframe" : "table"; +} + export type DatasetTablesMap = ReturnType<(typeof allTablesAtom)["read"]>; diff --git a/frontend/src/css/app/codemirror-completions.css b/frontend/src/css/app/codemirror-completions.css index dbfadf3301a..1a6b9c50155 100644 --- a/frontend/src/css/app/codemirror-completions.css +++ b/frontend/src/css/app/codemirror-completions.css @@ -181,6 +181,24 @@ } } + .cm-completionIcon-error { + color: var(--red-11); + background-color: var(--red-3); + + &::after { + content: "𝑒"; /* Italic e for errors */ + } + } + + .cm-completionIcon-datasource { + color: var(--purple-11); + background-color: var(--purple-3); + + &::after { + content: "⚙"; /* Gear icon for datasources */ + } + } + /* Completion info tooltip */ .cm-completionInfo { @apply max-w-md border-l border-border;