|
18 | 18 | package org.apache.spark.mllib.util |
19 | 19 |
|
20 | 20 | import java.io.File |
| 21 | +import scala.math |
| 22 | +import scala.util.Random |
21 | 23 |
|
22 | 24 | import org.scalatest.FunSuite |
23 | 25 |
|
@@ -136,19 +138,30 @@ class MLUtilsSuite extends FunSuite with LocalSparkContext { |
136 | 138 | new LinearRegressionModel(Array(1.0), 0) |
137 | 139 | } |
138 | 140 |
|
139 | | - test("kfoldRdd") { |
| 141 | + test("kFold") { |
140 | 142 | val data = sc.parallelize(1 to 100, 2) |
141 | 143 | val collectedData = data.collect().sorted |
142 | | - val twoFoldedRdd = MLUtils.kFoldRdds(data, 2, 1) |
| 144 | + val twoFoldedRdd = MLUtils.kFold(data, 2, 1) |
143 | 145 | assert(twoFoldedRdd(0)._1.collect().sorted === twoFoldedRdd(1)._2.collect().sorted) |
144 | 146 | assert(twoFoldedRdd(0)._2.collect().sorted === twoFoldedRdd(1)._1.collect().sorted) |
145 | 147 | for (folds <- 2 to 10) { |
146 | 148 | for (seed <- 1 to 5) { |
147 | | - val foldedRdds = MLUtils.kFoldRdds(data, folds, seed) |
| 149 | + val foldedRdds = MLUtils.kFold(data, folds, seed) |
148 | 150 | assert(foldedRdds.size === folds) |
149 | 151 | foldedRdds.map{case (test, train) => |
150 | 152 | val result = test.union(train).collect().sorted |
151 | | - assert(test.collect().size > 0, "Non empty test data") |
| 153 | + val testSize = test.collect().size.toFloat |
| 154 | + assert(testSize > 0, "Non empty test data") |
| 155 | + val p = 1 / folds.toFloat |
| 156 | + // Within 3 standard deviations of the mean |
| 157 | + val range = 3 * math.sqrt(100 * p * (1-p)) |
| 158 | + val expected = 100 * p |
| 159 | + val lowerBound = expected - range |
| 160 | + val upperBound = expected + range |
| 161 | + assert(testSize > lowerBound, |
| 162 | + "Test data (" + testSize + ") smaller than expected (" + lowerBound +")" ) |
| 163 | + assert(testSize < upperBound, |
| 164 | + "Test data (" + testSize + ") larger than expected (" + upperBound +")" ) |
152 | 165 | assert(train.collect().size > 0, "Non empty training data") |
153 | 166 | assert(result === collectedData, |
154 | 167 | "Each training+test set combined contains all of the data") |
|
0 commit comments