diff --git a/.gitignore b/.gitignore index 950c92821..c25366594 100644 --- a/.gitignore +++ b/.gitignore @@ -61,3 +61,6 @@ dask-worker-space/ node_modules/ docs/source/_build/ dask_planner/Cargo.lock + +# Ignore development specific local testing files +dev_tests diff --git a/dask_planner/src/expression.rs b/dask_planner/src/expression.rs index f482ed9c0..95ed295c1 100644 --- a/dask_planner/src/expression.rs +++ b/dask_planner/src/expression.rs @@ -2,7 +2,7 @@ use crate::sql::logical; use crate::sql::types::RexType; use pyo3::prelude::*; -use std::convert::{From, Into}; +use std::convert::From; use datafusion::error::{DataFusionError, Result}; @@ -64,10 +64,38 @@ impl PyExpr { } } - fn _column_name(&self, plan: LogicalPlan) -> Result { + /// Determines the name of the `Expr` instance by examining the LogicalPlan + pub fn _column_name(&self, plan: &LogicalPlan) -> Result { let field = expr_to_field(&self.expr, &plan)?; Ok(field.unqualified_column().name.clone()) } + + fn _rex_type(&self, expr: &Expr) -> RexType { + match expr { + Expr::Alias(..) => RexType::Reference, + Expr::Column(..) => RexType::Reference, + Expr::ScalarVariable(..) => RexType::Literal, + Expr::Literal(..) => RexType::Literal, + Expr::BinaryExpr { .. } => RexType::Call, + Expr::Not(..) => RexType::Call, + Expr::IsNotNull(..) => RexType::Call, + Expr::Negative(..) => RexType::Call, + Expr::GetIndexedField { .. } => RexType::Reference, + Expr::IsNull(..) => RexType::Call, + Expr::Between { .. } => RexType::Call, + Expr::Case { .. } => RexType::Call, + Expr::Cast { .. } => RexType::Call, + Expr::TryCast { .. } => RexType::Call, + Expr::Sort { .. } => RexType::Call, + Expr::ScalarFunction { .. } => RexType::Call, + Expr::AggregateFunction { .. } => RexType::Call, + Expr::WindowFunction { .. } => RexType::Call, + Expr::AggregateUDF { .. } => RexType::Call, + Expr::InList { .. } => RexType::Call, + Expr::Wildcard => RexType::Call, + _ => RexType::Other, + } + } } #[pymethods] @@ -147,36 +175,13 @@ impl PyExpr { /// Determines the type of this Expr based on its variant #[pyo3(name = "getRexType")] - pub fn rex_type(&self) -> RexType { - match &self.expr { - Expr::Alias(expr, name) => RexType::Reference, - Expr::Column(..) => RexType::Reference, - Expr::ScalarVariable(..) => RexType::Literal, - Expr::Literal(..) => RexType::Literal, - Expr::BinaryExpr { .. } => RexType::Call, - Expr::Not(..) => RexType::Call, - Expr::IsNotNull(..) => RexType::Call, - Expr::Negative(..) => RexType::Call, - Expr::GetIndexedField { .. } => RexType::Reference, - Expr::IsNull(..) => RexType::Call, - Expr::Between { .. } => RexType::Call, - Expr::Case { .. } => RexType::Call, - Expr::Cast { .. } => RexType::Call, - Expr::TryCast { .. } => RexType::Call, - Expr::Sort { .. } => RexType::Call, - Expr::ScalarFunction { .. } => RexType::Call, - Expr::AggregateFunction { .. } => RexType::Call, - Expr::WindowFunction { .. } => RexType::Call, - Expr::AggregateUDF { .. } => RexType::Call, - Expr::InList { .. } => RexType::Call, - Expr::Wildcard => RexType::Call, - _ => RexType::Other, - } + pub fn rex_type(&self) -> PyResult { + Ok(self._rex_type(&self.expr)) } /// Python friendly shim code to get the name of a column referenced by an expression pub fn column_name(&self, mut plan: logical::PyLogicalPlan) -> PyResult { - self._column_name(plan.current_node()) + self._column_name(&plan.current_node()) .map_err(|e| py_runtime_err(e)) } @@ -184,16 +189,10 @@ impl PyExpr { #[pyo3(name = "getOperands")] pub fn get_operands(&self) -> PyResult> { match &self.expr { - Expr::BinaryExpr { left, op: _, right } => { - let mut operands: Vec = Vec::new(); - let left_desc: Expr = *left.clone(); - let py_left: PyExpr = PyExpr::from(left_desc, self.input_plan.clone()); - operands.push(py_left); - let right_desc: Expr = *right.clone(); - let py_right: PyExpr = PyExpr::from(right_desc, self.input_plan.clone()); - operands.push(py_right); - Ok(operands) - } + Expr::BinaryExpr { left, right, .. } => Ok(vec![ + PyExpr::from(*left.clone(), self.input_plan.clone()), + PyExpr::from(*right.clone(), self.input_plan.clone()), + ]), Expr::ScalarFunction { fun: _, args } => { let mut operands: Vec = Vec::new(); for arg in args { @@ -203,15 +202,44 @@ impl PyExpr { Ok(operands) } Expr::Cast { expr, data_type: _ } => { + Ok(vec![PyExpr::from(*expr.clone(), self.input_plan.clone())]) + } + Expr::Case { + expr, + when_then_expr, + else_expr, + } => { let mut operands: Vec = Vec::new(); - let ex: Expr = *expr.clone(); - let py_ex: PyExpr = PyExpr::from(ex, self.input_plan.clone()); - operands.push(py_ex); + + if let Some(e) = expr { + operands.push(PyExpr::from(*e.clone(), self.input_plan.clone())); + }; + + for (when, then) in when_then_expr { + operands.push(PyExpr::from(*when.clone(), self.input_plan.clone())); + operands.push(PyExpr::from(*then.clone(), self.input_plan.clone())); + } + + if let Some(e) = else_expr { + operands.push(PyExpr::from(*e.clone(), self.input_plan.clone())); + }; + Ok(operands) } - _ => Err(PyErr::new::( - "unknown Expr type encountered", - )), + Expr::Between { + expr, + negated: _, + low, + high, + } => Ok(vec![ + PyExpr::from(*expr.clone(), self.input_plan.clone()), + PyExpr::from(*low.clone(), self.input_plan.clone()), + PyExpr::from(*high.clone(), self.input_plan.clone()), + ]), + _ => Err(PyErr::new::(format!( + "unknown Expr type {:?} encountered", + &self.expr + ))), } } @@ -224,13 +252,13 @@ impl PyExpr { right: _, } => Ok(format!("{}", op)), Expr::ScalarFunction { fun, args: _ } => Ok(format!("{}", fun)), - Expr::Cast { - expr: _, - data_type: _, - } => Ok(String::from("cast")), - _ => Err(PyErr::new::( - "Catch all triggered ....", - )), + Expr::Cast { .. } => Ok("cast".to_string()), + Expr::Between { .. } => Ok("between".to_string()), + Expr::Case { .. } => Ok("case".to_string()), + _ => Err(PyErr::new::(format!( + "Catch all triggered for get_operator_name: {:?}", + &self.expr + ))), } } diff --git a/dask_planner/src/sql/logical.rs b/dask_planner/src/sql/logical.rs index bb4ff53b0..b6a961e67 100644 --- a/dask_planner/src/sql/logical.rs +++ b/dask_planner/src/sql/logical.rs @@ -1,7 +1,6 @@ use crate::sql::table; use crate::sql::types::rel_data_type::RelDataType; use crate::sql::types::rel_data_type_field::RelDataTypeField; -use datafusion::logical_plan::DFField; mod aggregate; mod filter; @@ -12,7 +11,6 @@ pub use datafusion_expr::LogicalPlan; use datafusion::common::Result; use datafusion::prelude::Column; -use pyo3::ffi::Py_FatalError; use crate::sql::exceptions::py_type_err; use pyo3::prelude::*; diff --git a/dask_planner/src/sql/logical/join.rs b/dask_planner/src/sql/logical/join.rs index ccb77ef6b..fbed464ca 100644 --- a/dask_planner/src/sql/logical/join.rs +++ b/dask_planner/src/sql/logical/join.rs @@ -27,7 +27,6 @@ impl PyJoin { let mut join_conditions: Vec<(column::PyColumn, column::PyColumn)> = Vec::new(); for (mut lhs, mut rhs) in self.join.on.clone() { - println!("lhs: {:?} rhs: {:?}", lhs, rhs); lhs.relation = Some(lhs_table_name.clone()); rhs.relation = Some(rhs_table_name.clone()); join_conditions.push((lhs.into(), rhs.into())); diff --git a/dask_planner/src/sql/logical/projection.rs b/dask_planner/src/sql/logical/projection.rs index c3144a15b..bbce9a137 100644 --- a/dask_planner/src/sql/logical/projection.rs +++ b/dask_planner/src/sql/logical/projection.rs @@ -17,9 +17,8 @@ impl PyProjection { let mut projs: Vec = Vec::new(); match &local_expr.expr { Expr::Alias(expr, _name) => { - let ex: Expr = *expr.clone(); - let mut py_expr: PyExpr = PyExpr::from(ex, Some(self.projection.input.clone())); - py_expr.input_plan = local_expr.input_plan.clone(); + let py_expr: PyExpr = + PyExpr::from(*expr.clone(), Some(self.projection.input.clone())); projs.extend_from_slice(self.projected_expressions(&py_expr).as_slice()); } _ => projs.push(local_expr.clone()), @@ -30,57 +29,6 @@ impl PyProjection { #[pymethods] impl PyProjection { - #[pyo3(name = "getColumnName")] - fn column_name(&mut self, expr: PyExpr) -> PyResult { - let mut val: String = String::from("OK"); - match expr.expr { - Expr::Alias(expr, _alias) => match expr.as_ref() { - Expr::Column(col) => { - let index = self.projection.input.schema().index_of_column(col).unwrap(); - match self.projection.input.as_ref() { - LogicalPlan::Aggregate(agg) => { - let mut exprs = agg.group_expr.clone(); - exprs.extend_from_slice(&agg.aggr_expr); - match &exprs[index] { - Expr::AggregateFunction { args, .. } => match &args[0] { - Expr::Column(col) => { - println!("AGGREGATE COLUMN IS {}", col.name); - val = col.name.clone(); - } - _ => unimplemented!("projection.rs column_name is unimplemented for Expr variant: {:?}", &args[0]), - }, - _ => unimplemented!("projection.rs column_name is unimplemented for Expr variant: {:?}", &exprs[index]), - } - } - LogicalPlan::TableScan(table_scan) => val = table_scan.table_name.clone(), - _ => unimplemented!("projection.rs column_name is unimplemented for LogicalPlan variant: {:?}", self.projection.input), - } - } - Expr::Cast { expr, data_type: _ } => { - let ex_type: Expr = *expr.clone(); - let py_type: PyExpr = - PyExpr::from(ex_type, Some(self.projection.input.clone())); - val = self.column_name(py_type).unwrap(); - println!("Setting col name to: {:?}", val); - } - _ => panic!("not supported: {:?}", expr), - }, - Expr::Column(col) => val = col.name.clone(), - Expr::Cast { expr, data_type: _ } => { - let ex_type: Expr = *expr; - let py_type: PyExpr = PyExpr::from(ex_type, Some(self.projection.input.clone())); - val = self.column_name(py_type).unwrap() - } - _ => { - panic!( - "column_name is unimplemented for Expr variant: {:?}", - expr.expr - ); - } - } - Ok(val) - } - #[pyo3(name = "getNamedProjects")] fn named_projects(&mut self) -> PyResult> { let mut named: Vec<(String, PyExpr)> = Vec::new(); @@ -88,8 +36,9 @@ impl PyProjection { let mut py_expr: PyExpr = PyExpr::from(expression, Some(self.projection.input.clone())); py_expr.input_plan = Some(self.projection.input.clone()); for expr in self.projected_expressions(&py_expr) { - let name: String = self.column_name(expr.clone()).unwrap(); - named.push((name, expr.clone())); + if let Ok(name) = expr._column_name(&*self.projection.input) { + named.push((name, expr.clone())); + } } } Ok(named) diff --git a/dask_planner/src/sql/table.rs b/dask_planner/src/sql/table.rs index eebc6ff7f..10b1e7ccc 100644 --- a/dask_planner/src/sql/table.rs +++ b/dask_planner/src/sql/table.rs @@ -125,7 +125,6 @@ impl DaskTable { qualified_name.push(table_scan.table_name); } _ => { - println!("Nothing matches"); qualified_name.push(self.name.clone()); } } diff --git a/dask_planner/src/sql/types.rs b/dask_planner/src/sql/types.rs index 618786d88..2765664df 100644 --- a/dask_planner/src/sql/types.rs +++ b/dask_planner/src/sql/types.rs @@ -49,8 +49,6 @@ impl DaskTypeMap { #[new] #[args(sql_type, py_kwargs = "**")] fn new(sql_type: SqlTypeName, py_kwargs: Option<&PyDict>) -> Self { - // println!("sql_type={:?} - py_kwargs={:?}", sql_type, py_kwargs); - let d_type: DataType = match sql_type { SqlTypeName::TIMESTAMP_WITH_LOCAL_TIME_ZONE => { let (unit, tz) = match py_kwargs { @@ -206,8 +204,7 @@ impl SqlTypeName { SqlTypeName::DATE => DataType::Date64, SqlTypeName::VARCHAR => DataType::Utf8, _ => { - println!("Type: {:?}", self); - todo!(); + todo!("Type: {:?}", self); } } } diff --git a/dask_sql/physical/rel/logical/limit.py b/dask_sql/physical/rel/logical/limit.py index 58cd68fe8..76773e37e 100644 --- a/dask_sql/physical/rel/logical/limit.py +++ b/dask_sql/physical/rel/logical/limit.py @@ -9,7 +9,7 @@ if TYPE_CHECKING: import dask_sql - from dask_sql.java import org + from dask_planner.rust import LogicalPlan class DaskLimitPlugin(BaseRelPlugin): @@ -18,11 +18,9 @@ class DaskLimitPlugin(BaseRelPlugin): (LIMIT). """ - class_name = "com.dask.sql.nodes.DaskLimit" + class_name = "Limit" - def convert( - self, rel: "org.apache.calcite.rel.RelNode", context: "dask_sql.Context" - ) -> DataContainer: + def convert(self, rel: "LogicalPlan", context: "dask_sql.Context") -> DataContainer: (dc,) = self.assert_inputs(rel, 1, context) df = dc.df cc = dc.column_container diff --git a/dask_sql/physical/rex/convert.py b/dask_sql/physical/rex/convert.py index ac443a5dd..bbbeda1db 100644 --- a/dask_sql/physical/rex/convert.py +++ b/dask_sql/physical/rex/convert.py @@ -14,6 +14,15 @@ logger = logging.getLogger(__name__) +# _REX_TYPE_TO_PLUGIN = { +# "Alias": "InputRef", +# "Column": "InputRef", +# "BinaryExpr": "RexCall", +# "Literal": "RexLiteral", +# "ScalarFunction": "RexCall", +# "Cast": "RexCall", +# } + _REX_TYPE_TO_PLUGIN = { "RexType.Reference": "InputRef", "RexType.Call": "RexCall", diff --git a/dask_sql/physical/rex/core/call.py b/dask_sql/physical/rex/core/call.py index 1a74ce9d5..69fad753c 100644 --- a/dask_sql/physical/rex/core/call.py +++ b/dask_sql/physical/rex/core/call.py @@ -193,6 +193,9 @@ def case(self, *operands) -> SeriesOrScalar: if len(operands) > 3: other = self.case(*operands[2:]) + elif len(operands) == 2: + # CASE/WHEN statement without an ELSE + other = None else: other = operands[2] @@ -765,6 +768,18 @@ def date_part(self, what, df: SeriesOrScalar): raise NotImplementedError(f"Extraction of {what} is not (yet) implemented.") +class BetweenOperation(Operation): + """ + Function for finding rows between two scalar values + """ + + def __init__(self): + super().__init__(self.between) + + def between(self, series: dd.Series, low, high): + return series.between(low, high, inclusive="both") + + class RexCallPlugin(BaseRexPlugin): """ RexCall is used for expressions, which calculate something. @@ -785,6 +800,7 @@ class RexCallPlugin(BaseRexPlugin): OPERATION_MAPPING = { # "binary" functions + "between": BetweenOperation(), "and": ReduceOperation(operation=operator.and_), "or": ReduceOperation(operation=operator.or_), ">": ReduceOperation(operation=operator.gt), @@ -873,20 +889,17 @@ def convert( dc: DataContainer, context: "dask_sql.Context", ) -> SeriesOrScalar: - logger.debug(f"Expression Operands: {expr.getOperands()}") + # Prepare the operands by turning the RexNodes into python expressions operands = [ RexConverter.convert(rel, o, dc, context=context) for o in expr.getOperands() ] - logger.debug(f"Operands: {operands}") - # Now use the operator name in the mapping # TODO: obviously this needs to not be hardcoded but not sure of the best place to pull the value from currently??? schema_name = "root" operator_name = expr.getOperatorName().lower() - logger.debug(f"Operator Name: {operator_name}") try: operation = self.OPERATION_MAPPING[operator_name] diff --git a/tests/integration/test_rex.py b/tests/integration/test_rex.py index b433e3fce..89e92023c 100644 --- a/tests/integration/test_rex.py +++ b/tests/integration/test_rex.py @@ -8,9 +8,6 @@ from tests.utils import assert_eq -@pytest.mark.xfail( - reason="Bumping to Calcite 1.29.0 to address CVE-2021-44228 caused a stack overflow in this test" -) def test_case(c, df): result_df = c.sql( """ diff --git a/tests/integration/test_select.py b/tests/integration/test_select.py index defa9cc9a..f96969e15 100644 --- a/tests/integration/test_select.py +++ b/tests/integration/test_select.py @@ -12,7 +12,6 @@ def test_select(c, df): assert_eq(result_df, df) -@pytest.mark.skip(reason="WIP DataFusion") def test_select_alias(c, df): result_df = c.sql("SELECT a as b, b as a FROM df") @@ -49,7 +48,6 @@ def test_select_different_types(c): assert_eq(result_df, expected_df) -@pytest.mark.skip(reason="WIP DataFusion") def test_select_expr(c, df): result_df = c.sql("SELECT a + 1 AS a, b AS bla, a - 1 FROM df") result_df = result_df @@ -58,7 +56,7 @@ def test_select_expr(c, df): { "a": df["a"] + 1, "bla": df["b"], - '"df"."a" - 1': df["a"] - 1, + "df.a - Int64(1)": df["a"] - 1, } ) assert_eq(result_df, expected_df) @@ -200,7 +198,6 @@ def test_timestamp_casting(c, input_table, request): assert_eq(result_df, expected_df) -@pytest.mark.skip(reason="WIP DataFusion") def test_multi_case_when(c): df = pd.DataFrame({"a": [1, 6, 7, 8, 9]}) c.create_table("df", df) @@ -208,10 +205,26 @@ def test_multi_case_when(c): actual_df = c.sql( """ SELECT - CASE WHEN a BETWEEN 6 AND 8 THEN 1 ELSE 0 END AS C + CASE WHEN a BETWEEN 6 AND 8 THEN 1 ELSE 0 END AS "C" + FROM df + """ + ) + expected_df = pd.DataFrame({"C": [0, 1, 1, 1, 0]}, dtype=np.int64) + + assert_eq(actual_df, expected_df) + + +def test_case_when_no_else(c): + df = pd.DataFrame({"a": [1, 6, 7, 8, 9]}) + c.create_table("df", df) + + actual_df = c.sql( + """ + SELECT + CASE WHEN a BETWEEN 6 AND 8 THEN 1 END AS "C" FROM df """ ) - expected_df = pd.DataFrame({"C": [0, 1, 1, 1, 0]}, dtype=np.int32) + expected_df = pd.DataFrame({"C": [None, 1, 1, 1, None]}) assert_eq(actual_df, expected_df)