diff --git a/datafusion/optimizer/src/eliminate_cross_join.rs b/datafusion/optimizer/src/eliminate_cross_join.rs index 923be75748037..9d871c50ad996 100644 --- a/datafusion/optimizer/src/eliminate_cross_join.rs +++ b/datafusion/optimizer/src/eliminate_cross_join.rs @@ -18,11 +18,13 @@ //! [`EliminateCrossJoin`] converts `CROSS JOIN` to `INNER JOIN` if join predicates are available. use std::sync::Arc; -use crate::{utils, OptimizerConfig, OptimizerRule}; +use crate::{OptimizerConfig, OptimizerRule}; use crate::join_key_set::JoinKeySet; -use datafusion_common::{plan_err, Result}; +use datafusion_common::tree_node::{Transformed, TreeNode}; +use datafusion_common::{internal_err, Result}; use datafusion_expr::expr::{BinaryExpr, Expr}; +use datafusion_expr::logical_plan::tree_node::unwrap_arc; use datafusion_expr::logical_plan::{ CrossJoin, Filter, Join, JoinConstraint, JoinType, LogicalPlan, Projection, }; @@ -39,65 +41,109 @@ impl EliminateCrossJoin { } } -/// Attempt to reorder join to eliminate cross joins to inner joins. -/// for queries: -/// 'select ... from a, b where a.x = b.y and b.xx = 100;' -/// 'select ... from a, b where (a.x = b.y and b.xx = 100) or (a.x = b.y and b.xx = 200);' -/// 'select ... from a, b, c where (a.x = b.y and b.xx = 100 and a.z = c.z) -/// or (a.x = b.y and b.xx = 200 and a.z=c.z);' -/// 'select ... from a, b where a.x > b.y' +/// Eliminate cross joins by rewriting them to inner joins when possible. +/// +/// # Example +/// The initial plan for this query: +/// ```sql +/// select ... from a, b where a.x = b.y and b.xx = 100; +/// ``` +/// +/// Looks like this: +/// ```text +/// Filter(a.x = b.y AND b.xx = 100) +/// CrossJoin +/// TableScan a +/// TableScan b +/// ``` +/// +/// After the rule is applied, the plan will look like this: +/// ```text +/// Filter(b.xx = 100) +/// InnerJoin(a.x = b.y) +/// TableScan a +/// TableScan b +/// ``` +/// +/// # Other Examples +/// * 'select ... from a, b where a.x = b.y and b.xx = 100;' +/// * 'select ... from a, b where (a.x = b.y and b.xx = 100) or (a.x = b.y and b.xx = 200);' +/// * 'select ... from a, b, c where (a.x = b.y and b.xx = 100 and a.z = c.z) +/// * or (a.x = b.y and b.xx = 200 and a.z=c.z);' +/// * 'select ... from a, b where a.x > b.y' +/// /// For above queries, the join predicate is available in filters and they are moved to /// join nodes appropriately +/// /// This fix helps to improve the performance of TPCH Q19. issue#78 impl OptimizerRule for EliminateCrossJoin { fn try_optimize( &self, - plan: &LogicalPlan, - config: &dyn OptimizerConfig, + _plan: &LogicalPlan, + _config: &dyn OptimizerConfig, ) -> Result> { + internal_err!("Should have called EliminateCrossJoin::rewrite") + } + + fn supports_rewrite(&self) -> bool { + true + } + + fn rewrite( + &self, + plan: LogicalPlan, + config: &dyn OptimizerConfig, + ) -> Result> { + let plan_schema = plan.schema().clone(); let mut possible_join_keys = JoinKeySet::new(); let mut all_inputs: Vec = vec![]; - let parent_predicate = match plan { - LogicalPlan::Filter(filter) => { - let input = filter.input.as_ref(); - match input { - LogicalPlan::Join(Join { - join_type: JoinType::Inner, - .. - }) - | LogicalPlan::CrossJoin(_) => { - if !try_flatten_join_inputs( - input, - &mut possible_join_keys, - &mut all_inputs, - )? { - return Ok(None); - } - extract_possible_join_keys( - &filter.predicate, - &mut possible_join_keys, - ); - Some(&filter.predicate) - } - _ => { - return utils::optimize_children(self, plan, config); - } - } + + let parent_predicate = if let LogicalPlan::Filter(filter) = plan { + // if input isn't a join that can potentially be rewritten + // avoid unwrapping the input + let rewriteable = matches!( + filter.input.as_ref(), + LogicalPlan::Join(Join { + join_type: JoinType::Inner, + .. + }) | LogicalPlan::CrossJoin(_) + ); + + if !rewriteable { + // recursively try to rewrite children + return rewrite_children(self, LogicalPlan::Filter(filter), config); } + + if !can_flatten_join_inputs(&filter.input) { + return Ok(Transformed::no(LogicalPlan::Filter(filter))); + } + + let Filter { + input, predicate, .. + } = filter; + flatten_join_inputs( + unwrap_arc(input), + &mut possible_join_keys, + &mut all_inputs, + )?; + + extract_possible_join_keys(&predicate, &mut possible_join_keys); + Some(predicate) + } else if matches!( + plan, LogicalPlan::Join(Join { join_type: JoinType::Inner, .. - }) => { - if !try_flatten_join_inputs( - plan, - &mut possible_join_keys, - &mut all_inputs, - )? { - return Ok(None); - } - None + }) + ) { + if !can_flatten_join_inputs(&plan) { + return Ok(Transformed::no(plan)); } - _ => return utils::optimize_children(self, plan, config), + flatten_join_inputs(plan, &mut possible_join_keys, &mut all_inputs)?; + None + } else { + // recursively try to rewrite children + return rewrite_children(self, plan, config); }; // Join keys are handled locally: @@ -105,36 +151,36 @@ impl OptimizerRule for EliminateCrossJoin { let mut left = all_inputs.remove(0); while !all_inputs.is_empty() { left = find_inner_join( - &left, + left, &mut all_inputs, &possible_join_keys, &mut all_join_keys, )?; } - left = utils::optimize_children(self, &left, config)?.unwrap_or(left); + left = rewrite_children(self, left, config)?.data; - if plan.schema() != left.schema() { + if &plan_schema != left.schema() { left = LogicalPlan::Projection(Projection::new_from_schema( Arc::new(left), - plan.schema().clone(), + plan_schema.clone(), )); } let Some(predicate) = parent_predicate else { - return Ok(Some(left)); + return Ok(Transformed::yes(left)); }; // If there are no join keys then do nothing: if all_join_keys.is_empty() { - Filter::try_new(predicate.clone(), Arc::new(left)) - .map(|f| Some(LogicalPlan::Filter(f))) + Filter::try_new(predicate, Arc::new(left)) + .map(|filter| Transformed::yes(LogicalPlan::Filter(filter))) } else { // Remove join expressions from filter: - match remove_join_expressions(predicate.clone(), &all_join_keys) { + match remove_join_expressions(predicate, &all_join_keys) { Some(filter_expr) => Filter::try_new(filter_expr, Arc::new(left)) - .map(|f| Some(LogicalPlan::Filter(f))), - _ => Ok(Some(left)), + .map(|filter| Transformed::yes(LogicalPlan::Filter(filter))), + _ => Ok(Transformed::yes(left)), } } } @@ -144,49 +190,89 @@ impl OptimizerRule for EliminateCrossJoin { } } +fn rewrite_children( + optimizer: &impl OptimizerRule, + plan: LogicalPlan, + config: &dyn OptimizerConfig, +) -> Result> { + let transformed_plan = plan.map_children(|input| optimizer.rewrite(input, config))?; + + // recompute schema if the plan was transformed + if transformed_plan.transformed { + transformed_plan.map_data(|plan| plan.recompute_schema()) + } else { + Ok(transformed_plan) + } +} + /// Recursively accumulate possible_join_keys and inputs from inner joins /// (including cross joins). /// -/// Returns a boolean indicating whether the flattening was successful. -fn try_flatten_join_inputs( - plan: &LogicalPlan, +/// Assumes can_flatten_join_inputs has returned true and thus the plan can be +/// flattened. Adds all leaf inputs to `all_inputs` and join_keys to +/// possible_join_keys +fn flatten_join_inputs( + plan: LogicalPlan, possible_join_keys: &mut JoinKeySet, all_inputs: &mut Vec, -) -> Result { - let children = match plan { +) -> Result<()> { + match plan { LogicalPlan::Join(join) if join.join_type == JoinType::Inner => { + // checked in can_flatten_join_inputs if join.filter.is_some() { - // The filter of inner join will lost, skip this rule. - // issue: https://github.com/apache/datafusion/issues/4844 - return Ok(false); + return internal_err!( + "should not have filter in inner join in flatten_join_inputs" + ); } - possible_join_keys.insert_all(join.on.iter()); - vec![&join.left, &join.right] + possible_join_keys.insert_all_owned(join.on); + flatten_join_inputs(unwrap_arc(join.left), possible_join_keys, all_inputs)?; + flatten_join_inputs(unwrap_arc(join.right), possible_join_keys, all_inputs)?; } LogicalPlan::CrossJoin(join) => { - vec![&join.left, &join.right] + flatten_join_inputs(unwrap_arc(join.left), possible_join_keys, all_inputs)?; + flatten_join_inputs(unwrap_arc(join.right), possible_join_keys, all_inputs)?; } _ => { - return plan_err!("flatten_join_inputs just can call join/cross_join"); + all_inputs.push(plan); } }; + Ok(()) +} - for child in children.iter() { - let child = child.as_ref(); +/// Returns true if the plan is a Join or Cross join could be flattened with +/// `flatten_join_inputs` +/// +/// Must stay in sync with `flatten_join_inputs` +fn can_flatten_join_inputs(plan: &LogicalPlan) -> bool { + // can only flatten inner / cross joins + match plan { + LogicalPlan::Join(join) if join.join_type == JoinType::Inner => { + // The filter of inner join will lost, skip this rule. + // issue: https://github.com/apache/datafusion/issues/4844 + if join.filter.is_some() { + return false; + } + } + LogicalPlan::CrossJoin(_) => {} + _ => return false, + }; + + for child in plan.inputs() { match child { LogicalPlan::Join(Join { join_type: JoinType::Inner, .. }) | LogicalPlan::CrossJoin(_) => { - if !try_flatten_join_inputs(child, possible_join_keys, all_inputs)? { - return Ok(false); + if !can_flatten_join_inputs(child) { + return false; } } - _ => all_inputs.push(child.clone()), + // the child is not a join/cross join + _ => (), } } - Ok(true) + true } /// Finds the next to join with the left input plan, @@ -202,7 +288,7 @@ fn try_flatten_join_inputs( /// 1. Removes the first plan from `rights` /// 2. Returns `left_input CROSS JOIN right`. fn find_inner_join( - left_input: &LogicalPlan, + left_input: LogicalPlan, rights: &mut Vec, possible_join_keys: &JoinKeySet, all_join_keys: &mut JoinKeySet, @@ -237,7 +323,7 @@ fn find_inner_join( )?); return Ok(LogicalPlan::Join(Join { - left: Arc::new(left_input.clone()), + left: Arc::new(left_input), right: Arc::new(right_input), join_type: JoinType::Inner, join_constraint: JoinConstraint::On, @@ -259,7 +345,7 @@ fn find_inner_join( )?); Ok(LogicalPlan::CrossJoin(CrossJoin { - left: Arc::new(left_input.clone()), + left: Arc::new(left_input), right: Arc::new(right), schema: join_schema, })) @@ -341,12 +427,12 @@ mod tests { Operator::{And, Or}, }; - fn assert_optimized_plan_eq(plan: &LogicalPlan, expected: Vec<&str>) { + fn assert_optimized_plan_eq(plan: LogicalPlan, expected: Vec<&str>) { + let starting_schema = plan.schema().clone(); let rule = EliminateCrossJoin::new(); - let optimized_plan = rule - .try_optimize(plan, &OptimizerContext::new()) - .unwrap() - .expect("failed to optimize plan"); + let transformed_plan = rule.rewrite(plan, &OptimizerContext::new()).unwrap(); + assert!(transformed_plan.transformed, "failed to optimize plan"); + let optimized_plan = transformed_plan.data; let formatted = optimized_plan.display_indent_schema().to_string(); let actual: Vec<&str> = formatted.trim().lines().collect(); @@ -355,13 +441,13 @@ mod tests { "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" ); - assert_eq!(plan.schema(), optimized_plan.schema()) + assert_eq!(&starting_schema, optimized_plan.schema()) } - fn assert_optimization_rule_fails(plan: &LogicalPlan) { + fn assert_optimization_rule_fails(plan: LogicalPlan) { let rule = EliminateCrossJoin::new(); - let optimized_plan = rule.try_optimize(plan, &OptimizerContext::new()).unwrap(); - assert!(optimized_plan.is_none()); + let transformed_plan = rule.rewrite(plan, &OptimizerContext::new()).unwrap(); + assert!(!transformed_plan.transformed) } #[test] @@ -386,7 +472,7 @@ mod tests { " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", ]; - assert_optimized_plan_eq(&plan, expected); + assert_optimized_plan_eq(plan, expected); Ok(()) } @@ -414,7 +500,7 @@ mod tests { " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", ]; - assert_optimized_plan_eq(&plan, expected); + assert_optimized_plan_eq(plan, expected); Ok(()) } @@ -441,7 +527,7 @@ mod tests { " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", ]; - assert_optimized_plan_eq(&plan, expected); + assert_optimized_plan_eq(plan, expected); Ok(()) } @@ -471,7 +557,7 @@ mod tests { " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", ]; - assert_optimized_plan_eq(&plan, expected); + assert_optimized_plan_eq(plan, expected); Ok(()) } @@ -501,7 +587,7 @@ mod tests { " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", ]; - assert_optimized_plan_eq(&plan, expected); + assert_optimized_plan_eq(plan, expected); Ok(()) } @@ -527,7 +613,7 @@ mod tests { " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", ]; - assert_optimized_plan_eq(&plan, expected); + assert_optimized_plan_eq(plan, expected); Ok(()) } @@ -551,7 +637,7 @@ mod tests { .filter(col("t1.a").gt(lit(15u32)))? .build()?; - assert_optimization_rule_fails(&plan); + assert_optimization_rule_fails(plan); Ok(()) } @@ -598,7 +684,7 @@ mod tests { " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", ]; - assert_optimized_plan_eq(&plan, expected); + assert_optimized_plan_eq(plan, expected); Ok(()) } @@ -675,7 +761,7 @@ mod tests { " TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]", ]; - assert_optimized_plan_eq(&plan, expected); + assert_optimized_plan_eq(plan, expected); Ok(()) } @@ -750,7 +836,7 @@ mod tests { " TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]", ]; - assert_optimized_plan_eq(&plan, expected); + assert_optimized_plan_eq(plan, expected); Ok(()) } @@ -825,7 +911,7 @@ mod tests { " TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]", ]; - assert_optimized_plan_eq(&plan, expected); + assert_optimized_plan_eq(plan, expected); Ok(()) } @@ -904,7 +990,7 @@ mod tests { " TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]", ]; - assert_optimized_plan_eq(&plan, expected); + assert_optimized_plan_eq(plan, expected); Ok(()) } @@ -987,7 +1073,7 @@ mod tests { " TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]", ]; - assert_optimized_plan_eq(&plan, expected); + assert_optimized_plan_eq(plan, expected); Ok(()) } @@ -1074,7 +1160,7 @@ mod tests { " TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]", ]; - assert_optimized_plan_eq(&plan, expected); + assert_optimized_plan_eq(plan, expected); Ok(()) } @@ -1100,7 +1186,7 @@ mod tests { " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]"]; - assert_optimized_plan_eq(&plan, expected); + assert_optimized_plan_eq(plan, expected); Ok(()) } @@ -1128,7 +1214,7 @@ mod tests { " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", ]; - assert_optimized_plan_eq(&plan, expected); + assert_optimized_plan_eq(plan, expected); Ok(()) } @@ -1156,7 +1242,7 @@ mod tests { " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", ]; - assert_optimized_plan_eq(&plan, expected); + assert_optimized_plan_eq(plan, expected); Ok(()) } @@ -1184,7 +1270,7 @@ mod tests { " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", ]; - assert_optimized_plan_eq(&plan, expected); + assert_optimized_plan_eq(plan, expected); Ok(()) } @@ -1224,7 +1310,7 @@ mod tests { " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", ]; - assert_optimized_plan_eq(&plan, expected); + assert_optimized_plan_eq(plan, expected); Ok(()) } diff --git a/datafusion/optimizer/src/join_key_set.rs b/datafusion/optimizer/src/join_key_set.rs index c47afa012c174..cd8ed382f0690 100644 --- a/datafusion/optimizer/src/join_key_set.rs +++ b/datafusion/optimizer/src/join_key_set.rs @@ -66,20 +66,46 @@ impl JoinKeySet { } } + /// Same as [`Self::insert`] but avoids cloning expression if they + /// are owned + pub fn insert_owned(&mut self, left: Expr, right: Expr) -> bool { + if self.contains(&left, &right) { + false + } else { + self.inner.insert((left, right)); + true + } + } + /// Inserts potentially many join keys into the set, copying only when necessary /// /// returns true if any of the pairs were inserted pub fn insert_all<'a>( &mut self, - iter: impl Iterator, + iter: impl IntoIterator, ) -> bool { let mut inserted = false; - for (left, right) in iter { + for (left, right) in iter.into_iter() { inserted |= self.insert(left, right); } inserted } + /// Same as [`Self::insert_all`] but avoids cloning expressions if they are + /// already owned + /// + /// returns true if any of the pairs were inserted + pub fn insert_all_owned( + &mut self, + iter: impl IntoIterator, + ) -> bool { + let mut inserted = false; + for (left, right) in iter.into_iter() { + inserted |= self.insert_owned(left, right); + } + inserted + } + /// Inserts any join keys that are common to both `s1` and `s2` into self pub fn insert_intersection(&mut self, s1: JoinKeySet, s2: JoinKeySet) { // note can't use inner.intersection as we need to consider both (l, r) @@ -156,6 +182,15 @@ mod test { assert_eq!(set.len(), 2); } + #[test] + fn test_insert_owned() { + let mut set = JoinKeySet::new(); + assert!(set.insert_owned(col("a"), col("b"))); + assert!(set.contains(&col("a"), &col("b"))); + assert!(set.contains(&col("b"), &col("a"))); + assert!(!set.contains(&col("a"), &col("c"))); + } + #[test] fn test_contains() { let mut set = JoinKeySet::new(); @@ -217,18 +252,34 @@ mod test { } #[test] - fn test_insert_many() { + fn test_insert_all() { let mut set = JoinKeySet::new(); // insert (a=b), (b=c), (b=a) - set.insert_all( - vec![ - &(col("a"), col("b")), - &(col("b"), col("c")), - &(col("b"), col("a")), - ] - .into_iter(), - ); + set.insert_all(vec![ + &(col("a"), col("b")), + &(col("b"), col("c")), + &(col("b"), col("a")), + ]); + assert_eq!(set.len(), 2); + assert!(set.contains(&col("a"), &col("b"))); + assert!(set.contains(&col("b"), &col("c"))); + assert!(set.contains(&col("b"), &col("a"))); + + // should not contain (a=c) + assert!(!set.contains(&col("a"), &col("c"))); + } + + #[test] + fn test_insert_all_owned() { + let mut set = JoinKeySet::new(); + + // insert (a=b), (b=c), (b=a) + set.insert_all_owned(vec![ + (col("a"), col("b")), + (col("b"), col("c")), + (col("b"), col("a")), + ]); assert_eq!(set.len(), 2); assert!(set.contains(&col("a"), &col("b"))); assert!(set.contains(&col("b"), &col("c")));