Skip to content

Commit fe09f91

Browse files
Handle columns in with_new_exprs with a Join (apache#15055) (#384)
apache#15055 * handle columns in with_new_exprs with Join * test doesn't return result * take join from result * clippy * make test fallible * accept any pair of expression for new_on in with_new_exprs for Join * use with_capacity Co-authored-by: delamarch3 <[email protected]>
1 parent 724e220 commit fe09f91

File tree

1 file changed

+118
-12
lines changed
  • datafusion/expr/src/logical_plan

1 file changed

+118
-12
lines changed

datafusion/expr/src/logical_plan/plan.rs

Lines changed: 118 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -890,7 +890,7 @@ impl LogicalPlan {
890890
let (left, right) = self.only_two_inputs(inputs)?;
891891
let schema = build_join_schema(left.schema(), right.schema(), join_type)?;
892892

893-
let equi_expr_count = on.len();
893+
let equi_expr_count = on.len() * 2;
894894
assert!(expr.len() >= equi_expr_count);
895895

896896
// Assume that the last expr, if any,
@@ -904,17 +904,16 @@ impl LogicalPlan {
904904
// The first part of expr is equi-exprs,
905905
// and the struct of each equi-expr is like `left-expr = right-expr`.
906906
assert_eq!(expr.len(), equi_expr_count);
907-
let new_on = expr.into_iter().map(|equi_expr| {
907+
let mut new_on = Vec::with_capacity(on.len());
908+
let mut iter = expr.into_iter();
909+
while let Some(left) = iter.next() {
910+
let Some(right) = iter.next() else {
911+
internal_err!("Expected a pair of expressions to construct the join on expression")?
912+
};
913+
908914
// SimplifyExpression rule may add alias to the equi_expr.
909-
let unalias_expr = equi_expr.clone().unalias();
910-
if let Expr::BinaryExpr(BinaryExpr { left, op: Operator::Eq, right }) = unalias_expr {
911-
Ok((*left, *right))
912-
} else {
913-
internal_err!(
914-
"The front part expressions should be an binary equality expression, actual:{equi_expr}"
915-
)
916-
}
917-
}).collect::<Result<Vec<(Expr, Expr)>>>()?;
915+
new_on.push((left.unalias(), right.unalias()));
916+
}
918917

919918
Ok(LogicalPlan::Join(Join {
920919
left: Arc::new(left),
@@ -3516,7 +3515,8 @@ mod tests {
35163515
use crate::builder::LogicalTableSource;
35173516
use crate::logical_plan::table_scan;
35183517
use crate::{
3519-
col, exists, in_subquery, lit, placeholder, scalar_subquery, GroupingSet,
3518+
binary_expr, col, exists, in_subquery, lit, placeholder, scalar_subquery,
3519+
GroupingSet,
35203520
};
35213521

35223522
use datafusion_common::tree_node::{
@@ -4347,4 +4347,110 @@ digraph {
43474347
plan.rewrite_with_subqueries(&mut rewriter).unwrap();
43484348
assert!(!rewriter.filter_found);
43494349
}
4350+
4351+
#[test]
4352+
fn test_join_with_new_exprs() -> Result<()> {
4353+
fn create_test_join(
4354+
on: Vec<(Expr, Expr)>,
4355+
filter: Option<Expr>,
4356+
) -> Result<LogicalPlan> {
4357+
let schema = Schema::new(vec![
4358+
Field::new("a", DataType::Int32, false),
4359+
Field::new("b", DataType::Int32, false),
4360+
]);
4361+
4362+
let left_schema = DFSchema::try_from_qualified_schema("t1", &schema)?;
4363+
let right_schema = DFSchema::try_from_qualified_schema("t2", &schema)?;
4364+
4365+
Ok(LogicalPlan::Join(Join {
4366+
left: Arc::new(
4367+
table_scan(Some("t1"), left_schema.as_arrow(), None)?.build()?,
4368+
),
4369+
right: Arc::new(
4370+
table_scan(Some("t2"), right_schema.as_arrow(), None)?.build()?,
4371+
),
4372+
on,
4373+
filter,
4374+
join_type: JoinType::Inner,
4375+
join_constraint: JoinConstraint::On,
4376+
schema: Arc::new(left_schema.join(&right_schema)?),
4377+
null_equals_null: false,
4378+
}))
4379+
}
4380+
4381+
{
4382+
let join = create_test_join(vec![(col("t1.a"), (col("t2.a")))], None)?;
4383+
let LogicalPlan::Join(join) = join.with_new_exprs(
4384+
join.expressions(),
4385+
join.inputs().into_iter().cloned().collect(),
4386+
)?
4387+
else {
4388+
unreachable!()
4389+
};
4390+
assert_eq!(join.on, vec![(col("t1.a"), (col("t2.a")))]);
4391+
assert_eq!(join.filter, None);
4392+
}
4393+
4394+
{
4395+
let join = create_test_join(vec![], Some(col("t1.a").gt(col("t2.a"))))?;
4396+
let LogicalPlan::Join(join) = join.with_new_exprs(
4397+
join.expressions(),
4398+
join.inputs().into_iter().cloned().collect(),
4399+
)?
4400+
else {
4401+
unreachable!()
4402+
};
4403+
assert_eq!(join.on, vec![]);
4404+
assert_eq!(join.filter, Some(col("t1.a").gt(col("t2.a"))));
4405+
}
4406+
4407+
{
4408+
let join = create_test_join(
4409+
vec![(col("t1.a"), (col("t2.a")))],
4410+
Some(col("t1.b").gt(col("t2.b"))),
4411+
)?;
4412+
let LogicalPlan::Join(join) = join.with_new_exprs(
4413+
join.expressions(),
4414+
join.inputs().into_iter().cloned().collect(),
4415+
)?
4416+
else {
4417+
unreachable!()
4418+
};
4419+
assert_eq!(join.on, vec![(col("t1.a"), (col("t2.a")))]);
4420+
assert_eq!(join.filter, Some(col("t1.b").gt(col("t2.b"))));
4421+
}
4422+
4423+
{
4424+
let join = create_test_join(
4425+
vec![(col("t1.a"), (col("t2.a"))), (col("t1.b"), (col("t2.b")))],
4426+
None,
4427+
)?;
4428+
let LogicalPlan::Join(join) = join.with_new_exprs(
4429+
vec![
4430+
binary_expr(col("t1.a"), Operator::Plus, lit(1)),
4431+
binary_expr(col("t2.a"), Operator::Plus, lit(2)),
4432+
col("t1.b"),
4433+
col("t2.b"),
4434+
lit(true),
4435+
],
4436+
join.inputs().into_iter().cloned().collect(),
4437+
)?
4438+
else {
4439+
unreachable!()
4440+
};
4441+
assert_eq!(
4442+
join.on,
4443+
vec![
4444+
(
4445+
binary_expr(col("t1.a"), Operator::Plus, lit(1)),
4446+
binary_expr(col("t2.a"), Operator::Plus, lit(2))
4447+
),
4448+
(col("t1.b"), (col("t2.b")))
4449+
]
4450+
);
4451+
assert_eq!(join.filter, Some(lit(true)));
4452+
}
4453+
4454+
Ok(())
4455+
}
43504456
}

0 commit comments

Comments
 (0)