Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pandasai/data_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from pandasai.dataframe.virtual_dataframe import VirtualDataFrame
from pandasai.exceptions import InvalidDataSourceType
from pandasai.helpers.path import find_project_root
from pandasai.helpers.sql_sanitizer import sanitize_sql_table_name

from ..constants import (
LOCAL_SOURCE_TYPES,
Expand Down Expand Up @@ -78,6 +79,7 @@ def _load_schema(self):

with open(schema_path, "r") as file:
raw_schema = yaml.safe_load(file)
raw_schema["name"] = sanitize_sql_table_name(raw_schema["name"])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ensure that there is a corresponding unit test for the sanitize_sql_table_name function to verify that the dataset name is correctly sanitized in the schema.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tests already defined in tests/unit_tests/helpers/test_sql_sanitizer.py

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@scaliseraoul maybe we could also add a test to enforce the name of the schema loaded from the loader is sanitized (like in this case)?

self.schema = SemanticLayerSchema(**raw_schema)

def _get_loader_function(self, source_type: str):
Expand Down
11 changes: 11 additions & 0 deletions tests/unit_tests/dataframe/test_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,17 @@ def test_load_schema_mysql(self, mysql_schema):
loader._load_schema()
assert loader.schema == mysql_schema

def test_load_schema_mysql_sanitized_name(self, mysql_schema):
mysql_schema.name = "non-sanitized-name"

with patch("os.path.exists", return_value=True), patch(
"builtins.open", mock_open(read_data=str(mysql_schema.to_yaml()))
):
loader = DatasetLoader()
loader.dataset_path = "test/users"
loader._load_schema()
assert loader.schema.name == "non_sanitized_name"

def test_load_schema_file_not_found(self):
with patch("os.path.exists", return_value=False):
loader = DatasetLoader()
Expand Down
Loading