Skip to content

Commit 4b0a5d3

Browse files
committed
Cleaned up tests
- Use LocalSparkContext from mllib.util - Clarified timing parameters and removed unnecessary delays
1 parent 74188d6 commit 4b0a5d3

File tree

1 file changed

+12
-15
lines changed

1 file changed

+12
-15
lines changed

mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,10 @@ import com.google.common.io.Files
2626
import org.scalatest.FunSuite
2727

2828
import org.apache.spark.mllib.util.{MLStreamingUtils, LinearDataGenerator, LocalSparkContext}
29-
import org.apache.spark.SparkConf
30-
import org.apache.spark.streaming.{Milliseconds, Seconds, StreamingContext}
29+
import org.apache.spark.streaming.{Milliseconds, StreamingContext}
3130
import org.apache.spark.util.Utils
3231

33-
class StreamingLinearRegressionSuite extends FunSuite {
32+
class StreamingLinearRegressionSuite extends FunSuite with LocalSparkContext {
3433

3534
// Assert that two values are equal within tolerance epsilon
3635
def assertEqual(v1: Double, v2: Double, epsilon: Double) {
@@ -51,10 +50,10 @@ class StreamingLinearRegressionSuite extends FunSuite {
5150
// Test if we can accurately learn Y = 10*X1 + 10*X2 on streaming data
5251
test("streaming linear regression parameter accuracy") {
5352

54-
val conf = new SparkConf().setMaster("local").setAppName("streaming test")
5553
val testDir = Files.createTempDir()
5654
val numBatches = 10
57-
val ssc = new StreamingContext(conf, Seconds(1))
55+
val batchDuration = Milliseconds(1000)
56+
val ssc = new StreamingContext(sc, batchDuration)
5857
val data = MLStreamingUtils.loadLabeledPointsFromText(ssc, testDir.toString)
5958
val model = StreamingLinearRegressionWithSGD.start(numFeatures=2, numIterations=50)
6059

@@ -63,16 +62,14 @@ class StreamingLinearRegressionSuite extends FunSuite {
6362
ssc.start()
6463

6564
// write data to a file stream
66-
Thread.sleep(5000)
6765
for (i <- 0 until numBatches) {
6866
val samples = LinearDataGenerator.generateLinearInput(0.0, Array(10.0, 10.0), 100, 42 * (i + 1))
6967
val file = new File(testDir, i.toString)
7068
Files.write(samples.map(x => x.toString).mkString("\n"), file, Charset.forName("UTF-8"))
71-
Thread.sleep(Milliseconds(1000).milliseconds)
69+
Thread.sleep(batchDuration.milliseconds)
7270
}
73-
Thread.sleep(Milliseconds(5000).milliseconds)
7471

75-
ssc.stop()
72+
ssc.stop(stopSparkContext=false)
7673

7774
System.clearProperty("spark.driver.port")
7875
Utils.deleteRecursively(testDir)
@@ -90,9 +87,9 @@ class StreamingLinearRegressionSuite extends FunSuite {
9087
// Test that parameter estimates improve when learning Y = 10*X1 on streaming data
9188
test("streaming linear regression parameter convergence") {
9289

93-
val conf = new SparkConf().setMaster("local").setAppName("streaming test")
9490
val testDir = Files.createTempDir()
95-
val ssc = new StreamingContext(conf, Seconds(2))
91+
val batchDuration = Milliseconds(2000)
92+
val ssc = new StreamingContext(sc, batchDuration)
9693
val numBatches = 5
9794
val data = MLStreamingUtils.loadLabeledPointsFromText(ssc, testDir.toString)
9895
val model = StreamingLinearRegressionWithSGD.start(numFeatures=1, numIterations=50)
@@ -103,17 +100,17 @@ class StreamingLinearRegressionSuite extends FunSuite {
103100

104101
// write data to a file stream
105102
val history = new ArrayBuffer[Double](numBatches)
106-
Thread.sleep(5000)
107103
for (i <- 0 until numBatches) {
108104
val samples = LinearDataGenerator.generateLinearInput(0.0, Array(10.0), 100, 42 * (i + 1))
109105
val file = new File(testDir, i.toString)
110106
Files.write(samples.map(x => x.toString).mkString("\n"), file, Charset.forName("UTF-8"))
111-
Thread.sleep(Milliseconds(6000).milliseconds)
107+
Thread.sleep(batchDuration.milliseconds)
108+
// wait an extra few seconds to make sure the update finishes before new data arrive
109+
Thread.sleep(4000)
112110
history.append(math.abs(model.latest().weights(0) - 10.0))
113111
}
114-
Thread.sleep(Milliseconds(5000).milliseconds)
115112

116-
ssc.stop()
113+
ssc.stop(stopSparkContext=false)
117114

118115
System.clearProperty("spark.driver.port")
119116
Utils.deleteRecursively(testDir)

0 commit comments

Comments
 (0)