diff --git a/dask_planner/Cargo.toml b/dask_planner/Cargo.toml index bc7e3138a..1ebea5c08 100644 --- a/dask_planner/Cargo.toml +++ b/dask_planner/Cargo.toml @@ -12,8 +12,8 @@ rust-version = "1.59" tokio = { version = "1.0", features = ["macros", "rt", "rt-multi-thread", "sync", "fs", "parking_lot"] } rand = "0.7" pyo3 = { version = "0.16", features = ["extension-module", "abi3", "abi3-py38"] } -datafusion = { git="https://github.com/apache/arrow-datafusion/", rev = "23f1c77569d1f3b0ff42ade56f9b2afb53d44292" } -datafusion-expr = { git="https://github.com/apache/arrow-datafusion/", rev = "23f1c77569d1f3b0ff42ade56f9b2afb53d44292" } +datafusion = { git="https://github.com/apache/arrow-datafusion/", rev = "ef49d2858c2aba1ea7cd5fed3b1e5feb77fc2233" } +datafusion-expr = { git="https://github.com/apache/arrow-datafusion/", rev = "ef49d2858c2aba1ea7cd5fed3b1e5feb77fc2233" } uuid = { version = "0.8", features = ["v4"] } mimalloc = { version = "*", default-features = false } sqlparser = "0.14.0" diff --git a/dask_planner/src/expression.rs b/dask_planner/src/expression.rs index 4f589f1d1..f457d0ee9 100644 --- a/dask_planner/src/expression.rs +++ b/dask_planner/src/expression.rs @@ -7,7 +7,7 @@ use std::convert::{From, Into}; use datafusion::error::DataFusionError; use datafusion::arrow::datatypes::DataType; -use datafusion_expr::{col, lit, BuiltinScalarFunction, Expr}; +use datafusion_expr::{lit, BuiltinScalarFunction, Expr}; use datafusion::scalar::ScalarValue; @@ -31,15 +31,6 @@ impl From for Expr { } } -impl From for PyExpr { - fn from(expr: Expr) -> PyExpr { - PyExpr { - input_plan: None, - expr: expr, - } - } -} - #[pyclass(name = "ScalarValue", module = "datafusion", subclass)] #[derive(Debug, Clone)] pub struct PyScalarValue { @@ -70,13 +61,9 @@ impl PyExpr { } } - fn _column_name(&self, mut plan: LogicalPlan) -> String { + fn _column_name(&self, plan: LogicalPlan) -> String { match &self.expr { Expr::Alias(expr, name) => { - println!("Alias encountered with name: {:?}", name); - // let reference: Expr = *expr.as_ref(); - // let plan: logical::PyLogicalPlan = reference.input().clone().into(); - // Only certain LogicalPlan variants are valid in this nested Alias scenario so we // extract the valid ones and error on the invalid ones match expr.as_ref() { @@ -160,7 +147,7 @@ impl PyExpr { impl PyExpr { #[staticmethod] pub fn literal(value: PyScalarValue) -> PyExpr { - lit(value.scalar_value).into() + PyExpr::from(lit(value.scalar_value), None) } /// If this Expression instances references an existing @@ -173,6 +160,11 @@ impl PyExpr { } } + #[pyo3(name = "toString")] + pub fn to_string(&self) -> PyResult { + Ok(format!("{}", &self.expr)) + } + /// Gets the positional index of the Expr instance from the LogicalPlan DFSchema #[pyo3(name = "getIndex")] pub fn index(&self) -> PyResult { @@ -230,7 +222,7 @@ impl PyExpr { #[pyo3(name = "getRexType")] pub fn rex_type(&self) -> RexType { match &self.expr { - Expr::Alias(..) => RexType::Reference, + Expr::Alias(expr, name) => RexType::Reference, Expr::Column(..) => RexType::Reference, Expr::ScalarVariable(..) => RexType::Literal, Expr::Literal(..) => RexType::Literal, @@ -267,22 +259,26 @@ impl PyExpr { Expr::BinaryExpr { left, op: _, right } => { let mut operands: Vec = Vec::new(); let left_desc: Expr = *left.clone(); - operands.push(left_desc.into()); + let py_left: PyExpr = PyExpr::from(left_desc, self.input_plan.clone()); + operands.push(py_left); let right_desc: Expr = *right.clone(); - operands.push(right_desc.into()); + let py_right: PyExpr = PyExpr::from(right_desc, self.input_plan.clone()); + operands.push(py_right); Ok(operands) } Expr::ScalarFunction { fun: _, args } => { let mut operands: Vec = Vec::new(); for arg in args { - operands.push(arg.clone().into()); + let py_arg: PyExpr = PyExpr::from(arg.clone(), self.input_plan.clone()); + operands.push(py_arg); } Ok(operands) } Expr::Cast { expr, data_type: _ } => { let mut operands: Vec = Vec::new(); let ex: Expr = *expr.clone(); - operands.push(ex.into()); + let py_ex: PyExpr = PyExpr::from(ex, self.input_plan.clone()); + operands.push(py_ex); Ok(operands) } _ => Err(PyErr::new::( diff --git a/dask_planner/src/sql.rs b/dask_planner/src/sql.rs index 05650bbb1..bfade7bb9 100644 --- a/dask_planner/src/sql.rs +++ b/dask_planner/src/sql.rs @@ -158,7 +158,6 @@ impl DaskSQLContext { statement: statement::PyStatement, ) -> PyResult { let planner = SqlToRel::new(self); - match planner.statement_to_plan(statement.statement) { Ok(k) => { println!("\nLogicalPlan: {:?}\n\n", k); diff --git a/dask_planner/src/sql/logical.rs b/dask_planner/src/sql/logical.rs index 0c9ca27f5..2359afdf2 100644 --- a/dask_planner/src/sql/logical.rs +++ b/dask_planner/src/sql/logical.rs @@ -122,6 +122,7 @@ impl PyLogicalPlan { LogicalPlan::Explain(_explain) => "Explain", LogicalPlan::Analyze(_analyze) => "Analyze", LogicalPlan::Extension(_extension) => "Extension", + LogicalPlan::Subquery(_sub_query) => "Subquery", LogicalPlan::SubqueryAlias(_sqalias) => "SubqueryAlias", LogicalPlan::CreateCatalogSchema(_create) => "CreateCatalogSchema", LogicalPlan::CreateCatalog(_create_catalog) => "CreateCatalog", diff --git a/dask_planner/src/sql/logical/projection.rs b/dask_planner/src/sql/logical/projection.rs index 3d2eccdd8..c3144a15b 100644 --- a/dask_planner/src/sql/logical/projection.rs +++ b/dask_planner/src/sql/logical/projection.rs @@ -11,6 +11,23 @@ pub struct PyProjection { pub(crate) projection: Projection, } +impl PyProjection { + /// Projection: Gets the names of the fields that should be projected + fn projected_expressions(&mut self, local_expr: &PyExpr) -> Vec { + let mut projs: Vec = Vec::new(); + match &local_expr.expr { + Expr::Alias(expr, _name) => { + let ex: Expr = *expr.clone(); + let mut py_expr: PyExpr = PyExpr::from(ex, Some(self.projection.input.clone())); + py_expr.input_plan = local_expr.input_plan.clone(); + projs.extend_from_slice(self.projected_expressions(&py_expr).as_slice()); + } + _ => projs.push(local_expr.clone()), + } + projs + } +} + #[pymethods] impl PyProjection { #[pyo3(name = "getColumnName")] @@ -39,9 +56,21 @@ impl PyProjection { _ => unimplemented!("projection.rs column_name is unimplemented for LogicalPlan variant: {:?}", self.projection.input), } } + Expr::Cast { expr, data_type: _ } => { + let ex_type: Expr = *expr.clone(); + let py_type: PyExpr = + PyExpr::from(ex_type, Some(self.projection.input.clone())); + val = self.column_name(py_type).unwrap(); + println!("Setting col name to: {:?}", val); + } _ => panic!("not supported: {:?}", expr), }, Expr::Column(col) => val = col.name.clone(), + Expr::Cast { expr, data_type: _ } => { + let ex_type: Expr = *expr; + let py_type: PyExpr = PyExpr::from(ex_type, Some(self.projection.input.clone())); + val = self.column_name(py_type).unwrap() + } _ => { panic!( "column_name is unimplemented for Expr variant: {:?}", @@ -52,25 +81,16 @@ impl PyProjection { Ok(val) } - /// Projection: Gets the names of the fields that should be projected - #[pyo3(name = "getProjectedExpressions")] - fn projected_expressions(&mut self) -> PyResult> { - let mut projs: Vec = Vec::new(); - for expr in &self.projection.expr { - projs.push(PyExpr::from( - expr.clone(), - Some(self.projection.input.clone()), - )); - } - Ok(projs) - } - #[pyo3(name = "getNamedProjects")] fn named_projects(&mut self) -> PyResult> { let mut named: Vec<(String, PyExpr)> = Vec::new(); - for expr in &self.projected_expressions().unwrap() { - let name: String = self.column_name(expr.clone()).unwrap(); - named.push((name, expr.clone())); + for expression in self.projection.expr.clone() { + let mut py_expr: PyExpr = PyExpr::from(expression, Some(self.projection.input.clone())); + py_expr.input_plan = Some(self.projection.input.clone()); + for expr in self.projected_expressions(&py_expr) { + let name: String = self.column_name(expr.clone()).unwrap(); + named.push((name, expr.clone())); + } } Ok(named) } diff --git a/dask_planner/src/sql/types.rs b/dask_planner/src/sql/types.rs index e6bd5134c..618786d88 100644 --- a/dask_planner/src/sql/types.rs +++ b/dask_planner/src/sql/types.rs @@ -4,7 +4,6 @@ pub mod rel_data_type; pub mod rel_data_type_field; use pyo3::prelude::*; -use pyo3::types::PyAny; use pyo3::types::PyDict; #[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] diff --git a/dask_sql/physical/rel/base.py b/dask_sql/physical/rel/base.py index 990c73880..aae628f18 100644 --- a/dask_sql/physical/rel/base.py +++ b/dask_sql/physical/rel/base.py @@ -105,9 +105,6 @@ def fix_dtype_to_row_type(dc: DataContainer, row_type: "RelDataType"): expected_type = sql_to_python_type(field_type.getSqlType()) df_field_name = cc.get_backend_by_frontend_name(field_name) - logger.debug( - f"Before cast df_field_name: {df_field_name}, expected_type: {expected_type}" - ) df = cast_column_type(df, df_field_name, expected_type) return DataContainer(df, dc.column_container) diff --git a/dask_sql/physical/rel/logical/project.py b/dask_sql/physical/rel/logical/project.py index 6d55977a2..0441fe486 100644 --- a/dask_sql/physical/rel/logical/project.py +++ b/dask_sql/physical/rel/logical/project.py @@ -63,7 +63,7 @@ def convert(self, rel: "LogicalPlan", context: "dask_sql.Context") -> DataContai else: random_name = new_temporary_column(df) new_columns[random_name] = RexConverter.convert( - expr, dc, context=context + rel, expr, dc, context=context ) logger.debug(f"Adding a new column {key} out of {expr}") new_mappings[key] = random_name diff --git a/dask_sql/physical/rex/convert.py b/dask_sql/physical/rex/convert.py index 1123e8359..ac443a5dd 100644 --- a/dask_sql/physical/rex/convert.py +++ b/dask_sql/physical/rex/convert.py @@ -15,12 +15,9 @@ _REX_TYPE_TO_PLUGIN = { - "Alias": "InputRef", - "Column": "InputRef", - "BinaryExpr": "RexCall", - "Literal": "RexLiteral", - "ScalarFunction": "RexCall", - "Cast": "RexCall", + "RexType.Reference": "InputRef", + "RexType.Call": "RexCall", + "RexType.Literal": "RexLiteral", } @@ -55,12 +52,12 @@ def convert( context: "dask_sql.Context", ) -> Union[dd.DataFrame, Any]: """ - Convert the given rel (java instance) + Convert the given Expression into a python expression (a dask dataframe) using the stored plugins and the dictionary of registered dask tables. """ - expr_type = _REX_TYPE_TO_PLUGIN[rex.getExprType()] + expr_type = _REX_TYPE_TO_PLUGIN[str(rex.getRexType())] try: plugin_instance = cls.get_plugin(expr_type) diff --git a/dask_sql/physical/rex/core/call.py b/dask_sql/physical/rex/core/call.py index 68c941c30..1a74ce9d5 100644 --- a/dask_sql/physical/rex/core/call.py +++ b/dask_sql/physical/rex/core/call.py @@ -225,9 +225,7 @@ def cast(self, operand, rex=None) -> SeriesOrScalar: return operand output_type = str(rex.getType()) - python_type = sql_to_python_type( - output_type=sql_to_python_type(output_type.upper()) - ) + python_type = sql_to_python_type(SqlTypeName.fromString(output_type.upper())) return_column = cast_column_to_type(operand, python_type) diff --git a/dask_sql/physical/rex/core/input_ref.py b/dask_sql/physical/rex/core/input_ref.py index 152b24caf..4272c832e 100644 --- a/dask_sql/physical/rex/core/input_ref.py +++ b/dask_sql/physical/rex/core/input_ref.py @@ -22,12 +22,14 @@ class RexInputRefPlugin(BaseRexPlugin): def convert( self, rel: "LogicalPlan", - expr: "Expression", + rex: "Expression", dc: DataContainer, context: "dask_sql.Context", ) -> dd.Series: df = dc.df + cc = dc.column_container # The column is references by index - column_name = str(expr.column_name(rel)) - return df[column_name] + index = rex.getIndex() + backend_column_name = cc.get_backend_by_frontend_index(index) + return df[backend_column_name] diff --git a/tests/integration/test_filter.py b/tests/integration/test_filter.py index b7f8a29f0..af7cf5bea 100644 --- a/tests/integration/test_filter.py +++ b/tests/integration/test_filter.py @@ -14,7 +14,6 @@ def test_filter(c, df): assert_eq(return_df, expected_df) -@pytest.mark.skip(reason="WIP DataFusion") def test_filter_scalar(c, df): return_df = c.sql("SELECT * FROM df WHERE True") @@ -37,7 +36,6 @@ def test_filter_scalar(c, df): assert_eq(return_df, expected_df, check_index_type=False) -@pytest.mark.skip(reason="WIP DataFusion") def test_filter_complicated(c, df): return_df = c.sql("SELECT * FROM df WHERE a < 3 AND (b > 1 AND b < 3)") @@ -48,7 +46,6 @@ def test_filter_complicated(c, df): ) -@pytest.mark.skip(reason="WIP DataFusion") def test_filter_with_nan(c): return_df = c.sql("SELECT * FROM user_table_nan WHERE c = 3") @@ -62,7 +59,6 @@ def test_filter_with_nan(c): ) -@pytest.mark.skip(reason="WIP DataFusion") def test_string_filter(c, string_table): return_df = c.sql("SELECT * FROM string_table WHERE a = 'a normal string'") @@ -77,7 +73,10 @@ def test_string_filter(c, string_table): "input_table", [ "datetime_table", - pytest.param("gpu_datetime_table", marks=pytest.mark.gpu), + pytest.param( + "gpu_datetime_table", + marks=(pytest.mark.gpu, pytest.mark.skip(reason="WIP DataFusion")), + ), ], ) def test_filter_cast_date(c, input_table, request): @@ -101,7 +100,10 @@ def test_filter_cast_date(c, input_table, request): "input_table", [ "datetime_table", - pytest.param("gpu_datetime_table", marks=pytest.mark.gpu), + pytest.param( + "gpu_datetime_table", + marks=(pytest.mark.gpu, pytest.mark.skip(reason="WIP DataFusion")), + ), ], ) def test_filter_cast_timestamp(c, input_table, request): @@ -206,7 +208,6 @@ def test_predicate_pushdown(c, parquet_ddf, query, df_func, filters): assert_eq(return_df, expected_df, check_divisions=False) -@pytest.mark.skip(reason="WIP DataFusion") def test_filtered_csv(tmpdir, c): # Predicate pushdown is NOT supported for CSV data. # This test just checks that the "attempted" diff --git a/tests/integration/test_select.py b/tests/integration/test_select.py index 443d9d395..defa9cc9a 100644 --- a/tests/integration/test_select.py +++ b/tests/integration/test_select.py @@ -144,7 +144,6 @@ def test_limit(c, input_table, limit, offset, request): assert_eq(c.sql(query), long_table.iloc[offset : offset + limit if limit else None]) -@pytest.mark.skip(reason="WIP DataFusion") @pytest.mark.parametrize( "input_table", [ @@ -178,7 +177,6 @@ def test_date_casting(c, input_table, request): assert_eq(result_df, expected_df) -@pytest.mark.skip(reason="WIP DataFusion") @pytest.mark.parametrize( "input_table", [ diff --git a/tests/unit/test_context.py b/tests/unit/test_context.py index 12e85b69c..269713915 100644 --- a/tests/unit/test_context.py +++ b/tests/unit/test_context.py @@ -116,6 +116,7 @@ def test_sql(gpu): assert_eq(result, data_frame) +@pytest.mark.skip(reason="WIP DataFusion - missing create statement logic") @pytest.mark.parametrize( "gpu", [