Skip to content

Commit 1b30119

Browse files
committed
Simplify test
1 parent 3af6a2d commit 1b30119

1 file changed

Lines changed: 9 additions & 11 deletions

File tree

sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -63,26 +63,24 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext {
6363
}
6464

6565
test("randomSplit on reordered partitions") {
66-
val n = 600
6766
// This test ensures that randomSplit does not create overlapping splits even when the
6867
// underlying dataframe (such as the one below) doesn't guarantee a deterministic ordering of
6968
// rows in each partition.
7069
val data =
71-
sparkContext.parallelize(1 to n, 2).mapPartitions(scala.util.Random.shuffle(_)).toDF("id")
72-
val splits = data.randomSplit(Array[Double](1, 2, 3), seed = 1)
73-
assert(splits.length == 3, "wrong number of splits")
70+
sparkContext.parallelize(1 to 600, 2).mapPartitions(scala.util.Random.shuffle(_)).toDF("id")
71+
val splits = data.randomSplit(Array[Double](2, 3), seed = 1)
7472

75-
assert(splits.reduce((a, b) => a.unionAll(b)).sort("id").collect().toList ==
76-
data.sort($"id").collect().toList, "incomplete or wrong split")
73+
assert(splits.length == 2, "wrong number of splits")
7774

78-
for (id <- splits.indices) {
79-
assert(splits(id).intersect(splits((id + 1) % splits.length)).collect().isEmpty,
80-
s"split $id overlaps with split ${(id + 1) % splits.length}")
81-
}
75+
// Verify that the splits span the entire dataset
76+
assert(splits.flatMap(_.collect()).toSet == data.collect().toSet)
77+
78+
// Verify that the splits don't overalap
79+
assert(splits(0).intersect(splits(1)).collect().isEmpty)
8280

8381
// Verify that the results are deterministic across multiple runs
8482
val firstRun = splits.toSeq.map(_.collect().toSeq)
85-
val secondRun = data.randomSplit(Array[Double](1, 2, 3), seed = 1).toSeq.map(_.collect().toSeq)
83+
val secondRun = data.randomSplit(Array[Double](2, 3), seed = 1).toSeq.map(_.collect().toSeq)
8684
assert(firstRun == secondRun)
8785
}
8886

0 commit comments

Comments
 (0)