Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 111 additions & 0 deletions marimo/_ai/_tools/tools/datasource.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
# Copyright 2025 Marimo. All rights reserved.

from __future__ import annotations

from dataclasses import dataclass, field
from typing import Optional

from marimo import _loggers
from marimo._ai._tools.base import ToolBase
from marimo._ai._tools.types import SuccessResult
from marimo._ai._tools.utils.exceptions import ToolExecutionError
from marimo._data.models import DataTable
from marimo._server.sessions import Session
from marimo._types.ids import SessionId
from marimo._utils.fuzzy_match import compile_regex, is_fuzzy_match

LOGGER = _loggers.marimo_logger()


@dataclass
class GetDatabaseTablesArgs:
session_id: SessionId
query: Optional[str] = None


@dataclass
class TableDetails:
connection: str
database: str
schema: str
table: DataTable


@dataclass
class GetDatabaseTablesOutput(SuccessResult):
tables: list[TableDetails] = field(default_factory=list)


class GetDatabaseTables(
ToolBase[GetDatabaseTablesArgs, GetDatabaseTablesOutput]
):
"""
Get information about tables in a database.

Args:
session_id: The session id.
query (optional): The query to match the database, schemas, and tables. Regex is supported.

If a query is provided, it will fuzzy match the query to the database, schemas, and tables available. If no query is provided, all tables are returned. Don't provide a query if you need to see the entire schema view.

The tables returned contain information about the database, schema and connection name to use in forming SQL queries.
"""

def handle(self, args: GetDatabaseTablesArgs) -> GetDatabaseTablesOutput:
session_id = args.session_id
session = self.context.get_session(session_id)

return self._get_tables(session, args.query)

def _get_tables(
self, session: Session, query: Optional[str]
) -> list[TableDetails]:
session_view = session.session_view
data_connectors = session_view.data_connectors

if len(data_connectors.connections) == 0:
raise ToolExecutionError(
message="No databases found. Please create a connection first.",
code="NO_DATABASES_FOUND",
is_retryable=False,
)

tables: list[TableDetails] = []

# Pre-compile regex if query exists
compiled_pattern = None
is_regex = False
if query:
compiled_pattern, is_regex = compile_regex(query)

for connection in data_connectors.connections:
for database in connection.databases:
for schema in database.schemas:
# If query is None, match all schemas
if query is None or is_fuzzy_match(
query, schema.name, compiled_pattern, is_regex
):
for table in schema.tables:
tables.append(
TableDetails(
connection=connection.name,
database=database.name,
schema=schema.name,
table=table,
)
)
continue
for table in schema.tables:
if is_fuzzy_match(
query, table.name, compiled_pattern, is_regex
):
tables.append(
TableDetails(
connection=connection.name,
database=database.name,
schema=schema.name,
table=table,
)
)

return GetDatabaseTablesOutput(tables=tables)
2 changes: 2 additions & 0 deletions marimo/_ai/_tools/tools_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
GetCellRuntimeData,
GetLightweightCellMap,
)
from marimo._ai._tools.tools.datasource import GetDatabaseTables
from marimo._ai._tools.tools.notebooks import GetActiveNotebooks
from marimo._ai._tools.tools.tables_and_variables import GetTablesAndVariables

Expand All @@ -14,4 +15,5 @@
GetCellRuntimeData,
GetLightweightCellMap,
GetTablesAndVariables,
GetDatabaseTables,
]
10 changes: 10 additions & 0 deletions marimo/_messaging/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,10 +585,19 @@ class Datasets(Op, tag="datasets"):
clear_channel: Optional[DataTableSource] = None


class SQLMetadata(msgspec.Struct):
"""Metadata for a SQL database."""

connection: str
database: str
schema: str


class SQLTablePreview(Op, tag="sql-table-preview"):
"""Preview of a table in a SQL database."""

name: ClassVar[str] = "sql-table-preview"
metadata: SQLMetadata
request_id: RequestId
table: Optional[DataTable]
error: Optional[str] = None
Expand All @@ -598,6 +607,7 @@ class SQLTableListPreview(Op, tag="sql-table-list-preview"):
"""Preview of a list of tables in a schema."""

name: ClassVar[str] = "sql-table-list-preview"
metadata: SQLMetadata
request_id: RequestId
tables: list[DataTable] = msgspec.field(default_factory=list)
error: Optional[str] = None
Expand Down
31 changes: 26 additions & 5 deletions marimo/_runtime/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
PackageStatusType,
RemoveUIElements,
SecretKeysResult,
SQLMetadata,
SQLTableListPreview,
SQLTablePreview,
VariableDeclaration,
Expand Down Expand Up @@ -2287,11 +2288,19 @@ async def preview_sql_table(self, request: PreviewSQLTableRequest) -> None:
database_name = request.database
schema_name = request.schema
table_name = request.table_name
metadata = SQLMetadata(
connection=variable_name,
database=database_name,
schema=schema_name,
)

engine, error = self._get_engine_catalog(variable_name)
if error is not None or engine is None:
SQLTablePreview(
request_id=request.request_id, table=None, error=error
metadata=metadata,
request_id=request.request_id,
table=None,
error=error,
).broadcast()
return

Expand All @@ -2303,7 +2312,7 @@ async def preview_sql_table(self, request: PreviewSQLTableRequest) -> None:
)

SQLTablePreview(
request_id=request.request_id, table=table
metadata=metadata, request_id=request.request_id, table=table
).broadcast()
except Exception as e:
LOGGER.exception(
Expand All @@ -2312,6 +2321,7 @@ async def preview_sql_table(self, request: PreviewSQLTableRequest) -> None:
schema_name,
)
SQLTablePreview(
metadata=metadata,
request_id=request.request_id,
table=None,
error="Failed to get table details: " + str(e),
Expand All @@ -2332,11 +2342,19 @@ async def preview_sql_table_list(
variable_name = cast(VariableName, request.engine)
database_name = request.database
schema_name = request.schema
metadata = SQLMetadata(
connection=variable_name,
database=database_name,
schema=schema_name,
)

engine, error = self._get_engine_catalog(variable_name)
if error is not None or engine is None:
SQLTableListPreview(
request_id=request.request_id, tables=[], error=error
metadata=metadata,
request_id=request.request_id,
tables=[],
error=error,
).broadcast()
return

Expand All @@ -2347,17 +2365,20 @@ async def preview_sql_table_list(
include_table_details=False,
)
SQLTableListPreview(
request_id=request.request_id, tables=table_list
metadata=metadata,
request_id=request.request_id,
tables=table_list,
).broadcast()
except Exception as e:
LOGGER.exception(
"Failed to get table list for schema %s", schema_name
)
SQLTableListPreview(
metadata=metadata,
request_id=request.request_id,
tables=[],
error="Failed to get table list: " + str(e),
)
).broadcast()

@kernel_tracer.start_as_current_span("preview_datasource_connection")
async def preview_datasource_connection(
Expand Down
27 changes: 27 additions & 0 deletions marimo/_server/session/session_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
Interrupted,
MessageOperation,
SendUIElementMessage,
SQLTableListPreview,
SQLTablePreview,
StartupLogs,
UpdateCellCodes,
UpdateCellIdsRequest,
Expand All @@ -31,6 +33,10 @@
ExecutionRequest,
SetUIElementValueRequest,
)
from marimo._sql.connection_utils import (
update_table_in_connection,
update_table_list_in_connection,
)
from marimo._sql.engines.duckdb import INTERNAL_DUCKDB_ENGINE
from marimo._types.ids import CellId_t, WidgetModelId
from marimo._utils.lists import as_list
Expand Down Expand Up @@ -247,6 +253,27 @@ def add_operation(self, operation: MessageOperation) -> None:
connections=list(connections.values())
)

elif isinstance(operation, SQLTablePreview):
sql_table_preview = operation
sql_metadata = sql_table_preview.metadata
table_preview_connections = self.data_connectors.connections
if sql_table_preview.table is not None:
update_table_in_connection(
table_preview_connections,
sql_metadata,
sql_table_preview.table,
)

elif isinstance(operation, SQLTableListPreview):
sql_table_list_preview = operation
sql_metadata = sql_table_list_preview.metadata
table_list_connections = self.data_connectors.connections
update_table_list_in_connection(
table_list_connections,
sql_metadata,
sql_table_list_preview.tables,
)

elif isinstance(operation, UpdateCellIdsRequest):
self.cell_ids = operation

Expand Down
62 changes: 62 additions & 0 deletions marimo/_sql/connection_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Copyright 2025 Marimo. All rights reserved.

from marimo._data.models import DataSourceConnection, DataTable
from marimo._messaging.ops import SQLMetadata


def update_table_in_connection(
connections: list[DataSourceConnection],
sql_metadata: SQLMetadata,
updated_table: DataTable,
) -> None:
"""Update a table in the connection hierarchy in-place

Args:
connections: List of data source connections
sql_metadata: SQL metadata containing connection, database, schema info
updated_table: The updated table to replace the existing one
"""
for connection in connections:
if connection.name != sql_metadata.connection:
continue

for database in connection.databases:
if database.name != sql_metadata.database:
continue

for schema in database.schemas:
if schema.name != sql_metadata.schema:
continue

for i, table in enumerate(schema.tables):
if table.name == updated_table.name:
schema.tables[i] = updated_table
return


def update_table_list_in_connection(
connections: list[DataSourceConnection],
sql_metadata: SQLMetadata,
updated_table_list: list[DataTable],
) -> None:
"""Update a list of tables in the connection hierarchy, updates in-place.

Args:
connections: List of data source connections
sql_metadata: SQL metadata containing connection, database, schema info
updated_table_list: The updated list of tables to replace the existing ones
"""
for connection in connections:
if connection.name != sql_metadata.connection:
continue

for database in connection.databases:
if database.name != sql_metadata.database:
continue

for schema in database.schemas:
if schema.name != sql_metadata.schema:
continue

schema.tables = updated_table_list
return
36 changes: 36 additions & 0 deletions marimo/_utils/fuzzy_match.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Copyright 2025 Marimo. All rights reserved.

from __future__ import annotations

import re


def compile_regex(query: str) -> tuple[re.Pattern[str] | None, bool]:
"""
Returns compiled regex pattern and whether the query is a valid regex.
"""
try:
return re.compile(query, re.IGNORECASE), True
except re.error:
return None, False


def is_fuzzy_match(
query: str,
name: str,
compiled_pattern: re.Pattern[str] | None,
is_regex: bool,
) -> bool:
"""
Fuzzy match using pre-compiled regex. If is not regex, fallback to substring match.

Args:
query: The query to match.
name: The name to match against.
compiled_pattern: Pre-compiled regex pattern (None if not regex).
is_regex: Whether the query is a valid regex.
"""
if is_regex and compiled_pattern:
return bool(compiled_pattern.search(name))
else:
return query.lower() in name.lower()
Loading
Loading