diff --git a/datafusion/core/benches/sql_planner.rs b/datafusion/core/benches/sql_planner.rs index 653bfbd51378..25129a354ee8 100644 --- a/datafusion/core/benches/sql_planner.rs +++ b/datafusion/core/benches/sql_planner.rs @@ -519,9 +519,6 @@ fn criterion_benchmark(c: &mut Criterion) { }; let raw_tpcds_sql_queries = (1..100) - // skip query 75 until it is fixed - // https://github.com/apache/datafusion/issues/17801 - .filter(|q| *q != 75) .map(|q| std::fs::read_to_string(format!("{tests_path}tpc-ds/{q}.sql")).unwrap()) .collect::>(); diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 914d411ad6aa..2ae5aed30df9 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -60,6 +60,7 @@ use crate::schema_equivalence::schema_satisfied_by; use arrow::array::{builder::StringBuilder, RecordBatch}; use arrow::compute::SortOptions; use arrow::datatypes::Schema; +use arrow_schema::Field; use datafusion_catalog::ScanArgs; use datafusion_common::display::ToStringifiedPlan; use datafusion_common::format::ExplainAnalyzeLevel; @@ -2520,7 +2521,9 @@ impl<'a> OptimizationInvariantChecker<'a> { previous_schema: &Arc, ) -> Result<()> { // if the rule is not permitted to change the schema, confirm that it did not change. - if self.rule.schema_check() && plan.schema() != *previous_schema { + if self.rule.schema_check() + && !is_allowed_schema_change(previous_schema.as_ref(), plan.schema().as_ref()) + { internal_err!("PhysicalOptimizer rule '{}' failed. Schema mismatch. Expected original schema: {:?}, got new schema: {:?}", self.rule.name(), previous_schema, @@ -2536,6 +2539,38 @@ impl<'a> OptimizationInvariantChecker<'a> { } } +/// Checks if the change from `old` schema to `new` is allowed or not. +/// +/// The current implementation only allows nullability of individual fields to change +/// from 'nullable' to 'not nullable'. This can happen due to physical expressions knowing +/// more about their null-ness than their logical counterparts. +/// This change is allowed because for any field the non-nullable domain `F` is a strict subset +/// of the nullable domain `F ∪ { NULL }`. A physical schema that guarantees a stricter subset +/// of values will not violate any assumptions made based on the less strict schema. +fn is_allowed_schema_change(old: &Schema, new: &Schema) -> bool { + if new.metadata != old.metadata { + return false; + } + + if new.fields.len() != old.fields.len() { + return false; + } + + let new_fields = new.fields.iter().map(|f| f.as_ref()); + let old_fields = old.fields.iter().map(|f| f.as_ref()); + old_fields + .zip(new_fields) + .all(|(old, new)| is_allowed_field_change(old, new)) +} + +fn is_allowed_field_change(old_field: &Field, new_field: &Field) -> bool { + new_field.name() == old_field.name() + && new_field.data_type() == old_field.data_type() + && new_field.metadata() == old_field.metadata() + && (new_field.is_nullable() == old_field.is_nullable() + || !new_field.is_nullable()) +} + impl<'n> TreeNodeVisitor<'n> for OptimizationInvariantChecker<'_> { type Node = Arc; diff --git a/datafusion/core/tests/tpcds_planning.rs b/datafusion/core/tests/tpcds_planning.rs index 252d76d0f9d9..3ad74962bc2c 100644 --- a/datafusion/core/tests/tpcds_planning.rs +++ b/datafusion/core/tests/tpcds_planning.rs @@ -1052,9 +1052,12 @@ async fn regression_test(query_no: u8, create_physical: bool) -> Result<()> { for sql in &sql { let df = ctx.sql(sql).await?; let (state, plan) = df.into_parts(); - let plan = state.optimize(&plan)?; if create_physical { let _ = state.create_physical_plan(&plan).await?; + } else { + // Run the logical optimizer even if we are not creating the physical plan + // to ensure it will properly succeed + let _ = state.optimize(&plan)?; } } diff --git a/datafusion/expr-common/src/interval_arithmetic.rs b/datafusion/expr-common/src/interval_arithmetic.rs index 407cc1032dc7..b9f8102f341a 100644 --- a/datafusion/expr-common/src/interval_arithmetic.rs +++ b/datafusion/expr-common/src/interval_arithmetic.rs @@ -1822,18 +1822,86 @@ impl NullableInterval { self == &Self::TRUE } + /// Returns the set of possible values after applying the `is true` test on all + /// values in this set. + /// The resulting set can only contain 'TRUE' and/or 'FALSE', never 'UNKNOWN'. + pub fn is_true(&self) -> Result { + let (t, f, u) = self.is_true_false_unknown()?; + + match (t, f, u) { + (true, false, false) => Ok(Self::TRUE), + (true, _, _) => Ok(Self::TRUE_OR_FALSE), + (false, _, _) => Ok(Self::FALSE), + } + } + /// Return true if the value is definitely false (and not null). pub fn is_certainly_false(&self) -> bool { self == &Self::FALSE } + /// Returns the set of possible values after applying the `is false` test on all + /// values in this set. + /// The resulting set can only contain 'TRUE' and/or 'FALSE', never 'UNKNOWN'. + pub fn is_false(&self) -> Result { + let (t, f, u) = self.is_true_false_unknown()?; + + match (t, f, u) { + (false, true, false) => Ok(Self::TRUE), + (_, true, _) => Ok(Self::TRUE_OR_FALSE), + (_, false, _) => Ok(Self::FALSE), + } + } + /// Return true if the value is definitely null (and not true or false). pub fn is_certainly_unknown(&self) -> bool { self == &Self::UNKNOWN } - /// Perform logical negation on a boolean nullable interval. - fn not(&self) -> Result { + /// Returns the set of possible values after applying the `is unknown` test on all + /// values in this set. + /// The resulting set can only contain 'TRUE' and/or 'FALSE', never 'UNKNOWN'. + pub fn is_unknown(&self) -> Result { + let (t, f, u) = self.is_true_false_unknown()?; + + match (t, f, u) { + (false, false, true) => Ok(Self::TRUE), + (_, _, true) => Ok(Self::TRUE_OR_FALSE), + (_, _, false) => Ok(Self::FALSE), + } + } + + /// Returns a tuple of booleans indicating if this interval contains the + /// true, false, and unknown truth values respectively. + fn is_true_false_unknown(&self) -> Result<(bool, bool, bool), DataFusionError> { + Ok(match self { + NullableInterval::Null { .. } => (false, false, true), + NullableInterval::MaybeNull { values } => ( + values.contains_value(ScalarValue::Boolean(Some(true)))?, + values.contains_value(ScalarValue::Boolean(Some(false)))?, + true, + ), + NullableInterval::NotNull { values } => ( + values.contains_value(ScalarValue::Boolean(Some(true)))?, + values.contains_value(ScalarValue::Boolean(Some(false)))?, + false, + ), + }) + } + + /// Returns an interval representing the set of possible values after applying + /// SQL three-valued logical NOT on possible value in this interval. + /// + /// This method uses the following truth table. + /// + /// ```text + /// A | ¬A + /// ----|---- + /// F | T + /// U | U + /// T | F + /// ``` + pub fn not(&self) -> Result { match self { Self::Null { datatype } => { assert_eq_or_internal_err!( @@ -1852,8 +1920,20 @@ impl NullableInterval { } } - /// Compute the logical conjunction of this (boolean) interval with the - /// given boolean interval. + /// Returns an interval representing the set of possible values after applying SQL + /// three-valued logical AND on each combination of possible values from `self` and `other`. + /// + /// This method uses the following truth table. + /// + /// ```text + /// │ B + /// A ∧ B ├────── + /// │ F U T + /// ──┬───┼────── + /// │ F │ F F F + /// A │ U │ F U U + /// │ T │ F U T + /// ``` pub fn and>(&self, rhs: T) -> Result { if self == &Self::FALSE || rhs.borrow() == &Self::FALSE { return Ok(Self::FALSE); @@ -1880,8 +1960,20 @@ impl NullableInterval { } } - /// Compute the logical disjunction of this (boolean) interval with the - /// given boolean interval. + /// Returns an interval representing the set of possible values after applying SQL three-valued + /// logical OR on each combination of possible values from `self` and `other`. + /// + /// This method uses the following truth table. + /// + /// ```text + /// │ B + /// A ∨ B ├────── + /// │ F U T + /// ──┬───┼────── + /// │ F │ F U T + /// A │ U │ U U T + /// │ T │ T T T + /// ``` pub fn or>(&self, rhs: T) -> Result { if self == &Self::TRUE || rhs.borrow() == &Self::TRUE { return Ok(Self::TRUE); @@ -2042,6 +2134,30 @@ impl NullableInterval { } } + /// Determines if this interval contains a [`ScalarValue`] or not. + pub fn contains_value>(&self, value: T) -> Result { + match value.borrow() { + ScalarValue::Null => match self { + NullableInterval::Null { .. } | NullableInterval::MaybeNull { .. } => { + Ok(true) + } + NullableInterval::NotNull { .. } => Ok(false), + }, + s if s.is_null() => match self { + NullableInterval::Null { datatype } => Ok(datatype.eq(&s.data_type())), + NullableInterval::MaybeNull { values } => { + Ok(values.data_type().eq(&s.data_type())) + } + NullableInterval::NotNull { .. } => Ok(false), + }, + s => match self { + NullableInterval::Null { .. } => Ok(false), + NullableInterval::MaybeNull { values } + | NullableInterval::NotNull { values } => values.contains_value(s), + }, + } + } + /// If the interval has collapsed to a single value, return that value. /// Otherwise, returns `None`. /// @@ -4459,4 +4575,175 @@ mod tests { } Ok(()) } + + #[test] + fn nullable_interval_is_certainly_true() { + // Test cases: (interval, expected) => interval.is_certainly_true() = expected + #[rustfmt::skip] + let test_cases = vec![ + (NullableInterval::TRUE, true), + (NullableInterval::FALSE, false), + (NullableInterval::UNKNOWN, false), + (NullableInterval::TRUE_OR_FALSE, false), + (NullableInterval::TRUE_OR_UNKNOWN, false), + (NullableInterval::FALSE_OR_UNKNOWN, false), + (NullableInterval::ANY_TRUTH_VALUE, false), + ]; + + for (interval, expected) in test_cases { + let result = interval.is_certainly_true(); + assert_eq!(result, expected, "Failed for interval: {interval}",); + } + } + + #[test] + fn nullable_interval_is_true() { + // Test cases: (interval, expected) => interval.is_true() = expected + #[rustfmt::skip] + let test_cases = vec![ + (NullableInterval::TRUE, NullableInterval::TRUE), + (NullableInterval::FALSE, NullableInterval::FALSE), + (NullableInterval::UNKNOWN, NullableInterval::FALSE), + (NullableInterval::TRUE_OR_FALSE,NullableInterval::TRUE_OR_FALSE), + (NullableInterval::TRUE_OR_UNKNOWN,NullableInterval::TRUE_OR_FALSE), + (NullableInterval::FALSE_OR_UNKNOWN, NullableInterval::FALSE), + (NullableInterval::ANY_TRUTH_VALUE,NullableInterval::TRUE_OR_FALSE), + ]; + + for (interval, expected) in test_cases { + let result = interval.is_true().unwrap(); + assert_eq!(result, expected, "Failed for interval: {interval}",); + } + } + + #[test] + fn nullable_interval_is_certainly_false() { + // Test cases: (interval, expected) => interval.is_certainly_false() = expected + #[rustfmt::skip] + let test_cases = vec![ + (NullableInterval::TRUE, false), + (NullableInterval::FALSE, true), + (NullableInterval::UNKNOWN, false), + (NullableInterval::TRUE_OR_FALSE, false), + (NullableInterval::TRUE_OR_UNKNOWN, false), + (NullableInterval::FALSE_OR_UNKNOWN, false), + (NullableInterval::ANY_TRUTH_VALUE, false), + ]; + + for (interval, expected) in test_cases { + let result = interval.is_certainly_false(); + assert_eq!(result, expected, "Failed for interval: {interval}",); + } + } + + #[test] + fn nullable_interval_is_false() { + // Test cases: (interval, expected) => interval.is_false() = expected + #[rustfmt::skip] + let test_cases = vec![ + (NullableInterval::TRUE, NullableInterval::FALSE), + (NullableInterval::FALSE, NullableInterval::TRUE), + (NullableInterval::UNKNOWN, NullableInterval::FALSE), + (NullableInterval::TRUE_OR_FALSE,NullableInterval::TRUE_OR_FALSE), + (NullableInterval::TRUE_OR_UNKNOWN, NullableInterval::FALSE), + (NullableInterval::FALSE_OR_UNKNOWN,NullableInterval::TRUE_OR_FALSE), + (NullableInterval::ANY_TRUTH_VALUE,NullableInterval::TRUE_OR_FALSE), + ]; + + for (interval, expected) in test_cases { + let result = interval.is_false().unwrap(); + assert_eq!(result, expected, "Failed for interval: {interval}",); + } + } + + #[test] + fn nullable_interval_is_certainly_unknown() { + // Test cases: (interval, expected) => interval.is_certainly_unknown() = expected + #[rustfmt::skip] + let test_cases = vec![ + (NullableInterval::TRUE, false), + (NullableInterval::FALSE, false), + (NullableInterval::UNKNOWN, true), + (NullableInterval::TRUE_OR_FALSE, false), + (NullableInterval::TRUE_OR_UNKNOWN, false), + (NullableInterval::FALSE_OR_UNKNOWN, false), + (NullableInterval::ANY_TRUTH_VALUE, false), + ]; + + for (interval, expected) in test_cases { + let result = interval.is_certainly_unknown(); + assert_eq!(result, expected, "Failed for interval: {interval}",); + } + } + + #[test] + fn nullable_interval_is_unknown() { + // Test cases: (interval, expected) => interval.is_unknown() = expected + #[rustfmt::skip] + let test_cases = vec![ + (NullableInterval::TRUE, NullableInterval::FALSE), + (NullableInterval::FALSE, NullableInterval::FALSE), + (NullableInterval::UNKNOWN, NullableInterval::TRUE), + (NullableInterval::TRUE_OR_FALSE, NullableInterval::FALSE), + (NullableInterval::TRUE_OR_UNKNOWN,NullableInterval::TRUE_OR_FALSE), + (NullableInterval::FALSE_OR_UNKNOWN,NullableInterval::TRUE_OR_FALSE), + (NullableInterval::ANY_TRUTH_VALUE,NullableInterval::TRUE_OR_FALSE), + ]; + + for (interval, expected) in test_cases { + let result = interval.is_unknown().unwrap(); + assert_eq!(result, expected, "Failed for interval: {interval}",); + } + } + + #[test] + fn nullable_interval_contains_value() { + // Test cases: (interval, value, expected) => interval.contains_value(value) = expected + #[rustfmt::skip] + let test_cases = vec![ + (NullableInterval::TRUE, ScalarValue::Boolean(Some(true)), true), + (NullableInterval::TRUE, ScalarValue::Boolean(Some(false)), false), + (NullableInterval::TRUE, ScalarValue::Boolean(None), false), + (NullableInterval::TRUE, ScalarValue::Null, false), + (NullableInterval::TRUE, ScalarValue::UInt32(None), false), + (NullableInterval::FALSE, ScalarValue::Boolean(Some(true)), false), + (NullableInterval::FALSE, ScalarValue::Boolean(Some(false)), true), + (NullableInterval::FALSE, ScalarValue::Boolean(None), false), + (NullableInterval::FALSE, ScalarValue::Null, false), + (NullableInterval::FALSE, ScalarValue::UInt32(None), false), + (NullableInterval::UNKNOWN, ScalarValue::Boolean(Some(true)), false), + (NullableInterval::UNKNOWN, ScalarValue::Boolean(Some(false)), false), + (NullableInterval::UNKNOWN, ScalarValue::Boolean(None), true), + (NullableInterval::UNKNOWN, ScalarValue::Null, true), + (NullableInterval::UNKNOWN, ScalarValue::UInt32(None), false), + (NullableInterval::TRUE_OR_FALSE, ScalarValue::Boolean(Some(true)), true), + (NullableInterval::TRUE_OR_FALSE, ScalarValue::Boolean(Some(false)), true), + (NullableInterval::TRUE_OR_FALSE, ScalarValue::Boolean(None), false), + (NullableInterval::TRUE_OR_FALSE, ScalarValue::Null, false), + (NullableInterval::TRUE_OR_FALSE, ScalarValue::UInt32(None), false), + (NullableInterval::TRUE_OR_UNKNOWN, ScalarValue::Boolean(Some(true)), true), + (NullableInterval::TRUE_OR_UNKNOWN, ScalarValue::Boolean(Some(false)), false), + (NullableInterval::TRUE_OR_UNKNOWN, ScalarValue::Boolean(None), true), + (NullableInterval::TRUE_OR_UNKNOWN, ScalarValue::Null, true), + (NullableInterval::TRUE_OR_UNKNOWN, ScalarValue::UInt32(None), false), + (NullableInterval::FALSE_OR_UNKNOWN, ScalarValue::Boolean(Some(true)), false), + (NullableInterval::FALSE_OR_UNKNOWN, ScalarValue::Boolean(Some(false)), true), + (NullableInterval::FALSE_OR_UNKNOWN, ScalarValue::Boolean(None), true), + (NullableInterval::FALSE_OR_UNKNOWN, ScalarValue::Null, true), + (NullableInterval::FALSE_OR_UNKNOWN, ScalarValue::UInt32(None), false), + (NullableInterval::ANY_TRUTH_VALUE, ScalarValue::Boolean(Some(true)), true), + (NullableInterval::ANY_TRUTH_VALUE, ScalarValue::Boolean(Some(false)), true), + (NullableInterval::ANY_TRUTH_VALUE, ScalarValue::Boolean(None), true), + (NullableInterval::ANY_TRUTH_VALUE, ScalarValue::Null, true), + (NullableInterval::ANY_TRUTH_VALUE, ScalarValue::UInt32(None), false), + ]; + + for (interval, value, expected) in test_cases { + let result = interval.contains_value(value.clone()).unwrap(); + assert_eq!( + result, expected, + "Failed for interval: {interval} and value {value:?}", + ); + } + } } diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index c777c4978f99..94d8009ce814 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -341,6 +341,11 @@ pub fn is_null(expr: Expr) -> Expr { Expr::IsNull(Box::new(expr)) } +/// Create is not null expression +pub fn is_not_null(expr: Expr) -> Expr { + Expr::IsNotNull(Box::new(expr)) +} + /// Create is true expression pub fn is_true(expr: Expr) -> Expr { Expr::IsTrue(Box::new(expr)) diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index bbd52ac154c7..8f8720941fad 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -15,12 +15,13 @@ // specific language governing permissions and limitations // under the License. -use super::{Between, Expr, Like}; +use super::{predicate_bounds, Between, Expr, Like}; use crate::expr::{ AggregateFunction, AggregateFunctionParams, Alias, BinaryExpr, Cast, InList, InSubquery, Placeholder, ScalarFunction, TryCast, Unnest, WindowFunction, WindowFunctionParams, }; +use crate::expr_rewriter::rewrite_with_guarantees; use crate::type_coercion::functions::{ data_types_with_scalar_udf, fields_with_aggregate_udf, fields_with_window_udf, }; @@ -31,8 +32,9 @@ use arrow::datatypes::{DataType, Field, FieldRef}; use datafusion_common::metadata::FieldMetadata; use datafusion_common::{ not_impl_err, plan_datafusion_err, plan_err, Column, DataFusionError, ExprSchema, - Result, Spans, TableReference, + Result, ScalarValue, Spans, TableReference, }; +use datafusion_expr_common::interval_arithmetic::NullableInterval; use datafusion_expr_common::type_coercion::binary::BinaryTypeCoercer; use datafusion_functions_window_common::field::WindowUDFFieldArgs; use std::sync::Arc; @@ -282,15 +284,85 @@ impl ExprSchemable for Expr { Expr::OuterReferenceColumn(field, _) => Ok(field.is_nullable()), Expr::Literal(value, _) => Ok(value.is_null()), Expr::Case(case) => { - // This expression is nullable if any of the input expressions are nullable - let then_nullable = case + let nullable_then = case .when_then_expr .iter() - .map(|(_, t)| t.nullable(input_schema)) - .collect::>>()?; - if then_nullable.contains(&true) { - Ok(true) + .filter_map(|(w, t)| { + let is_nullable = match t.nullable(input_schema) { + Err(e) => return Some(Err(e)), + Ok(n) => n, + }; + + // Branches with a then expression that is not nullable do not impact the + // nullability of the case expression. + if !is_nullable { + return None; + } + + // For case-with-expression assume all 'then' expressions are reachable + if case.expr.is_some() { + return Some(Ok(())); + } + + // For branches with a nullable 'then' expression, try to determine + // if the 'then' expression is ever reachable in the situation where + // it would evaluate to null. + + // First, derive a variant of the 'when' expression, where all occurrences + // of the 'then' expression have been replaced by 'NULL'. + let certainly_null_expr = unwrap_certainly_null_expr(t).clone(); + let certainly_null_type = + match certainly_null_expr.get_type(input_schema) { + Err(e) => return Some(Err(e)), + Ok(datatype) => datatype, + }; + let null_interval = NullableInterval::Null { + datatype: certainly_null_type, + }; + let guarantees = vec![(certainly_null_expr, null_interval)]; + let when_with_null = + match rewrite_with_guarantees(*w.clone(), &guarantees) { + Err(e) => return Some(Err(e)), + Ok(e) => e.data, + }; + + // Next, determine the bounds of the derived 'when' expression to see if it + // can ever evaluate to true. + let bounds = match predicate_bounds::evaluate_bounds( + &when_with_null, + input_schema, + ) { + Err(e) => return Some(Err(e)), + Ok(b) => b, + }; + + let can_be_true = match bounds + .contains_value(ScalarValue::Boolean(Some(true))) + { + Err(e) => return Some(Err(e)), + Ok(b) => b, + }; + + if !can_be_true { + // If the derived 'when' expression can never evaluate to true, the + // 'then' expression is not reachable when it would evaluate to NULL. + // The most common pattern for this is `WHEN x IS NOT NULL THEN x`. + None + } else { + // The branch might be taken + Some(Ok(())) + } + }) + .next(); + + if let Some(nullable_then) = nullable_then { + // There is at least one reachable nullable 'then' expression, so the case + // expression itself is nullable. + // Use `Result::map` to propagate the error from `nullable_then` if there is one. + nullable_then.map(|_| true) } else if let Some(e) = &case.else_expr { + // There are no reachable nullable 'then' expressions, so all we still need to + // check is the 'else' expression's nullability. e.nullable(input_schema) } else { // CASE produces NULL if there is no `else` expr @@ -642,6 +714,16 @@ impl ExprSchemable for Expr { } } +/// Returns the innermost [Expr] that is provably null if `expr` is null. +fn unwrap_certainly_null_expr(expr: &Expr) -> &Expr { + match expr { + Expr::Not(e) => unwrap_certainly_null_expr(e), + Expr::Negative(e) => unwrap_certainly_null_expr(e), + Expr::Cast(e) => unwrap_certainly_null_expr(e.expr.as_ref()), + _ => expr, + } +} + impl Expr { /// Common method for window functions that applies type coercion /// to all arguments of the window function to check if it matches @@ -773,7 +855,7 @@ mod tests { use std::collections::HashMap; use super::*; - use crate::{col, lit, out_ref_col_with_metadata}; + use crate::{and, col, lit, not, or, out_ref_col_with_metadata, when}; use datafusion_common::{ assert_or_internal_err, DFSchema, DataFusionError, ScalarValue, @@ -828,6 +910,137 @@ mod tests { assert!(expr.nullable(&get_schema(false)).unwrap()); } + fn assert_nullability(expr: &Expr, schema: &dyn ExprSchema, expected: bool) { + assert_eq!( + expr.nullable(schema).unwrap(), + expected, + "Nullability of '{expr}' should be {expected}" + ); + } + + fn assert_not_nullable(expr: &Expr, schema: &dyn ExprSchema) { + assert_nullability(expr, schema, false); + } + + fn assert_nullable(expr: &Expr, schema: &dyn ExprSchema) { + assert_nullability(expr, schema, true); + } + + #[test] + fn test_case_expression_nullability() -> Result<()> { + let nullable_schema = MockExprSchema::new() + .with_data_type(DataType::Int32) + .with_nullable(true); + + let not_nullable_schema = MockExprSchema::new() + .with_data_type(DataType::Int32) + .with_nullable(false); + + // CASE WHEN x IS NOT NULL THEN x ELSE 0 + let e = when(col("x").is_not_null(), col("x")).otherwise(lit(0))?; + assert_not_nullable(&e, &nullable_schema); + assert_not_nullable(&e, ¬_nullable_schema); + + // CASE WHEN NOT x IS NULL THEN x ELSE 0 + let e = when(not(col("x").is_null()), col("x")).otherwise(lit(0))?; + assert_not_nullable(&e, &nullable_schema); + assert_not_nullable(&e, ¬_nullable_schema); + + // CASE WHEN X = 5 THEN x ELSE 0 + let e = when(col("x").eq(lit(5)), col("x")).otherwise(lit(0))?; + assert_not_nullable(&e, &nullable_schema); + assert_not_nullable(&e, ¬_nullable_schema); + + // CASE WHEN x IS NOT NULL AND x = 5 THEN x ELSE 0 + let e = when(and(col("x").is_not_null(), col("x").eq(lit(5))), col("x")) + .otherwise(lit(0))?; + assert_not_nullable(&e, &nullable_schema); + assert_not_nullable(&e, ¬_nullable_schema); + + // CASE WHEN x = 5 AND x IS NOT NULL THEN x ELSE 0 + let e = when(and(col("x").eq(lit(5)), col("x").is_not_null()), col("x")) + .otherwise(lit(0))?; + assert_not_nullable(&e, &nullable_schema); + assert_not_nullable(&e, ¬_nullable_schema); + + // CASE WHEN x IS NOT NULL OR x = 5 THEN x ELSE 0 + let e = when(or(col("x").is_not_null(), col("x").eq(lit(5))), col("x")) + .otherwise(lit(0))?; + assert_not_nullable(&e, &nullable_schema); + assert_not_nullable(&e, ¬_nullable_schema); + + // CASE WHEN x = 5 OR x IS NOT NULL THEN x ELSE 0 + let e = when(or(col("x").eq(lit(5)), col("x").is_not_null()), col("x")) + .otherwise(lit(0))?; + assert_not_nullable(&e, &nullable_schema); + assert_not_nullable(&e, ¬_nullable_schema); + + // CASE WHEN (x = 5 AND x IS NOT NULL) OR (x = bar AND x IS NOT NULL) THEN x ELSE 0 + let e = when( + or( + and(col("x").eq(lit(5)), col("x").is_not_null()), + and(col("x").eq(col("bar")), col("x").is_not_null()), + ), + col("x"), + ) + .otherwise(lit(0))?; + assert_not_nullable(&e, &nullable_schema); + assert_not_nullable(&e, ¬_nullable_schema); + + // CASE WHEN x = 5 OR x IS NULL THEN x ELSE 0 + let e = when(or(col("x").eq(lit(5)), col("x").is_null()), col("x")) + .otherwise(lit(0))?; + assert_nullable(&e, &nullable_schema); + assert_not_nullable(&e, ¬_nullable_schema); + + // CASE WHEN x IS TRUE THEN x ELSE 0 + let e = when(col("x").is_true(), col("x")).otherwise(lit(0))?; + assert_not_nullable(&e, &nullable_schema); + assert_not_nullable(&e, ¬_nullable_schema); + + // CASE WHEN x IS NOT TRUE THEN x ELSE 0 + let e = when(col("x").is_not_true(), col("x")).otherwise(lit(0))?; + assert_nullable(&e, &nullable_schema); + assert_not_nullable(&e, ¬_nullable_schema); + + // CASE WHEN x IS FALSE THEN x ELSE 0 + let e = when(col("x").is_false(), col("x")).otherwise(lit(0))?; + assert_not_nullable(&e, &nullable_schema); + assert_not_nullable(&e, ¬_nullable_schema); + + // CASE WHEN x IS NOT FALSE THEN x ELSE 0 + let e = when(col("x").is_not_false(), col("x")).otherwise(lit(0))?; + assert_nullable(&e, &nullable_schema); + assert_not_nullable(&e, ¬_nullable_schema); + + // CASE WHEN x IS UNKNOWN THEN x ELSE 0 + let e = when(col("x").is_unknown(), col("x")).otherwise(lit(0))?; + assert_nullable(&e, &nullable_schema); + assert_not_nullable(&e, ¬_nullable_schema); + + // CASE WHEN x IS NOT UNKNOWN THEN x ELSE 0 + let e = when(col("x").is_not_unknown(), col("x")).otherwise(lit(0))?; + assert_not_nullable(&e, &nullable_schema); + assert_not_nullable(&e, ¬_nullable_schema); + + // CASE WHEN x LIKE 'x' THEN x ELSE 0 + let e = when(col("x").like(lit("x")), col("x")).otherwise(lit(0))?; + assert_not_nullable(&e, &nullable_schema); + assert_not_nullable(&e, ¬_nullable_schema); + + // CASE WHEN 0 THEN x ELSE 0 + let e = when(lit(0), col("x")).otherwise(lit(0))?; + assert_not_nullable(&e, &nullable_schema); + assert_not_nullable(&e, ¬_nullable_schema); + + // CASE WHEN 1 THEN x ELSE 0 + let e = when(lit(1), col("x")).otherwise(lit(0))?; + assert_nullable(&e, &nullable_schema); + assert_not_nullable(&e, ¬_nullable_schema); + + Ok(()) + } + #[test] fn test_inlist_nullability() { let get_schema = |nullable| { diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index 885e582ea6d4..c82b56aa58a3 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -73,6 +73,7 @@ pub mod async_udf; pub mod statistics { pub use datafusion_expr_common::statistics::*; } +mod predicate_bounds; pub mod ptr_eq; pub mod test; pub mod tree_node; diff --git a/datafusion/expr/src/predicate_bounds.rs b/datafusion/expr/src/predicate_bounds.rs new file mode 100644 index 000000000000..e79e756a3215 --- /dev/null +++ b/datafusion/expr/src/predicate_bounds.rs @@ -0,0 +1,669 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::{Between, BinaryExpr, Expr, ExprSchemable}; +use arrow::datatypes::DataType; +use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion}; +use datafusion_common::{ExprSchema, Result, ScalarValue}; +use datafusion_expr_common::interval_arithmetic::NullableInterval; +use datafusion_expr_common::operator::Operator; + +/// Computes the output interval for the given boolean expression based on statically +/// available information. +/// +/// # Arguments +/// +/// * `predicate` - The boolean expression to analyze +/// * `is_null` - A callback function that provides additional nullability information for +/// expressions. When called with an expression, it should return: +/// - `Some(true)` if the expression is known to evaluate to NULL +/// - `Some(false)` if the expression is known to NOT evaluate to NULL +/// - `None` if the nullability cannot be determined +/// +/// This callback allows the caller to provide context-specific knowledge about expression +/// nullability that cannot be determined from the schema alone. For example, it can be used +/// to indicate that a particular column reference is known to be NULL in a specific context, +/// or that certain expressions will never be NULL based on runtime constraints. +/// +/// * `input_schema` - Schema information for resolving expression types and nullability +/// +/// # Return Value +/// +/// The function returns a [NullableInterval] that describes the possible boolean values the +/// predicate can evaluate to. +/// +pub(super) fn evaluate_bounds( + predicate: &Expr, + input_schema: &dyn ExprSchema, +) -> Result { + let evaluator = PredicateBoundsEvaluator { input_schema }; + evaluator.evaluate_bounds(predicate) +} + +struct PredicateBoundsEvaluator<'a> { + input_schema: &'a dyn ExprSchema, +} + +impl PredicateBoundsEvaluator<'_> { + /// Derives the bounds of the given boolean expression + fn evaluate_bounds(&self, predicate: &Expr) -> Result { + Ok(match predicate { + Expr::Literal(scalar, _) => { + // Interpret literals as boolean, coercing if necessary + match scalar { + ScalarValue::Null => NullableInterval::UNKNOWN, + ScalarValue::Boolean(b) => match b { + Some(true) => NullableInterval::TRUE, + Some(false) => NullableInterval::FALSE, + None => NullableInterval::UNKNOWN, + }, + _ => { + let b = Expr::Literal(scalar.cast_to(&DataType::Boolean)?, None); + self.evaluate_bounds(&b)? + } + } + } + Expr::IsNull(e) => { + // If `e` is not nullable, then `e IS NULL` is provably false + if !e.nullable(self.input_schema)? { + NullableInterval::FALSE + } else { + match e.get_type(self.input_schema)? { + // If `e` is a boolean expression, check if `e` is provably 'unknown'. + DataType::Boolean => self.evaluate_bounds(e)?.is_unknown()?, + // If `e` is not a boolean expression, check if `e` is provably null + _ => self.is_null(e), + } + } + } + Expr::IsNotNull(e) => { + // If `e` is not nullable, then `e IS NOT NULL` is provably true + if !e.nullable(self.input_schema)? { + NullableInterval::TRUE + } else { + match e.get_type(self.input_schema)? { + // If `e` is a boolean expression, try to evaluate it and test for not unknown + DataType::Boolean => { + self.evaluate_bounds(e)?.is_unknown()?.not()? + } + // If `e` is not a boolean expression, check if `e` is provably null + _ => self.is_null(e).not()?, + } + } + } + Expr::IsTrue(e) => self.evaluate_bounds(e)?.is_true()?, + Expr::IsNotTrue(e) => self.evaluate_bounds(e)?.is_true()?.not()?, + Expr::IsFalse(e) => self.evaluate_bounds(e)?.is_false()?, + Expr::IsNotFalse(e) => self.evaluate_bounds(e)?.is_false()?.not()?, + Expr::IsUnknown(e) => self.evaluate_bounds(e)?.is_unknown()?, + Expr::IsNotUnknown(e) => self.evaluate_bounds(e)?.is_unknown()?.not()?, + Expr::Not(e) => self.evaluate_bounds(e)?.not()?, + Expr::BinaryExpr(BinaryExpr { + left, + op: Operator::And, + right, + }) => NullableInterval::and( + &self.evaluate_bounds(left)?, + &self.evaluate_bounds(right)?, + )?, + Expr::BinaryExpr(BinaryExpr { + left, + op: Operator::Or, + right, + }) => NullableInterval::or( + &self.evaluate_bounds(left)?, + &self.evaluate_bounds(right)?, + )?, + e => { + let is_null = self.is_null(e); + + // If an expression is null, then it's value is UNKNOWN + let maybe_null = + is_null.contains_value(ScalarValue::Boolean(Some(true)))?; + + let maybe_not_null = + is_null.contains_value(ScalarValue::Boolean(Some(false)))?; + + match (maybe_null, maybe_not_null) { + (true, true) | (false, false) => NullableInterval::ANY_TRUTH_VALUE, + (true, false) => NullableInterval::UNKNOWN, + (false, true) => NullableInterval::TRUE_OR_FALSE, + } + } + }) + } + + /// Determines if the given expression can evaluate to `NULL`. + /// + /// This method only returns sets containing `TRUE`, `FALSE`, or both. + fn is_null(&self, expr: &Expr) -> NullableInterval { + // Fast path for literals + if let Expr::Literal(scalar, _) = expr { + if scalar.is_null() { + return NullableInterval::TRUE; + } else { + return NullableInterval::FALSE; + } + } + + // If `expr` is not nullable, we can be certain `expr` is not null + if let Ok(false) = expr.nullable(self.input_schema) { + return NullableInterval::FALSE; + } + + // `expr` is nullable, so our default answer for `is null` is going to be `{ TRUE, FALSE }`. + // Try to see if we can narrow it down to just one option. + match expr { + Expr::BinaryExpr(BinaryExpr { op, .. }) if op.returns_null_on_null() => { + self.is_null_if_any_child_null(expr) + } + Expr::Alias(_) + | Expr::Cast(_) + | Expr::Like(_) + | Expr::Negative(_) + | Expr::Not(_) + | Expr::SimilarTo(_) => self.is_null_if_any_child_null(expr), + Expr::Between(Between { + expr, low, high, .. + }) if self.is_null(expr).is_certainly_true() + || (self.is_null(low.as_ref()).is_certainly_true() + && self.is_null(high.as_ref()).is_certainly_true()) => + { + // Between is always null if the left side is null + // or both the low and high bounds are null + NullableInterval::TRUE + } + _ => NullableInterval::TRUE_OR_FALSE, + } + } + + fn is_null_if_any_child_null(&self, expr: &Expr) -> NullableInterval { + // These expressions are null if any of their direct children is null + // If any child is inconclusive, the result for this expression is also inconclusive + let mut is_null = NullableInterval::FALSE; + + let _ = expr.apply_children(|child| { + let child_is_null = self.is_null(child); + + if child_is_null.contains_value(ScalarValue::Boolean(Some(true)))? { + // If a child might be null, then the result may also be null + is_null = NullableInterval::TRUE_OR_FALSE; + } + + if !child_is_null.contains_value(ScalarValue::Boolean(Some(false)))? { + // If the child is never not null, then the result can also never be not null + // and we can stop traversing the children + is_null = NullableInterval::TRUE; + Ok(TreeNodeRecursion::Stop) + } else { + Ok(TreeNodeRecursion::Continue) + } + }); + + is_null + } +} + +#[cfg(test)] +mod tests { + use crate::expr::ScalarFunction; + use crate::predicate_bounds::evaluate_bounds; + use crate::{ + binary_expr, col, create_udf, is_false, is_not_false, is_not_null, is_not_true, + is_not_unknown, is_null, is_true, is_unknown, lit, not, Expr, + }; + use arrow::datatypes::{DataType, Field, Schema}; + use datafusion_common::{DFSchema, Result, ScalarValue}; + use datafusion_expr_common::columnar_value::ColumnarValue; + use datafusion_expr_common::interval_arithmetic::NullableInterval; + use datafusion_expr_common::operator::Operator::{And, Eq, Or}; + use datafusion_expr_common::signature::Volatility; + use std::ops::Neg; + use std::sync::Arc; + + fn eval_bounds(predicate: &Expr) -> Result { + let schema = DFSchema::try_from(Schema::empty())?; + evaluate_bounds(predicate, &schema) + } + + #[test] + fn evaluate_bounds_literal() { + #[rustfmt::skip] + let cases = vec![ + (lit(ScalarValue::Null), NullableInterval::UNKNOWN), + (lit(false), NullableInterval::FALSE), + (lit(true), NullableInterval::TRUE), + (lit(0), NullableInterval::FALSE), + (lit(1), NullableInterval::TRUE), + (lit(ScalarValue::Utf8(None)), NullableInterval::UNKNOWN), + ]; + + for case in cases { + assert_eq!( + eval_bounds(&case.0).unwrap(), + case.1, + "Failed for {}", + case.0 + ); + } + + assert!(eval_bounds(&lit("foo")).is_err()); + } + + #[test] + fn evaluate_bounds_and() { + let null = lit(ScalarValue::Null); + let zero = lit(0); + let one = lit(1); + let t = lit(true); + let f = lit(false); + let func = make_scalar_func_expr(); + + #[rustfmt::skip] + let cases = vec![ + (binary_expr(null.clone(), And, null.clone()), NullableInterval::UNKNOWN), + (binary_expr(null.clone(), And, one.clone()), NullableInterval::UNKNOWN), + (binary_expr(null.clone(), And, zero.clone()), NullableInterval::FALSE), + (binary_expr(one.clone(), And, one.clone()), NullableInterval::TRUE), + (binary_expr(one.clone(), And, zero.clone()), NullableInterval::FALSE), + (binary_expr(null.clone(), And, t.clone()), NullableInterval::UNKNOWN), + (binary_expr(t.clone(), And, null.clone()), NullableInterval::UNKNOWN), + (binary_expr(null.clone(), And, f.clone()), NullableInterval::FALSE), + (binary_expr(f.clone(), And, null.clone()), NullableInterval::FALSE), + (binary_expr(t.clone(), And, t.clone()), NullableInterval::TRUE), + (binary_expr(t.clone(), And, f.clone()), NullableInterval::FALSE), + (binary_expr(f.clone(), And, t.clone()), NullableInterval::FALSE), + (binary_expr(f.clone(), And, f.clone()), NullableInterval::FALSE), + (binary_expr(t.clone(), And, func.clone()), NullableInterval::ANY_TRUTH_VALUE), + (binary_expr(func.clone(), And, t.clone()), NullableInterval::ANY_TRUTH_VALUE), + (binary_expr(f.clone(), And, func.clone()), NullableInterval::FALSE), + (binary_expr(func.clone(), And, f.clone()), NullableInterval::FALSE), + (binary_expr(null.clone(), And, func.clone()), NullableInterval::FALSE_OR_UNKNOWN), + (binary_expr(func.clone(), And, null.clone()), NullableInterval::FALSE_OR_UNKNOWN), + ]; + + for case in cases { + assert_eq!( + eval_bounds(&case.0).unwrap(), + case.1, + "Failed for {}", + case.0 + ); + } + } + + #[test] + fn evaluate_bounds_or() { + let null = lit(ScalarValue::Null); + let zero = lit(0); + let one = lit(1); + let t = lit(true); + let f = lit(false); + let func = make_scalar_func_expr(); + + #[rustfmt::skip] + let cases = vec![ + (binary_expr(null.clone(), Or, null.clone()), NullableInterval::UNKNOWN), + (binary_expr(null.clone(), Or, one.clone()), NullableInterval::TRUE), + (binary_expr(null.clone(), Or, zero.clone()), NullableInterval::UNKNOWN), + (binary_expr(one.clone(), Or, one.clone()), NullableInterval::TRUE), + (binary_expr(one.clone(), Or, zero.clone()), NullableInterval::TRUE), + (binary_expr(null.clone(), Or, t.clone()), NullableInterval::TRUE), + (binary_expr(t.clone(), Or, null.clone()), NullableInterval::TRUE), + (binary_expr(null.clone(), Or, f.clone()), NullableInterval::UNKNOWN), + (binary_expr(f.clone(), Or, null.clone()), NullableInterval::UNKNOWN), + (binary_expr(t.clone(), Or, t.clone()), NullableInterval::TRUE), + (binary_expr(t.clone(), Or, f.clone()), NullableInterval::TRUE), + (binary_expr(f.clone(), Or, t.clone()), NullableInterval::TRUE), + (binary_expr(f.clone(), Or, f.clone()), NullableInterval::FALSE), + (binary_expr(t.clone(), Or, func.clone()), NullableInterval::TRUE), + (binary_expr(func.clone(), Or, t.clone()), NullableInterval::TRUE), + (binary_expr(f.clone(), Or, func.clone()), NullableInterval::ANY_TRUTH_VALUE), + (binary_expr(func.clone(), Or, f.clone()), NullableInterval::ANY_TRUTH_VALUE), + (binary_expr(null.clone(), Or, func.clone()), NullableInterval::TRUE_OR_UNKNOWN), + (binary_expr(func.clone(), Or, null.clone()), NullableInterval::TRUE_OR_UNKNOWN), + ]; + + for case in cases { + assert_eq!( + eval_bounds(&case.0).unwrap(), + case.1, + "Failed for {}", + case.0 + ); + } + } + + #[test] + fn evaluate_bounds_not() { + let null = lit(ScalarValue::Null); + let zero = lit(0); + let one = lit(1); + let t = lit(true); + let f = lit(false); + let func = make_scalar_func_expr(); + + #[rustfmt::skip] + let cases = vec![ + (not(null.clone()), NullableInterval::UNKNOWN), + (not(one.clone()), NullableInterval::FALSE), + (not(zero.clone()), NullableInterval::TRUE), + (not(t.clone()), NullableInterval::FALSE), + (not(f.clone()), NullableInterval::TRUE), + (not(func.clone()), NullableInterval::ANY_TRUTH_VALUE), + ]; + + for case in cases { + assert_eq!( + eval_bounds(&case.0).unwrap(), + case.1, + "Failed for {}", + case.0 + ); + } + } + + #[test] + fn evaluate_bounds_is() { + let null = lit(ScalarValue::Null); + let zero = lit(0); + let one = lit(1); + let t = lit(true); + let f = lit(false); + let col = col("col"); + let nullable_schema = DFSchema::try_from(Schema::new(vec![Field::new( + "col", + DataType::UInt8, + true, + )])) + .unwrap(); + let not_nullable_schema = DFSchema::try_from(Schema::new(vec![Field::new( + "col", + DataType::UInt8, + false, + )])) + .unwrap(); + + #[rustfmt::skip] + let cases = vec![ + (is_null(null.clone()), NullableInterval::TRUE), + (is_null(one.clone()), NullableInterval::FALSE), + (is_null(binary_expr(null.clone(), Eq, null.clone())), NullableInterval::TRUE), + (is_not_null(null.clone()), NullableInterval::FALSE), + (is_not_null(one.clone()), NullableInterval::TRUE), + (is_not_null(binary_expr(null.clone(), Eq, null.clone())), NullableInterval::FALSE), + (is_true(null.clone()), NullableInterval::FALSE), + (is_true(t.clone()), NullableInterval::TRUE), + (is_true(f.clone()), NullableInterval::FALSE), + (is_true(zero.clone()), NullableInterval::FALSE), + (is_true(one.clone()), NullableInterval::TRUE), + (is_true(binary_expr(null.clone(), Eq, null.clone())), NullableInterval::FALSE), + (is_not_true(null.clone()), NullableInterval::TRUE), + (is_not_true(t.clone()), NullableInterval::FALSE), + (is_not_true(f.clone()), NullableInterval::TRUE), + (is_not_true(zero.clone()), NullableInterval::TRUE), + (is_not_true(one.clone()), NullableInterval::FALSE), + (is_not_true(binary_expr(null.clone(), Eq, null.clone())), NullableInterval::TRUE), + (is_false(null.clone()), NullableInterval::FALSE), + (is_false(t.clone()), NullableInterval::FALSE), + (is_false(f.clone()), NullableInterval::TRUE), + (is_false(zero.clone()), NullableInterval::TRUE), + (is_false(one.clone()), NullableInterval::FALSE), + (is_false(binary_expr(null.clone(), Eq, null.clone())), NullableInterval::FALSE), + (is_not_false(null.clone()), NullableInterval::TRUE), + (is_not_false(t.clone()), NullableInterval::TRUE), + (is_not_false(f.clone()), NullableInterval::FALSE), + (is_not_false(zero.clone()), NullableInterval::FALSE), + (is_not_false(one.clone()), NullableInterval::TRUE), + (is_not_false(binary_expr(null.clone(), Eq, null.clone())), NullableInterval::TRUE), + (is_unknown(null.clone()), NullableInterval::TRUE), + (is_unknown(t.clone()), NullableInterval::FALSE), + (is_unknown(f.clone()), NullableInterval::FALSE), + (is_unknown(zero.clone()), NullableInterval::FALSE), + (is_unknown(one.clone()), NullableInterval::FALSE), + (is_unknown(binary_expr(null.clone(), Eq, null.clone())), NullableInterval::TRUE), + (is_not_unknown(null.clone()), NullableInterval::FALSE), + (is_not_unknown(t.clone()), NullableInterval::TRUE), + (is_not_unknown(f.clone()), NullableInterval::TRUE), + (is_not_unknown(zero.clone()), NullableInterval::TRUE), + (is_not_unknown(one.clone()), NullableInterval::TRUE), + (is_not_unknown(binary_expr(null.clone(), Eq, null.clone())), NullableInterval::FALSE), + ]; + + for case in cases { + assert_eq!( + eval_bounds(&case.0).unwrap(), + case.1, + "Failed for {}", + case.0 + ); + } + + #[rustfmt::skip] + let cases = vec![ + (is_null(col.clone()), &nullable_schema, NullableInterval::TRUE_OR_FALSE), + (is_null(col.clone()), ¬_nullable_schema, NullableInterval::FALSE), + (is_null(binary_expr(col.clone(), Eq, col.clone())), &nullable_schema, NullableInterval::TRUE_OR_FALSE), + (is_null(binary_expr(col.clone(), Eq, col.clone())), ¬_nullable_schema, NullableInterval::FALSE), + (is_not_null(col.clone()), &nullable_schema, NullableInterval::TRUE_OR_FALSE), + (is_not_null(col.clone()), ¬_nullable_schema, NullableInterval::TRUE), + (is_not_null(binary_expr(col.clone(), Eq, col.clone())), &nullable_schema, NullableInterval::TRUE_OR_FALSE), + (is_not_null(binary_expr(col.clone(), Eq, col.clone())), ¬_nullable_schema, NullableInterval::TRUE), + ]; + + for case in cases { + assert_eq!( + evaluate_bounds(&case.0, case.1).unwrap(), + case.2, + "Failed for {}", + case.0 + ); + } + } + + #[test] + fn evaluate_bounds_between() { + let null = lit(ScalarValue::Null); + let zero = lit(0); + + #[rustfmt::skip] + let cases = vec![ + (zero.clone().between(zero.clone(), zero.clone()), NullableInterval::TRUE_OR_FALSE), + (null.clone().between(zero.clone(), zero.clone()), NullableInterval::UNKNOWN), + (zero.clone().between(null.clone(), zero.clone()), NullableInterval::ANY_TRUTH_VALUE), + (zero.clone().between(zero.clone(), null.clone()), NullableInterval::ANY_TRUTH_VALUE), + (zero.clone().between(null.clone(), null.clone()), NullableInterval::UNKNOWN), + (null.clone().between(null.clone(), null.clone()), NullableInterval::UNKNOWN), + ]; + + for case in cases { + assert_eq!( + eval_bounds(&case.0).unwrap(), + case.1, + "Failed for {}", + case.0 + ); + } + } + + #[test] + fn evaluate_bounds_binary_op() { + let null = lit(ScalarValue::Null); + let zero = lit(0); + let col = col("col"); + let nullable_schema = DFSchema::try_from(Schema::new(vec![Field::new( + "col", + DataType::Utf8, + true, + )])) + .unwrap(); + let not_nullable_schema = DFSchema::try_from(Schema::new(vec![Field::new( + "col", + DataType::Utf8, + false, + )])) + .unwrap(); + + #[rustfmt::skip] + let cases = vec![ + (binary_expr(zero.clone(), Eq, zero.clone()), NullableInterval::TRUE_OR_FALSE), + (binary_expr(null.clone(), Eq, zero.clone()), NullableInterval::UNKNOWN), + (binary_expr(zero.clone(), Eq, null.clone()), NullableInterval::UNKNOWN), + (binary_expr(null.clone(), Eq, null.clone()), NullableInterval::UNKNOWN), + ]; + + for case in cases { + assert_eq!( + eval_bounds(&case.0).unwrap(), + case.1, + "Failed for {}", + case.0 + ); + } + + #[rustfmt::skip] + let cases = vec![ + (binary_expr(zero.clone(), Eq, col.clone()), NullableInterval::TRUE_OR_FALSE), + (binary_expr(col.clone(), Eq, zero.clone()), NullableInterval::TRUE_OR_FALSE), + ]; + + for case in cases { + assert_eq!( + evaluate_bounds(&case.0, ¬_nullable_schema).unwrap(), + case.1, + "Failed for {}", + case.0 + ); + + assert_eq!( + evaluate_bounds(&case.0, &nullable_schema).unwrap(), + NullableInterval::ANY_TRUTH_VALUE, + "Failed for {}", + case.0 + ); + } + } + + #[test] + fn evaluate_bounds_negative() { + let null = lit(ScalarValue::Null); + let zero = lit(0); + + #[rustfmt::skip] + let cases = vec![ + (zero.clone().neg(), NullableInterval::TRUE_OR_FALSE), + (null.clone().neg(), NullableInterval::UNKNOWN), + ]; + + for case in cases { + assert_eq!( + eval_bounds(&case.0).unwrap(), + case.1, + "Failed for {}", + case.0 + ); + } + } + + #[test] + fn evaluate_bounds_like() { + let null = lit(ScalarValue::Null); + let expr = lit("foo"); + let pattern = lit("f.*"); + let col = col("col"); + let nullable_schema = DFSchema::try_from(Schema::new(vec![Field::new( + "col", + DataType::Utf8, + true, + )])) + .unwrap(); + let not_nullable_schema = DFSchema::try_from(Schema::new(vec![Field::new( + "col", + DataType::Utf8, + false, + )])) + .unwrap(); + + #[rustfmt::skip] + let cases = vec![ + (expr.clone().like(pattern.clone()), NullableInterval::TRUE_OR_FALSE), + (null.clone().like(pattern.clone()), NullableInterval::UNKNOWN), + (expr.clone().like(null.clone()), NullableInterval::UNKNOWN), + (null.clone().like(null.clone()), NullableInterval::UNKNOWN), + ]; + + for case in cases { + assert_eq!( + eval_bounds(&case.0).unwrap(), + case.1, + "Failed for {}", + case.0 + ); + } + + #[rustfmt::skip] + let cases = vec![ + (col.clone().like(pattern.clone()), NullableInterval::TRUE_OR_FALSE), + (expr.clone().like(col.clone()), NullableInterval::TRUE_OR_FALSE), + ]; + + for case in cases { + assert_eq!( + evaluate_bounds(&case.0, ¬_nullable_schema).unwrap(), + case.1, + "Failed for {}", + case.0 + ); + + assert_eq!( + evaluate_bounds(&case.0, &nullable_schema).unwrap(), + NullableInterval::ANY_TRUTH_VALUE, + "Failed for {}", + case.0 + ); + } + } + + #[test] + fn evaluate_bounds_udf() { + let func = make_scalar_func_expr(); + + #[rustfmt::skip] + let cases = vec![ + (func.clone(), NullableInterval::ANY_TRUTH_VALUE), + (not(func.clone()), NullableInterval::ANY_TRUTH_VALUE), + (binary_expr(func.clone(), And, func.clone()), NullableInterval::ANY_TRUTH_VALUE), + ]; + + for case in cases { + assert_eq!(eval_bounds(&case.0).unwrap(), case.1); + } + } + + fn make_scalar_func_expr() -> Expr { + let scalar_func_impl = + |_: &[ColumnarValue]| Ok(ColumnarValue::Scalar(ScalarValue::Null)); + let udf = create_udf( + "foo", + vec![], + DataType::Boolean, + Volatility::Stable, + Arc::new(scalar_func_impl), + ); + Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(udf), vec![])) + } +} diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index 9db4a51c8404..ddb0a98df537 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -17,7 +17,7 @@ use super::{Column, Literal}; use crate::expressions::case::ResultState::{Complete, Empty, Partial}; -use crate::expressions::try_cast; +use crate::expressions::{lit, try_cast}; use crate::PhysicalExpr; use arrow::array::*; use arrow::compute::kernels::zip::zip; @@ -34,13 +34,14 @@ use datafusion_common::{ DataFusionError, HashMap, HashSet, Result, ScalarValue, }; use datafusion_expr::ColumnarValue; -use datafusion_physical_expr_common::datum::compare_with_eq; -use itertools::Itertools; use std::borrow::Cow; -use std::fmt::{Debug, Formatter}; use std::hash::Hash; use std::{any::Any, sync::Arc}; +use datafusion_physical_expr_common::datum::compare_with_eq; +use itertools::Itertools; +use std::fmt::{Debug, Formatter}; + type WhenThen = (Arc, Arc); #[derive(Debug, Hash, PartialEq, Eq)] @@ -1286,16 +1287,62 @@ impl PhysicalExpr for CaseExpr { } fn nullable(&self, input_schema: &Schema) -> Result { - // this expression is nullable if any of the input expressions are nullable - let then_nullable = self + let nullable_then = self .body .when_then_expr .iter() - .map(|(_, t)| t.nullable(input_schema)) - .collect::>>()?; - if then_nullable.contains(&true) { - Ok(true) + .filter_map(|(w, t)| { + let is_nullable = match t.nullable(input_schema) { + // Pass on error determining nullability verbatim + Err(e) => return Some(Err(e)), + Ok(n) => n, + }; + + // Branches with a then expression that is not nullable do not impact the + // nullability of the case expression. + if !is_nullable { + return None; + } + + // For case-with-expression assume all 'then' expressions are reachable + if self.body.expr.is_some() { + return Some(Ok(())); + } + + // For branches with a nullable 'then' expression, try to determine + // if the 'then' expression is ever reachable in the situation where + // it would evaluate to null. + + // Replace the `then` expression with `NULL` in the `when` expression + let with_null = match replace_with_null(w, t.as_ref(), input_schema) { + Err(e) => return Some(Err(e)), + Ok(e) => e, + }; + + // Try to const evaluate the modified `when` expression. + let predicate_result = match evaluate_predicate(&with_null) { + Err(e) => return Some(Err(e)), + Ok(b) => b, + }; + + match predicate_result { + // Evaluation was inconclusive or true, so the 'then' expression is reachable + None | Some(true) => Some(Ok(())), + // Evaluation proves the branch will never be taken. + // The most common pattern for this is `WHEN x IS NOT NULL THEN x`. + Some(false) => None, + } + }) + .next(); + + if let Some(nullable_then) = nullable_then { + // There is at least one reachable nullable 'then' expression, so the case + // expression itself is nullable. + // Use `Result::map` to propagate the error from `nullable_then` if there is one. + nullable_then.map(|_| true) } else if let Some(e) = &self.body.else_expr { + // There are no reachable nullable 'then' expressions, so all we still need to + // check is the 'else' expression's nullability. e.nullable(input_schema) } else { // CASE produces NULL if there is no `else` expr @@ -1398,6 +1445,51 @@ impl PhysicalExpr for CaseExpr { } } +/// Attempts to const evaluate the given `predicate`. +/// Returns: +/// - `Some(true)` if the predicate evaluates to a truthy value. +/// - `Some(false)` if the predicate evaluates to a falsy value. +/// - `None` if the predicate could not be evaluated. +fn evaluate_predicate(predicate: &Arc) -> Result> { + // Create a dummy record with no columns and one row + let batch = RecordBatch::try_new_with_options( + Arc::new(Schema::empty()), + vec![], + &RecordBatchOptions::new().with_row_count(Some(1)), + )?; + + // Evaluate the predicate and interpret the result as a boolean + let result = match predicate.evaluate(&batch) { + // An error during evaluation means we couldn't const evaluate the predicate, so return `None` + Err(_) => None, + Ok(ColumnarValue::Array(array)) => Some( + ScalarValue::try_from_array(array.as_ref(), 0)? + .cast_to(&DataType::Boolean)?, + ), + Ok(ColumnarValue::Scalar(scalar)) => Some(scalar.cast_to(&DataType::Boolean)?), + }; + Ok(result.map(|v| matches!(v, ScalarValue::Boolean(Some(true))))) +} + +fn replace_with_null( + expr: &Arc, + expr_to_replace: &dyn PhysicalExpr, + input_schema: &Schema, +) -> Result, DataFusionError> { + let with_null = Arc::clone(expr) + .transform_down(|e| { + if e.as_ref().dyn_eq(expr_to_replace) { + let data_type = e.data_type(input_schema)?; + let null_literal = lit(ScalarValue::try_new_null(&data_type)?); + Ok(Transformed::yes(null_literal)) + } else { + Ok(Transformed::no(e)) + } + })? + .data; + Ok(with_null) +} + /// Create a CASE expression pub fn case( expr: Option>, @@ -1411,7 +1503,8 @@ pub fn case( mod tests { use super::*; - use crate::expressions::{binary, cast, col, lit, BinaryExpr}; + use crate::expressions; + use crate::expressions::{binary, cast, col, is_not_null, lit, BinaryExpr}; use arrow::buffer::Buffer; use arrow::datatypes::DataType::Float64; use arrow::datatypes::Field; @@ -1419,7 +1512,7 @@ mod tests { use datafusion_common::plan_err; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_expr::type_coercion::binary::comparison_coercion; - use datafusion_expr::Operator; + use datafusion_expr_common::operator::Operator; use datafusion_physical_expr_common::physical_expr::fmt_sql; #[test] @@ -2296,4 +2389,182 @@ mod tests { assert!(merged.is_valid(2)); assert_eq!(merged.value(2), "C"); } + + fn when_then_else( + when: &Arc, + then: &Arc, + els: &Arc, + ) -> Result> { + let case = CaseExpr::try_new( + None, + vec![(Arc::clone(when), Arc::clone(then))], + Some(Arc::clone(els)), + )?; + Ok(Arc::new(case)) + } + + #[test] + fn test_case_expression_nullability_with_nullable_column() -> Result<()> { + case_expression_nullability(true) + } + + #[test] + fn test_case_expression_nullability_with_not_nullable_column() -> Result<()> { + case_expression_nullability(false) + } + + fn case_expression_nullability(col_is_nullable: bool) -> Result<()> { + let schema = + Schema::new(vec![Field::new("foo", DataType::Int32, col_is_nullable)]); + + let foo = col("foo", &schema)?; + let foo_is_not_null = is_not_null(Arc::clone(&foo))?; + let foo_is_null = expressions::is_null(Arc::clone(&foo))?; + let not_foo_is_null = expressions::not(Arc::clone(&foo_is_null))?; + let zero = lit(0); + let foo_eq_zero = + binary(Arc::clone(&foo), Operator::Eq, Arc::clone(&zero), &schema)?; + + assert_not_nullable(when_then_else(&foo_is_not_null, &foo, &zero)?, &schema); + assert_not_nullable(when_then_else(¬_foo_is_null, &foo, &zero)?, &schema); + assert_not_nullable(when_then_else(&foo_eq_zero, &foo, &zero)?, &schema); + + assert_not_nullable( + when_then_else( + &binary( + Arc::clone(&foo_is_not_null), + Operator::And, + Arc::clone(&foo_eq_zero), + &schema, + )?, + &foo, + &zero, + )?, + &schema, + ); + + assert_not_nullable( + when_then_else( + &binary( + Arc::clone(&foo_eq_zero), + Operator::And, + Arc::clone(&foo_is_not_null), + &schema, + )?, + &foo, + &zero, + )?, + &schema, + ); + + assert_not_nullable( + when_then_else( + &binary( + Arc::clone(&foo_is_not_null), + Operator::Or, + Arc::clone(&foo_eq_zero), + &schema, + )?, + &foo, + &zero, + )?, + &schema, + ); + + assert_not_nullable( + when_then_else( + &binary( + Arc::clone(&foo_eq_zero), + Operator::Or, + Arc::clone(&foo_is_not_null), + &schema, + )?, + &foo, + &zero, + )?, + &schema, + ); + + assert_nullability( + when_then_else( + &binary( + Arc::clone(&foo_is_null), + Operator::Or, + Arc::clone(&foo_eq_zero), + &schema, + )?, + &foo, + &zero, + )?, + &schema, + col_is_nullable, + ); + + assert_nullability( + when_then_else( + &binary( + binary(Arc::clone(&foo), Operator::Eq, Arc::clone(&zero), &schema)?, + Operator::Or, + Arc::clone(&foo_is_null), + &schema, + )?, + &foo, + &zero, + )?, + &schema, + col_is_nullable, + ); + + assert_not_nullable( + when_then_else( + &binary( + binary( + binary( + Arc::clone(&foo), + Operator::Eq, + Arc::clone(&zero), + &schema, + )?, + Operator::And, + Arc::clone(&foo_is_not_null), + &schema, + )?, + Operator::Or, + binary( + binary( + Arc::clone(&foo), + Operator::Eq, + Arc::clone(&foo), + &schema, + )?, + Operator::And, + Arc::clone(&foo_is_not_null), + &schema, + )?, + &schema, + )?, + &foo, + &zero, + )?, + &schema, + ); + + Ok(()) + } + + fn assert_not_nullable(expr: Arc, schema: &Schema) { + assert!(!expr.nullable(schema).unwrap()); + } + + fn assert_nullable(expr: Arc, schema: &Schema) { + assert!(expr.nullable(schema).unwrap()); + } + + fn assert_nullability(expr: Arc, schema: &Schema, nullable: bool) { + if nullable { + assert_nullable(expr, schema); + } else { + assert_not_nullable(expr, schema); + } + } } diff --git a/datafusion/sqllogictest/test_files/case.slt b/datafusion/sqllogictest/test_files/case.slt index 1a4b6a7a2b4a..3905575d22dc 100644 --- a/datafusion/sqllogictest/test_files/case.slt +++ b/datafusion/sqllogictest/test_files/case.slt @@ -683,3 +683,23 @@ FROM ( 10 10 100 -20 20 200 NULL 30 300 + +# Case-with-expression that was incorrectly classified as not-nullable, but evaluates to null +query I +SELECT CASE 0 WHEN 0 THEN NULL WHEN SUM(1) + COUNT(*) THEN 10 ELSE 20 END +---- +NULL + +query TT +EXPLAIN SELECT CASE WHEN CASE WHEN a IS NOT NULL THEN a ELSE 1 END IS NOT NULL THEN a ELSE 1 END FROM ( + VALUES (10), (20), (30) + ) t(a); +---- +logical_plan +01)Projection: t.a AS CASE WHEN CASE WHEN t.a IS NOT NULL THEN t.a ELSE Int64(1) END IS NOT NULL THEN t.a ELSE Int64(1) END +02)--SubqueryAlias: t +03)----Projection: column1 AS a +04)------Values: (Int64(10)), (Int64(20)), (Int64(30)) +physical_plan +01)ProjectionExec: expr=[column1@0 as CASE WHEN CASE WHEN t.a IS NOT NULL THEN t.a ELSE Int64(1) END IS NOT NULL THEN t.a ELSE Int64(1) END] +02)--DataSourceExec: partitions=1, partition_sizes=[1]