Skip to content
This repository was archived by the owner on Aug 29, 2024. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions dask_sql/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from ._version import get_version
from .cmd import cmd_loop
from .context import Context
from .datacontainer import Statistics
from .server.app import run_server

__version__ = get_version()

__all__ = [__version__, cmd_loop, Context, run_server, Statistics]
78 changes: 48 additions & 30 deletions dask_sql/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import inspect
import logging
import warnings
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Tuple, Union
from typing import Any, Callable, Dict, List, Tuple, Union

import dask.dataframe as dd
import pandas as pd
Expand All @@ -20,28 +20,16 @@
DataContainer,
FunctionDescription,
SchemaContainer,
Statistics,
)
from dask_sql.input_utils import InputType, InputUtil
from dask_sql.integrations.ipython import ipython_integration
from dask_sql.java import (
DaskAggregateFunction,
DaskScalarFunction,
DaskSchema,
DaskTable,
RelationalAlgebraGenerator,
RelationalAlgebraGeneratorBuilder,
SqlParseException,
ValidationException,
get_java_class,
)
from dask_sql.java import com, get_java_class, org
from dask_sql.mappings import python_to_sql_type
from dask_sql.physical.rel import RelConverter, custom, logical
from dask_sql.physical.rex import RexConverter, core
from dask_sql.utils import ParsingException

if TYPE_CHECKING:
from dask_sql.java import org

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -90,15 +78,16 @@ def __init__(self):
self.sql_server = None

# Register any default plugins, if nothing was registered before.
RelConverter.add_plugin_class(logical.LogicalAggregatePlugin, replace=False)
RelConverter.add_plugin_class(logical.LogicalFilterPlugin, replace=False)
RelConverter.add_plugin_class(logical.LogicalJoinPlugin, replace=False)
RelConverter.add_plugin_class(logical.LogicalProjectPlugin, replace=False)
RelConverter.add_plugin_class(logical.LogicalSortPlugin, replace=False)
RelConverter.add_plugin_class(logical.LogicalTableScanPlugin, replace=False)
RelConverter.add_plugin_class(logical.LogicalUnionPlugin, replace=False)
RelConverter.add_plugin_class(logical.LogicalValuesPlugin, replace=False)
RelConverter.add_plugin_class(logical.LogicalWindowPlugin, replace=False)
RelConverter.add_plugin_class(logical.DaskAggregatePlugin, replace=False)
RelConverter.add_plugin_class(logical.DaskFilterPlugin, replace=False)
RelConverter.add_plugin_class(logical.DaskJoinPlugin, replace=False)
RelConverter.add_plugin_class(logical.DaskLimitPlugin, replace=False)
RelConverter.add_plugin_class(logical.DaskProjectPlugin, replace=False)
RelConverter.add_plugin_class(logical.DaskSortPlugin, replace=False)
RelConverter.add_plugin_class(logical.DaskTableScanPlugin, replace=False)
RelConverter.add_plugin_class(logical.DaskUnionPlugin, replace=False)
RelConverter.add_plugin_class(logical.DaskValuesPlugin, replace=False)
RelConverter.add_plugin_class(logical.DaskWindowPlugin, replace=False)
RelConverter.add_plugin_class(logical.SamplePlugin, replace=False)
RelConverter.add_plugin_class(custom.AnalyzeTablePlugin, replace=False)
RelConverter.add_plugin_class(custom.CreateExperimentPlugin, replace=False)
Expand Down Expand Up @@ -140,6 +129,7 @@ def create_table(
format: str = None,
persist: bool = False,
schema_name: str = None,
statistics: Statistics = None,
gpu: bool = False,
**kwargs,
):
Expand Down Expand Up @@ -202,6 +192,11 @@ def create_table(
If set to "memory", load the data from a published dataset in the dask cluster.
persist (:obj:`bool`): Only used when passing a string into the ``input`` parameter.
Set to true to turn on loading the file data directly into memory.
schema_name: (:obj:`str`): in which schema to create the table. By default, will use the currently selected schema.
statistics: (:obj:`Statistics`): if given, use these statistics during the cost-based optimization. If no
statistics are provided, we will just assume 100 rows.
gpu: (:obj:`bool`): if set to true, use dask-cudf to run the data frame calculations on your GPU.
Please note that the GPU support is currently not covering all of dask-sql's SQL language.
**kwargs: Additional arguments for specific formats. See :ref:`data_input` for more information.

"""
Expand All @@ -220,6 +215,8 @@ def create_table(
**kwargs,
)
self.schema[schema_name].tables[table_name.lower()] = dc
if statistics:
self.schema[schema_name].statistics[table_name.lower()] = statistics

def register_dask_table(self, df: dd.DataFrame, name: str, *args, **kwargs):
"""
Expand Down Expand Up @@ -765,14 +762,26 @@ def _prepare_schemas(self):
"""
schema_list = []

DaskTable = com.dask.sql.schema.DaskTable
DaskAggregateFunction = com.dask.sql.schema.DaskAggregateFunction
DaskScalarFunction = com.dask.sql.schema.DaskScalarFunction
DaskSchema = com.dask.sql.schema.DaskSchema

for schema_name, schema in self.schema.items():
java_schema = DaskSchema(schema_name)

if not schema.tables:
logger.warning("No tables are registered.")

for name, dc in schema.tables.items():
table = DaskTable(name)
row_count = (
schema.statistics[name].row_count
if name in schema.statistics
else None
)
if row_count is not None:
row_count = float(row_count)
table = DaskTable(name, row_count)
df = dc.df
logger.debug(
f"Adding table '{name}' to schema with columns: {list(df.columns)}"
Expand Down Expand Up @@ -824,6 +833,10 @@ def _get_ral(self, sql):
# get the schema of what we currently have registered
schemas = self._prepare_schemas()

RelationalAlgebraGeneratorBuilder = (
com.dask.sql.application.RelationalAlgebraGeneratorBuilder
)

# True if the SQL query should be case sensitive and False otherwise
case_sensitive = (
self.schema[self.schema_name]
Expand All @@ -835,12 +848,16 @@ def _get_ral(self, sql):
self.schema_name, case_sensitive
)
for schema in schemas:
generator_builder.addSchema(schema)
generator_builder = generator_builder.addSchema(schema)
generator = generator_builder.build()
default_dialect = generator.getDialect()

logger.debug(f"Using dialect: {get_java_class(default_dialect)}")

ValidationException = org.apache.calcite.tools.ValidationException
SqlParseException = org.apache.calcite.sql.parser.SqlParseException
CalciteContextException = org.apache.calcite.runtime.CalciteContextException

try:
sqlNode = generator.getSqlNode(sql)
sqlNodeClass = get_java_class(sqlNode)
Expand All @@ -850,16 +867,15 @@ def _get_ral(self, sql):
rel_string = ""

if not sqlNodeClass.startswith("com.dask.sql.parser."):
validatedSqlNode = generator.getValidatedNode(sqlNode)
nonOptimizedRelNode = generator.getRelationalAlgebra(validatedSqlNode)
nonOptimizedRelNode = generator.getRelationalAlgebra(sqlNode)
# Optimization might remove some alias projects. Make sure to keep them here.
select_names = [
str(name)
for name in nonOptimizedRelNode.getRowType().getFieldNames()
]
rel = generator.getOptimizedRelationalAlgebra(nonOptimizedRelNode)
rel_string = str(generator.getRelationalAlgebraString(rel))
except (ValidationException, SqlParseException) as e:
except (ValidationException, SqlParseException, CalciteContextException) as e:
logger.debug(f"Original exception raised by Java:\n {e}")
# We do not want to re-raise an exception here
# as this would print the full java stack trace
Expand Down Expand Up @@ -895,7 +911,9 @@ def _get_ral(self, sql):

def _to_sql_string(self, s: "org.apache.calcite.sql.SqlNode", default_dialect=None):
if default_dialect is None:
default_dialect = RelationalAlgebraGenerator.getDialect()
default_dialect = (
com.dask.sql.application.RelationalAlgebraGenerator.getDialect()
)

try:
return str(s.toSqlString(default_dialect))
Expand Down
12 changes: 12 additions & 0 deletions dask_sql/datacontainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,10 +231,22 @@ def __hash__(self):
return (self.func, self.row_udf).__hash__()


class Statistics:
"""
Statistics are used during the cost-based optimization.
Currently, only the row count is supported, more
properties might follow. It needs to be provided by the user.
"""

def __init__(self, row_count: int) -> None:
self.row_count = row_count


class SchemaContainer:
def __init__(self, name: str):
self.__name__ = name
self.tables: Dict[str, DataContainer] = {}
self.statistics: Dict[str, Statistics] = {}
self.experiments: Dict[str, pd.DataFrame] = {}
self.models: Dict[str, Tuple[Any, List[str]]] = {}
self.functions: Dict[str, UDF] = {}
Expand Down
12 changes: 0 additions & 12 deletions dask_sql/java.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,18 +85,6 @@ def _set_or_check_java_home():
org = jpype.JPackage("org")
java = jpype.JPackage("java")

DaskTable = com.dask.sql.schema.DaskTable
DaskAggregateFunction = com.dask.sql.schema.DaskAggregateFunction
DaskScalarFunction = com.dask.sql.schema.DaskScalarFunction
DaskSchema = com.dask.sql.schema.DaskSchema
RelationalAlgebraGenerator = com.dask.sql.application.RelationalAlgebraGenerator
RelationalAlgebraGeneratorBuilder = (
com.dask.sql.application.RelationalAlgebraGeneratorBuilder
)
SqlTypeName = org.apache.calcite.sql.type.SqlTypeName
ValidationException = org.apache.calcite.tools.ValidationException
SqlParseException = org.apache.calcite.sql.parser.SqlParseException


def get_java_class(instance):
"""Get the stringified class name of a java object"""
Expand Down
4 changes: 2 additions & 2 deletions dask_sql/mappings.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
import pandas as pd

from dask_sql._compat import FLOAT_NAN_IMPLEMENTED
from dask_sql.java import SqlTypeName
from dask_sql.java import org

logger = logging.getLogger(__name__)

SqlTypeName = org.apache.calcite.sql.type.SqlTypeName

# Default mapping between python types and SQL types
_PYTHON_TO_SQL = {
Expand Down
2 changes: 2 additions & 0 deletions dask_sql/physical/rel/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ def fix_column_to_row_type(
"""
field_names = [str(x) for x in row_type.getFieldNames()]

logger.debug(f"Renaming {cc.columns} to {field_names}")

cc = cc.rename(columns=dict(zip(cc.columns, field_names)))

# TODO: We can also check for the types here and do any conversions if needed
Expand Down
2 changes: 1 addition & 1 deletion dask_sql/physical/rel/custom/create_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
import pandas as pd

from dask_sql.datacontainer import ColumnContainer, DataContainer
from dask_sql.java import org
from dask_sql.physical.rel.base import BaseRelPlugin
from dask_sql.utils import convert_sql_kwargs, import_class

if TYPE_CHECKING:
import dask_sql
from dask_sql.java import org

logger = logging.getLogger(__name__)

Expand Down
2 changes: 1 addition & 1 deletion dask_sql/physical/rel/custom/create_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@
from dask import delayed

from dask_sql.datacontainer import DataContainer
from dask_sql.java import org
from dask_sql.physical.rel.base import BaseRelPlugin
from dask_sql.utils import convert_sql_kwargs, import_class

if TYPE_CHECKING:
import dask_sql
from dask_sql.java import org

logger = logging.getLogger(__name__)

Expand Down
38 changes: 20 additions & 18 deletions dask_sql/physical/rel/logical/__init__.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,25 @@
from .aggregate import LogicalAggregatePlugin
from .filter import LogicalFilterPlugin
from .join import LogicalJoinPlugin
from .project import LogicalProjectPlugin
from .aggregate import DaskAggregatePlugin
from .filter import DaskFilterPlugin
from .join import DaskJoinPlugin
from .limit import DaskLimitPlugin
from .project import DaskProjectPlugin
from .sample import SamplePlugin
from .sort import LogicalSortPlugin
from .table_scan import LogicalTableScanPlugin
from .union import LogicalUnionPlugin
from .values import LogicalValuesPlugin
from .window import LogicalWindowPlugin
from .sort import DaskSortPlugin
from .table_scan import DaskTableScanPlugin
from .union import DaskUnionPlugin
from .values import DaskValuesPlugin
from .window import DaskWindowPlugin

__all__ = [
LogicalAggregatePlugin,
LogicalFilterPlugin,
LogicalJoinPlugin,
LogicalProjectPlugin,
LogicalSortPlugin,
LogicalTableScanPlugin,
LogicalUnionPlugin,
LogicalValuesPlugin,
LogicalWindowPlugin,
DaskAggregatePlugin,
DaskFilterPlugin,
DaskJoinPlugin,
DaskLimitPlugin,
DaskProjectPlugin,
DaskSortPlugin,
DaskTableScanPlugin,
DaskUnionPlugin,
DaskValuesPlugin,
DaskWindowPlugin,
SamplePlugin,
]
6 changes: 3 additions & 3 deletions dask_sql/physical/rel/logical/aggregate.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,9 @@ def get_supported_aggregation(self, series):
return self.custom_aggregation


class LogicalAggregatePlugin(BaseRelPlugin):
class DaskAggregatePlugin(BaseRelPlugin):
"""
A LogicalAggregate is used in GROUP BY clauses, but also
A DaskAggregate is used in GROUP BY clauses, but also
when aggregating a function over the full dataset.

In the first case we need to find out which columns we need to
Expand All @@ -119,7 +119,7 @@ class LogicalAggregatePlugin(BaseRelPlugin):
these things via HINTs.
"""

class_name = "org.apache.calcite.rel.logical.LogicalAggregate"
class_name = "com.dask.sql.nodes.DaskAggregate"

AGGREGATION_MAPPING = {
"$sum0": AggregationSpecification("sum", AggregationOnPandas("sum")),
Expand Down
6 changes: 3 additions & 3 deletions dask_sql/physical/rel/logical/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,13 @@ def filter_or_scalar(df: dd.DataFrame, filter_condition: Union[np.bool_, dd.Seri
return df[filter_condition]


class LogicalFilterPlugin(BaseRelPlugin):
class DaskFilterPlugin(BaseRelPlugin):
"""
LogicalFilter is used on WHERE clauses.
DaskFilter is used on WHERE clauses.
We just evaluate the filter (which is of type RexNode) and apply it
"""

class_name = "org.apache.calcite.rel.logical.LogicalFilter"
class_name = "com.dask.sql.nodes.DaskFilter"

def convert(
self, rel: "org.apache.calcite.rel.RelNode", context: "dask_sql.Context"
Expand Down
6 changes: 3 additions & 3 deletions dask_sql/physical/rel/logical/join.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@
logger = logging.getLogger(__name__)


class LogicalJoinPlugin(BaseRelPlugin):
class DaskJoinPlugin(BaseRelPlugin):
"""
A LogicalJoin is used when (surprise) joining two tables.
A DaskJoin is used when (surprise) joining two tables.
SQL allows for quite complicated joins with difficult conditions.
dask/pandas only knows about equijoins on a specific column.

Expand All @@ -36,7 +36,7 @@ class LogicalJoinPlugin(BaseRelPlugin):
but so far, it is the only solution...
"""

class_name = "org.apache.calcite.rel.logical.LogicalJoin"
class_name = "com.dask.sql.nodes.DaskJoin"

JOIN_TYPE_MAPPING = {
"INNER": "inner",
Expand Down
Loading