@@ -296,8 +296,8 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext {
296296 assert(model.intercept ~== interceptR relTol 1E-2 )
297297 assert(model.weights(0 ) ~== weightsR(0 ) relTol 1E-3 )
298298 assert(model.weights(1 ) ~== weightsR(1 ) relTol 1E-3 )
299- assert(model.weights(2 ) ~== weightsR(2 ) relTol 1E-3 )
300- assert(model.weights(3 ) ~== weightsR(3 ) relTol 1E -2 )
299+ assert(model.weights(2 ) ~== weightsR(2 ) relTol 1E-2 )
300+ assert(model.weights(3 ) ~== weightsR(3 ) relTol 2E -2 )
301301 }
302302
303303 test(" binary logistic regression without intercept with L1 regularization" ) {
@@ -423,10 +423,10 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext {
423423 val interceptR = 0.57734851
424424 val weightsR = Array (- 0.05310287 , 0.0 , - 0.08849250 , - 0.15458796 )
425425
426- assert(model.intercept ~== interceptR relTol 1E-3 )
427- assert(model.weights(0 ) ~== weightsR(0 ) relTol 1E-3 )
426+ assert(model.intercept ~== interceptR relTol 1E-2 )
427+ assert(model.weights(0 ) ~== weightsR(0 ) relTol 1E-2 )
428428 assert(model.weights(1 ) ~== weightsR(1 ) relTol 1E-3 )
429- assert(model.weights(2 ) ~== weightsR(2 ) relTol 1E-3 )
429+ assert(model.weights(2 ) ~== weightsR(2 ) relTol 1E-2 )
430430 assert(model.weights(3 ) ~== weightsR(3 ) relTol 1E-3 )
431431 }
432432
@@ -462,4 +462,66 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext {
462462 assert(model.weights(2 ) ~== weightsR(2 ) relTol 1E-3 )
463463 assert(model.weights(3 ) ~== weightsR(3 ) relTol 1E-2 )
464464 }
465+
466+ test(" binary logistic regression with intercept with strong L1 regularization" ) {
467+ val trainer = (new LogisticRegression ).setFitIntercept(true )
468+ .setElasticNetParam(1.0 ).setRegParam(6.0 )
469+ val model = trainer.fit(binaryDataset)
470+
471+ val histogram = binaryDataset.map { case Row (label : Double , features : Vector ) => label }
472+ .treeAggregate(new MultiClassSummarizer )(
473+ seqOp = (c, v) => (c, v) match {
474+ case (classSummarizer : MultiClassSummarizer , label : Double ) => classSummarizer.add(label)
475+ },
476+ combOp = (c1, c2) => (c1, c2) match {
477+ case (classSummarizer1 : MultiClassSummarizer , classSummarizer2 : MultiClassSummarizer ) =>
478+ classSummarizer1.merge(classSummarizer2)
479+ }).histogram
480+
481+ /**
482+ * For binary logistic regression with strong L1 regularization, all the weights will be zeros.
483+ * As a result,
484+ * {{{
485+ * P(0) = 1 / (1 + \exp(b)), and
486+ * P(1) = \exp(b) / (1 + \exp(b))
487+ * }}}, hence
488+ * {{{
489+ * b = \log{P(1) / P(0)} = \log{count_1 / count_0}
490+ * }}}
491+ */
492+ val interceptTheory = Math .log(histogram(1 ).toDouble / histogram(0 ).toDouble)
493+ val weightsTheory = Array (0.0 , 0.0 , 0.0 , 0.0 )
494+
495+ assert(model.intercept ~== interceptTheory relTol 1E-3 )
496+ assert(model.weights(0 ) ~== weightsTheory(0 ) absTol 1E-6 )
497+ assert(model.weights(1 ) ~== weightsTheory(1 ) absTol 1E-6 )
498+ assert(model.weights(2 ) ~== weightsTheory(2 ) absTol 1E-6 )
499+ assert(model.weights(3 ) ~== weightsTheory(3 ) absTol 1E-6 )
500+
501+ /**
502+ * Using the following R code to load the data and train the model using glmnet package.
503+ *
504+ * > library("glmnet")
505+ * > data <- read.csv("path", header=FALSE)
506+ * > label = factor(data$V1)
507+ * > features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5))
508+ * > weights = coef(glmnet(features,label, family="binomial", alpha = 1.0, lambda = 6.0))
509+ * > weights
510+ * 5 x 1 sparse Matrix of class "dgCMatrix"
511+ * s0
512+ * (Intercept) -0.2480643
513+ * data.V2 0.0000000
514+ * data.V3 .
515+ * data.V4 .
516+ * data.V5 .
517+ */
518+ val interceptR = - 0.248065
519+ val weightsR = Array (0.0 , 0.0 , 0.0 , 0.0 )
520+
521+ assert(model.intercept ~== interceptR relTol 1E-3 )
522+ assert(model.weights(0 ) ~== weightsR(0 ) absTol 1E-6 )
523+ assert(model.weights(1 ) ~== weightsR(1 ) absTol 1E-6 )
524+ assert(model.weights(2 ) ~== weightsR(2 ) absTol 1E-6 )
525+ assert(model.weights(3 ) ~== weightsR(3 ) absTol 1E-6 )
526+ }
465527}
0 commit comments