Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 11 additions & 54 deletions datafusion/src/optimizer/constant_folding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,13 @@

//! Constant folding and algebraic simplification

use std::sync::Arc;

use arrow::datatypes::DataType;

use crate::error::Result;
use crate::execution::context::ExecutionProps;
use crate::logical_plan::{DFSchemaRef, Expr, ExprRewriter, LogicalPlan, Operator};
use crate::optimizer::optimizer::OptimizerRule;
use crate::optimizer::utils;
use crate::physical_plan::functions::BuiltinScalarFunction;
use crate::scalar::ScalarValue;

/// Simplifies plans by rewriting [`Expr`]`s evaluating constants
Expand Down Expand Up @@ -61,18 +58,14 @@ impl OptimizerRule for ConstantFolding {
// children plans.
let mut simplifier = Simplifier {
schemas: plan.all_schemas(),
execution_props,
};

let mut const_evaluator = utils::ConstEvaluator::new();
let mut const_evaluator = utils::ConstEvaluator::new(execution_props);

match plan {
LogicalPlan::Filter { predicate, input } => Ok(LogicalPlan::Filter {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no reason to special case LogicalPlan::Filter as the predicate is handled by LogicalPlan::expressions -- and if you look carefully this doesn't call rewrite using the const_evaluator (I totally missed this in #1153 ) but found it while updating tests in this PR

predicate: predicate.clone().rewrite(&mut simplifier)?,
input: Arc::new(self.optimize(input, execution_props)?),
}),
// Rest: recurse into plan, apply optimization where possible
LogicalPlan::Projection { .. }
// Recurse into plan, apply optimization where possible
LogicalPlan::Filter { .. }
| LogicalPlan::Projection { .. }
| LogicalPlan::Window { .. }
| LogicalPlan::Aggregate { .. }
| LogicalPlan::Repartition { .. }
Expand Down Expand Up @@ -130,7 +123,6 @@ impl OptimizerRule for ConstantFolding {
struct Simplifier<'a> {
/// input schemas
schemas: Vec<&'a DFSchemaRef>,
execution_props: &'a ExecutionProps,
}

impl<'a> Simplifier<'a> {
Expand Down Expand Up @@ -228,15 +220,6 @@ impl<'a> ExprRewriter for Simplifier<'a> {
Expr::Not(inner)
}
}
// convert now() --> the time in `ExecutionProps`
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The point of this PR is to remove this code (it is now handled by ConstEvaluator)

Expr::ScalarFunction {
fun: BuiltinScalarFunction::Now,
..
} => Expr::Literal(ScalarValue::TimestampNanosecond(Some(
self.execution_props
.query_execution_start_time
.timestamp_nanos(),
))),
expr => {
// no additional rewrites possible
expr
Expand All @@ -248,10 +231,13 @@ impl<'a> ExprRewriter for Simplifier<'a> {

#[cfg(test)]
mod tests {
use std::sync::Arc;

use super::*;
use crate::{
assert_contains,
logical_plan::{col, lit, max, min, DFField, DFSchema, LogicalPlanBuilder},
physical_plan::functions::BuiltinScalarFunction,
};

use arrow::datatypes::*;
Expand Down Expand Up @@ -282,7 +268,6 @@ mod tests {
let schema = expr_test_schema();
let mut rewriter = Simplifier {
schemas: vec![&schema],
execution_props: &ExecutionProps::new(),
};

assert_eq!(
Expand All @@ -298,7 +283,6 @@ mod tests {
let schema = expr_test_schema();
let mut rewriter = Simplifier {
schemas: vec![&schema],
execution_props: &ExecutionProps::new(),
};

// x = null is always null
Expand Down Expand Up @@ -334,7 +318,6 @@ mod tests {
let schema = expr_test_schema();
let mut rewriter = Simplifier {
schemas: vec![&schema],
execution_props: &ExecutionProps::new(),
};

assert_eq!(col("c2").get_type(&schema)?, DataType::Boolean);
Expand Down Expand Up @@ -365,7 +348,6 @@ mod tests {
let schema = expr_test_schema();
let mut rewriter = Simplifier {
schemas: vec![&schema],
execution_props: &ExecutionProps::new(),
};

// When one of the operand is not of boolean type, folding the other boolean constant will
Expand Down Expand Up @@ -405,7 +387,6 @@ mod tests {
let schema = expr_test_schema();
let mut rewriter = Simplifier {
schemas: vec![&schema],
execution_props: &ExecutionProps::new(),
};

assert_eq!(col("c2").get_type(&schema)?, DataType::Boolean);
Expand Down Expand Up @@ -441,7 +422,6 @@ mod tests {
let schema = expr_test_schema();
let mut rewriter = Simplifier {
schemas: vec![&schema],
execution_props: &ExecutionProps::new(),
};

// when one of the operand is not of boolean type, folding the other boolean constant will
Expand Down Expand Up @@ -477,7 +457,6 @@ mod tests {
let schema = expr_test_schema();
let mut rewriter = Simplifier {
schemas: vec![&schema],
execution_props: &ExecutionProps::new(),
};

assert_eq!(
Expand Down Expand Up @@ -753,27 +732,6 @@ mod tests {
}
}

#[test]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

covered in the tests for ConstEvaluator

fn single_now_expr() {
let table_scan = test_table_scan().unwrap();
let proj = vec![now_expr()];
let time = Utc::now();
let plan = LogicalPlanBuilder::from(table_scan)
.project(proj)
.unwrap()
.build()
.unwrap();

let expected = format!(
"Projection: TimestampNanosecond({})\
\n TableScan: test projection=None",
time.timestamp_nanos()
);
let actual = get_optimized_plan_formatted(&plan, &time);

assert_eq!(expected, actual);
}

#[test]
fn multiple_now_expr() {
let table_scan = test_table_scan().unwrap();
Expand Down Expand Up @@ -838,17 +796,16 @@ mod tests {
// now() < cast(to_timestamp(...) as int) + 5000000000
let plan = LogicalPlanBuilder::from(table_scan)
.filter(
now_expr()
cast_to_int64_expr(now_expr())
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will be pretty awesome when we get #194 so casting to do timestamp arithmetic is no longer needed

.lt(cast_to_int64_expr(to_timestamp_expr(ts_string)) + lit(50000)),
)
.unwrap()
.build()
.unwrap();

// Note that constant folder should be able to run again and fold
// this whole expression down to a single constant;
// https://github.com/apache/arrow-datafusion/issues/1160
let expected = "Filter: TimestampNanosecond(1599566400000000000) < CAST(totimestamp(Utf8(\"2020-09-08T12:05:00+00:00\")) AS Int64) + Int32(50000)\
// Note that constant folder runs and folds the entire
// expression down to a single constant (true)
let expected = "Filter: Boolean(true)\
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The filter expression has been totally simplified 🎉

\n TableScan: test projection=None";
let actual = get_optimized_plan_formatted(&plan, &time);

Expand Down
100 changes: 72 additions & 28 deletions datafusion/src/optimizer/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -506,7 +506,10 @@ pub fn rewrite_expression(expr: &Expr, expressions: &[Expr]) -> Result<Expr> {
/// ```
/// # use datafusion::prelude::*;
/// # use datafusion::optimizer::utils::ConstEvaluator;
/// let mut const_evaluator = ConstEvaluator::new();
/// # use datafusion::execution::context::ExecutionProps;
///
/// let execution_props = ExecutionProps::new();
/// let mut const_evaluator = ConstEvaluator::new(&execution_props);
///
/// // (1 + 2) + a
/// let expr = (lit(1) + lit(2)) + col("a");
Expand Down Expand Up @@ -575,10 +578,15 @@ impl ExprRewriter for ConstEvaluator {
}

impl ConstEvaluator {
/// Create a new `ConstantEvaluator`.
pub fn new() -> Self {
/// Create a new `ConstantEvaluator`. Session constants (such as
/// the time for `now()` are taken from the passed
/// `execution_props`.
pub fn new(execution_props: &ExecutionProps) -> Self {
let planner = DefaultPhysicalPlanner::default();
let ctx_state = ExecutionContextState::new();
let ctx_state = ExecutionContextState {
execution_props: execution_props.clone(),
..ExecutionContextState::new()
};
let input_schema = DFSchema::empty();

// The dummy column name is unused and doesn't matter as only
Expand All @@ -604,9 +612,8 @@ impl ConstEvaluator {
fn volatility_ok(volatility: Volatility) -> bool {
match volatility {
Volatility::Immutable => true,
// To evaluate stable functions, need ExecutionProps, see
// Simplifier for code that does that.
Volatility::Stable => false,
// Values for functions such as now() are taken from ExecutionProps
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is the actual change that allows the const evaluator to replace now() with a constant.

Volatility::Stable => true,
Volatility::Volatile => false,
}
}
Expand Down Expand Up @@ -689,6 +696,7 @@ mod tests {
array::{ArrayRef, Int32Array},
datatypes::DataType,
};
use chrono::{DateTime, TimeZone, Utc};
use std::collections::HashSet;

#[test]
Expand Down Expand Up @@ -799,42 +807,69 @@ mod tests {
let rand = Expr::ScalarFunction { args: vec![], fun };
let expr = (rand + lit(1)) + lit(2);
test_evaluate(expr.clone(), expr);
}

// volatile / stable functions should not be evaluated
// now() + (1 + 2) --> now() + 3
let fun = BuiltinScalarFunction::Now;
assert_eq!(fun.volatility(), Volatility::Stable);
let now = Expr::ScalarFunction { args: vec![], fun };
let expr = now.clone() + (lit(1) + lit(2));
let expected = now + lit(3);
test_evaluate(expr, expected);
#[test]
fn test_const_evaluator_now() {
let ts_nanos = 1599566400000000000i64;
let time = chrono::Utc.timestamp_nanos(ts_nanos);
let ts_string = "2020-09-08T12:05:00+00:00";

// now() --> ts
test_evaluate_with_start_time(now_expr(), lit_timestamp_nano(ts_nanos), &time);

// CAST(now() as int64) + 100 --> ts + 100
let expr = cast_to_int64_expr(now_expr()) + lit(100);
test_evaluate_with_start_time(expr, lit(ts_nanos + 100), &time);

// now() < cast(to_timestamp(...) as int) + 50000 ---> true
let expr = cast_to_int64_expr(now_expr())
.lt(cast_to_int64_expr(to_timestamp_expr(ts_string)) + lit(50000));
test_evaluate_with_start_time(expr, lit(true), &time);
}

fn now_expr() -> Expr {
Expr::ScalarFunction {
args: vec![],
fun: BuiltinScalarFunction::Now,
}
}

fn cast_to_int64_expr(expr: Expr) -> Expr {
Expr::Cast {
expr: expr.into(),
data_type: DataType::Int64,
}
}

fn to_timestamp_expr(arg: impl Into<String>) -> Expr {
Expr::ScalarFunction {
args: vec![lit(arg.into())],
fun: BuiltinScalarFunction::ToTimestamp,
}
}

#[test]
fn test_const_evaluator_udfs() {
fn test_evaluator_udfs() {
let args = vec![lit(1) + lit(2), lit(30) + lit(40)];
let folded_args = vec![lit(3), lit(70)];

// immutable UDF should get folded
// udf_add(1+2, 30+40) --> 70
// udf_add(1+2, 30+40) --> 73
let expr = Expr::ScalarUDF {
args: args.clone(),
fun: make_udf_add(Volatility::Immutable),
};
test_evaluate(expr, lit(73));

// stable UDF should have args folded
// udf_add(1+2, 30+40) --> udf_add(3, 70)
// stable UDF should be entirely folded
// udf_add(1+2, 30+40) --> 73
let fun = make_udf_add(Volatility::Stable);
let expr = Expr::ScalarUDF {
args: args.clone(),
fun: Arc::clone(&fun),
};
let expected_expr = Expr::ScalarUDF {
args: folded_args.clone(),
fun: Arc::clone(&fun),
};
test_evaluate(expr, expected_expr);
test_evaluate(expr, lit(73));

// volatile UDF should have args folded
// udf_add(1+2, 30+40) --> udf_add(3, 70)
Expand Down Expand Up @@ -892,11 +927,16 @@ mod tests {
))
}

// udfs
// validate that even a volatile function's arguments will be evaluated
fn test_evaluate_with_start_time(
input_expr: Expr,
expected_expr: Expr,
date_time: &DateTime<Utc>,
) {
let execution_props = ExecutionProps {
query_execution_start_time: *date_time,
};

fn test_evaluate(input_expr: Expr, expected_expr: Expr) {
let mut const_evaluator = ConstEvaluator::new();
let mut const_evaluator = ConstEvaluator::new(&execution_props);
let evaluated_expr = input_expr
.clone()
.rewrite(&mut const_evaluator)
Expand All @@ -908,4 +948,8 @@ mod tests {
input_expr, expected_expr, evaluated_expr
);
}

fn test_evaluate(input_expr: Expr, expected_expr: Expr) {
test_evaluate_with_start_time(input_expr, expected_expr, &Utc::now())
}
}