@@ -890,7 +890,7 @@ impl LogicalPlan {
890890 let ( left, right) = self . only_two_inputs ( inputs) ?;
891891 let schema = build_join_schema ( left. schema ( ) , right. schema ( ) , join_type) ?;
892892
893- let equi_expr_count = on. len ( ) ;
893+ let equi_expr_count = on. len ( ) * 2 ;
894894 assert ! ( expr. len( ) >= equi_expr_count) ;
895895
896896 // Assume that the last expr, if any,
@@ -904,17 +904,16 @@ impl LogicalPlan {
904904 // The first part of expr is equi-exprs,
905905 // and the struct of each equi-expr is like `left-expr = right-expr`.
906906 assert_eq ! ( expr. len( ) , equi_expr_count) ;
907- let new_on = expr. into_iter ( ) . map ( |equi_expr| {
907+ let mut new_on = Vec :: with_capacity ( on. len ( ) ) ;
908+ let mut iter = expr. into_iter ( ) ;
909+ while let Some ( left) = iter. next ( ) {
910+ let Some ( right) = iter. next ( ) else {
911+ internal_err ! ( "Expected a pair of expressions to construct the join on expression" ) ?
912+ } ;
913+
908914 // SimplifyExpression rule may add alias to the equi_expr.
909- let unalias_expr = equi_expr. clone ( ) . unalias ( ) ;
910- if let Expr :: BinaryExpr ( BinaryExpr { left, op : Operator :: Eq , right } ) = unalias_expr {
911- Ok ( ( * left, * right) )
912- } else {
913- internal_err ! (
914- "The front part expressions should be an binary equality expression, actual:{equi_expr}"
915- )
916- }
917- } ) . collect :: < Result < Vec < ( Expr , Expr ) > > > ( ) ?;
915+ new_on. push ( ( left. unalias ( ) , right. unalias ( ) ) ) ;
916+ }
918917
919918 Ok ( LogicalPlan :: Join ( Join {
920919 left : Arc :: new ( left) ,
@@ -3516,7 +3515,8 @@ mod tests {
35163515 use crate :: builder:: LogicalTableSource ;
35173516 use crate :: logical_plan:: table_scan;
35183517 use crate :: {
3519- col, exists, in_subquery, lit, placeholder, scalar_subquery, GroupingSet ,
3518+ binary_expr, col, exists, in_subquery, lit, placeholder, scalar_subquery,
3519+ GroupingSet ,
35203520 } ;
35213521
35223522 use datafusion_common:: tree_node:: {
@@ -4347,4 +4347,110 @@ digraph {
43474347 plan. rewrite_with_subqueries ( & mut rewriter) . unwrap ( ) ;
43484348 assert ! ( !rewriter. filter_found) ;
43494349 }
4350+
4351+ #[ test]
4352+ fn test_join_with_new_exprs ( ) -> Result < ( ) > {
4353+ fn create_test_join (
4354+ on : Vec < ( Expr , Expr ) > ,
4355+ filter : Option < Expr > ,
4356+ ) -> Result < LogicalPlan > {
4357+ let schema = Schema :: new ( vec ! [
4358+ Field :: new( "a" , DataType :: Int32 , false ) ,
4359+ Field :: new( "b" , DataType :: Int32 , false ) ,
4360+ ] ) ;
4361+
4362+ let left_schema = DFSchema :: try_from_qualified_schema ( "t1" , & schema) ?;
4363+ let right_schema = DFSchema :: try_from_qualified_schema ( "t2" , & schema) ?;
4364+
4365+ Ok ( LogicalPlan :: Join ( Join {
4366+ left : Arc :: new (
4367+ table_scan ( Some ( "t1" ) , left_schema. as_arrow ( ) , None ) ?. build ( ) ?,
4368+ ) ,
4369+ right : Arc :: new (
4370+ table_scan ( Some ( "t2" ) , right_schema. as_arrow ( ) , None ) ?. build ( ) ?,
4371+ ) ,
4372+ on,
4373+ filter,
4374+ join_type : JoinType :: Inner ,
4375+ join_constraint : JoinConstraint :: On ,
4376+ schema : Arc :: new ( left_schema. join ( & right_schema) ?) ,
4377+ null_equals_null : false ,
4378+ } ) )
4379+ }
4380+
4381+ {
4382+ let join = create_test_join ( vec ! [ ( col( "t1.a" ) , ( col( "t2.a" ) ) ) ] , None ) ?;
4383+ let LogicalPlan :: Join ( join) = join. with_new_exprs (
4384+ join. expressions ( ) ,
4385+ join. inputs ( ) . into_iter ( ) . cloned ( ) . collect ( ) ,
4386+ ) ?
4387+ else {
4388+ unreachable ! ( )
4389+ } ;
4390+ assert_eq ! ( join. on, vec![ ( col( "t1.a" ) , ( col( "t2.a" ) ) ) ] ) ;
4391+ assert_eq ! ( join. filter, None ) ;
4392+ }
4393+
4394+ {
4395+ let join = create_test_join ( vec ! [ ] , Some ( col ( "t1.a" ) . gt ( col ( "t2.a" ) ) ) ) ?;
4396+ let LogicalPlan :: Join ( join) = join. with_new_exprs (
4397+ join. expressions ( ) ,
4398+ join. inputs ( ) . into_iter ( ) . cloned ( ) . collect ( ) ,
4399+ ) ?
4400+ else {
4401+ unreachable ! ( )
4402+ } ;
4403+ assert_eq ! ( join. on, vec![ ] ) ;
4404+ assert_eq ! ( join. filter, Some ( col( "t1.a" ) . gt( col( "t2.a" ) ) ) ) ;
4405+ }
4406+
4407+ {
4408+ let join = create_test_join (
4409+ vec ! [ ( col( "t1.a" ) , ( col( "t2.a" ) ) ) ] ,
4410+ Some ( col ( "t1.b" ) . gt ( col ( "t2.b" ) ) ) ,
4411+ ) ?;
4412+ let LogicalPlan :: Join ( join) = join. with_new_exprs (
4413+ join. expressions ( ) ,
4414+ join. inputs ( ) . into_iter ( ) . cloned ( ) . collect ( ) ,
4415+ ) ?
4416+ else {
4417+ unreachable ! ( )
4418+ } ;
4419+ assert_eq ! ( join. on, vec![ ( col( "t1.a" ) , ( col( "t2.a" ) ) ) ] ) ;
4420+ assert_eq ! ( join. filter, Some ( col( "t1.b" ) . gt( col( "t2.b" ) ) ) ) ;
4421+ }
4422+
4423+ {
4424+ let join = create_test_join (
4425+ vec ! [ ( col( "t1.a" ) , ( col( "t2.a" ) ) ) , ( col( "t1.b" ) , ( col( "t2.b" ) ) ) ] ,
4426+ None ,
4427+ ) ?;
4428+ let LogicalPlan :: Join ( join) = join. with_new_exprs (
4429+ vec ! [
4430+ binary_expr( col( "t1.a" ) , Operator :: Plus , lit( 1 ) ) ,
4431+ binary_expr( col( "t2.a" ) , Operator :: Plus , lit( 2 ) ) ,
4432+ col( "t1.b" ) ,
4433+ col( "t2.b" ) ,
4434+ lit( true ) ,
4435+ ] ,
4436+ join. inputs ( ) . into_iter ( ) . cloned ( ) . collect ( ) ,
4437+ ) ?
4438+ else {
4439+ unreachable ! ( )
4440+ } ;
4441+ assert_eq ! (
4442+ join. on,
4443+ vec![
4444+ (
4445+ binary_expr( col( "t1.a" ) , Operator :: Plus , lit( 1 ) ) ,
4446+ binary_expr( col( "t2.a" ) , Operator :: Plus , lit( 2 ) )
4447+ ) ,
4448+ ( col( "t1.b" ) , ( col( "t2.b" ) ) )
4449+ ]
4450+ ) ;
4451+ assert_eq ! ( join. filter, Some ( lit( true ) ) ) ;
4452+ }
4453+
4454+ Ok ( ( ) )
4455+ }
43504456}
0 commit comments