Skip to content

Commit 7730b35

Browse files
committed
update datasource connection when manually introspecting
1 parent e6433bc commit 7730b35

File tree

7 files changed

+742
-9
lines changed

7 files changed

+742
-9
lines changed

marimo/_messaging/ops.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -587,11 +587,20 @@ class Datasets(Op, tag="datasets"):
587587
clear_channel: Optional[DataTableSource] = None
588588

589589

590+
class SQLMetadata(msgspec.Struct, tag="sql-metadata"):
591+
"""Metadata for a SQL database."""
592+
593+
connection: str
594+
database: str
595+
schema: str
596+
597+
590598
class SQLTablePreview(Op, tag="sql-table-preview"):
591599
"""Preview of a table in a SQL database."""
592600

593601
name: ClassVar[str] = "sql-table-preview"
594602
request_id: RequestId
603+
metadata: SQLMetadata
595604
table: Optional[DataTable]
596605
error: Optional[str] = None
597606

@@ -601,6 +610,7 @@ class SQLTableListPreview(Op, tag="sql-table-list-preview"):
601610

602611
name: ClassVar[str] = "sql-table-list-preview"
603612
request_id: RequestId
613+
metadata: SQLMetadata
604614
tables: list[DataTable] = msgspec.field(default_factory=list)
605615
error: Optional[str] = None
606616

marimo/_runtime/runtime.py

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@
6767
PackageStatusType,
6868
RemoveUIElements,
6969
SecretKeysResult,
70+
SQLMetadata,
7071
SQLTableListPreview,
7172
SQLTablePreview,
7273
ValidateSQLResult,
@@ -2454,11 +2455,19 @@ async def preview_sql_table(self, request: PreviewSQLTableRequest) -> None:
24542455
database_name = request.database
24552456
schema_name = request.schema
24562457
table_name = request.table_name
2458+
sql_metadata = SQLMetadata(
2459+
connection=variable_name,
2460+
database=database_name,
2461+
schema=schema_name,
2462+
)
24572463

24582464
engine, error = self.get_engine_catalog(variable_name)
24592465
if error is not None or engine is None:
24602466
SQLTablePreview(
2461-
request_id=request.request_id, table=None, error=error
2467+
request_id=request.request_id,
2468+
table=None,
2469+
error=error,
2470+
metadata=sql_metadata,
24622471
).broadcast()
24632472
return
24642473

@@ -2470,7 +2479,9 @@ async def preview_sql_table(self, request: PreviewSQLTableRequest) -> None:
24702479
)
24712480

24722481
SQLTablePreview(
2473-
request_id=request.request_id, table=table
2482+
request_id=request.request_id,
2483+
table=table,
2484+
metadata=sql_metadata,
24742485
).broadcast()
24752486
except Exception as e:
24762487
LOGGER.exception(
@@ -2482,6 +2493,7 @@ async def preview_sql_table(self, request: PreviewSQLTableRequest) -> None:
24822493
request_id=request.request_id,
24832494
table=None,
24842495
error="Failed to get table details: " + str(e),
2496+
metadata=sql_metadata,
24852497
).broadcast()
24862498

24872499
@kernel_tracer.start_as_current_span("preview_sql_table_list")
@@ -2499,11 +2511,19 @@ async def preview_sql_table_list(
24992511
variable_name = cast(VariableName, request.engine)
25002512
database_name = request.database
25012513
schema_name = request.schema
2514+
sql_metadata = SQLMetadata(
2515+
connection=variable_name,
2516+
database=database_name,
2517+
schema=schema_name,
2518+
)
25022519

25032520
engine, error = self.get_engine_catalog(variable_name)
25042521
if error is not None or engine is None:
25052522
SQLTableListPreview(
2506-
request_id=request.request_id, tables=[], error=error
2523+
request_id=request.request_id,
2524+
tables=[],
2525+
error=error,
2526+
metadata=sql_metadata,
25072527
).broadcast()
25082528
return
25092529

@@ -2514,7 +2534,9 @@ async def preview_sql_table_list(
25142534
include_table_details=False,
25152535
)
25162536
SQLTableListPreview(
2517-
request_id=request.request_id, tables=table_list
2537+
request_id=request.request_id,
2538+
tables=table_list,
2539+
metadata=sql_metadata,
25182540
).broadcast()
25192541
except Exception as e:
25202542
LOGGER.exception(
@@ -2524,7 +2546,8 @@ async def preview_sql_table_list(
25242546
request_id=request.request_id,
25252547
tables=[],
25262548
error="Failed to get table list: " + str(e),
2527-
)
2549+
metadata=sql_metadata,
2550+
).broadcast()
25282551

25292552
@kernel_tracer.start_as_current_span("preview_datasource_connection")
25302553
async def preview_datasource_connection(

marimo/_server/session/session_view.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
Interrupted,
1616
MessageOperation,
1717
SendUIElementMessage,
18+
SQLTableListPreview,
19+
SQLTablePreview,
1820
StartupLogs,
1921
UpdateCellCodes,
2022
UpdateCellIdsRequest,
@@ -32,6 +34,10 @@
3234
SetUIElementValueRequest,
3335
SyncGraphRequest,
3436
)
37+
from marimo._sql.connection_utils import (
38+
update_table_in_connection,
39+
update_table_list_in_connection,
40+
)
3541
from marimo._sql.engines.duckdb import INTERNAL_DUCKDB_ENGINE
3642
from marimo._types.ids import CellId_t, WidgetModelId
3743
from marimo._utils.lists import as_list
@@ -248,6 +254,27 @@ def add_operation(self, operation: MessageOperation) -> None:
248254
connections=list(connections.values())
249255
)
250256

257+
elif isinstance(operation, SQLTablePreview):
258+
sql_table_preview = operation
259+
sql_metadata = sql_table_preview.metadata
260+
table_preview_connections = self.data_connectors.connections
261+
if sql_table_preview.table is not None:
262+
update_table_in_connection(
263+
table_preview_connections,
264+
sql_metadata,
265+
sql_table_preview.table,
266+
)
267+
268+
elif isinstance(operation, SQLTableListPreview):
269+
sql_table_list_preview = operation
270+
sql_metadata = sql_table_list_preview.metadata
271+
table_list_connections = self.data_connectors.connections
272+
update_table_list_in_connection(
273+
table_list_connections,
274+
sql_metadata,
275+
sql_table_list_preview.tables,
276+
)
277+
251278
elif isinstance(operation, UpdateCellIdsRequest):
252279
self.cell_ids = operation
253280

marimo/_sql/connection_utils.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# Copyright 2025 Marimo. All rights reserved.
2+
3+
from marimo._data.models import DataSourceConnection, DataTable
4+
from marimo._messaging.ops import SQLMetadata
5+
6+
7+
def update_table_in_connection(
8+
connections: list[DataSourceConnection],
9+
sql_metadata: SQLMetadata,
10+
updated_table: DataTable,
11+
) -> None:
12+
"""Update a table in the connection hierarchy in-place
13+
14+
Args:
15+
connections: List of data source connections
16+
sql_metadata: SQL metadata containing connection, database, schema info
17+
updated_table: The updated table to replace the existing one
18+
"""
19+
for connection in connections:
20+
if connection.name != sql_metadata.connection:
21+
continue
22+
23+
for database in connection.databases:
24+
if database.name != sql_metadata.database:
25+
continue
26+
27+
for schema in database.schemas:
28+
if schema.name != sql_metadata.schema:
29+
continue
30+
31+
for i, table in enumerate(schema.tables):
32+
if table.name == updated_table.name:
33+
schema.tables[i] = updated_table
34+
return
35+
36+
37+
def update_table_list_in_connection(
38+
connections: list[DataSourceConnection],
39+
sql_metadata: SQLMetadata,
40+
updated_table_list: list[DataTable],
41+
) -> None:
42+
"""Update a list of tables in the connection hierarchy, updates in-place.
43+
44+
Args:
45+
connections: List of data source connections
46+
sql_metadata: SQL metadata containing connection, database, schema info
47+
updated_table_list: The updated list of tables to replace the existing ones
48+
"""
49+
for connection in connections:
50+
if connection.name != sql_metadata.connection:
51+
continue
52+
53+
for database in connection.databases:
54+
if database.name != sql_metadata.database:
55+
continue
56+
57+
for schema in database.schemas:
58+
if schema.name != sql_metadata.schema:
59+
continue
60+
61+
schema.tables = updated_table_list
62+
return

pixi.lock

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/_server/session/test_session_view.py

Lines changed: 124 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from unittest.mock import patch
66

77
from marimo._ast.cell import RuntimeStateType
8-
from marimo._data.models import DataTable, DataTableColumn
8+
from marimo._data.models import Database, DataTable, DataTableColumn, Schema
99
from marimo._messaging.cell_output import CellChannel, CellOutput
1010
from marimo._messaging.msgspec_encoder import asdict as serialize
1111
from marimo._messaging.ops import (
@@ -15,6 +15,9 @@
1515
DataSourceConnections,
1616
InstallingPackageAlert,
1717
SendUIElementMessage,
18+
SQLMetadata,
19+
SQLTableListPreview,
20+
SQLTablePreview,
1821
StartupLogs,
1922
UpdateCellCodes,
2023
UpdateCellIdsRequest,
@@ -32,7 +35,7 @@
3235
)
3336
from marimo._server.session.session_view import SessionView
3437
from marimo._sql.engines.duckdb import INTERNAL_DUCKDB_ENGINE
35-
from marimo._types.ids import CellId_t, WidgetModelId
38+
from marimo._types.ids import CellId_t, RequestId, VariableName, WidgetModelId
3639
from marimo._utils.parse_dataclass import parse_raw
3740

3841
cell_id = CellId_t("cell_1")
@@ -558,6 +561,125 @@ def test_add_data_source_connections(session_view: SessionView) -> None:
558561
assert INTERNAL_DUCKDB_ENGINE in session_view_names
559562

560563

564+
def test_add_sql_table_previews() -> None:
565+
session_view = SessionView()
566+
567+
# Add initial connections
568+
session_view.add_raw_operation(
569+
serialize_kernel_message(
570+
DataSourceConnections(
571+
connections=[
572+
DataSourceConnection(
573+
source="duckdb",
574+
name="connection1",
575+
dialect="duckdb",
576+
display_name="duckdb (connection1)",
577+
databases=[
578+
Database(
579+
name="db1",
580+
dialect="duckdb",
581+
schemas=[
582+
Schema(
583+
name="db1",
584+
tables=[
585+
DataTable(
586+
name="table1",
587+
source_type="connection",
588+
source="db1",
589+
columns=[],
590+
num_rows=0,
591+
num_columns=0,
592+
variable_name=None,
593+
)
594+
],
595+
)
596+
],
597+
)
598+
],
599+
)
600+
],
601+
)
602+
)
603+
)
604+
605+
session_view_connections = session_view.data_connectors.connections
606+
assert session_view_connections[0].databases[0].schemas[0].tables == [
607+
DataTable(
608+
source_type="connection",
609+
source="db1",
610+
name="table1",
611+
num_rows=0,
612+
num_columns=0,
613+
variable_name=None,
614+
columns=[],
615+
)
616+
]
617+
618+
session_view.add_raw_operation(
619+
serialize_kernel_message(
620+
SQLTablePreview(
621+
metadata=SQLMetadata(
622+
connection="connection1", database="db1", schema="db1"
623+
),
624+
request_id=RequestId("request_id"),
625+
table=DataTable(
626+
name="table1",
627+
source_type="connection",
628+
source="db1",
629+
num_rows=10, # Updated
630+
num_columns=0,
631+
variable_name=None,
632+
columns=[],
633+
),
634+
)
635+
)
636+
)
637+
session_view_connections = session_view.data_connectors.connections
638+
assert (
639+
session_view_connections[0].databases[0].schemas[0].tables[0].num_rows
640+
== 10
641+
)
642+
643+
# Add sql table preview list
644+
session_view.add_raw_operation(
645+
serialize_kernel_message(
646+
SQLTableListPreview(
647+
metadata=SQLMetadata(
648+
connection="connection1", database="db1", schema="db1"
649+
),
650+
request_id=RequestId("request_id"),
651+
tables=[
652+
DataTable(
653+
name="table2",
654+
source_type="connection",
655+
source="db1",
656+
num_rows=20,
657+
num_columns=10,
658+
variable_name=VariableName("var"),
659+
columns=[],
660+
)
661+
],
662+
)
663+
)
664+
)
665+
666+
assert session_view_connections[0].databases[0].schemas[0].tables == [
667+
DataTable(
668+
source_type="connection",
669+
source="db1",
670+
name="table2",
671+
num_rows=20,
672+
num_columns=10,
673+
variable_name=VariableName("var"),
674+
columns=[],
675+
engine=None,
676+
type="table",
677+
primary_keys=None,
678+
indexes=None,
679+
)
680+
]
681+
682+
561683
def test_add_cell_op(session_view: SessionView) -> None:
562684
session_view.add_raw_operation(
563685
serialize_kernel_message(

0 commit comments

Comments
 (0)