Skip to content

Commit 9265436

Browse files
wangmiao1981jkbradley
authored andcommitted
[SPARK-19382][ML] Test sparse vectors in LinearSVCSuite
## What changes were proposed in this pull request? Add unit tests for testing SparseVector. We can't add mixed DenseVector and SparseVector test case, as discussed in JIRA 19382. def merge(other: MultivariateOnlineSummarizer): this.type = { if (this.totalWeightSum != 0.0 && other.totalWeightSum != 0.0) { require(n == other.n, s"Dimensions mismatch when merging with another summarizer. " + s"Expecting $n but got $ {other.n} .") ## How was this patch tested? Unit tests Author: wm624@hotmail.com <wm624@hotmail.com> Author: Miao Wang <wangmiao1981@users.noreply.github.com> Closes #16784 from wangmiao1981/bk.
1 parent 9991c2d commit 9265436

1 file changed

Lines changed: 22 additions & 2 deletions

File tree

mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,13 @@ import breeze.linalg.{DenseVector => BDV}
2424
import org.apache.spark.SparkFunSuite
2525
import org.apache.spark.ml.classification.LinearSVCSuite._
2626
import org.apache.spark.ml.feature.{Instance, LabeledPoint}
27-
import org.apache.spark.ml.linalg.{Vector, Vectors}
27+
import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors}
2828
import org.apache.spark.ml.param.ParamsSuite
2929
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
3030
import org.apache.spark.ml.util.TestingUtils._
3131
import org.apache.spark.mllib.util.MLlibTestSparkContext
3232
import org.apache.spark.sql.{Dataset, Row}
33+
import org.apache.spark.sql.functions.udf
3334

3435

3536
class LinearSVCSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
@@ -41,6 +42,9 @@ class LinearSVCSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
4142
@transient var smallValidationDataset: Dataset[_] = _
4243
@transient var binaryDataset: Dataset[_] = _
4344

45+
@transient var smallSparseBinaryDataset: Dataset[_] = _
46+
@transient var smallSparseValidationDataset: Dataset[_] = _
47+
4448
override def beforeAll(): Unit = {
4549
super.beforeAll()
4650

@@ -51,6 +55,13 @@ class LinearSVCSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
5155
smallBinaryDataset = generateSVMInput(A, Array[Double](B, C), nPoints, 42).toDF()
5256
smallValidationDataset = generateSVMInput(A, Array[Double](B, C), nPoints, 17).toDF()
5357
binaryDataset = generateSVMInput(1.0, Array[Double](1.0, 2.0, 3.0, 4.0), 10000, 42).toDF()
58+
59+
// Dataset for testing SparseVector
60+
val toSparse: Vector => SparseVector = _.asInstanceOf[DenseVector].toSparse
61+
val sparse = udf(toSparse)
62+
smallSparseBinaryDataset = smallBinaryDataset.withColumn("features", sparse('features))
63+
smallSparseValidationDataset = smallValidationDataset.withColumn("features", sparse('features))
64+
5465
}
5566

5667
/**
@@ -68,13 +79,17 @@ class LinearSVCSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
6879
val model = svm.fit(smallBinaryDataset)
6980
assert(model.transform(smallValidationDataset)
7081
.where("prediction=label").count() > nPoints * 0.8)
82+
val sparseModel = svm.fit(smallSparseBinaryDataset)
83+
checkModels(model, sparseModel)
7184
}
7285

7386
test("Linear SVC binary classification with regularization") {
7487
val svm = new LinearSVC()
7588
val model = svm.setRegParam(0.1).fit(smallBinaryDataset)
7689
assert(model.transform(smallValidationDataset)
7790
.where("prediction=label").count() > nPoints * 0.8)
91+
val sparseModel = svm.fit(smallSparseBinaryDataset)
92+
checkModels(model, sparseModel)
7893
}
7994

8095
test("params") {
@@ -235,7 +250,7 @@ object LinearSVCSuite {
235250
"aggregationDepth" -> 3
236251
)
237252

238-
// Generate noisy input of the form Y = signum(x.dot(weights) + intercept + noise)
253+
// Generate noisy input of the form Y = signum(x.dot(weights) + intercept + noise)
239254
def generateSVMInput(
240255
intercept: Double,
241256
weights: Array[Double],
@@ -252,5 +267,10 @@ object LinearSVCSuite {
252267
y.zip(x).map(p => LabeledPoint(p._1, Vectors.dense(p._2)))
253268
}
254269

270+
def checkModels(model1: LinearSVCModel, model2: LinearSVCModel): Unit = {
271+
assert(model1.intercept == model2.intercept)
272+
assert(model1.coefficients.equals(model2.coefficients))
273+
}
274+
255275
}
256276

0 commit comments

Comments
 (0)