diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index f93f08574906..cdad7e331427 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -508,7 +508,7 @@ fn field_for_index( /// cast subquery in InSubquery/ScalarSubquery to a given type. pub fn cast_subquery(subquery: Subquery, cast_to_type: &DataType) -> Result { - if subquery.subquery.schema().field(0).data_type() == cast_to_type { + if subquery.data_type() == cast_to_type { return Ok(subquery); } diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 23f5280377a3..8199327b212a 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -52,6 +52,7 @@ use datafusion_common::{ // backwards compatibility use crate::display::PgJsonVisitor; +use crate::logical_plan::tree_node::unwrap_arc; pub use datafusion_common::display::{PlanType, StringifiedPlan, ToStringifiedPlan}; pub use datafusion_common::{JoinConstraint, JoinType}; @@ -467,6 +468,200 @@ impl LogicalPlan { self.with_new_exprs(self.expressions(), inputs.to_vec()) } + /// Recomputes schema and type information for this LogicalPlan if needed. + /// + /// Some `LogicalPlan`s may need to recompute their schema if the number or + /// type of expressions have been changed (for example due to type + /// coercion). For example [`LogicalPlan::Projection`]s schema depends on + /// its expressions. + /// + /// Some `LogicalPlan`s schema is unaffected by any changes to their + /// expressions. For example [`LogicalPlan::Filter`] schema is always the + /// same as its input schema. + /// + /// # Return value + /// Returns an error if there is some issue recomputing the schema. + /// + /// # Notes + /// + /// * Does not recursively recompute schema for input (child) plans. + pub fn recompute_schema(self) -> Result { + match self { + // Since expr may be different than the previous expr, schema of the projection + // may change. We need to use try_new method instead of try_new_with_schema method. + LogicalPlan::Projection(Projection { + expr, + input, + schema: _, + }) => Projection::try_new(expr, input).map(LogicalPlan::Projection), + LogicalPlan::Dml(_) => Ok(self), + LogicalPlan::Copy(_) => Ok(self), + LogicalPlan::Values(Values { schema, values }) => { + // todo it isn't clear why the schema is not recomputed here + Ok(LogicalPlan::Values(Values { schema, values })) + } + LogicalPlan::Filter(Filter { predicate, input }) => { + // todo: should this logic be moved to Filter::try_new? + + // filter predicates should not contain aliased expressions so we remove any aliases + // before this logic was added we would have aliases within filters such as for + // benchmark q6: + // + // lineitem.l_shipdate >= Date32(\"8766\") + // AND lineitem.l_shipdate < Date32(\"9131\") + // AND CAST(lineitem.l_discount AS Decimal128(30, 15)) AS lineitem.l_discount >= + // Decimal128(Some(49999999999999),30,15) + // AND CAST(lineitem.l_discount AS Decimal128(30, 15)) AS lineitem.l_discount <= + // Decimal128(Some(69999999999999),30,15) + // AND lineitem.l_quantity < Decimal128(Some(2400),15,2) + + let predicate = predicate + .transform_down(|expr| { + match expr { + Expr::Exists { .. } + | Expr::ScalarSubquery(_) + | Expr::InSubquery(_) => { + // subqueries could contain aliases so we don't recurse into those + Ok(Transformed::new(expr, false, TreeNodeRecursion::Jump)) + } + Expr::Alias(_) => Ok(Transformed::new( + expr.unalias(), + true, + TreeNodeRecursion::Jump, + )), + _ => Ok(Transformed::no(expr)), + } + }) + .data()?; + + Filter::try_new(predicate, input).map(LogicalPlan::Filter) + } + LogicalPlan::Repartition(_) => Ok(self), + LogicalPlan::Window(Window { + input, + window_expr, + schema: _, + }) => Window::try_new(window_expr, input).map(LogicalPlan::Window), + LogicalPlan::Aggregate(Aggregate { + input, + group_expr, + aggr_expr, + schema: _, + }) => Aggregate::try_new(input, group_expr, aggr_expr) + .map(LogicalPlan::Aggregate), + LogicalPlan::Sort(_) => Ok(self), + LogicalPlan::Join(Join { + left, + right, + filter, + join_type, + join_constraint, + on, + schema: _, + null_equals_null, + }) => { + let schema = + build_join_schema(left.schema(), right.schema(), &join_type)?; + + let new_on: Vec<_> = on + .into_iter() + .map(|equi_expr| { + // SimplifyExpression rule may add alias to the equi_expr. + (equi_expr.0.unalias(), equi_expr.1.unalias()) + }) + .collect(); + + Ok(LogicalPlan::Join(Join { + left, + right, + join_type, + join_constraint, + on: new_on, + filter, + schema: DFSchemaRef::new(schema), + null_equals_null, + })) + } + LogicalPlan::CrossJoin(CrossJoin { + left, + right, + schema: _, + }) => { + let join_schema = + build_join_schema(left.schema(), right.schema(), &JoinType::Inner)?; + + Ok(LogicalPlan::CrossJoin(CrossJoin { + left, + right, + schema: join_schema.into(), + })) + } + LogicalPlan::Subquery(_) => Ok(self), + LogicalPlan::SubqueryAlias(SubqueryAlias { + input, + alias, + schema: _, + }) => SubqueryAlias::try_new(input, alias).map(LogicalPlan::SubqueryAlias), + LogicalPlan::Limit(_) => Ok(self), + LogicalPlan::Ddl(_) => Ok(self), + LogicalPlan::Extension(Extension { node }) => { + // todo make an API that does not require cloning + // This requires a copy of the extension nodes expressions and inputs + let expr = node.expressions(); + let inputs: Vec<_> = node.inputs().into_iter().cloned().collect(); + Ok(LogicalPlan::Extension(Extension { + node: node.from_template(&expr, &inputs), + })) + } + LogicalPlan::Union(Union { inputs, schema }) => { + let input_schema = inputs[0].schema(); + // If inputs are not pruned do not change schema + // TODO this seems wrong (shouldn't we always use the schema of the input?) + let schema = if schema.fields().len() == input_schema.fields().len() { + schema.clone() + } else { + input_schema.clone() + }; + Ok(LogicalPlan::Union(Union { inputs, schema })) + } + LogicalPlan::Distinct(distinct) => { + let distinct = match distinct { + Distinct::All(input) => Distinct::All(input), + Distinct::On(DistinctOn { + on_expr, + select_expr, + sort_expr, + input, + schema: _, + }) => Distinct::On(DistinctOn::try_new( + on_expr, + select_expr, + sort_expr, + input, + )?), + }; + Ok(LogicalPlan::Distinct(distinct)) + } + LogicalPlan::RecursiveQuery(_) => Ok(self), + LogicalPlan::Analyze(_) => Ok(self), + LogicalPlan::Explain(_) => Ok(self), + LogicalPlan::Prepare(_) => Ok(self), + LogicalPlan::TableScan(_) => Ok(self), + LogicalPlan::EmptyRelation(_) => Ok(self), + LogicalPlan::Statement(_) => Ok(self), + LogicalPlan::DescribeTable(_) => Ok(self), + LogicalPlan::Unnest(Unnest { + input, + columns, + schema: _, + options, + }) => { + // Update schema with unnested column type. + unnest_with_options(unwrap_arc(input), columns, options) + } + } + } + /// Returns a new `LogicalPlan` based on `self` with inputs and /// expressions replaced. /// @@ -2490,6 +2685,11 @@ impl Subquery { outer_ref_columns: self.outer_ref_columns.clone(), } } + + /// Returns the type of the first column of the subquery + pub fn data_type(&self) -> &DataType { + self.subquery.schema().fields()[0].data_type() + } } impl Debug for Subquery { diff --git a/datafusion/optimizer/src/analyzer/mod.rs b/datafusion/optimizer/src/analyzer/mod.rs index fb0eb14da659..98f664e48269 100644 --- a/datafusion/optimizer/src/analyzer/mod.rs +++ b/datafusion/optimizer/src/analyzer/mod.rs @@ -22,7 +22,7 @@ use log::debug; use datafusion_common::config::ConfigOptions; use datafusion_common::instant::Instant; -use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion}; +use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion}; use datafusion_common::{DataFusionError, Result}; use datafusion_expr::expr::Exists; use datafusion_expr::expr::InSubquery; @@ -62,6 +62,17 @@ pub trait AnalyzerRule { /// Rewrite `plan` fn analyze(&self, plan: LogicalPlan, config: &ConfigOptions) -> Result; + /// Rewrite a plan indicating if the plan was modified + /// The default implementation calls `analyze` + fn rewrite( + &self, + plan: LogicalPlan, + config: &ConfigOptions, + ) -> Result> { + let new_plan = self.analyze(plan, config)?; + Ok(Transformed::yes(new_plan)) + } + /// A human readable name for this analyzer rule fn name(&self) -> &str; } diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index 9295b08f419e..825dc77aa0bc 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -22,7 +22,9 @@ use std::sync::Arc; use arrow::datatypes::{DataType, IntervalUnit}; use datafusion_common::config::ConfigOptions; -use datafusion_common::tree_node::{Transformed, TreeNodeRewriter}; +use datafusion_common::tree_node::{ + Transformed, TreeNode, TreeNodeIterator, TreeNodeRewriter, +}; use datafusion_common::{ exec_err, internal_err, not_impl_err, plan_datafusion_err, plan_err, DFSchema, DataFusionError, Result, ScalarValue, @@ -31,8 +33,8 @@ use datafusion_expr::expr::{ self, AggregateFunctionDefinition, Between, BinaryExpr, Case, Exists, InList, InSubquery, Like, ScalarFunction, WindowFunction, }; -use datafusion_expr::expr_rewriter::rewrite_preserving_name; use datafusion_expr::expr_schema::cast_subquery; +use datafusion_expr::logical_plan::tree_node::unwrap_arc; use datafusion_expr::logical_plan::Subquery; use datafusion_expr::type_coercion::binary::{ comparison_coercion, get_input_types, like_coercion, @@ -45,12 +47,13 @@ use datafusion_expr::type_coercion::{is_datetime, is_utf8_or_large_utf8}; use datafusion_expr::utils::merge_schema; use datafusion_expr::{ is_false, is_not_false, is_not_true, is_not_unknown, is_true, is_unknown, not, - type_coercion, AggregateFunction, Expr, ExprSchemable, LogicalPlan, Operator, + type_coercion, AggregateFunction, Expr, ExprSchemable, Join, LogicalPlan, Operator, ScalarFunctionDefinition, ScalarUDF, Signature, WindowFrame, WindowFrameBound, WindowFrameUnits, }; use crate::analyzer::AnalyzerRule; +use crate::utils::NamePreserver; #[derive(Default)] pub struct TypeCoercion {} @@ -67,58 +70,100 @@ impl AnalyzerRule for TypeCoercion { } fn analyze(&self, plan: LogicalPlan, _: &ConfigOptions) -> Result { - analyze_internal(&DFSchema::empty(), &plan) + let empty = DFSchema::empty(); + let mut plan_rewriter = TypeCoercionPlanRewriter::new(&empty); + + let transformed_plan = plan.rewrite_with_subqueries(&mut plan_rewriter)?.data; + + Ok(transformed_plan) } } -fn analyze_internal( - // use the external schema to handle the correlated subqueries case - external_schema: &DFSchema, - plan: &LogicalPlan, -) -> Result { - // optimize child plans first - let new_inputs = plan - .inputs() - .iter() - .map(|p| analyze_internal(external_schema, p)) - .collect::>>()?; - // get schema representing all available input fields. This is used for data type - // resolution only, so order does not matter here - let mut schema = merge_schema(new_inputs.iter().collect()); - - if let LogicalPlan::TableScan(ts) = plan { - let source_schema = DFSchema::try_from_qualified_schema( - ts.table_name.clone(), - &ts.source.schema(), - )?; - schema.merge(&source_schema); +/// Rewrites plans to ensure that all expressions have valid types +/// tracking if any input plans have been transformed +pub struct TypeCoercionPlanRewriter<'a> { + /// were any child plans transformed? If so, we need to recompute the schema + /// of the parent plan as some plan outputs are dependent on the schema of + /// the children + any_plan_transformed: bool, + /// The outer query schema, if any + external_schema: &'a DFSchema, +} + +impl<'a> TypeCoercionPlanRewriter<'a> { + fn new(external_schema: &'a DFSchema) -> Self { + Self { + any_plan_transformed: false, + external_schema, + } } +} - // merge the outer schema for correlated subqueries - // like case: - // select t2.c2 from t1 where t1.c1 in (select t2.c1 from t2 where t2.c2=t1.c3) - schema.merge(external_schema); +impl<'a> TreeNodeRewriter for TypeCoercionPlanRewriter<'a> { + type Node = LogicalPlan; - let mut expr_rewrite = TypeCoercionRewriter { schema: &schema }; + fn f_up(&mut self, plan: Self::Node) -> Result> { + // get schema representing all available input fields. This is used for data type + // resolution only, so order does not matter here + let mut schema = merge_schema(plan.inputs()); - let new_expr = plan - .expressions() - .into_iter() - .map(|expr| { - // ensure aggregate names don't change: - // https://github.com/apache/datafusion/issues/3555 - rewrite_preserving_name(expr, &mut expr_rewrite) - }) - .collect::>>()?; + if let LogicalPlan::TableScan(ts) = &plan { + let source_schema = DFSchema::try_from_qualified_schema( + ts.table_name.clone(), + &ts.source.schema(), + )?; + schema.merge(&source_schema); + } + + // merge the outer schema for correlated subqueries + // like case: + // select t2.c2 from t1 where t1.c1 in (select t2.c1 from t2 where t2.c2=t1.c3) + schema.merge(self.external_schema); + + let mut expr_rewrite = TypeCoercionExprRewriter::new(&schema); + let name_preserver = NamePreserver::new(&plan); + // apply coercion rewrite to all expressions in the plan individually + let transformed_plan = plan + .map_expressions(|expr| { + let original_name = name_preserver.save(&expr)?; + expr.rewrite(&mut expr_rewrite)? + .map_data(|expr| original_name.restore(expr)) + })? + // coerce join expressions specially + .transform_data(|plan| expr_rewrite.coerce_joins(plan))?; + + // Note: We must recompute the schema if any of expressions (or inputs) + // have been rewritten, as the types may have changed. + + if transformed_plan.transformed { + self.any_plan_transformed = true; + } - plan.with_new_exprs(new_expr, new_inputs) + // Hack: For subquery alias need to recompute the schema for unknown reasons + if matches!(transformed_plan.data, LogicalPlan::SubqueryAlias(_)) { + self.any_plan_transformed = true; + } + + if self.any_plan_transformed { + transformed_plan + .transform_data(|plan| plan.recompute_schema().map(Transformed::yes)) + } else { + Ok(transformed_plan) + } + } } -pub(crate) struct TypeCoercionRewriter<'a> { +pub(crate) struct TypeCoercionExprRewriter<'a> { pub(crate) schema: &'a DFSchema, } -impl<'a> TreeNodeRewriter for TypeCoercionRewriter<'a> { +impl<'a> TypeCoercionExprRewriter<'a> { + fn new(schema: &'a DFSchema) -> Self { + Self { schema } + } +} + +impl<'a> TreeNodeRewriter for TypeCoercionExprRewriter<'a> { type Node = Expr; fn f_up(&mut self, expr: Expr) -> Result> { @@ -126,70 +171,67 @@ impl<'a> TreeNodeRewriter for TypeCoercionRewriter<'a> { Expr::Unnest(_) => not_impl_err!( "Unnest should be rewritten to LogicalPlan::Unnest before type coercion" ), - Expr::ScalarSubquery(Subquery { - subquery, - outer_ref_columns, - }) => { - let new_plan = analyze_internal(self.schema, &subquery)?; - Ok(Transformed::yes(Expr::ScalarSubquery(Subquery { - subquery: Arc::new(new_plan), - outer_ref_columns, - }))) - } - Expr::Exists(Exists { subquery, negated }) => { - let new_plan = analyze_internal(self.schema, &subquery.subquery)?; - Ok(Transformed::yes(Expr::Exists(Exists { - subquery: Subquery { - subquery: Arc::new(new_plan), - outer_ref_columns: subquery.outer_ref_columns, - }, - negated, - }))) - } + Expr::ScalarSubquery(subsquery) => self + .coerce_in_subquery(subsquery)? + .map_data(|subquery| Ok(Expr::ScalarSubquery(subquery))), + Expr::Exists(Exists { subquery, negated }) => self + .coerce_in_subquery(subquery)? + .map_data(|subquery| Ok(Expr::Exists(Exists { subquery, negated }))), Expr::InSubquery(InSubquery { expr, subquery, negated, }) => { - let new_plan = analyze_internal(self.schema, &subquery.subquery)?; - let expr_type = expr.get_type(self.schema)?; - let subquery_type = new_plan.schema().field(0).data_type(); - let common_type = comparison_coercion(&expr_type, subquery_type).ok_or(plan_datafusion_err!( - "expr type {expr_type:?} can't cast to {subquery_type:?} in InSubquery" - ), - )?; - let new_subquery = Subquery { - subquery: Arc::new(new_plan), - outer_ref_columns: subquery.outer_ref_columns, - }; - Ok(Transformed::yes(Expr::InSubquery(InSubquery::new( - Box::new(expr.cast_to(&common_type, self.schema)?), - cast_subquery(new_subquery, &common_type)?, - negated, - )))) + self.coerce_in_subquery(subquery)? + .transform_data(|subquery| { + let new_plan = &subquery.subquery; + let expr_type = expr.get_type(self.schema)?; + let subquery_type = new_plan.schema().field(0).data_type(); + let common_type = comparison_coercion(&expr_type, subquery_type).ok_or(plan_datafusion_err!( + "expr type {expr_type:?} can't cast to {subquery_type:?} in InSubquery" + ))?; + // coerce the expr and subquery to the common type if needed + let expr = self.maybe_cast_from(*expr, &expr_type, &common_type)?; + let subquery = if subquery.data_type() != &common_type { + Transformed::yes(cast_subquery(subquery, &common_type)?) + } else { + Transformed::no(subquery) + }; + // transformation was applied if either the expr or the + // subquery transformed + let transformed = expr.transformed || subquery.transformed; + // create output + let new_expr = Expr::InSubquery(InSubquery::new( + Box::new(expr.data), + subquery.data, + negated, + )); + if transformed { + Ok(Transformed::yes(new_expr)) + } else { + Ok(Transformed::no(new_expr)) + } + }) } - Expr::Not(expr) => Ok(Transformed::yes(not(get_casted_expr_for_bool_op( - *expr, - self.schema, - )?))), - Expr::IsTrue(expr) => Ok(Transformed::yes(is_true( - get_casted_expr_for_bool_op(*expr, self.schema)?, - ))), - Expr::IsNotTrue(expr) => Ok(Transformed::yes(is_not_true( - get_casted_expr_for_bool_op(*expr, self.schema)?, - ))), - Expr::IsFalse(expr) => Ok(Transformed::yes(is_false( - get_casted_expr_for_bool_op(*expr, self.schema)?, - ))), - Expr::IsNotFalse(expr) => Ok(Transformed::yes(is_not_false( - get_casted_expr_for_bool_op(*expr, self.schema)?, - ))), - Expr::IsUnknown(expr) => Ok(Transformed::yes(is_unknown( - get_casted_expr_for_bool_op(*expr, self.schema)?, - ))), - Expr::IsNotUnknown(expr) => Ok(Transformed::yes(is_not_unknown( - get_casted_expr_for_bool_op(*expr, self.schema)?, - ))), + Expr::Not(expr) => self.coerce_bool_op(*expr)?.map_data(|expr| Ok(not(expr))), + Expr::IsTrue(expr) => self + .coerce_bool_op(*expr)? + .map_data(|expr| Ok(is_true(expr))), + Expr::IsNotTrue(expr) => self + .coerce_bool_op(*expr)? + .map_data(|expr| Ok(is_not_true(expr))), + Expr::IsFalse(expr) => self + .coerce_bool_op(*expr)? + .map_data(|expr| Ok(is_false(expr))), + Expr::IsNotFalse(expr) => self + .coerce_bool_op(*expr)? + .map_data(|expr| Ok(is_not_false(expr))), + Expr::IsUnknown(expr) => self + .coerce_bool_op(*expr)? + .map_data(|expr| Ok(is_unknown(expr))), + Expr::IsNotUnknown(expr) => self + .coerce_bool_op(*expr)? + .map_data(|expr| Ok(is_not_unknown(expr))), Expr::Like(Like { negated, expr, @@ -209,28 +251,32 @@ impl<'a> TreeNodeRewriter for TypeCoercionRewriter<'a> { "There isn't a common type to coerce {left_type} and {right_type} in {op_name} expression" ) })?; - let expr = Box::new(expr.cast_to(&coerced_type, self.schema)?); - let pattern = Box::new(pattern.cast_to(&coerced_type, self.schema)?); - Ok(Transformed::yes(Expr::Like(Like::new( + let mut caster = self.caster(); + + let expr = + Box::new(caster.maybe_cast_from(*expr, &left_type, &coerced_type)?); + let pattern = Box::new(caster.maybe_cast_from( + *pattern, + &right_type, + &coerced_type, + )?); + caster.build(Expr::Like(Like::new( negated, expr, pattern, escape_char, case_insensitive, - )))) - } - Expr::BinaryExpr(BinaryExpr { left, op, right }) => { - let (left_type, right_type) = get_input_types( - &left.get_type(self.schema)?, - &op, - &right.get_type(self.schema)?, - )?; - Ok(Transformed::yes(Expr::BinaryExpr(BinaryExpr::new( - Box::new(left.cast_to(&left_type, self.schema)?), - op, - Box::new(right.cast_to(&right_type, self.schema)?), - )))) + ))) } + Expr::BinaryExpr(BinaryExpr { left, op, right }) => self + .coerce_binary_op(*left, op, *right)? + .map_data(|(left, right)| { + Ok(Expr::BinaryExpr(BinaryExpr::new( + Box::new(left), + op, + Box::new(right), + ))) + }), Expr::Between(Between { expr, negated, @@ -259,12 +305,16 @@ impl<'a> TreeNodeRewriter for TypeCoercionRewriter<'a> { "Failed to coerce types {expr_type} and {high_type} in BETWEEN expression" )) })?; - Ok(Transformed::yes(Expr::Between(Between::new( - Box::new(expr.cast_to(&coercion_type, self.schema)?), + let mut caster = self.caster(); + let expr = caster.maybe_cast_from(*expr, &expr_type, &coercion_type)?; + let low = caster.maybe_cast_from(*low, &low_type, &coercion_type)?; + let high = caster.maybe_cast_from(*high, &high_type, &coercion_type)?; + caster.build(Expr::Between(Between::new( + Box::new(expr), negated, - Box::new(low.cast_to(&coercion_type, self.schema)?), - Box::new(high.cast_to(&coercion_type, self.schema)?), - )))) + Box::new(low), + Box::new(high), + ))) } Expr::InList(InList { expr, @@ -278,43 +328,37 @@ impl<'a> TreeNodeRewriter for TypeCoercionRewriter<'a> { .collect::>>()?; let result_type = get_coerce_type_for_list(&expr_data_type, &list_data_types); - match result_type { - None => plan_err!( + let Some(coerced_type) = result_type else { + return plan_err!( "Can not find compatible types to compare {expr_data_type:?} with {list_data_types:?}" - ), - Some(coerced_type) => { - // find the coerced type - let cast_expr = expr.cast_to(&coerced_type, self.schema)?; - let cast_list_expr = list - .into_iter() - .map(|list_expr| { - list_expr.cast_to(&coerced_type, self.schema) - }) - .collect::>>()?; - Ok(Transformed::yes(Expr::InList(InList ::new( - Box::new(cast_expr), - cast_list_expr, - negated, - )))) - } - } - } - Expr::Case(case) => { - let case = coerce_case_expression(case, self.schema)?; - Ok(Transformed::yes(Expr::Case(case))) + ); + }; + // find the coerced type + let mut caster = self.caster(); + let cast_expr = + caster.maybe_cast_from(*expr, &expr_data_type, &coerced_type)?; + let cast_list_expr = list + .into_iter() + .map(|list_expr| caster.maybe_cast(list_expr, &coerced_type)) + .collect::>>()?; + caster.build(Expr::InList(InList::new( + Box::new(cast_expr), + cast_list_expr, + negated, + ))) } + Expr::Case(case) => self + .coerce_case_expression(case)? + .map_data(|case| Ok(Expr::Case(case))), Expr::ScalarFunction(ScalarFunction { func_def, args }) => match func_def { - ScalarFunctionDefinition::UDF(fun) => { - let new_expr = coerce_arguments_for_signature( - args, - self.schema, - fun.signature(), - )?; - let new_expr = coerce_arguments_for_fun(new_expr, self.schema, &fun)?; - Ok(Transformed::yes(Expr::ScalarFunction( - ScalarFunction::new_udf(fun, new_expr), - ))) - } + ScalarFunctionDefinition::UDF(fun) => self + .coerce_arguments_for_signature(args, fun.signature())? + .transform_data(|new_expr| { + self.coerce_arguments_for_fun(new_expr, &fun) + })? + .map_data(|new_expr| { + Ok(Expr::ScalarFunction(ScalarFunction::new_udf(fun, new_expr))) + }), }, Expr::AggregateFunction(expr::AggregateFunction { func_def, @@ -324,41 +368,30 @@ impl<'a> TreeNodeRewriter for TypeCoercionRewriter<'a> { order_by, null_treatment, }) => match func_def { - AggregateFunctionDefinition::BuiltIn(fun) => { - let new_expr = coerce_agg_exprs_for_signature( - &fun, - args, - self.schema, - &fun.signature(), - )?; - Ok(Transformed::yes(Expr::AggregateFunction( - expr::AggregateFunction::new( + AggregateFunctionDefinition::BuiltIn(fun) => self + .coerce_agg_exprs_for_signature(&fun, args, &fun.signature())? + .map_data(|new_expr| { + Ok(Expr::AggregateFunction(expr::AggregateFunction::new( fun, new_expr, distinct, filter, order_by, null_treatment, - ), - ))) - } - AggregateFunctionDefinition::UDF(fun) => { - let new_expr = coerce_arguments_for_signature( - args, - self.schema, - fun.signature(), - )?; - Ok(Transformed::yes(Expr::AggregateFunction( - expr::AggregateFunction::new_udf( + ))) + }), + AggregateFunctionDefinition::UDF(fun) => self + .coerce_arguments_for_signature(args, fun.signature())? + .map_data(|new_expr| { + Ok(Expr::AggregateFunction(expr::AggregateFunction::new_udf( fun, new_expr, false, filter, order_by, null_treatment, - ), - ))) - } + ))) + }), AggregateFunctionDefinition::Name(_) => { internal_err!("Function `Expr` with name should be resolved.") } @@ -371,29 +404,29 @@ impl<'a> TreeNodeRewriter for TypeCoercionRewriter<'a> { window_frame, null_treatment, }) => { - let window_frame = - coerce_window_frame(window_frame, self.schema, &order_by)?; + let window_frame = self.coerce_window_frame(window_frame, &order_by)?; let args = match &fun { expr::WindowFunctionDefinition::AggregateFunction(fun) => { - coerce_agg_exprs_for_signature( - fun, - args, - self.schema, - &fun.signature(), - )? + self.coerce_agg_exprs_for_signature(fun, args, &fun.signature())? } - _ => args, + _ => Transformed::no(args), }; - - Ok(Transformed::yes(Expr::WindowFunction(WindowFunction::new( + let transformed = window_frame.transformed | args.transformed; + let window_func = Expr::WindowFunction(WindowFunction::new( fun, - args, + args.data, partition_by, order_by, - window_frame, + window_frame.data, null_treatment, - )))) + )); + + if transformed { + Ok(Transformed::yes(window_func)) + } else { + Ok(Transformed::no(window_func)) + } } Expr::Alias(_) | Expr::Column(_) @@ -414,313 +447,512 @@ impl<'a> TreeNodeRewriter for TypeCoercionRewriter<'a> { } } } +impl<'a> TypeCoercionExprRewriter<'a> { + /// Casts expr to new_type if it is different from the current type + /// + /// Note: current_type is passed in as an optimization to avoid recomputing the type + fn maybe_cast_from( + &self, + expr: Expr, + cur_type: &DataType, + new_type: &DataType, + ) -> Result> { + let mut caster = self.caster(); + let expr = caster.maybe_cast_from(expr, cur_type, new_type)?; + caster.build(expr) + } + + /// Create a [`Caster`] for tracking when expressions have been changed + fn caster(&self) -> Caster { + Caster::new(self.schema) + } + + /// Coerce the subquery recursively + fn coerce_in_subquery(&self, subquery: Subquery) -> Result> { + let Subquery { + subquery, + outer_ref_columns, + } = subquery; + + // use TypeCoercionPlanRewriter rather than calling analyze_internal + // directly to pass along info if any of the subquery inputs were + // transformed + let mut subquery_rewriter = TypeCoercionPlanRewriter::new(self.schema); + unwrap_arc(subquery) + .rewrite_with_subqueries(&mut subquery_rewriter)? + .map_data(|new_plan| { + Ok(Subquery { + subquery: Arc::new(new_plan), + outer_ref_columns, + }) + }) + } -/// Casts the given `value` to `target_type`. Note that this function -/// only considers `Null` or `Utf8` values. -fn coerce_scalar(target_type: &DataType, value: &ScalarValue) -> Result { - match value { - // Coerce Utf8 values: - ScalarValue::Utf8(Some(val)) => { - ScalarValue::try_from_string(val.clone(), target_type) + /// Coerce join equality expressions + /// + /// Joins must be treated specially as their equality expressions are stored + /// as a parallel list of left and right expressions, rather than a single + /// equality expression + /// + /// For example, on_exprs like `t1.a = t2.b AND t1.x = t2.y` will be stored + /// as a list of `(t1.a, t2.b), (t1.x, t2.y)` + fn coerce_joins(&mut self, plan: LogicalPlan) -> Result> { + let LogicalPlan::Join(Join { + left, + right, + on, + filter, + join_type, + join_constraint, + schema, + null_equals_null, + }) = plan + else { + return Ok(Transformed::no(plan)); + }; + + // apply the coercion to each equality expression + let on = on.into_iter().map_until_stop_and_collect(|(lhs, rhs)| { + // coerce the arguments as though they were a single binary equality + // expression + self.coerce_binary_op(lhs, Operator::Eq, rhs) + })?; + + // pass the transformed flag back up + on.map_data(|on| { + Ok(LogicalPlan::Join(Join { + left, + right, + on, + filter, + join_type, + join_constraint, + schema, + null_equals_null, + })) + }) + } + fn coerce_binary_op( + &self, + left: Expr, + op: Operator, + right: Expr, + ) -> Result> { + let initial_left_type = left.get_type(self.schema)?; + let initial_right_type = right.get_type(self.schema)?; + + let (left_type, right_type) = + get_input_types(&initial_left_type, &op, &initial_right_type)?; + + let mut caster = Caster::new(self.schema); + let result = ( + caster.maybe_cast_from(left, &initial_left_type, &left_type)?, + caster.maybe_cast_from(right, &initial_right_type, &right_type)?, + ); + caster.build(result) + } + + /// Support the `IsTrue` `IsNotTrue` `IsFalse` `IsNotFalse` type coercion by coercing + /// the input expr to `Boolean` type. + /// The above Exprs will be rewrite to the binary op when creating the physical op. + fn coerce_bool_op(&self, expr: Expr) -> Result> { + let left_type = expr.get_type(self.schema)?; + // error check + get_input_types(&left_type, &Operator::IsDistinctFrom, &DataType::Boolean)?; + self.maybe_cast_from(expr, &left_type, &DataType::Boolean) + } + + /// Casts the given `value` to `target_type`. Note that this function + /// only considers `Null` or `Utf8` values. + fn coerce_scalar( + &self, + target_type: &DataType, + value: &ScalarValue, + ) -> Result> { + match value { + // Coerce Utf8 values: + ScalarValue::Utf8(Some(val)) => { + ScalarValue::try_from_string(val.clone(), target_type) + .map(Transformed::yes) + } + s => { + if s.is_null() { + // Coerce `Null` values: + ScalarValue::try_from(target_type).map(Transformed::yes) + } else { + // Values except `Utf8`/`Null` variants already have the right type + // (casted before) since we convert `sqlparser` outputs to `Utf8` + // for all possible cases. Therefore, we return a clone here. + Ok(Transformed::no(s.clone())) + } + } } - s => { - if s.is_null() { - // Coerce `Null` values: - ScalarValue::try_from(target_type) + } + + /// This function coerces `value` to `target_type` in a range-aware fashion. + /// If the coercion is successful, we return an `Ok` value with the result. + /// If the coercion fails because `target_type` is not wide enough (i.e. we + /// can not coerce to `target_type`, but we can to a wider type in the same + /// family), we return a `Null` value of this type to signal this situation. + /// Downstream code uses this signal to treat these values as *unbounded*. + fn coerce_scalar_range_aware( + &self, + target_type: &DataType, + value: ScalarValue, + ) -> Result> { + self.coerce_scalar(target_type, &value).or_else(|err| { + // If type coercion fails, check if the largest type in family works: + if let Some(largest_type) = self.get_widest_type_in_family(target_type) { + self.coerce_scalar(largest_type, &value).map_or_else( + |_| exec_err!("Cannot cast {value:?} to {target_type:?}"), + |_| ScalarValue::try_from(target_type).map(Transformed::yes), + ) } else { - // Values except `Utf8`/`Null` variants already have the right type - // (casted before) since we convert `sqlparser` outputs to `Utf8` - // for all possible cases. Therefore, we return a clone here. - Ok(s.clone()) + Err(err) } - } + }) } -} -/// This function coerces `value` to `target_type` in a range-aware fashion. -/// If the coercion is successful, we return an `Ok` value with the result. -/// If the coercion fails because `target_type` is not wide enough (i.e. we -/// can not coerce to `target_type`, but we can to a wider type in the same -/// family), we return a `Null` value of this type to signal this situation. -/// Downstream code uses this signal to treat these values as *unbounded*. -fn coerce_scalar_range_aware( - target_type: &DataType, - value: ScalarValue, -) -> Result { - coerce_scalar(target_type, &value).or_else(|err| { - // If type coercion fails, check if the largest type in family works: - if let Some(largest_type) = get_widest_type_in_family(target_type) { - coerce_scalar(largest_type, &value).map_or_else( - |_| exec_err!("Cannot cast {value:?} to {target_type:?}"), - |_| ScalarValue::try_from(target_type), - ) - } else { - Err(err) + /// This function returns the widest type in the family of `given_type`. + /// If the given type is already the widest type, it returns `None`. + /// For example, if `given_type` is `Int8`, it returns `Int64`. + fn get_widest_type_in_family(&self, given_type: &DataType) -> Option<&DataType> { + match given_type { + DataType::UInt8 | DataType::UInt16 | DataType::UInt32 => { + Some(&DataType::UInt64) + } + DataType::Int8 | DataType::Int16 | DataType::Int32 => Some(&DataType::Int64), + DataType::Float16 | DataType::Float32 => Some(&DataType::Float64), + _ => None, } - }) -} - -/// This function returns the widest type in the family of `given_type`. -/// If the given type is already the widest type, it returns `None`. -/// For example, if `given_type` is `Int8`, it returns `Int64`. -fn get_widest_type_in_family(given_type: &DataType) -> Option<&DataType> { - match given_type { - DataType::UInt8 | DataType::UInt16 | DataType::UInt32 => Some(&DataType::UInt64), - DataType::Int8 | DataType::Int16 | DataType::Int32 => Some(&DataType::Int64), - DataType::Float16 | DataType::Float32 => Some(&DataType::Float64), - _ => None, } -} -/// Coerces the given (window frame) `bound` to `target_type`. -fn coerce_frame_bound( - target_type: &DataType, - bound: WindowFrameBound, -) -> Result { - match bound { - WindowFrameBound::Preceding(v) => { - coerce_scalar_range_aware(target_type, v).map(WindowFrameBound::Preceding) - } - WindowFrameBound::CurrentRow => Ok(WindowFrameBound::CurrentRow), - WindowFrameBound::Following(v) => { - coerce_scalar_range_aware(target_type, v).map(WindowFrameBound::Following) - } + /// Coerces the given (window frame) `bound` to `target_type`. + fn coerce_frame_bound( + &self, + target_type: &DataType, + bound: WindowFrameBound, + ) -> Result> { + Ok(match bound { + WindowFrameBound::Preceding(v) => self + .coerce_scalar_range_aware(target_type, v)? + .update_data(WindowFrameBound::Preceding), + WindowFrameBound::CurrentRow => Transformed::no(WindowFrameBound::CurrentRow), + WindowFrameBound::Following(v) => self + .coerce_scalar_range_aware(target_type, v)? + .update_data(WindowFrameBound::Following), + }) } -} -// Coerces the given `window_frame` to use appropriate natural types. -// For example, ROWS and GROUPS frames use `UInt64` during calculations. -fn coerce_window_frame( - window_frame: WindowFrame, - schema: &DFSchema, - expressions: &[Expr], -) -> Result { - let mut window_frame = window_frame; - let current_types = expressions - .iter() - .map(|e| e.get_type(schema)) - .collect::>>()?; - let target_type = match window_frame.units { - WindowFrameUnits::Range => { - if let Some(col_type) = current_types.first() { - if col_type.is_numeric() - || is_utf8_or_large_utf8(col_type) - || matches!(col_type, DataType::Null) - { - col_type - } else if is_datetime(col_type) { - &DataType::Interval(IntervalUnit::MonthDayNano) + // Coerces the given `window_frame` to use appropriate natural types. + // For example, ROWS and GROUPS frames use `UInt64` during calculations. + fn coerce_window_frame( + &self, + mut window_frame: WindowFrame, + expressions: &[Expr], + ) -> Result> { + let current_types = expressions + .iter() + .map(|e| e.get_type(self.schema)) + .collect::>>()?; + let target_type = match window_frame.units { + WindowFrameUnits::Range => { + if let Some(col_type) = current_types.first() { + if col_type.is_numeric() + || is_utf8_or_large_utf8(col_type) + || matches!(col_type, DataType::Null) + { + col_type + } else if is_datetime(col_type) { + &DataType::Interval(IntervalUnit::MonthDayNano) + } else { + return internal_err!( + "Cannot run range queries on datatype: {col_type:?}" + ); + } } else { - return internal_err!( - "Cannot run range queries on datatype: {col_type:?}" - ); + return internal_err!("ORDER BY column cannot be empty"); } - } else { - return internal_err!("ORDER BY column cannot be empty"); } - } - WindowFrameUnits::Rows | WindowFrameUnits::Groups => &DataType::UInt64, - }; - window_frame.start_bound = coerce_frame_bound(target_type, window_frame.start_bound)?; - window_frame.end_bound = coerce_frame_bound(target_type, window_frame.end_bound)?; - Ok(window_frame) -} + WindowFrameUnits::Rows | WindowFrameUnits::Groups => &DataType::UInt64, + }; -// Support the `IsTrue` `IsNotTrue` `IsFalse` `IsNotFalse` type coercion. -// The above op will be rewrite to the binary op when creating the physical op. -fn get_casted_expr_for_bool_op(expr: Expr, schema: &DFSchema) -> Result { - let left_type = expr.get_type(schema)?; - get_input_types(&left_type, &Operator::IsDistinctFrom, &DataType::Boolean)?; - expr.cast_to(&DataType::Boolean, schema) -} + let start_bound = + self.coerce_frame_bound(target_type, window_frame.start_bound)?; + let end_bound = self.coerce_frame_bound(target_type, window_frame.end_bound)?; + let transformed = start_bound.transformed | end_bound.transformed; -/// Returns `expressions` coerced to types compatible with -/// `signature`, if possible. -/// -/// See the module level documentation for more detail on coercion. -fn coerce_arguments_for_signature( - expressions: Vec, - schema: &DFSchema, - signature: &Signature, -) -> Result> { - if expressions.is_empty() { - return Ok(expressions); + window_frame.start_bound = start_bound.data; + window_frame.end_bound = end_bound.data; + + if transformed { + Ok(Transformed::yes(window_frame)) + } else { + Ok(Transformed::no(window_frame)) + } } - let current_types = expressions - .iter() - .map(|e| e.get_type(schema)) - .collect::>>()?; + /// Returns `expressions` coerced to types compatible with + /// `signature`, if possible. + /// + /// See the module level documentation for more detail on coercion. + fn coerce_arguments_for_signature( + &self, + expressions: Vec, + signature: &Signature, + ) -> Result>> { + if expressions.is_empty() { + return Ok(Transformed::no(expressions)); + } - let new_types = data_types(¤t_types, signature)?; + let current_types = expressions + .iter() + .map(|e| e.get_type(self.schema)) + .collect::>>()?; - expressions - .into_iter() - .enumerate() - .map(|(i, expr)| expr.cast_to(&new_types[i], schema)) - .collect() -} + let new_types = data_types(¤t_types, signature)?; -fn coerce_arguments_for_fun( - expressions: Vec, - schema: &DFSchema, - fun: &Arc, -) -> Result> { - // Cast Fixedsizelist to List for array functions - if fun.name() == "make_array" { - expressions + let mut caster = self.caster(); + let new_expressions = expressions .into_iter() - .map(|expr| { - let data_type = expr.get_type(schema).unwrap(); - if let DataType::FixedSizeList(field, _) = data_type { - let to_type = DataType::List(field.clone()); - expr.cast_to(&to_type, schema) - } else { - Ok(expr) - } + .enumerate() + .map(|(i, expr)| { + caster.maybe_cast_from(expr, ¤t_types[i], &new_types[i]) }) - .collect() - } else { - Ok(expressions) - } -} + .collect::>>()?; -/// Returns the coerced exprs for each `input_exprs`. -/// Get the coerced data type from `aggregate_rule::coerce_types` and add `try_cast` if the -/// data type of `input_exprs` need to be coerced. -fn coerce_agg_exprs_for_signature( - agg_fun: &AggregateFunction, - input_exprs: Vec, - schema: &DFSchema, - signature: &Signature, -) -> Result> { - if input_exprs.is_empty() { - return Ok(input_exprs); + caster.build(new_expressions) } - let current_types = input_exprs - .iter() - .map(|e| e.get_type(schema)) - .collect::>>()?; - - let coerced_types = - type_coercion::aggregates::coerce_types(agg_fun, ¤t_types, signature)?; - - input_exprs - .into_iter() - .enumerate() - .map(|(i, expr)| expr.cast_to(&coerced_types[i], schema)) - .collect() -} -fn coerce_case_expression(case: Case, schema: &DFSchema) -> Result { - // Given expressions like: - // - // CASE a1 - // WHEN a2 THEN b1 - // WHEN a3 THEN b2 - // ELSE b3 - // END - // - // or: - // - // CASE - // WHEN x1 THEN b1 - // WHEN x2 THEN b2 - // ELSE b3 - // END - // - // Then all aN (a1, a2, a3) must be converted to a common data type in the first example - // (case-when expression coercion) - // - // All xN (x1, x2) must be converted to a boolean data type in the second example - // (when-boolean expression coercion) - // - // And all bN (b1, b2, b3) must be converted to a common data type in both examples - // (then-else expression coercion) - // - // If any fail to find and cast to a common/specific data type, will return error - // - // Note that case-when and when-boolean expression coercions are mutually exclusive - // Only one or the other can occur for a case expression, whilst then-else expression coercion will always occur - - // prepare types - let case_type = case - .expr - .as_ref() - .map(|expr| expr.get_type(schema)) - .transpose()?; - let then_types = case - .when_then_expr - .iter() - .map(|(_when, then)| then.get_type(schema)) - .collect::>>()?; - let else_type = case - .else_expr - .as_ref() - .map(|expr| expr.get_type(schema)) - .transpose()?; - - // find common coercible types - let case_when_coerce_type = case_type - .as_ref() - .map(|case_type| { - let when_types = case - .when_then_expr - .iter() - .map(|(when, _then)| when.get_type(schema)) + fn coerce_arguments_for_fun( + &self, + expressions: Vec, + fun: &Arc, + ) -> Result>> { + // Cast Fixedsizelist to List for array functions + if fun.name() == "make_array" { + let mut caster = self.caster(); + let new_expressions = expressions + .into_iter() + .map(|expr| { + let data_type = expr.get_type(self.schema)?; + if let DataType::FixedSizeList(field, _) = &data_type { + let to_type = DataType::List(field.clone()); + caster.maybe_cast_from(expr, &data_type, &to_type) + } else { + Ok(expr) + } + }) .collect::>>()?; - let coerced_type = - get_coerce_type_for_case_expression(&when_types, Some(case_type)); - coerced_type.ok_or_else(|| { - plan_datafusion_err!( + // Hack: force the schema to be computed again even if seemingly + // nothing changed + // caster.build(new_expressions) + Ok(Transformed::yes(new_expressions)) + } else { + Ok(Transformed::no(expressions)) + } + } + + /// Returns the coerced exprs for each `input_exprs`. + /// Get the coerced data type from `aggregate_rule::coerce_types` and add `try_cast` if the + /// data type of `input_exprs` need to be coerced. + fn coerce_agg_exprs_for_signature( + &self, + agg_fun: &AggregateFunction, + input_exprs: Vec, + signature: &Signature, + ) -> Result>> { + if input_exprs.is_empty() { + return Ok(Transformed::no(input_exprs)); + } + let current_types = input_exprs + .iter() + .map(|e| e.get_type(self.schema)) + .collect::>>()?; + + let coerced_types = + type_coercion::aggregates::coerce_types(agg_fun, ¤t_types, signature)?; + + let mut caster = Caster::new(self.schema); + let new_exprs = input_exprs + .into_iter() + .enumerate() + .map(|(i, expr)| { + caster.maybe_cast_from(expr, ¤t_types[i], &coerced_types[i]) + }) + .collect::>>()?; + caster.build(new_exprs) + } + + /// Given expressions like: + /// + /// ```sql + /// CASE a1 + /// WHEN a2 THEN b1 + /// WHEN a3 THEN b2 + /// ELSE b3 + /// END + /// ``` + /// + /// or: + /// + /// ```sql + /// CASE + /// WHEN x1 THEN b1 + /// WHEN x2 THEN b2 + /// ELSE b3 + /// END + /// ``` + /// + /// Then all aN (a1, a2, a3) must be converted to a common data type in the first example + /// (case-when expression coercion) + /// + /// All xN (x1, x2) must be converted to a boolean data type in the second example + /// (when-boolean expression coercion) + /// + /// And all bN (b1, b2, b3) must be converted to a common data type in both examples + /// (then-else expression coercion) + /// + /// If any fail to find and cast to a common/specific data type, will return error + /// + /// Note that case-when and when-boolean expression coercions are mutually + /// exclusive Only one or the other can occur for a case expression, whilst + /// then-else expression coercion will always occur + fn coerce_case_expression(&self, case: Case) -> Result> { + // prepare types + let case_type = case + .expr + .as_ref() + .map(|expr| expr.get_type(self.schema)) + .transpose()?; + let then_types = case + .when_then_expr + .iter() + .map(|(_when, then)| then.get_type(self.schema)) + .collect::>>()?; + let else_type = case + .else_expr + .as_ref() + .map(|expr| expr.get_type(self.schema)) + .transpose()?; + + // find common coercible types + let case_when_coerce_type = case_type + .as_ref() + .map(|case_type| { + let when_types = case + .when_then_expr + .iter() + .map(|(when, _then)| when.get_type(self.schema)) + .collect::>>()?; + let coerced_type = + get_coerce_type_for_case_expression(&when_types, Some(case_type)); + coerced_type.ok_or_else(|| { + plan_datafusion_err!( "Failed to coerce case ({case_type:?}) and when ({when_types:?}) \ to common types in CASE WHEN expression" ) + }) }) - }) - .transpose()?; - let then_else_coerce_type = - get_coerce_type_for_case_expression(&then_types, else_type.as_ref()).ok_or_else( - || { - plan_datafusion_err!( + .transpose()?; + let then_else_coerce_type = + get_coerce_type_for_case_expression(&then_types, else_type.as_ref()) + .ok_or_else(|| { + plan_datafusion_err!( "Failed to coerce then ({then_types:?}) and else ({else_type:?}) \ to common types in CASE WHEN expression" ) - }, - )?; - - // do cast if found common coercible types - let case_expr = case - .expr - .zip(case_when_coerce_type.as_ref()) - .map(|(case_expr, coercible_type)| case_expr.cast_to(coercible_type, schema)) - .transpose()? - .map(Box::new); - let when_then = case - .when_then_expr - .into_iter() - .map(|(when, then)| { - let when_type = case_when_coerce_type.as_ref().unwrap_or(&DataType::Boolean); - let when = when.cast_to(when_type, schema).map_err(|e| { - DataFusionError::Context( - format!( - "WHEN expressions in CASE couldn't be \ + })?; + + // do cast if found common coercible types + let mut caster = Caster::new(self.schema); + let case_expr = case + .expr + .zip(case_when_coerce_type.as_ref()) + .map(|(case_expr, coercible_type)| { + caster.maybe_cast(*case_expr, coercible_type) + }) + .transpose()? + .map(Box::new); + let when_then = case + .when_then_expr + .into_iter() + .map(|(when, then)| { + let when_type = + case_when_coerce_type.as_ref().unwrap_or(&DataType::Boolean); + let when = caster.maybe_cast(*when, when_type).map_err(|e| { + DataFusionError::Context( + format!( + "WHEN expressions in CASE couldn't be \ converted to common type ({when_type})" - ), - Box::new(e), - ) - })?; - let then = then.cast_to(&then_else_coerce_type, schema)?; - Ok((Box::new(when), Box::new(then))) + ), + Box::new(e), + ) + })?; + let then = caster.maybe_cast(*then, &then_else_coerce_type)?; + Ok((Box::new(when), Box::new(then))) + }) + .collect::>>()?; + let else_expr = case + .else_expr + .map(|expr| caster.maybe_cast(*expr, &then_else_coerce_type)) + .transpose()? + .map(Box::new); + + caster.build(Case::new(case_expr, when_then, else_expr)) + } +} + +/// Casts exprs to new types, tracking if any types have been changed +/// +/// This is used when casting multiple expressions to track if any have been +/// transformed to set the transformed flag correctly +#[derive(Debug)] +struct Caster<'a> { + schema: &'a DFSchema, + transformed: bool, +} + +impl<'a> Caster<'a> { + fn new(schema: &'a DFSchema) -> Self { + Self { + schema, + transformed: false, + } + } + + /// cast expr to new_type if it is different from the current type + fn maybe_cast(&mut self, expr: Expr, new_type: &DataType) -> Result { + let original_type = expr.get_type(self.schema)?; + self.maybe_cast_from(expr, &original_type, new_type) + } + + /// Casts expr to new_type if it is different from the current type + /// + /// Note: current_type is passed in as an optimziation to avoid recomputing the type + fn maybe_cast_from( + &mut self, + expr: Expr, + cur_type: &DataType, + new_type: &DataType, + ) -> Result { + if cur_type == new_type { + Ok(expr) + } else { + self.transformed = true; + expr.cast_to(new_type, self.schema) + } + } + + /// Returns a Transformed::yes/Transformed::no based on whether any expr has + /// been transformed + fn build(self, arg: T) -> Result> { + Ok(if self.transformed { + Transformed::yes(arg) + } else { + Transformed::no(arg) }) - .collect::>>()?; - let else_expr = case - .else_expr - .map(|expr| expr.cast_to(&then_else_coerce_type, schema)) - .transpose()? - .map(Box::new); - - Ok(Case::new(case_expr, when_then, else_expr)) + } } #[cfg(test)] @@ -744,7 +976,7 @@ mod test { use datafusion_physical_expr::expressions::AvgAccumulator; use crate::analyzer::type_coercion::{ - coerce_case_expression, TypeCoercion, TypeCoercionRewriter, + coerce_case_expression, TypeCoercion, TypeCoercionExprRewriter, }; use crate::test::assert_analyzed_plan_eq; @@ -1235,7 +1467,7 @@ mod test { vec![Field::new("a", DataType::Int64, true)].into(), std::collections::HashMap::new(), )?); - let mut rewriter = TypeCoercionRewriter { schema: &schema }; + let mut rewriter = TypeCoercionExprRewriter { schema: &schema }; let expr = is_true(lit(12i32).gt(lit(13i64))); let expected = is_true(cast(lit(12i32), DataType::Int64).gt(lit(13i64))); let result = expr.rewrite(&mut rewriter).data()?; @@ -1246,7 +1478,7 @@ mod test { vec![Field::new("a", DataType::Int64, true)].into(), std::collections::HashMap::new(), )?); - let mut rewriter = TypeCoercionRewriter { schema: &schema }; + let mut rewriter = TypeCoercionExprRewriter { schema: &schema }; let expr = is_true(lit(12i32).eq(lit(13i64))); let expected = is_true(cast(lit(12i32), DataType::Int64).eq(lit(13i64))); let result = expr.rewrite(&mut rewriter).data()?; @@ -1257,7 +1489,7 @@ mod test { vec![Field::new("a", DataType::Int64, true)].into(), std::collections::HashMap::new(), )?); - let mut rewriter = TypeCoercionRewriter { schema: &schema }; + let mut rewriter = TypeCoercionExprRewriter { schema: &schema }; let expr = is_true(lit(12i32).lt(lit(13i64))); let expected = is_true(cast(lit(12i32), DataType::Int64).lt(lit(13i64))); let result = expr.rewrite(&mut rewriter).data()?; diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 4d7a207afb1b..b6979f45a730 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -41,7 +41,7 @@ use datafusion_expr::{ use datafusion_expr::{expr::ScalarFunction, interval_arithmetic::NullableInterval}; use datafusion_physical_expr::{create_physical_expr, execution_props::ExecutionProps}; -use crate::analyzer::type_coercion::TypeCoercionRewriter; +use crate::analyzer::type_coercion::TypeCoercionExprRewriter; use crate::simplify_expressions::guarantees::GuaranteeRewriter; use crate::simplify_expressions::regex::simplify_regex_expr; use crate::simplify_expressions::SimplifyInfo; @@ -207,7 +207,7 @@ impl ExprSimplifier { /// See the [type coercion module](datafusion_expr::type_coercion) /// documentation for more details on type coercion pub fn coerce(&self, expr: Expr, schema: &DFSchema) -> Result { - let mut expr_rewrite = TypeCoercionRewriter { schema }; + let mut expr_rewrite = TypeCoercionExprRewriter { schema }; expr.rewrite(&mut expr_rewrite).data() }