diff --git a/datafusion/src/optimizer/constant_folding.rs b/datafusion/src/optimizer/constant_folding.rs index 74fdc729eb690..8c29da4172bac 100644 --- a/datafusion/src/optimizer/constant_folding.rs +++ b/datafusion/src/optimizer/constant_folding.rs @@ -17,8 +17,6 @@ //! Constant folding and algebraic simplification -use std::sync::Arc; - use arrow::datatypes::DataType; use crate::error::Result; @@ -26,7 +24,6 @@ use crate::execution::context::ExecutionProps; use crate::logical_plan::{DFSchemaRef, Expr, ExprRewriter, LogicalPlan, Operator}; use crate::optimizer::optimizer::OptimizerRule; use crate::optimizer::utils; -use crate::physical_plan::functions::BuiltinScalarFunction; use crate::scalar::ScalarValue; /// Simplifies plans by rewriting [`Expr`]`s evaluating constants @@ -61,18 +58,14 @@ impl OptimizerRule for ConstantFolding { // children plans. let mut simplifier = Simplifier { schemas: plan.all_schemas(), - execution_props, }; - let mut const_evaluator = utils::ConstEvaluator::new(); + let mut const_evaluator = utils::ConstEvaluator::new(execution_props); match plan { - LogicalPlan::Filter { predicate, input } => Ok(LogicalPlan::Filter { - predicate: predicate.clone().rewrite(&mut simplifier)?, - input: Arc::new(self.optimize(input, execution_props)?), - }), - // Rest: recurse into plan, apply optimization where possible - LogicalPlan::Projection { .. } + // Recurse into plan, apply optimization where possible + LogicalPlan::Filter { .. } + | LogicalPlan::Projection { .. } | LogicalPlan::Window { .. } | LogicalPlan::Aggregate { .. } | LogicalPlan::Repartition { .. } @@ -130,7 +123,6 @@ impl OptimizerRule for ConstantFolding { struct Simplifier<'a> { /// input schemas schemas: Vec<&'a DFSchemaRef>, - execution_props: &'a ExecutionProps, } impl<'a> Simplifier<'a> { @@ -228,15 +220,6 @@ impl<'a> ExprRewriter for Simplifier<'a> { Expr::Not(inner) } } - // convert now() --> the time in `ExecutionProps` - Expr::ScalarFunction { - fun: BuiltinScalarFunction::Now, - .. - } => Expr::Literal(ScalarValue::TimestampNanosecond(Some( - self.execution_props - .query_execution_start_time - .timestamp_nanos(), - ))), expr => { // no additional rewrites possible expr @@ -248,10 +231,13 @@ impl<'a> ExprRewriter for Simplifier<'a> { #[cfg(test)] mod tests { + use std::sync::Arc; + use super::*; use crate::{ assert_contains, logical_plan::{col, lit, max, min, DFField, DFSchema, LogicalPlanBuilder}, + physical_plan::functions::BuiltinScalarFunction, }; use arrow::datatypes::*; @@ -282,7 +268,6 @@ mod tests { let schema = expr_test_schema(); let mut rewriter = Simplifier { schemas: vec![&schema], - execution_props: &ExecutionProps::new(), }; assert_eq!( @@ -298,7 +283,6 @@ mod tests { let schema = expr_test_schema(); let mut rewriter = Simplifier { schemas: vec![&schema], - execution_props: &ExecutionProps::new(), }; // x = null is always null @@ -334,7 +318,6 @@ mod tests { let schema = expr_test_schema(); let mut rewriter = Simplifier { schemas: vec![&schema], - execution_props: &ExecutionProps::new(), }; assert_eq!(col("c2").get_type(&schema)?, DataType::Boolean); @@ -365,7 +348,6 @@ mod tests { let schema = expr_test_schema(); let mut rewriter = Simplifier { schemas: vec![&schema], - execution_props: &ExecutionProps::new(), }; // When one of the operand is not of boolean type, folding the other boolean constant will @@ -405,7 +387,6 @@ mod tests { let schema = expr_test_schema(); let mut rewriter = Simplifier { schemas: vec![&schema], - execution_props: &ExecutionProps::new(), }; assert_eq!(col("c2").get_type(&schema)?, DataType::Boolean); @@ -441,7 +422,6 @@ mod tests { let schema = expr_test_schema(); let mut rewriter = Simplifier { schemas: vec![&schema], - execution_props: &ExecutionProps::new(), }; // when one of the operand is not of boolean type, folding the other boolean constant will @@ -477,7 +457,6 @@ mod tests { let schema = expr_test_schema(); let mut rewriter = Simplifier { schemas: vec![&schema], - execution_props: &ExecutionProps::new(), }; assert_eq!( @@ -753,27 +732,6 @@ mod tests { } } - #[test] - fn single_now_expr() { - let table_scan = test_table_scan().unwrap(); - let proj = vec![now_expr()]; - let time = Utc::now(); - let plan = LogicalPlanBuilder::from(table_scan) - .project(proj) - .unwrap() - .build() - .unwrap(); - - let expected = format!( - "Projection: TimestampNanosecond({})\ - \n TableScan: test projection=None", - time.timestamp_nanos() - ); - let actual = get_optimized_plan_formatted(&plan, &time); - - assert_eq!(expected, actual); - } - #[test] fn multiple_now_expr() { let table_scan = test_table_scan().unwrap(); @@ -838,17 +796,16 @@ mod tests { // now() < cast(to_timestamp(...) as int) + 5000000000 let plan = LogicalPlanBuilder::from(table_scan) .filter( - now_expr() + cast_to_int64_expr(now_expr()) .lt(cast_to_int64_expr(to_timestamp_expr(ts_string)) + lit(50000)), ) .unwrap() .build() .unwrap(); - // Note that constant folder should be able to run again and fold - // this whole expression down to a single constant; - // https://github.com/apache/arrow-datafusion/issues/1160 - let expected = "Filter: TimestampNanosecond(1599566400000000000) < CAST(totimestamp(Utf8(\"2020-09-08T12:05:00+00:00\")) AS Int64) + Int32(50000)\ + // Note that constant folder runs and folds the entire + // expression down to a single constant (true) + let expected = "Filter: Boolean(true)\ \n TableScan: test projection=None"; let actual = get_optimized_plan_formatted(&plan, &time); diff --git a/datafusion/src/optimizer/utils.rs b/datafusion/src/optimizer/utils.rs index fdc9a173ed5e5..00ea31e2a358e 100644 --- a/datafusion/src/optimizer/utils.rs +++ b/datafusion/src/optimizer/utils.rs @@ -506,7 +506,10 @@ pub fn rewrite_expression(expr: &Expr, expressions: &[Expr]) -> Result { /// ``` /// # use datafusion::prelude::*; /// # use datafusion::optimizer::utils::ConstEvaluator; -/// let mut const_evaluator = ConstEvaluator::new(); +/// # use datafusion::execution::context::ExecutionProps; +/// +/// let execution_props = ExecutionProps::new(); +/// let mut const_evaluator = ConstEvaluator::new(&execution_props); /// /// // (1 + 2) + a /// let expr = (lit(1) + lit(2)) + col("a"); @@ -575,10 +578,15 @@ impl ExprRewriter for ConstEvaluator { } impl ConstEvaluator { - /// Create a new `ConstantEvaluator`. - pub fn new() -> Self { + /// Create a new `ConstantEvaluator`. Session constants (such as + /// the time for `now()` are taken from the passed + /// `execution_props`. + pub fn new(execution_props: &ExecutionProps) -> Self { let planner = DefaultPhysicalPlanner::default(); - let ctx_state = ExecutionContextState::new(); + let ctx_state = ExecutionContextState { + execution_props: execution_props.clone(), + ..ExecutionContextState::new() + }; let input_schema = DFSchema::empty(); // The dummy column name is unused and doesn't matter as only @@ -604,9 +612,8 @@ impl ConstEvaluator { fn volatility_ok(volatility: Volatility) -> bool { match volatility { Volatility::Immutable => true, - // To evaluate stable functions, need ExecutionProps, see - // Simplifier for code that does that. - Volatility::Stable => false, + // Values for functions such as now() are taken from ExecutionProps + Volatility::Stable => true, Volatility::Volatile => false, } } @@ -689,6 +696,7 @@ mod tests { array::{ArrayRef, Int32Array}, datatypes::DataType, }; + use chrono::{DateTime, TimeZone, Utc}; use std::collections::HashSet; #[test] @@ -799,42 +807,69 @@ mod tests { let rand = Expr::ScalarFunction { args: vec![], fun }; let expr = (rand + lit(1)) + lit(2); test_evaluate(expr.clone(), expr); + } - // volatile / stable functions should not be evaluated - // now() + (1 + 2) --> now() + 3 - let fun = BuiltinScalarFunction::Now; - assert_eq!(fun.volatility(), Volatility::Stable); - let now = Expr::ScalarFunction { args: vec![], fun }; - let expr = now.clone() + (lit(1) + lit(2)); - let expected = now + lit(3); - test_evaluate(expr, expected); + #[test] + fn test_const_evaluator_now() { + let ts_nanos = 1599566400000000000i64; + let time = chrono::Utc.timestamp_nanos(ts_nanos); + let ts_string = "2020-09-08T12:05:00+00:00"; + + // now() --> ts + test_evaluate_with_start_time(now_expr(), lit_timestamp_nano(ts_nanos), &time); + + // CAST(now() as int64) + 100 --> ts + 100 + let expr = cast_to_int64_expr(now_expr()) + lit(100); + test_evaluate_with_start_time(expr, lit(ts_nanos + 100), &time); + + // now() < cast(to_timestamp(...) as int) + 50000 ---> true + let expr = cast_to_int64_expr(now_expr()) + .lt(cast_to_int64_expr(to_timestamp_expr(ts_string)) + lit(50000)); + test_evaluate_with_start_time(expr, lit(true), &time); + } + + fn now_expr() -> Expr { + Expr::ScalarFunction { + args: vec![], + fun: BuiltinScalarFunction::Now, + } + } + + fn cast_to_int64_expr(expr: Expr) -> Expr { + Expr::Cast { + expr: expr.into(), + data_type: DataType::Int64, + } + } + + fn to_timestamp_expr(arg: impl Into) -> Expr { + Expr::ScalarFunction { + args: vec![lit(arg.into())], + fun: BuiltinScalarFunction::ToTimestamp, + } } #[test] - fn test_const_evaluator_udfs() { + fn test_evaluator_udfs() { let args = vec![lit(1) + lit(2), lit(30) + lit(40)]; let folded_args = vec![lit(3), lit(70)]; // immutable UDF should get folded - // udf_add(1+2, 30+40) --> 70 + // udf_add(1+2, 30+40) --> 73 let expr = Expr::ScalarUDF { args: args.clone(), fun: make_udf_add(Volatility::Immutable), }; test_evaluate(expr, lit(73)); - // stable UDF should have args folded - // udf_add(1+2, 30+40) --> udf_add(3, 70) + // stable UDF should be entirely folded + // udf_add(1+2, 30+40) --> 73 let fun = make_udf_add(Volatility::Stable); let expr = Expr::ScalarUDF { args: args.clone(), fun: Arc::clone(&fun), }; - let expected_expr = Expr::ScalarUDF { - args: folded_args.clone(), - fun: Arc::clone(&fun), - }; - test_evaluate(expr, expected_expr); + test_evaluate(expr, lit(73)); // volatile UDF should have args folded // udf_add(1+2, 30+40) --> udf_add(3, 70) @@ -892,11 +927,16 @@ mod tests { )) } - // udfs - // validate that even a volatile function's arguments will be evaluated + fn test_evaluate_with_start_time( + input_expr: Expr, + expected_expr: Expr, + date_time: &DateTime, + ) { + let execution_props = ExecutionProps { + query_execution_start_time: *date_time, + }; - fn test_evaluate(input_expr: Expr, expected_expr: Expr) { - let mut const_evaluator = ConstEvaluator::new(); + let mut const_evaluator = ConstEvaluator::new(&execution_props); let evaluated_expr = input_expr .clone() .rewrite(&mut const_evaluator) @@ -908,4 +948,8 @@ mod tests { input_expr, expected_expr, evaluated_expr ); } + + fn test_evaluate(input_expr: Expr, expected_expr: Expr) { + test_evaluate_with_start_time(input_expr, expected_expr, &Utc::now()) + } }