diff --git a/dask_planner/src/sql/logical.rs b/dask_planner/src/sql/logical.rs index 9f4505807..9f499a565 100644 --- a/dask_planner/src/sql/logical.rs +++ b/dask_planner/src/sql/logical.rs @@ -3,6 +3,7 @@ use crate::sql::types::rel_data_type::RelDataType; use crate::sql::types::rel_data_type_field::RelDataTypeField; mod aggregate; +mod cross_join; mod explain; mod filter; mod join; @@ -63,6 +64,11 @@ impl PyLogicalPlan { to_py_plan(self.current_node.as_ref()) } + /// LogicalPlan::CrossJoin as PyCrossJoin + pub fn cross_join(&self) -> PyResult { + to_py_plan(self.current_node.as_ref()) + } + /// LogicalPlan::Explain as PyExplain pub fn explain(&self) -> PyResult { to_py_plan(self.current_node.as_ref()) @@ -85,12 +91,7 @@ impl PyLogicalPlan { /// LogicalPlan::Sort as PySort pub fn sort(&self) -> PyResult { - self.current_node - .as_ref() - .map(|plan| plan.clone().into()) - .ok_or(PyErr::new::( - "current_node was None", - )) + to_py_plan(self.current_node.as_ref()) } /// Gets the "input" for the current LogicalPlan diff --git a/dask_planner/src/sql/logical/cross_join.rs b/dask_planner/src/sql/logical/cross_join.rs new file mode 100644 index 000000000..e04f2ec19 --- /dev/null +++ b/dask_planner/src/sql/logical/cross_join.rs @@ -0,0 +1,20 @@ +use datafusion::logical_plan::{CrossJoin, LogicalPlan}; + +use pyo3::prelude::*; + +#[pyclass(name = "CrossJoin", module = "dask_planner", subclass)] +#[derive(Clone)] +pub struct PyCrossJoin { + cross_join: CrossJoin, +} + +impl From for PyCrossJoin { + fn from(logical_plan: LogicalPlan) -> PyCrossJoin { + match logical_plan { + LogicalPlan::CrossJoin(cross_join) => PyCrossJoin { + cross_join: cross_join, + }, + _ => panic!("something went wrong here"), + } + } +} diff --git a/dask_planner/src/sql/logical/join.rs b/dask_planner/src/sql/logical/join.rs index 08808dd11..f2d47e809 100644 --- a/dask_planner/src/sql/logical/join.rs +++ b/dask_planner/src/sql/logical/join.rs @@ -1,7 +1,7 @@ use crate::sql::column; use datafusion::logical_expr::logical_plan::Join; -pub use datafusion::logical_expr::{logical_plan::JoinType, LogicalPlan}; +use datafusion::logical_plan::{JoinType, LogicalPlan}; use pyo3::prelude::*; @@ -51,7 +51,7 @@ impl PyJoin { impl From for PyJoin { fn from(logical_plan: LogicalPlan) -> PyJoin { match logical_plan { - LogicalPlan::Join(join) => PyJoin { join }, + LogicalPlan::Join(join) => PyJoin { join: join }, _ => panic!("something went wrong here"), } } diff --git a/dask_sql/context.py b/dask_sql/context.py index ccb4a9a72..55f3e7f04 100644 --- a/dask_sql/context.py +++ b/dask_sql/context.py @@ -97,6 +97,7 @@ def __init__(self, logging_level=logging.INFO): RelConverter.add_plugin_class(logical.DaskAggregatePlugin, replace=False) RelConverter.add_plugin_class(logical.DaskFilterPlugin, replace=False) RelConverter.add_plugin_class(logical.DaskJoinPlugin, replace=False) + RelConverter.add_plugin_class(logical.DaskCrossJoinPlugin, replace=False) RelConverter.add_plugin_class(logical.DaskLimitPlugin, replace=False) RelConverter.add_plugin_class(logical.DaskProjectPlugin, replace=False) RelConverter.add_plugin_class(logical.DaskSortPlugin, replace=False) diff --git a/dask_sql/mappings.py b/dask_sql/mappings.py index 47d8624da..3e1c895bf 100644 --- a/dask_sql/mappings.py +++ b/dask_sql/mappings.py @@ -40,7 +40,9 @@ } if FLOAT_NAN_IMPLEMENTED: # pragma: no cover - _PYTHON_TO_SQL.update({pd.Float32Dtype(): "FLOAT", pd.Float64Dtype(): "FLOAT"}) + _PYTHON_TO_SQL.update( + {pd.Float32Dtype(): SqlTypeName.FLOAT, pd.Float64Dtype(): SqlTypeName.DOUBLE} + ) # Default mapping between SQL types and python types # for values diff --git a/dask_sql/physical/rel/logical/__init__.py b/dask_sql/physical/rel/logical/__init__.py index ba835a65d..10fbb2fba 100644 --- a/dask_sql/physical/rel/logical/__init__.py +++ b/dask_sql/physical/rel/logical/__init__.py @@ -1,4 +1,5 @@ from .aggregate import DaskAggregatePlugin +from .cross_join import DaskCrossJoinPlugin from .explain import ExplainPlugin from .filter import DaskFilterPlugin from .join import DaskJoinPlugin @@ -15,6 +16,7 @@ DaskAggregatePlugin, DaskFilterPlugin, DaskJoinPlugin, + DaskCrossJoinPlugin, DaskLimitPlugin, DaskProjectPlugin, DaskSortPlugin, diff --git a/dask_sql/physical/rel/logical/cross_join.py b/dask_sql/physical/rel/logical/cross_join.py new file mode 100644 index 000000000..a5c9cd984 --- /dev/null +++ b/dask_sql/physical/rel/logical/cross_join.py @@ -0,0 +1,41 @@ +import logging +from typing import TYPE_CHECKING + +import dask.dataframe as dd + +import dask_sql.utils as utils +from dask_sql.datacontainer import ColumnContainer, DataContainer +from dask_sql.physical.rel.base import BaseRelPlugin + +if TYPE_CHECKING: + import dask_sql + from dask_planner.rust import LogicalPlan + +logger = logging.getLogger(__name__) + + +class DaskCrossJoinPlugin(BaseRelPlugin): + """ + While similar to `DaskJoinPlugin` a `CrossJoin` has enough of a differing + structure to justify its own plugin. This in turn limits the number of + Dask tasks that are generated for `CrossJoin`'s when compared to a + standard `Join` + """ + + class_name = "CrossJoin" + + def convert(self, rel: "LogicalPlan", context: "dask_sql.Context") -> DataContainer: + # We now have two inputs (from left and right), so we fetch them both + dc_lhs, dc_rhs = self.assert_inputs(rel, 2, context) + + df_lhs = dc_lhs.df + df_rhs = dc_rhs.df + + # Create a 'key' column in both DataFrames to join on + cross_join_key = utils.new_temporary_column(df_lhs) + df_lhs[cross_join_key] = 1 + df_rhs[cross_join_key] = 1 + + result = dd.merge(df_lhs, df_rhs, on=cross_join_key).drop(cross_join_key, 1) + + return DataContainer(result, ColumnContainer(result.columns)) diff --git a/tests/integration/fixtures.py b/tests/integration/fixtures.py index e1f53c6db..8dec835c3 100644 --- a/tests/integration/fixtures.py +++ b/tests/integration/fixtures.py @@ -52,6 +52,11 @@ def df(): ) +@pytest.fixture() +def department_table(): + return pd.DataFrame({"department_name": ["English", "Math", "Science"]}) + + @pytest.fixture() def user_table_1(): return pd.DataFrame({"user_id": [2, 1, 2, 3], "b": [3, 3, 1, 3]}) @@ -159,6 +164,7 @@ def c( df_simple, df_wide, df, + department_table, user_table_1, user_table_2, long_table, @@ -177,6 +183,7 @@ def c( "df_simple": df_simple, "df_wide": df_wide, "df": df, + "department_table": department_table, "user_table_1": user_table_1, "user_table_2": user_table_2, "long_table": long_table, diff --git a/tests/integration/test_join.py b/tests/integration/test_join.py index 3d178bc4c..17a0182db 100644 --- a/tests/integration/test_join.py +++ b/tests/integration/test_join.py @@ -110,6 +110,22 @@ def test_join_right(c): assert_eq(return_df, expected_df, check_index=False) +def test_join_cross(c, user_table_1, department_table): + return_df = c.sql( + """ + SELECT user_id, b, department_name + FROM user_table_1, department_table + """ + ) + + user_table_1["key"] = 1 + department_table["key"] = 1 + + expected_df = dd.merge(user_table_1, department_table, on="key").drop("key", 1) + + assert_eq(return_df, expected_df, check_index=False) + + @pytest.mark.skip(reason="WIP DataFusion") def test_join_complex(c): return_df = c.sql(