@@ -63,26 +63,24 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext {
6363 }
6464
6565 test(" randomSplit on reordered partitions" ) {
66- val n = 600
6766 // This test ensures that randomSplit does not create overlapping splits even when the
6867 // underlying dataframe (such as the one below) doesn't guarantee a deterministic ordering of
6968 // rows in each partition.
7069 val data =
71- sparkContext.parallelize(1 to n, 2 ).mapPartitions(scala.util.Random .shuffle(_)).toDF(" id" )
72- val splits = data.randomSplit(Array [Double ](1 , 2 , 3 ), seed = 1 )
73- assert(splits.length == 3 , " wrong number of splits" )
70+ sparkContext.parallelize(1 to 600 , 2 ).mapPartitions(scala.util.Random .shuffle(_)).toDF(" id" )
71+ val splits = data.randomSplit(Array [Double ](2 , 3 ), seed = 1 )
7472
75- assert(splits.reduce((a, b) => a.unionAll(b)).sort(" id" ).collect().toList ==
76- data.sort($" id" ).collect().toList, " incomplete or wrong split" )
73+ assert(splits.length == 2 , " wrong number of splits" )
7774
78- for (id <- splits.indices) {
79- assert(splits(id).intersect(splits((id + 1 ) % splits.length)).collect().isEmpty,
80- s " split $id overlaps with split ${(id + 1 ) % splits.length}" )
81- }
75+ // Verify that the splits span the entire dataset
76+ assert(splits.flatMap(_.collect()).toSet == data.collect().toSet)
77+
78+ // Verify that the splits don't overalap
79+ assert(splits(0 ).intersect(splits(1 )).collect().isEmpty)
8280
8381 // Verify that the results are deterministic across multiple runs
8482 val firstRun = splits.toSeq.map(_.collect().toSeq)
85- val secondRun = data.randomSplit(Array [Double ](1 , 2 , 3 ), seed = 1 ).toSeq.map(_.collect().toSeq)
83+ val secondRun = data.randomSplit(Array [Double ](2 , 3 ), seed = 1 ).toSeq.map(_.collect().toSeq)
8684 assert(firstRun == secondRun)
8785 }
8886
0 commit comments