diff --git a/dask_sql/context.py b/dask_sql/context.py index 9f159a5a6..3c600653a 100644 --- a/dask_sql/context.py +++ b/dask_sql/context.py @@ -16,6 +16,7 @@ from dask_sql.mappings import python_to_sql_type from dask_sql.physical.rel import RelConverter, logical from dask_sql.physical.rex import RexConverter, core +from dask_sql.datacontainer import DataContainer, ColumnContainer from dask_sql.utils import ParsingException FunctionDescription = namedtuple( @@ -75,13 +76,15 @@ def __init__(self): def register_dask_table(self, df: dd.DataFrame, name: str): """ - Registering a dask table makes it usable in SQl queries. + Registering a dask table makes it usable in SQL queries. The name you give here can be used as table name in the SQL later. Please note, that the table is stored as it is now. If you change the table later, you need to re-register. """ - self.tables[name.lower()] = df.copy() + self.tables[name.lower()] = DataContainer( + df.copy(), ColumnContainer(df.columns) + ) def register_function( self, @@ -167,7 +170,7 @@ def sql(self, sql: str, debug: bool = False) -> dd.DataFrame: """ try: rel, select_names = self._get_ral(sql, debug=debug) - df = RelConverter.convert(rel, context=self) + dc = RelConverter.convert(rel, context=self) except (ValidationException, SqlParseException) as e: if debug: from_chained_exception = e @@ -182,12 +185,16 @@ def sql(self, sql: str, debug: bool = False) -> dd.DataFrame: if select_names: # Rename any columns named EXPR$* to a more human readable name - df.columns = [ - df_col if not df_col.startswith("EXPR$") else select_name - for df_col, select_name in zip(df.columns, select_names) - ] + cc = dc.column_container + cc = cc.rename( + { + df_col: df_col if not df_col.startswith("EXPR$") else select_name + for df_col, select_name in zip(cc.columns, select_names) + } + ) + dc = DataContainer(dc.df, cc) - return df + return dc.assign() def _prepare_schema(self): """ @@ -196,8 +203,9 @@ def _prepare_schema(self): """ schema = DaskSchema("schema") - for name, df in self.tables.items(): + for name, dc in self.tables.items(): table = DaskTable(name) + df = dc.df for column in df.columns: data_type = df[column].dtype sql_data_type = python_to_sql_type(data_type) diff --git a/dask_sql/datacontainer.py b/dask_sql/datacontainer.py new file mode 100644 index 000000000..c5182566c --- /dev/null +++ b/dask_sql/datacontainer.py @@ -0,0 +1,161 @@ +from typing import List, Dict, Tuple, Union + +import dask.dataframe as dd + + +ColumnType = Union[str, int] + + +class ColumnContainer: + # Forward declaration + pass + + +class ColumnContainer: + """ + Helper class to store a list of columns, + which do not necessarily be the ones of the dask dataframe. + Instead, the container also stores a mapping from "frontend" + columns (columns with the names and order expected by SQL) + to "backend" columns (the real column names used by dask) + to prevent unnecessary renames. + """ + + def __init__( + self, + frontend_columns: List[str], + frontend_backend_mapping: Union[Dict[str, ColumnType], None] = None, + ): + assert all( + isinstance(col, str) for col in frontend_columns + ), "All frontend columns need to be of string type" + self._frontend_columns = list(frontend_columns) + if frontend_backend_mapping is None: + self._frontend_backend_mapping = { + col: col for col in self._frontend_columns + } + else: + self._frontend_backend_mapping = frontend_backend_mapping + + def _copy(self) -> ColumnContainer: + """ + Internal function to copy this container + """ + return ColumnContainer(self._frontend_columns, self._frontend_backend_mapping) + + def limit_to(self, fields: List[str]) -> ColumnContainer: + """ + Create a new ColumnContainer, which has frontend columns + limited to only the ones given as parameter. + Also uses the order of these as the new column order. + """ + assert all(f in self._frontend_backend_mapping for f in fields) + cc = self._copy() + cc._frontend_columns = [str(x) for x in fields] + return cc + + def rename(self, columns: Dict[str, str]) -> ColumnContainer: + """ + Return a new ColumnContainer where the frontend columns + are renamed according to the given mapping. + Columns not present in the mapping are not touched, + the order is preserved. + """ + cc = self._copy() + for column_from, column_to in columns.items(): + backend_column = self._frontend_backend_mapping[str(column_from)] + cc._frontend_backend_mapping[str(column_to)] = backend_column + + cc._frontend_columns = [ + str(columns[col]) if col in columns else col + for col in self._frontend_columns + ] + + return cc + + def mapping(self) -> List[Tuple[str, ColumnType]]: + """ + The mapping from frontend columns to backend columns. + """ + return list(self._frontend_backend_mapping.items()) + + @property + def columns(self) -> List[str]: + """ + The stored frontend columns in the correct order + """ + return self._frontend_columns + + def add( + self, frontend_column: str, backend_column: Union[str, None] = None + ) -> ColumnContainer: + """ + Return a new ColumnContainer with the + given column added. + The column is added at the last position in the column list. + """ + cc = self._copy() + + frontend_column = str(frontend_column) + + cc._frontend_backend_mapping[frontend_column] = str( + backend_column or frontend_column + ) + cc._frontend_columns.append(frontend_column) + + return cc + + def get_backend_by_frontend_index(self, index: int) -> str: + """ + Get back the dask column, which is referenced by the + frontend (SQL) column with the given index. + """ + frontend_column = self._frontend_columns[index] + backend_column = self._frontend_backend_mapping[frontend_column] + return backend_column + + def make_unique(self, prefix="col"): + """ + Make sure we have unique column names by calling each column + + _ + + where is the column index. + """ + return self.rename( + columns={str(col): f"{prefix}_{i}" for i, col in enumerate(self.columns)} + ) + + +class DataContainer: + """ + In SQL, every column operation or reference is done via + the column index. Some dask operations, such as grouping, + joining or concatenating preserve the columns in a different + order than SQL would expect. + However, we do not want to change the column data itself + all the time (because this would lead to computational overhead), + but still would like to keep the columns accessible by name and index. + For this, we add an additional `ColumnContainer` to each dataframe, + which does all the column mapping between "frontend" + (what SQL expects, also in the correct order) + and "backend" (what dask has). + """ + + def __init__(self, df: dd.DataFrame, column_container: ColumnContainer): + self.df = df + self.column_container = column_container + + def assign(self) -> dd.DataFrame: + """ + Combine the column mapping with the actual data and return + a dataframe which has the the columns specified in the + stored ColumnContainer. + """ + df = self.df.assign( + **{ + col_from: self.df[col_to] + for col_from, col_to in self.column_container.mapping() + } + ) + return df[self.column_container.columns] diff --git a/dask_sql/physical/rel/base.py b/dask_sql/physical/rel/base.py index f763db0ae..7c4beaf6b 100644 --- a/dask_sql/physical/rel/base.py +++ b/dask_sql/physical/rel/base.py @@ -2,6 +2,8 @@ import dask.dataframe as dd +from dask_sql.datacontainer import ColumnContainer + class BaseRelPlugin: """ @@ -22,20 +24,20 @@ def convert( @staticmethod def fix_column_to_row_type( - df: dd.DataFrame, row_type: "org.apache.calcite.rel.type.RelDataType" - ) -> dd.DataFrame: + cc: ColumnContainer, row_type: "org.apache.calcite.rel.type.RelDataType" + ) -> ColumnContainer: """ - Make sure that the given dask dataframe + Make sure that the given column container has the column names specified by the row type. We assume that the column order is already correct and will just "blindly" rename the columns. """ field_names = [str(x) for x in row_type.getFieldNames()] - df = df.rename(columns=dict(zip(df.columns, field_names))) + cc = cc.rename(columns=dict(zip(cc.columns, field_names))) # TODO: We can also check for the types here and do any conversions if needed - return df[field_names] + return cc.limit_to(field_names) @staticmethod def check_columns_from_row_type( @@ -73,16 +75,3 @@ def assert_inputs( from dask_sql.physical.rel.convert import RelConverter return [RelConverter.convert(input_rel, context) for input_rel in input_rels] - - @staticmethod - def make_unique(df, prefix="col"): - """ - Make sure we have unique column names by calling each column - - prefix_number - - where number is the column index. - """ - return df.rename( - columns={col: f"{prefix}_{i}" for i, col in enumerate(df.columns)} - ) diff --git a/dask_sql/physical/rel/logical/aggregate.py b/dask_sql/physical/rel/logical/aggregate.py index f7394f2ab..f23532735 100644 --- a/dask_sql/physical/rel/logical/aggregate.py +++ b/dask_sql/physical/rel/logical/aggregate.py @@ -2,10 +2,12 @@ from collections import defaultdict from functools import reduce from typing import Callable, Dict, List, Tuple, Union +import uuid import dask.dataframe as dd from dask_sql.physical.rel.base import BaseRelPlugin +from dask_sql.datacontainer import DataContainer, ColumnContainer class GroupDatasetDescription: @@ -87,26 +89,38 @@ class LogicalAggregatePlugin(BaseRelPlugin): def convert( self, rel: "org.apache.calcite.rel.RelNode", context: "dask_sql.Context" - ) -> dd.DataFrame: - (df,) = self.assert_inputs(rel, 1, context) + ) -> DataContainer: + (dc,) = self.assert_inputs(rel, 1, context) + + df = dc.df + cc = dc.column_container # We make our life easier with having unique column names - df = self.make_unique(df) + cc = cc.make_unique() # I have no idea what that is, but so far it was always of length 1 assert len(rel.getGroupSets()) == 1, "Do not know how to handle this case!" # Extract the information, which columns we need to group for group_column_indices = [int(i) for i in rel.getGroupSet()] - group_columns = [df.columns[i] for i in group_column_indices] + group_columns = [ + cc.get_backend_by_frontend_index(i) for i in group_column_indices + ] # Always keep an additional column around for empty groups and aggregates - additional_column_name = str(len(df.columns)) + additional_column_name = str(uuid.uuid4()) + + # NOTE: it might be the case that + # we do not need this additional + # column, but hopefully adding a single + # column of 1 is not so problematic... df = df.assign(**{additional_column_name: 1}) + cc = cc.add(additional_column_name) + dc = DataContainer(df, cc) # Collect all aggregates filtered_aggregations, output_column_order = self._collect_aggregations( - rel, df, group_columns, additional_column_name, context + rel, dc, group_columns, additional_column_name, context ) if not group_columns: @@ -143,16 +157,15 @@ def convert( # Fix the column names and the order of them, as this was messed with during the aggregations df_agg.columns = df_agg.columns.get_level_values(-1) - df_agg = df_agg[output_column_order] - - df_agg = self.fix_column_to_row_type(df_agg, rel.getRowType()) + cc = ColumnContainer(df_agg.columns).limit_to(output_column_order) - return df_agg + cc = self.fix_column_to_row_type(cc, rel.getRowType()) + return DataContainer(df_agg, cc) def _collect_aggregations( self, rel: "org.apache.calcite.rel.RelNode", - df: dd.DataFrame, + dc: DataContainer, group_columns: List[str], additional_column_name: str, context: "dask_sql.Context", @@ -165,6 +178,8 @@ def _collect_aggregations( """ aggregations = defaultdict(lambda: defaultdict(dict)) output_column_order = [] + df = dc.df + cc = dc.column_container # SQL needs to copy the old content also. As the values of the group columns # are the same for a single group anyways, we just use the first row @@ -178,8 +193,8 @@ def _collect_aggregations( expr = agg_call.getKey() if expr.hasFilter(): - filter_column = expr.filterArg - filter_expression = df.iloc[:, filter_column] + filter_column = cc.get_backend_by_frontend_index(expr.filterArg) + filter_expression = df[filter_column] filtered_df = df[filter_expression] grouped_df = GroupDatasetDescription(filtered_df, filter_column) @@ -205,7 +220,7 @@ def _collect_aggregations( inputs = expr.getArgList() if len(inputs) == 1: - input_col = df.columns[inputs[0]] + input_col = cc.get_backend_by_frontend_index(inputs[0]) elif len(inputs) == 0: input_col = additional_column_name else: diff --git a/dask_sql/physical/rel/logical/filter.py b/dask_sql/physical/rel/logical/filter.py index bb49153a8..1ca624393 100644 --- a/dask_sql/physical/rel/logical/filter.py +++ b/dask_sql/physical/rel/logical/filter.py @@ -4,6 +4,7 @@ from dask_sql.physical.rex import RexConverter from dask_sql.physical.rel.base import BaseRelPlugin +from dask_sql.datacontainer import DataContainer class LogicalFilterPlugin(BaseRelPlugin): @@ -16,13 +17,16 @@ class LogicalFilterPlugin(BaseRelPlugin): def convert( self, rel: "org.apache.calcite.rel.RelNode", context: "dask_sql.Context" - ) -> dd.DataFrame: - (df,) = self.assert_inputs(rel, 1, context) - self.check_columns_from_row_type(df, rel.getExpectedInputRowType(0)) + ) -> DataContainer: + (dc,) = self.assert_inputs(rel, 1, context) + df = dc.df + cc = dc.column_container + # Every logic is handled in the RexConverter + # we just need to apply it here condition = rel.getCondition() - df_condition = RexConverter.convert(condition, df, context=context) + df_condition = RexConverter.convert(condition, dc, context=context) df = df[df_condition] - df = self.fix_column_to_row_type(df, rel.getRowType()) - return df + cc = self.fix_column_to_row_type(cc, rel.getRowType()) + return DataContainer(df, cc) diff --git a/dask_sql/physical/rel/logical/join.py b/dask_sql/physical/rel/logical/join.py index 4228f277c..4af06b877 100644 --- a/dask_sql/physical/rel/logical/join.py +++ b/dask_sql/physical/rel/logical/join.py @@ -8,6 +8,7 @@ from dask_sql.physical.rex import RexConverter from dask_sql.java import get_short_java_class from dask_sql.physical.rel.base import BaseRelPlugin +from dask_sql.datacontainer import DataContainer, ColumnContainer class LogicalJoinPlugin(BaseRelPlugin): @@ -37,19 +38,27 @@ class LogicalJoinPlugin(BaseRelPlugin): def convert( self, rel: "org.apache.calcite.rel.RelNode", context: "dask_sql.Context" - ) -> dd.DataFrame: + ) -> DataContainer: # Joining is a bit more complicated, so lets do it in steps: # 1. We now have two inputs (from left and right), so we fetch them both - df_lhs, df_rhs = self.assert_inputs(rel, 2, context) + dc_lhs, dc_rhs = self.assert_inputs(rel, 2, context) + cc_lhs = dc_lhs.column_container + cc_rhs = dc_rhs.column_container # 2. dask's merge will do some smart things with columns, which have the same name # on lhs an rhs (which also includes reordering). # However, that will confuse our column numbering in SQL. # So we make our life easier by converting the column names into unique names # We will convert back in the end - df_lhs_renamed = self.make_unique(df_lhs, "lhs") - df_rhs_renamed = self.make_unique(df_rhs, "rhs") + cc_lhs_renamed = cc_lhs.make_unique("lhs") + cc_rhs_renamed = cc_rhs.make_unique("rhs") + + dc_lhs_renamed = DataContainer(dc_lhs.df, cc_lhs_renamed) + dc_rhs_renamed = DataContainer(dc_rhs.df, cc_rhs_renamed) + + df_lhs_renamed = dc_lhs_renamed.assign() + df_rhs_renamed = dc_rhs_renamed.assign() join_type = rel.getJoinType() join_type = self.JOIN_TYPE_MAPPING[str(join_type)] @@ -70,7 +79,7 @@ def convert( # The given column indices are for the full, merged table which consists # of lhs and rhs put side-by-side (in this order) # We therefore need to normalize the rhs indices relative to the rhs table. - rhs_on = [index - len(df_lhs.columns) for index in rhs_on] + rhs_on = [index - len(df_lhs_renamed.columns) for index in rhs_on] # 4. dask can only merge on the same column names. # We therefore create new columns on purpose, which have a distinct name. @@ -109,7 +118,22 @@ def convert( # 6. So the next step is to make sure # we have the correct column order (and to remove the temporary join columns) - df = df[list(df_lhs_renamed.columns) + list(df_rhs_renamed.columns)] + correct_column_order = list(df_lhs_renamed.columns) + list( + df_rhs_renamed.columns + ) + cc = ColumnContainer(df.columns).limit_to(correct_column_order) + + # and to rename them like the rel specifies + row_type = rel.getRowType() + field_specifications = [str(f) for f in row_type.getFieldNames()] + cc = cc.rename( + { + from_col: to_col + for from_col, to_col in zip(cc.columns, field_specifications) + } + ) + cc = self.fix_column_to_row_type(cc, rel.getRowType()) + dc = DataContainer(df, cc) # 7. Last but not least we apply any filters by and-chaining together the filters if filter_condition: @@ -117,16 +141,14 @@ def convert( filter_condition = reduce( operator.and_, [ - RexConverter.convert(rex, df, context=context) + RexConverter.convert(rex, dc, context=context) for rex in filter_condition ], ) df = df[filter_condition] + dc = DataContainer(df, cc) - # Now we go back to the names requested by the rel - df = self.fix_column_to_row_type(df, rel.getRowType()) - - return df + return dc def _split_join_condition( self, join_condition: "org.apache.calcite.rex.RexCall" diff --git a/dask_sql/physical/rel/logical/project.py b/dask_sql/physical/rel/logical/project.py index fd566ce29..f3cc7c36e 100644 --- a/dask_sql/physical/rel/logical/project.py +++ b/dask_sql/physical/rel/logical/project.py @@ -3,7 +3,10 @@ import dask.dataframe as dd from dask_sql.physical.rex import RexConverter +from dask_sql.physical.rex.core.input_ref import RexInputRefPlugin from dask_sql.physical.rel.base import BaseRelPlugin +from dask_sql.datacontainer import DataContainer +from dask_sql.java import get_java_class class LogicalProjectPlugin(BaseRelPlugin): @@ -17,22 +20,38 @@ class LogicalProjectPlugin(BaseRelPlugin): def convert( self, rel: "org.apache.calcite.rel.RelNode", context: "dask_sql.Context" - ) -> dd.DataFrame: + ) -> DataContainer: # Get the input of the previous step - (df,) = self.assert_inputs(rel, 1, context) + (dc,) = self.assert_inputs(rel, 1, context) - # It is easiest to just replace all columns with the new ones + df = dc.df + cc = dc.column_container + + # Collect all (new) columns named_projects = rel.getNamedProjects() + column_names = [] new_columns = {} for expr, key in named_projects: - new_columns[str(key)] = RexConverter.convert(expr, df, context=context) - - df = df.drop(columns=list(df.columns)).assign(**new_columns) + key = str(key) + column_names.append(key) + + # shortcut: if we have a column already, there is no need to re-assign it again + # this is only the case if the expr is a RexInputRef + if get_java_class(expr) == RexInputRefPlugin.class_name: + index = expr.getIndex() + backend_column_name = cc.get_backend_by_frontend_index(index) + cc = cc.add(key, backend_column_name) + else: + new_columns[key] = RexConverter.convert(expr, dc, context=context) + cc = cc.add(key, key) + + # Actually add the new columns + if new_columns: + df = df.assign(**new_columns) # Make sure the order is correct - column_names = list(new_columns.keys()) - df = df[column_names] + cc = cc.limit_to(column_names) - df = self.fix_column_to_row_type(df, rel.getRowType()) - return df + cc = self.fix_column_to_row_type(cc, rel.getRowType()) + return DataContainer(df, cc) diff --git a/dask_sql/physical/rel/logical/sort.py b/dask_sql/physical/rel/logical/sort.py index 34fcefa97..d39e5e37f 100644 --- a/dask_sql/physical/rel/logical/sort.py +++ b/dask_sql/physical/rel/logical/sort.py @@ -7,6 +7,7 @@ from dask_sql.physical.rex import RexConverter from dask_sql.physical.rel.base import BaseRelPlugin +from dask_sql.datacontainer import DataContainer class LogicalSortPlugin(BaseRelPlugin): @@ -20,12 +21,16 @@ class LogicalSortPlugin(BaseRelPlugin): def convert( self, rel: "org.apache.calcite.rel.RelNode", context: "dask_sql.Context" - ) -> dd.DataFrame: - (df,) = self.assert_inputs(rel, 1, context) - self.check_columns_from_row_type(df, rel.getExpectedInputRowType(0)) + ) -> DataContainer: + (dc,) = self.assert_inputs(rel, 1, context) + df = dc.df + cc = dc.column_container sort_collation = rel.getCollation().getFieldCollations() - sort_columns = [df.columns[int(x.getFieldIndex())] for x in sort_collation] + sort_columns = [ + cc.get_backend_by_frontend_index(int(x.getFieldIndex())) + for x in sort_collation + ] sort_ascending = [str(x.getDirection()) == "ASCENDING" for x in sort_collation] offset = rel.offset @@ -45,8 +50,8 @@ def convert( if offset is not None or end is not None: df = self._apply_offset(df, offset, end) - df = self.fix_column_to_row_type(df, rel.getRowType()) - return df + cc = self.fix_column_to_row_type(cc, rel.getRowType()) + return DataContainer(df, cc) def _apply_sort( self, df: dd.DataFrame, sort_columns: List[str], sort_ascending: List[bool] diff --git a/dask_sql/physical/rel/logical/table_scan.py b/dask_sql/physical/rel/logical/table_scan.py index c0b176d42..436d69440 100644 --- a/dask_sql/physical/rel/logical/table_scan.py +++ b/dask_sql/physical/rel/logical/table_scan.py @@ -1,8 +1,7 @@ from typing import Dict -import dask.dataframe as dd - from dask_sql.physical.rel.base import BaseRelPlugin +from dask_sql.datacontainer import DataContainer class LogicalTableScanPlugin(BaseRelPlugin): @@ -21,7 +20,7 @@ class LogicalTableScanPlugin(BaseRelPlugin): def convert( self, rel: "org.apache.calcite.rel.RelNode", context: "dask_sql.Context" - ) -> dd.DataFrame: + ) -> DataContainer: # There should not be any input. This is the first step. self.assert_inputs(rel, 0) @@ -38,12 +37,14 @@ def convert( table_name = table_names[1] table_name = table_name.lower() - df = context.tables[table_name] + dc = context.tables[table_name] + df = dc.df + cc = dc.column_container # Make sure we only return the requested columns row_type = table.getRowType() field_specifications = [str(f) for f in row_type.getFieldNames()] - df = df[field_specifications] + cc = cc.limit_to(field_specifications) - df = self.fix_column_to_row_type(df, rel.getRowType()) - return df + cc = self.fix_column_to_row_type(cc, rel.getRowType()) + return DataContainer(df, cc) diff --git a/dask_sql/physical/rel/logical/union.py b/dask_sql/physical/rel/logical/union.py index 862b8765f..b62cb3e15 100644 --- a/dask_sql/physical/rel/logical/union.py +++ b/dask_sql/physical/rel/logical/union.py @@ -4,6 +4,7 @@ from dask_sql.physical.rex import RexConverter from dask_sql.physical.rel.base import BaseRelPlugin +from dask_sql.datacontainer import DataContainer, ColumnContainer class LogicalUnionPlugin(BaseRelPlugin): @@ -16,25 +17,41 @@ class LogicalUnionPlugin(BaseRelPlugin): def convert( self, rel: "org.apache.calcite.rel.RelNode", context: "dask_sql.Context" - ) -> dd.DataFrame: - first_df, second_df = self.assert_inputs(rel, 2, context) + ) -> DataContainer: + first_dc, second_dc = self.assert_inputs(rel, 2, context) + + first_df = first_dc.df + first_cc = first_dc.column_container + + second_df = second_dc.df + second_cc = second_dc.column_container # For concatenating, they should have exactly the same fields output_field_names = [str(x) for x in rel.getRowType().getFieldNames()] - assert len(first_df.columns) == len(output_field_names) - first_df = first_df.rename( + assert len(first_cc.columns) == len(output_field_names) + first_cc = first_cc.rename( columns={ col: output_col - for col, output_col in zip(first_df.columns, output_field_names) + for col, output_col in zip(first_cc.columns, output_field_names) } ) - assert len(second_df.columns) == len(output_field_names) - second_df = second_df.rename( + first_dc = DataContainer(first_df, first_cc) + + assert len(second_cc.columns) == len(output_field_names) + second_cc = second_cc.rename( columns={ col: output_col - for col, output_col in zip(second_df.columns, output_field_names) + for col, output_col in zip(second_cc.columns, output_field_names) } ) + second_dc = DataContainer(second_df, second_cc) + + # To concat the to dataframes, we need to make sure the + # columns actually have the specified names in the + # column containers + # Otherwise the concat won't work + first_df = first_dc.assign() + second_df = second_dc.assign() self.check_columns_from_row_type(first_df, rel.getExpectedInputRowType(0)) self.check_columns_from_row_type(second_df, rel.getExpectedInputRowType(1)) @@ -44,5 +61,6 @@ def convert( if not rel.all: df = df.drop_duplicates() - df = self.fix_column_to_row_type(df, rel.getRowType()) - return df + cc = ColumnContainer(df.columns) + cc = self.fix_column_to_row_type(cc, rel.getRowType()) + return DataContainer(df, cc) diff --git a/dask_sql/physical/rel/logical/values.py b/dask_sql/physical/rel/logical/values.py index 7c396ea2b..833ebafb1 100644 --- a/dask_sql/physical/rel/logical/values.py +++ b/dask_sql/physical/rel/logical/values.py @@ -5,6 +5,7 @@ from dask_sql.physical.rex import RexConverter from dask_sql.physical.rel.base import BaseRelPlugin +from dask_sql.datacontainer import DataContainer, ColumnContainer class LogicalValuesPlugin(BaseRelPlugin): @@ -25,23 +26,30 @@ class LogicalValuesPlugin(BaseRelPlugin): def convert( self, rel: "org.apache.calcite.rel.RelNode", context: "dask_sql.Context" - ) -> dd.DataFrame: + ) -> DataContainer: # There should not be any input. This is the first step. self.assert_inputs(rel, 0) rex_expression_rows = list(rel.getTuples()) rows = [] - for rex_expressions in rex_expression_rows: + for rex_expression_row in rex_expression_rows: + # We convert each of the cells in the row + # using a RexConverter. + # As we do not have any information on the + # column headers, we just name them with + # their index. rows.append( - [ - RexConverter.convert(rex, None, context=context) - for rex in rex_expressions - ] + { + str(i): RexConverter.convert(rex_cell, None, context=context) + for i, rex_cell in enumerate(rex_expression_row) + } ) - # We assume here that when using the values plan, the resulting dataframe will be quite small # TODO: we explicitely reference pandas and dask here -> might we worth making this more general - df = dd.from_pandas(pd.DataFrame(rows), npartitions=1) - df = self.fix_column_to_row_type(df, rel.getRowType()) + # We assume here that when using the values plan, the resulting dataframe will be quite small + df = pd.DataFrame(rows) + df = dd.from_pandas(df, npartitions=1) + cc = ColumnContainer(df.columns) - return df + cc = self.fix_column_to_row_type(cc, rel.getRowType()) + return DataContainer(df, cc) diff --git a/dask_sql/physical/rex/base.py b/dask_sql/physical/rex/base.py index a93cc9bd5..840df4846 100644 --- a/dask_sql/physical/rex/base.py +++ b/dask_sql/physical/rex/base.py @@ -2,6 +2,8 @@ import dask.dataframe as dd +from dask_sql.datacontainer import DataContainer + class BaseRexPlugin: """ @@ -17,7 +19,7 @@ class BaseRexPlugin: def convert( self, rex: "org.apache.calcite.rex.RexNode", - df: dd.DataFrame, + dc: DataContainer, context: "dask_sql.Context", ) -> Union[dd.Series, Any]: """Base method to implement""" diff --git a/dask_sql/physical/rex/convert.py b/dask_sql/physical/rex/convert.py index 7d135e746..d30f0b802 100644 --- a/dask_sql/physical/rex/convert.py +++ b/dask_sql/physical/rex/convert.py @@ -5,6 +5,7 @@ from dask_sql.java import get_java_class from dask_sql.utils import Pluggable from dask_sql.physical.rex.base import BaseRexPlugin +from dask_sql.datacontainer import DataContainer class RexConverter(Pluggable): @@ -32,7 +33,7 @@ def add_plugin_class(cls, plugin_class: BaseRexPlugin, replace=True): def convert( cls, rex: "org.apache.calcite.rex.RexNode", - df: dd.DataFrame, + dc: DataContainer, context: "dask_sql.Context", ) -> Union[dd.DataFrame, Any]: """ @@ -50,5 +51,5 @@ def convert( f"No conversion for class {class_name} available (yet)." ) - df = plugin_instance.convert(rex, df=df, context=context) + df = plugin_instance.convert(rex, dc, context=context) return df diff --git a/dask_sql/physical/rex/core/call.py b/dask_sql/physical/rex/core/call.py index 544f2b6df..f55f90b46 100644 --- a/dask_sql/physical/rex/core/call.py +++ b/dask_sql/physical/rex/core/call.py @@ -10,6 +10,7 @@ from dask_sql.physical.rex import RexConverter from dask_sql.physical.rex.base import BaseRexPlugin from dask_sql.utils import is_frame +from dask_sql.datacontainer import DataContainer class Operation: @@ -202,12 +203,12 @@ class RexCallPlugin(BaseRexPlugin): def convert( self, rex: "org.apache.calcite.rex.RexNode", - df: dd.DataFrame, + dc: DataContainer, context: "dask_sql.Context", ) -> Union[dd.Series, Any]: # Prepare the operands by turning the RexNodes into python expressions operands = [ - RexConverter.convert(o, df, context=context) for o in rex.getOperands() + RexConverter.convert(o, dc, context=context) for o in rex.getOperands() ] # Now use the operator name in the mapping diff --git a/dask_sql/physical/rex/core/input_ref.py b/dask_sql/physical/rex/core/input_ref.py index 5457e0e5a..391065026 100644 --- a/dask_sql/physical/rex/core/input_ref.py +++ b/dask_sql/physical/rex/core/input_ref.py @@ -1,6 +1,7 @@ import dask.dataframe as dd from dask_sql.physical.rex.base import BaseRexPlugin +from dask_sql.datacontainer import DataContainer class RexInputRefPlugin(BaseRexPlugin): @@ -15,9 +16,13 @@ class RexInputRefPlugin(BaseRexPlugin): def convert( self, rex: "org.apache.calcite.rex.RexNode", - df: dd.DataFrame, + dc: DataContainer, context: "dask_sql.Context", ) -> dd.Series: + df = dc.df + cc = dc.column_container + # The column is references by index index = rex.getIndex() - return df.iloc[:, index] + backend_column_name = cc.get_backend_by_frontend_index(index) + return df[backend_column_name] diff --git a/dask_sql/physical/rex/core/literal.py b/dask_sql/physical/rex/core/literal.py index c7d10141e..f366d1ceb 100644 --- a/dask_sql/physical/rex/core/literal.py +++ b/dask_sql/physical/rex/core/literal.py @@ -1,9 +1,11 @@ from typing import Any +import numpy as np import dask.dataframe as dd from dask_sql.physical.rex.base import BaseRexPlugin from dask_sql.mappings import sql_to_python_value +from dask_sql.datacontainer import DataContainer class RexLiteralPlugin(BaseRexPlugin): @@ -21,7 +23,7 @@ class RexLiteralPlugin(BaseRexPlugin): def convert( self, rex: "org.apache.calcite.rex.RexNode", - df: dd.DataFrame, + dc: DataContainer, context: "dask_sql.Context", ) -> Any: literal_value = rex.getValue() diff --git a/tests/integration/test_select.py b/tests/integration/test_select.py index e92c27850..b5ca3ac54 100644 --- a/tests/integration/test_select.py +++ b/tests/integration/test_select.py @@ -14,6 +14,15 @@ def test_select(self): assert_frame_equal(df, self.df) + def test_select_alias(self): + df = self.c.sql("SELECT a as b, b as a FROM df") + df = df.compute() + + expected_df = self.df + expected_df.assign(a=self.df.b, b=self.df.a) + + assert_frame_equal(df, expected_df) + def test_select_column(self): df = self.c.sql("SELECT a FROM df") df = df.compute()