Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 17 additions & 9 deletions dask_sql/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from dask_sql.mappings import python_to_sql_type
from dask_sql.physical.rel import RelConverter, logical
from dask_sql.physical.rex import RexConverter, core
from dask_sql.datacontainer import DataContainer, ColumnContainer
from dask_sql.utils import ParsingException

FunctionDescription = namedtuple(
Expand Down Expand Up @@ -75,13 +76,15 @@ def __init__(self):

def register_dask_table(self, df: dd.DataFrame, name: str):
"""
Registering a dask table makes it usable in SQl queries.
Registering a dask table makes it usable in SQL queries.
The name you give here can be used as table name in the SQL later.

Please note, that the table is stored as it is now.
If you change the table later, you need to re-register.
"""
self.tables[name.lower()] = df.copy()
self.tables[name.lower()] = DataContainer(
df.copy(), ColumnContainer(df.columns)
)

def register_function(
self,
Expand Down Expand Up @@ -167,7 +170,7 @@ def sql(self, sql: str, debug: bool = False) -> dd.DataFrame:
"""
try:
rel, select_names = self._get_ral(sql, debug=debug)
df = RelConverter.convert(rel, context=self)
dc = RelConverter.convert(rel, context=self)
except (ValidationException, SqlParseException) as e:
if debug:
from_chained_exception = e
Expand All @@ -182,12 +185,16 @@ def sql(self, sql: str, debug: bool = False) -> dd.DataFrame:

if select_names:
# Rename any columns named EXPR$* to a more human readable name
df.columns = [
df_col if not df_col.startswith("EXPR$") else select_name
for df_col, select_name in zip(df.columns, select_names)
]
cc = dc.column_container
cc = cc.rename(
{
df_col: df_col if not df_col.startswith("EXPR$") else select_name
for df_col, select_name in zip(cc.columns, select_names)
}
)
dc = DataContainer(dc.df, cc)

return df
return dc.assign()

def _prepare_schema(self):
"""
Expand All @@ -196,8 +203,9 @@ def _prepare_schema(self):
"""
schema = DaskSchema("schema")

for name, df in self.tables.items():
for name, dc in self.tables.items():
table = DaskTable(name)
df = dc.df
for column in df.columns:
data_type = df[column].dtype
sql_data_type = python_to_sql_type(data_type)
Expand Down
161 changes: 161 additions & 0 deletions dask_sql/datacontainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
from typing import List, Dict, Tuple, Union

import dask.dataframe as dd


ColumnType = Union[str, int]


class ColumnContainer:
# Forward declaration
pass


class ColumnContainer:
"""
Helper class to store a list of columns,
which do not necessarily be the ones of the dask dataframe.
Instead, the container also stores a mapping from "frontend"
columns (columns with the names and order expected by SQL)
to "backend" columns (the real column names used by dask)
to prevent unnecessary renames.
"""

def __init__(
self,
frontend_columns: List[str],
frontend_backend_mapping: Union[Dict[str, ColumnType], None] = None,
):
assert all(
isinstance(col, str) for col in frontend_columns
), "All frontend columns need to be of string type"
self._frontend_columns = list(frontend_columns)
if frontend_backend_mapping is None:
self._frontend_backend_mapping = {
col: col for col in self._frontend_columns
}
else:
self._frontend_backend_mapping = frontend_backend_mapping

def _copy(self) -> ColumnContainer:
"""
Internal function to copy this container
"""
return ColumnContainer(self._frontend_columns, self._frontend_backend_mapping)

def limit_to(self, fields: List[str]) -> ColumnContainer:
"""
Create a new ColumnContainer, which has frontend columns
limited to only the ones given as parameter.
Also uses the order of these as the new column order.
"""
assert all(f in self._frontend_backend_mapping for f in fields)
cc = self._copy()
cc._frontend_columns = [str(x) for x in fields]
return cc

def rename(self, columns: Dict[str, str]) -> ColumnContainer:
"""
Return a new ColumnContainer where the frontend columns
are renamed according to the given mapping.
Columns not present in the mapping are not touched,
the order is preserved.
"""
cc = self._copy()
for column_from, column_to in columns.items():
backend_column = self._frontend_backend_mapping[str(column_from)]
cc._frontend_backend_mapping[str(column_to)] = backend_column

cc._frontend_columns = [
str(columns[col]) if col in columns else col
for col in self._frontend_columns
]

return cc

def mapping(self) -> List[Tuple[str, ColumnType]]:
"""
The mapping from frontend columns to backend columns.
"""
return list(self._frontend_backend_mapping.items())

@property
def columns(self) -> List[str]:
"""
The stored frontend columns in the correct order
"""
return self._frontend_columns

def add(
self, frontend_column: str, backend_column: Union[str, None] = None
) -> ColumnContainer:
"""
Return a new ColumnContainer with the
given column added.
The column is added at the last position in the column list.
"""
cc = self._copy()

frontend_column = str(frontend_column)

cc._frontend_backend_mapping[frontend_column] = str(
backend_column or frontend_column
)
cc._frontend_columns.append(frontend_column)

return cc

def get_backend_by_frontend_index(self, index: int) -> str:
"""
Get back the dask column, which is referenced by the
frontend (SQL) column with the given index.
"""
frontend_column = self._frontend_columns[index]
backend_column = self._frontend_backend_mapping[frontend_column]
return backend_column

def make_unique(self, prefix="col"):
"""
Make sure we have unique column names by calling each column

<prefix>_<number>

where <number> is the column index.
"""
return self.rename(
columns={str(col): f"{prefix}_{i}" for i, col in enumerate(self.columns)}
)


class DataContainer:
"""
In SQL, every column operation or reference is done via
the column index. Some dask operations, such as grouping,
joining or concatenating preserve the columns in a different
order than SQL would expect.
However, we do not want to change the column data itself
all the time (because this would lead to computational overhead),
but still would like to keep the columns accessible by name and index.
For this, we add an additional `ColumnContainer` to each dataframe,
which does all the column mapping between "frontend"
(what SQL expects, also in the correct order)
and "backend" (what dask has).
"""

def __init__(self, df: dd.DataFrame, column_container: ColumnContainer):
self.df = df
self.column_container = column_container

def assign(self) -> dd.DataFrame:
"""
Combine the column mapping with the actual data and return
a dataframe which has the the columns specified in the
stored ColumnContainer.
"""
df = self.df.assign(
**{
col_from: self.df[col_to]
for col_from, col_to in self.column_container.mapping()
}
)
return df[self.column_container.columns]
25 changes: 7 additions & 18 deletions dask_sql/physical/rel/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import dask.dataframe as dd

from dask_sql.datacontainer import ColumnContainer


class BaseRelPlugin:
"""
Expand All @@ -22,20 +24,20 @@ def convert(

@staticmethod
def fix_column_to_row_type(
df: dd.DataFrame, row_type: "org.apache.calcite.rel.type.RelDataType"
) -> dd.DataFrame:
cc: ColumnContainer, row_type: "org.apache.calcite.rel.type.RelDataType"
) -> ColumnContainer:
"""
Make sure that the given dask dataframe
Make sure that the given column container
has the column names specified by the row type.
We assume that the column order is already correct
and will just "blindly" rename the columns.
"""
field_names = [str(x) for x in row_type.getFieldNames()]

df = df.rename(columns=dict(zip(df.columns, field_names)))
cc = cc.rename(columns=dict(zip(cc.columns, field_names)))

# TODO: We can also check for the types here and do any conversions if needed
return df[field_names]
return cc.limit_to(field_names)

@staticmethod
def check_columns_from_row_type(
Expand Down Expand Up @@ -73,16 +75,3 @@ def assert_inputs(
from dask_sql.physical.rel.convert import RelConverter

return [RelConverter.convert(input_rel, context) for input_rel in input_rels]

@staticmethod
def make_unique(df, prefix="col"):
"""
Make sure we have unique column names by calling each column

prefix_number

where number is the column index.
"""
return df.rename(
columns={col: f"{prefix}_{i}" for i, col in enumerate(df.columns)}
)
43 changes: 29 additions & 14 deletions dask_sql/physical/rel/logical/aggregate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
from collections import defaultdict
from functools import reduce
from typing import Callable, Dict, List, Tuple, Union
import uuid

import dask.dataframe as dd

from dask_sql.physical.rel.base import BaseRelPlugin
from dask_sql.datacontainer import DataContainer, ColumnContainer


class GroupDatasetDescription:
Expand Down Expand Up @@ -87,26 +89,38 @@ class LogicalAggregatePlugin(BaseRelPlugin):

def convert(
self, rel: "org.apache.calcite.rel.RelNode", context: "dask_sql.Context"
) -> dd.DataFrame:
(df,) = self.assert_inputs(rel, 1, context)
) -> DataContainer:
(dc,) = self.assert_inputs(rel, 1, context)

df = dc.df
cc = dc.column_container

# We make our life easier with having unique column names
df = self.make_unique(df)
cc = cc.make_unique()

# I have no idea what that is, but so far it was always of length 1
assert len(rel.getGroupSets()) == 1, "Do not know how to handle this case!"

# Extract the information, which columns we need to group for
group_column_indices = [int(i) for i in rel.getGroupSet()]
group_columns = [df.columns[i] for i in group_column_indices]
group_columns = [
cc.get_backend_by_frontend_index(i) for i in group_column_indices
]

# Always keep an additional column around for empty groups and aggregates
additional_column_name = str(len(df.columns))
additional_column_name = str(uuid.uuid4())

# NOTE: it might be the case that
# we do not need this additional
# column, but hopefully adding a single
# column of 1 is not so problematic...
df = df.assign(**{additional_column_name: 1})
cc = cc.add(additional_column_name)
dc = DataContainer(df, cc)

# Collect all aggregates
filtered_aggregations, output_column_order = self._collect_aggregations(
rel, df, group_columns, additional_column_name, context
rel, dc, group_columns, additional_column_name, context
)

if not group_columns:
Expand Down Expand Up @@ -143,16 +157,15 @@ def convert(

# Fix the column names and the order of them, as this was messed with during the aggregations
df_agg.columns = df_agg.columns.get_level_values(-1)
df_agg = df_agg[output_column_order]

df_agg = self.fix_column_to_row_type(df_agg, rel.getRowType())
cc = ColumnContainer(df_agg.columns).limit_to(output_column_order)

return df_agg
cc = self.fix_column_to_row_type(cc, rel.getRowType())
return DataContainer(df_agg, cc)

def _collect_aggregations(
self,
rel: "org.apache.calcite.rel.RelNode",
df: dd.DataFrame,
dc: DataContainer,
group_columns: List[str],
additional_column_name: str,
context: "dask_sql.Context",
Expand All @@ -165,6 +178,8 @@ def _collect_aggregations(
"""
aggregations = defaultdict(lambda: defaultdict(dict))
output_column_order = []
df = dc.df
cc = dc.column_container

# SQL needs to copy the old content also. As the values of the group columns
# are the same for a single group anyways, we just use the first row
Expand All @@ -178,8 +193,8 @@ def _collect_aggregations(
expr = agg_call.getKey()

if expr.hasFilter():
filter_column = expr.filterArg
filter_expression = df.iloc[:, filter_column]
filter_column = cc.get_backend_by_frontend_index(expr.filterArg)
filter_expression = df[filter_column]
filtered_df = df[filter_expression]

grouped_df = GroupDatasetDescription(filtered_df, filter_column)
Expand All @@ -205,7 +220,7 @@ def _collect_aggregations(

inputs = expr.getArgList()
if len(inputs) == 1:
input_col = df.columns[inputs[0]]
input_col = cc.get_backend_by_frontend_index(inputs[0])
elif len(inputs) == 0:
input_col = additional_column_name
else:
Expand Down
Loading