Skip to content

Commit 017b73e

Browse files
sameeragarwalrxin
authored andcommitted
[SPARK-12662][SQL] Fix DataFrame.randomSplit to avoid creating overlapping splits
https://issues.apache.org/jira/browse/SPARK-12662 cc yhuai Author: Sameer Agarwal <[email protected]> Closes #10626 from sameeragarwal/randomsplit. (cherry picked from commit f194d99) Signed-off-by: Reynold Xin <[email protected]>
1 parent 69a885a commit 017b73e

2 files changed

Lines changed: 28 additions & 1 deletion

File tree

sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1107,10 +1107,15 @@ class DataFrame private[sql](
11071107
* @since 1.4.0
11081108
*/
11091109
def randomSplit(weights: Array[Double], seed: Long): Array[DataFrame] = {
1110+
// It is possible that the underlying dataframe doesn't guarantee the ordering of rows in its
1111+
// constituent partitions each time a split is materialized which could result in
1112+
// overlapping splits. To prevent this, we explicitly sort each input partition to make the
1113+
// ordering deterministic.
1114+
val sorted = Sort(logicalPlan.output.map(SortOrder(_, Ascending)), global = false, logicalPlan)
11101115
val sum = weights.sum
11111116
val normalizedCumWeights = weights.map(_ / sum).scanLeft(0.0d)(_ + _)
11121117
normalizedCumWeights.sliding(2).map { x =>
1113-
new DataFrame(sqlContext, Sample(x(0), x(1), withReplacement = false, seed, logicalPlan))
1118+
new DataFrame(sqlContext, Sample(x(0), x(1), withReplacement = false, seed, sorted))
11141119
}.toArray
11151120
}
11161121

sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,28 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext {
6262
}
6363
}
6464

65+
test("randomSplit on reordered partitions") {
66+
// This test ensures that randomSplit does not create overlapping splits even when the
67+
// underlying dataframe (such as the one below) doesn't guarantee a deterministic ordering of
68+
// rows in each partition.
69+
val data =
70+
sparkContext.parallelize(1 to 600, 2).mapPartitions(scala.util.Random.shuffle(_)).toDF("id")
71+
val splits = data.randomSplit(Array[Double](2, 3), seed = 1)
72+
73+
assert(splits.length == 2, "wrong number of splits")
74+
75+
// Verify that the splits span the entire dataset
76+
assert(splits.flatMap(_.collect()).toSet == data.collect().toSet)
77+
78+
// Verify that the splits don't overalap
79+
assert(splits(0).intersect(splits(1)).collect().isEmpty)
80+
81+
// Verify that the results are deterministic across multiple runs
82+
val firstRun = splits.toSeq.map(_.collect().toSeq)
83+
val secondRun = data.randomSplit(Array[Double](2, 3), seed = 1).toSeq.map(_.collect().toSeq)
84+
assert(firstRun == secondRun)
85+
}
86+
6587
test("pearson correlation") {
6688
val df = Seq.tabulate(10)(i => (i, 2 * i, i * -1.0)).toDF("a", "b", "c")
6789
val corr1 = df.stat.corr("a", "b", "pearson")

0 commit comments

Comments
 (0)