diff --git a/datafusion/src/optimizer/constant_folding.rs b/datafusion/src/optimizer/constant_folding.rs index 4d8f06fb2844d..17125a33fa0de 100644 --- a/datafusion/src/optimizer/constant_folding.rs +++ b/datafusion/src/optimizer/constant_folding.rs @@ -15,12 +15,10 @@ // specific language governing permissions and limitations // under the License. -//! Boolean comparison rule rewrites redundant comparison expression involving boolean literal into -//! unary expression. +//! Constant folding and algebraic simplification use std::sync::Arc; -use arrow::compute::kernels::cast_utils::string_to_timestamp_nanos; use arrow::datatypes::DataType; use crate::error::Result; @@ -30,11 +28,11 @@ use crate::optimizer::optimizer::OptimizerRule; use crate::optimizer::utils; use crate::physical_plan::functions::BuiltinScalarFunction; use crate::scalar::ScalarValue; -use arrow::compute::{kernels, DEFAULT_CAST_OPTIONS}; -/// Optimizer that simplifies comparison expressions involving boolean literals. +/// Simplifies plans by rewriting [`Expr`]`s evaluating constants +/// and applying algebraic simplifications /// -/// Recursively go through all expressions and simplify the following cases: +/// Example transformations that are applied: /// * `expr = true` and `expr != false` to `expr` when `expr` is of boolean type /// * `expr = false` and `expr != true` to `!expr` when `expr` is of boolean type /// * `true = true` and `false = false` to `true` @@ -61,14 +59,16 @@ impl OptimizerRule for ConstantFolding { // projected columns. With just the projected schema, it's not possible to infer types for // expressions that references non-projected columns within the same project plan or its // children plans. - let mut rewriter = ConstantRewriter { + let mut simplifier = Simplifier { schemas: plan.all_schemas(), execution_props, }; + let mut const_evaluator = utils::ConstEvaluator::new(); + match plan { LogicalPlan::Filter { predicate, input } => Ok(LogicalPlan::Filter { - predicate: predicate.clone().rewrite(&mut rewriter)?, + predicate: predicate.clone().rewrite(&mut simplifier)?, input: Arc::new(self.optimize(input, execution_props)?), }), // Rest: recurse into plan, apply optimization where possible @@ -95,7 +95,18 @@ impl OptimizerRule for ConstantFolding { let expr = plan .expressions() .into_iter() - .map(|e| e.rewrite(&mut rewriter)) + .map(|e| { + // TODO iterate until no changes are made + // during rewrite (evaluating constants can + // enable new simplifications and + // simplifications can enable new constant + // evaluation) + let new_e = e + // fold constants and then simplify + .rewrite(&mut const_evaluator)? + .rewrite(&mut simplifier)?; + Ok(new_e) + }) .collect::>>()?; utils::from_plan(plan, &expr, &new_inputs) @@ -111,13 +122,17 @@ impl OptimizerRule for ConstantFolding { } } -struct ConstantRewriter<'a> { +/// Simplifies [`Expr`]s by applying algebraic transformation rules +/// +/// For example +/// `false && col` --> `col` where `col` is a boolean types +struct Simplifier<'a> { /// input schemas schemas: Vec<&'a DFSchemaRef>, execution_props: &'a ExecutionProps, } -impl<'a> ConstantRewriter<'a> { +impl<'a> Simplifier<'a> { fn is_boolean_type(&self, expr: &Expr) -> bool { for schema in &self.schemas { if let Ok(DataType::Boolean) = expr.get_type(schema) { @@ -129,22 +144,15 @@ impl<'a> ConstantRewriter<'a> { } } -impl<'a> ExprRewriter for ConstantRewriter<'a> { - /// rewrite the expression simplifying any constant expressions +impl<'a> ExprRewriter for Simplifier<'a> { + /// rewrite the expression using algebriac simplification rules fn mutate(&mut self, expr: Expr) -> Result { let new_expr = match expr { - Expr::BinaryExpr { left, op, right } => match op { - Operator::Eq => match (left.as_ref(), right.as_ref()) { - ( - Expr::Literal(ScalarValue::Boolean(l)), - Expr::Literal(ScalarValue::Boolean(r)), - ) => match (l, r) { - (Some(l), Some(r)) => { - Expr::Literal(ScalarValue::Boolean(Some(l == r))) - } - _ => Expr::Literal(ScalarValue::Boolean(None)), - }, - (Expr::Literal(ScalarValue::Boolean(b)), _) + Expr::BinaryExpr { left, op, right } => { + match (left.as_ref(), op, right.as_ref()) { + // = --> + // = --> ! + (Expr::Literal(ScalarValue::Boolean(b)), Operator::Eq, _) if self.is_boolean_type(&right) => { match b { @@ -153,7 +161,9 @@ impl<'a> ExprRewriter for ConstantRewriter<'a> { None => Expr::Literal(ScalarValue::Boolean(None)), } } - (_, Expr::Literal(ScalarValue::Boolean(b))) + // = --> + // = --> ! + (_, Operator::Eq, Expr::Literal(ScalarValue::Boolean(b))) if self.is_boolean_type(&left) => { match b { @@ -162,23 +172,9 @@ impl<'a> ExprRewriter for ConstantRewriter<'a> { None => Expr::Literal(ScalarValue::Boolean(None)), } } - _ => Expr::BinaryExpr { - left, - op: Operator::Eq, - right, - }, - }, - Operator::NotEq => match (left.as_ref(), right.as_ref()) { - ( - Expr::Literal(ScalarValue::Boolean(l)), - Expr::Literal(ScalarValue::Boolean(r)), - ) => match (l, r) { - (Some(l), Some(r)) => { - Expr::Literal(ScalarValue::Boolean(Some(l != r))) - } - _ => Expr::Literal(ScalarValue::Boolean(None)), - }, - (Expr::Literal(ScalarValue::Boolean(b)), _) + // != --> ! + // != --> + (Expr::Literal(ScalarValue::Boolean(b)), Operator::NotEq, _) if self.is_boolean_type(&right) => { match b { @@ -187,7 +183,9 @@ impl<'a> ExprRewriter for ConstantRewriter<'a> { None => Expr::Literal(ScalarValue::Boolean(None)), } } - (_, Expr::Literal(ScalarValue::Boolean(b))) + // != --> ! + // != --> + (_, Operator::NotEq, Expr::Literal(ScalarValue::Boolean(b))) if self.is_boolean_type(&left) => { match b { @@ -196,22 +194,18 @@ impl<'a> ExprRewriter for ConstantRewriter<'a> { None => Expr::Literal(ScalarValue::Boolean(None)), } } - _ => Expr::BinaryExpr { - left, - op: Operator::NotEq, - right, - }, - }, - _ => Expr::BinaryExpr { left, op, right }, - }, + _ => Expr::BinaryExpr { left, op, right }, + } + } + // Not(Not(expr)) --> expr Expr::Not(inner) => { - // Not(Not(expr)) --> expr if let Expr::Not(negated_inner) = *inner { *negated_inner } else { Expr::Not(inner) } } + // convert now() --> the time in `ExecutionProps` Expr::ScalarFunction { fun: BuiltinScalarFunction::Now, .. @@ -220,56 +214,8 @@ impl<'a> ExprRewriter for ConstantRewriter<'a> { .query_execution_start_time .timestamp_nanos(), ))), - Expr::ScalarFunction { - fun: BuiltinScalarFunction::ToTimestamp, - args, - } => { - if !args.is_empty() { - match &args[0] { - Expr::Literal(ScalarValue::Utf8(Some(val))) => { - match string_to_timestamp_nanos(val) { - Ok(timestamp) => Expr::Literal( - ScalarValue::TimestampNanosecond(Some(timestamp)), - ), - _ => Expr::ScalarFunction { - fun: BuiltinScalarFunction::ToTimestamp, - args, - }, - } - } - _ => Expr::ScalarFunction { - fun: BuiltinScalarFunction::ToTimestamp, - args, - }, - } - } else { - Expr::ScalarFunction { - fun: BuiltinScalarFunction::ToTimestamp, - args, - } - } - } - Expr::Cast { - expr: inner, - data_type, - } => match inner.as_ref() { - Expr::Literal(val) => { - let scalar_array = val.to_array(); - let cast_array = kernels::cast::cast_with_options( - &scalar_array, - &data_type, - &DEFAULT_CAST_OPTIONS, - )?; - let cast_scalar = ScalarValue::try_from_array(&cast_array, 0)?; - Expr::Literal(cast_scalar) - } - _ => Expr::Cast { - expr: inner, - data_type, - }, - }, expr => { - // no rewrite possible + // no additional rewrites possible expr } }; @@ -280,12 +226,13 @@ impl<'a> ExprRewriter for ConstantRewriter<'a> { #[cfg(test)] mod tests { use super::*; - use crate::logical_plan::{ - col, lit, max, min, DFField, DFSchema, LogicalPlanBuilder, + use crate::{ + assert_contains, + logical_plan::{col, lit, max, min, DFField, DFSchema, LogicalPlanBuilder}, }; use arrow::datatypes::*; - use chrono::{DateTime, Utc}; + use chrono::{DateTime, TimeZone, Utc}; fn test_table_scan() -> Result { let schema = Schema::new(vec![ @@ -310,7 +257,7 @@ mod tests { #[test] fn optimize_expr_not_not() -> Result<()> { let schema = expr_test_schema(); - let mut rewriter = ConstantRewriter { + let mut rewriter = Simplifier { schemas: vec![&schema], execution_props: &ExecutionProps::new(), }; @@ -326,7 +273,7 @@ mod tests { #[test] fn optimize_expr_null_comparison() -> Result<()> { let schema = expr_test_schema(); - let mut rewriter = ConstantRewriter { + let mut rewriter = Simplifier { schemas: vec![&schema], execution_props: &ExecutionProps::new(), }; @@ -362,7 +309,7 @@ mod tests { #[test] fn optimize_expr_eq() -> Result<()> { let schema = expr_test_schema(); - let mut rewriter = ConstantRewriter { + let mut rewriter = Simplifier { schemas: vec![&schema], execution_props: &ExecutionProps::new(), }; @@ -393,7 +340,7 @@ mod tests { #[test] fn optimize_expr_eq_skip_nonboolean_type() -> Result<()> { let schema = expr_test_schema(); - let mut rewriter = ConstantRewriter { + let mut rewriter = Simplifier { schemas: vec![&schema], execution_props: &ExecutionProps::new(), }; @@ -433,7 +380,7 @@ mod tests { #[test] fn optimize_expr_not_eq() -> Result<()> { let schema = expr_test_schema(); - let mut rewriter = ConstantRewriter { + let mut rewriter = Simplifier { schemas: vec![&schema], execution_props: &ExecutionProps::new(), }; @@ -452,24 +399,13 @@ mod tests { col("c2"), ); - // test constant - assert_eq!( - (lit(true).not_eq(lit(true))).rewrite(&mut rewriter)?, - lit(false), - ); - - assert_eq!( - (lit(true).not_eq(lit(false))).rewrite(&mut rewriter)?, - lit(true), - ); - Ok(()) } #[test] fn optimize_expr_not_eq_skip_nonboolean_type() -> Result<()> { let schema = expr_test_schema(); - let mut rewriter = ConstantRewriter { + let mut rewriter = Simplifier { schemas: vec![&schema], execution_props: &ExecutionProps::new(), }; @@ -505,7 +441,7 @@ mod tests { #[test] fn optimize_expr_case_when_then_else() -> Result<()> { let schema = expr_test_schema(); - let mut rewriter = ConstantRewriter { + let mut rewriter = Simplifier { schemas: vec![&schema], execution_props: &ExecutionProps::new(), }; @@ -668,6 +604,20 @@ mod tests { Ok(()) } + // expect optimizing will result in an error, returning the error string + fn get_optimized_plan_err(plan: &LogicalPlan, date_time: &DateTime) -> String { + let rule = ConstantFolding::new(); + let execution_props = ExecutionProps { + query_execution_start_time: *date_time, + }; + + let err = rule + .optimize(plan, &execution_props) + .expect_err("expected optimization to fail"); + + err.to_string() + } + fn get_optimized_plan_formatted( plan: &LogicalPlan, date_time: &DateTime, @@ -683,15 +633,19 @@ mod tests { return format!("{:?}", optimized_plan); } + /// Create a to_timestamp expr + fn to_timestamp_expr(arg: impl Into) -> Expr { + Expr::ScalarFunction { + args: vec![lit(arg.into())], + fun: BuiltinScalarFunction::ToTimestamp, + } + } + #[test] - fn to_timestamp_expr() { + fn to_timestamp_expr_folded() { let table_scan = test_table_scan().unwrap(); - let proj = vec![Expr::ScalarFunction { - args: vec![Expr::Literal(ScalarValue::Utf8(Some( - "2020-09-08T12:00:00+00:00".to_string(), - )))], - fun: BuiltinScalarFunction::ToTimestamp, - }]; + let proj = vec![to_timestamp_expr("2020-09-08T12:00:00+00:00")]; + let plan = LogicalPlanBuilder::from(table_scan) .project(proj) .unwrap() @@ -701,55 +655,30 @@ mod tests { let expected = "Projection: TimestampNanosecond(1599566400000000000)\ \n TableScan: test projection=None" .to_string(); - let actual = get_optimized_plan_formatted(&plan, &chrono::Utc::now()); + let actual = get_optimized_plan_formatted(&plan, &Utc::now()); assert_eq!(expected, actual); } #[test] fn to_timestamp_expr_wrong_arg() { let table_scan = test_table_scan().unwrap(); - let proj = vec![Expr::ScalarFunction { - args: vec![Expr::Literal(ScalarValue::Utf8(Some( - "I'M NOT A TIMESTAMP".to_string(), - )))], - fun: BuiltinScalarFunction::ToTimestamp, - }]; - let plan = LogicalPlanBuilder::from(table_scan) - .project(proj) - .unwrap() - .build() - .unwrap(); - - let expected = "Projection: totimestamp(Utf8(\"I\'M NOT A TIMESTAMP\"))\ - \n TableScan: test projection=None"; - let actual = get_optimized_plan_formatted(&plan, &chrono::Utc::now()); - assert_eq!(expected, actual); - } - - #[test] - fn to_timestamp_expr_no_arg() { - let table_scan = test_table_scan().unwrap(); - let proj = vec![Expr::ScalarFunction { - args: vec![], - fun: BuiltinScalarFunction::ToTimestamp, - }]; + let proj = vec![to_timestamp_expr("I'M NOT A TIMESTAMP")]; let plan = LogicalPlanBuilder::from(table_scan) .project(proj) .unwrap() .build() .unwrap(); - let expected = "Projection: totimestamp()\ - \n TableScan: test projection=None"; - let actual = get_optimized_plan_formatted(&plan, &chrono::Utc::now()); - assert_eq!(expected, actual); + let expected = "Error parsing 'I'M NOT A TIMESTAMP' as timestamp"; + let actual = get_optimized_plan_err(&plan, &Utc::now()); + assert_contains!(actual, expected); } #[test] fn cast_expr() { let table_scan = test_table_scan().unwrap(); let proj = vec![Expr::Cast { - expr: Box::new(Expr::Literal(ScalarValue::Utf8(Some("0".to_string())))), + expr: Box::new(lit("0")), data_type: DataType::Int32, }]; let plan = LogicalPlanBuilder::from(table_scan) @@ -760,7 +689,7 @@ mod tests { let expected = "Projection: Int32(0)\ \n TableScan: test projection=None"; - let actual = get_optimized_plan_formatted(&plan, &chrono::Utc::now()); + let actual = get_optimized_plan_formatted(&plan, &Utc::now()); assert_eq!(expected, actual); } @@ -768,7 +697,7 @@ mod tests { fn cast_expr_wrong_arg() { let table_scan = test_table_scan().unwrap(); let proj = vec![Expr::Cast { - expr: Box::new(Expr::Literal(ScalarValue::Utf8(Some("".to_string())))), + expr: Box::new(lit("")), data_type: DataType::Int32, }]; let plan = LogicalPlanBuilder::from(table_scan) @@ -777,20 +706,24 @@ mod tests { .build() .unwrap(); - let expected = "Projection: Int32(NULL)\ - \n TableScan: test projection=None"; - let actual = get_optimized_plan_formatted(&plan, &chrono::Utc::now()); - assert_eq!(expected, actual); + let expected = + "Cannot cast string '' to value of arrow::datatypes::types::Int32Type type"; + let actual = get_optimized_plan_err(&plan, &Utc::now()); + assert_contains!(actual, expected); + } + + fn now_expr() -> Expr { + Expr::ScalarFunction { + args: vec![], + fun: BuiltinScalarFunction::Now, + } } #[test] fn single_now_expr() { let table_scan = test_table_scan().unwrap(); - let proj = vec![Expr::ScalarFunction { - args: vec![], - fun: BuiltinScalarFunction::Now, - }]; - let time = chrono::Utc::now(); + let proj = vec![now_expr()]; + let time = Utc::now(); let plan = LogicalPlanBuilder::from(table_scan) .project(proj) .unwrap() @@ -810,19 +743,10 @@ mod tests { #[test] fn multiple_now_expr() { let table_scan = test_table_scan().unwrap(); - let time = chrono::Utc::now(); + let time = Utc::now(); let proj = vec![ - Expr::ScalarFunction { - args: vec![], - fun: BuiltinScalarFunction::Now, - }, - Expr::Alias( - Box::new(Expr::ScalarFunction { - args: vec![], - fun: BuiltinScalarFunction::Now, - }), - "t2".to_string(), - ), + now_expr(), + Expr::Alias(Box::new(now_expr()), "t2".to_string()), ]; let plan = LogicalPlanBuilder::from(table_scan) .project(proj) @@ -830,6 +754,7 @@ mod tests { .build() .unwrap(); + // expect the same timestamp appears in both exprs let actual = get_optimized_plan_formatted(&plan, &time); let expected = format!( "Projection: TimestampNanosecond({}), TimestampNanosecond({}) AS t2\ @@ -840,4 +765,59 @@ mod tests { assert_eq!(actual, expected); } + + #[test] + fn simplify_and_eval() { + // demonstrate a case where the evaluation needs to run prior + // to the simplifier for it to work + let table_scan = test_table_scan().unwrap(); + let time = Utc::now(); + // (true or false) != col --> !col + let proj = vec![lit(true).or(lit(false)).not_eq(col("a"))]; + let plan = LogicalPlanBuilder::from(table_scan) + .project(proj) + .unwrap() + .build() + .unwrap(); + + let actual = get_optimized_plan_formatted(&plan, &time); + let expected = "Projection: NOT #test.a\ + \n TableScan: test projection=None"; + + assert_eq!(actual, expected); + } + + fn cast_to_int64_expr(expr: Expr) -> Expr { + Expr::Cast { + expr: expr.into(), + data_type: DataType::Int64, + } + } + + #[test] + fn now_less_than_timestamp() { + let table_scan = test_table_scan().unwrap(); + + let ts_string = "2020-09-08T12:05:00+00:00"; + let time = chrono::Utc.timestamp_nanos(1599566400000000000i64); + + // now() < cast(to_timestamp(...) as int) + 5000000000 + let plan = LogicalPlanBuilder::from(table_scan) + .filter( + now_expr() + .lt(cast_to_int64_expr(to_timestamp_expr(ts_string)) + lit(50000)), + ) + .unwrap() + .build() + .unwrap(); + + // TODO constant folder hould be able to run again and fold + // this whole thing down + // TODO add ticket + let expected = "Filter: TimestampNanosecond(1599566400000000000) < CAST(totimestamp(Utf8(\"2020-09-08T12:05:00+00:00\")) AS Int64) + Int32(50000)\ + \n TableScan: test projection=None"; + let actual = get_optimized_plan_formatted(&plan, &time); + + assert_eq!(expected, actual); + } } diff --git a/datafusion/src/optimizer/utils.rs b/datafusion/src/optimizer/utils.rs index 6e64bf39b2e2d..fd33c38a1f53b 100644 --- a/datafusion/src/optimizer/utils.rs +++ b/datafusion/src/optimizer/utils.rs @@ -17,12 +17,18 @@ //! Collection of utility functions that are leveraged by the query optimizer rules +use arrow::array::new_null_array; +use arrow::datatypes::{DataType, Field, Schema}; +use arrow::record_batch::RecordBatch; + use super::optimizer::OptimizerRule; -use crate::execution::context::ExecutionProps; +use crate::execution::context::{ExecutionContextState, ExecutionProps}; use crate::logical_plan::{ - build_join_schema, Column, DFSchemaRef, Expr, LogicalPlan, LogicalPlanBuilder, - Operator, Partitioning, Recursion, + build_join_schema, Column, DFSchema, DFSchemaRef, Expr, ExprRewriter, LogicalPlan, + LogicalPlanBuilder, Operator, Partitioning, Recursion, RewriteRecursion, }; +use crate::physical_plan::functions::Volatility; +use crate::physical_plan::planner::DefaultPhysicalPlanner; use crate::prelude::lit; use crate::scalar::ScalarValue; use crate::{ @@ -468,11 +474,197 @@ pub fn rewrite_expression(expr: &Expr, expressions: &[Expr]) -> Result { } } +/// Partially evaluate `Expr`s so constant subtrees are evaluated at plan time. +/// +/// Note it does not handle other algebriac rewrites such as `(a and false)` --> `a` +/// +/// ``` +/// # use datafusion::prelude::*; +/// # use datafusion::optimizer::utils::ConstEvaluator; +/// let mut const_evaluator = ConstEvaluator::new(); +/// +/// // (1 + 2) + a +/// let expr = (lit(1) + lit(2)) + col("a"); +/// +/// // is rewritten to (3 + a); +/// let rewritten = expr.rewrite(&mut const_evaluator).unwrap(); +/// assert_eq!(rewritten, lit(3) + col("a")); +/// ``` +pub struct ConstEvaluator { + /// can_evaluate is used during the depth-first-search of the + /// Expr tree to track if any siblings (or their descendants) were + /// non evaluatable (e.g. had a column reference or volatile + /// function) + /// + /// Specifically, can_evaluate[N] represents the state of + /// traversal when we are N levels deep in the tree, one entry for + /// this Expr and each of its parents. + /// + /// After visiting all siblings if can_evauate.top() is true, that + /// means there were no non evaluatable siblings (or their + /// descendants) so this Expr can be evaluated + can_evaluate: Vec, + + ctx_state: ExecutionContextState, + planner: DefaultPhysicalPlanner, + input_schema: DFSchema, + input_batch: RecordBatch, +} + +impl ExprRewriter for ConstEvaluator { + fn pre_visit(&mut self, expr: &Expr) -> Result { + // Default to being able to evaluate this node + self.can_evaluate.push(true); + + // if this expr is not ok to evaluate, mark entire parent + // stack as not ok (as all parents have at least one child or + // descendant that is non evaluateable + + if !Self::can_evaluate(expr) { + // walk back up stack, marking first parent that is not mutable + let parent_iter = self.can_evaluate.iter_mut().rev(); + for p in parent_iter { + if !*p { + // optimization: if we find an element on the + // stack already marked, know all elements above are also marked + break; + } + *p = false; + } + } + + // NB: do not short circuit recursion even if we find a non + // evaluatable node (so we can fold other children, args to + // functions, etc) + Ok(RewriteRecursion::Continue) + } + + fn mutate(&mut self, expr: Expr) -> Result { + if self.can_evaluate.pop().unwrap() { + let scalar = self.evaluate_to_scalar(expr)?; + Ok(Expr::Literal(scalar)) + } else { + Ok(expr) + } + } +} + +impl ConstEvaluator { + /// Create a new `ConstantEvaluator`. + pub fn new() -> Self { + let planner = DefaultPhysicalPlanner::default(); + let ctx_state = ExecutionContextState::new(); + let input_schema = DFSchema::empty(); + + // The dummy column name uis used doesn't matter as only scalar + // expressions will be evaluated + static DUMMY_COL_NAME: &str = "."; + let schema = + Schema::new(vec![Field::new(DUMMY_COL_NAME, DataType::Float64, true)]); + + let col = new_null_array(&DataType::Float64, 1); + + let input_batch = + RecordBatch::try_new(std::sync::Arc::new(schema), vec![col]).unwrap(); + + Self { + can_evaluate: vec![], + ctx_state, + planner, + input_schema, + input_batch, + } + } + + /// Can a function of the specified volatility be evaluated? + fn volatility_ok(volatility: Volatility) -> bool { + match volatility { + Volatility::Immutable => true, + // To evaluate stable functions, need ExecutionProps, see + // Simplifier for code that does that. + Volatility::Stable => false, + Volatility::Volatile => false, + } + } + + /// Can the expression be evaluated at plan time, (assuming all of + /// its children can also be evaluated)? + fn can_evaluate(expr: &Expr) -> bool { + // check for reasons we can't evaluate this node + // + // NOTE all expr types are listed here so when new ones are + // added they can be checked for their ability to be evaluated + // at plan time + match expr { + // Has no runtime cost, but needed during planning + Expr::Alias(..) => false, + Expr::AggregateFunction { .. } => false, + Expr::AggregateUDF { .. } => false, + Expr::ScalarVariable(_) => false, + Expr::Column(_) => false, + Expr::ScalarFunction { fun, .. } => Self::volatility_ok(fun.volatility()), + Expr::ScalarUDF { fun, .. } => Self::volatility_ok(fun.signature.volatility), + Expr::WindowFunction { .. } => false, + Expr::Sort { .. } => false, + Expr::Wildcard => false, + + Expr::Literal(_) => true, + Expr::BinaryExpr { .. } => true, + Expr::Not(_) => true, + Expr::IsNotNull(_) => true, + Expr::IsNull(_) => true, + Expr::Negative(_) => true, + Expr::Between { .. } => true, + Expr::Case { .. } => true, + Expr::Cast { .. } => true, + Expr::TryCast { .. } => true, + Expr::InList { .. } => true, + } + } + + /// Internal helper to evaluates an Expr + fn evaluate_to_scalar(&self, expr: Expr) -> Result { + if let Expr::Literal(s) = expr { + return Ok(s); + } + + let phys_expr = self.planner.create_physical_expr( + &expr, + &self.input_schema, + &self.input_batch.schema(), + &self.ctx_state, + )?; + let col_val = phys_expr.evaluate(&self.input_batch)?; + match col_val { + crate::physical_plan::ColumnarValue::Array(a) => { + if a.len() != 1 { + Err(DataFusionError::Execution(format!( + "Could not evaluate the expressison, found a result of length {}", + a.len() + ))) + } else { + Ok(ScalarValue::try_from_array(&a, 0)?) + } + } + crate::physical_plan::ColumnarValue::Scalar(s) => Ok(s), + } + } +} + #[cfg(test)] mod tests { use super::*; - use crate::logical_plan::col; - use arrow::datatypes::DataType; + use crate::{ + logical_plan::{col, create_udf, lit_timestamp_nano}, + physical_plan::{ + functions::{make_scalar_function, BuiltinScalarFunction}, + udf::ScalarUDF, + }, + }; + use arrow::{ + array::{ArrayRef, Int32Array}, + datatypes::DataType, + }; use std::collections::HashSet; #[test] @@ -496,4 +688,206 @@ mod tests { assert!(accum.contains(&Column::from_name("a"))); Ok(()) } + + #[test] + fn test_const_evaluator() { + // true --> true + test_evaluate(lit(true), lit(true)); + // true or true --> true + test_evaluate(lit(true).or(lit(true)), lit(true)); + // true or false --> true + test_evaluate(lit(true).or(lit(false)), lit(true)); + + // "foo" == "foo" --> true + test_evaluate(lit("foo").eq(lit("foo")), lit(true)); + // "foo" != "foo" --> false + test_evaluate(lit("foo").not_eq(lit("foo")), lit(false)); + + // c = 1 --> c = 1 + test_evaluate(col("c").eq(lit(1)), col("c").eq(lit(1))); + // c = 1 + 2 --> c + 3 + test_evaluate(col("c").eq(lit(1) + lit(2)), col("c").eq(lit(3))); + // (foo != foo) OR (c = 1) --> false OR (c = 1) + test_evaluate( + (lit("foo").not_eq(lit("foo"))).or(col("c").eq(lit(1))), + lit(false).or(col("c").eq(lit(1))), + ); + // test boolean constant evaluation + + // true != true --> false + test_evaluate(lit(true).not_eq(lit(true)), lit(false)); + // true != false --> true + test_evaluate(lit(true).not_eq(lit(false)), lit(true)); + } + + #[test] + fn test_const_evaluator_scalar_functions() { + // concat("foo", "bar") --> "foobar" + let expr = Expr::ScalarFunction { + args: vec![lit("foo"), lit("bar")], + fun: BuiltinScalarFunction::Concat, + }; + test_evaluate(expr, lit("foobar")); + + // ensure arguments are also constant folded + // concat("foo", concat("bar", "baz")) --> "foobarbaz" + let concat1 = Expr::ScalarFunction { + args: vec![lit("bar"), lit("baz")], + fun: BuiltinScalarFunction::Concat, + }; + let expr = Expr::ScalarFunction { + args: vec![lit("foo"), concat1], + fun: BuiltinScalarFunction::Concat, + }; + test_evaluate(expr, lit("foobarbaz")); + + // Check non string arguments + // to_timestamp("2020-09-08T12:00:00+00:00") --> timestamp(1599566400000000000i64) + let expr = Expr::ScalarFunction { + args: vec![lit("2020-09-08T12:00:00+00:00")], + fun: BuiltinScalarFunction::ToTimestamp, + }; + test_evaluate(expr, lit_timestamp_nano(1599566400000000000i64)); + + // check that non foldable arguments are folded + // to_timestamp(a) --> to_timestamp(a) [no rewrite possible] + let expr = Expr::ScalarFunction { + args: vec![col("a")], + fun: BuiltinScalarFunction::ToTimestamp, + }; + test_evaluate(expr.clone(), expr); + + // check that non foldable arguments are folded + // to_timestamp(a) --> to_timestamp(a) [no rewrite possible] + let expr = Expr::ScalarFunction { + args: vec![col("a")], + fun: BuiltinScalarFunction::ToTimestamp, + }; + test_evaluate(expr.clone(), expr); + + // volatile / stable functions should not be evaluated + // rand() + (1 + 2) --> rand() + 3 + let fun = BuiltinScalarFunction::Random; + assert_eq!(fun.volatility(), Volatility::Volatile); + let rand = Expr::ScalarFunction { args: vec![], fun }; + let expr = rand.clone() + (lit(1) + lit(2)); + let expected = rand + lit(3); + test_evaluate(expr, expected); + + // parenthesization matters: can't rewrite + // (rand() + 1) + 2 --> (rand() + 1) + 2) + let fun = BuiltinScalarFunction::Random; + assert_eq!(fun.volatility(), Volatility::Volatile); + let rand = Expr::ScalarFunction { args: vec![], fun }; + let expr = (rand + lit(1)) + lit(2); + test_evaluate(expr.clone(), expr); + + // volatile / stable functions should not be evaluated + // now() + (1 + 2) --> now() + 3 + let fun = BuiltinScalarFunction::Now; + assert_eq!(fun.volatility(), Volatility::Stable); + let now = Expr::ScalarFunction { args: vec![], fun }; + let expr = now.clone() + (lit(1) + lit(2)); + let expected = now + lit(3); + test_evaluate(expr, expected); + } + + #[test] + fn test_const_evaluator_udfs() { + let args = vec![lit(1) + lit(2), lit(30) + lit(40)]; + let folded_args = vec![lit(3), lit(70)]; + + // immutable UDF should get folded + // udf_add(1+2, 30+40) --> 70 + let expr = Expr::ScalarUDF { + args: args.clone(), + fun: make_udf_add(Volatility::Immutable), + }; + test_evaluate(expr, lit(73)); + + // stable UDF should have args folded + // udf_add(1+2, 30+40) --> udf_add(3, 70) + let fun = make_udf_add(Volatility::Stable); + let expr = Expr::ScalarUDF { + args: args.clone(), + fun: Arc::clone(&fun), + }; + let expected_expr = Expr::ScalarUDF { + args: folded_args.clone(), + fun: Arc::clone(&fun), + }; + test_evaluate(expr, expected_expr); + + // volatile UDF should have args folded + // udf_add(1+2, 30+40) --> udf_add(3, 70) + let fun = make_udf_add(Volatility::Volatile); + let expr = Expr::ScalarUDF { + args, + fun: Arc::clone(&fun), + }; + let expected_expr = Expr::ScalarUDF { + args: folded_args, + fun: Arc::clone(&fun), + }; + test_evaluate(expr, expected_expr); + } + + // Make a UDF that adds its two values together, with the specified volatility + fn make_udf_add(volatility: Volatility) -> Arc { + let input_types = vec![DataType::Int32, DataType::Int32]; + let return_type = Arc::new(DataType::Int32); + + let fun = |args: &[ArrayRef]| { + let arg0 = &args[0] + .as_any() + .downcast_ref::() + .expect("cast failed"); + let arg1 = &args[1] + .as_any() + .downcast_ref::() + .expect("cast failed"); + + // 2. perform the computation + let array = arg0 + .iter() + .zip(arg1.iter()) + .map(|args| { + if let (Some(arg0), Some(arg1)) = args { + Some(arg0 + arg1) + } else { + // one or both args were Null + None + } + }) + .collect::(); + + Ok(Arc::new(array) as ArrayRef) + }; + + let fun = make_scalar_function(fun); + Arc::new(create_udf( + "udf_add", + input_types, + return_type, + volatility, + fun, + )) + } + + // udfs + // validate that even a volatile function's arguments will be evaluated + + fn test_evaluate(input_expr: Expr, expected_expr: Expr) { + let mut const_evaluator = ConstEvaluator::new(); + let evaluated_expr = input_expr + .clone() + .rewrite(&mut const_evaluator) + .expect("successfully evaluated"); + + assert_eq!( + evaluated_expr, expected_expr, + "Mismatch evaluating {}\n Expected:{}\n Got:{}", + input_expr, expected_expr, evaluated_expr + ); + } } diff --git a/datafusion/src/physical_plan/expressions/binary.rs b/datafusion/src/physical_plan/expressions/binary.rs index 4a3ef581e6f3b..a8be1b52354e8 100644 --- a/datafusion/src/physical_plan/expressions/binary.rs +++ b/datafusion/src/physical_plan/expressions/binary.rs @@ -342,6 +342,29 @@ macro_rules! boolean_op { }}; } +/// Invoke a boolean kernel with a scalar on an array +macro_rules! boolean_op_scalar { + ($LEFT:expr, $RIGHT:expr, $OP:ident) => {{ + let ll = $LEFT + .as_any() + .downcast_ref::() + .expect("boolean_op_scalar failed to downcast array"); + + let result = if let ScalarValue::Boolean(scalar) = $RIGHT { + Ok( + Arc::new(paste::expr! {[<$OP _bool_scalar>]}(&ll, scalar.as_ref())?) + as ArrayRef, + ) + } else { + Err(DataFusionError::Internal(format!( + "boolean_op_scalar failed to cast literal value {}", + $RIGHT + ))) + }; + Some(result) + }}; +} + macro_rules! binary_string_array_flag_op { ($LEFT:expr, $RIGHT:expr, $OP:ident, $NOT:expr, $FLAG:expr) => {{ match $LEFT.data_type() { @@ -592,9 +615,19 @@ impl BinaryExpr { Operator::GtEq => { binary_array_op_scalar!(array, scalar.clone(), gt_eq) } - Operator::Eq => binary_array_op_scalar!(array, scalar.clone(), eq), + Operator::Eq => { + if array.data_type() == &DataType::Boolean { + boolean_op_scalar!(array, scalar.clone(), eq) + } else { + binary_array_op_scalar!(array, scalar.clone(), eq) + } + } Operator::NotEq => { - binary_array_op_scalar!(array, scalar.clone(), neq) + if array.data_type() == &DataType::Boolean { + boolean_op_scalar!(array, scalar.clone(), neq) + } else { + binary_array_op_scalar!(array, scalar.clone(), neq) + } } Operator::Like => { binary_string_array_op_scalar!(array, scalar.clone(), like) @@ -659,9 +692,19 @@ impl BinaryExpr { Operator::GtEq => { binary_array_op_scalar!(array, scalar.clone(), lt_eq) } - Operator::Eq => binary_array_op_scalar!(array, scalar.clone(), eq), + Operator::Eq => { + if array.data_type() == &DataType::Boolean { + boolean_op_scalar!(array, scalar.clone(), eq) + } else { + binary_array_op_scalar!(array, scalar.clone(), eq) + } + } Operator::NotEq => { - binary_array_op_scalar!(array, scalar.clone(), neq) + if array.data_type() == &DataType::Boolean { + boolean_op_scalar!(array, scalar.clone(), neq) + } else { + binary_array_op_scalar!(array, scalar.clone(), neq) + } } // if scalar operation is not supported - fallback to array implementation _ => None, @@ -683,8 +726,21 @@ impl BinaryExpr { Operator::LtEq => binary_array_op!(left, right, lt_eq), Operator::Gt => binary_array_op!(left, right, gt), Operator::GtEq => binary_array_op!(left, right, gt_eq), - Operator::Eq => binary_array_op!(left, right, eq), - Operator::NotEq => binary_array_op!(left, right, neq), + Operator::Eq => { + if left_data_type == &DataType::Boolean { + boolean_op!(left, right, eq_bool) + } else { + binary_array_op!(left, right, eq) + } + } + Operator::NotEq => { + if left_data_type == &DataType::Boolean { + boolean_op!(left, right, neq_bool) + } else { + binary_array_op!(left, right, neq) + } + } + Operator::IsDistinctFrom => binary_array_op!(left, right, is_distinct_from), Operator::IsNotDistinctFrom => { binary_array_op!(left, right, is_not_distinct_from) @@ -814,6 +870,60 @@ pub fn binary( Ok(Arc::new(BinaryExpr::new(l, op, r))) } +// TODO file a ticket with arrow-rs to include these kernels + +fn eq_bool(lhs: &BooleanArray, rhs: &BooleanArray) -> Result { + let arr: BooleanArray = lhs + .iter() + .zip(rhs.iter()) + .map(|v| match v { + // both lhs and rhs were non null + (Some(lhs), Some(rhs)) => Some(lhs == rhs), + _ => None, + }) + .collect(); + + Ok(arr) +} + +fn eq_bool_scalar(lhs: &BooleanArray, rhs: Option<&bool>) -> Result { + let arr: BooleanArray = lhs + .iter() + .map(|v| match (v, rhs) { + // both lhs and rhs were non null + (Some(lhs), Some(rhs)) => Some(lhs == *rhs), + _ => None, + }) + .collect(); + Ok(arr) +} + +fn neq_bool(lhs: &BooleanArray, rhs: &BooleanArray) -> Result { + let arr: BooleanArray = lhs + .iter() + .zip(rhs.iter()) + .map(|v| match v { + // both lhs and rhs were non null + (Some(lhs), Some(rhs)) => Some(lhs != rhs), + _ => None, + }) + .collect(); + + Ok(arr) +} + +fn neq_bool_scalar(lhs: &BooleanArray, rhs: Option<&bool>) -> Result { + let arr: BooleanArray = lhs + .iter() + .map(|v| match (v, rhs) { + // both lhs and rhs were non null + (Some(lhs), Some(rhs)) => Some(lhs != *rhs), + _ => None, + }) + .collect(); + Ok(arr) +} + #[cfg(test)] mod tests { use arrow::datatypes::{ArrowNumericType, Field, Int32Type, SchemaRef}; @@ -821,7 +931,7 @@ mod tests { use super::*; use crate::error::Result; - use crate::physical_plan::expressions::col; + use crate::physical_plan::expressions::{col, lit}; // Create a binary expression without coercion. Used here when we do not want to coerce the expressions // to valid types. Usage can result in an execution (after plan) error. @@ -1371,6 +1481,42 @@ mod tests { Ok(()) } + // Test `scalar arr` produces expected + fn apply_logic_op_scalar_arr( + schema: &SchemaRef, + scalar: bool, + arr: &ArrayRef, + op: Operator, + expected: &BooleanArray, + ) -> Result<()> { + let scalar = lit(scalar.into()); + + let arithmetic_op = binary_simple(scalar, op, col("a", schema)?); + let batch = RecordBatch::try_new(Arc::clone(schema), vec![Arc::clone(arr)])?; + let result = arithmetic_op.evaluate(&batch)?.into_array(batch.num_rows()); + assert_eq!(result.as_ref(), expected); + + Ok(()) + } + + // Test `arr scalar` produces expected + fn apply_logic_op_arr_scalar( + schema: &SchemaRef, + arr: &ArrayRef, + scalar: bool, + op: Operator, + expected: &BooleanArray, + ) -> Result<()> { + let scalar = lit(scalar.into()); + + let arithmetic_op = binary_simple(col("a", schema)?, op, scalar); + let batch = RecordBatch::try_new(Arc::clone(schema), vec![Arc::clone(arr)])?; + let result = arithmetic_op.evaluate(&batch)?.into_array(batch.num_rows()); + assert_eq!(result.as_ref(), expected); + + Ok(()) + } + #[test] fn and_with_nulls_op() -> Result<()> { let schema = Schema::new(vec![ @@ -1461,6 +1607,58 @@ mod tests { Ok(()) } + #[test] + fn eq_op_bool() { + let schema = Schema::new(vec![ + Field::new("a", DataType::Boolean, false), + Field::new("b", DataType::Boolean, false), + ]); + let a = BooleanArray::from(vec![Some(true), None, Some(false), None]); + let b = + BooleanArray::from(vec![Some(true), Some(false), Some(true), Some(false)]); + + let expected = BooleanArray::from(vec![Some(true), None, Some(false), None]); + apply_logic_op(Arc::new(schema), a, b, Operator::Eq, expected).unwrap(); + } + + #[test] + fn eq_op_bool_scalar() { + let schema = Schema::new(vec![Field::new("a", DataType::Boolean, false)]); + let schema = Arc::new(schema); + let a: ArrayRef = + Arc::new(BooleanArray::from(vec![Some(true), None, Some(false)])); + + let expected = BooleanArray::from(vec![Some(true), None, Some(false)]); + apply_logic_op_scalar_arr(&schema, true, &a, Operator::Eq, &expected).unwrap(); + apply_logic_op_arr_scalar(&schema, &a, true, Operator::Eq, &expected).unwrap(); + } + + #[test] + fn neq_op_bool() { + let schema = Schema::new(vec![ + Field::new("a", DataType::Boolean, false), + Field::new("b", DataType::Boolean, false), + ]); + let a = BooleanArray::from(vec![Some(true), None, Some(false), None]); + let b = + BooleanArray::from(vec![Some(true), Some(false), Some(true), Some(false)]); + + let expected = BooleanArray::from(vec![Some(false), None, Some(true), None]); + apply_logic_op(Arc::new(schema), a, b, Operator::NotEq, expected).unwrap(); + } + + #[test] + fn neq_op_bool_scalar() { + let schema = Schema::new(vec![Field::new("a", DataType::Boolean, false)]); + let schema = Arc::new(schema); + let a: ArrayRef = + Arc::new(BooleanArray::from(vec![Some(true), None, Some(false)])); + + let expected = BooleanArray::from(vec![Some(false), None, Some(true)]); + apply_logic_op_scalar_arr(&schema, true, &a, Operator::NotEq, &expected).unwrap(); + apply_logic_op_arr_scalar(&schema, &a, true, Operator::NotEq, &expected).unwrap(); + } + #[test] fn test_coersion_error() -> Result<()> { let expr = diff --git a/datafusion/src/test_util.rs b/datafusion/src/test_util.rs index 0c9498acf9207..03e0054b15706 100644 --- a/datafusion/src/test_util.rs +++ b/datafusion/src/test_util.rs @@ -88,6 +88,52 @@ macro_rules! assert_batches_sorted_eq { }; } +/// A macro to assert that one string is contained within another with +/// a nice error message if they are not. +/// +/// Usage: `assert_contains!(actual, expected)` +/// +/// Is a macro so test error +/// messages are on the same line as the failure; +/// +/// Both arguments must be convertable into Strings (Into) +#[macro_export] +macro_rules! assert_contains { + ($ACTUAL: expr, $EXPECTED: expr) => { + let actual_value: String = $ACTUAL.into(); + let expected_value: String = $EXPECTED.into(); + assert!( + actual_value.contains(&expected_value), + "Can not find expected in actual.\n\nExpected:\n{}\n\nActual:\n{}", + expected_value, + actual_value + ); + }; +} + +/// A macro to assert that one string is NOT contained within another with +/// a nice error message if they are are. +/// +/// Usage: `assert_not_contains!(actual, unexpected)` +/// +/// Is a macro so test error +/// messages are on the same line as the failure; +/// +/// Both arguments must be convertable into Strings (Into) +#[macro_export] +macro_rules! assert_not_contains { + ($ACTUAL: expr, $UNEXPECTED: expr) => { + let actual_value: String = $ACTUAL.into(); + let unexpected_value: String = $UNEXPECTED.into(); + assert!( + !actual_value.contains(&unexpected_value), + "Found unexpected in actual.\n\nUnexpected:\n{}\n\nActual:\n{}", + unexpected_value, + actual_value + ); + }; +} + /// Returns the arrow test data directory, which is by default stored /// in a git submodule rooted at `testing/data`. /// diff --git a/datafusion/tests/sql.rs b/datafusion/tests/sql.rs index 67270c5cfb044..5c526e2709b24 100644 --- a/datafusion/tests/sql.rs +++ b/datafusion/tests/sql.rs @@ -34,6 +34,8 @@ use arrow::{ use datafusion::assert_batches_eq; use datafusion::assert_batches_sorted_eq; +use datafusion::assert_contains; +use datafusion::assert_not_contains; use datafusion::logical_plan::LogicalPlan; use datafusion::physical_plan::functions::Volatility; use datafusion::physical_plan::metrics::MetricValue; @@ -47,50 +49,6 @@ use datafusion::{ }; use datafusion::{execution::context::ExecutionContext, physical_plan::displayable}; -/// A macro to assert that one string is contained within another with -/// a nice error message if they are not. -/// -/// Usage: `assert_contains!(actual, expected)` -/// -/// Is a macro so test error -/// messages are on the same line as the failure; -/// -/// Both arguments must be convertable into Strings (Into) -macro_rules! assert_contains { - ($ACTUAL: expr, $EXPECTED: expr) => { - let actual_value: String = $ACTUAL.into(); - let expected_value: String = $EXPECTED.into(); - assert!( - actual_value.contains(&expected_value), - "Can not find expected in actual.\n\nExpected:\n{}\n\nActual:\n{}", - expected_value, - actual_value - ); - }; -} - -/// A macro to assert that one string is NOT contained within another with -/// a nice error message if they are are. -/// -/// Usage: `assert_not_contains!(actual, unexpected)` -/// -/// Is a macro so test error -/// messages are on the same line as the failure; -/// -/// Both arguments must be convertable into Strings (Into) -macro_rules! assert_not_contains { - ($ACTUAL: expr, $UNEXPECTED: expr) => { - let actual_value: String = $ACTUAL.into(); - let unexpected_value: String = $UNEXPECTED.into(); - assert!( - !actual_value.contains(&unexpected_value), - "Found unexpected in actual.\n\nUnexpected:\n{}\n\nActual:\n{}", - unexpected_value, - actual_value - ); - }; -} - #[tokio::test] async fn nyc() -> Result<()> { // schema for nyxtaxi csv files @@ -598,7 +556,7 @@ async fn select_distinct_simple_4() { async fn select_distinct_from() { let mut ctx = ExecutionContext::new(); - let sql = "select + let sql = "select 1 IS DISTINCT FROM CAST(NULL as INT) as a, 1 IS DISTINCT FROM 1 as b, 1 IS NOT DISTINCT FROM CAST(NULL as INT) as c, @@ -621,7 +579,7 @@ async fn select_distinct_from() { async fn select_distinct_from_utf8() { let mut ctx = ExecutionContext::new(); - let sql = "select + let sql = "select 'x' IS DISTINCT FROM NULL as a, 'x' IS DISTINCT FROM 'x' as b, 'x' IS NOT DISTINCT FROM NULL as c, @@ -812,6 +770,40 @@ async fn csv_query_having_without_group_by() -> Result<()> { Ok(()) } +#[tokio::test] +async fn csv_query_boolean_eq() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_aggregate_simple_csv(&mut ctx).await?; + + let sql = "SELECT c3, c3 = c3 as eq, c3 != c3 as neq FROM aggregate_simple"; + let actual = execute_to_batches(&mut ctx, sql).await; + + let expected = vec![ + "+-------+------+-------+", + "| c3 | eq | neq |", + "+-------+------+-------+", + "| true | true | false |", + "| false | true | false |", + "| false | true | false |", + "| true | true | false |", + "| true | true | false |", + "| true | true | false |", + "| false | true | false |", + "| false | true | false |", + "| false | true | false |", + "| false | true | false |", + "| true | true | false |", + "| true | true | false |", + "| true | true | false |", + "| true | true | false |", + "| true | true | false |", + "+-------+------+-------+", + ]; + assert_batches_eq!(expected, &actual); + + Ok(()) +} + #[tokio::test] async fn csv_query_avg_sqrt() -> Result<()> { let mut ctx = create_ctx()?; @@ -4054,6 +4046,8 @@ macro_rules! test_expression { async fn test_boolean_expressions() -> Result<()> { test_expression!("true", "true"); test_expression!("false", "false"); + test_expression!("false = false", "true"); + test_expression!("true = false", "false"); Ok(()) }