Skip to content

Commit 09de4d5

Browse files
committed
add duplicates test
1 parent 4b129bf commit 09de4d5

File tree

2 files changed

+86
-0
lines changed

2 files changed

+86
-0
lines changed

marimo/_ai/_tools/tools/datasource.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ def _get_tables(
8282
for database in connection.databases:
8383
for schema in database.schemas:
8484
# If query is None, match all schemas
85+
# If matching, add all tables to the list
8586
if query is None or is_fuzzy_match(
8687
query, schema.name, compiled_pattern, is_regex
8788
):

tests/_ai/tools/tools/test_datasource_tool.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -464,3 +464,88 @@ def mock_get_session(_session_id):
464464
# The "orders" table is in the "public" schema, not the "user" schema
465465
# So it won't be included when query matches "user"
466466
assert "orders" not in table_names
467+
468+
469+
def test_query_no_duplicates(tool: GetDatabaseTables):
470+
"""Test that schema-level matching doesn't create duplicates with table-level matching."""
471+
# Create a schema that matches the query AND has tables that also match
472+
schema1 = Schema(
473+
name="users", # This will match query "user"
474+
tables=[
475+
DataTable(
476+
source_type="connection",
477+
source="postgresql",
478+
name="user_profiles", # This would also match "user"
479+
num_rows=10,
480+
num_columns=0,
481+
variable_name=None,
482+
columns=[],
483+
),
484+
DataTable(
485+
source_type="connection",
486+
source="postgresql",
487+
name="user_settings", # This would also match "user"
488+
num_rows=20,
489+
num_columns=0,
490+
variable_name=None,
491+
columns=[],
492+
),
493+
],
494+
)
495+
496+
# Create another schema that doesn't match but has tables that do
497+
schema2 = Schema(
498+
name="products", # This won't match "user"
499+
tables=[
500+
DataTable(
501+
source_type="connection",
502+
source="postgresql",
503+
name="user_reviews", # This would match "user"
504+
num_rows=5,
505+
num_columns=0,
506+
variable_name=None,
507+
columns=[],
508+
),
509+
],
510+
)
511+
512+
database = Database(
513+
name="test_db",
514+
dialect="postgresql",
515+
schemas=[schema1, schema2],
516+
)
517+
518+
connection = MockDataSourceConnection(
519+
name="test_conn",
520+
dialect="postgresql",
521+
databases=[database],
522+
)
523+
524+
session = MockSession(
525+
session_view=MockSessionView(
526+
data_connectors=DataSourceConnections(connections=[connection])
527+
)
528+
)
529+
530+
# Query that matches both schema name and individual table names
531+
result = tool._get_tables(session, query="user")
532+
533+
# Should get all tables from the matching schema (2 tables)
534+
# plus the matching table from the non-matching schema (1 table)
535+
# Total: 3 tables, no duplicates
536+
assert len(result.tables) == 3
537+
538+
# Verify no duplicates by checking unique combinations
539+
table_identifiers = [
540+
(t.connection, t.database, t.schema, t.table.name)
541+
for t in result.tables
542+
]
543+
assert len(table_identifiers) == len(set(table_identifiers)), (
544+
"Found duplicate tables"
545+
)
546+
547+
# Verify we got the expected tables
548+
table_names = [t.table.name for t in result.tables]
549+
assert "user_profiles" in table_names
550+
assert "user_settings" in table_names
551+
assert "user_reviews" in table_names

0 commit comments

Comments
 (0)