diff --git a/marimo/_ai/_tools/tools/datasource.py b/marimo/_ai/_tools/tools/datasource.py new file mode 100644 index 00000000000..f62afd16527 --- /dev/null +++ b/marimo/_ai/_tools/tools/datasource.py @@ -0,0 +1,111 @@ +# Copyright 2025 Marimo. All rights reserved. + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Optional + +from marimo import _loggers +from marimo._ai._tools.base import ToolBase +from marimo._ai._tools.types import SuccessResult +from marimo._ai._tools.utils.exceptions import ToolExecutionError +from marimo._data.models import DataTable +from marimo._server.sessions import Session +from marimo._types.ids import SessionId +from marimo._utils.fuzzy_match import compile_regex, is_fuzzy_match + +LOGGER = _loggers.marimo_logger() + + +@dataclass +class GetDatabaseTablesArgs: + session_id: SessionId + query: Optional[str] = None + + +@dataclass +class TableDetails: + connection: str + database: str + schema: str + table: DataTable + + +@dataclass +class GetDatabaseTablesOutput(SuccessResult): + tables: list[TableDetails] = field(default_factory=list) + + +class GetDatabaseTables( + ToolBase[GetDatabaseTablesArgs, GetDatabaseTablesOutput] +): + """ + Get information about tables in a database. + + Args: + session_id: The session id. + query (optional): The query to match the database, schemas, and tables. Regex is supported. + + If a query is provided, it will fuzzy match the query to the database, schemas, and tables available. If no query is provided, all tables are returned. Don't provide a query if you need to see the entire schema view. + + The tables returned contain information about the database, schema and connection name to use in forming SQL queries. + """ + + def handle(self, args: GetDatabaseTablesArgs) -> GetDatabaseTablesOutput: + session_id = args.session_id + session = self.context.get_session(session_id) + + return self._get_tables(session, args.query) + + def _get_tables( + self, session: Session, query: Optional[str] + ) -> list[TableDetails]: + session_view = session.session_view + data_connectors = session_view.data_connectors + + if len(data_connectors.connections) == 0: + raise ToolExecutionError( + message="No databases found. Please create a connection first.", + code="NO_DATABASES_FOUND", + is_retryable=False, + ) + + tables: list[TableDetails] = [] + + # Pre-compile regex if query exists + compiled_pattern = None + is_regex = False + if query: + compiled_pattern, is_regex = compile_regex(query) + + for connection in data_connectors.connections: + for database in connection.databases: + for schema in database.schemas: + # If query is None, match all schemas + if query is None or is_fuzzy_match( + query, schema.name, compiled_pattern, is_regex + ): + for table in schema.tables: + tables.append( + TableDetails( + connection=connection.name, + database=database.name, + schema=schema.name, + table=table, + ) + ) + continue + for table in schema.tables: + if is_fuzzy_match( + query, table.name, compiled_pattern, is_regex + ): + tables.append( + TableDetails( + connection=connection.name, + database=database.name, + schema=schema.name, + table=table, + ) + ) + + return GetDatabaseTablesOutput(tables=tables) diff --git a/marimo/_ai/_tools/tools_registry.py b/marimo/_ai/_tools/tools_registry.py index b5eae33c06b..e023c718f03 100644 --- a/marimo/_ai/_tools/tools_registry.py +++ b/marimo/_ai/_tools/tools_registry.py @@ -6,6 +6,7 @@ GetCellRuntimeData, GetLightweightCellMap, ) +from marimo._ai._tools.tools.datasource import GetDatabaseTables from marimo._ai._tools.tools.notebooks import GetActiveNotebooks from marimo._ai._tools.tools.tables_and_variables import GetTablesAndVariables @@ -14,4 +15,5 @@ GetCellRuntimeData, GetLightweightCellMap, GetTablesAndVariables, + GetDatabaseTables, ] diff --git a/marimo/_messaging/ops.py b/marimo/_messaging/ops.py index a9e4bb1a88a..bfba179b649 100644 --- a/marimo/_messaging/ops.py +++ b/marimo/_messaging/ops.py @@ -585,10 +585,19 @@ class Datasets(Op, tag="datasets"): clear_channel: Optional[DataTableSource] = None +class SQLMetadata(msgspec.Struct): + """Metadata for a SQL database.""" + + connection: str + database: str + schema: str + + class SQLTablePreview(Op, tag="sql-table-preview"): """Preview of a table in a SQL database.""" name: ClassVar[str] = "sql-table-preview" + metadata: SQLMetadata request_id: RequestId table: Optional[DataTable] error: Optional[str] = None @@ -598,6 +607,7 @@ class SQLTableListPreview(Op, tag="sql-table-list-preview"): """Preview of a list of tables in a schema.""" name: ClassVar[str] = "sql-table-list-preview" + metadata: SQLMetadata request_id: RequestId tables: list[DataTable] = msgspec.field(default_factory=list) error: Optional[str] = None diff --git a/marimo/_runtime/runtime.py b/marimo/_runtime/runtime.py index 94de6431328..274513abbf0 100644 --- a/marimo/_runtime/runtime.py +++ b/marimo/_runtime/runtime.py @@ -65,6 +65,7 @@ PackageStatusType, RemoveUIElements, SecretKeysResult, + SQLMetadata, SQLTableListPreview, SQLTablePreview, VariableDeclaration, @@ -2287,11 +2288,19 @@ async def preview_sql_table(self, request: PreviewSQLTableRequest) -> None: database_name = request.database schema_name = request.schema table_name = request.table_name + metadata = SQLMetadata( + connection=variable_name, + database=database_name, + schema=schema_name, + ) engine, error = self._get_engine_catalog(variable_name) if error is not None or engine is None: SQLTablePreview( - request_id=request.request_id, table=None, error=error + metadata=metadata, + request_id=request.request_id, + table=None, + error=error, ).broadcast() return @@ -2303,7 +2312,7 @@ async def preview_sql_table(self, request: PreviewSQLTableRequest) -> None: ) SQLTablePreview( - request_id=request.request_id, table=table + metadata=metadata, request_id=request.request_id, table=table ).broadcast() except Exception as e: LOGGER.exception( @@ -2312,6 +2321,7 @@ async def preview_sql_table(self, request: PreviewSQLTableRequest) -> None: schema_name, ) SQLTablePreview( + metadata=metadata, request_id=request.request_id, table=None, error="Failed to get table details: " + str(e), @@ -2332,11 +2342,19 @@ async def preview_sql_table_list( variable_name = cast(VariableName, request.engine) database_name = request.database schema_name = request.schema + metadata = SQLMetadata( + connection=variable_name, + database=database_name, + schema=schema_name, + ) engine, error = self._get_engine_catalog(variable_name) if error is not None or engine is None: SQLTableListPreview( - request_id=request.request_id, tables=[], error=error + metadata=metadata, + request_id=request.request_id, + tables=[], + error=error, ).broadcast() return @@ -2347,17 +2365,20 @@ async def preview_sql_table_list( include_table_details=False, ) SQLTableListPreview( - request_id=request.request_id, tables=table_list + metadata=metadata, + request_id=request.request_id, + tables=table_list, ).broadcast() except Exception as e: LOGGER.exception( "Failed to get table list for schema %s", schema_name ) SQLTableListPreview( + metadata=metadata, request_id=request.request_id, tables=[], error="Failed to get table list: " + str(e), - ) + ).broadcast() @kernel_tracer.start_as_current_span("preview_datasource_connection") async def preview_datasource_connection( diff --git a/marimo/_server/session/session_view.py b/marimo/_server/session/session_view.py index 29c5dfd7aa6..1ab01692060 100644 --- a/marimo/_server/session/session_view.py +++ b/marimo/_server/session/session_view.py @@ -15,6 +15,8 @@ Interrupted, MessageOperation, SendUIElementMessage, + SQLTableListPreview, + SQLTablePreview, StartupLogs, UpdateCellCodes, UpdateCellIdsRequest, @@ -31,6 +33,10 @@ ExecutionRequest, SetUIElementValueRequest, ) +from marimo._sql.connection_utils import ( + update_table_in_connection, + update_table_list_in_connection, +) from marimo._sql.engines.duckdb import INTERNAL_DUCKDB_ENGINE from marimo._types.ids import CellId_t, WidgetModelId from marimo._utils.lists import as_list @@ -247,6 +253,27 @@ def add_operation(self, operation: MessageOperation) -> None: connections=list(connections.values()) ) + elif isinstance(operation, SQLTablePreview): + sql_table_preview = operation + sql_metadata = sql_table_preview.metadata + table_preview_connections = self.data_connectors.connections + if sql_table_preview.table is not None: + update_table_in_connection( + table_preview_connections, + sql_metadata, + sql_table_preview.table, + ) + + elif isinstance(operation, SQLTableListPreview): + sql_table_list_preview = operation + sql_metadata = sql_table_list_preview.metadata + table_list_connections = self.data_connectors.connections + update_table_list_in_connection( + table_list_connections, + sql_metadata, + sql_table_list_preview.tables, + ) + elif isinstance(operation, UpdateCellIdsRequest): self.cell_ids = operation diff --git a/marimo/_sql/connection_utils.py b/marimo/_sql/connection_utils.py new file mode 100644 index 00000000000..52b1be8a7de --- /dev/null +++ b/marimo/_sql/connection_utils.py @@ -0,0 +1,62 @@ +# Copyright 2025 Marimo. All rights reserved. + +from marimo._data.models import DataSourceConnection, DataTable +from marimo._messaging.ops import SQLMetadata + + +def update_table_in_connection( + connections: list[DataSourceConnection], + sql_metadata: SQLMetadata, + updated_table: DataTable, +) -> None: + """Update a table in the connection hierarchy in-place + + Args: + connections: List of data source connections + sql_metadata: SQL metadata containing connection, database, schema info + updated_table: The updated table to replace the existing one + """ + for connection in connections: + if connection.name != sql_metadata.connection: + continue + + for database in connection.databases: + if database.name != sql_metadata.database: + continue + + for schema in database.schemas: + if schema.name != sql_metadata.schema: + continue + + for i, table in enumerate(schema.tables): + if table.name == updated_table.name: + schema.tables[i] = updated_table + return + + +def update_table_list_in_connection( + connections: list[DataSourceConnection], + sql_metadata: SQLMetadata, + updated_table_list: list[DataTable], +) -> None: + """Update a list of tables in the connection hierarchy, updates in-place. + + Args: + connections: List of data source connections + sql_metadata: SQL metadata containing connection, database, schema info + updated_table_list: The updated list of tables to replace the existing ones + """ + for connection in connections: + if connection.name != sql_metadata.connection: + continue + + for database in connection.databases: + if database.name != sql_metadata.database: + continue + + for schema in database.schemas: + if schema.name != sql_metadata.schema: + continue + + schema.tables = updated_table_list + return diff --git a/marimo/_utils/fuzzy_match.py b/marimo/_utils/fuzzy_match.py new file mode 100644 index 00000000000..7fd5e17e8a1 --- /dev/null +++ b/marimo/_utils/fuzzy_match.py @@ -0,0 +1,36 @@ +# Copyright 2025 Marimo. All rights reserved. + +from __future__ import annotations + +import re + + +def compile_regex(query: str) -> tuple[re.Pattern[str] | None, bool]: + """ + Returns compiled regex pattern and whether the query is a valid regex. + """ + try: + return re.compile(query, re.IGNORECASE), True + except re.error: + return None, False + + +def is_fuzzy_match( + query: str, + name: str, + compiled_pattern: re.Pattern[str] | None, + is_regex: bool, +) -> bool: + """ + Fuzzy match using pre-compiled regex. If is not regex, fallback to substring match. + + Args: + query: The query to match. + name: The name to match against. + compiled_pattern: Pre-compiled regex pattern (None if not regex). + is_regex: Whether the query is a valid regex. + """ + if is_regex and compiled_pattern: + return bool(compiled_pattern.search(name)) + else: + return query.lower() in name.lower() diff --git a/tests/_ai/tools/tools/test_datasource_tool.py b/tests/_ai/tools/tools/test_datasource_tool.py new file mode 100644 index 00000000000..91218315db0 --- /dev/null +++ b/tests/_ai/tools/tools/test_datasource_tool.py @@ -0,0 +1,467 @@ +# Copyright 2025 Marimo. All rights reserved. + +from __future__ import annotations + +from dataclasses import dataclass + +import pytest + +from marimo._ai._tools.base import ToolContext +from marimo._ai._tools.tools.datasource import ( + GetDatabaseTables, + GetDatabaseTablesArgs, + TableDetails, +) +from marimo._data.models import Database, DataTable, DataTableColumn, Schema +from marimo._messaging.ops import DataSourceConnections + + +@dataclass +class MockDataSourceConnection: + name: str + dialect: str + databases: list[Database] + + +@dataclass +class MockSessionView: + data_connectors: DataSourceConnections + + +@dataclass +class MockSession: + session_view: MockSessionView + + +@pytest.fixture +def tool() -> GetDatabaseTables: + """Create a GetDatabaseTables tool instance.""" + return GetDatabaseTables(ToolContext()) + + +@pytest.fixture +def sample_table() -> DataTable: + """Sample table for testing.""" + return DataTable( + source_type="connection", + source="postgresql", + name="users", + num_rows=100, + num_columns=3, + variable_name=None, + columns=[ + DataTableColumn("id", "int", "INTEGER", [1, 2, 3]), + DataTableColumn("name", "str", "VARCHAR", ["Alice", "Bob"]), + DataTableColumn("email", "str", "VARCHAR", ["alice@example.com"]), + ], + ) + + +@pytest.fixture +def sample_schema(sample_table: DataTable) -> Schema: + """Sample schema for testing.""" + return Schema( + name="public", + tables=[sample_table], + ) + + +@pytest.fixture +def sample_database(sample_schema: Schema) -> Database: + """Sample database for testing.""" + return Database( + name="test_db", + dialect="postgresql", + schemas=[sample_schema], + ) + + +@pytest.fixture +def sample_connection(sample_database: Database) -> MockDataSourceConnection: + """Sample connection for testing.""" + return MockDataSourceConnection( + name="postgres_conn", + dialect="postgresql", + databases=[sample_database], + ) + + +@pytest.fixture +def sample_session(sample_connection: MockDataSourceConnection) -> MockSession: + """Sample session with data connectors.""" + return MockSession( + session_view=MockSessionView( + data_connectors=DataSourceConnections( + connections=[sample_connection] + ) + ) + ) + + +@pytest.fixture +def multi_table_session() -> MockSession: + """Session with multiple tables for testing filtering.""" + tables = [ + DataTable( + source_type="connection", + source="mysql", + name="users", + num_rows=100, + num_columns=2, + variable_name=None, + columns=[ + DataTableColumn("id", "int", "INTEGER", [1, 2]), + DataTableColumn("name", "str", "VARCHAR", ["Alice"]), + ], + ), + DataTable( + source_type="connection", + source="mysql", + name="orders", + num_rows=50, + num_columns=2, + variable_name=None, + columns=[ + DataTableColumn("order_id", "int", "INTEGER", [1]), + DataTableColumn("user_id", "int", "INTEGER", [1]), + ], + ), + DataTable( + source_type="connection", + source="mysql", + name="products", + num_rows=25, + num_columns=2, + variable_name=None, + columns=[ + DataTableColumn("product_id", "int", "INTEGER", [1]), + DataTableColumn("name", "str", "VARCHAR", ["Widget"]), + ], + ), + ] + + schema = Schema(name="public", tables=tables) + database = Database(name="ecommerce", dialect="mysql", schemas=[schema]) + connection = MockDataSourceConnection( + name="mysql_conn", dialect="mysql", databases=[database] + ) + + return MockSession( + session_view=MockSessionView( + data_connectors=DataSourceConnections(connections=[connection]) + ) + ) + + +def test_get_tables_no_query( + tool: GetDatabaseTables, sample_session: MockSession +): + """Test getting all tables when no query is provided.""" + + # Mock the session + def mock_get_session(_session_id): + return sample_session + + tool.context.get_session = mock_get_session + + args = GetDatabaseTablesArgs( + session_id="test_session", + query=None, + ) + + result = tool.handle(args) + + assert isinstance(result, tool.Output) + assert len(result.tables) == 1 + + table_detail = result.tables[0] + assert isinstance(table_detail, TableDetails) + assert table_detail.connection == "postgres_conn" + assert table_detail.database == "test_db" + assert table_detail.schema == "public" + assert table_detail.table.name == "users" + + +def test_get_tables_with_simple_query( + tool: GetDatabaseTables, multi_table_session: MockSession +): + """Test getting tables with simple text query.""" + + # Mock the session + def mock_get_session(_session_id): + return multi_table_session + + tool.context.get_session = mock_get_session + + args = GetDatabaseTablesArgs( + session_id="test_session", + query="user", + ) + + result = tool.handle(args) + + assert isinstance(result, tool.Output) + assert len(result.tables) == 1 # Only "users" table matches "user" + + table_names = {td.table.name for td in result.tables} + assert "users" in table_names + assert "orders" not in table_names # "orders" doesn't contain "user" + assert "products" not in table_names + + +def test_get_tables_with_regex_query( + tool: GetDatabaseTables, multi_table_session: MockSession +): + """Test getting tables with regex query.""" + + # Mock the session + def mock_get_session(_session_id): + return multi_table_session + + tool.context.get_session = mock_get_session + + args = GetDatabaseTablesArgs( + session_id="test_session", + query="^user.*", + ) + + result = tool.handle(args) + + assert isinstance(result, tool.Output) + assert len(result.tables) == 1 + + table_detail = result.tables[0] + assert table_detail.table.name == "users" + + +def test_get_tables_with_schema_match( + tool: GetDatabaseTables, multi_table_session: MockSession +): + """Test getting tables by schema name match.""" + + # Mock the session + def mock_get_session(_session_id): + return multi_table_session + + tool.context.get_session = mock_get_session + + args = GetDatabaseTablesArgs( + session_id="test_session", + query="pub", + ) + + result = tool.handle(args) + + assert isinstance(result, tool.Output) + assert len(result.tables) == 3 # All tables in public schema + + table_names = {td.table.name for td in result.tables} + assert "users" in table_names + assert "orders" in table_names + assert "products" in table_names + + +def test_get_tables_empty_connections(tool: GetDatabaseTables): + """Test getting tables when no connections exist.""" + empty_session = MockSession( + session_view=MockSessionView( + data_connectors=DataSourceConnections(connections=[]) + ) + ) + + # Mock the session + def mock_get_session(_session_id): + return empty_session + + tool.context.get_session = mock_get_session + + args = GetDatabaseTablesArgs( + session_id="test_session", + query=None, + ) + + result = tool.handle(args) + + assert isinstance(result, tool.Output) + assert len(result.tables) == 0 + assert "No databases found" in result.next_steps[0] + + +def test_get_tables_no_matches( + tool: GetDatabaseTables, sample_session: MockSession +): + """Test getting tables when query matches nothing.""" + + # Mock the session + def mock_get_session(_session_id): + return sample_session + + tool.context.get_session = mock_get_session + + args = GetDatabaseTablesArgs( + session_id="test_session", + query="nonexistent", + ) + + result = tool.handle(args) + + assert isinstance(result, tool.Output) + assert len(result.tables) == 0 + + +def test_table_details_structure( + tool: GetDatabaseTables, sample_session: MockSession +): + """Test that TableDetails is properly structured.""" + + # Mock the session + def mock_get_session(_session_id): + return sample_session + + tool.context.get_session = mock_get_session + + args = GetDatabaseTablesArgs( + session_id="test_session", + query=None, + ) + + result = tool.handle(args) + + table_detail = result.tables[0] + assert isinstance(table_detail, TableDetails) + assert table_detail.connection == "postgres_conn" + assert table_detail.database == "test_db" + assert table_detail.schema == "public" + assert isinstance(table_detail.table, DataTable) + assert table_detail.table.name == "users" + assert len(table_detail.table.columns) == 3 + + +def test_multiple_connections(tool: GetDatabaseTables): + """Test with multiple connections.""" + # Create two connections with different databases + table1 = DataTable( + source_type="connection", + source="postgresql", + name="table1", + num_rows=10, + num_columns=0, + variable_name=None, + columns=[], + ) + table2 = DataTable( + source_type="connection", + source="mysql", + name="table2", + num_rows=20, + num_columns=0, + variable_name=None, + columns=[], + ) + + schema1 = Schema(name="schema1", tables=[table1]) + schema2 = Schema(name="schema2", tables=[table2]) + + db1 = Database(name="db1", dialect="postgresql", schemas=[schema1]) + db2 = Database(name="db2", dialect="mysql", schemas=[schema2]) + + conn1 = MockDataSourceConnection( + name="conn1", dialect="postgresql", databases=[db1] + ) + conn2 = MockDataSourceConnection( + name="conn2", dialect="mysql", databases=[db2] + ) + + multi_conn_session = MockSession( + session_view=MockSessionView( + data_connectors=DataSourceConnections(connections=[conn1, conn2]) + ) + ) + + # Mock the session + def mock_get_session(_session_id): + return multi_conn_session + + tool.context.get_session = mock_get_session + + args = GetDatabaseTablesArgs( + session_id="test_session", + query=None, + ) + + result = tool.handle(args) + + assert isinstance(result, tool.Output) + assert len(result.tables) == 2 + + connections = {td.connection for td in result.tables} + assert "conn1" in connections + assert "conn2" in connections + + databases = {td.database for td in result.tables} + assert "db1" in databases + assert "db2" in databases + + +def test_query_matches_multiple_levels(tool: GetDatabaseTables): + """Test query that matches at different levels (schema and table).""" + # Create tables with overlapping names + user_table = DataTable( + source_type="connection", + source="postgresql", + name="user", + num_rows=5, + num_columns=0, + variable_name=None, + columns=[], + ) + user_schema_table = DataTable( + source_type="connection", + source="postgresql", + name="orders", + num_rows=10, + num_columns=0, + variable_name=None, + columns=[], + ) + + user_schema = Schema(name="user", tables=[user_table]) + public_schema = Schema(name="public", tables=[user_schema_table]) + + database = Database( + name="testdb", + dialect="postgresql", + schemas=[user_schema, public_schema], + ) + + connection = MockDataSourceConnection( + name="conn", dialect="postgresql", databases=[database] + ) + + session = MockSession( + session_view=MockSessionView( + data_connectors=DataSourceConnections(connections=[connection]) + ) + ) + + # Mock the session + def mock_get_session(_session_id): + return session + + tool.context.get_session = mock_get_session + + args = GetDatabaseTablesArgs( + session_id="test_session", + query="user", + ) + + result = tool.handle(args) + + assert isinstance(result, tool.Output) + assert len(result.tables) == 1 # Only the "user" table matches "user" + + table_names = {td.table.name for td in result.tables} + assert "user" in table_names + # The "orders" table is in the "public" schema, not the "user" schema + # So it won't be included when query matches "user" + assert "orders" not in table_names diff --git a/tests/_server/session/test_session_view.py b/tests/_server/session/test_session_view.py index d4824624360..e73fb99efcd 100644 --- a/tests/_server/session/test_session_view.py +++ b/tests/_server/session/test_session_view.py @@ -5,7 +5,7 @@ from unittest.mock import patch from marimo._ast.cell import RuntimeStateType -from marimo._data.models import DataTable, DataTableColumn +from marimo._data.models import Database, DataTable, DataTableColumn, Schema from marimo._messaging.cell_output import CellChannel, CellOutput from marimo._messaging.msgspec_encoder import asdict as serialize from marimo._messaging.ops import ( @@ -15,6 +15,9 @@ DataSourceConnections, InstallingPackageAlert, SendUIElementMessage, + SQLMetadata, + SQLTableListPreview, + SQLTablePreview, StartupLogs, UpdateCellCodes, UpdateCellIdsRequest, @@ -32,7 +35,7 @@ ) from marimo._server.session.session_view import SessionView from marimo._sql.engines.duckdb import INTERNAL_DUCKDB_ENGINE -from marimo._types.ids import CellId_t, WidgetModelId +from marimo._types.ids import CellId_t, RequestId, WidgetModelId from marimo._utils.parse_dataclass import parse_raw cell_id = CellId_t("cell_1") @@ -574,6 +577,125 @@ def test_add_data_source_connections() -> None: assert INTERNAL_DUCKDB_ENGINE in session_view_names +def test_add_sql_table_previews() -> None: + session_view = SessionView() + + # Add initial connections + session_view.add_raw_operation( + serialize_kernel_message( + DataSourceConnections( + connections=[ + DataSourceConnection( + source="duckdb", + name="connection1", + dialect="duckdb", + display_name="duckdb (connection1)", + databases=[ + Database( + name="db1", + dialect="duckdb", + schemas=[ + Schema( + name="db1", + tables=[ + DataTable( + name="table1", + source_type="connection", + source="db1", + columns=[], + num_rows=0, + num_columns=0, + variable_name=None, + ) + ], + ) + ], + ) + ], + ) + ], + ) + ) + ) + + session_view_connections = session_view.data_connectors.connections + assert session_view_connections[0].databases[0].schemas[0].tables == [ + DataTable( + source_type="connection", + source="db1", + name="table1", + num_rows=0, + num_columns=0, + variable_name=None, + columns=[], + ) + ] + + session_view.add_raw_operation( + serialize_kernel_message( + SQLTablePreview( + metadata=SQLMetadata( + connection="connection1", database="db1", schema="db1" + ), + request_id=RequestId("request_id"), + table=DataTable( + name="table1", + source_type="connection", + source="db1", + num_rows=10, # Updated + num_columns=0, + variable_name=None, + columns=[], + ), + ) + ) + ) + session_view_connections = session_view.data_connectors.connections + assert ( + session_view_connections[0].databases[0].schemas[0].tables[0].num_rows + == 10 + ) + + # Add sql table preview list + session_view.add_raw_operation( + serialize_kernel_message( + SQLTableListPreview( + metadata=SQLMetadata( + connection="connection1", database="db1", schema="db1" + ), + request_id=RequestId("request_id"), + tables=[ + DataTable( + name="table2", + source_type="connection", + source="db1", + num_rows=20, + num_columns=10, + variable_name="var", + columns=[], + ) + ], + ) + ) + ) + + assert session_view_connections[0].databases[0].schemas[0].tables == [ + DataTable( + source_type="connection", + source="db1", + name="table2", + num_rows=20, + num_columns=10, + variable_name="var", + columns=[], + engine=None, + type="table", + primary_keys=None, + indexes=None, + ) + ] + + def test_add_cell_op() -> None: session_view = SessionView() session_view.add_raw_operation( diff --git a/tests/_utils/test_fuzzy_match.py b/tests/_utils/test_fuzzy_match.py new file mode 100644 index 00000000000..d9c957383fc --- /dev/null +++ b/tests/_utils/test_fuzzy_match.py @@ -0,0 +1,57 @@ +# Copyright 2025 Marimo. All rights reserved. + + +from marimo._utils.fuzzy_match import compile_regex, is_fuzzy_match + + +def test_compile_regex_valid_pattern(): + """Test _compile_regex with valid regex pattern.""" + pattern, is_regex = compile_regex("^user.*") + + assert pattern is not None + assert is_regex is True + assert pattern.search("users") is not None + assert pattern.search("orders") is None + + +def test_compile_regex_invalid_pattern(): + """Test _compile_regex with invalid regex pattern.""" + pattern, is_regex = compile_regex("[invalid") + + assert pattern is None + assert is_regex is False + + +def test_compile_regex_simple_text(): + """Test _compile_regex with simple text (valid regex).""" + pattern, is_regex = compile_regex("user") + + assert pattern is not None + assert is_regex is True + assert pattern.search("users") is not None + assert pattern.search("orders") is None + + +def test_is_fuzzy_match_with_regex(): + """Test is_fuzzy_match with compiled regex pattern.""" + pattern, is_regex = compile_regex("^user.*") + + assert is_fuzzy_match("^user.*", "users", pattern, is_regex) is True + assert is_fuzzy_match("^user.*", "orders", pattern, is_regex) is False + + +def test_is_fuzzy_match_without_regex(): + """Test is_fuzzy_match with invalid regex (fallback to substring).""" + pattern, is_regex = compile_regex("[invalid") + + assert is_fuzzy_match("[invalid", "users", pattern, is_regex) is False + assert is_fuzzy_match("[invalid", "[invalid", pattern, is_regex) is True + + +def test_is_fuzzy_match_case_insensitive(): + """Test that matching is case insensitive.""" + pattern, is_regex = compile_regex("USER") + + assert is_fuzzy_match("USER", "users", pattern, is_regex) is True + assert is_fuzzy_match("USER", "USERS", pattern, is_regex) is True + assert is_fuzzy_match("USER", "orders", pattern, is_regex) is False