diff --git a/dask_sql/__init__.py b/dask_sql/__init__.py index 566610005..985a7022c 100644 --- a/dask_sql/__init__.py +++ b/dask_sql/__init__.py @@ -1,6 +1,9 @@ from ._version import get_version from .cmd import cmd_loop from .context import Context +from .datacontainer import Statistics from .server.app import run_server __version__ = get_version() + +__all__ = [__version__, cmd_loop, Context, run_server, Statistics] diff --git a/dask_sql/context.py b/dask_sql/context.py index 787aca847..adce9eaf1 100644 --- a/dask_sql/context.py +++ b/dask_sql/context.py @@ -2,7 +2,7 @@ import inspect import logging import warnings -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Tuple, Union +from typing import Any, Callable, Dict, List, Tuple, Union import dask.dataframe as dd import pandas as pd @@ -20,28 +20,16 @@ DataContainer, FunctionDescription, SchemaContainer, + Statistics, ) from dask_sql.input_utils import InputType, InputUtil from dask_sql.integrations.ipython import ipython_integration -from dask_sql.java import ( - DaskAggregateFunction, - DaskScalarFunction, - DaskSchema, - DaskTable, - RelationalAlgebraGenerator, - RelationalAlgebraGeneratorBuilder, - SqlParseException, - ValidationException, - get_java_class, -) +from dask_sql.java import com, get_java_class, org from dask_sql.mappings import python_to_sql_type from dask_sql.physical.rel import RelConverter, custom, logical from dask_sql.physical.rex import RexConverter, core from dask_sql.utils import ParsingException -if TYPE_CHECKING: - from dask_sql.java import org - logger = logging.getLogger(__name__) @@ -90,15 +78,16 @@ def __init__(self): self.sql_server = None # Register any default plugins, if nothing was registered before. - RelConverter.add_plugin_class(logical.LogicalAggregatePlugin, replace=False) - RelConverter.add_plugin_class(logical.LogicalFilterPlugin, replace=False) - RelConverter.add_plugin_class(logical.LogicalJoinPlugin, replace=False) - RelConverter.add_plugin_class(logical.LogicalProjectPlugin, replace=False) - RelConverter.add_plugin_class(logical.LogicalSortPlugin, replace=False) - RelConverter.add_plugin_class(logical.LogicalTableScanPlugin, replace=False) - RelConverter.add_plugin_class(logical.LogicalUnionPlugin, replace=False) - RelConverter.add_plugin_class(logical.LogicalValuesPlugin, replace=False) - RelConverter.add_plugin_class(logical.LogicalWindowPlugin, replace=False) + RelConverter.add_plugin_class(logical.DaskAggregatePlugin, replace=False) + RelConverter.add_plugin_class(logical.DaskFilterPlugin, replace=False) + RelConverter.add_plugin_class(logical.DaskJoinPlugin, replace=False) + RelConverter.add_plugin_class(logical.DaskLimitPlugin, replace=False) + RelConverter.add_plugin_class(logical.DaskProjectPlugin, replace=False) + RelConverter.add_plugin_class(logical.DaskSortPlugin, replace=False) + RelConverter.add_plugin_class(logical.DaskTableScanPlugin, replace=False) + RelConverter.add_plugin_class(logical.DaskUnionPlugin, replace=False) + RelConverter.add_plugin_class(logical.DaskValuesPlugin, replace=False) + RelConverter.add_plugin_class(logical.DaskWindowPlugin, replace=False) RelConverter.add_plugin_class(logical.SamplePlugin, replace=False) RelConverter.add_plugin_class(custom.AnalyzeTablePlugin, replace=False) RelConverter.add_plugin_class(custom.CreateExperimentPlugin, replace=False) @@ -140,6 +129,7 @@ def create_table( format: str = None, persist: bool = False, schema_name: str = None, + statistics: Statistics = None, gpu: bool = False, **kwargs, ): @@ -202,6 +192,11 @@ def create_table( If set to "memory", load the data from a published dataset in the dask cluster. persist (:obj:`bool`): Only used when passing a string into the ``input`` parameter. Set to true to turn on loading the file data directly into memory. + schema_name: (:obj:`str`): in which schema to create the table. By default, will use the currently selected schema. + statistics: (:obj:`Statistics`): if given, use these statistics during the cost-based optimization. If no + statistics are provided, we will just assume 100 rows. + gpu: (:obj:`bool`): if set to true, use dask-cudf to run the data frame calculations on your GPU. + Please note that the GPU support is currently not covering all of dask-sql's SQL language. **kwargs: Additional arguments for specific formats. See :ref:`data_input` for more information. """ @@ -220,6 +215,8 @@ def create_table( **kwargs, ) self.schema[schema_name].tables[table_name.lower()] = dc + if statistics: + self.schema[schema_name].statistics[table_name.lower()] = statistics def register_dask_table(self, df: dd.DataFrame, name: str, *args, **kwargs): """ @@ -765,6 +762,11 @@ def _prepare_schemas(self): """ schema_list = [] + DaskTable = com.dask.sql.schema.DaskTable + DaskAggregateFunction = com.dask.sql.schema.DaskAggregateFunction + DaskScalarFunction = com.dask.sql.schema.DaskScalarFunction + DaskSchema = com.dask.sql.schema.DaskSchema + for schema_name, schema in self.schema.items(): java_schema = DaskSchema(schema_name) @@ -772,7 +774,14 @@ def _prepare_schemas(self): logger.warning("No tables are registered.") for name, dc in schema.tables.items(): - table = DaskTable(name) + row_count = ( + schema.statistics[name].row_count + if name in schema.statistics + else None + ) + if row_count is not None: + row_count = float(row_count) + table = DaskTable(name, row_count) df = dc.df logger.debug( f"Adding table '{name}' to schema with columns: {list(df.columns)}" @@ -824,6 +833,10 @@ def _get_ral(self, sql): # get the schema of what we currently have registered schemas = self._prepare_schemas() + RelationalAlgebraGeneratorBuilder = ( + com.dask.sql.application.RelationalAlgebraGeneratorBuilder + ) + # True if the SQL query should be case sensitive and False otherwise case_sensitive = ( self.schema[self.schema_name] @@ -835,12 +848,16 @@ def _get_ral(self, sql): self.schema_name, case_sensitive ) for schema in schemas: - generator_builder.addSchema(schema) + generator_builder = generator_builder.addSchema(schema) generator = generator_builder.build() default_dialect = generator.getDialect() logger.debug(f"Using dialect: {get_java_class(default_dialect)}") + ValidationException = org.apache.calcite.tools.ValidationException + SqlParseException = org.apache.calcite.sql.parser.SqlParseException + CalciteContextException = org.apache.calcite.runtime.CalciteContextException + try: sqlNode = generator.getSqlNode(sql) sqlNodeClass = get_java_class(sqlNode) @@ -850,8 +867,7 @@ def _get_ral(self, sql): rel_string = "" if not sqlNodeClass.startswith("com.dask.sql.parser."): - validatedSqlNode = generator.getValidatedNode(sqlNode) - nonOptimizedRelNode = generator.getRelationalAlgebra(validatedSqlNode) + nonOptimizedRelNode = generator.getRelationalAlgebra(sqlNode) # Optimization might remove some alias projects. Make sure to keep them here. select_names = [ str(name) @@ -859,7 +875,7 @@ def _get_ral(self, sql): ] rel = generator.getOptimizedRelationalAlgebra(nonOptimizedRelNode) rel_string = str(generator.getRelationalAlgebraString(rel)) - except (ValidationException, SqlParseException) as e: + except (ValidationException, SqlParseException, CalciteContextException) as e: logger.debug(f"Original exception raised by Java:\n {e}") # We do not want to re-raise an exception here # as this would print the full java stack trace @@ -895,7 +911,9 @@ def _get_ral(self, sql): def _to_sql_string(self, s: "org.apache.calcite.sql.SqlNode", default_dialect=None): if default_dialect is None: - default_dialect = RelationalAlgebraGenerator.getDialect() + default_dialect = ( + com.dask.sql.application.RelationalAlgebraGenerator.getDialect() + ) try: return str(s.toSqlString(default_dialect)) diff --git a/dask_sql/datacontainer.py b/dask_sql/datacontainer.py index eb2604fc1..f9605625f 100644 --- a/dask_sql/datacontainer.py +++ b/dask_sql/datacontainer.py @@ -231,10 +231,22 @@ def __hash__(self): return (self.func, self.row_udf).__hash__() +class Statistics: + """ + Statistics are used during the cost-based optimization. + Currently, only the row count is supported, more + properties might follow. It needs to be provided by the user. + """ + + def __init__(self, row_count: int) -> None: + self.row_count = row_count + + class SchemaContainer: def __init__(self, name: str): self.__name__ = name self.tables: Dict[str, DataContainer] = {} + self.statistics: Dict[str, Statistics] = {} self.experiments: Dict[str, pd.DataFrame] = {} self.models: Dict[str, Tuple[Any, List[str]]] = {} self.functions: Dict[str, UDF] = {} diff --git a/dask_sql/java.py b/dask_sql/java.py index 5101315e2..f4084e7c7 100644 --- a/dask_sql/java.py +++ b/dask_sql/java.py @@ -85,18 +85,6 @@ def _set_or_check_java_home(): org = jpype.JPackage("org") java = jpype.JPackage("java") -DaskTable = com.dask.sql.schema.DaskTable -DaskAggregateFunction = com.dask.sql.schema.DaskAggregateFunction -DaskScalarFunction = com.dask.sql.schema.DaskScalarFunction -DaskSchema = com.dask.sql.schema.DaskSchema -RelationalAlgebraGenerator = com.dask.sql.application.RelationalAlgebraGenerator -RelationalAlgebraGeneratorBuilder = ( - com.dask.sql.application.RelationalAlgebraGeneratorBuilder -) -SqlTypeName = org.apache.calcite.sql.type.SqlTypeName -ValidationException = org.apache.calcite.tools.ValidationException -SqlParseException = org.apache.calcite.sql.parser.SqlParseException - def get_java_class(instance): """Get the stringified class name of a java object""" diff --git a/dask_sql/mappings.py b/dask_sql/mappings.py index 5df6a42bd..09828998c 100644 --- a/dask_sql/mappings.py +++ b/dask_sql/mappings.py @@ -8,10 +8,10 @@ import pandas as pd from dask_sql._compat import FLOAT_NAN_IMPLEMENTED -from dask_sql.java import SqlTypeName +from dask_sql.java import org logger = logging.getLogger(__name__) - +SqlTypeName = org.apache.calcite.sql.type.SqlTypeName # Default mapping between python types and SQL types _PYTHON_TO_SQL = { diff --git a/dask_sql/physical/rel/base.py b/dask_sql/physical/rel/base.py index 20b2ee69b..ee9e65d28 100644 --- a/dask_sql/physical/rel/base.py +++ b/dask_sql/physical/rel/base.py @@ -42,6 +42,8 @@ def fix_column_to_row_type( """ field_names = [str(x) for x in row_type.getFieldNames()] + logger.debug(f"Renaming {cc.columns} to {field_names}") + cc = cc.rename(columns=dict(zip(cc.columns, field_names))) # TODO: We can also check for the types here and do any conversions if needed diff --git a/dask_sql/physical/rel/custom/create_experiment.py b/dask_sql/physical/rel/custom/create_experiment.py index 29537b4ff..1bfd27a89 100644 --- a/dask_sql/physical/rel/custom/create_experiment.py +++ b/dask_sql/physical/rel/custom/create_experiment.py @@ -5,12 +5,12 @@ import pandas as pd from dask_sql.datacontainer import ColumnContainer, DataContainer -from dask_sql.java import org from dask_sql.physical.rel.base import BaseRelPlugin from dask_sql.utils import convert_sql_kwargs, import_class if TYPE_CHECKING: import dask_sql + from dask_sql.java import org logger = logging.getLogger(__name__) diff --git a/dask_sql/physical/rel/custom/create_model.py b/dask_sql/physical/rel/custom/create_model.py index c6a56a5d7..5d9c8020c 100644 --- a/dask_sql/physical/rel/custom/create_model.py +++ b/dask_sql/physical/rel/custom/create_model.py @@ -4,12 +4,12 @@ from dask import delayed from dask_sql.datacontainer import DataContainer -from dask_sql.java import org from dask_sql.physical.rel.base import BaseRelPlugin from dask_sql.utils import convert_sql_kwargs, import_class if TYPE_CHECKING: import dask_sql + from dask_sql.java import org logger = logging.getLogger(__name__) diff --git a/dask_sql/physical/rel/logical/__init__.py b/dask_sql/physical/rel/logical/__init__.py index 0cbe20f65..4ce1aa365 100644 --- a/dask_sql/physical/rel/logical/__init__.py +++ b/dask_sql/physical/rel/logical/__init__.py @@ -1,23 +1,25 @@ -from .aggregate import LogicalAggregatePlugin -from .filter import LogicalFilterPlugin -from .join import LogicalJoinPlugin -from .project import LogicalProjectPlugin +from .aggregate import DaskAggregatePlugin +from .filter import DaskFilterPlugin +from .join import DaskJoinPlugin +from .limit import DaskLimitPlugin +from .project import DaskProjectPlugin from .sample import SamplePlugin -from .sort import LogicalSortPlugin -from .table_scan import LogicalTableScanPlugin -from .union import LogicalUnionPlugin -from .values import LogicalValuesPlugin -from .window import LogicalWindowPlugin +from .sort import DaskSortPlugin +from .table_scan import DaskTableScanPlugin +from .union import DaskUnionPlugin +from .values import DaskValuesPlugin +from .window import DaskWindowPlugin __all__ = [ - LogicalAggregatePlugin, - LogicalFilterPlugin, - LogicalJoinPlugin, - LogicalProjectPlugin, - LogicalSortPlugin, - LogicalTableScanPlugin, - LogicalUnionPlugin, - LogicalValuesPlugin, - LogicalWindowPlugin, + DaskAggregatePlugin, + DaskFilterPlugin, + DaskJoinPlugin, + DaskLimitPlugin, + DaskProjectPlugin, + DaskSortPlugin, + DaskTableScanPlugin, + DaskUnionPlugin, + DaskValuesPlugin, + DaskWindowPlugin, SamplePlugin, ] diff --git a/dask_sql/physical/rel/logical/aggregate.py b/dask_sql/physical/rel/logical/aggregate.py index cfdaca82e..d71458eed 100644 --- a/dask_sql/physical/rel/logical/aggregate.py +++ b/dask_sql/physical/rel/logical/aggregate.py @@ -95,9 +95,9 @@ def get_supported_aggregation(self, series): return self.custom_aggregation -class LogicalAggregatePlugin(BaseRelPlugin): +class DaskAggregatePlugin(BaseRelPlugin): """ - A LogicalAggregate is used in GROUP BY clauses, but also + A DaskAggregate is used in GROUP BY clauses, but also when aggregating a function over the full dataset. In the first case we need to find out which columns we need to @@ -119,7 +119,7 @@ class LogicalAggregatePlugin(BaseRelPlugin): these things via HINTs. """ - class_name = "org.apache.calcite.rel.logical.LogicalAggregate" + class_name = "com.dask.sql.nodes.DaskAggregate" AGGREGATION_MAPPING = { "$sum0": AggregationSpecification("sum", AggregationOnPandas("sum")), diff --git a/dask_sql/physical/rel/logical/filter.py b/dask_sql/physical/rel/logical/filter.py index 053f50ffa..87c99e3e0 100644 --- a/dask_sql/physical/rel/logical/filter.py +++ b/dask_sql/physical/rel/logical/filter.py @@ -34,13 +34,13 @@ def filter_or_scalar(df: dd.DataFrame, filter_condition: Union[np.bool_, dd.Seri return df[filter_condition] -class LogicalFilterPlugin(BaseRelPlugin): +class DaskFilterPlugin(BaseRelPlugin): """ - LogicalFilter is used on WHERE clauses. + DaskFilter is used on WHERE clauses. We just evaluate the filter (which is of type RexNode) and apply it """ - class_name = "org.apache.calcite.rel.logical.LogicalFilter" + class_name = "com.dask.sql.nodes.DaskFilter" def convert( self, rel: "org.apache.calcite.rel.RelNode", context: "dask_sql.Context" diff --git a/dask_sql/physical/rel/logical/join.py b/dask_sql/physical/rel/logical/join.py index 3b75e8e49..b84373d1f 100644 --- a/dask_sql/physical/rel/logical/join.py +++ b/dask_sql/physical/rel/logical/join.py @@ -20,9 +20,9 @@ logger = logging.getLogger(__name__) -class LogicalJoinPlugin(BaseRelPlugin): +class DaskJoinPlugin(BaseRelPlugin): """ - A LogicalJoin is used when (surprise) joining two tables. + A DaskJoin is used when (surprise) joining two tables. SQL allows for quite complicated joins with difficult conditions. dask/pandas only knows about equijoins on a specific column. @@ -36,7 +36,7 @@ class LogicalJoinPlugin(BaseRelPlugin): but so far, it is the only solution... """ - class_name = "org.apache.calcite.rel.logical.LogicalJoin" + class_name = "com.dask.sql.nodes.DaskJoin" JOIN_TYPE_MAPPING = { "INNER": "inner", diff --git a/dask_sql/physical/rel/logical/limit.py b/dask_sql/physical/rel/logical/limit.py new file mode 100644 index 000000000..468ac6f0d --- /dev/null +++ b/dask_sql/physical/rel/logical/limit.py @@ -0,0 +1,107 @@ +from typing import TYPE_CHECKING + +import dask.dataframe as dd + +from dask_sql.datacontainer import DataContainer +from dask_sql.physical.rel.base import BaseRelPlugin +from dask_sql.physical.rex import RexConverter +from dask_sql.physical.utils.map import map_on_partition_index + +if TYPE_CHECKING: + import dask_sql + from dask_sql.java import org + + +class DaskLimitPlugin(BaseRelPlugin): + """ + Limit is used to only get a certain part of the dataframe + (LIMIT). + """ + + class_name = "com.dask.sql.nodes.DaskLimit" + + def convert( + self, rel: "org.apache.calcite.rel.RelNode", context: "dask_sql.Context" + ) -> DataContainer: + (dc,) = self.assert_inputs(rel, 1, context) + df = dc.df + cc = dc.column_container + + offset = rel.getOffset() + if offset: + offset = RexConverter.convert(offset, df, context=context) + + end = rel.getFetch() + if end: + end = RexConverter.convert(end, df, context=context) + + if offset: + end += offset + + df = self._apply_offset(df, offset, end) + + cc = self.fix_column_to_row_type(cc, rel.getRowType()) + # No column type has changed, so no need to cast again + return DataContainer(df, cc) + + def _apply_offset(self, df: dd.DataFrame, offset: int, end: int) -> dd.DataFrame: + """ + Limit the dataframe to the window [offset, end]. + That is unfortunately, not so simple as we do not know how many + items we have in each partition. We have therefore no other way than to + calculate (!!!) the sizes of each partition. + + After that, we can create a new dataframe from the old + dataframe by calculating for each partition if and how much + it should be used. + We do this via generating our own dask computation graph as + we need to pass the partition number to the selection + function, which is not possible with normal "map_partitions". + """ + df = df.persist() + if not offset: + # We do a (hopefully) very quick check: if the first partition + # is already enough, we will just use this + first_partition_length = len(df.partitions[0]) + if first_partition_length >= end: + return df.head(end, compute=False) + + # First, we need to find out which partitions we want to use. + # Therefore we count the total number of entries + partition_borders = df.map_partitions(lambda x: len(x)) + + # Now we let each of the partitions figure out, how much it needs to return + # using these partition borders + # For this, we generate out own dask computation graph (as it does not really + # fit well with one of the already present methods). + + # (a) we define a method to be calculated on each partition + # This method returns the part of the partition, which falls between [offset, fetch] + # Please note that the dask object "partition_borders", will be turned into + # its pandas representation at this point and we can calculate the cumsum + # (which is not possible on the dask object). Recalculating it should not cost + # us much, as we assume the number of partitions is rather small. + def select_from_to(df, partition_index, partition_borders): + partition_borders = partition_borders.cumsum().to_dict() + this_partition_border_left = ( + partition_borders[partition_index - 1] if partition_index > 0 else 0 + ) + this_partition_border_right = partition_borders[partition_index] + + if (end and end < this_partition_border_left) or ( + offset and offset >= this_partition_border_right + ): + return df.iloc[0:0] + + from_index = max(offset - this_partition_border_left, 0) if offset else 0 + to_index = ( + min(end, this_partition_border_right) + if end + else this_partition_border_right + ) - this_partition_border_left + + return df.iloc[from_index:to_index] + + # (b) Now we just need to apply the function on every partition + # We do this via the delayed interface, which seems the easiest one. + return map_on_partition_index(df, select_from_to, partition_borders) diff --git a/dask_sql/physical/rel/logical/project.py b/dask_sql/physical/rel/logical/project.py index 737e795f2..7e1c71d03 100644 --- a/dask_sql/physical/rel/logical/project.py +++ b/dask_sql/physical/rel/logical/project.py @@ -13,14 +13,14 @@ logger = logging.getLogger(__name__) -class LogicalProjectPlugin(BaseRelPlugin): +class DaskProjectPlugin(BaseRelPlugin): """ - A LogicalProject is used to + A DaskProject is used to (a) apply expressions to the columns and (b) only select a subset of the columns """ - class_name = "org.apache.calcite.rel.logical.LogicalProject" + class_name = "com.dask.sql.nodes.DaskProject" def convert( self, rel: "org.apache.calcite.rel.RelNode", context: "dask_sql.Context" diff --git a/dask_sql/physical/rel/logical/sample.py b/dask_sql/physical/rel/logical/sample.py index f78d46f25..87db98cb1 100644 --- a/dask_sql/physical/rel/logical/sample.py +++ b/dask_sql/physical/rel/logical/sample.py @@ -32,7 +32,7 @@ class SamplePlugin(BaseRelPlugin): the expected. """ - class_name = "org.apache.calcite.rel.core.Sample" + class_name = "com.dask.sql.nodes.DaskSample" def convert( self, rel: "org.apache.calcite.rel.RelNode", context: "dask_sql.Context" diff --git a/dask_sql/physical/rel/logical/sort.py b/dask_sql/physical/rel/logical/sort.py index f10fc2531..94101bca9 100644 --- a/dask_sql/physical/rel/logical/sort.py +++ b/dask_sql/physical/rel/logical/sort.py @@ -1,26 +1,20 @@ from typing import TYPE_CHECKING -import dask.dataframe as dd - from dask_sql.datacontainer import DataContainer from dask_sql.java import org from dask_sql.physical.rel.base import BaseRelPlugin -from dask_sql.physical.rex import RexConverter -from dask_sql.physical.utils.map import map_on_partition_index from dask_sql.physical.utils.sort import apply_sort if TYPE_CHECKING: import dask_sql -class LogicalSortPlugin(BaseRelPlugin): +class DaskSortPlugin(BaseRelPlugin): """ - LogicalSort is used to sort by columns (ORDER BY) - as well as to only get a certain part of the dataframe - (LIMIT). + DaskSort is used to sort by columns (ORDER BY). """ - class_name = "org.apache.calcite.rel.logical.LogicalSort" + class_name = "com.dask.sql.nodes.DaskSort" def convert( self, rel: "org.apache.calcite.rel.RelNode", context: "dask_sql.Context" @@ -35,91 +29,14 @@ def convert( for x in sort_collation ] - if sort_columns: - ASCENDING = org.apache.calcite.rel.RelFieldCollation.Direction.ASCENDING - FIRST = org.apache.calcite.rel.RelFieldCollation.NullDirection.FIRST - sort_ascending = [x.getDirection() == ASCENDING for x in sort_collation] - sort_null_first = [x.nullDirection == FIRST for x in sort_collation] - - df = df.persist() - df = apply_sort(df, sort_columns, sort_ascending, sort_null_first) - - offset = rel.offset - if offset: - offset = RexConverter.convert(offset, df, context=context) - - end = rel.fetch - if end: - end = RexConverter.convert(end, df, context=context) + ASCENDING = org.apache.calcite.rel.RelFieldCollation.Direction.ASCENDING + FIRST = org.apache.calcite.rel.RelFieldCollation.NullDirection.FIRST + sort_ascending = [x.getDirection() == ASCENDING for x in sort_collation] + sort_null_first = [x.nullDirection == FIRST for x in sort_collation] - if offset: - end += offset - - if offset is not None or end is not None: - df = self._apply_offset(df, offset, end) + df = df.persist() + df = apply_sort(df, sort_columns, sort_ascending, sort_null_first) cc = self.fix_column_to_row_type(cc, rel.getRowType()) # No column type has changed, so no need to cast again return DataContainer(df, cc) - - def _apply_offset(self, df: dd.DataFrame, offset: int, end: int) -> dd.DataFrame: - """ - Limit the dataframe to the window [offset, end]. - That is unfortunately, not so simple as we do not know how many - items we have in each partition. We have therefore no other way than to - calculate (!!!) the sizes of each partition. - - After that, we can create a new dataframe from the old - dataframe by calculating for each partition if and how much - it should be used. - We do this via generating our own dask computation graph as - we need to pass the partition number to the selection - function, which is not possible with normal "map_partitions". - """ - df = df.persist() - if not offset: - # We do a (hopefully) very quick check: if the first partition - # is already enough, we will just use this - first_partition_length = len(df.partitions[0]) - if first_partition_length >= end: - return df.head(end, compute=False) - - # First, we need to find out which partitions we want to use. - # Therefore we count the total number of entries - partition_borders = df.map_partitions(lambda x: len(x)) - - # Now we let each of the partitions figure out, how much it needs to return - # using these partition borders - # For this, we generate out own dask computation graph (as it does not really - # fit well with one of the already present methods). - - # (a) we define a method to be calculated on each partition - # This method returns the part of the partition, which falls between [offset, fetch] - # Please note that the dask object "partition_borders", will be turned into - # its pandas representation at this point and we can calculate the cumsum - # (which is not possible on the dask object). Recalculating it should not cost - # us much, as we assume the number of partitions is rather small. - def select_from_to(df, partition_index, partition_borders): - partition_borders = partition_borders.cumsum().to_dict() - this_partition_border_left = ( - partition_borders[partition_index - 1] if partition_index > 0 else 0 - ) - this_partition_border_right = partition_borders[partition_index] - - if (end and end < this_partition_border_left) or ( - offset and offset >= this_partition_border_right - ): - return df.iloc[0:0] - - from_index = max(offset - this_partition_border_left, 0) if offset else 0 - to_index = ( - min(end, this_partition_border_right) - if end - else this_partition_border_right - ) - this_partition_border_left - - return df.iloc[from_index:to_index] - - # (b) Now we just need to apply the function on every partition - # We do this via the delayed interface, which seems the easiest one. - return map_on_partition_index(df, select_from_to, partition_borders) diff --git a/dask_sql/physical/rel/logical/table_scan.py b/dask_sql/physical/rel/logical/table_scan.py index 378858a05..8453ab1f8 100644 --- a/dask_sql/physical/rel/logical/table_scan.py +++ b/dask_sql/physical/rel/logical/table_scan.py @@ -8,9 +8,9 @@ from dask_sql.java import org -class LogicalTableScanPlugin(BaseRelPlugin): +class DaskTableScanPlugin(BaseRelPlugin): """ - A LogicalTableScal is the main ingredient: it will get the data + A DaskTableScal is the main ingredient: it will get the data from the database. It is always used, when the SQL looks like SELECT .... FROM table .... @@ -20,7 +20,7 @@ class LogicalTableScanPlugin(BaseRelPlugin): Calcite will always refer to columns via index. """ - class_name = "org.apache.calcite.rel.logical.LogicalTableScan" + class_name = "com.dask.sql.nodes.DaskTableScan" def convert( self, rel: "org.apache.calcite.rel.RelNode", context: "dask_sql.Context" diff --git a/dask_sql/physical/rel/logical/union.py b/dask_sql/physical/rel/logical/union.py index b8153fc28..39e96b861 100644 --- a/dask_sql/physical/rel/logical/union.py +++ b/dask_sql/physical/rel/logical/union.py @@ -10,13 +10,13 @@ from dask_sql.java import org -class LogicalUnionPlugin(BaseRelPlugin): +class DaskUnionPlugin(BaseRelPlugin): """ - LogicalUnion is used on UNION clauses. + DaskUnion is used on UNION clauses. It just concatonates the two data frames. """ - class_name = "org.apache.calcite.rel.logical.LogicalUnion" + class_name = "com.dask.sql.nodes.DaskUnion" def convert( self, rel: "org.apache.calcite.rel.RelNode", context: "dask_sql.Context" diff --git a/dask_sql/physical/rel/logical/values.py b/dask_sql/physical/rel/logical/values.py index 1d8363201..ca95375c9 100644 --- a/dask_sql/physical/rel/logical/values.py +++ b/dask_sql/physical/rel/logical/values.py @@ -12,9 +12,9 @@ from dask_sql.java import org -class LogicalValuesPlugin(BaseRelPlugin): +class DaskValuesPlugin(BaseRelPlugin): """ - A LogicalValue is a table just consisting of + A DaskValue is a table just consisting of raw values (nothing database-dependent). For example @@ -26,7 +26,7 @@ class LogicalValuesPlugin(BaseRelPlugin): data samples. """ - class_name = "org.apache.calcite.rel.logical.LogicalValues" + class_name = "com.dask.sql.nodes.DaskValues" def convert( self, rel: "org.apache.calcite.rel.RelNode", context: "dask_sql.Context" diff --git a/dask_sql/physical/rel/logical/window.py b/dask_sql/physical/rel/logical/window.py index 7826d7a86..b4d4ee815 100644 --- a/dask_sql/physical/rel/logical/window.py +++ b/dask_sql/physical/rel/logical/window.py @@ -206,9 +206,9 @@ def map_on_each_group( return partitioned_group -class LogicalWindowPlugin(BaseRelPlugin): +class DaskWindowPlugin(BaseRelPlugin): """ - A LogicalWindow is an expression, which calculates a given function over the dataframe + A DaskWindow is an expression, which calculates a given function over the dataframe while first optionally partitoning the data and optionally sorting it. Expressions like `F OVER (PARTITION BY x ORDER BY y)` apply f on each @@ -217,7 +217,7 @@ class LogicalWindowPlugin(BaseRelPlugin): Typical examples include ROW_NUMBER and lagging. """ - class_name = "org.apache.calcite.rel.logical.LogicalWindow" + class_name = "com.dask.sql.nodes.DaskWindow" OPERATION_MAPPING = { "row_number": None, # That is the easiest one: we do not even need to have any windowing. We therefore threat it separately @@ -249,7 +249,7 @@ def convert( # Output to the right field names right away field_names = rel.getRowType().getFieldNames() - for window in rel.groups: + for window in rel.getGroups(): dc = self._apply_window( window, constants, constant_count_offset, dc, field_names, context ) @@ -283,7 +283,7 @@ def _apply_window( window, cc ) logger.debug( - "Before applying the function, sorting according to {sort_columns}." + f"Before applying the function, sorting according to {sort_columns}." ) df, group_columns = self._extract_groupby(df, window, dc, context) @@ -299,6 +299,8 @@ def _apply_window( newly_created_columns = [new_column for _, new_column, _ in operations] + logger.debug(f"Will create {newly_created_columns} new columns") + # Apply the windowing operation filled_map = partial( map_on_each_group, @@ -320,6 +322,9 @@ def _apply_window( df = df.groupby(group_columns).apply( make_pickable_without_dask_sql(filled_map), meta=meta ) + logger.debug( + f"Having created a dataframe {LoggableDataFrame(df)} after windowing. Will now drop {temporary_columns}." + ) df = df.drop(columns=temporary_columns).reset_index(drop=True) dc = DataContainer(df, cc) @@ -332,6 +337,9 @@ def _apply_window( cc = cc.add(field_name, c) dc = DataContainer(df, cc) + logger.debug( + f"Removed unneeded columns and registered new ones: {LoggableDataFrame(dc)}." + ) return dc def _extract_groupby( diff --git a/dask_sql/physical/rex/base.py b/dask_sql/physical/rex/base.py index 530449143..c5f6097da 100644 --- a/dask_sql/physical/rex/base.py +++ b/dask_sql/physical/rex/base.py @@ -1,10 +1,12 @@ -from typing import Any, Union +from typing import TYPE_CHECKING, Any, Union import dask.dataframe as dd import dask_sql from dask_sql.datacontainer import DataContainer -from dask_sql.java import org + +if TYPE_CHECKING: + from dask_sql.java import org class BaseRexPlugin: @@ -20,7 +22,7 @@ class BaseRexPlugin: def convert( self, - rex: org.apache.calcite.rex.RexNode, + rex: "org.apache.calcite.rex.RexNode", dc: DataContainer, context: "dask_sql.Context", ) -> Union[dd.Series, Any]: diff --git a/dask_sql/physical/rex/core/call.py b/dask_sql/physical/rex/core/call.py index 4b042ede0..48b9a73c3 100644 --- a/dask_sql/physical/rex/core/call.py +++ b/dask_sql/physical/rex/core/call.py @@ -1,3 +1,4 @@ +import datetime import logging import operator import re @@ -148,6 +149,32 @@ def div(self, lhs, rhs, rex=None): return result +class IntDivisionOperator(Operation): + """ + Truncated integer division (so -1 / 2 = 0). + This is only used for internal calculations, + which are created by Calcite. + """ + + def __init__(self): + super().__init__(self.div) + + def div(self, lhs, rhs): + result = lhs / rhs + + # Specialized code for literals like "1000µs" + # For some reasons, Calcite decides to represent + # 1000µs as 1000µs * 1000 / 1000 + # We do not need to truncate in this case + # So far, I did not spot any other occurrence + # of this function. + if isinstance(result, datetime.timedelta): + return result + else: # pragma: no cover + result = da.trunc(result) + return result + + class CaseOperation(Operation): """The case operator (basically an if then else)""" @@ -193,7 +220,7 @@ def __init__(self): super().__init__(self.cast) def cast(self, operand, rex=None) -> SeriesOrScalar: - if not is_frame(operand): + if not is_frame(operand): # pragma: no cover return operand output_type = str(rex.getType()) @@ -701,6 +728,7 @@ class RexCallPlugin(BaseRexPlugin): "*": ReduceOperation(operation=operator.mul), "is distinct from": NotOperation().of(IsNotDistinctOperation()), "is not distinct from": IsNotDistinctOperation(), + "/int": IntDivisionOperator(), # special operations "cast": CastOperation(), "case": CaseOperation(), diff --git a/dask_sql/utils.py b/dask_sql/utils.py index a67b4deda..61c5c724d 100644 --- a/dask_sql/utils.py +++ b/dask_sql/utils.py @@ -182,6 +182,8 @@ def __str__(self): df = self.df if isinstance(df, pd.Series) or isinstance(df, dd.Series): return f"Series: {(df.name, df.dtype)}" + if isinstance(df, pd.DataFrame) or isinstance(df, dd.DataFrame): + return f"DataFrame: {[(col, dtype) for col, dtype in zip(df.columns, df.dtypes)]}" elif isinstance(df, DataContainer): cols = df.column_container.columns diff --git a/planner/src/main/java/com/dask/sql/application/DaskPlanner.java b/planner/src/main/java/com/dask/sql/application/DaskPlanner.java new file mode 100644 index 000000000..47e7599b8 --- /dev/null +++ b/planner/src/main/java/com/dask/sql/application/DaskPlanner.java @@ -0,0 +1,77 @@ +package com.dask.sql.application; + +import com.dask.sql.rules.DaskAggregateRule; +import com.dask.sql.rules.DaskFilterRule; +import com.dask.sql.rules.DaskJoinRule; +import com.dask.sql.rules.DaskProjectRule; +import com.dask.sql.rules.DaskSampleRule; +import com.dask.sql.rules.DaskSortLimitRule; +import com.dask.sql.rules.DaskTableScanRule; +import com.dask.sql.rules.DaskUnionRule; +import com.dask.sql.rules.DaskValuesRule; +import com.dask.sql.rules.DaskWindowRule; + +import org.apache.calcite.plan.ConventionTraitDef; +import org.apache.calcite.plan.volcano.VolcanoPlanner; +import org.apache.calcite.rel.rules.CoreRules; +import org.apache.calcite.rel.rules.DateRangeRules; +import org.apache.calcite.rel.rules.JoinPushThroughJoinRule; +import org.apache.calcite.rel.rules.PruneEmptyRules; + +/** + * DaskPlanner is a cost-based optimizer based on the Calcite VolcanoPlanner. + * + * Its only difference to the raw Volcano planner, are the predefined rules (for + * converting logical into dask nodes and some basic core rules so far), as well + * as the null executor. + */ +public class DaskPlanner extends VolcanoPlanner { + public DaskPlanner() { + // Allow transformation between logical and dask nodes + addRule(DaskAggregateRule.INSTANCE); + addRule(DaskFilterRule.INSTANCE); + addRule(DaskJoinRule.INSTANCE); + addRule(DaskProjectRule.INSTANCE); + addRule(DaskSampleRule.INSTANCE); + addRule(DaskSortLimitRule.INSTANCE); + addRule(DaskTableScanRule.INSTANCE); + addRule(DaskUnionRule.INSTANCE); + addRule(DaskValuesRule.INSTANCE); + addRule(DaskWindowRule.INSTANCE); + + // Set of core rules + addRule(PruneEmptyRules.UNION_INSTANCE); + addRule(PruneEmptyRules.INTERSECT_INSTANCE); + addRule(PruneEmptyRules.MINUS_INSTANCE); + addRule(PruneEmptyRules.PROJECT_INSTANCE); + addRule(PruneEmptyRules.FILTER_INSTANCE); + addRule(PruneEmptyRules.SORT_INSTANCE); + addRule(PruneEmptyRules.AGGREGATE_INSTANCE); + addRule(PruneEmptyRules.JOIN_LEFT_INSTANCE); + addRule(PruneEmptyRules.JOIN_RIGHT_INSTANCE); + addRule(PruneEmptyRules.SORT_FETCH_ZERO_INSTANCE); + addRule(DateRangeRules.FILTER_INSTANCE); + addRule(CoreRules.INTERSECT_TO_DISTINCT); + addRule(CoreRules.PROJECT_FILTER_TRANSPOSE); + addRule(CoreRules.FILTER_PROJECT_TRANSPOSE); + addRule(CoreRules.FILTER_INTO_JOIN); + addRule(CoreRules.JOIN_CONDITION_PUSH); + addRule(CoreRules.JOIN_PUSH_EXPRESSIONS); + addRule(CoreRules.FILTER_AGGREGATE_TRANSPOSE); + addRule(CoreRules.PROJECT_WINDOW_TRANSPOSE); + addRule(CoreRules.JOIN_COMMUTE); + addRule(CoreRules.FILTER_INTO_JOIN); + addRule(CoreRules.PROJECT_JOIN_TRANSPOSE); + addRule(JoinPushThroughJoinRule.RIGHT); + addRule(JoinPushThroughJoinRule.LEFT); + addRule(CoreRules.SORT_PROJECT_TRANSPOSE); + addRule(CoreRules.SORT_JOIN_TRANSPOSE); + addRule(CoreRules.SORT_UNION_TRANSPOSE); + + // Enable conventions to turn from logical to dask + addRelTraitDef(ConventionTraitDef.INSTANCE); + + // We do not want to execute any SQL + setExecutor(null); + } +} diff --git a/planner/src/main/java/com/dask/sql/application/DaskProgram.java b/planner/src/main/java/com/dask/sql/application/DaskProgram.java new file mode 100644 index 000000000..054a6e7c4 --- /dev/null +++ b/planner/src/main/java/com/dask/sql/application/DaskProgram.java @@ -0,0 +1,151 @@ +package com.dask.sql.application; + +import java.util.Arrays; +import java.util.List; + +import com.dask.sql.nodes.DaskRel; + +import org.apache.calcite.plan.RelOptLattice; +import org.apache.calcite.plan.RelOptMaterialization; +import org.apache.calcite.plan.RelOptPlanner; +import org.apache.calcite.plan.RelOptRule; +import org.apache.calcite.plan.RelTraitSet; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.RelFactories; +import org.apache.calcite.rel.metadata.DefaultRelMetadataProvider; +import org.apache.calcite.rel.rules.CoreRules; +import org.apache.calcite.sql2rel.RelDecorrelator; +import org.apache.calcite.sql2rel.RelFieldTrimmer; +import org.apache.calcite.tools.Program; +import org.apache.calcite.tools.Programs; +import org.apache.calcite.tools.RelBuilder; + +/** + * DaskProgram is the optimization program which is executed on a tree of + * relational algebras. + * + * It consists of five steps: decorrelation, trimming (removing unneeded + * fields), executing a fixed set of rules, converting into the correct trait + * (which means in our case: from logical to dask) and finally a cost-based + * optimization. + */ +public class DaskProgram { + private final Program mainProgram; + + public DaskProgram(RelOptPlanner planner) { + final DecorrelateProgram decorrelateProgram = new DecorrelateProgram(); + final TrimFieldsProgram trimProgram = new TrimFieldsProgram(); + final FixedRulesProgram fixedRulesProgram = new FixedRulesProgram(); + final ConvertProgram convertProgram = new ConvertProgram(planner); + final CostBasedOptimizationProgram costBasedOptimizationProgram = new CostBasedOptimizationProgram(planner); + + this.mainProgram = Programs.sequence(decorrelateProgram, trimProgram, fixedRulesProgram, convertProgram, + costBasedOptimizationProgram); + } + + public RelNode run(RelNode rel) { + final RelTraitSet desiredTraits = rel.getTraitSet().replace(DaskRel.CONVENTION).simplify(); + return this.mainProgram.run(null, rel, desiredTraits, Arrays.asList(), Arrays.asList()); + } + + /** + * DaskProgramWrapper is a helper for auto-filling unneeded arguments in + * Programs + */ + private static interface DaskProgramWrapper extends Program { + public RelNode run(RelNode rel, RelTraitSet relTraitSet); + + @Override + public default RelNode run(RelOptPlanner planner, RelNode rel, RelTraitSet requiredOutputTraits, + List materializations, List lattices) { + return run(rel, requiredOutputTraits); + } + } + + /** + * DecorrelateProgram decorrelates a query, by tunring them into e.g. JOINs + */ + private static class DecorrelateProgram implements DaskProgramWrapper { + @Override + public RelNode run(RelNode rel, RelTraitSet relTraitSet) { + final RelBuilder relBuilder = RelFactories.LOGICAL_BUILDER.create(rel.getCluster(), null); + return RelDecorrelator.decorrelateQuery(rel, relBuilder); + } + } + + /** + * TrimFieldsProgram removes unneeded fields from the REL steps + */ + private static class TrimFieldsProgram implements DaskProgramWrapper { + public RelNode run(RelNode rel, RelTraitSet relTraitSet) { + final RelBuilder relBuilder = RelFactories.LOGICAL_BUILDER.create(rel.getCluster(), null); + return new RelFieldTrimmer(null, relBuilder).trim(rel); + } + } + + /** + * FixedRulesProgram applies a fixed set of conversion rules, which we always + */ + private static class FixedRulesProgram implements Program { + static private final List RULES = Arrays.asList(CoreRules.AGGREGATE_PROJECT_PULL_UP_CONSTANTS, + CoreRules.AGGREGATE_ANY_PULL_UP_CONSTANTS, CoreRules.AGGREGATE_PROJECT_MERGE, + CoreRules.AGGREGATE_REDUCE_FUNCTIONS, CoreRules.AGGREGATE_MERGE, + CoreRules.AGGREGATE_EXPAND_DISTINCT_AGGREGATES_TO_JOIN, CoreRules.AGGREGATE_JOIN_REMOVE, + CoreRules.PROJECT_MERGE, CoreRules.FILTER_MERGE, CoreRules.PROJECT_REMOVE, + CoreRules.PROJECT_REDUCE_EXPRESSIONS, CoreRules.FILTER_REDUCE_EXPRESSIONS, + CoreRules.FILTER_EXPAND_IS_NOT_DISTINCT_FROM, CoreRules.PROJECT_TO_LOGICAL_PROJECT_AND_WINDOW); + + @Override + public RelNode run(RelOptPlanner planner, RelNode rel, RelTraitSet requiredOutputTraits, + List materializations, List lattices) { + + Program fixedRulesProgram = Programs.hep(RULES, true, DefaultRelMetadataProvider.INSTANCE); + return fixedRulesProgram.run(planner, rel, requiredOutputTraits, materializations, lattices); + } + } + + /** + * ConvertProgram marks the rel as "to-be-converted-into-the-dask-trait" by the + * upcoming volcano planner. + */ + private static class ConvertProgram implements DaskProgramWrapper { + private RelOptPlanner planner; + + public ConvertProgram(RelOptPlanner planner) { + this.planner = planner; + } + + @Override + public RelNode run(final RelNode rel, final RelTraitSet requiredOutputTraits) { + planner.setRoot(rel); + + if (rel.getTraitSet().equals(requiredOutputTraits)) { + return rel; + } + final RelNode convertedRel = planner.changeTraits(rel, requiredOutputTraits); + assert convertedRel != null; + return convertedRel; + } + } + + /** + * CostBasedOptimizationProgram applies a cost-based optimization, which can + * take the size of the inputs into account (not implemented now). + */ + private static class CostBasedOptimizationProgram implements DaskProgramWrapper { + private RelOptPlanner planner; + + public CostBasedOptimizationProgram(RelOptPlanner planner) { + this.planner = planner; + } + + @Override + public RelNode run(RelNode rel, RelTraitSet requiredOutputTraits) { + planner.setRoot(rel); + final RelOptPlanner planner2 = planner.chooseDelegate(); + final RelNode optimizedRel = planner2.findBestExp(); + assert optimizedRel != null : "could not implement exp"; + return optimizedRel; + } + } +} diff --git a/planner/src/main/java/com/dask/sql/application/DaskSqlDialect.java b/planner/src/main/java/com/dask/sql/application/DaskSqlDialect.java index e2254bb15..88a4aab2e 100644 --- a/planner/src/main/java/com/dask/sql/application/DaskSqlDialect.java +++ b/planner/src/main/java/com/dask/sql/application/DaskSqlDialect.java @@ -9,6 +9,10 @@ import org.apache.calcite.sql.dialect.PostgresqlSqlDialect; import org.apache.calcite.sql.type.SqlTypeName; +/** + * DaskSqlDialect is the specific SQL dialect we use in dask-sql. It is mainly a + * postgreSQL dialect with a bit tuning. + */ public class DaskSqlDialect { public static final RelDataTypeSystem DASKSQL_TYPE_SYSTEM = new RelDataTypeSystemImpl() { @Override diff --git a/planner/src/main/java/com/dask/sql/application/DaskSqlParser.java b/planner/src/main/java/com/dask/sql/application/DaskSqlParser.java new file mode 100644 index 000000000..14dda852f --- /dev/null +++ b/planner/src/main/java/com/dask/sql/application/DaskSqlParser.java @@ -0,0 +1,26 @@ +package com.dask.sql.application; + +import org.apache.calcite.sql.SqlNode; +import org.apache.calcite.sql.parser.SqlParseException; +import org.apache.calcite.sql.parser.SqlParser; +import org.apache.calcite.sql.validate.SqlConformanceEnum; + +/** + * DaskSqlParser can turn a SQL string into a tree of SqlNodes. It uses the + * SqlParser from calcite for this. + */ +public class DaskSqlParser { + private SqlParser.Config DEFAULT_CONFIG; + + public DaskSqlParser() { + DEFAULT_CONFIG = DaskSqlDialect.DEFAULT.configureParser(SqlParser.Config.DEFAULT) + .withConformance(SqlConformanceEnum.DEFAULT) + .withParserFactory(new DaskSqlParserImplFactory()); + } + + public SqlNode parse(String sql) throws SqlParseException { + final SqlParser parser = SqlParser.create(sql, DEFAULT_CONFIG); + final SqlNode sqlNode = parser.parseStmt(); + return sqlNode; + } +} diff --git a/planner/src/main/java/com/dask/sql/application/DaskSqlParserImplFactory.java b/planner/src/main/java/com/dask/sql/application/DaskSqlParserImplFactory.java index 249e8ac43..2fe143f5b 100644 --- a/planner/src/main/java/com/dask/sql/application/DaskSqlParserImplFactory.java +++ b/planner/src/main/java/com/dask/sql/application/DaskSqlParserImplFactory.java @@ -6,6 +6,10 @@ import org.apache.calcite.sql.parser.SqlParserImplFactory; import com.dask.sql.parser.DaskSqlParserImpl; +/** + * DaskSqlParserImplFactory is the bridge between the Java code written by us + * and the code generated by the freetype code template. + */ public class DaskSqlParserImplFactory implements SqlParserImplFactory { @Override diff --git a/planner/src/main/java/com/dask/sql/application/DaskSqlToRelConverter.java b/planner/src/main/java/com/dask/sql/application/DaskSqlToRelConverter.java new file mode 100644 index 000000000..6f825880d --- /dev/null +++ b/planner/src/main/java/com/dask/sql/application/DaskSqlToRelConverter.java @@ -0,0 +1,135 @@ +package com.dask.sql.application; + +import java.sql.Connection; +import java.sql.DriverManager; +import java.sql.SQLException; +import java.util.ArrayList; +import java.util.List; +import java.util.Properties; + +import com.dask.sql.schema.DaskSchema; + +import org.apache.calcite.adapter.java.JavaTypeFactory; +import org.apache.calcite.config.CalciteConnectionConfig; +import org.apache.calcite.config.CalciteConnectionConfigImpl; +import org.apache.calcite.jdbc.CalciteConnection; +import org.apache.calcite.jdbc.CalciteSchema; +import org.apache.calcite.jdbc.JavaTypeFactoryImpl; +import org.apache.calcite.plan.RelOptCluster; +import org.apache.calcite.plan.RelOptPlanner; +import org.apache.calcite.prepare.CalciteCatalogReader; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rex.RexBuilder; +import org.apache.calcite.schema.SchemaPlus; +import org.apache.calcite.sql.SqlNode; +import org.apache.calcite.sql.SqlOperatorTable; +import org.apache.calcite.sql.fun.SqlLibrary; +import org.apache.calcite.sql.fun.SqlLibraryOperatorTableFactory; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.apache.calcite.sql.util.SqlOperatorTables; +import org.apache.calcite.sql.validate.SqlValidator; +import org.apache.calcite.sql.validate.SqlValidatorImpl; +import org.apache.calcite.sql2rel.SqlToRelConverter; +import org.apache.calcite.sql2rel.StandardConvertletTable; + +/** + * DaskSqlToRelConverter turns a tree of SqlNodes into a first, suboptimal + * version of relational algebra nodes. It needs to know about all the tables + * and schemas, therefore the configuration looks a bit more complicated. + */ +public class DaskSqlToRelConverter { + private final SqlToRelConverter sqlToRelConverter; + private final boolean caseSensitive; + + public DaskSqlToRelConverter(final RelOptPlanner optimizer, final String rootSchemaName, + final List schemas, boolean caseSensitive) throws SQLException { + this.caseSensitive = caseSensitive; + final SchemaPlus rootSchema = createRootSchema(rootSchemaName, schemas); + + final JavaTypeFactoryImpl typeFactory = createTypeFactory(); + final CalciteCatalogReader calciteCatalogReader = createCatalogReader(rootSchemaName, rootSchema, typeFactory, this.caseSensitive); + final SqlValidator validator = createValidator(typeFactory, calciteCatalogReader); + final RelOptCluster cluster = RelOptCluster.create(optimizer, new RexBuilder(typeFactory)); + final SqlToRelConverter.Config config = SqlToRelConverter.config().withTrimUnusedFields(true).withExpand(true); + + this.sqlToRelConverter = new SqlToRelConverter(null, validator, calciteCatalogReader, cluster, + StandardConvertletTable.INSTANCE, config); + } + + public RelNode convert(SqlNode sqlNode) { + RelNode root = sqlToRelConverter.convertQuery(sqlNode, true, true).project(true); + return root; + } + + private JavaTypeFactoryImpl createTypeFactory() { + return new JavaTypeFactoryImpl(DaskSqlDialect.DASKSQL_TYPE_SYSTEM); + } + + private SqlValidator createValidator(final JavaTypeFactoryImpl typeFactory, + final CalciteCatalogReader calciteCatalogReader) { + final SqlOperatorTable operatorTable = createOperatorTable(calciteCatalogReader); + final CalciteConnectionConfig connectionConfig = calciteCatalogReader.getConfig(); + final SqlValidator.Config validatorConfig = SqlValidator.Config.DEFAULT + .withLenientOperatorLookup(connectionConfig.lenientOperatorLookup()) + .withSqlConformance(connectionConfig.conformance()) + .withDefaultNullCollation(connectionConfig.defaultNullCollation()).withIdentifierExpansion(true); + final SqlValidator validator = new CalciteSqlValidator(operatorTable, calciteCatalogReader, typeFactory, + validatorConfig); + return validator; + } + + private SchemaPlus createRootSchema(final String rootSchemaName, final List schemas) + throws SQLException { + final CalciteConnection calciteConnection = createConnection(rootSchemaName); + final SchemaPlus rootSchema = calciteConnection.getRootSchema(); + for (DaskSchema schema : schemas) { + rootSchema.add(schema.getName(), schema); + } + return rootSchema; + } + + private CalciteConnection createConnection(final String schemaName) throws SQLException { + // Taken from https://calcite.apache.org/docs/ + final Properties info = new Properties(); + info.setProperty("lex", "JAVA"); + + final Connection connection = DriverManager.getConnection("jdbc:calcite:", info); + final CalciteConnection calciteConnection = connection.unwrap(CalciteConnection.class); + + calciteConnection.setSchema(schemaName); + return calciteConnection; + } + + private CalciteCatalogReader createCatalogReader(final String schemaName, final SchemaPlus schemaPlus, + final JavaTypeFactoryImpl typeFactory, final boolean caseSensitive) throws SQLException { + final CalciteSchema calciteSchema = CalciteSchema.from(schemaPlus); + + final Properties props = new Properties(); + props.setProperty("defaultSchema", schemaName); + props.setProperty("caseSensitive", String.valueOf(caseSensitive)); + + final List defaultSchema = new ArrayList(); + defaultSchema.add(schemaName); + + final CalciteCatalogReader calciteCatalogReader = new CalciteCatalogReader(calciteSchema, defaultSchema, + typeFactory, new CalciteConnectionConfigImpl(props)); + return calciteCatalogReader; + } + + private SqlOperatorTable createOperatorTable(final CalciteCatalogReader calciteCatalogReader) { + final List sqlOperatorTables = new ArrayList<>(); + sqlOperatorTables.add(SqlStdOperatorTable.instance()); + sqlOperatorTables.add(SqlLibraryOperatorTableFactory.INSTANCE.getOperatorTable(SqlLibrary.POSTGRESQL)); + sqlOperatorTables.add(calciteCatalogReader); + + SqlOperatorTable operatorTable = SqlOperatorTables.chain(sqlOperatorTables); + return operatorTable; + } + + static class CalciteSqlValidator extends SqlValidatorImpl { + CalciteSqlValidator(SqlOperatorTable opTab, CalciteCatalogReader catalogReader, JavaTypeFactory typeFactory, + Config config) { + super(opTab, catalogReader, typeFactory, config); + } + } +} diff --git a/planner/src/main/java/com/dask/sql/application/RelationalAlgebraGenerator.java b/planner/src/main/java/com/dask/sql/application/RelationalAlgebraGenerator.java index e4e04de1b..dd8eee5e7 100644 --- a/planner/src/main/java/com/dask/sql/application/RelationalAlgebraGenerator.java +++ b/planner/src/main/java/com/dask/sql/application/RelationalAlgebraGenerator.java @@ -1,227 +1,58 @@ package com.dask.sql.application; -import java.sql.Connection; -import java.sql.DriverManager; import java.sql.SQLException; -import java.util.ArrayList; import java.util.List; -import java.util.Properties; import com.dask.sql.schema.DaskSchema; -import org.apache.calcite.sql.SqlDialect; -import org.apache.calcite.config.CalciteConnectionConfig; -import org.apache.calcite.config.CalciteConnectionConfigImpl; -import org.apache.calcite.config.CalciteConnectionProperty; -import org.apache.calcite.jdbc.CalciteConnection; -import org.apache.calcite.jdbc.CalciteSchema; -import org.apache.calcite.jdbc.JavaTypeFactoryImpl; -import org.apache.calcite.plan.Context; -import org.apache.calcite.plan.Contexts; + import org.apache.calcite.plan.RelOptUtil; -import org.apache.calcite.plan.hep.HepPlanner; -import org.apache.calcite.plan.hep.HepProgram; -import org.apache.calcite.plan.hep.HepProgramBuilder; -import org.apache.calcite.prepare.CalciteCatalogReader; import org.apache.calcite.rel.RelNode; -import org.apache.calcite.rel.rules.AggregateExpandDistinctAggregatesRule; -import org.apache.calcite.rel.rules.AggregateReduceFunctionsRule; -import org.apache.calcite.rel.rules.CoreRules; -import org.apache.calcite.rel.rules.FilterAggregateTransposeRule; -import org.apache.calcite.rel.rules.FilterJoinRule; -import org.apache.calcite.rel.rules.FilterMergeRule; -import org.apache.calcite.rel.rules.FilterRemoveIsNotDistinctFromRule; -import org.apache.calcite.rel.rules.ProjectJoinTransposeRule; -import org.apache.calcite.rel.rules.ProjectMergeRule; -import org.apache.calcite.rel.rules.ReduceExpressionsRule; -import org.apache.calcite.rex.RexExecutorImpl; -import org.apache.calcite.schema.SchemaPlus; +import org.apache.calcite.sql.SqlDialect; +import org.apache.calcite.sql.SqlExplainLevel; import org.apache.calcite.sql.SqlNode; -import org.apache.calcite.sql.SqlOperatorTable; -import org.apache.calcite.sql.fun.SqlLibrary; -import org.apache.calcite.sql.fun.SqlLibraryOperatorTableFactory; -import org.apache.calcite.sql.fun.SqlStdOperatorTable; import org.apache.calcite.sql.parser.SqlParseException; -import org.apache.calcite.sql.parser.SqlParser; -import org.apache.calcite.sql.parser.SqlParser.Config; -import org.apache.calcite.sql.util.SqlOperatorTables; -import org.apache.calcite.sql.validate.SqlConformanceEnum; -import org.apache.calcite.tools.FrameworkConfig; -import org.apache.calcite.tools.Frameworks; -import org.apache.calcite.tools.Planner; import org.apache.calcite.tools.RelConversionException; import org.apache.calcite.tools.ValidationException; /** - * The core of the calcite program: the generator for the relational algebra. - * Using a passed schema, it generates (optimized) relational algebra out of SQL - * query strings or throws an exception. - * - * This class is taken (in parts) from the blazingSQL project. + * The core of the calcite program which has references to all other helper + * classes. Using a passed schema, it generates (optimized) relational algebra + * out of SQL query strings or throws an exception. */ public class RelationalAlgebraGenerator { - final Planner planner; - final HepPlanner hepPlanner; - /// Create a new relational algebra generator from a schema - public RelationalAlgebraGenerator(final String rootSchemaName, final List schemas, final boolean case_sensitive) throws ClassNotFoundException, SQLException { - // Taken from https://calcite.apache.org/docs/ and blazingSQL - final SchemaPlus rootSchema = createRootSchema(rootSchemaName, schemas); + final private DaskPlanner planner; + final private DaskSqlToRelConverter sqlToRelConverter; + final private DaskProgram program; + final private DaskSqlParser parser; - final JavaTypeFactoryImpl typeFactory = createTypeFactory(); - final CalciteCatalogReader calciteCatalogReader = createCatalogReader(rootSchemaName, rootSchema, typeFactory); - final SqlOperatorTable operatorTable = createOperatorTable(calciteCatalogReader); - final SqlParser.Config parserConfig = createParserConfig(case_sensitive); - final SchemaPlus schemaPlus = rootSchema.getSubSchema(rootSchemaName); - final FrameworkConfig frameworkConfig = createFrameworkConfig(schemaPlus, operatorTable, parserConfig); - - this.planner = createPlanner(frameworkConfig); - this.hepPlanner = createHepPlanner(frameworkConfig); + public RelationalAlgebraGenerator(final String rootSchemaName, + final List schemas, + final boolean case_sensitive) throws SQLException { + this.planner = new DaskPlanner(); + this.sqlToRelConverter = new DaskSqlToRelConverter(this.planner, rootSchemaName, schemas, case_sensitive); + this.program = new DaskProgram(this.planner); + this.parser = new DaskSqlParser(); } - /// Return the default dialect used static public SqlDialect getDialect() { return DaskSqlDialect.DEFAULT; } - /// Parse a sql string into a sql tree - public SqlNode getSqlNode(final String sql) throws SqlParseException { - try { - return this.planner.parse(sql); - } catch (final SqlParseException e) { - this.planner.close(); - throw e; - } + public SqlNode getSqlNode(final String sql) throws SqlParseException, ValidationException { + final SqlNode sqlNode = this.parser.parse(sql); + return sqlNode; } - /// Validate a sql node - public SqlNode getValidatedNode(final SqlNode sqlNode) throws ValidationException { - try { - return this.planner.validate(sqlNode); - } catch (final ValidationException e) { - this.planner.close(); - throw e; - } + public RelNode getRelationalAlgebra(final SqlNode sqlNode) throws RelConversionException { + return sqlToRelConverter.convert(sqlNode); } - /// Turn a validated sql node into a rel node - public RelNode getRelationalAlgebra(final SqlNode validatedSqlNode) throws RelConversionException { - try { - return this.planner.rel(validatedSqlNode).project(true); - } catch (final RelConversionException e) { - this.planner.close(); - throw e; - } + public RelNode getOptimizedRelationalAlgebra(final RelNode rel) { + return this.program.run(rel); } - /// Turn a non-optimized algebra into an optimized one - public RelNode getOptimizedRelationalAlgebra(final RelNode nonOptimizedPlan) { - this.hepPlanner.setRoot(nonOptimizedPlan); - this.planner.close(); - - return this.hepPlanner.findBestExp(); - } - - /// Return the string representation of a rel node public String getRelationalAlgebraString(final RelNode relNode) { - return RelOptUtil.toString(relNode); - } - - private Planner createPlanner(final FrameworkConfig config) { - return Frameworks.getPlanner(config); - } - - private JavaTypeFactoryImpl createTypeFactory() { - return new JavaTypeFactoryImpl(DaskSqlDialect.DASKSQL_TYPE_SYSTEM); - } - - private SchemaPlus createRootSchema(final String rootSchemaName, final List schemas) throws SQLException { - final CalciteConnection calciteConnection = createConnection(rootSchemaName); - final SchemaPlus rootSchema = calciteConnection.getRootSchema(); - for(DaskSchema schema : schemas) { - rootSchema.add(schema.getName(), schema); - } - return rootSchema; - } - - private CalciteConnection createConnection(final String schemaName) throws SQLException { - // Taken from https://calcite.apache.org/docs/ - final Properties info = new Properties(); - info.setProperty("lex", "JAVA"); - - final Connection connection = DriverManager.getConnection("jdbc:calcite:", info); - final CalciteConnection calciteConnection = connection.unwrap(CalciteConnection.class); - - calciteConnection.setSchema(schemaName); - return calciteConnection; + return RelOptUtil.toString(relNode, SqlExplainLevel.ALL_ATTRIBUTES); } - - private CalciteCatalogReader createCatalogReader(final String schemaName, final SchemaPlus schemaPlus, - final JavaTypeFactoryImpl typeFactory) { - final CalciteSchema calciteSchema = CalciteSchema.from(schemaPlus); - - final Properties props = new Properties(); - props.setProperty("defaultSchema", schemaName); - - final List defaultSchema = new ArrayList(); - defaultSchema.add(schemaName); - - final CalciteCatalogReader calciteCatalogReader = new CalciteCatalogReader(calciteSchema, defaultSchema, - typeFactory, new CalciteConnectionConfigImpl(props)); - return calciteCatalogReader; - } - - private SqlOperatorTable createOperatorTable(final CalciteCatalogReader calciteCatalogReader) { - final List sqlOperatorTables = new ArrayList<>(); - sqlOperatorTables.add(SqlStdOperatorTable.instance()); - sqlOperatorTables.add(SqlLibraryOperatorTableFactory.INSTANCE.getOperatorTable(SqlLibrary.POSTGRESQL)); - sqlOperatorTables.add(calciteCatalogReader); - - SqlOperatorTable operatorTable = SqlOperatorTables.chain(sqlOperatorTables); - return operatorTable; - } - - private Config createParserConfig(boolean case_sensitive) { - return getDialect().configureParser(SqlParser.Config.DEFAULT).withConformance(SqlConformanceEnum.DEFAULT) - .withCaseSensitive(case_sensitive) - .withParserFactory(new DaskSqlParserImplFactory()); - } - - private FrameworkConfig createFrameworkConfig(final SchemaPlus schemaPlus, SqlOperatorTable operatorTable, - final SqlParser.Config parserConfig) { - // Use our defined type system - final Context defaultContext = Contexts.of(CalciteConnectionConfig.DEFAULT.set( - CalciteConnectionProperty.TYPE_SYSTEM, "com.dask.sql.application.DaskSqlDialect#DASKSQL_TYPE_SYSTEM")); - - return Frameworks.newConfigBuilder().context(defaultContext).defaultSchema(schemaPlus) - .parserConfig(parserConfig).executor(new RexExecutorImpl(null)).operatorTable(operatorTable).build(); - } - - private HepPlanner createHepPlanner(final FrameworkConfig config) { - final HepProgram program = new HepProgramBuilder() - .addRuleInstance(CoreRules.AGGREGATE_PROJECT_MERGE) - .addRuleInstance(CoreRules.AGGREGATE_PROJECT_PULL_UP_CONSTANTS) - .addRuleInstance(CoreRules.AGGREGATE_ANY_PULL_UP_CONSTANTS) - .addRuleInstance(CoreRules.AGGREGATE_REDUCE_FUNCTIONS) - .addRuleInstance(CoreRules.AGGREGATE_MERGE) - .addRuleInstance(CoreRules.AGGREGATE_EXPAND_DISTINCT_AGGREGATES_TO_JOIN) - .addRuleInstance(CoreRules.AGGREGATE_JOIN_REMOVE) - .addRuleInstance(CoreRules.FILTER_AGGREGATE_TRANSPOSE) - .addRuleInstance(CoreRules.JOIN_CONDITION_PUSH) - .addRuleInstance(CoreRules.FILTER_INTO_JOIN) - .addRuleInstance(CoreRules.PROJECT_MERGE) - .addRuleInstance(CoreRules.FILTER_MERGE) - .addRuleInstance(CoreRules.PROJECT_JOIN_TRANSPOSE) - // In principle, not a bad idea. But we need to keep the most - // outer project - because otherwise the column name information is lost - // in cases such as SELECT x AS a, y AS B FROM df - // .addRuleInstance(ProjectRemoveRule.Config.DEFAULT.toRule()) - .addRuleInstance(CoreRules.PROJECT_REDUCE_EXPRESSIONS) - .addRuleInstance(CoreRules.FILTER_REDUCE_EXPRESSIONS) - .addRuleInstance(CoreRules.FILTER_EXPAND_IS_NOT_DISTINCT_FROM) - .addRuleInstance(CoreRules.PROJECT_TO_LOGICAL_PROJECT_AND_WINDOW) - .build(); - - return new HepPlanner(program, config.getContext()); - } - } diff --git a/planner/src/main/java/com/dask/sql/application/RelationalAlgebraGeneratorBuilder.java b/planner/src/main/java/com/dask/sql/application/RelationalAlgebraGeneratorBuilder.java index 57490d1c4..6da2a3115 100644 --- a/planner/src/main/java/com/dask/sql/application/RelationalAlgebraGeneratorBuilder.java +++ b/planner/src/main/java/com/dask/sql/application/RelationalAlgebraGeneratorBuilder.java @@ -6,10 +6,14 @@ import com.dask.sql.schema.DaskSchema; +/** + * RelationalAlgebraGeneratorBuilder is a Builder-pattern to make creating a + * RelationalAlgebraGenerator easier from Python. + */ public class RelationalAlgebraGeneratorBuilder { private final String rootSchemaName; private final List schemas; - private final boolean case_sensitive; // True if case should be ignored when comparing SQLNode(s) + private final boolean case_sensitive; // False if case should be ignored when comparing SQLNode(s) public RelationalAlgebraGeneratorBuilder(final String rootSchemaName, final boolean case_sensitive) { this.rootSchemaName = rootSchemaName; @@ -23,6 +27,6 @@ public RelationalAlgebraGeneratorBuilder addSchema(final DaskSchema schema) { } public RelationalAlgebraGenerator build() throws ClassNotFoundException, SQLException { - return new RelationalAlgebraGenerator(rootSchemaName, schemas, case_sensitive); + return new RelationalAlgebraGenerator(rootSchemaName, schemas, this.case_sensitive); } } diff --git a/planner/src/main/java/com/dask/sql/nodes/DaskAggregate.java b/planner/src/main/java/com/dask/sql/nodes/DaskAggregate.java new file mode 100644 index 000000000..5eb2dea5f --- /dev/null +++ b/planner/src/main/java/com/dask/sql/nodes/DaskAggregate.java @@ -0,0 +1,31 @@ +package com.dask.sql.nodes; + +import java.util.List; + +import com.google.common.collect.ImmutableList; + +import org.apache.calcite.plan.RelOptCluster; +import org.apache.calcite.plan.RelTraitSet; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.Aggregate; +import org.apache.calcite.rel.core.AggregateCall; +import org.apache.calcite.util.ImmutableBitSet; + +public class DaskAggregate extends Aggregate implements DaskRel { + private DaskAggregate(RelOptCluster cluster, RelTraitSet traitSet, RelNode input, ImmutableBitSet groupSet, + List groupSets, List aggCalls) { + super(cluster, traitSet, ImmutableList.of(), input, groupSet, groupSets, aggCalls); + } + + @Override + public Aggregate copy(RelTraitSet traitSet, RelNode input, ImmutableBitSet groupSet, + List groupSets, List aggCalls) { + return DaskAggregate.create(getCluster(), traitSet, input, groupSet, groupSets, aggCalls); + } + + public static DaskAggregate create(RelOptCluster cluster, RelTraitSet traitSet, RelNode input, + ImmutableBitSet groupSet, List groupSets, List aggCalls) { + assert traitSet.getConvention() == DaskRel.CONVENTION; + return new DaskAggregate(cluster, traitSet, input, groupSet, groupSets, aggCalls); + } +} diff --git a/planner/src/main/java/com/dask/sql/nodes/DaskConvention.java b/planner/src/main/java/com/dask/sql/nodes/DaskConvention.java new file mode 100644 index 000000000..57e427056 --- /dev/null +++ b/planner/src/main/java/com/dask/sql/nodes/DaskConvention.java @@ -0,0 +1,44 @@ +package com.dask.sql.nodes; + +import org.apache.calcite.plan.Convention; +import org.apache.calcite.plan.ConventionTraitDef; +import org.apache.calcite.plan.RelTraitSet; +import org.apache.calcite.rel.RelNode; + +public class DaskConvention extends Convention.Impl { + static public DaskConvention INSTANCE = new DaskConvention(); + + private DaskConvention() { + super("DASK", DaskRel.class); + } + + @Override + public RelNode enforce(final RelNode input, final RelTraitSet required) { + RelNode rel = input; + if (input.getConvention() != INSTANCE) { + rel = ConventionTraitDef.INSTANCE.convert(input.getCluster().getPlanner(), input, INSTANCE, true); + } + return rel; + } + + public boolean canConvertConvention(Convention toConvention) { + return false; + } + + public boolean useAbstractConvertersForConversion(RelTraitSet fromTraits, RelTraitSet toTraits) { + // Note: there seems to be two possibilities how to handle traits + // during optimization. + // One: set planner.setTopDownOpt(false) (the default). + // In this mode, we need to return true here and let Calcite include + // abstract converters whenever needed. We then need rules to + // turn the abstract converters into actual rels (like Exchange and Sort). + // Two: set planner.setTopDownOpt(true) + // Here, Calcite will propagate the needed traits from the top to the bottom. + // Each rel can decide on its own whether it can propagate the traits + // (example: a project will keep the collation and distribution traits, + // so it can propagate them). Whenever calcite sees that traits can not be + // propagated, it will call the enforce method of this convention. + // Here, we want to return false in this function. + return true; + } +} diff --git a/planner/src/main/java/com/dask/sql/nodes/DaskFilter.java b/planner/src/main/java/com/dask/sql/nodes/DaskFilter.java new file mode 100644 index 000000000..82ad1cdbb --- /dev/null +++ b/planner/src/main/java/com/dask/sql/nodes/DaskFilter.java @@ -0,0 +1,24 @@ +package com.dask.sql.nodes; + +import org.apache.calcite.plan.RelOptCluster; +import org.apache.calcite.plan.RelTraitSet; +import org.apache.calcite.rel.RelCollationTraitDef; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.Filter; +import org.apache.calcite.rex.RexNode; + +public class DaskFilter extends Filter implements DaskRel { + private DaskFilter(RelOptCluster cluster, RelTraitSet traitSet, RelNode child, RexNode condition) { + super(cluster, traitSet, child, condition); + } + + @Override + public Filter copy(RelTraitSet traitSet, RelNode input, RexNode condition) { + return DaskFilter.create(getCluster(), traitSet, input, condition); + } + + public static DaskFilter create(RelOptCluster cluster, RelTraitSet traitSet, RelNode child, RexNode condition) { + assert traitSet.getConvention() == DaskRel.CONVENTION; + return new DaskFilter(cluster, traitSet, child, condition); + } +} diff --git a/planner/src/main/java/com/dask/sql/nodes/DaskJoin.java b/planner/src/main/java/com/dask/sql/nodes/DaskJoin.java new file mode 100644 index 000000000..b03cd6c2d --- /dev/null +++ b/planner/src/main/java/com/dask/sql/nodes/DaskJoin.java @@ -0,0 +1,32 @@ +package com.dask.sql.nodes; + +import java.util.Set; + +import com.google.common.collect.ImmutableList; + +import org.apache.calcite.plan.RelOptCluster; +import org.apache.calcite.plan.RelTraitSet; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.CorrelationId; +import org.apache.calcite.rel.core.Join; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rex.RexNode; + +public class DaskJoin extends Join implements DaskRel { + private DaskJoin(RelOptCluster cluster, RelTraitSet traitSet, RelNode left, RelNode right, RexNode condition, + Set variablesSet, JoinRelType joinType) { + super(cluster, traitSet, ImmutableList.of(), left, right, condition, variablesSet, joinType); + } + + @Override + public Join copy(RelTraitSet traitSet, RexNode conditionExpr, RelNode left, RelNode right, JoinRelType joinType, + boolean semiJoinDone) { + return new DaskJoin(getCluster(), traitSet, left, right, condition, variablesSet, joinType); + } + + public static DaskJoin create(RelOptCluster cluster, RelTraitSet traitSet, RelNode left, RelNode right, + RexNode condition, Set variablesSet, JoinRelType joinType) { + assert traitSet.getConvention() == DaskRel.CONVENTION; + return new DaskJoin(cluster, traitSet, left, right, condition, variablesSet, joinType); + } +} diff --git a/planner/src/main/java/com/dask/sql/nodes/DaskLimit.java b/planner/src/main/java/com/dask/sql/nodes/DaskLimit.java new file mode 100644 index 000000000..891b8c60e --- /dev/null +++ b/planner/src/main/java/com/dask/sql/nodes/DaskLimit.java @@ -0,0 +1,51 @@ +package com.dask.sql.nodes; + +import java.util.List; + +import org.apache.calcite.plan.RelOptCluster; +import org.apache.calcite.plan.RelTraitSet; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.RelWriter; +import org.apache.calcite.rel.SingleRel; +import org.apache.calcite.rex.RexNode; +import org.checkerframework.checker.nullness.qual.Nullable; + +public class DaskLimit extends SingleRel implements DaskRel { + private final @Nullable RexNode offset; + private final @Nullable RexNode fetch; + + private DaskLimit(RelOptCluster cluster, RelTraitSet traitSet, RelNode input, @Nullable RexNode offset, + @Nullable RexNode fetch) { + super(cluster, traitSet, input); + this.offset = offset; + this.fetch = fetch; + } + + public DaskLimit copy(RelTraitSet traitSet, RelNode input, @Nullable RexNode offset, @Nullable RexNode fetch) { + return new DaskLimit(getCluster(), traitSet, input, offset, fetch); + } + + @Override + public RelNode copy(RelTraitSet traitSet, List inputs) { + return copy(traitSet, sole(inputs), this.getOffset(), this.getFetch()); + } + + public @Nullable RexNode getFetch() { + return this.fetch; + } + + public @Nullable RexNode getOffset() { + return this.offset; + } + + @Override + public RelWriter explainTerms(RelWriter pw) { + return super.explainTerms(pw).item("offset", this.getOffset()).item("fetch", this.getFetch()); + } + + public static DaskLimit create(RelOptCluster cluster, RelTraitSet traitSet, RelNode input, @Nullable RexNode offset, + @Nullable RexNode fetch) { + assert traitSet.getConvention() == DaskRel.CONVENTION; + return new DaskLimit(cluster, traitSet, input, offset, fetch); + } +} diff --git a/planner/src/main/java/com/dask/sql/nodes/DaskProject.java b/planner/src/main/java/com/dask/sql/nodes/DaskProject.java new file mode 100644 index 000000000..0b1f249e1 --- /dev/null +++ b/planner/src/main/java/com/dask/sql/nodes/DaskProject.java @@ -0,0 +1,30 @@ +package com.dask.sql.nodes; + +import java.util.List; + +import com.google.common.collect.ImmutableList; + +import org.apache.calcite.plan.RelOptCluster; +import org.apache.calcite.plan.RelTraitSet; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.Project; +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rex.RexNode; + +public class DaskProject extends Project implements DaskRel { + private DaskProject(RelOptCluster cluster, RelTraitSet traits, RelNode input, List projects, + RelDataType rowType) { + super(cluster, traits, ImmutableList.of(), input, projects, rowType); + } + + @Override + public Project copy(RelTraitSet traitSet, RelNode input, List projects, RelDataType rowType) { + return DaskProject.create(getCluster(), traitSet, input, projects, rowType); + } + + public static DaskProject create(RelOptCluster cluster, RelTraitSet traits, RelNode input, + List projects, RelDataType rowType) { + assert traits.getConvention() == DaskRel.CONVENTION; + return new DaskProject(cluster, traits, input, projects, rowType); + } +} diff --git a/planner/src/main/java/com/dask/sql/nodes/DaskRel.java b/planner/src/main/java/com/dask/sql/nodes/DaskRel.java new file mode 100644 index 000000000..83eafeced --- /dev/null +++ b/planner/src/main/java/com/dask/sql/nodes/DaskRel.java @@ -0,0 +1,8 @@ +package com.dask.sql.nodes; + +import org.apache.calcite.plan.Convention; +import org.apache.calcite.rel.RelNode; + +public interface DaskRel extends RelNode { + Convention CONVENTION = DaskConvention.INSTANCE; +} diff --git a/planner/src/main/java/com/dask/sql/nodes/DaskSample.java b/planner/src/main/java/com/dask/sql/nodes/DaskSample.java new file mode 100644 index 000000000..e7d2a540b --- /dev/null +++ b/planner/src/main/java/com/dask/sql/nodes/DaskSample.java @@ -0,0 +1,27 @@ +package com.dask.sql.nodes; + +import java.util.List; + +import org.apache.calcite.plan.RelOptCluster; +import org.apache.calcite.plan.RelOptSamplingParameters; +import org.apache.calcite.plan.RelTraitSet; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.Sample; + +public class DaskSample extends Sample implements DaskRel { + private DaskSample(RelOptCluster cluster, RelTraitSet traits, RelNode child, RelOptSamplingParameters params) { + super(cluster, child, params); + this.traitSet = traits; + } + + @Override + public RelNode copy(RelTraitSet traitSet, List inputs) { + return DaskSample.create(getCluster(), traitSet, sole(inputs), getSamplingParameters()); + } + + public static DaskSample create(RelOptCluster cluster, RelTraitSet traits, RelNode input, + RelOptSamplingParameters params) { + assert traits.getConvention() == DaskRel.CONVENTION; + return new DaskSample(cluster, traits, input, params); + } +} diff --git a/planner/src/main/java/com/dask/sql/nodes/DaskSort.java b/planner/src/main/java/com/dask/sql/nodes/DaskSort.java new file mode 100644 index 000000000..3a1c8bca7 --- /dev/null +++ b/planner/src/main/java/com/dask/sql/nodes/DaskSort.java @@ -0,0 +1,29 @@ +package com.dask.sql.nodes; + +import org.apache.calcite.plan.RelOptCluster; +import org.apache.calcite.plan.RelTraitSet; +import org.apache.calcite.rel.RelCollation; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.Sort; +import org.apache.calcite.rex.RexNode; +import org.checkerframework.checker.nullness.qual.Nullable; + +public class DaskSort extends Sort implements DaskRel { + private DaskSort(RelOptCluster cluster, RelTraitSet traits, RelNode child, RelCollation collation) { + super(cluster, traits, child, collation); + } + + @Override + public Sort copy(RelTraitSet traitSet, RelNode newInput, RelCollation newCollation, @Nullable RexNode offset, + @Nullable RexNode fetch) { + assert offset == null; + assert fetch == null; + return DaskSort.create(getCluster(), traitSet, newInput, newCollation); + } + + static public DaskSort create(RelOptCluster cluster, RelTraitSet traitSet, RelNode input, RelCollation collation) { + assert traitSet.getConvention() == DaskRel.CONVENTION; + return new DaskSort(cluster, traitSet, input, collation); + } + +} diff --git a/planner/src/main/java/com/dask/sql/nodes/DaskTableScan.java b/planner/src/main/java/com/dask/sql/nodes/DaskTableScan.java new file mode 100644 index 000000000..6d0ad4587 --- /dev/null +++ b/planner/src/main/java/com/dask/sql/nodes/DaskTableScan.java @@ -0,0 +1,31 @@ +package com.dask.sql.nodes; + +import java.util.List; + +import com.google.common.collect.ImmutableList; + +import org.apache.calcite.plan.RelOptCluster; +import org.apache.calcite.plan.RelOptTable; +import org.apache.calcite.plan.RelTraitSet; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.TableScan; + +public class DaskTableScan extends TableScan implements DaskRel { + private DaskTableScan(RelOptCluster cluster, RelTraitSet traitSet, RelOptTable table) { + super(cluster, traitSet, ImmutableList.of(), table); + } + + public static DaskTableScan create(RelOptCluster cluster, RelTraitSet traitSet, RelOptTable table) { + assert traitSet.getConvention() == DaskRel.CONVENTION; + return new DaskTableScan(cluster, traitSet, table); + } + + @Override + public RelNode copy(RelTraitSet traitSet, List inputs) { + return copy(traitSet, this.getTable()); + } + + private RelNode copy(RelTraitSet traitSet, RelOptTable table) { + return new DaskTableScan(getCluster(), traitSet, table); + } +} diff --git a/planner/src/main/java/com/dask/sql/nodes/DaskUnion.java b/planner/src/main/java/com/dask/sql/nodes/DaskUnion.java new file mode 100644 index 000000000..2f41fdc6b --- /dev/null +++ b/planner/src/main/java/com/dask/sql/nodes/DaskUnion.java @@ -0,0 +1,26 @@ +package com.dask.sql.nodes; + +import java.util.List; + +import org.apache.calcite.plan.RelOptCluster; +import org.apache.calcite.plan.RelTraitSet; +import org.apache.calcite.rel.RelCollations; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.SetOp; +import org.apache.calcite.rel.core.Union; + +public class DaskUnion extends Union implements DaskRel { + private DaskUnion(RelOptCluster cluster, RelTraitSet traits, List inputs, boolean all) { + super(cluster, traits, inputs, all); + } + + @Override + public SetOp copy(RelTraitSet traitSet, List inputs, boolean all) { + return new DaskUnion(getCluster(), traitSet, inputs, all); + } + + public static DaskUnion create(RelOptCluster cluster, RelTraitSet traits, List inputs, boolean all) { + assert traits.getConvention() == DaskRel.CONVENTION; + return new DaskUnion(cluster, traits, inputs, all); + } +} diff --git a/planner/src/main/java/com/dask/sql/nodes/DaskValues.java b/planner/src/main/java/com/dask/sql/nodes/DaskValues.java new file mode 100644 index 000000000..d352887b0 --- /dev/null +++ b/planner/src/main/java/com/dask/sql/nodes/DaskValues.java @@ -0,0 +1,21 @@ +package com.dask.sql.nodes; + +import org.apache.calcite.plan.RelOptCluster; +import org.apache.calcite.plan.RelTraitSet; +import org.apache.calcite.rel.core.Values; +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rex.RexLiteral; +import com.google.common.collect.ImmutableList; + +public class DaskValues extends Values implements DaskRel { + private DaskValues(RelOptCluster cluster, RelDataType rowType, ImmutableList> tuples, + RelTraitSet traits) { + super(cluster, rowType, tuples, traits); + } + + public static DaskValues create(RelOptCluster cluster, RelDataType rowType, + ImmutableList> tuples, RelTraitSet traits) { + assert traits.getConvention() == DaskRel.CONVENTION; + return new DaskValues(cluster, rowType, tuples, traits); + } +} diff --git a/planner/src/main/java/com/dask/sql/nodes/DaskWindow.java b/planner/src/main/java/com/dask/sql/nodes/DaskWindow.java new file mode 100644 index 000000000..2cacfe102 --- /dev/null +++ b/planner/src/main/java/com/dask/sql/nodes/DaskWindow.java @@ -0,0 +1,38 @@ +package com.dask.sql.nodes; + +import java.util.List; + +import org.apache.calcite.plan.RelOptCluster; +import org.apache.calcite.plan.RelTraitSet; +import org.apache.calcite.rel.RelCollationTraitDef; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.Window; +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rex.RexLiteral; + +public class DaskWindow extends Window implements DaskRel { + public DaskWindow(RelOptCluster cluster, RelTraitSet traitSet, RelNode input, List constants, + RelDataType rowType, List groups) { + super(cluster, traitSet, input, constants, rowType, groups); + } + + public List getGroups() { + return this.groups; + } + + @Override + public RelNode copy(RelTraitSet traitSet, List inputs) { + return copy(traitSet, sole(inputs), this.getConstants(), this.getRowType(), this.getGroups()); + } + + public Window copy(RelTraitSet traitSet, RelNode input, List constants, RelDataType rowType, + List groups) { + return DaskWindow.create(getCluster(), traitSet, input, constants, rowType, groups); + } + + public static DaskWindow create(RelOptCluster cluster, RelTraitSet traits, RelNode input, + List constants, RelDataType rowType, List groups) { + assert traits.getConvention() == DaskRel.CONVENTION; + return new DaskWindow(cluster, traits, input, constants, rowType, groups); + } +} diff --git a/planner/src/main/java/com/dask/sql/rules/DaskAggregateRule.java b/planner/src/main/java/com/dask/sql/rules/DaskAggregateRule.java new file mode 100644 index 000000000..3a7aae2d2 --- /dev/null +++ b/planner/src/main/java/com/dask/sql/rules/DaskAggregateRule.java @@ -0,0 +1,30 @@ +package com.dask.sql.rules; + +import com.dask.sql.nodes.DaskRel; +import com.dask.sql.nodes.DaskAggregate; + +import org.apache.calcite.plan.Convention; +import org.apache.calcite.plan.RelTraitSet; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.convert.ConverterRule; +import org.apache.calcite.rel.logical.LogicalAggregate; + +public class DaskAggregateRule extends ConverterRule { + public static final DaskAggregateRule INSTANCE = Config.INSTANCE + .withConversion(LogicalAggregate.class, Convention.NONE, DaskRel.CONVENTION, "DaskAggregateRule") + .withRuleFactory(DaskAggregateRule::new).toRule(DaskAggregateRule.class); + + DaskAggregateRule(Config config) { + super(config); + } + + @Override + public RelNode convert(RelNode rel) { + final LogicalAggregate agg = (LogicalAggregate) rel; + final RelTraitSet traitSet = agg.getTraitSet().replace(out); + RelNode transformedInput = convert(agg.getInput(), agg.getInput().getTraitSet().replace(DaskRel.CONVENTION)); + + return DaskAggregate.create(agg.getCluster(), traitSet, transformedInput, agg.getGroupSet(), agg.getGroupSets(), + agg.getAggCallList()); + } +} diff --git a/planner/src/main/java/com/dask/sql/rules/DaskFilterRule.java b/planner/src/main/java/com/dask/sql/rules/DaskFilterRule.java new file mode 100644 index 000000000..2ea46af30 --- /dev/null +++ b/planner/src/main/java/com/dask/sql/rules/DaskFilterRule.java @@ -0,0 +1,29 @@ +package com.dask.sql.rules; + +import com.dask.sql.nodes.DaskRel; +import com.dask.sql.nodes.DaskFilter; + +import org.apache.calcite.plan.Convention; +import org.apache.calcite.plan.RelTraitSet; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.convert.ConverterRule; +import org.apache.calcite.rel.logical.LogicalFilter; + +public class DaskFilterRule extends ConverterRule { + public static final DaskFilterRule INSTANCE = Config.INSTANCE.withConversion(LogicalFilter.class, + f -> !f.containsOver(), Convention.NONE, DaskRel.CONVENTION, "DaskFilterRule") + .withRuleFactory(DaskFilterRule::new).toRule(DaskFilterRule.class); + + DaskFilterRule(Config config) { + super(config); + } + + @Override + public RelNode convert(RelNode rel) { + final LogicalFilter filter = (LogicalFilter) rel; + final RelTraitSet traitSet = filter.getTraitSet().replace(out); + RelNode transformedInput = convert(filter.getInput(), + filter.getInput().getTraitSet().replace(DaskRel.CONVENTION)); + return DaskFilter.create(filter.getCluster(), traitSet, transformedInput, filter.getCondition()); + } +} diff --git a/planner/src/main/java/com/dask/sql/rules/DaskJoinRule.java b/planner/src/main/java/com/dask/sql/rules/DaskJoinRule.java new file mode 100644 index 000000000..97b29e2bc --- /dev/null +++ b/planner/src/main/java/com/dask/sql/rules/DaskJoinRule.java @@ -0,0 +1,31 @@ +package com.dask.sql.rules; + +import com.dask.sql.nodes.DaskRel; +import com.dask.sql.nodes.DaskJoin; + +import org.apache.calcite.plan.Convention; +import org.apache.calcite.plan.RelTraitSet; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.convert.ConverterRule; +import org.apache.calcite.rel.logical.LogicalJoin; + +public class DaskJoinRule extends ConverterRule { + public static final DaskJoinRule INSTANCE = Config.INSTANCE + .withConversion(LogicalJoin.class, Convention.NONE, DaskRel.CONVENTION, "DaskJoinRule") + .withRuleFactory(DaskJoinRule::new).toRule(DaskJoinRule.class); + + DaskJoinRule(Config config) { + super(config); + } + + @Override + public RelNode convert(RelNode rel) { + final LogicalJoin join = (LogicalJoin) rel; + final RelTraitSet traitSet = join.getTraitSet().replace(out); + RelNode transformedLeft = convert(join.getLeft(), join.getLeft().getTraitSet().replace(DaskRel.CONVENTION)); + RelNode transformedRight = convert(join.getRight(), join.getRight().getTraitSet().replace(DaskRel.CONVENTION)); + + return DaskJoin.create(join.getCluster(), traitSet, transformedLeft, transformedRight, join.getCondition(), + join.getVariablesSet(), join.getJoinType()); + } +} diff --git a/planner/src/main/java/com/dask/sql/rules/DaskProjectRule.java b/planner/src/main/java/com/dask/sql/rules/DaskProjectRule.java new file mode 100644 index 000000000..7b8af2177 --- /dev/null +++ b/planner/src/main/java/com/dask/sql/rules/DaskProjectRule.java @@ -0,0 +1,31 @@ +package com.dask.sql.rules; + +import com.dask.sql.nodes.DaskRel; +import com.dask.sql.nodes.DaskProject; + +import org.apache.calcite.plan.Convention; +import org.apache.calcite.plan.RelTraitSet; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.convert.ConverterRule; +import org.apache.calcite.rel.logical.LogicalProject; + +public class DaskProjectRule extends ConverterRule { + public static final DaskProjectRule INSTANCE = Config.INSTANCE.withConversion(LogicalProject.class, + p -> !p.containsOver(), Convention.NONE, DaskRel.CONVENTION, "DaskProjectRule") + .withRuleFactory(DaskProjectRule::new).toRule(DaskProjectRule.class); + + DaskProjectRule(Config config) { + super(config); + } + + @Override + public RelNode convert(RelNode rel) { + final LogicalProject project = (LogicalProject) rel; + final RelTraitSet traitSet = project.getTraitSet().replace(out); + RelNode transformedInput = convert(project.getInput(), + project.getInput().getTraitSet().replace(DaskRel.CONVENTION)); + + return DaskProject.create(project.getCluster(), traitSet, transformedInput, project.getProjects(), + project.getRowType()); + } +} diff --git a/planner/src/main/java/com/dask/sql/rules/DaskSampleRule.java b/planner/src/main/java/com/dask/sql/rules/DaskSampleRule.java new file mode 100644 index 000000000..8aee0fc54 --- /dev/null +++ b/planner/src/main/java/com/dask/sql/rules/DaskSampleRule.java @@ -0,0 +1,30 @@ +package com.dask.sql.rules; + +import com.dask.sql.nodes.DaskRel; +import com.dask.sql.nodes.DaskSample; + +import org.apache.calcite.plan.Convention; +import org.apache.calcite.plan.RelTraitSet; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.convert.ConverterRule; +import org.apache.calcite.rel.core.Sample; + +public class DaskSampleRule extends ConverterRule { + public static final DaskSampleRule INSTANCE = Config.INSTANCE + .withConversion(Sample.class, Convention.NONE, DaskRel.CONVENTION, "DaskSampleRule") + .withRuleFactory(DaskSampleRule::new).toRule(DaskSampleRule.class); + + DaskSampleRule(Config config) { + super(config); + } + + @Override + public RelNode convert(RelNode rel) { + final Sample sample = (Sample) rel; + final RelTraitSet traitSet = sample.getTraitSet().replace(out); + RelNode transformedInput = convert(sample.getInput(), + sample.getInput().getTraitSet().replace(DaskRel.CONVENTION)); + + return DaskSample.create(sample.getCluster(), traitSet, transformedInput, sample.getSamplingParameters()); + } +} diff --git a/planner/src/main/java/com/dask/sql/rules/DaskSortLimitRule.java b/planner/src/main/java/com/dask/sql/rules/DaskSortLimitRule.java new file mode 100644 index 000000000..323a2bacb --- /dev/null +++ b/planner/src/main/java/com/dask/sql/rules/DaskSortLimitRule.java @@ -0,0 +1,39 @@ +package com.dask.sql.rules; + +import com.dask.sql.nodes.DaskLimit; +import com.dask.sql.nodes.DaskRel; +import com.dask.sql.nodes.DaskSort; + +import org.apache.calcite.plan.Convention; +import org.apache.calcite.plan.RelTraitSet; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.convert.ConverterRule; +import org.apache.calcite.rel.logical.LogicalSort; + +public class DaskSortLimitRule extends ConverterRule { + public static final DaskSortLimitRule INSTANCE = Config.INSTANCE + .withConversion(LogicalSort.class, Convention.NONE, DaskRel.CONVENTION, "DaskSortLimitRule") + .withRuleFactory(DaskSortLimitRule::new).toRule(DaskSortLimitRule.class); + + DaskSortLimitRule(Config config) { + super(config); + } + + @Override + public RelNode convert(RelNode rel) { + final LogicalSort sort = (LogicalSort) rel; + final RelTraitSet traitSet = sort.getTraitSet().replace(out); + RelNode transformedInput = convert(sort.getInput(), sort.getInput().getTraitSet().replace(DaskRel.CONVENTION)); + + if (!sort.getCollation().getFieldCollations().isEmpty()) { + // Create a sort with the same sort key, but no offset or fetch. + transformedInput = DaskSort.create(transformedInput.getCluster(), traitSet, transformedInput, + sort.getCollation()); + } + if (sort.fetch == null && sort.offset == null) { + return transformedInput; + } + + return DaskLimit.create(sort.getCluster(), traitSet, transformedInput, sort.offset, sort.fetch); + } +} diff --git a/planner/src/main/java/com/dask/sql/rules/DaskTableScanRule.java b/planner/src/main/java/com/dask/sql/rules/DaskTableScanRule.java new file mode 100644 index 000000000..3d7e9c83f --- /dev/null +++ b/planner/src/main/java/com/dask/sql/rules/DaskTableScanRule.java @@ -0,0 +1,27 @@ +package com.dask.sql.rules; + +import com.dask.sql.nodes.DaskRel; +import com.dask.sql.nodes.DaskTableScan; + +import org.apache.calcite.plan.Convention; +import org.apache.calcite.plan.RelTraitSet; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.convert.ConverterRule; +import org.apache.calcite.rel.logical.LogicalTableScan; + +public class DaskTableScanRule extends ConverterRule { + public static final DaskTableScanRule INSTANCE = Config.INSTANCE + .withConversion(LogicalTableScan.class, Convention.NONE, DaskRel.CONVENTION, "DaskTableScanRule") + .withRuleFactory(DaskTableScanRule::new).toRule(DaskTableScanRule.class); + + DaskTableScanRule(Config config) { + super(config); + } + + @Override + public RelNode convert(RelNode rel) { + final LogicalTableScan scan = (LogicalTableScan) rel; + final RelTraitSet traitSet = scan.getTraitSet().replace(out); + return DaskTableScan.create(scan.getCluster(), traitSet, scan.getTable()); + } +} diff --git a/planner/src/main/java/com/dask/sql/rules/DaskUnionRule.java b/planner/src/main/java/com/dask/sql/rules/DaskUnionRule.java new file mode 100644 index 000000000..5fd047224 --- /dev/null +++ b/planner/src/main/java/com/dask/sql/rules/DaskUnionRule.java @@ -0,0 +1,32 @@ +package com.dask.sql.rules; + +import java.util.List; +import java.util.stream.Collectors; + +import com.dask.sql.nodes.DaskRel; +import com.dask.sql.nodes.DaskUnion; + +import org.apache.calcite.plan.Convention; +import org.apache.calcite.plan.RelTraitSet; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.convert.ConverterRule; +import org.apache.calcite.rel.logical.LogicalUnion; + +public class DaskUnionRule extends ConverterRule { + public static final DaskUnionRule INSTANCE = Config.INSTANCE + .withConversion(LogicalUnion.class, Convention.NONE, DaskRel.CONVENTION, "DaskUnionRule") + .withRuleFactory(DaskUnionRule::new).toRule(DaskUnionRule.class); + + DaskUnionRule(Config config) { + super(config); + } + + @Override + public RelNode convert(RelNode rel) { + final LogicalUnion union = (LogicalUnion) rel; + final RelTraitSet traitSet = union.getTraitSet().replace(out); + List transformedInputs = union.getInputs().stream() + .map(c -> convert(c, c.getTraitSet().replace(DaskRel.CONVENTION))).collect(Collectors.toList()); + return DaskUnion.create(union.getCluster(), traitSet, transformedInputs, union.all); + } +} diff --git a/planner/src/main/java/com/dask/sql/rules/DaskValuesRule.java b/planner/src/main/java/com/dask/sql/rules/DaskValuesRule.java new file mode 100644 index 000000000..070f09b6a --- /dev/null +++ b/planner/src/main/java/com/dask/sql/rules/DaskValuesRule.java @@ -0,0 +1,27 @@ +package com.dask.sql.rules; + +import com.dask.sql.nodes.DaskRel; +import com.dask.sql.nodes.DaskValues; + +import org.apache.calcite.plan.Convention; +import org.apache.calcite.plan.RelTraitSet; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.convert.ConverterRule; +import org.apache.calcite.rel.logical.LogicalValues; + +public class DaskValuesRule extends ConverterRule { + public static final DaskValuesRule INSTANCE = Config.INSTANCE + .withConversion(LogicalValues.class, Convention.NONE, DaskRel.CONVENTION, "DaskValuesRule") + .withRuleFactory(DaskValuesRule::new).toRule(DaskValuesRule.class); + + DaskValuesRule(Config config) { + super(config); + } + + @Override + public RelNode convert(RelNode rel) { + final LogicalValues values = (LogicalValues) rel; + final RelTraitSet traitSet = values.getTraitSet().replace(out); + return DaskValues.create(values.getCluster(), values.getRowType(), values.getTuples(), traitSet); + } +} diff --git a/planner/src/main/java/com/dask/sql/rules/DaskWindowRule.java b/planner/src/main/java/com/dask/sql/rules/DaskWindowRule.java new file mode 100644 index 000000000..06688ce6f --- /dev/null +++ b/planner/src/main/java/com/dask/sql/rules/DaskWindowRule.java @@ -0,0 +1,32 @@ +package com.dask.sql.rules; + +import com.dask.sql.nodes.DaskRel; +import com.dask.sql.nodes.DaskWindow; + +import org.apache.calcite.plan.Convention; +import org.apache.calcite.plan.RelTraitSet; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.convert.ConverterRule; +import org.apache.calcite.rel.core.Window.Group; +import org.apache.calcite.rel.logical.LogicalWindow; + +public class DaskWindowRule extends ConverterRule { + public static final DaskWindowRule INSTANCE = Config.INSTANCE + .withConversion(LogicalWindow.class, Convention.NONE, DaskRel.CONVENTION, "DaskWindowRule") + .withRuleFactory(DaskWindowRule::new).toRule(DaskWindowRule.class); + + DaskWindowRule(Config config) { + super(config); + } + + @Override + public RelNode convert(RelNode rel) { + final LogicalWindow window = (LogicalWindow) rel; + final RelTraitSet traitSet = window.getTraitSet().replace(out); + RelNode transformedInput = convert(window.getInput(), + window.getInput().getTraitSet().replace(DaskRel.CONVENTION)); + + return DaskWindow.create(window.getCluster(), traitSet, transformedInput, window.getConstants(), + window.getRowType(), window.groups); + } +} diff --git a/planner/src/main/java/com/dask/sql/schema/DaskTable.java b/planner/src/main/java/com/dask/sql/schema/DaskTable.java index 31681f9fe..61412ed9e 100644 --- a/planner/src/main/java/com/dask/sql/schema/DaskTable.java +++ b/planner/src/main/java/com/dask/sql/schema/DaskTable.java @@ -3,16 +3,18 @@ import java.util.ArrayList; import java.util.List; -import org.apache.calcite.DataContext; import org.apache.calcite.config.CalciteConnectionConfig; -import org.apache.calcite.linq4j.Enumerable; +import org.apache.calcite.plan.RelOptTable; +import org.apache.calcite.plan.RelTraitSet; +import org.apache.calcite.plan.RelOptTable.ToRelContext; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.logical.LogicalTableScan; import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rel.type.RelDataTypeFactory; -import org.apache.calcite.rex.RexNode; -import org.apache.calcite.schema.ProjectableFilterableTable; import org.apache.calcite.schema.Schema; import org.apache.calcite.schema.Statistic; import org.apache.calcite.schema.Statistics; +import org.apache.calcite.schema.TranslatableTable; import org.apache.calcite.sql.SqlCall; import org.apache.calcite.sql.SqlNode; import org.apache.calcite.sql.type.SqlTypeName; @@ -23,16 +25,24 @@ * * Basically just a list of columns, each column being a column name and a type. */ -public class DaskTable implements ProjectableFilterableTable { +public class DaskTable implements TranslatableTable { // List of columns (name, column type) private final ArrayList> tableColumns; // Name of this table private final String name; + // Any statistics information we have + private final DaskStatistics statistics; - /// Construct a new table with the given name - public DaskTable(final String name) { + /// Construct a new table with the given name and estimated row count + public DaskTable(final String name, final Double rowCount) { this.name = name; this.tableColumns = new ArrayList>(); + this.statistics = new DaskStatistics(rowCount); + } + + /// Construct a new table with the given name + public DaskTable(final String name) { + this(name, null); } /// Add a column with the given type @@ -61,7 +71,7 @@ public RelDataType getRowType(final RelDataTypeFactory relDataTypeFactory) { /// calcite method: statistics of this table (not implemented) @Override public Statistic getStatistic() { - return Statistics.UNKNOWN; + return this.statistics; } /// calcite method: the type -> it is a table @@ -83,12 +93,22 @@ public boolean rolledUpColumnValidInsideAgg(final String string, final SqlCall s throw new AssertionError("This should not be called!"); } - /** - * calcite method: normally, this would return the actual data - but we do not - * use the computation engine from calcite - */ @Override - public Enumerable scan(final DataContext root, final List filters, final int[] projects) { - return null; + public RelNode toRel(ToRelContext context, RelOptTable relOptTable) { + RelTraitSet traitSet = context.getCluster().traitSet(); + return new LogicalTableScan(context.getCluster(), traitSet, context.getTableHints(), relOptTable); + } + + private final class DaskStatistics implements Statistic { + private final Double rowCount; + + public DaskStatistics(final Double rowCount) { + this.rowCount = rowCount; + } + + @Override + public Double getRowCount() { + return this.rowCount; + } } } diff --git a/tests/integration/test_filter.py b/tests/integration/test_filter.py index b668d1c51..9168a217c 100644 --- a/tests/integration/test_filter.py +++ b/tests/integration/test_filter.py @@ -53,7 +53,7 @@ def test_filter_with_nan(c): return_df = return_df.compute() if INT_NAN_IMPLEMENTED: - expected_df = pd.DataFrame({"c": [3]}, dtype="int8") + expected_df = pd.DataFrame({"c": [3]}, dtype="Int8") else: expected_df = pd.DataFrame({"c": [3]}, dtype="float") assert_frame_equal( diff --git a/tests/integration/test_rex.py b/tests/integration/test_rex.py index dea6d5604..106c6edab 100644 --- a/tests/integration/test_rex.py +++ b/tests/integration/test_rex.py @@ -538,7 +538,8 @@ def test_date_functions(c): TIMESTAMPADD(HOUR, 1, d) as "plus_1_hour", TIMESTAMPADD(MINUTE, 1, d) as "plus_1_min", TIMESTAMPADD(SECOND, 1, d) as "plus_1_sec", - TIMESTAMPADD(MICROSECOND, 1000, d) as "plus_1000_millisec", + TIMESTAMPADD(MICROSECOND, 999*1000, d) as "plus_999_millisec", + TIMESTAMPADD(MICROSECOND, 999, d) as "plus_999_microsec", TIMESTAMPADD(QUARTER, 1, d) as "plus_1_qt", CEIL(d TO DAY) as ceil_to_day, @@ -582,7 +583,8 @@ def test_date_functions(c): "plus_1_hour": [datetime(2021, 10, 3, 16, 53, 42, 47)], "plus_1_min": [datetime(2021, 10, 3, 15, 54, 42, 47)], "plus_1_sec": [datetime(2021, 10, 3, 15, 53, 43, 47)], - "plus_1000_millisec": [datetime(2021, 10, 3, 15, 53, 42, 1047)], + "plus_999_millisec": [datetime(2021, 10, 3, 15, 53, 42, 1000 * 999 + 47)], + "plus_999_microsec": [datetime(2021, 10, 3, 15, 53, 42, 1046)], "plus_1_qt": [datetime(2022, 1, 3, 15, 53, 42, 47)], "ceil_to_day": [datetime(2021, 10, 4)], "ceil_to_hour": [datetime(2021, 10, 3, 16)], diff --git a/tests/integration/test_sort.py b/tests/integration/test_sort.py index 6d03220f6..5825e3567 100644 --- a/tests/integration/test_sort.py +++ b/tests/integration/test_sort.py @@ -364,6 +364,11 @@ def test_limit(c, long_table): assert_frame_equal(df_result, long_table.iloc[101 : 101 + 101]) + df_result = c.sql("SELECT * FROM long_table OFFSET 101") + df_result = df_result.compute() + + assert_frame_equal(df_result, long_table.iloc[101:]) + @pytest.mark.gpu def test_sort_gpu(c, gpu_user_table_1, gpu_df): diff --git a/tests/unit/test_context.py b/tests/unit/test_context.py index cbbc09e0c..a9b0d3fe6 100644 --- a/tests/unit/test_context.py +++ b/tests/unit/test_context.py @@ -1,4 +1,3 @@ -import os import warnings import dask.dataframe as dd @@ -6,6 +5,7 @@ import pytest from dask_sql import Context +from dask_sql.datacontainer import Statistics try: import cudf @@ -62,9 +62,16 @@ def test_explain(gpu): sql_string = c.explain("SELECT * FROM df") - assert ( - sql_string - == f"LogicalProject(a=[$0]){os.linesep} LogicalTableScan(table=[[root, df]]){os.linesep}" + assert sql_string.startswith( + "DaskTableScan(table=[[root, df]]): rowcount = 100.0, cumulative cost = {100.0 rows, 101.0 cpu, 0.0 io}, id = " + ) + + c.create_table("df", data_frame, statistics=Statistics(row_count=1337)) + + sql_string = c.explain("SELECT * FROM df") + + assert sql_string.startswith( + "DaskTableScan(table=[[root, df]]): rowcount = 1337.0, cumulative cost = {1337.0 rows, 1338.0 cpu, 0.0 io}, id = " ) c = Context() @@ -78,9 +85,8 @@ def test_explain(gpu): "SELECT * FROM other_df", dataframes={"other_df": data_frame} ) - assert ( - sql_string - == f"LogicalProject(a=[$0]){os.linesep} LogicalTableScan(table=[[root, other_df]]){os.linesep}" + assert sql_string.startswith( + "DaskTableScan(table=[[root, other_df]]): rowcount = 100.0, cumulative cost = {100.0 rows, 101.0 cpu, 0.0 io}, id = " )