@@ -24,8 +24,7 @@ import org.apache.spark.sql.catalyst.expressions._
2424import org .apache .spark .sql .catalyst .plans .physical ._
2525import org .apache .spark .sql .catalyst .rules .Rule
2626import org .apache .spark .sql .execution ._
27- import org .apache .spark .sql .execution .joins .{BroadcastHashJoinExec , ShuffledHashJoinExec ,
28- SortMergeJoinExec }
27+ import org .apache .spark .sql .execution .joins .{ShuffledHashJoinExec , SortMergeJoinExec }
2928import org .apache .spark .sql .internal .SQLConf
3029
3130/**
@@ -117,25 +116,41 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] {
117116 }
118117
119118 private def reorder (
120- leftKeys : Seq [Expression ],
121- rightKeys : Seq [Expression ],
119+ leftKeys : IndexedSeq [Expression ],
120+ rightKeys : IndexedSeq [Expression ],
122121 expectedOrderOfKeys : Seq [Expression ],
123122 currentOrderOfKeys : Seq [Expression ]): (Seq [Expression ], Seq [Expression ]) = {
124- val leftKeysBuffer = ArrayBuffer [Expression ]()
125- val rightKeysBuffer = ArrayBuffer [Expression ]()
126- val pickedIndexes = mutable.Set [Int ]()
127- val keysAndIndexes = currentOrderOfKeys.zipWithIndex
128-
129- expectedOrderOfKeys.foreach(expression => {
130- val index = keysAndIndexes.find { case (e, idx) =>
131- // As we may have the same key used many times, we need to filter out its occurrence we
132- // have already used.
133- e.semanticEquals(expression) && ! pickedIndexes.contains(idx)
134- }.map(_._2).get
135- pickedIndexes += index
136- leftKeysBuffer.append(leftKeys(index))
137- rightKeysBuffer.append(rightKeys(index))
138- })
123+ if (expectedOrderOfKeys.size != currentOrderOfKeys.size) {
124+ return (leftKeys, rightKeys)
125+ }
126+
127+ // Build a lookup between an expression and the positions its holds in the current key seq.
128+ val keyToIndexMap = mutable.Map .empty[Expression , mutable.BitSet ]
129+ currentOrderOfKeys.zipWithIndex.foreach {
130+ case (key, index) =>
131+ keyToIndexMap.getOrElseUpdate(key.canonicalized, mutable.BitSet .empty).add(index)
132+ }
133+
134+ // Reorder the keys.
135+ val leftKeysBuffer = new ArrayBuffer [Expression ](leftKeys.size)
136+ val rightKeysBuffer = new ArrayBuffer [Expression ](rightKeys.size)
137+ val iterator = expectedOrderOfKeys.iterator
138+ while (iterator.hasNext) {
139+ // Lookup the current index of this key.
140+ keyToIndexMap.get(iterator.next().canonicalized) match {
141+ case Some (indices) if indices.nonEmpty =>
142+ // Take the first available index from the map.
143+ val index = indices.firstKey
144+ indices.remove(index)
145+
146+ // Add the keys for that index to the reordered keys.
147+ leftKeysBuffer += leftKeys(index)
148+ rightKeysBuffer += rightKeys(index)
149+ case _ =>
150+ // The expression cannot be found, or we have exhausted all indices for that expression.
151+ return (leftKeys, rightKeys)
152+ }
153+ }
139154 (leftKeysBuffer, rightKeysBuffer)
140155 }
141156
@@ -145,20 +160,13 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] {
145160 leftPartitioning : Partitioning ,
146161 rightPartitioning : Partitioning ): (Seq [Expression ], Seq [Expression ]) = {
147162 if (leftKeys.forall(_.deterministic) && rightKeys.forall(_.deterministic)) {
148- leftPartitioning match {
149- case HashPartitioning (leftExpressions, _)
150- if leftExpressions.length == leftKeys.length &&
151- leftKeys.forall(x => leftExpressions.exists(_.semanticEquals(x))) =>
152- reorder(leftKeys, rightKeys, leftExpressions, leftKeys)
153-
154- case _ => rightPartitioning match {
155- case HashPartitioning (rightExpressions, _)
156- if rightExpressions.length == rightKeys.length &&
157- rightKeys.forall(x => rightExpressions.exists(_.semanticEquals(x))) =>
158- reorder(leftKeys, rightKeys, rightExpressions, rightKeys)
159-
160- case _ => (leftKeys, rightKeys)
161- }
163+ (leftPartitioning, rightPartitioning) match {
164+ case (HashPartitioning (leftExpressions, _), _) =>
165+ reorder(leftKeys.toIndexedSeq, rightKeys.toIndexedSeq, leftExpressions, leftKeys)
166+ case (_, HashPartitioning (rightExpressions, _)) =>
167+ reorder(leftKeys.toIndexedSeq, rightKeys.toIndexedSeq, rightExpressions, rightKeys)
168+ case _ =>
169+ (leftKeys, rightKeys)
162170 }
163171 } else {
164172 (leftKeys, rightKeys)
0 commit comments