Skip to content
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
6e45ea7
Refactoring of the planner logic
nils-braun Aug 24, 2021
429acc5
Adjust the python implementations to follow the new conventions
nils-braun Aug 24, 2021
51e267b
Fix a bug in the column name creation and in the window node. While d…
nils-braun Aug 24, 2021
07a539f
Implement the missing sample node
nils-braun Aug 25, 2021
35622da
Calcite is now creating a new function: /int. Implement it
nils-braun Aug 25, 2021
e3dce64
Additional optimization rules (probably not needed)
nils-braun Aug 25, 2021
49d256e
Expaning will lead to RexSubQueries, which are harder to handle and n…
nils-braun Aug 25, 2021
cefcebe
Documentation
nils-braun Aug 25, 2021
417c26f
Allow the user to pass in row-count statistics on the tables, which a…
nils-braun Aug 25, 2021
672d063
Merge remote-tracking branch 'origin/main' into feature/cost-based-op…
nils-braun Aug 25, 2021
73ee12f
List.of is a Java 11 feature -> replace with Arrays.asList to stay ja…
nils-braun Aug 26, 2021
4e27798
Bring the coverage up again
nils-braun Aug 26, 2021
63abd27
Merge remote-tracking branch 'upstream/main' into feature/cost-based-…
charlesbluca Oct 12, 2021
395d7fa
Merge remote-tracking branch 'upstream/main' into feature/cost-based-…
charlesbluca Nov 22, 2021
20e7891
Run pre-commit hooks
charlesbluca Nov 22, 2021
cca757d
Remove formatted brackets causing explain failures
charlesbluca Nov 29, 2021
266b26f
Merge remote-tracking branch 'upstream/main' into feature/cost-based-…
charlesbluca Dec 13, 2021
8f94429
apply changes from Jeremy
galipremsagar Dec 21, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions dask_sql/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from ._version import get_version
from .cmd import cmd_loop
from .context import Context
from .datacontainer import Statistics
from .server.app import run_server

__version__ = get_version()

__all__ = [__version__, cmd_loop, Context, run_server, Statistics]
76 changes: 47 additions & 29 deletions dask_sql/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import inspect
import logging
import warnings
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Tuple, Union
from typing import Any, Callable, Dict, List, Tuple, Union

import dask.dataframe as dd
import pandas as pd
Expand All @@ -20,28 +20,16 @@
DataContainer,
FunctionDescription,
SchemaContainer,
Statistics,
)
from dask_sql.input_utils import InputType, InputUtil
from dask_sql.integrations.ipython import ipython_integration
from dask_sql.java import (
DaskAggregateFunction,
DaskScalarFunction,
DaskSchema,
DaskTable,
RelationalAlgebraGenerator,
RelationalAlgebraGeneratorBuilder,
SqlParseException,
ValidationException,
get_java_class,
)
from dask_sql.java import com, get_java_class, org
from dask_sql.mappings import python_to_sql_type
from dask_sql.physical.rel import RelConverter, custom, logical
from dask_sql.physical.rex import RexConverter, core
from dask_sql.utils import ParsingException

if TYPE_CHECKING:
from dask_sql.java import org

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -90,15 +78,16 @@ def __init__(self):
self.sql_server = None

# Register any default plugins, if nothing was registered before.
RelConverter.add_plugin_class(logical.LogicalAggregatePlugin, replace=False)
RelConverter.add_plugin_class(logical.LogicalFilterPlugin, replace=False)
RelConverter.add_plugin_class(logical.LogicalJoinPlugin, replace=False)
RelConverter.add_plugin_class(logical.LogicalProjectPlugin, replace=False)
RelConverter.add_plugin_class(logical.LogicalSortPlugin, replace=False)
RelConverter.add_plugin_class(logical.LogicalTableScanPlugin, replace=False)
RelConverter.add_plugin_class(logical.LogicalUnionPlugin, replace=False)
RelConverter.add_plugin_class(logical.LogicalValuesPlugin, replace=False)
RelConverter.add_plugin_class(logical.LogicalWindowPlugin, replace=False)
RelConverter.add_plugin_class(logical.DaskAggregatePlugin, replace=False)
RelConverter.add_plugin_class(logical.DaskFilterPlugin, replace=False)
RelConverter.add_plugin_class(logical.DaskJoinPlugin, replace=False)
RelConverter.add_plugin_class(logical.DaskLimitPlugin, replace=False)
RelConverter.add_plugin_class(logical.DaskProjectPlugin, replace=False)
RelConverter.add_plugin_class(logical.DaskSortPlugin, replace=False)
RelConverter.add_plugin_class(logical.DaskTableScanPlugin, replace=False)
RelConverter.add_plugin_class(logical.DaskUnionPlugin, replace=False)
RelConverter.add_plugin_class(logical.DaskValuesPlugin, replace=False)
RelConverter.add_plugin_class(logical.DaskWindowPlugin, replace=False)
RelConverter.add_plugin_class(logical.SamplePlugin, replace=False)
RelConverter.add_plugin_class(custom.AnalyzeTablePlugin, replace=False)
RelConverter.add_plugin_class(custom.CreateExperimentPlugin, replace=False)
Expand Down Expand Up @@ -138,6 +127,7 @@ def create_table(
format: str = None,
persist: bool = False,
schema_name: str = None,
statistics: Statistics = None,
gpu: bool = False,
**kwargs,
):
Expand Down Expand Up @@ -200,6 +190,11 @@ def create_table(
If set to "memory", load the data from a published dataset in the dask cluster.
persist (:obj:`bool`): Only used when passing a string into the ``input`` parameter.
Set to true to turn on loading the file data directly into memory.
schema_name: (:obj:`str`): in which schema to create the table. By default, will use the currently selected schema.
statistics: (:obj:`Statistics`): if given, use these statistics during the cost-based optimization. If no
statistics are provided, we will just assume 100 rows.
gpu: (:obj:`bool`): if set to true, use dask-cudf to run the data frame calculations on your GPU.
Please note that the GPU support is currently not covering all of dask-sql's SQL language.
**kwargs: Additional arguments for specific formats. See :ref:`data_input` for more information.

"""
Expand All @@ -218,6 +213,8 @@ def create_table(
**kwargs,
)
self.schema[schema_name].tables[table_name.lower()] = dc
if statistics:
self.schema[schema_name].statistics[table_name.lower()] = statistics

def register_dask_table(self, df: dd.DataFrame, name: str, *args, **kwargs):
"""
Expand Down Expand Up @@ -676,14 +673,26 @@ def _prepare_schemas(self):
"""
schema_list = []

DaskTable = com.dask.sql.schema.DaskTable
DaskAggregateFunction = com.dask.sql.schema.DaskAggregateFunction
DaskScalarFunction = com.dask.sql.schema.DaskScalarFunction
DaskSchema = com.dask.sql.schema.DaskSchema

for schema_name, schema in self.schema.items():
java_schema = DaskSchema(schema_name)

if not schema.tables:
logger.warning("No tables are registered.")

for name, dc in schema.tables.items():
table = DaskTable(name)
row_count = (
schema.statistics[name].row_count
if name in schema.statistics
else None
)
if row_count is not None:
row_count = float(row_count)
table = DaskTable(name, row_count)
df = dc.df
logger.debug(
f"Adding table '{name}' to schema with columns: {list(df.columns)}"
Expand Down Expand Up @@ -735,6 +744,10 @@ def _get_ral(self, sql):
# get the schema of what we currently have registered
schemas = self._prepare_schemas()

RelationalAlgebraGeneratorBuilder = (
com.dask.sql.application.RelationalAlgebraGeneratorBuilder
)

# Now create a relational algebra from that
generator_builder = RelationalAlgebraGeneratorBuilder(self.schema_name)
for schema in schemas:
Expand All @@ -744,6 +757,10 @@ def _get_ral(self, sql):

logger.debug(f"Using dialect: {get_java_class(default_dialect)}")

ValidationException = org.apache.calcite.tools.ValidationException
SqlParseException = org.apache.calcite.sql.parser.SqlParseException
CalciteContextException = org.apache.calcite.runtime.CalciteContextException

try:
sqlNode = generator.getSqlNode(sql)
sqlNodeClass = get_java_class(sqlNode)
Expand All @@ -753,16 +770,15 @@ def _get_ral(self, sql):
rel_string = ""

if not sqlNodeClass.startswith("com.dask.sql.parser."):
validatedSqlNode = generator.getValidatedNode(sqlNode)
nonOptimizedRelNode = generator.getRelationalAlgebra(validatedSqlNode)
nonOptimizedRelNode = generator.getRelationalAlgebra(sqlNode)
# Optimization might remove some alias projects. Make sure to keep them here.
select_names = [
str(name)
for name in nonOptimizedRelNode.getRowType().getFieldNames()
]
rel = generator.getOptimizedRelationalAlgebra(nonOptimizedRelNode)
rel_string = str(generator.getRelationalAlgebraString(rel))
except (ValidationException, SqlParseException) as e:
except (ValidationException, SqlParseException, CalciteContextException) as e:
logger.debug(f"Original exception raised by Java:\n {e}")
# We do not want to re-raise an exception here
# as this would print the full java stack trace
Expand Down Expand Up @@ -798,7 +814,9 @@ def _get_ral(self, sql):

def _to_sql_string(self, s: "org.apache.calcite.sql.SqlNode", default_dialect=None):
if default_dialect is None:
default_dialect = RelationalAlgebraGenerator.getDialect()
default_dialect = (
com.dask.sql.application.RelationalAlgebraGenerator.getDialect()
)

try:
return str(s.toSqlString(default_dialect))
Expand Down
12 changes: 12 additions & 0 deletions dask_sql/datacontainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,10 +231,22 @@ def __hash__(self):
return (self.func, self.row_udf).__hash__()


class Statistics:
"""
Statistics are used during the cost-based optimization.
Currently, only the row count is supported, more
properties might follow. It needs to be provided by the user.
"""

def __init__(self, row_count: int) -> None:
self.row_count = row_count


class SchemaContainer:
def __init__(self, name: str):
self.__name__ = name
self.tables: Dict[str, DataContainer] = {}
self.statistics: Dict[str, Statistics] = {}
self.experiments: Dict[str, pd.DataFrame] = {}
self.models: Dict[str, Tuple[Any, List[str]]] = {}
self.functions: Dict[str, UDF] = {}
Expand Down
12 changes: 0 additions & 12 deletions dask_sql/java.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,18 +85,6 @@ def _set_or_check_java_home():
org = jpype.JPackage("org")
java = jpype.JPackage("java")

DaskTable = com.dask.sql.schema.DaskTable
DaskAggregateFunction = com.dask.sql.schema.DaskAggregateFunction
DaskScalarFunction = com.dask.sql.schema.DaskScalarFunction
DaskSchema = com.dask.sql.schema.DaskSchema
RelationalAlgebraGenerator = com.dask.sql.application.RelationalAlgebraGenerator
RelationalAlgebraGeneratorBuilder = (
com.dask.sql.application.RelationalAlgebraGeneratorBuilder
)
SqlTypeName = org.apache.calcite.sql.type.SqlTypeName
ValidationException = org.apache.calcite.tools.ValidationException
SqlParseException = org.apache.calcite.sql.parser.SqlParseException


def get_java_class(instance):
"""Get the stringified class name of a java object"""
Expand Down
4 changes: 2 additions & 2 deletions dask_sql/mappings.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
import pandas as pd

from dask_sql._compat import FLOAT_NAN_IMPLEMENTED
from dask_sql.java import SqlTypeName
from dask_sql.java import org

logger = logging.getLogger(__name__)

SqlTypeName = org.apache.calcite.sql.type.SqlTypeName

# Default mapping between python types and SQL types
_PYTHON_TO_SQL = {
Expand Down
2 changes: 2 additions & 0 deletions dask_sql/physical/rel/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ def fix_column_to_row_type(
"""
field_names = [str(x) for x in row_type.getFieldNames()]

logger.debug(f"Renaming {cc.columns} to {field_names}")

cc = cc.rename(columns=dict(zip(cc.columns, field_names)))

# TODO: We can also check for the types here and do any conversions if needed
Expand Down
2 changes: 1 addition & 1 deletion dask_sql/physical/rel/custom/create_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
import pandas as pd

from dask_sql.datacontainer import ColumnContainer, DataContainer
from dask_sql.java import org
from dask_sql.physical.rel.base import BaseRelPlugin
from dask_sql.utils import convert_sql_kwargs, import_class

if TYPE_CHECKING:
import dask_sql
from dask_sql.java import org

logger = logging.getLogger(__name__)

Expand Down
2 changes: 1 addition & 1 deletion dask_sql/physical/rel/custom/create_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@
from typing import TYPE_CHECKING

from dask_sql.datacontainer import DataContainer
from dask_sql.java import org
from dask_sql.physical.rel.base import BaseRelPlugin
from dask_sql.utils import convert_sql_kwargs, import_class

if TYPE_CHECKING:
import dask_sql
from dask_sql.java import org

logger = logging.getLogger(__name__)

Expand Down
38 changes: 20 additions & 18 deletions dask_sql/physical/rel/logical/__init__.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,25 @@
from .aggregate import LogicalAggregatePlugin
from .filter import LogicalFilterPlugin
from .join import LogicalJoinPlugin
from .project import LogicalProjectPlugin
from .aggregate import DaskAggregatePlugin
from .filter import DaskFilterPlugin
from .join import DaskJoinPlugin
from .limit import DaskLimitPlugin
from .project import DaskProjectPlugin
from .sample import SamplePlugin
from .sort import LogicalSortPlugin
from .table_scan import LogicalTableScanPlugin
from .union import LogicalUnionPlugin
from .values import LogicalValuesPlugin
from .window import LogicalWindowPlugin
from .sort import DaskSortPlugin
from .table_scan import DaskTableScanPlugin
from .union import DaskUnionPlugin
from .values import DaskValuesPlugin
from .window import DaskWindowPlugin

__all__ = [
LogicalAggregatePlugin,
LogicalFilterPlugin,
LogicalJoinPlugin,
LogicalProjectPlugin,
LogicalSortPlugin,
LogicalTableScanPlugin,
LogicalUnionPlugin,
LogicalValuesPlugin,
LogicalWindowPlugin,
DaskAggregatePlugin,
DaskFilterPlugin,
DaskJoinPlugin,
DaskLimitPlugin,
DaskProjectPlugin,
DaskSortPlugin,
DaskTableScanPlugin,
DaskUnionPlugin,
DaskValuesPlugin,
DaskWindowPlugin,
SamplePlugin,
]
6 changes: 3 additions & 3 deletions dask_sql/physical/rel/logical/aggregate.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,9 @@ def get_supported_aggregation(self, series):
return self.custom_aggregation


class LogicalAggregatePlugin(BaseRelPlugin):
class DaskAggregatePlugin(BaseRelPlugin):
"""
A LogicalAggregate is used in GROUP BY clauses, but also
A DaskAggregate is used in GROUP BY clauses, but also
when aggregating a function over the full dataset.

In the first case we need to find out which columns we need to
Expand All @@ -119,7 +119,7 @@ class LogicalAggregatePlugin(BaseRelPlugin):
these things via HINTs.
"""

class_name = "org.apache.calcite.rel.logical.LogicalAggregate"
class_name = "com.dask.sql.nodes.DaskAggregate"

AGGREGATION_MAPPING = {
"$sum0": AggregationSpecification("sum", AggregationOnPandas("sum")),
Expand Down
6 changes: 3 additions & 3 deletions dask_sql/physical/rel/logical/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,13 @@ def filter_or_scalar(df: dd.DataFrame, filter_condition: Union[np.bool_, dd.Seri
return df[filter_condition]


class LogicalFilterPlugin(BaseRelPlugin):
class DaskFilterPlugin(BaseRelPlugin):
"""
LogicalFilter is used on WHERE clauses.
DaskFilter is used on WHERE clauses.
We just evaluate the filter (which is of type RexNode) and apply it
"""

class_name = "org.apache.calcite.rel.logical.LogicalFilter"
class_name = "com.dask.sql.nodes.DaskFilter"

def convert(
self, rel: "org.apache.calcite.rel.RelNode", context: "dask_sql.Context"
Expand Down
6 changes: 3 additions & 3 deletions dask_sql/physical/rel/logical/join.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@
logger = logging.getLogger(__name__)


class LogicalJoinPlugin(BaseRelPlugin):
class DaskJoinPlugin(BaseRelPlugin):
"""
A LogicalJoin is used when (surprise) joining two tables.
A DaskJoin is used when (surprise) joining two tables.
SQL allows for quite complicated joins with difficult conditions.
dask/pandas only knows about equijoins on a specific column.

Expand All @@ -36,7 +36,7 @@ class LogicalJoinPlugin(BaseRelPlugin):
but so far, it is the only solution...
"""

class_name = "org.apache.calcite.rel.logical.LogicalJoin"
class_name = "com.dask.sql.nodes.DaskJoin"

JOIN_TYPE_MAPPING = {
"INNER": "inner",
Expand Down
Loading