Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -130,9 +130,14 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] {
leftKeys: IndexedSeq[Expression],
rightKeys: IndexedSeq[Expression],
expectedOrderOfKeys: Seq[Expression],
currentOrderOfKeys: Seq[Expression]): (Seq[Expression], Seq[Expression]) = {
currentOrderOfKeys: Seq[Expression]): Option[(Seq[Expression], Seq[Expression])] = {
if (expectedOrderOfKeys.size != currentOrderOfKeys.size) {
return (leftKeys, rightKeys)
return None
}

// Check if the current order already satisfies the expected order.
if (expectedOrderOfKeys.zip(currentOrderOfKeys).forall(p => p._1.semanticEquals(p._2))) {
return Some(leftKeys, rightKeys)
}

// Build a lookup between an expression and the positions its holds in the current key seq.
Expand All @@ -159,10 +164,10 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] {
rightKeysBuffer += rightKeys(index)
case _ =>
// The expression cannot be found, or we have exhausted all indices for that expression.
return (leftKeys, rightKeys)
return None
}
}
(leftKeysBuffer, rightKeysBuffer)
Some(leftKeysBuffer, rightKeysBuffer)
}

private def reorderJoinKeys(
Expand All @@ -171,19 +176,50 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] {
leftPartitioning: Partitioning,
rightPartitioning: Partitioning): (Seq[Expression], Seq[Expression]) = {
if (leftKeys.forall(_.deterministic) && rightKeys.forall(_.deterministic)) {
(leftPartitioning, rightPartitioning) match {
case (HashPartitioning(leftExpressions, _), _) =>
reorder(leftKeys.toIndexedSeq, rightKeys.toIndexedSeq, leftExpressions, leftKeys)
case (_, HashPartitioning(rightExpressions, _)) =>
reorder(leftKeys.toIndexedSeq, rightKeys.toIndexedSeq, rightExpressions, rightKeys)
case _ =>
(leftKeys, rightKeys)
}
reorderJoinKeysRecursively(leftKeys, rightKeys, leftPartitioning, rightPartitioning)
.getOrElse((leftKeys, rightKeys))
} else {
(leftKeys, rightKeys)
}
}

/**
* Recursively reorders the join keys based on partitioning. It starts reordering the
* join keys to match HashPartitioning on either side, followed by PartitioningCollection.
*/
private def reorderJoinKeysRecursively(
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
leftPartitioning: Partitioning,
rightPartitioning: Partitioning): Option[(Seq[Expression], Seq[Expression])] = {
(leftPartitioning, rightPartitioning) match {
case (HashPartitioning(leftExpressions, _), _) =>
reorder(leftKeys.toIndexedSeq, rightKeys.toIndexedSeq, leftExpressions, leftKeys)
.orElse(reorderJoinKeysRecursively(
leftKeys, rightKeys, UnknownPartitioning(0), rightPartitioning))
case (_, HashPartitioning(rightExpressions, _)) =>
reorder(leftKeys.toIndexedSeq, rightKeys.toIndexedSeq, rightExpressions, rightKeys)
.orElse(reorderJoinKeysRecursively(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be also implemented by looking at left partitioning first then move to the right partitionoing:

    (leftPartitioning, rightPartitioning) match {
      case (HashPartitioning(leftExpressions, _), _) =>
        reorder(leftKeys.toIndexedSeq, rightKeys.toIndexedSeq, leftExpressions, leftKeys)
          .orElse(reorderJoinKeysRecursively(
            leftKeys, rightKeys, UnknownPartitioning(0), rightPartitioning))
      case (PartitioningCollection(partitionings), _) =>
        partitionings.foreach { p =>
          reorderJoinKeysRecursively(leftKeys, rightKeys, p, rightPartitioning).map { k =>
            return Some(k)
          }
        }
        reorderJoinKeysRecursively(leftKeys, rightKeys, UnknownPartitioning(0), rightPartitioning)
      case (_, HashPartitioning(rightExpressions, _)) =>
        reorder(leftKeys.toIndexedSeq, rightKeys.toIndexedSeq, rightExpressions, rightKeys)
      case (_, PartitioningCollection(partitionings)) =>
        partitionings.foreach { p =>
          reorderJoinKeysRecursively(leftKeys, rightKeys, leftPartitioning, p).map { k =>
            return Some(k)
          }
        }
        None
      case _ =>
        None
    }

However, I chose this way so that the behavior remains the same. If you have leftPartitioning = PartitioningCollection and rightPartitioning = HashPartitioning, it will match the rightPartitioning first, which is the existing behavior.

leftKeys, rightKeys, leftPartitioning, UnknownPartitioning(0)))
case (PartitioningCollection(partitionings), _) =>
partitionings.foreach { p =>
reorderJoinKeysRecursively(leftKeys, rightKeys, p, rightPartitioning).map { k =>
return Some(k)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

partitionings.foldLeft(None) { (res, p) =>
  res.orElse(reorderJoinKeysRecursively...)
}.getOrElse(reorderJoinKeysRecursively(leftKeys, rightKeys, None, rightPartitioning))

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, updated.

}
}
reorderJoinKeysRecursively(leftKeys, rightKeys, UnknownPartitioning(0), rightPartitioning)
case (_, PartitioningCollection(partitionings)) =>
partitionings.foreach { p =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you do the same refactor here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

reorderJoinKeysRecursively(leftKeys, rightKeys, leftPartitioning, p).map { k =>
return Some(k)
}
}
None
case _ =>
None
}
}

/**
* When the physical operators are created for JOIN, the ordering of join keys is based on order
* in which the join keys appear in the user query. That might not match with the output
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -994,6 +994,88 @@ class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper {
}
}
}

test("EnsureRequirements.reorder should fallback to the right side HashPartitioning") {
val plan1 = DummySparkPlan(
outputPartitioning = HashPartitioning(exprA :: exprB :: exprC :: Nil, 5))
val plan2 = DummySparkPlan(
outputPartitioning = HashPartitioning(exprB :: exprC :: Nil, 5))
// The left keys cannot be reordered to match the left partitioning, and it should
// fall back to reorder the right side.
val smjExec = SortMergeJoinExec(
exprA :: exprB :: Nil, exprC :: exprB :: Nil, Inner, None, plan1, plan2)
val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(smjExec)
outputPlan match {
case SortMergeJoinExec(leftKeys, rightKeys, _, _,
SortExec(_, _,
ShuffleExchangeExec(HashPartitioning(leftPartitioningExpressions, _), _, _), _),
SortExec(_, _,
DummySparkPlan(_, _, HashPartitioning(rightPartitioningExpressions, _), _, _), _), _) =>
assert(leftKeys !== smjExec.leftKeys)
assert(rightKeys !== smjExec.rightKeys)
assert(leftKeys === leftPartitioningExpressions)
assert(rightKeys === rightPartitioningExpressions)
case _ => fail(outputPlan.toString)
}
}

test("EnsureRequirements.reorder should handle PartitioningCollection") {
// PartitioningCollection on the left side of join.
val plan1 = DummySparkPlan(
outputPartitioning = PartitioningCollection(Seq(
HashPartitioning(exprA :: exprB :: Nil, 5),
HashPartitioning(exprA :: Nil, 5))))
val plan2 = DummySparkPlan()
val smjExec1 = SortMergeJoinExec(
exprB :: exprA :: Nil, exprA :: exprB :: Nil, Inner, None, plan1, plan2)
val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(smjExec1)
outputPlan match {
case SortMergeJoinExec(leftKeys, rightKeys, _, _,
SortExec(_, _,
DummySparkPlan(_, _, PartitioningCollection(leftPartitionings), _, _), _),
SortExec(_, _,
ShuffleExchangeExec(HashPartitioning(rightPartitioningExpressions, _), _, _), _), _) =>
assert(leftKeys !== smjExec1.leftKeys)
assert(rightKeys !== smjExec1.rightKeys)
assert(leftKeys === leftPartitionings(0).asInstanceOf[HashPartitioning].expressions)
assert(rightKeys === rightPartitioningExpressions)
case _ => fail(outputPlan.toString)
}

// PartitioningCollection on the right side of join.
val smjExec2 = SortMergeJoinExec(
exprA :: exprB :: Nil, exprB :: exprA :: Nil, Inner, None, plan2, plan1)
val outputPlan2 = EnsureRequirements(spark.sessionState.conf).apply(smjExec2)
outputPlan2 match {
case SortMergeJoinExec(leftKeys, rightKeys, _, _,
SortExec(_, _,
ShuffleExchangeExec(HashPartitioning(leftPartitioningExpressions, _), _, _), _),
SortExec(_, _,
DummySparkPlan(_, _, PartitioningCollection(rightPartitionings), _, _), _), _) =>
assert(leftKeys !== smjExec2.leftKeys)
assert(rightKeys !== smjExec2.rightKeys)
assert(leftKeys === leftPartitioningExpressions)
assert(rightKeys === rightPartitionings(0).asInstanceOf[HashPartitioning].expressions)
case _ => fail(outputPlan2.toString)
}

// Both sides are PartitioningCollection and falls back to the right side.
val smjExec3 = SortMergeJoinExec(
exprA :: exprC :: Nil, exprB :: exprA :: Nil, Inner, None, plan1, plan1)
val outputPlan3 = EnsureRequirements(spark.sessionState.conf).apply(smjExec2)
outputPlan3 match {
case SortMergeJoinExec(leftKeys, rightKeys, _, _,
SortExec(_, _,
ShuffleExchangeExec(HashPartitioning(leftPartitioningExpressions, _), _, _), _),
SortExec(_, _,
DummySparkPlan(_, _, PartitioningCollection(rightPartitionings), _, _), _), _) =>
assert(leftKeys !== smjExec2.leftKeys)
assert(rightKeys !== smjExec2.rightKeys)
assert(leftKeys === leftPartitioningExpressions)
assert(rightKeys === rightPartitionings(0).asInstanceOf[HashPartitioning].expressions)
case _ => fail(outputPlan3.toString)
}
}
}

// Used for unit-testing EnsureRequirements
Expand Down