@@ -343,6 +343,241 @@ class QueryStageSuite extends SparkFunSuite with BeforeAndAfterAll {
343343 }
344344 }
345345
346+ test(" adaptive skewed join: left/right outer join and skewed on right side" ) {
347+ val spark = defaultSparkSession
348+ spark.conf.set(SQLConf .ADAPTIVE_EXECUTION_JOIN_ENABLED .key, " false" )
349+ spark.conf.set(SQLConf .ADAPTIVE_EXECUTION_SKEWED_JOIN_ENABLED .key, " true" )
350+ spark.conf.set(SQLConf .ADAPTIVE_EXECUTION_SKEWED_PARTITION_ROW_COUNT_THRESHOLD .key, 10 )
351+ withSparkSession(spark) { spark : SparkSession =>
352+ val df1 =
353+ spark
354+ .range(0 , 10 , 1 , 2 )
355+ .selectExpr(" id % 5 as key1" , " id as value1" )
356+ val df2 =
357+ spark
358+ .range(0 , 1000 , 1 , numInputPartitions)
359+ .selectExpr(" id % 1 as key2" , " id as value2" )
360+
361+ val leftOuterJoin =
362+ df1.join(df2, col(" key1" ) === col(" key2" ), " left" ).select(col(" key1" ), col(" value2" ))
363+ val rightOuterJoin =
364+ df1.join(df2, col(" key1" ) === col(" key2" ), " right" ).select(col(" key1" ), col(" value2" ))
365+
366+ // Before Execution, there is one SortMergeJoin
367+ val smjBeforeExecutionForLeftOuter = leftOuterJoin.queryExecution.executedPlan.collect {
368+ case smj : SortMergeJoinExec => smj
369+ }
370+ assert(smjBeforeExecutionForLeftOuter.length === 1 )
371+
372+ val smjBeforeExecutionForRightOuter = leftOuterJoin.queryExecution.executedPlan.collect {
373+ case smj : SortMergeJoinExec => smj
374+ }
375+ assert(smjBeforeExecutionForRightOuter.length === 1 )
376+
377+ // Check the answer.
378+ val expectedAnswerForLeftOuter =
379+ spark
380+ .range(0 , 1000 )
381+ .selectExpr(" 0 as key" , " id as value" )
382+ .union(spark.range(0 , 1000 ).selectExpr(" 0 as key" , " id as value" ))
383+ .union(spark.range(0 , 10 , 1 ).filter(_ % 5 != 0 ).selectExpr(" id % 5 as key1" , " null" ))
384+ checkAnswer(
385+ leftOuterJoin,
386+ expectedAnswerForLeftOuter.collect())
387+
388+ val expectedAnswerForRightOuter =
389+ spark
390+ .range(0 , 1000 )
391+ .selectExpr(" 0 as key" , " id as value" )
392+ .union(spark.range(0 , 1000 ).selectExpr(" 0 as key" , " id as value" ))
393+ checkAnswer(
394+ rightOuterJoin,
395+ expectedAnswerForRightOuter.collect())
396+
397+ // For the left outer join case: during execution, the SMJ can not be translated to any sub
398+ // joins due to the skewed side is on the right but the join type is left outer
399+ // (not correspond with each other)
400+ val smjAfterExecutionForLeftOuter = leftOuterJoin.queryExecution.executedPlan.collect {
401+ case smj : SortMergeJoinExec => smj
402+ }
403+ assert(smjAfterExecutionForLeftOuter.length === 1 )
404+
405+ // For the right outer join case: during execution, the SMJ is changed to Union of SMJ + 5 SMJ
406+ // joins due to the skewed side is on the right and the join type is right
407+ // outer (correspond with each other)
408+ val smjAfterExecutionForRightOuter = rightOuterJoin.queryExecution.executedPlan.collect {
409+ case smj : SortMergeJoinExec => smj
410+ }
411+
412+ assert(smjAfterExecutionForRightOuter.length === 6 )
413+ val queryStageInputs = rightOuterJoin.queryExecution.executedPlan.collect {
414+ case q : ShuffleQueryStageInput => q
415+ }
416+ assert(queryStageInputs.length === 2 )
417+ assert(queryStageInputs(0 ).skewedPartitions === queryStageInputs(1 ).skewedPartitions)
418+ assert(queryStageInputs(0 ).skewedPartitions === Some (Set (0 )))
419+
420+ }
421+ }
422+
423+ test(" adaptive skewed join: left/right outer join and skewed on left side" ) {
424+ val spark = defaultSparkSession
425+ spark.conf.set(SQLConf .ADAPTIVE_EXECUTION_JOIN_ENABLED .key, " false" )
426+ spark.conf.set(SQLConf .ADAPTIVE_EXECUTION_SKEWED_JOIN_ENABLED .key, " true" )
427+ spark.conf.set(SQLConf .ADAPTIVE_EXECUTION_SKEWED_PARTITION_ROW_COUNT_THRESHOLD .key, 10 )
428+ withSparkSession(spark) { spark : SparkSession =>
429+ val df1 =
430+ spark
431+ .range(0 , 1000 , 1 , numInputPartitions)
432+ .selectExpr(" id % 1 as key1" , " id as value1" )
433+ val df2 =
434+ spark
435+ .range(0 , 10 , 1 , 2 )
436+ .selectExpr(" id % 5 as key2" , " id as value2" )
437+
438+ val leftOuterJoin =
439+ df1.join(df2, col(" key1" ) === col(" key2" ), " left" ).select(col(" key1" ), col(" value1" ))
440+ val rightOuterJoin =
441+ df1.join(df2, col(" key1" ) === col(" key2" ), " right" ).select(col(" key1" ), col(" value1" ))
442+
443+ // Before Execution, there is one SortMergeJoin
444+ val smjBeforeExecutionForLeftOuter = leftOuterJoin.queryExecution.executedPlan.collect {
445+ case smj : SortMergeJoinExec => smj
446+ }
447+ assert(smjBeforeExecutionForLeftOuter.length === 1 )
448+
449+ val smjBeforeExecutionForRightOuter = leftOuterJoin.queryExecution.executedPlan.collect {
450+ case smj : SortMergeJoinExec => smj
451+ }
452+ assert(smjBeforeExecutionForRightOuter.length === 1 )
453+
454+ // Check the answer.
455+ val expectedAnswerForLeftOuter =
456+ spark
457+ .range(0 , 1000 )
458+ .selectExpr(" 0 as key" , " id as value" )
459+ .union(spark.range(0 , 1000 ).selectExpr(" 0 as key" , " id as value" ))
460+ checkAnswer(
461+ leftOuterJoin,
462+ expectedAnswerForLeftOuter.collect())
463+
464+ val expectedAnswerForRightOuter =
465+ spark
466+ .range(0 , 1000 )
467+ .selectExpr(" 0 as key" , " id as value" )
468+ .union(spark.range(0 , 1000 ).selectExpr(" 0 as key" , " id as value" ))
469+ .union(spark.range(0 , 10 , 1 ).filter(_ % 5 != 0 ).selectExpr(" null" , " null" ))
470+
471+ checkAnswer(
472+ rightOuterJoin,
473+ expectedAnswerForRightOuter.collect())
474+
475+ // For the left outer join case: during execution, the SMJ is changed to Union of SMJ + 5 SMJ
476+ // joins due to the skewed side is on the left and the join type is left outer
477+ // (correspond with each other)
478+ val smjAfterExecutionForLeftOuter = leftOuterJoin.queryExecution.executedPlan.collect {
479+ case smj : SortMergeJoinExec => smj
480+ }
481+ assert(smjAfterExecutionForLeftOuter.length === 6 )
482+
483+ // For the right outer join case: during execution, the SMJ can not be translated to any sub
484+ // joins due to the skewed side is on the left but the join type is right outer
485+ // (not correspond with each other)
486+ val smjAfterExecutionForRightOuter = rightOuterJoin.queryExecution.executedPlan.collect {
487+ case smj : SortMergeJoinExec => smj
488+ }
489+
490+ assert(smjAfterExecutionForRightOuter.length === 1 )
491+ val queryStageInputs = leftOuterJoin.queryExecution.executedPlan.collect {
492+ case q : ShuffleQueryStageInput => q
493+ }
494+ assert(queryStageInputs.length === 2 )
495+ assert(queryStageInputs(0 ).skewedPartitions === queryStageInputs(1 ).skewedPartitions)
496+ assert(queryStageInputs(0 ).skewedPartitions === Some (Set (0 )))
497+
498+ }
499+ }
500+
501+ test(" adaptive skewed join: left/right outer join and skewed on both sides" ) {
502+ val spark = defaultSparkSession
503+ spark.conf.set(SQLConf .ADAPTIVE_EXECUTION_JOIN_ENABLED .key, " false" )
504+ spark.conf.set(SQLConf .ADAPTIVE_EXECUTION_SKEWED_JOIN_ENABLED .key, " true" )
505+ spark.conf.set(SQLConf .ADAPTIVE_EXECUTION_SKEWED_PARTITION_ROW_COUNT_THRESHOLD .key, 10 )
506+ withSparkSession(spark) { spark : SparkSession =>
507+ import spark .implicits ._
508+ val df1 =
509+ spark
510+ .range(0 , 100 , 1 , numInputPartitions)
511+ .selectExpr(" id % 1 as key1" , " id as value1" )
512+ val df2 =
513+ spark
514+ .range(0 , 100 , 1 , numInputPartitions)
515+ .selectExpr(" id % 1 as key2" , " id as value2" )
516+
517+ val leftOuterJoin =
518+ df1.join(df2, col(" key1" ) === col(" key2" ), " left" ).select(col(" key1" ), col(" value2" ))
519+ val rightOuterJoin =
520+ df1.join(df2, col(" key1" ) === col(" key2" ), " right" ).select(col(" key1" ), col(" value2" ))
521+
522+ // Before Execution, there is one SortMergeJoin
523+ val smjBeforeExecutionForLeftOuter = leftOuterJoin.queryExecution.executedPlan.collect {
524+ case smj : SortMergeJoinExec => smj
525+ }
526+ assert(smjBeforeExecutionForLeftOuter.length === 1 )
527+
528+ val smjBeforeExecutionForRightOuter = leftOuterJoin.queryExecution.executedPlan.collect {
529+ case smj : SortMergeJoinExec => smj
530+ }
531+ assert(smjBeforeExecutionForRightOuter.length === 1 )
532+
533+ // Check the answer.
534+ val expectedAnswerForLeftOuter =
535+ spark
536+ .range(0 , 100 )
537+ .flatMap(i => Seq .fill(100 )(i))
538+ .selectExpr(" 0 as key" , " value" )
539+
540+ checkAnswer(
541+ leftOuterJoin,
542+ expectedAnswerForLeftOuter.collect())
543+
544+ val expectedAnswerForRightOuter =
545+ spark
546+ .range(0 , 100 )
547+ .flatMap(i => Seq .fill(100 )(i))
548+ .selectExpr(" 0 as key" , " value" )
549+ checkAnswer(
550+ rightOuterJoin,
551+ expectedAnswerForRightOuter.collect())
552+
553+ // For the left outer join case: during execution, although the skewed sides include the
554+ // right, the SMJ is still changed to Union of SMJ + 5 SMJ joins due to the skewed sides
555+ // also include the left, so we split the left skewed partition
556+ // (correspondence exists)
557+ val smjAfterExecutionForLeftOuter = leftOuterJoin.queryExecution.executedPlan.collect {
558+ case smj : SortMergeJoinExec => smj
559+ }
560+ assert(smjAfterExecutionForLeftOuter.length === 6 )
561+
562+ // For the right outer join case: during execution, although the skewed sides include the
563+ // left, the SMJ is still changed to Union of SMJ + 5 SMJ joins due to the skewed sides
564+ // also include the right, so we split the right skewed partition
565+ // (correspondence exists)
566+ val smjAfterExecutionForRightOuter = rightOuterJoin.queryExecution.executedPlan.collect {
567+ case smj : SortMergeJoinExec => smj
568+ }
569+
570+ assert(smjAfterExecutionForRightOuter.length === 6 )
571+ val queryStageInputs = rightOuterJoin.queryExecution.executedPlan.collect {
572+ case q : ShuffleQueryStageInput => q
573+ }
574+ assert(queryStageInputs.length === 2 )
575+ assert(queryStageInputs(0 ).skewedPartitions === queryStageInputs(1 ).skewedPartitions)
576+ assert(queryStageInputs(0 ).skewedPartitions === Some (Set (0 )))
577+
578+ }
579+ }
580+
346581 test(" row count statistics, compressed" ) {
347582 val spark = defaultSparkSession
348583 withSparkSession(spark) { spark : SparkSession =>
0 commit comments