1717
1818package org .apache .spark .mllib .regression
1919
20- import java .io .File
21- import java .nio .charset .Charset
22-
2320import scala .collection .mutable .ArrayBuffer
2421
25- import com .google .common .io .Files
2622import org .scalatest .FunSuite
2723
2824import org .apache .spark .mllib .linalg .Vectors
29- import org .apache .spark .mllib .util .{LinearDataGenerator , LocalSparkContext }
30- import org .apache .spark .streaming .{Milliseconds , StreamingContext }
31- import org .apache .spark .util .Utils
25+ import org .apache .spark .mllib .util .LinearDataGenerator
26+ import org .apache .spark .streaming .dstream .DStream
27+ import org .apache .spark .streaming .TestSuiteBase
28+
29+ class StreamingLinearRegressionSuite extends FunSuite with TestSuiteBase {
3230
33- class StreamingLinearRegressionSuite extends FunSuite with LocalSparkContext {
31+ // use longer wait time to ensure job completion
32+ override def maxWaitTimeMillis = 20000
3433
3534 // Assert that two values are equal within tolerance epsilon
3635 def assertEqual (v1 : Double , v2 : Double , epsilon : Double ) {
@@ -49,35 +48,26 @@ class StreamingLinearRegressionSuite extends FunSuite with LocalSparkContext {
4948 }
5049
5150 // Test if we can accurately learn Y = 10*X1 + 10*X2 on streaming data
52- test(" streaming linear regression parameter accuracy" ) {
51+ test(" parameter accuracy" ) {
5352
54- val testDir = Files .createTempDir()
55- val numBatches = 10
56- val batchDuration = Milliseconds (1000 )
57- val ssc = new StreamingContext (sc, batchDuration)
58- val data = ssc.textFileStream(testDir.toString).map(LabeledPoint .parse)
53+ // create model
5954 val model = new StreamingLinearRegressionWithSGD ()
6055 .setInitialWeights(Vectors .dense(0.0 , 0.0 ))
6156 .setStepSize(0.1 )
62- .setNumIterations(50 )
57+ .setNumIterations(25 )
6358
64- model.trainOn(data)
65-
66- ssc.start()
67-
68- // write data to a file stream
69- for (i <- 0 until numBatches) {
70- val samples = LinearDataGenerator .generateLinearInput(
71- 0.0 , Array (10.0 , 10.0 ), 100 , 42 * (i + 1 ))
72- val file = new File (testDir, i.toString)
73- Files .write(samples.map(x => x.toString).mkString(" \n " ), file, Charset .forName(" UTF-8" ))
74- Thread .sleep(batchDuration.milliseconds)
59+ // generate sequence of simulated data
60+ val numBatches = 10
61+ val input = (0 until numBatches).map { i =>
62+ LinearDataGenerator .generateLinearInput(0.0 , Array (10.0 , 10.0 ), 100 , 42 * (i + 1 ))
7563 }
7664
77- ssc.stop(stopSparkContext= false )
78-
79- System .clearProperty(" spark.driver.port" )
80- Utils .deleteRecursively(testDir)
65+ // apply model training to input stream
66+ val ssc = setupStreams(input, (inputDStream : DStream [LabeledPoint ]) => {
67+ model.trainOn(inputDStream)
68+ inputDStream.count()
69+ })
70+ runStreams(ssc, numBatches, numBatches)
8171
8272 // check accuracy of final parameter estimates
8373 assertEqual(model.latestModel().intercept, 0.0 , 0.1 )
@@ -91,39 +81,33 @@ class StreamingLinearRegressionSuite extends FunSuite with LocalSparkContext {
9181 }
9282
9383 // Test that parameter estimates improve when learning Y = 10*X1 on streaming data
94- test(" streaming linear regression parameter convergence" ) {
84+ test(" parameter convergence" ) {
9585
96- val testDir = Files .createTempDir()
97- val batchDuration = Milliseconds (2000 )
98- val ssc = new StreamingContext (sc, batchDuration)
99- val numBatches = 5
100- val data = ssc.textFileStream(testDir.toString()).map(LabeledPoint .parse)
86+ // create model
10187 val model = new StreamingLinearRegressionWithSGD ()
10288 .setInitialWeights(Vectors .dense(0.0 ))
10389 .setStepSize(0.1 )
104- .setNumIterations(50 )
105-
106- model.trainOn(data)
107-
108- ssc.start()
90+ .setNumIterations(25 )
10991
110- // write data to a file stream
111- val history = new ArrayBuffer [Double ](numBatches)
112- for (i <- 0 until numBatches) {
113- val samples = LinearDataGenerator .generateLinearInput(0.0 , Array (10.0 ), 100 , 42 * (i + 1 ))
114- val file = new File (testDir, i.toString)
115- Files .write(samples.map(x => x.toString).mkString(" \n " ), file, Charset .forName(" UTF-8" ))
116- Thread .sleep(batchDuration.milliseconds)
117- // wait an extra few seconds to make sure the update finishes before new data arrive
118- Thread .sleep(4000 )
119- history.append(math.abs(model.latestModel().weights(0 ) - 10.0 ))
92+ // generate sequence of simulated data
93+ val numBatches = 10
94+ val input = (0 until numBatches).map { i =>
95+ LinearDataGenerator .generateLinearInput(0.0 , Array (10.0 ), 100 , 42 * (i + 1 ))
12096 }
12197
122- ssc.stop(stopSparkContext= false )
98+ // create buffer to store intermediate fits
99+ val history = new ArrayBuffer [Double ](numBatches)
123100
124- System .clearProperty(" spark.driver.port" )
125- Utils .deleteRecursively(testDir)
101+ // apply model training to input stream, storing the intermediate results
102+ // (we add a count to ensure the result is a DStream)
103+ val ssc = setupStreams(input, (inputDStream : DStream [LabeledPoint ]) => {
104+ model.trainOn(inputDStream)
105+ inputDStream.foreachRDD(x => history.append(math.abs(model.latestModel().weights(0 ) - 10.0 )))
106+ inputDStream.count()
107+ })
108+ runStreams(ssc, numBatches, numBatches)
126109
110+ // compute change in error
127111 val deltas = history.drop(1 ).zip(history.dropRight(1 ))
128112 // check error stability (it always either shrinks, or increases with small tol)
129113 assert(deltas.forall(x => (x._1 - x._2) <= 0.1 ))
@@ -132,4 +116,33 @@ class StreamingLinearRegressionSuite extends FunSuite with LocalSparkContext {
132116
133117 }
134118
119+ // Test predictions on a stream
120+ test(" predictions" ) {
121+
122+ // create model initialized with true weights
123+ val model = new StreamingLinearRegressionWithSGD ()
124+ .setInitialWeights(Vectors .dense(10.0 , 10.0 ))
125+ .setStepSize(0.1 )
126+ .setNumIterations(25 )
127+
128+ // generate sequence of simulated data for testing
129+ val numBatches = 10
130+ val nPoints = 100
131+ val testInput = (0 until numBatches).map { i =>
132+ LinearDataGenerator .generateLinearInput(0.0 , Array (10.0 , 10.0 ), nPoints, 42 * (i + 1 ))
133+ }
134+
135+ // apply model predictions to test stream
136+ val ssc = setupStreams(testInput, (inputDStream : DStream [LabeledPoint ]) => {
137+ model.predictOnValues(inputDStream.map(x => (x.label, x.features)))
138+ })
139+ // collect the output as (true, estimated) tuples
140+ val output : Seq [Seq [(Double , Double )]] = runStreams(ssc, numBatches, numBatches)
141+
142+ // compute the mean absolute error and check that it's always less than 0.1
143+ val errors = output.map(batch => batch.map(p => math.abs(p._1 - p._2)).sum / nPoints)
144+ assert(errors.forall(x => x <= 0.1 ))
145+
146+ }
147+
135148}
0 commit comments