Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions pandasai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,26 +114,30 @@ def create(
"Please provide either a DataFrame, a Source or a View"
)

parsed_columns = [Column(**column) for column in columns] if columns else None

if df is not None:
schema = df.schema
schema.name = dataset_name
if (
parsed_columns
): # if no columns are passed it automatically parse the columns from the df
schema.columns = parsed_columns
parquet_file_path_abs_path = file_manager.abs_path(parquet_file_path)
df.to_parquet(parquet_file_path_abs_path, index=False)
elif view:
_relation = [Relation(**relation) for relation in relations or ()]
schema: SemanticLayerSchema = SemanticLayerSchema(
name=dataset_name, relations=_relation, view=True
name=dataset_name, relations=_relation, view=True, columns=parsed_columns
)
elif source.get("table"):
schema: SemanticLayerSchema = SemanticLayerSchema(
name=dataset_name, source=Source(**source)
name=dataset_name, source=Source(**source), columns=parsed_columns
)
else:
raise InvalidConfigError("Unable to create schema with the provided params")

schema.description = description or schema.description
if columns:
schema.columns = [Column(**column) for column in columns]

file_manager.write(schema_path, schema.to_yaml())

Expand Down
13 changes: 7 additions & 6 deletions pandasai/data_loader/semantic_layer_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,7 @@ def check_columns_relations(self):

# unpack columns info
_columns = self.columns

_column_names = [col.name for col in _columns or ()]
_tables_names_in_columns = {
column_name.split(".")[0] for column_name in _column_names or ()
Expand All @@ -309,8 +310,10 @@ def check_columns_relations(self):
for column_name in _column_names_in_relations or ()
}

if not self.relations:
raise ValueError("At least one relation must be defined for view.")
if not self.relations and not self.columns:
raise ValueError(
"At least a relation or a column must be defined for view."
)

if not all(
is_view_column_name(column_name) for column_name in _column_names
Expand All @@ -327,10 +330,8 @@ def check_columns_relations(self):
"All params 'from' and 'to' in the relations must be in the format '[dataset].[column]'."
)

if (
uncovered_tables := _tables_names_in_columns
- _tables_names_in_relations
):
uncovered_tables = _tables_names_in_columns - _tables_names_in_relations
if uncovered_tables and len(_tables_names_in_columns) > 1:
raise ValueError(
f"No relations provided for the following tables {uncovered_tables}."
)
Expand Down
2 changes: 1 addition & 1 deletion pandasai/data_loader/view_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def _get_dependencies_datasets(self) -> set[str]:
table.split(".")[0]
for relation in self.schema.relations
for table in (relation.from_, relation.to)
}
} or {self.schema.columns[0].name.split(".")[0]}

def _get_dependencies_schemas(self) -> dict[str, DatasetLoader]:
dependency_dict = {
Expand Down
1 change: 0 additions & 1 deletion pandasai/dataframe/virtual_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ def __init__(self, *args, **kwargs):
self._head = None

super().__init__(
self.get_head(),
*args,
**kwargs,
)
Expand Down
2 changes: 1 addition & 1 deletion pandasai/helpers/sql_sanitizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import sqlglot


def sanitize_relation_name(relation_name: str) -> str:
def sanitize_view_column_name(relation_name: str) -> str:
return ".".join(list(map(sanitize_sql_table_name, relation_name.split("."))))


Expand Down
9 changes: 3 additions & 6 deletions pandasai/query_builders/base_query_builder.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
import re
from typing import Any, List
from typing import List

import sqlglot
from sqlglot import from_, pretty, select
from sqlglot.expressions import Limit, cast
from sqlglot import select
from sqlglot.optimizer.normalize_identifiers import normalize_identifiers

from pandasai.data_loader.semantic_layer_schema import SemanticLayerSchema, Source
Expand Down Expand Up @@ -42,7 +39,7 @@ def _get_columns(self) -> list[str]:
return ["*"]

def _get_table_expression(self) -> str:
return normalize_identifiers(self.schema.name).sql()
return normalize_identifiers(self.schema.name).sql(pretty=True)

@staticmethod
def check_compatible_sources(sources: List[Source]) -> bool:
Expand Down
9 changes: 5 additions & 4 deletions pandasai/query_builders/sql_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,16 @@ def transform_node(node):
if isinstance(node, exp.Table):
original_name = node.name
if original_name in table_mapping:
alias = node.alias or original_name
mapped_value = parsed_mapping[original_name]
if isinstance(mapped_value, exp.Alias):
return exp.Subquery(
this=mapped_value.this.this,
alias=node.alias or original_name,
alias=alias,
)
return exp.Subquery(
this=mapped_value, alias=node.alias or original_name
)
elif isinstance(mapped_value, exp.Column):
return exp.Table(this=mapped_value.this, alias=alias)
return exp.Subquery(this=mapped_value, alias=alias)

return node

Expand Down
25 changes: 20 additions & 5 deletions pandasai/query_builders/view_query_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from ..data_loader.loader import DatasetLoader
from ..data_loader.semantic_layer_schema import SemanticLayerSchema
from ..helpers.sql_sanitizer import sanitize_relation_name
from ..helpers.sql_sanitizer import sanitize_view_column_name
from .base_query_builder import BaseQueryBuilder


Expand All @@ -19,10 +19,20 @@ def __init__(
super().__init__(schema)
self.schema_dependencies_dict = schema_dependencies_dict

@staticmethod
def normalize_view_column_name(name: str) -> str:
return normalize_identifiers(parse_one(sanitize_view_column_name(name))).sql()

@staticmethod
def normalize_view_column_alias(name: str) -> str:
return normalize_identifiers(
sanitize_view_column_name(name).replace(".", "_")
).sql()

def _get_columns(self) -> list[str]:
if self.schema.columns:
return [
normalize_identifiers(col.name.replace(".", "_")).sql()
self.normalize_view_column_alias(col.name)
for col in self.schema.columns
]
else:
Expand All @@ -34,13 +44,18 @@ def _get_sub_query_from_loader(self, loader: DatasetLoader) -> Subquery:

def _get_table_expression(self) -> str:
relations = self.schema.relations
first_dataset = relations[0].from_.split(".")[0]
columns = self.schema.columns
first_dataset = (
relations[0].from_.split(".")[0]
if relations
else columns[0].name.split(".")[0]
)
first_loader = self.schema_dependencies_dict[first_dataset]
first_query = self._get_sub_query_from_loader(first_loader)

if self.schema.columns:
columns = [
f"{normalize_identifiers(col.name).sql()} AS {normalize_identifiers(col.name.replace('.', '_'))}"
f"{self.normalize_view_column_name(col.name)} AS {self.normalize_view_column_alias(col.name)}"
for col in self.schema.columns
]
else:
Expand All @@ -54,7 +69,7 @@ def _get_table_expression(self) -> str:
subquery = self._get_sub_query_from_loader(loader)
query = query.join(
subquery,
on=f"{sanitize_relation_name(relation.from_)} = {sanitize_relation_name(relation.to)}",
on=f"{sanitize_view_column_name(relation.from_)} = {sanitize_view_column_name(relation.to)}",
append=True,
)
alias = normalize_identifiers(self.schema.name).sql()
Expand Down
4 changes: 2 additions & 2 deletions tests/unit_tests/helpers/test_sql_sanitizer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from pandasai.helpers.sql_sanitizer import (
is_sql_query_safe,
sanitize_file_name,
sanitize_relation_name,
sanitize_view_column_name,
)


Expand All @@ -25,7 +25,7 @@ def test_sanitize_file_name_long_name(self):
def test_sanitize_relation_name_valid(self):
relation = "dataset-name.column"
expected = "dataset_name.column"
assert sanitize_relation_name(relation) == expected
assert sanitize_view_column_name(relation) == expected

def test_safe_select_query(self):
query = "SELECT * FROM users WHERE username = 'admin';"
Expand Down
16 changes: 16 additions & 0 deletions tests/unit_tests/query_builders/test_query_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,3 +250,19 @@ def test_table_name_time_based_injection(self, mysql_schema):
created_at DESC
LIMIT 100"""
)

@pytest.mark.parametrize(
"injection",
[
"users; DROP TABLE users;",
"users UNION SELECT 1,2,3;",
'users"; SELECT * FROM sensitive_data; --',
"users; TRUNCATE users; SELECT * FROM users WHERE 't'='t",
"users' AND (SELECT * FROM (SELECT(SLEEP(5)))test); --",
],
)
def test_order_by_injection(self, injection, mysql_schema):
mysql_schema.order_by = [injection]
query_builder = BaseQueryBuilder(mysql_schema)
with pytest.raises((sqlglot.errors.ParseError, sqlglot.errors.TokenError)):
query_builder.build_query()
58 changes: 58 additions & 0 deletions tests/unit_tests/query_builders/test_sql_parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import pytest

from pandasai.query_builders.sql_parser import SQLParser


class TestSqlParser:
@staticmethod
@pytest.mark.parametrize(
"query, table_mapping, expected",
[
(
"SELECT * FROM customers",
{"customers": "clients"},
"""SELECT
*
FROM "clients" AS customers""",
),
(
"SELECT * FROM orders",
{"orders": "(SELECT * FROM sales)"},
"""SELECT
*
FROM (
(
SELECT
*
FROM "sales"
)
) AS orders""",
),
(
"SELECT * FROM customers c",
{"customers": "clients"},
"""SELECT
*
FROM "clients" AS c""",
),
(
"SELECT c.id, o.amount FROM customers c JOIN orders o ON c.id = o.customer_id",
{"customers": "clients", "orders": "(SELECT * FROM sales)"},
'''SELECT
"c"."id",
"o"."amount"
FROM "clients" AS c
JOIN (
(
SELECT
*
FROM "sales"
)
) AS o
ON "c"."id" = "o"."customer_id"''',
),
],
)
def test_replace_table_names(query, table_mapping, expected):
result = SQLParser.replace_table_and_column_names(query, table_mapping)
assert result.strip() == expected.strip()
Loading
Loading