Skip to content

Commit 119f6a0

Browse files
committed
[SPARK-22883][ML][TEST] Streaming tests for spark.ml.feature, from A to H
## What changes were proposed in this pull request? Adds structured streaming tests using testTransformer for these suites: * BinarizerSuite * BucketedRandomProjectionLSHSuite * BucketizerSuite * ChiSqSelectorSuite * CountVectorizerSuite * DCTSuite.scala * ElementwiseProductSuite * FeatureHasherSuite * HashingTFSuite ## How was this patch tested? It tests itself because it is a bunch of tests! Author: Joseph K. Bradley <joseph@databricks.com> Closes apache#20111 from jkbradley/SPARK-22883-streaming-featureAM.
1 parent 34811e0 commit 119f6a0

9 files changed

Lines changed: 126 additions & 101 deletions

mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,12 @@
1717

1818
package org.apache.spark.ml.feature
1919

20-
import org.apache.spark.SparkFunSuite
2120
import org.apache.spark.ml.linalg.{Vector, Vectors}
2221
import org.apache.spark.ml.param.ParamsSuite
23-
import org.apache.spark.ml.util.DefaultReadWriteTest
24-
import org.apache.spark.mllib.util.MLlibTestSparkContext
22+
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest}
2523
import org.apache.spark.sql.{DataFrame, Row}
2624

27-
class BinarizerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
25+
class BinarizerSuite extends MLTest with DefaultReadWriteTest {
2826

2927
import testImplicits._
3028

@@ -47,7 +45,7 @@ class BinarizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
4745
.setInputCol("feature")
4846
.setOutputCol("binarized_feature")
4947

50-
binarizer.transform(dataFrame).select("binarized_feature", "expected").collect().foreach {
48+
testTransformer[(Double, Double)](dataFrame, binarizer, "binarized_feature", "expected") {
5149
case Row(x: Double, y: Double) =>
5250
assert(x === y, "The feature value is not correct after binarization.")
5351
}

mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,16 +20,15 @@ package org.apache.spark.ml.feature
2020
import breeze.numerics.{cos, sin}
2121
import breeze.numerics.constants.Pi
2222

23-
import org.apache.spark.SparkFunSuite
2423
import org.apache.spark.ml.linalg.{Vector, Vectors}
2524
import org.apache.spark.ml.param.ParamsSuite
26-
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
25+
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
2726
import org.apache.spark.ml.util.TestingUtils._
28-
import org.apache.spark.mllib.util.MLlibTestSparkContext
29-
import org.apache.spark.sql.Dataset
27+
import org.apache.spark.sql.{Dataset, Row}
3028

31-
class BucketedRandomProjectionLSHSuite
32-
extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
29+
class BucketedRandomProjectionLSHSuite extends MLTest with DefaultReadWriteTest {
30+
31+
import testImplicits._
3332

3433
@transient var dataset: Dataset[_] = _
3534

@@ -98,6 +97,21 @@ class BucketedRandomProjectionLSHSuite
9897
MLTestingUtils.checkCopyAndUids(brp, brpModel)
9998
}
10099

100+
test("BucketedRandomProjectionLSH: streaming transform") {
101+
val brp = new BucketedRandomProjectionLSH()
102+
.setNumHashTables(2)
103+
.setInputCol("keys")
104+
.setOutputCol("values")
105+
.setBucketLength(1.0)
106+
.setSeed(12345)
107+
val brpModel = brp.fit(dataset)
108+
109+
testTransformer[Tuple1[Vector]](dataset.toDF(), brpModel, "values") {
110+
case Row(values: Seq[_]) =>
111+
assert(values.length === brp.getNumHashTables)
112+
}
113+
}
114+
101115
test("BucketedRandomProjectionLSH: test of LSH property") {
102116
// Project from 2 dimensional Euclidean Space to 1 dimensions
103117
val brp = new BucketedRandomProjectionLSH()

mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,13 @@ import org.apache.spark.{SparkException, SparkFunSuite}
2323
import org.apache.spark.ml.Pipeline
2424
import org.apache.spark.ml.linalg.Vectors
2525
import org.apache.spark.ml.param.ParamsSuite
26-
import org.apache.spark.ml.util.DefaultReadWriteTest
26+
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest}
2727
import org.apache.spark.ml.util.TestingUtils._
28-
import org.apache.spark.mllib.util.MLlibTestSparkContext
2928
import org.apache.spark.sql.{DataFrame, Row}
3029
import org.apache.spark.sql.functions._
3130
import org.apache.spark.sql.types._
3231

33-
class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
32+
class BucketizerSuite extends MLTest with DefaultReadWriteTest {
3433

3534
import testImplicits._
3635

@@ -50,7 +49,7 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
5049
.setOutputCol("result")
5150
.setSplits(splits)
5251

53-
bucketizer.transform(dataFrame).select("result", "expected").collect().foreach {
52+
testTransformer[(Double, Double)](dataFrame, bucketizer, "result", "expected") {
5453
case Row(x: Double, y: Double) =>
5554
assert(x === y,
5655
s"The feature value is not correct after bucketing. Expected $y but found $x")
@@ -84,7 +83,7 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
8483
.setOutputCol("result")
8584
.setSplits(splits)
8685

87-
bucketizer.transform(dataFrame).select("result", "expected").collect().foreach {
86+
testTransformer[(Double, Double)](dataFrame, bucketizer, "result", "expected") {
8887
case Row(x: Double, y: Double) =>
8988
assert(x === y,
9089
s"The feature value is not correct after bucketing. Expected $y but found $x")
@@ -103,7 +102,7 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
103102
.setSplits(splits)
104103

105104
bucketizer.setHandleInvalid("keep")
106-
bucketizer.transform(dataFrame).select("result", "expected").collect().foreach {
105+
testTransformer[(Double, Double)](dataFrame, bucketizer, "result", "expected") {
107106
case Row(x: Double, y: Double) =>
108107
assert(x === y,
109108
s"The feature value is not correct after bucketing. Expected $y but found $x")

mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,15 @@
1717

1818
package org.apache.spark.ml.feature
1919

20-
import org.apache.spark.SparkFunSuite
2120
import org.apache.spark.ml.linalg.{Vector, Vectors}
2221
import org.apache.spark.ml.param.ParamsSuite
23-
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
22+
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
2423
import org.apache.spark.ml.util.TestingUtils._
25-
import org.apache.spark.mllib.util.MLlibTestSparkContext
2624
import org.apache.spark.sql.{Dataset, Row}
2725

28-
class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext
29-
with DefaultReadWriteTest {
26+
class ChiSqSelectorSuite extends MLTest with DefaultReadWriteTest {
27+
28+
import testImplicits._
3029

3130
@transient var dataset: Dataset[_] = _
3231

@@ -119,32 +118,32 @@ class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext
119118
test("Test Chi-Square selector: numTopFeatures") {
120119
val selector = new ChiSqSelector()
121120
.setOutputCol("filtered").setSelectorType("numTopFeatures").setNumTopFeatures(1)
122-
val model = ChiSqSelectorSuite.testSelector(selector, dataset)
121+
val model = testSelector(selector, dataset)
123122
MLTestingUtils.checkCopyAndUids(selector, model)
124123
}
125124

126125
test("Test Chi-Square selector: percentile") {
127126
val selector = new ChiSqSelector()
128127
.setOutputCol("filtered").setSelectorType("percentile").setPercentile(0.17)
129-
ChiSqSelectorSuite.testSelector(selector, dataset)
128+
testSelector(selector, dataset)
130129
}
131130

132131
test("Test Chi-Square selector: fpr") {
133132
val selector = new ChiSqSelector()
134133
.setOutputCol("filtered").setSelectorType("fpr").setFpr(0.02)
135-
ChiSqSelectorSuite.testSelector(selector, dataset)
134+
testSelector(selector, dataset)
136135
}
137136

138137
test("Test Chi-Square selector: fdr") {
139138
val selector = new ChiSqSelector()
140139
.setOutputCol("filtered").setSelectorType("fdr").setFdr(0.12)
141-
ChiSqSelectorSuite.testSelector(selector, dataset)
140+
testSelector(selector, dataset)
142141
}
143142

144143
test("Test Chi-Square selector: fwe") {
145144
val selector = new ChiSqSelector()
146145
.setOutputCol("filtered").setSelectorType("fwe").setFwe(0.12)
147-
ChiSqSelectorSuite.testSelector(selector, dataset)
146+
testSelector(selector, dataset)
148147
}
149148

150149
test("read/write") {
@@ -163,18 +162,19 @@ class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext
163162
assert(expected.selectedFeatures === actual.selectedFeatures)
164163
}
165164
}
166-
}
167165

168-
object ChiSqSelectorSuite {
169-
170-
private def testSelector(selector: ChiSqSelector, dataset: Dataset[_]): ChiSqSelectorModel = {
171-
val selectorModel = selector.fit(dataset)
172-
selectorModel.transform(dataset).select("filtered", "topFeature").collect()
173-
.foreach { case Row(vec1: Vector, vec2: Vector) =>
166+
private def testSelector(selector: ChiSqSelector, data: Dataset[_]): ChiSqSelectorModel = {
167+
val selectorModel = selector.fit(data)
168+
testTransformer[(Double, Vector, Vector)](data.toDF(), selectorModel,
169+
"filtered", "topFeature") {
170+
case Row(vec1: Vector, vec2: Vector) =>
174171
assert(vec1 ~== vec2 absTol 1e-1)
175-
}
172+
}
176173
selectorModel
177174
}
175+
}
176+
177+
object ChiSqSelectorSuite {
178178

179179
/**
180180
* Mapping from all Params to valid settings which differ from the defaults.

mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,13 @@
1616
*/
1717
package org.apache.spark.ml.feature
1818

19-
import org.apache.spark.SparkFunSuite
2019
import org.apache.spark.ml.linalg.{Vector, Vectors}
2120
import org.apache.spark.ml.param.ParamsSuite
22-
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
21+
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
2322
import org.apache.spark.ml.util.TestingUtils._
24-
import org.apache.spark.mllib.util.MLlibTestSparkContext
2523
import org.apache.spark.sql.Row
2624

27-
class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext
28-
with DefaultReadWriteTest {
25+
class CountVectorizerSuite extends MLTest with DefaultReadWriteTest {
2926

3027
import testImplicits._
3128

@@ -50,7 +47,7 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext
5047
val cv = new CountVectorizerModel(Array("a", "b", "c", "d"))
5148
.setInputCol("words")
5249
.setOutputCol("features")
53-
cv.transform(df).select("features", "expected").collect().foreach {
50+
testTransformer[(Int, Seq[String], Vector)](df, cv, "features", "expected") {
5451
case Row(features: Vector, expected: Vector) =>
5552
assert(features ~== expected absTol 1e-14)
5653
}
@@ -72,7 +69,7 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext
7269
MLTestingUtils.checkCopyAndUids(cv, cvm)
7370
assert(cvm.vocabulary.toSet === Set("a", "b", "c", "d", "e"))
7471

75-
cvm.transform(df).select("features", "expected").collect().foreach {
72+
testTransformer[(Int, Seq[String], Vector)](df, cvm, "features", "expected") {
7673
case Row(features: Vector, expected: Vector) =>
7774
assert(features ~== expected absTol 1e-14)
7875
}
@@ -100,7 +97,7 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext
10097
.fit(df)
10198
assert(cvModel2.vocabulary === Array("a", "b"))
10299

103-
cvModel2.transform(df).select("features", "expected").collect().foreach {
100+
testTransformer[(Int, Seq[String], Vector)](df, cvModel2, "features", "expected") {
104101
case Row(features: Vector, expected: Vector) =>
105102
assert(features ~== expected absTol 1e-14)
106103
}
@@ -113,7 +110,7 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext
113110
.fit(df)
114111
assert(cvModel3.vocabulary === Array("a", "b"))
115112

116-
cvModel3.transform(df).select("features", "expected").collect().foreach {
113+
testTransformer[(Int, Seq[String], Vector)](df, cvModel3, "features", "expected") {
117114
case Row(features: Vector, expected: Vector) =>
118115
assert(features ~== expected absTol 1e-14)
119116
}
@@ -219,7 +216,7 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext
219216
.setInputCol("words")
220217
.setOutputCol("features")
221218
.setMinTF(3)
222-
cv.transform(df).select("features", "expected").collect().foreach {
219+
testTransformer[(Int, Seq[String], Vector)](df, cv, "features", "expected") {
223220
case Row(features: Vector, expected: Vector) =>
224221
assert(features ~== expected absTol 1e-14)
225222
}
@@ -238,7 +235,7 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext
238235
.setInputCol("words")
239236
.setOutputCol("features")
240237
.setMinTF(0.3)
241-
cv.transform(df).select("features", "expected").collect().foreach {
238+
testTransformer[(Int, Seq[String], Vector)](df, cv, "features", "expected") {
242239
case Row(features: Vector, expected: Vector) =>
243240
assert(features ~== expected absTol 1e-14)
244241
}
@@ -258,7 +255,7 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext
258255
.setOutputCol("features")
259256
.setBinary(true)
260257
.fit(df)
261-
cv.transform(df).select("features", "expected").collect().foreach {
258+
testTransformer[(Int, Seq[String], Vector)](df, cv, "features", "expected") {
262259
case Row(features: Vector, expected: Vector) =>
263260
assert(features ~== expected absTol 1e-14)
264261
}
@@ -268,7 +265,7 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext
268265
.setInputCol("words")
269266
.setOutputCol("features")
270267
.setBinary(true)
271-
cv2.transform(df).select("features", "expected").collect().foreach {
268+
testTransformer[(Int, Seq[String], Vector)](df, cv2, "features", "expected") {
272269
case Row(features: Vector, expected: Vector) =>
273270
assert(features ~== expected absTol 1e-14)
274271
}

mllib/src/test/scala/org/apache/spark/ml/feature/DCTSuite.scala

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,16 +21,14 @@ import scala.beans.BeanInfo
2121

2222
import edu.emory.mathcs.jtransforms.dct.DoubleDCT_1D
2323

24-
import org.apache.spark.SparkFunSuite
2524
import org.apache.spark.ml.linalg.{Vector, Vectors}
26-
import org.apache.spark.ml.util.DefaultReadWriteTest
27-
import org.apache.spark.mllib.util.MLlibTestSparkContext
25+
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest}
2826
import org.apache.spark.sql.Row
2927

3028
@BeanInfo
3129
case class DCTTestData(vec: Vector, wantedVec: Vector)
3230

33-
class DCTSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
31+
class DCTSuite extends MLTest with DefaultReadWriteTest {
3432

3533
import testImplicits._
3634

@@ -72,11 +70,9 @@ class DCTSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead
7270
.setOutputCol("resultVec")
7371
.setInverse(inverse)
7472

75-
transformer.transform(dataset)
76-
.select("resultVec", "wantedVec")
77-
.collect()
78-
.foreach { case Row(resultVec: Vector, wantedVec: Vector) =>
79-
assert(Vectors.sqdist(resultVec, wantedVec) < 1e-6)
73+
testTransformer[(Vector, Vector)](dataset, transformer, "resultVec", "wantedVec") {
74+
case Row(resultVec: Vector, wantedVec: Vector) =>
75+
assert(Vectors.sqdist(resultVec, wantedVec) < 1e-6)
8076
}
8177
}
8278
}

mllib/src/test/scala/org/apache/spark/ml/feature/ElementwiseProductSuite.scala

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,31 @@
1717

1818
package org.apache.spark.ml.feature
1919

20-
import org.apache.spark.SparkFunSuite
21-
import org.apache.spark.ml.linalg.Vectors
22-
import org.apache.spark.ml.util.DefaultReadWriteTest
23-
import org.apache.spark.mllib.util.MLlibTestSparkContext
20+
import org.apache.spark.ml.linalg.{Vector, Vectors}
21+
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest}
22+
import org.apache.spark.ml.util.TestingUtils._
23+
import org.apache.spark.sql.Row
2424

25-
class ElementwiseProductSuite
26-
extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
25+
class ElementwiseProductSuite extends MLTest with DefaultReadWriteTest {
26+
27+
import testImplicits._
28+
29+
test("streaming transform") {
30+
val scalingVec = Vectors.dense(0.1, 10.0)
31+
val data = Seq(
32+
(Vectors.dense(0.1, 1.0), Vectors.dense(0.01, 10.0)),
33+
(Vectors.dense(0.0, -1.1), Vectors.dense(0.0, -11.0))
34+
)
35+
val df = spark.createDataFrame(data).toDF("features", "expected")
36+
val ep = new ElementwiseProduct()
37+
.setInputCol("features")
38+
.setOutputCol("actual")
39+
.setScalingVec(scalingVec)
40+
testTransformer[(Vector, Vector)](df, ep, "actual", "expected") {
41+
case Row(actual: Vector, expected: Vector) =>
42+
assert(actual ~== expected relTol 1e-14)
43+
}
44+
}
2745

2846
test("read/write") {
2947
val ep = new ElementwiseProduct()

0 commit comments

Comments
 (0)