From 0820c04bf26be840d0137b730e497ce4305938b1 Mon Sep 17 00:00:00 2001 From: Christoph Sawade Date: Mon, 15 Sep 2014 16:00:02 +0200 Subject: [PATCH] Use SquaredL2Updater in LogisticRegressionWithSGD SimpleUpdater ignores the regularizer, which leads to an unregularized LogReg. To enable the common L2 regularizer (and the corresponding regularization parameter) for logistic regression the SquaredL2Updater has to be used in SGD (see, e.g., [SVMWithSGD]) --- .../classification/LogisticRegression.scala | 2 +- .../LogisticRegressionSuite.scala | 44 +++++++++++++++++-- 2 files changed, 42 insertions(+), 4 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala index 486bdbfa9cb47..84d3c7cebd7c8 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala @@ -84,7 +84,7 @@ class LogisticRegressionWithSGD private ( extends GeneralizedLinearAlgorithm[LogisticRegressionModel] with Serializable { private val gradient = new LogisticGradient() - private val updater = new SimpleUpdater() + private val updater = new SquaredL2Updater() override val optimizer = new GradientDescent(gradient, updater) .setStepSize(stepSize) .setNumIterations(numIterations) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala index 862178694a50e..e954baaf7d91e 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala @@ -43,7 +43,7 @@ object LogisticRegressionSuite { offset: Double, scale: Double, nPoints: Int, - seed: Int): Seq[LabeledPoint] = { + seed: Int): Seq[LabeledPoint] = { val rnd = new Random(seed) val x1 = Array.fill[Double](nPoints)(rnd.nextGaussian()) @@ -58,12 +58,15 @@ object LogisticRegressionSuite { } class LogisticRegressionSuite extends FunSuite with LocalSparkContext with Matchers { - def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) { + def validatePrediction( + predictions: Seq[Double], + input: Seq[LabeledPoint], + expectedAcc: Double = 0.83) { val numOffPredictions = predictions.zip(input).count { case (prediction, expected) => prediction != expected.label } // At least 83% of the predictions should be on. - ((input.length - numOffPredictions).toDouble / input.length) should be > 0.83 + ((input.length - numOffPredictions).toDouble / input.length) should be > expectedAcc } // Test if we can correctly learn A, B where Y = logistic(A + B*X) @@ -155,6 +158,41 @@ class LogisticRegressionSuite extends FunSuite with LocalSparkContext with Match validatePrediction(validationData.map(row => model.predict(row.features)), validationData) } + test("logistic regression with initial weights and non-default regularization parameter") { + val nPoints = 10000 + val A = 2.0 + val B = -1.5 + + val testData = LogisticRegressionSuite.generateLogisticInput(A, B, nPoints, 42) + + val initialB = -1.0 + val initialWeights = Vectors.dense(initialB) + + val testRDD = sc.parallelize(testData, 2) + testRDD.cache() + + // Use half as many iterations as the previous test. + val lr = new LogisticRegressionWithSGD().setIntercept(true) + lr.optimizer. + setStepSize(10.0). + setNumIterations(10). + setRegParam(1.0) + + val model = lr.run(testRDD, initialWeights) + + // Test the weights + assert(model.weights(0) ~== -430000.0 relTol 20000.0) + assert(model.intercept ~== 370000.0 relTol 20000.0) + + val validationData = LogisticRegressionSuite.generateLogisticInput(A, B, nPoints, 17) + val validationRDD = sc.parallelize(validationData, 2) + // Test prediction on RDD. + validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData, 0.8) + + // Test prediction on Array. + validatePrediction(validationData.map(row => model.predict(row.features)), validationData, 0.8) + } + test("logistic regression with initial weights with LBFGS") { val nPoints = 10000 val A = 2.0