diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index f4dba67f13b54..585d5605e49b0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -66,7 +66,9 @@ object DefaultOptimizer extends Optimizer { Batch("Decimal Optimizations", FixedPoint(100), DecimalAggregates) :: Batch("LocalRelation", FixedPoint(100), - ConvertToLocalRelation) :: Nil + ConvertToLocalRelation) :: + Batch("Join Order Adjustment", FixedPoint(100), + AdjustJoinOrderWithEqualConditions) :: Nil } /** @@ -911,3 +913,98 @@ object RemoveLiteralFromGroupExpressions extends Rule[LogicalPlan] { a.copy(groupingExpressions = newGrouping) } } + +/** + * If there are equal-join conditions, but the join order prevents them from being seen + * by the optimizer, we will adjust the join order so that the join condition can be pushed + * down to join Operator. This avoids cartesian product in the physical plan + */ +object AdjustJoinOrderWithEqualConditions extends Rule[LogicalPlan] with PredicateHelper{ + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + // only consider Inner Join + case f @ Filter(conds, join @ Join(leftPlan, rightPlan, joinType, _)) + if joinType == Inner => + val (joins, relations, joinConds) = splitJoinRelationsNodes(join) + val allFilterConds = splitConjunctivePredicates(conds) ++ joinConds + newOperator(f, allFilterConds, joins, relations) + case join @ Join(leftPlan, rightPlan, joinType, _) if joinType == Inner => + val (joins, relations, joinConds) = splitJoinRelationsNodes(join) + val allFilterConds = joinConds + newOperator(join, allFilterConds, joins, relations) + } + + def splitJoinRelationsNodes(join: Join) : (Seq[Join], Seq[LogicalPlan], Seq[Expression]) = { + var joins = new collection.mutable.ArrayBuffer[Join]() + var relations = new collection.mutable.ArrayBuffer[LogicalPlan]() + var joinConds = new collection.mutable.ArrayBuffer[Expression]() + var queue = new collection.mutable.ArrayBuffer[Join]() + queue += join + while(!queue.isEmpty){ + val curNode = queue(0) + joins += curNode + queue = queue.drop(1) + curNode.asInstanceOf[Join].condition match { + case Some(e) => joinConds ++= + splitConjunctivePredicates(curNode.asInstanceOf[Join].condition.get) + case None => joinConds ++= Seq.empty[Expression] + } + + if(curNode.left.isInstanceOf[Join] && curNode.left.asInstanceOf[Join].joinType == Inner){ + queue += curNode.left.asInstanceOf[Join] + } + else relations += curNode.left + if(curNode.right.isInstanceOf[Join] && curNode.right.asInstanceOf[Join].joinType == Inner){ + queue += curNode.right.asInstanceOf[Join] + } + else relations += curNode.right + } + (joins, relations, joinConds) + } + + def newOperator(plan: LogicalPlan, allFilterConds: Seq[Expression], + joins: Seq[Join], relations: Seq[LogicalPlan]) : LogicalPlan = { + val equalConds = allFilterConds.filter { + case EqualTo(l, r) => true + case _ => false + } + + if(joins.length <= 1 || joins.length + 1 < relations.length) plan + else { + if (allFilterConds.isEmpty) plan + else Filter(allFilterConds.reduceLeft(And), shiftJoinOrder(relations, equalConds)) + } + } + + def shiftJoinOrder(relations: Seq[LogicalPlan], equalConds: Seq[Expression]) : Join = { + var finished : Boolean = false + var index : Int = 0 + var relationsMap: Map[LogicalPlan, Boolean] = relations.map(r => (r -> true)).toMap + while(!finished){ + if (relationsMap.size == 1 || index == equalConds.length) { + finished = true + } + else { + val equalCond = equalConds(index) + val left = equalCond.asInstanceOf[EqualTo].left.references + val lj = relationsMap.keys.toSeq.find(r => left.size > 0 && left.subsetOf(r.outputSet)) + if(lj != None){ + val right = equalCond.asInstanceOf[EqualTo].right.references + val rj = relationsMap.keys.toSeq.find(r => right.size > 0 && right.subsetOf(r.outputSet)) + if(rj != None) { + if (!lj.get.fastEquals(rj.get)){ + relationsMap -= rj.get + relationsMap -= lj.get + relationsMap += (Join(lj.get, rj.get, Inner, None) -> true) + } + } + } + index += 1 + } + } + relationsMap.keys.toSeq.reduceLeft(combineJoin).asInstanceOf[Join] + } + + def combineJoin(left: LogicalPlan, right: LogicalPlan) : LogicalPlan = { + Join(left, right, Inner, None) + } +}