Skip to content

Commit db114d3

Browse files
Guo ChenzhaoJkSelf
authored andcommitted
Support left/right outer join in handling data skew feature (apache#60)
* Support left/right outer join in data skew feature * Style * Refactor & style * Modify logic for non-split conditions(join type and skewed side) * Refactor
1 parent 55c157c commit db114d3

3 files changed

Lines changed: 262 additions & 9 deletions

File tree

sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/HandleSkewedJoin.scala

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,10 @@
1717

1818
package org.apache.spark.sql.execution.adaptive
1919

20+
import scala.collection.immutable.Nil
2021
import scala.collection.mutable
2122

22-
import org.apache.spark.sql.catalyst.plans.{Cross, Inner, JoinType, LeftSemi}
23+
import org.apache.spark.sql.catalyst.plans._
2324
import org.apache.spark.sql.catalyst.rules.Rule
2425
import org.apache.spark.sql.execution.{SortExec, SparkPlan, UnionExec}
2526
import org.apache.spark.sql.execution.joins.SortMergeJoinExec
@@ -28,6 +29,8 @@ import org.apache.spark.sql.internal.SQLConf
2829

2930
case class HandleSkewedJoin(conf: SQLConf) extends Rule[SparkPlan] {
3031

32+
private val supportedJoinTypes = Inner :: Cross :: LeftSemi :: LeftOuter:: RightOuter :: Nil
33+
3134
private def isSizeSkewed(size: Long, medianSize: Long): Boolean = {
3235
size > medianSize * conf.adaptiveSkewedFactor &&
3336
size > conf.adaptiveSkewedSizeThreshold
@@ -96,15 +99,26 @@ case class HandleSkewedJoin(conf: SQLConf) extends Rule[SparkPlan] {
9699
equallyDivide(queryStageInput.numMapper, numSplits).toArray
97100
}
98101

102+
/**
103+
* Base optimization support check: the join type is supported and plan statistics is available.
104+
* Note that for some join types(like left outer), whether a certain partition can be optimized
105+
* also depends on the filed isSkewAndSupportsSplit.
106+
*/
99107
private def supportOptimization(
100108
joinType: JoinType,
101109
left: QueryStageInput,
102110
right: QueryStageInput): Boolean = {
103-
(joinType == Inner || joinType == Cross || joinType == LeftSemi) &&
111+
supportedJoinTypes.contains(joinType) &&
104112
left.childStage.stats.getPartitionStatistics.isDefined &&
105113
right.childStage.stats.getPartitionStatistics.isDefined
106114
}
107115

116+
private def supportSplitOnLeftPartition(joinType: JoinType) = joinType != RightOuter
117+
118+
private def supportSplitOnRightPartition(joinType: JoinType) = {
119+
joinType != LeftOuter && joinType != LeftSemi
120+
}
121+
108122
private def handleSkewedJoin(
109123
operator: SparkPlan,
110124
queryStage: QueryStage): SparkPlan = operator.transformUp {
@@ -131,18 +145,21 @@ case class HandleSkewedJoin(conf: SQLConf) extends Rule[SparkPlan] {
131145
for (partitionId <- 0 until numPartitions) {
132146
val isLeftSkew = isSkewed(leftStats, partitionId, leftMedSize, leftMedRowCount)
133147
val isRightSkew = isSkewed(rightStats, partitionId, rightMedSize, rightMedRowCount)
134-
if (isLeftSkew || isRightSkew) {
148+
val isSkewAndSupportsSplit =
149+
(isLeftSkew && supportSplitOnLeftPartition(joinType)) ||
150+
(isRightSkew && supportSplitOnRightPartition(joinType))
151+
152+
if (isSkewAndSupportsSplit) {
135153
skewedPartitions += partitionId
136-
val leftMapIdStartIndices = if (isLeftSkew) {
154+
val leftMapIdStartIndices = if (isLeftSkew && supportSplitOnLeftPartition(joinType)) {
137155
estimateMapIdStartIndices(left, partitionId, leftMedSize, leftMedRowCount)
138156
} else {
139157
Array(0)
140158
}
141-
val rightMapIdStartIndices = if (!isRightSkew || joinType == LeftSemi) {
142-
// For left semi join, we don't split the right partition
143-
Array(0)
144-
} else {
159+
val rightMapIdStartIndices = if (isRightSkew && supportSplitOnRightPartition(joinType)) {
145160
estimateMapIdStartIndices(right, partitionId, rightMedSize, rightMedRowCount)
161+
} else {
162+
Array(0)
146163
}
147164

148165
for (i <- 0 until leftMapIdStartIndices.length;

sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStage.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@ package org.apache.spark.sql.execution.adaptive
1919

2020
import scala.concurrent.{ExecutionContext, Future}
2121
import scala.concurrent.duration.Duration
22-
import org.apache.spark.{MapOutputStatistics, SparkContext, broadcast}
22+
23+
import org.apache.spark.{broadcast, MapOutputStatistics, SparkContext}
2324
import org.apache.spark.rdd.RDD
2425
import org.apache.spark.sql.catalyst.InternalRow
2526
import org.apache.spark.sql.catalyst.expressions._

sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/QueryStageSuite.scala

Lines changed: 235 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -343,6 +343,241 @@ class QueryStageSuite extends SparkFunSuite with BeforeAndAfterAll {
343343
}
344344
}
345345

346+
test("adaptive skewed join: left/right outer join and skewed on right side") {
347+
val spark = defaultSparkSession
348+
spark.conf.set(SQLConf.ADAPTIVE_EXECUTION_JOIN_ENABLED.key, "false")
349+
spark.conf.set(SQLConf.ADAPTIVE_EXECUTION_SKEWED_JOIN_ENABLED.key, "true")
350+
spark.conf.set(SQLConf.ADAPTIVE_EXECUTION_SKEWED_PARTITION_ROW_COUNT_THRESHOLD.key, 10)
351+
withSparkSession(spark) { spark: SparkSession =>
352+
val df1 =
353+
spark
354+
.range(0, 10, 1, 2)
355+
.selectExpr("id % 5 as key1", "id as value1")
356+
val df2 =
357+
spark
358+
.range(0, 1000, 1, numInputPartitions)
359+
.selectExpr("id % 1 as key2", "id as value2")
360+
361+
val leftOuterJoin =
362+
df1.join(df2, col("key1") === col("key2"), "left").select(col("key1"), col("value2"))
363+
val rightOuterJoin =
364+
df1.join(df2, col("key1") === col("key2"), "right").select(col("key1"), col("value2"))
365+
366+
// Before Execution, there is one SortMergeJoin
367+
val smjBeforeExecutionForLeftOuter = leftOuterJoin.queryExecution.executedPlan.collect {
368+
case smj: SortMergeJoinExec => smj
369+
}
370+
assert(smjBeforeExecutionForLeftOuter.length === 1)
371+
372+
val smjBeforeExecutionForRightOuter = leftOuterJoin.queryExecution.executedPlan.collect {
373+
case smj: SortMergeJoinExec => smj
374+
}
375+
assert(smjBeforeExecutionForRightOuter.length === 1)
376+
377+
// Check the answer.
378+
val expectedAnswerForLeftOuter =
379+
spark
380+
.range(0, 1000)
381+
.selectExpr("0 as key", "id as value")
382+
.union(spark.range(0, 1000).selectExpr("0 as key", "id as value"))
383+
.union(spark.range(0, 10, 1).filter(_ % 5 != 0).selectExpr("id % 5 as key1", "null"))
384+
checkAnswer(
385+
leftOuterJoin,
386+
expectedAnswerForLeftOuter.collect())
387+
388+
val expectedAnswerForRightOuter =
389+
spark
390+
.range(0, 1000)
391+
.selectExpr("0 as key", "id as value")
392+
.union(spark.range(0, 1000).selectExpr("0 as key", "id as value"))
393+
checkAnswer(
394+
rightOuterJoin,
395+
expectedAnswerForRightOuter.collect())
396+
397+
// For the left outer join case: during execution, the SMJ can not be translated to any sub
398+
// joins due to the skewed side is on the right but the join type is left outer
399+
// (not correspond with each other)
400+
val smjAfterExecutionForLeftOuter = leftOuterJoin.queryExecution.executedPlan.collect {
401+
case smj: SortMergeJoinExec => smj
402+
}
403+
assert(smjAfterExecutionForLeftOuter.length === 1)
404+
405+
// For the right outer join case: during execution, the SMJ is changed to Union of SMJ + 5 SMJ
406+
// joins due to the skewed side is on the right and the join type is right
407+
// outer (correspond with each other)
408+
val smjAfterExecutionForRightOuter = rightOuterJoin.queryExecution.executedPlan.collect {
409+
case smj: SortMergeJoinExec => smj
410+
}
411+
412+
assert(smjAfterExecutionForRightOuter.length === 6)
413+
val queryStageInputs = rightOuterJoin.queryExecution.executedPlan.collect {
414+
case q: ShuffleQueryStageInput => q
415+
}
416+
assert(queryStageInputs.length === 2)
417+
assert(queryStageInputs(0).skewedPartitions === queryStageInputs(1).skewedPartitions)
418+
assert(queryStageInputs(0).skewedPartitions === Some(Set(0)))
419+
420+
}
421+
}
422+
423+
test("adaptive skewed join: left/right outer join and skewed on left side") {
424+
val spark = defaultSparkSession
425+
spark.conf.set(SQLConf.ADAPTIVE_EXECUTION_JOIN_ENABLED.key, "false")
426+
spark.conf.set(SQLConf.ADAPTIVE_EXECUTION_SKEWED_JOIN_ENABLED.key, "true")
427+
spark.conf.set(SQLConf.ADAPTIVE_EXECUTION_SKEWED_PARTITION_ROW_COUNT_THRESHOLD.key, 10)
428+
withSparkSession(spark) { spark: SparkSession =>
429+
val df1 =
430+
spark
431+
.range(0, 1000, 1, numInputPartitions)
432+
.selectExpr("id % 1 as key1", "id as value1")
433+
val df2 =
434+
spark
435+
.range(0, 10, 1, 2)
436+
.selectExpr("id % 5 as key2", "id as value2")
437+
438+
val leftOuterJoin =
439+
df1.join(df2, col("key1") === col("key2"), "left").select(col("key1"), col("value1"))
440+
val rightOuterJoin =
441+
df1.join(df2, col("key1") === col("key2"), "right").select(col("key1"), col("value1"))
442+
443+
// Before Execution, there is one SortMergeJoin
444+
val smjBeforeExecutionForLeftOuter = leftOuterJoin.queryExecution.executedPlan.collect {
445+
case smj: SortMergeJoinExec => smj
446+
}
447+
assert(smjBeforeExecutionForLeftOuter.length === 1)
448+
449+
val smjBeforeExecutionForRightOuter = leftOuterJoin.queryExecution.executedPlan.collect {
450+
case smj: SortMergeJoinExec => smj
451+
}
452+
assert(smjBeforeExecutionForRightOuter.length === 1)
453+
454+
// Check the answer.
455+
val expectedAnswerForLeftOuter =
456+
spark
457+
.range(0, 1000)
458+
.selectExpr("0 as key", "id as value")
459+
.union(spark.range(0, 1000).selectExpr("0 as key", "id as value"))
460+
checkAnswer(
461+
leftOuterJoin,
462+
expectedAnswerForLeftOuter.collect())
463+
464+
val expectedAnswerForRightOuter =
465+
spark
466+
.range(0, 1000)
467+
.selectExpr("0 as key", "id as value")
468+
.union(spark.range(0, 1000).selectExpr("0 as key", "id as value"))
469+
.union(spark.range(0, 10, 1).filter(_ % 5 != 0).selectExpr("null", "null"))
470+
471+
checkAnswer(
472+
rightOuterJoin,
473+
expectedAnswerForRightOuter.collect())
474+
475+
// For the left outer join case: during execution, the SMJ is changed to Union of SMJ + 5 SMJ
476+
// joins due to the skewed side is on the left and the join type is left outer
477+
// (correspond with each other)
478+
val smjAfterExecutionForLeftOuter = leftOuterJoin.queryExecution.executedPlan.collect {
479+
case smj: SortMergeJoinExec => smj
480+
}
481+
assert(smjAfterExecutionForLeftOuter.length === 6)
482+
483+
// For the right outer join case: during execution, the SMJ can not be translated to any sub
484+
// joins due to the skewed side is on the left but the join type is right outer
485+
// (not correspond with each other)
486+
val smjAfterExecutionForRightOuter = rightOuterJoin.queryExecution.executedPlan.collect {
487+
case smj: SortMergeJoinExec => smj
488+
}
489+
490+
assert(smjAfterExecutionForRightOuter.length === 1)
491+
val queryStageInputs = leftOuterJoin.queryExecution.executedPlan.collect {
492+
case q: ShuffleQueryStageInput => q
493+
}
494+
assert(queryStageInputs.length === 2)
495+
assert(queryStageInputs(0).skewedPartitions === queryStageInputs(1).skewedPartitions)
496+
assert(queryStageInputs(0).skewedPartitions === Some(Set(0)))
497+
498+
}
499+
}
500+
501+
test("adaptive skewed join: left/right outer join and skewed on both sides") {
502+
val spark = defaultSparkSession
503+
spark.conf.set(SQLConf.ADAPTIVE_EXECUTION_JOIN_ENABLED.key, "false")
504+
spark.conf.set(SQLConf.ADAPTIVE_EXECUTION_SKEWED_JOIN_ENABLED.key, "true")
505+
spark.conf.set(SQLConf.ADAPTIVE_EXECUTION_SKEWED_PARTITION_ROW_COUNT_THRESHOLD.key, 10)
506+
withSparkSession(spark) { spark: SparkSession =>
507+
import spark.implicits._
508+
val df1 =
509+
spark
510+
.range(0, 100, 1, numInputPartitions)
511+
.selectExpr("id % 1 as key1", "id as value1")
512+
val df2 =
513+
spark
514+
.range(0, 100, 1, numInputPartitions)
515+
.selectExpr("id % 1 as key2", "id as value2")
516+
517+
val leftOuterJoin =
518+
df1.join(df2, col("key1") === col("key2"), "left").select(col("key1"), col("value2"))
519+
val rightOuterJoin =
520+
df1.join(df2, col("key1") === col("key2"), "right").select(col("key1"), col("value2"))
521+
522+
// Before Execution, there is one SortMergeJoin
523+
val smjBeforeExecutionForLeftOuter = leftOuterJoin.queryExecution.executedPlan.collect {
524+
case smj: SortMergeJoinExec => smj
525+
}
526+
assert(smjBeforeExecutionForLeftOuter.length === 1)
527+
528+
val smjBeforeExecutionForRightOuter = leftOuterJoin.queryExecution.executedPlan.collect {
529+
case smj: SortMergeJoinExec => smj
530+
}
531+
assert(smjBeforeExecutionForRightOuter.length === 1)
532+
533+
// Check the answer.
534+
val expectedAnswerForLeftOuter =
535+
spark
536+
.range(0, 100)
537+
.flatMap(i => Seq.fill(100)(i))
538+
.selectExpr("0 as key", "value")
539+
540+
checkAnswer(
541+
leftOuterJoin,
542+
expectedAnswerForLeftOuter.collect())
543+
544+
val expectedAnswerForRightOuter =
545+
spark
546+
.range(0, 100)
547+
.flatMap(i => Seq.fill(100)(i))
548+
.selectExpr("0 as key", "value")
549+
checkAnswer(
550+
rightOuterJoin,
551+
expectedAnswerForRightOuter.collect())
552+
553+
// For the left outer join case: during execution, although the skewed sides include the
554+
// right, the SMJ is still changed to Union of SMJ + 5 SMJ joins due to the skewed sides
555+
// also include the left, so we split the left skewed partition
556+
// (correspondence exists)
557+
val smjAfterExecutionForLeftOuter = leftOuterJoin.queryExecution.executedPlan.collect {
558+
case smj: SortMergeJoinExec => smj
559+
}
560+
assert(smjAfterExecutionForLeftOuter.length === 6)
561+
562+
// For the right outer join case: during execution, although the skewed sides include the
563+
// left, the SMJ is still changed to Union of SMJ + 5 SMJ joins due to the skewed sides
564+
// also include the right, so we split the right skewed partition
565+
// (correspondence exists)
566+
val smjAfterExecutionForRightOuter = rightOuterJoin.queryExecution.executedPlan.collect {
567+
case smj: SortMergeJoinExec => smj
568+
}
569+
570+
assert(smjAfterExecutionForRightOuter.length === 6)
571+
val queryStageInputs = rightOuterJoin.queryExecution.executedPlan.collect {
572+
case q: ShuffleQueryStageInput => q
573+
}
574+
assert(queryStageInputs.length === 2)
575+
assert(queryStageInputs(0).skewedPartitions === queryStageInputs(1).skewedPartitions)
576+
assert(queryStageInputs(0).skewedPartitions === Some(Set(0)))
577+
578+
}
579+
}
580+
346581
test("row count statistics, compressed") {
347582
val spark = defaultSparkSession
348583
withSparkSession(spark) { spark: SparkSession =>

0 commit comments

Comments
 (0)