From be296ab4e6075e930fd7c5fb164305204ed3ca2f Mon Sep 17 00:00:00 2001 From: Subhan <68732277+delamarch3@users.noreply.github.com> Date: Thu, 6 Mar 2025 21:04:17 +0000 Subject: [PATCH 1/7] handle columns in with_new_exprs with Join --- datafusion/expr/src/logical_plan/plan.rs | 158 +++++++++++++++++++++-- 1 file changed, 148 insertions(+), 10 deletions(-) diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index c6fd95595233..6151ee4961ec 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -906,9 +906,12 @@ impl LogicalPlan { let equi_expr_count = on.len(); assert!(expr.len() >= equi_expr_count); + let col_pair_count = + expr.iter().filter(|e| matches!(e, Expr::Column(_))).count() / 2; + // Assume that the last expr, if any, // is the filter_expr (non equality predicate from ON clause) - let filter_expr = if expr.len() > equi_expr_count { + let filter_expr = if expr.len() - col_pair_count > equi_expr_count { expr.pop() } else { None @@ -916,18 +919,30 @@ impl LogicalPlan { // The first part of expr is equi-exprs, // and the struct of each equi-expr is like `left-expr = right-expr`. - assert_eq!(expr.len(), equi_expr_count); - let new_on = expr.into_iter().map(|equi_expr| { + assert_eq!(expr.len() - col_pair_count, equi_expr_count); + let mut new_on = Vec::new(); + let mut iter = expr.into_iter(); + while let Some(equi_expr) = iter.next() { // SimplifyExpression rule may add alias to the equi_expr. let unalias_expr = equi_expr.clone().unalias(); - if let Expr::BinaryExpr(BinaryExpr { left, op: Operator::Eq, right }) = unalias_expr { - Ok((*left, *right)) - } else { - internal_err!( - "The front part expressions should be an binary equality expression, actual:{equi_expr}" - ) + match unalias_expr { + Expr::BinaryExpr(BinaryExpr { + left, + op: Operator::Eq, + right, + }) => new_on.push((*left, *right)), + left @ Expr::Column(_) => { + let Some(right) = iter.next() else { + internal_err!("Expected a pair of columns to construct the join on expression")? + }; + + new_on.push((left, right)); + } + _ => internal_err!( + "The front part expressions should be a binary equality expression or a column expression, actual:{equi_expr}" + )? } - }).collect::>>()?; + } Ok(LogicalPlan::Join(Join { left: Arc::new(left), @@ -4630,4 +4645,127 @@ digraph { let parameter_type = params.clone().get(placeholder_value).unwrap().clone(); assert_eq!(parameter_type, None); } + + #[test] + fn test_join_with_new_exprs() -> Result<()> { + fn create_test_join(on: Vec<(Expr, Expr)>, filter: Option) -> LogicalPlan { + let schema = Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + ]); + + let left_schema = DFSchema::try_from_qualified_schema("t1", &schema).unwrap(); + let right_schema = + DFSchema::try_from_qualified_schema("t2", &schema).unwrap(); + + LogicalPlan::Join(Join { + left: Arc::new( + table_scan(Some("t1"), left_schema.as_arrow(), None) + .unwrap() + .build() + .unwrap(), + ), + right: Arc::new( + table_scan(Some("t2"), right_schema.as_arrow(), None) + .unwrap() + .build() + .unwrap(), + ), + on, + filter, + join_type: JoinType::Inner, + join_constraint: JoinConstraint::On, + schema: Arc::new(left_schema.join(&right_schema).unwrap()), + null_equals_null: false, + }) + } + + { + let join = create_test_join(vec![(col("t1.a"), (col("t2.a")))], None); + let join = join + .with_new_exprs( + join.expressions(), + join.inputs().into_iter().map(|x| x.clone()).collect(), + ) + .unwrap(); + let LogicalPlan::Join(join) = join else { + unreachable!() + }; + assert_eq!(join.on, vec![(col("t1.a"), (col("t2.a")))]); + assert_eq!(join.filter, None); + } + + { + let join = create_test_join(vec![], Some(col("t1.a").gt(col("t2.a")))); + let join = join + .with_new_exprs( + join.expressions(), + join.inputs().into_iter().map(|x| x.clone()).collect(), + ) + .unwrap(); + let LogicalPlan::Join(join) = join else { + unreachable!() + }; + assert_eq!(join.on, vec![]); + assert_eq!(join.filter, Some(col("t1.a").gt(col("t2.a")))); + } + + { + let join = create_test_join( + vec![(col("t1.a"), (col("t2.a")))], + Some(col("t1.b").gt(col("t2.b"))), + ); + let join = join + .with_new_exprs( + join.expressions(), + join.inputs().into_iter().map(|x| x.clone()).collect(), + ) + .unwrap(); + let LogicalPlan::Join(join) = join else { + unreachable!() + }; + assert_eq!(join.on, vec![(col("t1.a"), (col("t2.a")))]); + assert_eq!(join.filter, Some(col("t1.b").gt(col("t2.b")))); + } + + { + let join = create_test_join( + vec![(col("t1.a"), (col("t2.a"))), (col("t1.b"), (col("t2.b")))], + None, + ); + let join = join + .with_new_exprs( + vec![ + col("t1.a").eq(col("t2.a")), + col("t1.b"), + col("t2.b"), + lit(true), + ], + join.inputs().into_iter().map(|x| x.clone()).collect(), + ) + .unwrap(); + let LogicalPlan::Join(join) = join else { + unreachable!() + }; + assert_eq!( + join.on, + vec![(col("t1.a"), (col("t2.a"))), (col("t1.b"), (col("t2.b")))] + ); + assert_eq!(join.filter, Some(lit(true))); + } + + { + let join = create_test_join( + vec![(col("t1.a"), (col("t2.a"))), (col("t1.b"), (col("t2.b")))], + None, + ); + let res = join.with_new_exprs( + vec![col("t1.a").eq(col("t2.a")), col("t1.b")], + join.inputs().into_iter().map(|x| x.clone()).collect(), + ); + assert!(res.is_err()); + } + + Ok(()) + } } From a8db03a9001f072383a78e5c6bed5284f359bc3f Mon Sep 17 00:00:00 2001 From: Subhan <68732277+delamarch3@users.noreply.github.com> Date: Thu, 6 Mar 2025 21:06:42 +0000 Subject: [PATCH 2/7] test doesn't return result --- datafusion/expr/src/logical_plan/plan.rs | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 6151ee4961ec..5580f0df9030 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -4647,7 +4647,7 @@ digraph { } #[test] - fn test_join_with_new_exprs() -> Result<()> { + fn test_join_with_new_exprs() { fn create_test_join(on: Vec<(Expr, Expr)>, filter: Option) -> LogicalPlan { let schema = Schema::new(vec![ Field::new("a", DataType::Int32, false), @@ -4765,7 +4765,5 @@ digraph { ); assert!(res.is_err()); } - - Ok(()) } } From 0482e9fb5c67d27bd11526d27b91046b8c40af86 Mon Sep 17 00:00:00 2001 From: Subhan <68732277+delamarch3@users.noreply.github.com> Date: Thu, 6 Mar 2025 21:09:36 +0000 Subject: [PATCH 3/7] take join from result --- datafusion/expr/src/logical_plan/plan.rs | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 5580f0df9030..b204e15687db 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -4682,13 +4682,13 @@ digraph { { let join = create_test_join(vec![(col("t1.a"), (col("t2.a")))], None); - let join = join + let LogicalPlan::Join(join) = join .with_new_exprs( join.expressions(), join.inputs().into_iter().map(|x| x.clone()).collect(), ) - .unwrap(); - let LogicalPlan::Join(join) = join else { + .unwrap() + else { unreachable!() }; assert_eq!(join.on, vec![(col("t1.a"), (col("t2.a")))]); @@ -4697,13 +4697,13 @@ digraph { { let join = create_test_join(vec![], Some(col("t1.a").gt(col("t2.a")))); - let join = join + let LogicalPlan::Join(join) = join .with_new_exprs( join.expressions(), join.inputs().into_iter().map(|x| x.clone()).collect(), ) - .unwrap(); - let LogicalPlan::Join(join) = join else { + .unwrap() + else { unreachable!() }; assert_eq!(join.on, vec![]); @@ -4715,13 +4715,13 @@ digraph { vec![(col("t1.a"), (col("t2.a")))], Some(col("t1.b").gt(col("t2.b"))), ); - let join = join + let LogicalPlan::Join(join) = join .with_new_exprs( join.expressions(), join.inputs().into_iter().map(|x| x.clone()).collect(), ) - .unwrap(); - let LogicalPlan::Join(join) = join else { + .unwrap() + else { unreachable!() }; assert_eq!(join.on, vec![(col("t1.a"), (col("t2.a")))]); @@ -4733,7 +4733,7 @@ digraph { vec![(col("t1.a"), (col("t2.a"))), (col("t1.b"), (col("t2.b")))], None, ); - let join = join + let LogicalPlan::Join(join) = join .with_new_exprs( vec![ col("t1.a").eq(col("t2.a")), @@ -4743,8 +4743,8 @@ digraph { ], join.inputs().into_iter().map(|x| x.clone()).collect(), ) - .unwrap(); - let LogicalPlan::Join(join) = join else { + .unwrap() + else { unreachable!() }; assert_eq!( From bd8c18dc1291667eba752492e1a609091611df83 Mon Sep 17 00:00:00 2001 From: Subhan <68732277+delamarch3@users.noreply.github.com> Date: Thu, 6 Mar 2025 21:51:49 +0000 Subject: [PATCH 4/7] clippy --- datafusion/expr/src/logical_plan/plan.rs | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index b204e15687db..045b2892d911 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -4685,7 +4685,7 @@ digraph { let LogicalPlan::Join(join) = join .with_new_exprs( join.expressions(), - join.inputs().into_iter().map(|x| x.clone()).collect(), + join.inputs().into_iter().cloned().collect(), ) .unwrap() else { @@ -4700,7 +4700,7 @@ digraph { let LogicalPlan::Join(join) = join .with_new_exprs( join.expressions(), - join.inputs().into_iter().map(|x| x.clone()).collect(), + join.inputs().into_iter().cloned().collect(), ) .unwrap() else { @@ -4718,7 +4718,7 @@ digraph { let LogicalPlan::Join(join) = join .with_new_exprs( join.expressions(), - join.inputs().into_iter().map(|x| x.clone()).collect(), + join.inputs().into_iter().cloned().collect(), ) .unwrap() else { @@ -4741,7 +4741,7 @@ digraph { col("t2.b"), lit(true), ], - join.inputs().into_iter().map(|x| x.clone()).collect(), + join.inputs().into_iter().cloned().collect(), ) .unwrap() else { @@ -4761,7 +4761,7 @@ digraph { ); let res = join.with_new_exprs( vec![col("t1.a").eq(col("t2.a")), col("t1.b")], - join.inputs().into_iter().map(|x| x.clone()).collect(), + join.inputs().into_iter().cloned().collect(), ); assert!(res.is_err()); } From 65974e6e11f8d38030c5aeaa7452ee990bbabba0 Mon Sep 17 00:00:00 2001 From: Subhan <68732277+delamarch3@users.noreply.github.com> Date: Fri, 7 Mar 2025 09:54:52 +0000 Subject: [PATCH 5/7] make test fallible --- datafusion/expr/src/logical_plan/plan.rs | 90 +++++++++++------------- 1 file changed, 40 insertions(+), 50 deletions(-) diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 045b2892d911..d474975c1711 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -4647,47 +4647,41 @@ digraph { } #[test] - fn test_join_with_new_exprs() { - fn create_test_join(on: Vec<(Expr, Expr)>, filter: Option) -> LogicalPlan { + fn test_join_with_new_exprs() -> Result<()> { + fn create_test_join( + on: Vec<(Expr, Expr)>, + filter: Option, + ) -> Result { let schema = Schema::new(vec![ Field::new("a", DataType::Int32, false), Field::new("b", DataType::Int32, false), ]); - let left_schema = DFSchema::try_from_qualified_schema("t1", &schema).unwrap(); - let right_schema = - DFSchema::try_from_qualified_schema("t2", &schema).unwrap(); + let left_schema = DFSchema::try_from_qualified_schema("t1", &schema)?; + let right_schema = DFSchema::try_from_qualified_schema("t2", &schema)?; - LogicalPlan::Join(Join { + Ok(LogicalPlan::Join(Join { left: Arc::new( - table_scan(Some("t1"), left_schema.as_arrow(), None) - .unwrap() - .build() - .unwrap(), + table_scan(Some("t1"), left_schema.as_arrow(), None)?.build()?, ), right: Arc::new( - table_scan(Some("t2"), right_schema.as_arrow(), None) - .unwrap() - .build() - .unwrap(), + table_scan(Some("t2"), right_schema.as_arrow(), None)?.build()?, ), on, filter, join_type: JoinType::Inner, join_constraint: JoinConstraint::On, - schema: Arc::new(left_schema.join(&right_schema).unwrap()), + schema: Arc::new(left_schema.join(&right_schema)?), null_equals_null: false, - }) + })) } { - let join = create_test_join(vec![(col("t1.a"), (col("t2.a")))], None); - let LogicalPlan::Join(join) = join - .with_new_exprs( - join.expressions(), - join.inputs().into_iter().cloned().collect(), - ) - .unwrap() + let join = create_test_join(vec![(col("t1.a"), (col("t2.a")))], None)?; + let LogicalPlan::Join(join) = join.with_new_exprs( + join.expressions(), + join.inputs().into_iter().cloned().collect(), + )? else { unreachable!() }; @@ -4696,13 +4690,11 @@ digraph { } { - let join = create_test_join(vec![], Some(col("t1.a").gt(col("t2.a")))); - let LogicalPlan::Join(join) = join - .with_new_exprs( - join.expressions(), - join.inputs().into_iter().cloned().collect(), - ) - .unwrap() + let join = create_test_join(vec![], Some(col("t1.a").gt(col("t2.a"))))?; + let LogicalPlan::Join(join) = join.with_new_exprs( + join.expressions(), + join.inputs().into_iter().cloned().collect(), + )? else { unreachable!() }; @@ -4714,13 +4706,11 @@ digraph { let join = create_test_join( vec![(col("t1.a"), (col("t2.a")))], Some(col("t1.b").gt(col("t2.b"))), - ); - let LogicalPlan::Join(join) = join - .with_new_exprs( - join.expressions(), - join.inputs().into_iter().cloned().collect(), - ) - .unwrap() + )?; + let LogicalPlan::Join(join) = join.with_new_exprs( + join.expressions(), + join.inputs().into_iter().cloned().collect(), + )? else { unreachable!() }; @@ -4732,18 +4722,16 @@ digraph { let join = create_test_join( vec![(col("t1.a"), (col("t2.a"))), (col("t1.b"), (col("t2.b")))], None, - ); - let LogicalPlan::Join(join) = join - .with_new_exprs( - vec![ - col("t1.a").eq(col("t2.a")), - col("t1.b"), - col("t2.b"), - lit(true), - ], - join.inputs().into_iter().cloned().collect(), - ) - .unwrap() + )?; + let LogicalPlan::Join(join) = join.with_new_exprs( + vec![ + col("t1.a").eq(col("t2.a")), + col("t1.b"), + col("t2.b"), + lit(true), + ], + join.inputs().into_iter().cloned().collect(), + )? else { unreachable!() }; @@ -4758,12 +4746,14 @@ digraph { let join = create_test_join( vec![(col("t1.a"), (col("t2.a"))), (col("t1.b"), (col("t2.b")))], None, - ); + )?; let res = join.with_new_exprs( vec![col("t1.a").eq(col("t2.a")), col("t1.b")], join.inputs().into_iter().cloned().collect(), ); assert!(res.is_err()); } + + Ok(()) } } From f97cf6c6cd67bed95a6c52f7964267dd585bc20d Mon Sep 17 00:00:00 2001 From: Subhan <68732277+delamarch3@users.noreply.github.com> Date: Fri, 7 Mar 2025 10:18:23 +0000 Subject: [PATCH 6/7] accept any pair of expression for new_on in with_new_exprs for Join --- datafusion/expr/src/logical_plan/plan.rs | 60 ++++++++---------------- 1 file changed, 20 insertions(+), 40 deletions(-) diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index d474975c1711..2ed019fa18a1 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -903,15 +903,12 @@ impl LogicalPlan { let (left, right) = self.only_two_inputs(inputs)?; let schema = build_join_schema(left.schema(), right.schema(), join_type)?; - let equi_expr_count = on.len(); + let equi_expr_count = on.len() * 2; assert!(expr.len() >= equi_expr_count); - let col_pair_count = - expr.iter().filter(|e| matches!(e, Expr::Column(_))).count() / 2; - // Assume that the last expr, if any, // is the filter_expr (non equality predicate from ON clause) - let filter_expr = if expr.len() - col_pair_count > equi_expr_count { + let filter_expr = if expr.len() > equi_expr_count { expr.pop() } else { None @@ -919,29 +916,16 @@ impl LogicalPlan { // The first part of expr is equi-exprs, // and the struct of each equi-expr is like `left-expr = right-expr`. - assert_eq!(expr.len() - col_pair_count, equi_expr_count); + assert_eq!(expr.len(), equi_expr_count); let mut new_on = Vec::new(); let mut iter = expr.into_iter(); - while let Some(equi_expr) = iter.next() { - // SimplifyExpression rule may add alias to the equi_expr. - let unalias_expr = equi_expr.clone().unalias(); - match unalias_expr { - Expr::BinaryExpr(BinaryExpr { - left, - op: Operator::Eq, - right, - }) => new_on.push((*left, *right)), - left @ Expr::Column(_) => { - let Some(right) = iter.next() else { - internal_err!("Expected a pair of columns to construct the join on expression")? - }; + while let Some(left) = iter.next() { + let Some(right) = iter.next() else { + internal_err!("Expected a pair of expressions to construct the join on expression")? + }; - new_on.push((left, right)); - } - _ => internal_err!( - "The front part expressions should be a binary equality expression or a column expression, actual:{equi_expr}" - )? - } + // SimplifyExpression rule may add alias to the equi_expr. + new_on.push((left.unalias(), right.unalias())); } Ok(LogicalPlan::Join(Join { @@ -3793,7 +3777,8 @@ mod tests { use crate::builder::LogicalTableSource; use crate::logical_plan::table_scan; use crate::{ - col, exists, in_subquery, lit, placeholder, scalar_subquery, GroupingSet, + binary_expr, col, exists, in_subquery, lit, placeholder, scalar_subquery, + GroupingSet, }; use datafusion_common::tree_node::{ @@ -4725,7 +4710,8 @@ digraph { )?; let LogicalPlan::Join(join) = join.with_new_exprs( vec![ - col("t1.a").eq(col("t2.a")), + binary_expr(col("t1.a"), Operator::Plus, lit(1)), + binary_expr(col("t2.a"), Operator::Plus, lit(2)), col("t1.b"), col("t2.b"), lit(true), @@ -4737,23 +4723,17 @@ digraph { }; assert_eq!( join.on, - vec![(col("t1.a"), (col("t2.a"))), (col("t1.b"), (col("t2.b")))] + vec![ + ( + binary_expr(col("t1.a"), Operator::Plus, lit(1)), + binary_expr(col("t2.a"), Operator::Plus, lit(2)) + ), + (col("t1.b"), (col("t2.b"))) + ] ); assert_eq!(join.filter, Some(lit(true))); } - { - let join = create_test_join( - vec![(col("t1.a"), (col("t2.a"))), (col("t1.b"), (col("t2.b")))], - None, - )?; - let res = join.with_new_exprs( - vec![col("t1.a").eq(col("t2.a")), col("t1.b")], - join.inputs().into_iter().cloned().collect(), - ); - assert!(res.is_err()); - } - Ok(()) } } From daf17445610b5822c2a66a745ce7e1d544d08406 Mon Sep 17 00:00:00 2001 From: Subhan <68732277+delamarch3@users.noreply.github.com> Date: Fri, 7 Mar 2025 15:44:25 +0000 Subject: [PATCH 7/7] use with_capacity --- datafusion/expr/src/logical_plan/plan.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 2ed019fa18a1..aa62b4a22b64 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -917,7 +917,7 @@ impl LogicalPlan { // The first part of expr is equi-exprs, // and the struct of each equi-expr is like `left-expr = right-expr`. assert_eq!(expr.len(), equi_expr_count); - let mut new_on = Vec::new(); + let mut new_on = Vec::with_capacity(on.len()); let mut iter = expr.into_iter(); while let Some(left) = iter.next() { let Some(right) = iter.next() else {