diff --git a/datafusion/optimizer/src/optimize_unions.rs b/datafusion/optimizer/src/optimize_unions.rs index cfabd512b427..23a6fe95e579 100644 --- a/datafusion/optimizer/src/optimize_unions.rs +++ b/datafusion/optimizer/src/optimize_unions.rs @@ -21,7 +21,7 @@ use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::tree_node::Transformed; use datafusion_common::Result; use datafusion_expr::expr_rewriter::coerce_plan_expr_for_schema; -use datafusion_expr::{Distinct, LogicalPlan, Union}; +use datafusion_expr::{Distinct, LogicalPlan, Projection, Union}; use itertools::Itertools; use std::sync::Arc; @@ -105,6 +105,38 @@ fn extract_plans_from_union(plan: Arc) -> Vec { .into_iter() .map(Arc::unwrap_or_clone) .collect::>(), + // While unnesting, unwrap a Projection whose input is a nested Union, + // flatten the inner Union, and push the same Projection down onto + // each of the nested Union’s children. + // + // Example: + // Union { Projection { Union { plan1, plan2 } }, plan3 } + // => Union { Projection { plan1 }, Projection { plan2 }, plan3 } + LogicalPlan::Projection(Projection { + expr, + input, + schema, + .. + }) => match Arc::unwrap_or_clone(input) { + LogicalPlan::Union(Union { inputs, .. }) => inputs + .into_iter() + .map(Arc::unwrap_or_clone) + .map(|plan| { + LogicalPlan::Projection( + Projection::try_new_with_schema( + expr.clone(), + Arc::new(plan), + Arc::clone(&schema), + ) + .unwrap(), + ) + }) + .collect::>(), + + plan => vec![LogicalPlan::Projection( + Projection::try_new_with_schema(expr, Arc::new(plan), schema).unwrap(), + )], + }, plan => vec![plan], } } @@ -331,6 +363,27 @@ mod tests { ") } + #[test] + fn eliminate_nested_union_in_projection() -> Result<()> { + let plan_builder = table_scan(Some("table"), &schema(), None)?; + + let plan = plan_builder + .clone() + .union(plan_builder.clone().build()?)? + .project(vec![col("id").alias("table_id"), col("key"), col("value")])? + .union(plan_builder.build()?)? + .build()?; + + assert_optimized_plan_equal!(plan, @r" + Union + Projection: id AS table_id, key, value + TableScan: table + Projection: id AS table_id, key, value + TableScan: table + TableScan: table + ") + } + #[test] fn eliminate_nested_union_with_type_cast_projection() -> Result<()> { let table_1 = table_scan(