@@ -20,6 +20,7 @@ package org.apache.spark.sql.streaming
2020import java .util .UUID
2121
2222import org .apache .spark .rdd .RDD
23+ import org .apache .spark .sql .DataFrame
2324import org .apache .spark .sql .catalyst .InternalRow
2425import org .apache .spark .sql .catalyst .analysis .UnresolvedAttribute
2526import org .apache .spark .sql .catalyst .expressions .Attribute
@@ -32,66 +33,71 @@ import org.apache.spark.sql.test.SharedSQLContext
3233class EnsureStatefulOpPartitioningSuite extends SparkPlanTest with SharedSQLContext {
3334
3435 import testImplicits ._
35- super .beforeAll()
3636
37- private val baseDf = Seq (( 1 , " A " ), ( 2 , " b " )).toDF( " num " , " char " )
37+ private var baseDf : DataFrame = null
3838
39- testEnsureStatefulOpPartitioning(
40- " ClusteredDistribution generates Exchange with HashPartitioning" ,
41- baseDf.queryExecution.sparkPlan,
42- requiredDistribution = keys => ClusteredDistribution (keys),
43- expectedPartitioning =
44- keys => HashPartitioning (keys, spark.sessionState.conf.numShufflePartitions),
45- expectShuffle = true )
39+ override def beforeAll (): Unit = {
40+ super .beforeAll()
41+ baseDf = Seq ((1 , " A" ), (2 , " b" )).toDF(" num" , " char" )
42+ }
43+
44+ test(" ClusteredDistribution generates Exchange with HashPartitioning" ) {
45+ testEnsureStatefulOpPartitioning(
46+ baseDf.queryExecution.sparkPlan,
47+ requiredDistribution = keys => ClusteredDistribution (keys),
48+ expectedPartitioning =
49+ keys => HashPartitioning (keys, spark.sessionState.conf.numShufflePartitions),
50+ expectShuffle = true )
51+ }
4652
47- testEnsureStatefulOpPartitioning(
48- " ClusteredDistribution with coalesce(1) generates Exchange with HashPartitioning" ,
49- baseDf.coalesce(1 ).queryExecution.sparkPlan,
50- requiredDistribution = keys => ClusteredDistribution (keys),
51- expectedPartitioning =
52- keys => HashPartitioning (keys, spark.sessionState.conf.numShufflePartitions),
53- expectShuffle = true )
53+ test(" ClusteredDistribution with coalesce(1) generates Exchange with HashPartitioning" ) {
54+ testEnsureStatefulOpPartitioning(
55+ baseDf.coalesce(1 ).queryExecution.sparkPlan,
56+ requiredDistribution = keys => ClusteredDistribution (keys),
57+ expectedPartitioning =
58+ keys => HashPartitioning (keys, spark.sessionState.conf.numShufflePartitions),
59+ expectShuffle = true )
60+ }
5461
55- testEnsureStatefulOpPartitioning(
56- " AllTuples generates Exchange with SinglePartition" ,
57- baseDf.queryExecution.sparkPlan,
58- requiredDistribution = _ => AllTuples ,
59- expectedPartitioning = _ => SinglePartition ,
60- expectShuffle = true )
62+ test(" AllTuples generates Exchange with SinglePartition" ) {
63+ testEnsureStatefulOpPartitioning(
64+ baseDf.queryExecution.sparkPlan,
65+ requiredDistribution = _ => AllTuples ,
66+ expectedPartitioning = _ => SinglePartition ,
67+ expectShuffle = true )
68+ }
6169
62- testEnsureStatefulOpPartitioning(
63- " AllTuples with coalesce(1) doesn't need Exchange" ,
64- baseDf.coalesce(1 ).queryExecution.sparkPlan,
65- requiredDistribution = _ => AllTuples ,
66- expectedPartitioning = _ => SinglePartition ,
67- expectShuffle = false )
70+ test(" AllTuples with coalesce(1) doesn't need Exchange" ) {
71+ testEnsureStatefulOpPartitioning(
72+ baseDf.coalesce(1 ).queryExecution.sparkPlan,
73+ requiredDistribution = _ => AllTuples ,
74+ expectedPartitioning = _ => SinglePartition ,
75+ expectShuffle = false )
76+ }
6877
6978 /**
7079 * For `StatefulOperator` with the given `requiredChildDistribution`, and child SparkPlan
7180 * `inputPlan`, ensures that the incremental planner adds exchanges, if required, in order to
7281 * ensure the expected partitioning.
7382 */
7483 private def testEnsureStatefulOpPartitioning (
75- testName : String ,
7684 inputPlan : SparkPlan ,
7785 requiredDistribution : Seq [Attribute ] => Distribution ,
7886 expectedPartitioning : Seq [Attribute ] => Partitioning ,
7987 expectShuffle : Boolean ): Unit = {
80- test(testName) {
81- val operator = TestStatefulOperator (inputPlan, requiredDistribution(inputPlan.output.take(1 )))
82- val executed = executePlan(operator, OutputMode .Complete ())
83- if (expectShuffle) {
84- val exchange = executed.children.find(_.isInstanceOf [Exchange ])
85- if (exchange.isEmpty) {
86- fail(s " Was expecting an exchange but didn't get one in: \n $executed" )
87- }
88- assert(exchange.get ===
89- ShuffleExchange (expectedPartitioning(inputPlan.output.take(1 )), inputPlan),
90- s " Exchange didn't have expected properties: \n ${exchange.get}" )
91- } else {
92- assert(! executed.children.exists(_.isInstanceOf [Exchange ]),
93- s " Unexpected exchange found in: \n $executed" )
88+ val operator = TestStatefulOperator (inputPlan, requiredDistribution(inputPlan.output.take(1 )))
89+ val executed = executePlan(operator, OutputMode .Complete ())
90+ if (expectShuffle) {
91+ val exchange = executed.children.find(_.isInstanceOf [Exchange ])
92+ if (exchange.isEmpty) {
93+ fail(s " Was expecting an exchange but didn't get one in: \n $executed" )
9494 }
95+ assert(exchange.get ===
96+ ShuffleExchange (expectedPartitioning(inputPlan.output.take(1 )), inputPlan),
97+ s " Exchange didn't have expected properties: \n ${exchange.get}" )
98+ } else {
99+ assert(! executed.children.exists(_.isInstanceOf [Exchange ]),
100+ s " Unexpected exchange found in: \n $executed" )
95101 }
96102 }
97103
0 commit comments