Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,13 @@ import breeze.linalg.{DenseVector => BDV}
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.classification.LinearSVCSuite._
import org.apache.spark.ml.feature.{Instance, LabeledPoint}
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.ml.util.TestingUtils._
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{Dataset, Row}
import org.apache.spark.sql.functions.udf


class LinearSVCSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
Expand All @@ -41,6 +42,9 @@ class LinearSVCSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
@transient var smallValidationDataset: Dataset[_] = _
@transient var binaryDataset: Dataset[_] = _

@transient var smallSparseBinaryDataset: Dataset[_] = _
@transient var smallSparseValidationDataset: Dataset[_] = _

override def beforeAll(): Unit = {
super.beforeAll()

Expand All @@ -51,6 +55,13 @@ class LinearSVCSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
smallBinaryDataset = generateSVMInput(A, Array[Double](B, C), nPoints, 42).toDF()
smallValidationDataset = generateSVMInput(A, Array[Double](B, C), nPoints, 17).toDF()
binaryDataset = generateSVMInput(1.0, Array[Double](1.0, 2.0, 3.0, 4.0), 10000, 42).toDF()

// Dataset for testing SparseVector
val toSparse: Vector => SparseVector = _.asInstanceOf[DenseVector].toSparse
val sparse = udf(toSparse)
smallSparseBinaryDataset = smallBinaryDataset.withColumn("features", sparse('features))
smallSparseValidationDataset = smallValidationDataset.withColumn("features", sparse('features))

}

/**
Expand All @@ -68,13 +79,17 @@ class LinearSVCSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
val model = svm.fit(smallBinaryDataset)
assert(model.transform(smallValidationDataset)
.where("prediction=label").count() > nPoints * 0.8)
val sparseModel = svm.fit(smallSparseBinaryDataset)
checkModels(model, sparseModel)
}

test("Linear SVC binary classification with regularization") {
val svm = new LinearSVC()
val model = svm.setRegParam(0.1).fit(smallBinaryDataset)
assert(model.transform(smallValidationDataset)
.where("prediction=label").count() > nPoints * 0.8)
val sparseModel = svm.fit(smallSparseBinaryDataset)
checkModels(model, sparseModel)
}

test("params") {
Expand Down Expand Up @@ -235,7 +250,7 @@ object LinearSVCSuite {
"aggregationDepth" -> 3
)

// Generate noisy input of the form Y = signum(x.dot(weights) + intercept + noise)
// Generate noisy input of the form Y = signum(x.dot(weights) + intercept + noise)
def generateSVMInput(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This API is strange, where the caller expects numFeatures = weights.size, but really numFeatures = 10 * weights.size if isDense=false. Please update it to construct a random dense or sparse vector first (both of length weights.size) and then compute y to make the API more consistent.

intercept: Double,
weights: Array[Double],
Expand All @@ -252,5 +267,10 @@ object LinearSVCSuite {
y.zip(x).map(p => LabeledPoint(p._1, Vectors.dense(p._2)))
}

def checkModels(model1: LinearSVCModel, model2: LinearSVCModel): Unit = {
assert(model1.intercept == model2.intercept)
assert(model1.coefficients.equals(model2.coefficients))
}

}