Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 15 additions & 8 deletions dask_sql/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ def __init__(self):
RelConverter.add_plugin_class(logical.LogicalTableScanPlugin, replace=False)
RelConverter.add_plugin_class(logical.LogicalUnionPlugin, replace=False)
RelConverter.add_plugin_class(logical.LogicalValuesPlugin, replace=False)
RelConverter.add_plugin_class(logical.LogicalWindowPlugin, replace=False)
RelConverter.add_plugin_class(logical.SamplePlugin, replace=False)
RelConverter.add_plugin_class(custom.AnalyzeTablePlugin, replace=False)
RelConverter.add_plugin_class(custom.CreateExperimentPlugin, replace=False)
Expand All @@ -108,7 +109,6 @@ def __init__(self):
RexConverter.add_plugin_class(core.RexCallPlugin, replace=False)
RexConverter.add_plugin_class(core.RexInputRefPlugin, replace=False)
RexConverter.add_plugin_class(core.RexLiteralPlugin, replace=False)
RexConverter.add_plugin_class(core.RexOverPlugin, replace=False)

InputUtil.add_plugin_class(input_utils.DaskInputPlugin, replace=False)
InputUtil.add_plugin_class(input_utils.PandasInputPlugin, replace=False)
Expand Down Expand Up @@ -427,7 +427,7 @@ def sql(
cc = dc.column_container
cc = cc.rename(
{
df_col: df_col if not df_col.startswith("EXPR$") else select_name
df_col: select_name
for df_col, select_name in zip(cc.columns, select_names)
}
)
Expand Down Expand Up @@ -711,12 +711,18 @@ def _get_ral(self, sql):
sqlNode = generator.getSqlNode(sql)
sqlNodeClass = get_java_class(sqlNode)

if sqlNodeClass.startswith("com.dask.sql.parser."):
rel = sqlNode
rel_string = ""
else:
select_names = None
rel = sqlNode
rel_string = ""

if not sqlNodeClass.startswith("com.dask.sql.parser."):
validatedSqlNode = generator.getValidatedNode(sqlNode)
nonOptimizedRelNode = generator.getRelationalAlgebra(validatedSqlNode)
# Optimization might remove some alias projects. Make sure to keep them here.
select_names = [
str(name)
for name in nonOptimizedRelNode.getRowType().getFieldNames()
]
rel = generator.getOptimizedRelationalAlgebra(nonOptimizedRelNode)
rel_string = str(generator.getRelationalAlgebraString(rel))
except (ValidationException, SqlParseException) as e:
Expand All @@ -741,13 +747,14 @@ def _get_ral(self, sql):
if sqlNodeClass == "org.apache.calcite.sql.SqlSelect":
select_names = [
self._to_sql_string(s, default_dialect=default_dialect)
for s in sqlNode.getSelectList()
if current_name.startswith("EXPR$")
else current_name
for s, current_name in zip(sqlNode.getSelectList(), select_names)
]
else:
logger.debug(
"Not extracting output column names as the SQL is not a SELECT call"
)
select_names = None

logger.debug(f"Extracted relational algebra:\n {rel_string}")
return rel, select_names, rel_string
Expand Down
2 changes: 2 additions & 0 deletions dask_sql/physical/rel/logical/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .table_scan import LogicalTableScanPlugin
from .union import LogicalUnionPlugin
from .values import LogicalValuesPlugin
from .window import LogicalWindowPlugin

__all__ = [
LogicalAggregatePlugin,
Expand All @@ -17,5 +18,6 @@
LogicalTableScanPlugin,
LogicalUnionPlugin,
LogicalValuesPlugin,
LogicalWindowPlugin,
SamplePlugin,
]
Loading