diff --git a/dask_planner/src/expression.rs b/dask_planner/src/expression.rs index f457d0ee9..f482ed9c0 100644 --- a/dask_planner/src/expression.rs +++ b/dask_planner/src/expression.rs @@ -4,7 +4,7 @@ use crate::sql::types::RexType; use pyo3::prelude::*; use std::convert::{From, Into}; -use datafusion::error::DataFusionError; +use datafusion::error::{DataFusionError, Result}; use datafusion::arrow::datatypes::DataType; use datafusion_expr::{lit, BuiltinScalarFunction, Expr}; @@ -15,6 +15,9 @@ pub use datafusion_expr::LogicalPlan; use datafusion::prelude::Column; +use crate::sql::exceptions::py_runtime_err; +use datafusion::common::DFField; +use datafusion::logical_plan::exprlist_to_fields; use std::sync::Arc; /// An PyExpr that can be used on a DataFrame @@ -61,85 +64,9 @@ impl PyExpr { } } - fn _column_name(&self, plan: LogicalPlan) -> String { - match &self.expr { - Expr::Alias(expr, name) => { - // Only certain LogicalPlan variants are valid in this nested Alias scenario so we - // extract the valid ones and error on the invalid ones - match expr.as_ref() { - Expr::Column(col) => { - // First we must iterate the current node before getting its input - match plan { - LogicalPlan::Projection(proj) => { - match proj.input.as_ref() { - LogicalPlan::Aggregate(agg) => { - let mut exprs = agg.group_expr.clone(); - exprs.extend_from_slice(&agg.aggr_expr); - let col_index: usize = - proj.input.schema().index_of_column(col).unwrap(); - // match &exprs[plan.get_index(col)] { - match &exprs[col_index] { - Expr::AggregateFunction { args, .. } => { - match &args[0] { - Expr::Column(col) => { - println!( - "AGGREGATE COLUMN IS {}", - col.name - ); - col.name.clone() - } - _ => name.clone(), - } - } - _ => name.clone(), - } - } - _ => { - println!("Encountered a non-Aggregate type"); - - name.clone() - } - } - } - _ => name.clone(), - } - } - _ => { - println!("Encountered a non Expr::Column instance"); - name.clone() - } - } - } - Expr::Column(column) => column.name.clone(), - Expr::ScalarVariable(..) => unimplemented!("ScalarVariable!!!"), - Expr::Literal(..) => unimplemented!("Literal!!!"), - Expr::BinaryExpr { - left: _, - op: _, - right: _, - } => { - // /// TODO: Examine this more deeply about whether name comes from the left or right - // self.column_name(left) - unimplemented!("BinaryExpr HERE!!!") - } - Expr::Not(..) => unimplemented!("Not!!!"), - Expr::IsNotNull(..) => unimplemented!("IsNotNull!!!"), - Expr::Negative(..) => unimplemented!("Negative!!!"), - Expr::GetIndexedField { .. } => unimplemented!("GetIndexedField!!!"), - Expr::IsNull(..) => unimplemented!("IsNull!!!"), - Expr::Between { .. } => unimplemented!("Between!!!"), - Expr::Case { .. } => unimplemented!("Case!!!"), - Expr::Cast { .. } => unimplemented!("Cast!!!"), - Expr::TryCast { .. } => unimplemented!("TryCast!!!"), - Expr::Sort { .. } => unimplemented!("Sort!!!"), - Expr::ScalarFunction { .. } => unimplemented!("ScalarFunction!!!"), - Expr::AggregateFunction { .. } => unimplemented!("AggregateFunction!!!"), - Expr::WindowFunction { .. } => unimplemented!("WindowFunction!!!"), - Expr::AggregateUDF { .. } => unimplemented!("AggregateUDF!!!"), - Expr::InList { .. } => unimplemented!("InList!!!"), - Expr::Wildcard => unimplemented!("Wildcard!!!"), - _ => panic!("Nothing found!!!"), - } + fn _column_name(&self, plan: LogicalPlan) -> Result { + let field = expr_to_field(&self.expr, &plan)?; + Ok(field.unqualified_column().name.clone()) } } @@ -171,7 +98,7 @@ impl PyExpr { let input: &Option> = &self.input_plan; match input { Some(plan) => { - let name: Result = self.expr.name(plan.schema()); + let name: Result = self.expr.name(plan.schema()); match name { Ok(fq_name) => Ok(plan .schema() @@ -248,8 +175,9 @@ impl PyExpr { } /// Python friendly shim code to get the name of a column referenced by an expression - pub fn column_name(&self, mut plan: logical::PyLogicalPlan) -> String { + pub fn column_name(&self, mut plan: logical::PyLogicalPlan) -> PyResult { self._column_name(plan.current_node()) + .map_err(|e| py_runtime_err(e)) } /// Gets the operands for a BinaryExpr call @@ -540,3 +468,11 @@ impl PyExpr { } } } + +/// Create a [DFField] representing an [Expr], given an input [LogicalPlan] to resolve against +pub fn expr_to_field(expr: &Expr, input_plan: &LogicalPlan) -> Result { + // TODO this is not the implementation that we really want and will be improved + // once some changes are made in DataFusion + let fields = exprlist_to_fields(&[expr.clone()], &input_plan.schema())?; + Ok(fields[0].clone()) +} diff --git a/dask_planner/src/sql/exceptions.rs b/dask_planner/src/sql/exceptions.rs index 567225f4f..f6eae1b25 100644 --- a/dask_planner/src/sql/exceptions.rs +++ b/dask_planner/src/sql/exceptions.rs @@ -6,3 +6,7 @@ create_exception!(rust, ParsingException, pyo3::exceptions::PyException); pub fn py_type_err(e: DataFusionError) -> PyErr { PyErr::new::(format!("{:?}", e)) } + +pub fn py_runtime_err(e: DataFusionError) -> PyErr { + PyErr::new::(format!("{:?}", e)) +}