@@ -26,11 +26,10 @@ import com.google.common.io.Files
2626import org .scalatest .FunSuite
2727
2828import org .apache .spark .mllib .util .{MLStreamingUtils , LinearDataGenerator , LocalSparkContext }
29- import org .apache .spark .SparkConf
30- import org .apache .spark .streaming .{Milliseconds , Seconds , StreamingContext }
29+ import org .apache .spark .streaming .{Milliseconds , StreamingContext }
3130import org .apache .spark .util .Utils
3231
33- class StreamingLinearRegressionSuite extends FunSuite {
32+ class StreamingLinearRegressionSuite extends FunSuite with LocalSparkContext {
3433
3534 // Assert that two values are equal within tolerance epsilon
3635 def assertEqual (v1 : Double , v2 : Double , epsilon : Double ) {
@@ -51,10 +50,10 @@ class StreamingLinearRegressionSuite extends FunSuite {
5150 // Test if we can accurately learn Y = 10*X1 + 10*X2 on streaming data
5251 test(" streaming linear regression parameter accuracy" ) {
5352
54- val conf = new SparkConf ().setMaster(" local" ).setAppName(" streaming test" )
5553 val testDir = Files .createTempDir()
5654 val numBatches = 10
57- val ssc = new StreamingContext (conf, Seconds (1 ))
55+ val batchDuration = Milliseconds (1000 )
56+ val ssc = new StreamingContext (sc, batchDuration)
5857 val data = MLStreamingUtils .loadLabeledPointsFromText(ssc, testDir.toString)
5958 val model = StreamingLinearRegressionWithSGD .start(numFeatures= 2 , numIterations= 50 )
6059
@@ -63,16 +62,14 @@ class StreamingLinearRegressionSuite extends FunSuite {
6362 ssc.start()
6463
6564 // write data to a file stream
66- Thread .sleep(5000 )
6765 for (i <- 0 until numBatches) {
6866 val samples = LinearDataGenerator .generateLinearInput(0.0 , Array (10.0 , 10.0 ), 100 , 42 * (i + 1 ))
6967 val file = new File (testDir, i.toString)
7068 Files .write(samples.map(x => x.toString).mkString(" \n " ), file, Charset .forName(" UTF-8" ))
71- Thread .sleep(Milliseconds ( 1000 ) .milliseconds)
69+ Thread .sleep(batchDuration .milliseconds)
7270 }
73- Thread .sleep(Milliseconds (5000 ).milliseconds)
7471
75- ssc.stop()
72+ ssc.stop(stopSparkContext = false )
7673
7774 System .clearProperty(" spark.driver.port" )
7875 Utils .deleteRecursively(testDir)
@@ -90,9 +87,9 @@ class StreamingLinearRegressionSuite extends FunSuite {
9087 // Test that parameter estimates improve when learning Y = 10*X1 on streaming data
9188 test(" streaming linear regression parameter convergence" ) {
9289
93- val conf = new SparkConf ().setMaster(" local" ).setAppName(" streaming test" )
9490 val testDir = Files .createTempDir()
95- val ssc = new StreamingContext (conf, Seconds (2 ))
91+ val batchDuration = Milliseconds (2000 )
92+ val ssc = new StreamingContext (sc, batchDuration)
9693 val numBatches = 5
9794 val data = MLStreamingUtils .loadLabeledPointsFromText(ssc, testDir.toString)
9895 val model = StreamingLinearRegressionWithSGD .start(numFeatures= 1 , numIterations= 50 )
@@ -103,17 +100,17 @@ class StreamingLinearRegressionSuite extends FunSuite {
103100
104101 // write data to a file stream
105102 val history = new ArrayBuffer [Double ](numBatches)
106- Thread .sleep(5000 )
107103 for (i <- 0 until numBatches) {
108104 val samples = LinearDataGenerator .generateLinearInput(0.0 , Array (10.0 ), 100 , 42 * (i + 1 ))
109105 val file = new File (testDir, i.toString)
110106 Files .write(samples.map(x => x.toString).mkString(" \n " ), file, Charset .forName(" UTF-8" ))
111- Thread .sleep(Milliseconds (6000 ).milliseconds)
107+ Thread .sleep(batchDuration.milliseconds)
108+ // wait an extra few seconds to make sure the update finishes before new data arrive
109+ Thread .sleep(4000 )
112110 history.append(math.abs(model.latest().weights(0 ) - 10.0 ))
113111 }
114- Thread .sleep(Milliseconds (5000 ).milliseconds)
115112
116- ssc.stop()
113+ ssc.stop(stopSparkContext = false )
117114
118115 System .clearProperty(" spark.driver.port" )
119116 Utils .deleteRecursively(testDir)
0 commit comments