diff --git a/dask_sql/context.py b/dask_sql/context.py index c6030814c..8224fda42 100644 --- a/dask_sql/context.py +++ b/dask_sql/context.py @@ -86,6 +86,7 @@ def __init__(self): RelConverter.add_plugin_class(logical.LogicalTableScanPlugin, replace=False) RelConverter.add_plugin_class(logical.LogicalUnionPlugin, replace=False) RelConverter.add_plugin_class(logical.LogicalValuesPlugin, replace=False) + RelConverter.add_plugin_class(logical.LogicalWindowPlugin, replace=False) RelConverter.add_plugin_class(logical.SamplePlugin, replace=False) RelConverter.add_plugin_class(custom.AnalyzeTablePlugin, replace=False) RelConverter.add_plugin_class(custom.CreateExperimentPlugin, replace=False) @@ -108,7 +109,6 @@ def __init__(self): RexConverter.add_plugin_class(core.RexCallPlugin, replace=False) RexConverter.add_plugin_class(core.RexInputRefPlugin, replace=False) RexConverter.add_plugin_class(core.RexLiteralPlugin, replace=False) - RexConverter.add_plugin_class(core.RexOverPlugin, replace=False) InputUtil.add_plugin_class(input_utils.DaskInputPlugin, replace=False) InputUtil.add_plugin_class(input_utils.PandasInputPlugin, replace=False) @@ -427,7 +427,7 @@ def sql( cc = dc.column_container cc = cc.rename( { - df_col: df_col if not df_col.startswith("EXPR$") else select_name + df_col: select_name for df_col, select_name in zip(cc.columns, select_names) } ) @@ -711,12 +711,18 @@ def _get_ral(self, sql): sqlNode = generator.getSqlNode(sql) sqlNodeClass = get_java_class(sqlNode) - if sqlNodeClass.startswith("com.dask.sql.parser."): - rel = sqlNode - rel_string = "" - else: + select_names = None + rel = sqlNode + rel_string = "" + + if not sqlNodeClass.startswith("com.dask.sql.parser."): validatedSqlNode = generator.getValidatedNode(sqlNode) nonOptimizedRelNode = generator.getRelationalAlgebra(validatedSqlNode) + # Optimization might remove some alias projects. Make sure to keep them here. + select_names = [ + str(name) + for name in nonOptimizedRelNode.getRowType().getFieldNames() + ] rel = generator.getOptimizedRelationalAlgebra(nonOptimizedRelNode) rel_string = str(generator.getRelationalAlgebraString(rel)) except (ValidationException, SqlParseException) as e: @@ -741,13 +747,14 @@ def _get_ral(self, sql): if sqlNodeClass == "org.apache.calcite.sql.SqlSelect": select_names = [ self._to_sql_string(s, default_dialect=default_dialect) - for s in sqlNode.getSelectList() + if current_name.startswith("EXPR$") + else current_name + for s, current_name in zip(sqlNode.getSelectList(), select_names) ] else: logger.debug( "Not extracting output column names as the SQL is not a SELECT call" ) - select_names = None logger.debug(f"Extracted relational algebra:\n {rel_string}") return rel, select_names, rel_string diff --git a/dask_sql/physical/rel/logical/__init__.py b/dask_sql/physical/rel/logical/__init__.py index 99698157c..0cbe20f65 100644 --- a/dask_sql/physical/rel/logical/__init__.py +++ b/dask_sql/physical/rel/logical/__init__.py @@ -7,6 +7,7 @@ from .table_scan import LogicalTableScanPlugin from .union import LogicalUnionPlugin from .values import LogicalValuesPlugin +from .window import LogicalWindowPlugin __all__ = [ LogicalAggregatePlugin, @@ -17,5 +18,6 @@ LogicalTableScanPlugin, LogicalUnionPlugin, LogicalValuesPlugin, + LogicalWindowPlugin, SamplePlugin, ] diff --git a/dask_sql/physical/rel/logical/window.py b/dask_sql/physical/rel/logical/window.py new file mode 100644 index 000000000..2ecd10b24 --- /dev/null +++ b/dask_sql/physical/rel/logical/window.py @@ -0,0 +1,410 @@ +import logging +from collections import namedtuple +from functools import partial +from typing import Any, Callable, List, Optional, Tuple + +import dask.dataframe as dd +import numpy as np +import pandas as pd +from pandas.core.window.indexers import BaseIndexer + +from dask_sql.datacontainer import ColumnContainer, DataContainer +from dask_sql.java import org +from dask_sql.physical.rel.base import BaseRelPlugin +from dask_sql.physical.rex.convert import RexConverter +from dask_sql.physical.rex.core.literal import RexLiteralPlugin +from dask_sql.physical.utils.groupby import get_groupby_with_nulls_cols +from dask_sql.physical.utils.map import map_on_partition_index +from dask_sql.physical.utils.sort import sort_partition_func +from dask_sql.utils import ( + LoggableDataFrame, + make_pickable_without_dask_sql, + new_temporary_column, +) + +logger = logging.getLogger(__name__) + + +class OverOperation: + def __call__(self, partitioned_group, *args) -> pd.Series: + """Call the stored function""" + return self.call(partitioned_group, *args) + + +class FirstValueOperation(OverOperation): + def call(self, partitioned_group, value_col): + return partitioned_group[value_col].apply(lambda x: x.iloc[0]) + + +class LastValueOperation(OverOperation): + def call(self, partitioned_group, value_col): + return partitioned_group[value_col].apply(lambda x: x.iloc[-1]) + + +class SumOperation(OverOperation): + def call(self, partitioned_group, value_col): + return partitioned_group[value_col].sum() + + +class CountOperation(OverOperation): + def call(self, partitioned_group, value_col=None): + if value_col is None: + return partitioned_group.count().iloc[:, 0].fillna(0) + else: + return partitioned_group[value_col].count().fillna(0) + + +class MaxOperation(OverOperation): + def call(self, partitioned_group, value_col): + return partitioned_group[value_col].max() + + +class MinOperation(OverOperation): + def call(self, partitioned_group, value_col): + return partitioned_group[value_col].min() + + +class BoundDescription( + namedtuple( + "BoundDescription", + ["is_unbounded", "is_preceding", "is_following", "is_current_row", "offset"], + ) +): + """ + Small helper class to wrap a org.apache.calcite.rex.RexWindowBounds + Java object, as we can not ship it to to the dask workers + """ + + pass + + +def to_bound_description( + java_window: "org.apache.calcite.rex.RexWindowBounds.RexBoundedWindowBound", + constants: List[org.apache.calcite.rex.RexLiteral], + constant_count_offset: int, +) -> BoundDescription: + """Convert the java object "java_window" to a python representation, + replacing any literals or references to constants""" + offset = java_window.getOffset() + if offset: + if isinstance(offset, org.apache.calcite.rex.RexInputRef): + # For calcite, the constant pool are normal "columns", + # starting at (number of real columns + 1). + # Here, we do the de-referencing. + index = offset.getIndex() - constant_count_offset + offset = constants[index] + offset = int(RexLiteralPlugin().convert(offset, None, None)) + else: + offset = None + + return BoundDescription( + is_unbounded=bool(java_window.isUnbounded()), + is_preceding=bool(java_window.isPreceding()), + is_following=bool(java_window.isFollowing()), + is_current_row=bool(java_window.isCurrentRow()), + offset=offset, + ) + + +class Indexer(BaseIndexer): + """ + Window description used for complex windows with arbitrary start and end. + This class is directly taken from the fugue project. + """ + + def __init__(self, start: int, end: int): + super().__init__(self, start=start, end=end) + + def get_window_bounds( + self, + num_values: int = 0, + min_periods: Optional[int] = None, + center: Optional[bool] = None, + closed: Optional[str] = None, + ) -> Tuple[np.ndarray, np.ndarray]: + if self.start is None: + start = np.zeros(num_values, dtype=np.int64) + else: + start = np.arange(self.start, self.start + num_values, dtype=np.int64) + if self.start < 0: + start[: -self.start] = 0 + elif self.start > 0: + start[-self.start :] = num_values + if self.end is None: + end = np.full(num_values, num_values, dtype=np.int64) + else: + end = np.arange(self.end + 1, self.end + 1 + num_values, dtype=np.int64) + if self.end > 0: + end[-self.end :] = num_values + elif self.end < 0: + end[: -self.end] = 0 + else: # pragma: no cover + raise AssertionError( + "This case should have been handled before! Please report this bug" + ) + return start, end + + +def map_on_each_group( + partitioned_group: pd.DataFrame, + sort_columns: List[str], + sort_ascending: List[bool], + sort_null_first: List[bool], + lower_bound: BoundDescription, + upper_bound: BoundDescription, + operations: List[Tuple[Callable, str, List[str]]], +): + """Internal function mapped on each group of the dataframe after partitioning""" + # Apply sorting + if sort_columns: + partitioned_group = sort_partition_func( + partitioned_group, sort_columns, sort_ascending, sort_null_first + ) + + # Apply the windowing operation + if lower_bound.is_unbounded and ( + upper_bound.is_current_row or upper_bound.offset == 0 + ): + windowed_group = partitioned_group.expanding(min_periods=0) + elif lower_bound.is_preceding and ( + upper_bound.is_current_row or upper_bound.offset == 0 + ): + windowed_group = partitioned_group.rolling( + window=lower_bound.offset + 1, min_periods=0, + ) + else: + lower_offset = lower_bound.offset if not lower_bound.is_current_row else 0 + if lower_bound.is_preceding and lower_offset is not None: + lower_offset *= -1 + upper_offset = upper_bound.offset if not upper_bound.is_current_row else 0 + if upper_bound.is_preceding and upper_offset is not None: + upper_offset *= -1 + + indexer = Indexer(lower_offset, upper_offset) + windowed_group = partitioned_group.rolling(window=indexer, min_periods=0) + + # Calculate the results + new_columns = {} + for f, new_column_name, temporary_operand_columns in operations: + if f is None: + # This is the row_number operator. + # We do not need to do any windowing + column_result = range(1, len(partitioned_group) + 1) + else: + column_result = f(windowed_group, *temporary_operand_columns) + + new_columns[new_column_name] = column_result + + # Now apply all columns at once + partitioned_group = partitioned_group.assign(**new_columns) + return partitioned_group + + +class LogicalWindowPlugin(BaseRelPlugin): + """ + A LogicalWindow is an expression, which calculates a given function over the dataframe + while first optionally partitoning the data and optionally sorting it. + + Expressions like `F OVER (PARTITION BY x ORDER BY y)` apply f on each + partition separately and sort by y before applying f. The result of this + calculation has however the same length as the input dataframe - it is not an aggregation. + Typical examples include ROW_NUMBER and lagging. + """ + + class_name = "org.apache.calcite.rel.logical.LogicalWindow" + + OPERATION_MAPPING = { + "row_number": None, # That is the easiest one: we do not even need to have any windowing. We therefore threat it separately + "$sum0": SumOperation(), + "sum": SumOperation(), + # Is replaced by a sum and count by calcite: "avg": ExplodedOperation(AvgOperation()), + "count": CountOperation(), + "max": MaxOperation(), + "min": MinOperation(), + "single_value": FirstValueOperation(), + "first_value": FirstValueOperation(), + "last_value": LastValueOperation(), + } + + def convert( + self, rel: "org.apache.calcite.rel.RelNode", context: "dask_sql.Context" + ) -> DataContainer: + (dc,) = self.assert_inputs(rel, 1, context) + + # During optimization, some constants might end up in an internal + # constant pool. We need to dereference them here, as they + # are treated as "normal" columns. + # Unfortunately they are only referenced by their index, + # (which come after the real columns), so we need + # to always substract the number of real columns. + constants = list(rel.getConstants()) + constant_count_offset = len(dc.column_container.columns) + + # Output to the right field names right away + field_names = rel.getRowType().getFieldNames() + + for window in rel.groups: + dc = self._apply_window( + window, constants, constant_count_offset, dc, field_names, context + ) + + # Finally, fix the output schema if needed + df = dc.df + cc = dc.column_container + + cc = self.fix_column_to_row_type(cc, rel.getRowType()) + dc = DataContainer(df, cc) + dc = self.fix_dtype_to_row_type(dc, rel.getRowType()) + + return dc + + def _apply_window( + self, + window: org.apache.calcite.rel.core.Window.Group, + constants: List[org.apache.calcite.rex.RexLiteral], + constant_count_offset: int, + dc: DataContainer, + field_names: List[str], + context: "dask_sql.Context", + ): + temporary_columns = [] + + df = dc.df + cc = dc.column_container + + # Now extract the groupby and order information + sort_columns, sort_ascending, sort_null_first = self._extract_ordering( + window, cc + ) + logger.debug( + "Before applying the function, sorting according to {sort_columns}." + ) + + df, group_columns = self._extract_groupby(df, window, dc, context) + logger.debug( + f"Before applying the function, partitioning according to {group_columns}." + ) + # TODO: optimize by re-using already present columns + temporary_columns += group_columns + + operations, df = self._extract_operations(window, df, dc, context) + for _, _, cols in operations: + temporary_columns += cols + + newly_created_columns = [new_column for _, new_column, _ in operations] + + # Apply the windowing operation + filled_map = partial( + map_on_each_group, + sort_columns=sort_columns, + sort_ascending=sort_ascending, + sort_null_first=sort_null_first, + lower_bound=to_bound_description( + window.lowerBound, constants, constant_count_offset + ), + upper_bound=to_bound_description( + window.upperBound, constants, constant_count_offset + ), + operations=operations, + ) + + # TODO: That is a bit of a hack. We should really use the real column dtype + meta = df._meta.assign(**{col: 0.0 for col in newly_created_columns}) + + df = df.groupby(group_columns).apply( + make_pickable_without_dask_sql(filled_map), meta=meta + ) + df = df.drop(columns=temporary_columns).reset_index(drop=True) + + dc = DataContainer(df, cc) + df = dc.df + cc = dc.column_container + + for c in newly_created_columns: + # the fields are in the correct order by definition + field_name = field_names[len(cc.columns)] + cc = cc.add(field_name, c) + + dc = DataContainer(df, cc) + return dc + + def _extract_groupby( + self, + df: dd.DataFrame, + window: org.apache.calcite.rel.core.Window.Group, + dc: DataContainer, + context: "dask_sql.Context", + ) -> Tuple[dd.DataFrame, str]: + """Prepare grouping columns we can later use while applying the main function""" + partition_keys = list(window.keys) + if partition_keys: + group_columns = [ + df[dc.column_container.get_backend_by_frontend_index(o)] + for o in partition_keys + ] + group_columns = get_groupby_with_nulls_cols(df, group_columns) + group_columns = { + new_temporary_column(df): group_col for group_col in group_columns + } + else: + group_columns = {new_temporary_column(df): 1} + + df = df.assign(**group_columns) + group_columns = list(group_columns.keys()) + + return df, group_columns + + def _extract_ordering( + self, window: org.apache.calcite.rel.core.Window.Group, cc: ColumnContainer + ) -> Tuple[str, str, str]: + """Prepare sorting information we can later use while applying the main function""" + order_keys = list(window.orderKeys.getFieldCollations()) + sort_columns_indices = [int(i.getFieldIndex()) for i in order_keys] + sort_columns = [ + cc.get_backend_by_frontend_index(i) for i in sort_columns_indices + ] + + ASCENDING = org.apache.calcite.rel.RelFieldCollation.Direction.ASCENDING + FIRST = org.apache.calcite.rel.RelFieldCollation.NullDirection.FIRST + sort_ascending = [x.getDirection() == ASCENDING for x in order_keys] + sort_null_first = [x.nullDirection == FIRST for x in order_keys] + + return sort_columns, sort_ascending, sort_null_first + + def _extract_operations( + self, + window: org.apache.calcite.rel.core.Window.Group, + df: dd.DataFrame, + dc: DataContainer, + context: "dask_sql.Context", + ) -> List[Tuple[Callable, str, List[str]]]: + # Finally apply the actual function on each group separately + operations = [] + for agg_call in window.aggCalls: + operator = agg_call.getOperator() + operator_name = str(operator.getName()) + operator_name = operator_name.lower() + + try: + operation = self.OPERATION_MAPPING[operator_name] + except KeyError: # pragma: no cover + try: + operation = context.functions[operator_name] + except KeyError: # pragma: no cover + raise NotImplementedError(f"{operator_name} not (yet) implemented") + + logger.debug(f"Executing {operator_name} on {str(LoggableDataFrame(df))}") + + # TODO: can be optimized by re-using already present columns + temporary_operand_columns = { + new_temporary_column(df): RexConverter.convert(o, dc, context=context) + for o in agg_call.getOperands() + } + df = df.assign(**temporary_operand_columns) + temporary_operand_columns = list(temporary_operand_columns.keys()) + + operations.append( + (operation, new_temporary_column(df), temporary_operand_columns) + ) + + return operations, df diff --git a/dask_sql/physical/rex/core/__init__.py b/dask_sql/physical/rex/core/__init__.py index 1a5ecb6c1..9c34373eb 100644 --- a/dask_sql/physical/rex/core/__init__.py +++ b/dask_sql/physical/rex/core/__init__.py @@ -1,6 +1,5 @@ from .call import RexCallPlugin from .input_ref import RexInputRefPlugin from .literal import RexLiteralPlugin -from .over import RexOverPlugin -__all__ = [RexCallPlugin, RexInputRefPlugin, RexLiteralPlugin, RexOverPlugin] +__all__ = [RexCallPlugin, RexInputRefPlugin, RexLiteralPlugin] diff --git a/dask_sql/physical/rex/core/over.py b/dask_sql/physical/rex/core/over.py deleted file mode 100644 index 5fdb25b8b..000000000 --- a/dask_sql/physical/rex/core/over.py +++ /dev/null @@ -1,398 +0,0 @@ -import logging -from collections import namedtuple -from typing import Any, Callable, List, Optional, Tuple - -import dask.dataframe as dd -import numpy as np -import pandas as pd -from pandas.core.window.indexers import BaseIndexer - -from dask_sql.datacontainer import ColumnContainer, DataContainer -from dask_sql.java import org -from dask_sql.physical.rex.base import BaseRexPlugin -from dask_sql.physical.rex.convert import RexConverter -from dask_sql.physical.rex.core.literal import RexLiteralPlugin -from dask_sql.physical.utils.groupby import get_groupby_with_nulls_cols -from dask_sql.physical.utils.map import map_on_partition_index -from dask_sql.physical.utils.sort import sort_partition_func -from dask_sql.utils import ( - LoggableDataFrame, - make_pickable_without_dask_sql, - new_temporary_column, -) - -logger = logging.getLogger(__name__) - - -class OverOperation: - def __call__(self, partitioned_group, *args) -> pd.Series: - """Call the stored function""" - return self.call(partitioned_group, *args) - - -class FirstValueOperation(OverOperation): - def call(self, partitioned_group, value_col): - return partitioned_group[value_col].apply(lambda x: x.iloc[0]) - - -class LastValueOperation(OverOperation): - def call(self, partitioned_group, value_col): - return partitioned_group[value_col].apply(lambda x: x.iloc[-1]) - - -class SumOperation(OverOperation): - def call(self, partitioned_group, value_col): - return partitioned_group[value_col].sum() - - -class CountOperation(OverOperation): - def call(self, partitioned_group, value_col=None): - if value_col is None: - return partitioned_group.count().iloc[:, 0].fillna(0) - else: - return partitioned_group[value_col].count().fillna(0) - - -class MaxOperation(OverOperation): - def call(self, partitioned_group, value_col): - return partitioned_group[value_col].max() - - -class MinOperation(OverOperation): - def call(self, partitioned_group, value_col): - return partitioned_group[value_col].min() - - -class BoundDescription( - namedtuple( - "BoundDescription", - ["is_unbounded", "is_preceding", "is_following", "is_current_row", "offset"], - ) -): - """ - Small helper class to wrap a org.apache.calcite.rex.RexWindowBounds - Java object, as we can not ship it to to the dask workers - """ - - pass - - -def to_bound_description( - java_window: org.apache.calcite.rex.RexWindowBounds, -) -> BoundDescription: - offset = java_window.getOffset() - if offset: - offset = int(RexLiteralPlugin().convert(offset, None, None)) - else: - offset = None - - return BoundDescription( - is_unbounded=bool(java_window.isUnbounded()), - is_preceding=bool(java_window.isPreceding()), - is_following=bool(java_window.isFollowing()), - is_current_row=bool(java_window.isCurrentRow()), - offset=offset, - ) - - -class Indexer(BaseIndexer): - """ - Window description used for complex windows with arbitrary start and end. - This class is directly taken from the fugue project. - """ - - def __init__(self, start: int, end: int): - super().__init__(self, start=start, end=end) - - def get_window_bounds( - self, - num_values: int = 0, - min_periods: Optional[int] = None, - center: Optional[bool] = None, - closed: Optional[str] = None, - ) -> Tuple[np.ndarray, np.ndarray]: - if self.start is None: - start = np.zeros(num_values, dtype=np.int64) - else: - start = np.arange(self.start, self.start + num_values, dtype=np.int64) - if self.start < 0: - start[: -self.start] = 0 - elif self.start > 0: - start[-self.start :] = num_values - if self.end is None: - end = np.full(num_values, num_values, dtype=np.int64) - else: - end = np.arange(self.end + 1, self.end + 1 + num_values, dtype=np.int64) - if self.end > 0: - end[-self.end :] = num_values - elif self.end < 0: - end[: -self.end] = 0 - else: # pragma: no cover - raise AssertionError( - "This case should have been handled before! Please report this bug" - ) - return start, end - - -class RexOverPlugin(BaseRexPlugin): - """ - A RexOver is an expression, which calculates a given function over the dataframe - while first optionally partitoning the data and optionally sorting it. - - expressions like `F OVER (PARTITION BY x ORDER BY y)` apply f on each - partition separately and sort by y before applying f. The result of this - calculation has however the same length as the input dataframe - it is not an aggregation. - Typical examples include ROW_NUMBER and lagging. - """ - - class_name = "org.apache.calcite.rex.RexOver" - - OPERATION_MAPPING = { - "row_number": None, # That is the easiest one: we do not even need to have any windowing. We therefore threat it separately - "$sum0": SumOperation(), - "sum": SumOperation(), - # Is replaced by a sum and count by calcite: "avg": ExplodedOperation(AvgOperation()), - "count": CountOperation(), - "max": MaxOperation(), - "min": MinOperation(), - "single_value": FirstValueOperation(), - "first_value": FirstValueOperation(), - "last_value": LastValueOperation(), - } - - def convert( - self, - rex: "org.apache.calcite.rex.RexNode", - dc: DataContainer, - context: "dask_sql.Context", - ) -> Any: - window = rex.getWindow() - - df = dc.df - cc = dc.column_container - - # Store the divisions to apply them later again - known_divisions = df.divisions - - # Store the index and sort order to apply them later again - df, partition_col, index_col, sort_col = self._preserve_index_and_sort(df) - dc = DataContainer(df, cc) - - # Now extract the groupby and order information - sort_columns, sort_ascending, sort_null_first = self._extract_ordering( - window, cc - ) - logger.debug( - "Before applying the function, sorting according to {sort_columns}." - ) - - df, group_columns = self._extract_groupby(df, window, dc, context) - logger.debug( - f"Before applying the function, partitioning according to {group_columns}." - ) - - # Finally apply the actual function on each group separately - operator = rex.getOperator() - operator_name = str(operator.getName()) - operator_name = operator_name.lower() - - try: - operation = self.OPERATION_MAPPING[operator_name] - except KeyError: # pragma: no cover - try: - operation = context.functions[operator_name] - except KeyError: # pragma: no cover - raise NotImplementedError(f"{operator_name} not (yet) implemented") - - logger.debug(f"Executing {operator_name} on {str(LoggableDataFrame(df))}") - - # TODO: can be optimized by re-using already present columns - operands = [ - RexConverter.convert(o, dc, context=context) for o in rex.getOperands() - ] - - df, new_column_name = self._apply_function_over( - df, - operation, - operands, - window, - group_columns, - sort_columns, - sort_ascending, - sort_null_first, - ) - - # Revert back any sorting and grouping by using the previously stored information - df = self._revert_partition_and_order( - df, partition_col, index_col, sort_col, known_divisions - ) - - return df[new_column_name] - - def _preserve_index_and_sort( - self, df: dd.DataFrame - ) -> Tuple[dd.DataFrame, str, str, str]: - """Store the partition number, index and sort order separately to make any shuffling reversible""" - partition_col, index_col, sort_col = ( - new_temporary_column(df), - new_temporary_column(df), - new_temporary_column(df), - ) - - def store_index_columns(partition, partition_index): - return partition.assign( - **{ - partition_col: partition_index, - index_col: partition.index, - sort_col: range(len(partition)), - } - ) - - df = map_on_partition_index(df, store_index_columns) - - return df, partition_col, index_col, sort_col - - def _extract_groupby( - self, - df: dd.DataFrame, - window: org.apache.calcite.rex.RexWindow, - dc: DataContainer, - context: "dask_sql.Context", - ) -> Tuple[dd.DataFrame, str]: - """Prepare grouping columns we can later use while applying the main function""" - partition_keys = list(window.partitionKeys) - if partition_keys: - group_columns = [ - RexConverter.convert(o, dc, context=context) for o in partition_keys - ] - group_columns = get_groupby_with_nulls_cols(df, group_columns) - group_columns = { - new_temporary_column(df): group_col for group_col in group_columns - } - else: - group_columns = {new_temporary_column(df): 1} - - df = df.assign(**group_columns) - group_columns = list(group_columns.keys()) - - return df, group_columns - - def _extract_ordering( - self, window: org.apache.calcite.rex.RexWindow, cc: ColumnContainer - ) -> Tuple[str, str, str]: - """Prepare sorting information we can later use while applying the main function""" - order_keys = list(window.orderKeys) - sort_columns_indices = [int(i.getKey().getIndex()) for i in order_keys] - sort_columns = [ - cc.get_backend_by_frontend_index(i) for i in sort_columns_indices - ] - - ASCENDING = org.apache.calcite.rel.RelFieldCollation.Direction.ASCENDING - FIRST = org.apache.calcite.rel.RelFieldCollation.NullDirection.FIRST - sort_ascending = [x.getDirection() == ASCENDING for x in order_keys] - sort_null_first = [x.getNullDirection() == FIRST for x in order_keys] - - return sort_columns, sort_ascending, sort_null_first - - def _apply_function_over( - self, - df: dd.DataFrame, - f: Callable, - operands: List[dd.Series], - window: org.apache.calcite.rex.RexWindow, - group_columns: List[str], - sort_columns: List[str], - sort_ascending: List[bool], - sort_null_first: List[bool], - ) -> Tuple[dd.DataFrame, str]: - """Apply the given function over the dataframe, possibly grouped and sorted per group""" - temporary_operand_columns = { - new_temporary_column(df): operand for operand in operands - } - df = df.assign(**temporary_operand_columns) - # Important: move as few bytes as possible to the pickled function, - # which is evaluated on the workers - temporary_operand_columns = temporary_operand_columns.keys() - - # Extract the window definition - lower_bound = to_bound_description(window.getLowerBound()) - upper_bound = to_bound_description(window.getUpperBound()) - - new_column_name = new_temporary_column(df) - - @make_pickable_without_dask_sql - def map_on_each_group(partitioned_group): - # Apply sorting - if sort_columns: - partitioned_group = sort_partition_func( - partitioned_group, sort_columns, sort_ascending, sort_null_first - ) - - if f is None: - # This is the row_number operator. - # We do not need to do any windowing - column_result = range(1, len(partitioned_group) + 1) - else: - # In all other cases, apply the windowing operation - if lower_bound.is_unbounded and ( - upper_bound.is_current_row or upper_bound.offset == 0 - ): - windowed_group = partitioned_group.expanding(min_periods=0) - elif lower_bound.is_preceding and ( - upper_bound.is_current_row or upper_bound.offset == 0 - ): - windowed_group = partitioned_group.rolling( - window=lower_bound.offset + 1, min_periods=0, - ) - else: - lower_offset = ( - lower_bound.offset if not lower_bound.is_current_row else 0 - ) - if lower_bound.is_preceding and lower_offset is not None: - lower_offset *= -1 - upper_offset = ( - upper_bound.offset if not upper_bound.is_current_row else 0 - ) - if upper_bound.is_preceding and upper_offset is not None: - upper_offset *= -1 - - indexer = Indexer(lower_offset, upper_offset) - windowed_group = partitioned_group.rolling( - window=indexer, min_periods=0 - ) - - column_result = f(windowed_group, *temporary_operand_columns) - - partitioned_group = partitioned_group.assign( - **{new_column_name: column_result} - ) - - return partitioned_group - - # Currently, pandas will always return a float for windowing operations - meta = df._meta_nonempty.assign(**{new_column_name: 0.0}) - - df = df.groupby(group_columns).apply(map_on_each_group, meta=meta) - - return df, new_column_name - - def _revert_partition_and_order( - self, - df: dd.DataFrame, - partition_col: str, - index_col: str, - sort_col: str, - known_divisions: Any, - ) -> dd.DataFrame: - """Use the stored information to make revert the shuffling""" - from dask.dataframe.shuffle import set_partition - - divisions = tuple(range(len(known_divisions))) - df = set_partition(df, partition_col, divisions) - df = df.map_partitions( - lambda x: x.set_index(index_col, drop=True).sort_values(sort_col), - meta=df._meta.set_index(index_col), - ) - df.divisions = known_divisions - - return df diff --git a/planner/src/main/java/com/dask/sql/application/RelationalAlgebraGenerator.java b/planner/src/main/java/com/dask/sql/application/RelationalAlgebraGenerator.java index 65ecb20be..5c2fcdf4a 100644 --- a/planner/src/main/java/com/dask/sql/application/RelationalAlgebraGenerator.java +++ b/planner/src/main/java/com/dask/sql/application/RelationalAlgebraGenerator.java @@ -25,6 +25,7 @@ import org.apache.calcite.rel.RelNode; import org.apache.calcite.rel.rules.AggregateExpandDistinctAggregatesRule; import org.apache.calcite.rel.rules.AggregateReduceFunctionsRule; +import org.apache.calcite.rel.rules.CoreRules; import org.apache.calcite.rel.rules.FilterAggregateTransposeRule; import org.apache.calcite.rel.rules.FilterJoinRule; import org.apache.calcite.rel.rules.FilterMergeRule; @@ -196,24 +197,28 @@ private FrameworkConfig createFrameworkConfig(final SchemaPlus schemaPlus, SqlOp private HepPlanner createHepPlanner(final FrameworkConfig config) { final HepProgram program = new HepProgramBuilder() - .addRuleInstance(AggregateExpandDistinctAggregatesRule.Config.JOIN.toRule()) - .addRuleInstance(FilterAggregateTransposeRule.Config.DEFAULT.toRule()) - .addRuleInstance(FilterJoinRule.JoinConditionPushRule.Config.DEFAULT.toRule()) - .addRuleInstance(FilterJoinRule.FilterIntoJoinRule.Config.DEFAULT.toRule()) - .addRuleInstance(ProjectMergeRule.Config.DEFAULT.toRule()) - .addRuleInstance(FilterMergeRule.Config.DEFAULT.toRule()) - .addRuleInstance(ProjectJoinTransposeRule.Config.DEFAULT.toRule()) + .addRuleInstance(CoreRules.AGGREGATE_PROJECT_MERGE) + .addRuleInstance(CoreRules.AGGREGATE_PROJECT_PULL_UP_CONSTANTS) + .addRuleInstance(CoreRules.AGGREGATE_ANY_PULL_UP_CONSTANTS) + .addRuleInstance(CoreRules.AGGREGATE_REDUCE_FUNCTIONS) + .addRuleInstance(CoreRules.AGGREGATE_MERGE) + .addRuleInstance(CoreRules.AGGREGATE_EXPAND_DISTINCT_AGGREGATES_TO_JOIN) + .addRuleInstance(CoreRules.AGGREGATE_JOIN_REMOVE) + .addRuleInstance(CoreRules.FILTER_AGGREGATE_TRANSPOSE) + .addRuleInstance(CoreRules.JOIN_CONDITION_PUSH) + .addRuleInstance(CoreRules.FILTER_INTO_JOIN) + .addRuleInstance(CoreRules.PROJECT_MERGE) + .addRuleInstance(CoreRules.FILTER_MERGE) + .addRuleInstance(CoreRules.PROJECT_JOIN_TRANSPOSE) // In principle, not a bad idea. But we need to keep the most // outer project - because otherwise the column name information is lost // in cases such as SELECT x AS a, y AS B FROM df // .addRuleInstance(ProjectRemoveRule.Config.DEFAULT.toRule()) - .addRuleInstance(ReduceExpressionsRule.ProjectReduceExpressionsRule.Config.DEFAULT.toRule()) - // this rule might make sense, but turns a < 1 into a SEARCH expression - // which is currently not supported by dask-sql - // .addRuleInstance(ReduceExpressionsRule.FilterReduceExpressionsRule.Config.DEFAULT.toRule()) - .addRuleInstance(FilterRemoveIsNotDistinctFromRule.Config.DEFAULT.toRule()) - // TODO: remove AVG - .addRuleInstance(AggregateReduceFunctionsRule.Config.DEFAULT.toRule()).build(); + .addRuleInstance(CoreRules.PROJECT_REDUCE_EXPRESSIONS) + .addRuleInstance(CoreRules.FILTER_REDUCE_EXPRESSIONS) + .addRuleInstance(CoreRules.FILTER_EXPAND_IS_NOT_DISTINCT_FROM) + .addRuleInstance(CoreRules.PROJECT_TO_LOGICAL_PROJECT_AND_WINDOW) + .build(); return new HepPlanner(program, config.getContext()); } diff --git a/tests/integration/test_filter.py b/tests/integration/test_filter.py index c925687eb..b668d1c51 100644 --- a/tests/integration/test_filter.py +++ b/tests/integration/test_filter.py @@ -23,7 +23,7 @@ def test_filter_scalar(c, df): return_df = return_df.compute() expected_df = df.head(0) - assert_frame_equal(return_df, expected_df) + assert_frame_equal(return_df, expected_df, check_index_type=False) return_df = c.sql("SELECT * FROM df WHERE (1 = 1)") return_df = return_df.compute() @@ -35,7 +35,7 @@ def test_filter_scalar(c, df): return_df = return_df.compute() expected_df = df.head(0) - assert_frame_equal(return_df, expected_df) + assert_frame_equal(return_df, expected_df, check_index_type=False) def test_filter_complicated(c, df): diff --git a/tests/integration/test_over.py b/tests/integration/test_over.py index fbcf478da..9852be6a3 100644 --- a/tests/integration/test_over.py +++ b/tests/integration/test_over.py @@ -3,20 +3,31 @@ from pandas.testing import assert_frame_equal +def assert_frame_equal_after_sorting(df, expected_df, columns=None, **kwargs): + columns = columns or ["user_id", "b"] + + df = df.sort_values(columns).reset_index(drop=True) + expected_df = expected_df.sort_values(columns).reset_index(drop=True) + assert_frame_equal(df, expected_df, **kwargs) + + def test_over_with_sorting(c, user_table_1): df = c.sql( """ SELECT user_id, + b, ROW_NUMBER() OVER (ORDER BY user_id, b) AS R FROM user_table_1 """ ) df = df.compute() - expected_df = pd.DataFrame({"user_id": user_table_1.user_id, "R": [3, 1, 2, 4]}) + expected_df = pd.DataFrame( + {"user_id": user_table_1.user_id, "b": user_table_1.b, "R": [3, 1, 2, 4]} + ) expected_df["R"] = expected_df["R"].astype("Int64") - assert_frame_equal(df, expected_df) + assert_frame_equal_after_sorting(df, expected_df, columns=["user_id", "b"]) def test_over_with_partitioning(c, user_table_2): @@ -24,15 +35,18 @@ def test_over_with_partitioning(c, user_table_2): """ SELECT user_id, + c, ROW_NUMBER() OVER (PARTITION BY c) AS R FROM user_table_2 """ ) df = df.compute() - expected_df = pd.DataFrame({"user_id": user_table_2.user_id, "R": [1, 1, 1, 1]}) + expected_df = pd.DataFrame( + {"user_id": user_table_2.user_id, "c": user_table_2.c, "R": [1, 1, 1, 1]} + ) expected_df["R"] = expected_df["R"].astype("Int64") - assert_frame_equal(df, expected_df) + assert_frame_equal_after_sorting(df, expected_df, columns=["user_id", "c"]) def test_over_with_grouping_and_sort(c, user_table_1): @@ -40,15 +54,18 @@ def test_over_with_grouping_and_sort(c, user_table_1): """ SELECT user_id, + b, ROW_NUMBER() OVER (PARTITION BY user_id ORDER BY b) AS R FROM user_table_1 """ ) df = df.compute() - expected_df = pd.DataFrame({"user_id": user_table_1.user_id, "R": [2, 1, 1, 1]}) + expected_df = pd.DataFrame( + {"user_id": user_table_1.user_id, "b": user_table_1.b, "R": [2, 1, 1, 1]} + ) expected_df["R"] = expected_df["R"].astype("Int64") - assert_frame_equal(df, expected_df) + assert_frame_equal_after_sorting(df, expected_df, columns=["user_id", "b"]) def test_over_with_different(c, user_table_1): @@ -56,6 +73,7 @@ def test_over_with_different(c, user_table_1): """ SELECT user_id, + b, ROW_NUMBER() OVER (PARTITION BY user_id ORDER BY b) AS R1, ROW_NUMBER() OVER (ORDER BY user_id, b) AS R2 FROM user_table_1 @@ -64,17 +82,25 @@ def test_over_with_different(c, user_table_1): df = df.compute() expected_df = pd.DataFrame( - {"user_id": user_table_1.user_id, "R1": [2, 1, 1, 1], "R2": [3, 1, 2, 4],} + { + "user_id": user_table_1.user_id, + "b": user_table_1.b, + "R1": [2, 1, 1, 1], + "R2": [3, 1, 2, 4], + } ) for col in ["R1", "R2"]: expected_df[col] = expected_df[col].astype("Int64") - assert_frame_equal(df, expected_df) + + assert_frame_equal_after_sorting(df, expected_df, columns=["user_id", "b"]) -def test_over_calls(c): +def test_over_calls(c, user_table_1): df = c.sql( """ SELECT + user_id, + b, ROW_NUMBER() OVER (PARTITION BY user_id ORDER BY b) AS O1, FIRST_VALUE(user_id*10 - b) OVER (PARTITION BY user_id ORDER BY b) AS O2, SINGLE_VALUE(user_id*10 - b) OVER (PARTITION BY user_id ORDER BY b) AS O3, @@ -92,6 +118,8 @@ def test_over_calls(c): expected_df = pd.DataFrame( { + "user_id": user_table_1.user_id, + "b": user_table_1.b, "O1": [2, 1, 1, 1], "O2": [19, 7, 19, 27], "O3": [19, 7, 19, 27], @@ -105,11 +133,12 @@ def test_over_calls(c): } ) for col in expected_df.columns: - if col in ["06"]: + if col in ["06", "user_id", "b"]: continue expected_df[col] = expected_df[col].astype("Int64") - expected_df["O6"] = expected_df["O6"].astype("float64") - assert_frame_equal(df, expected_df) + expected_df["O6"] = expected_df["O6"].astype("Float64") + + assert_frame_equal_after_sorting(df, expected_df, columns=["user_id", "b"]) def test_over_with_windows(c): @@ -119,6 +148,7 @@ def test_over_with_windows(c): df = c.sql( """ SELECT + a, SUM(a) OVER (ORDER BY a ROWS BETWEEN 2 PRECEDING AND CURRENT ROW) AS O1, SUM(a) OVER (ORDER BY a ROWS BETWEEN 2 PRECEDING AND 3 FOLLOWING) AS O2, SUM(a) OVER (ORDER BY a ROWS BETWEEN 2 PRECEDING AND UNBOUNDED FOLLOWING) AS O3, @@ -136,6 +166,7 @@ def test_over_with_windows(c): expected_df = pd.DataFrame( { + "a": df.a, "O1": [0, 1, 3, 6, 9], "O2": [6, 10, 10, 10, 9], "O3": [10, 10, 10, 10, 9], @@ -149,5 +180,8 @@ def test_over_with_windows(c): } ) for col in expected_df.columns: + if col in ["a"]: + continue expected_df[col] = expected_df[col].astype("Int64") - assert_frame_equal(df, expected_df) + + assert_frame_equal_after_sorting(df, expected_df, columns=["a"])