-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-19270][ML] Add summary table to GLM summary #16630
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 16 commits
19b8de4
93139b9
0b50f34
af2dbea
e2ac2d4
eec31b4
602c3bd
6882be4
8405501
10f0f9b
3d72cf5
07a6784
1c1d3e6
a16cbee
640d564
57f1e5c
167af01
174fc49
be11106
adb3a74
7281b77
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -20,12 +20,14 @@ package org.apache.spark.ml.regression | |
| import java.util.Locale | ||
|
|
||
| import breeze.stats.{distributions => dist} | ||
| import org.apache.commons.lang3.StringUtils | ||
| import org.apache.hadoop.fs.Path | ||
|
|
||
| import org.apache.spark.SparkException | ||
| import org.apache.spark.annotation.{Experimental, Since} | ||
| import org.apache.spark.internal.Logging | ||
| import org.apache.spark.ml.PredictorParams | ||
| import org.apache.spark.ml.attribute.AttributeGroup | ||
| import org.apache.spark.ml.feature.{Instance, OffsetInstance} | ||
| import org.apache.spark.ml.linalg.{BLAS, Vector, Vectors} | ||
| import org.apache.spark.ml.optim._ | ||
|
|
@@ -37,7 +39,6 @@ import org.apache.spark.sql.{Column, DataFrame, Dataset, Row} | |
| import org.apache.spark.sql.functions._ | ||
| import org.apache.spark.sql.types.{DataType, DoubleType, StructType} | ||
|
|
||
|
|
||
| /** | ||
| * Params for Generalized Linear Regression. | ||
| */ | ||
|
|
@@ -1204,6 +1205,22 @@ class GeneralizedLinearRegressionSummary private[regression] ( | |
| @Since("2.2.0") | ||
| lazy val numInstances: Long = predictions.count() | ||
|
|
||
|
|
||
| /** | ||
| * Name of features. If the name cannot be retrieved from attributes, | ||
| * set default names to feature column name with numbered suffix "_0", "_1", and so on. | ||
| */ | ||
| private[ml] lazy val featureNames: Array[String] = { | ||
| val featureAttrs = AttributeGroup.fromStructField( | ||
| dataset.schema(model.getFeaturesCol)).attributes | ||
| if (featureAttrs == None) { | ||
| Array.tabulate[String](origModel.numFeatures)( | ||
| (x: Int) => (model.getFeaturesCol + "_" + x)) | ||
| } else { | ||
| featureAttrs.get.map(_.name.get) | ||
| } | ||
| } | ||
|
|
||
| /** The numeric rank of the fitted linear model. */ | ||
| @Since("2.0.0") | ||
| lazy val rank: Long = if (model.getFitIntercept) { | ||
|
|
@@ -1458,4 +1475,167 @@ class GeneralizedLinearRegressionTrainingSummary private[regression] ( | |
| "No p-value available for this GeneralizedLinearRegressionModel") | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * Coefficient matrix with feature name, coefficient, standard error, | ||
| * tValue and pValue. | ||
| */ | ||
| @Since("2.3.0") | ||
| lazy val coefficientMatrix: Array[(String, Double, Double, Double, Double)] = { | ||
|
||
| if (isNormalSolver) { | ||
| var featureNamesLocal = featureNames | ||
| var coefficients = model.coefficients.toArray | ||
| var idx = Array.range(0, coefficients.length) | ||
| if (model.getFitIntercept) { | ||
| featureNamesLocal = featureNamesLocal :+ "(Intercept)" | ||
| coefficients = coefficients :+ model.intercept | ||
| // Reorder so that intercept comes first | ||
| idx = (coefficients.length - 1) +: idx | ||
| } | ||
| val result = for (i <- idx) yield | ||
| (featureNamesLocal(i), coefficients(i), coefficientStandardErrors(i), | ||
| tValues(i), pValues(i)) | ||
| result | ||
| } else { | ||
| throw new UnsupportedOperationException( | ||
| "No summary table available for this GeneralizedLinearRegressionModel") | ||
| } | ||
| } | ||
|
|
||
| private def round(x: Double, digit: Int): String = { | ||
| BigDecimal(x).setScale(digit, BigDecimal.RoundingMode.HALF_UP).toString() | ||
| } | ||
|
|
||
| private[regression] def showString(_numRows: Int, truncate: Int = 20, | ||
| numDigits: Int = 3): String = { | ||
|
||
| val numRows = _numRows.max(1) | ||
| val data = coefficientMatrix.take(numRows) | ||
| val hasMoreData = coefficientMatrix.size > numRows | ||
|
|
||
| val colNames = Array("Feature", "Estimate", "StdError", "TValue", "PValue") | ||
| val numCols = colNames.size | ||
|
|
||
| val rows = colNames +: data.map( row => { | ||
| val mrow = for (cell <- row.productIterator) yield { | ||
| val str = cell match { | ||
| case s: String => s | ||
| case n: Double => round(n, numDigits).toString | ||
| } | ||
| if (truncate > 0 && str.length > truncate) { | ||
| // do not show ellipses for strings shorter than 4 characters. | ||
| if (truncate < 4) str.substring(0, truncate) | ||
| else str.substring(0, truncate - 3) + "..." | ||
| } else { | ||
| str | ||
| } | ||
| } | ||
| mrow.toArray | ||
| }) | ||
|
|
||
| val sb = new StringBuilder | ||
| val colWidths = Array.fill(numCols)(3) | ||
|
|
||
| // Compute the width of each column | ||
| for (row <- rows) { | ||
| for ((cell, i) <- row.zipWithIndex) { | ||
| colWidths(i) = math.max(colWidths(i), cell.length) | ||
| } | ||
| } | ||
|
|
||
| // Create SeparateLine | ||
| val sep: String = colWidths.map("-" * _).addString(sb, "+", "+", "+\n").toString() | ||
|
|
||
| // column names | ||
| rows.head.zipWithIndex.map { case (cell, i) => | ||
| if (truncate > 0) { | ||
| StringUtils.leftPad(cell, colWidths(i)) | ||
| } else { | ||
| StringUtils.rightPad(cell, colWidths(i)) | ||
| } | ||
| }.addString(sb, "|", "|", "|\n") | ||
| sb.append(sep) | ||
|
|
||
| // data | ||
| rows.tail.map { | ||
| _.zipWithIndex.map { case (cell, i) => | ||
| if (truncate > 0) { | ||
| StringUtils.leftPad(cell.toString, colWidths(i)) | ||
| } else { | ||
| StringUtils.rightPad(cell.toString, colWidths(i)) | ||
| } | ||
| }.addString(sb, "|", "|", "|\n") | ||
| } | ||
|
|
||
| // For Data that has more than "numRows" records | ||
| if (hasMoreData) { | ||
| sb.append("...\n") | ||
| sb.append(sep) | ||
| val rowsString = if (numRows == 1) "row" else "rows" | ||
| sb.append(s"only showing top $numRows $rowsString\n") | ||
| } else { | ||
| sb.append(sep) | ||
| } | ||
|
|
||
| sb.append("\n") | ||
| sb.append(s"(Dispersion parameter for ${family.name} family taken to be " + | ||
| round(dispersion, numDigits) + ")") | ||
|
|
||
| sb.append("\n") | ||
| val nd = "Null deviance: " + round(nullDeviance, numDigits) + | ||
| s" on $degreesOfFreedom degrees of freedom" | ||
| val rd = "Residual deviance: " + round(deviance, numDigits) + | ||
| s" on $residualDegreeOfFreedom degrees of freedom" | ||
| val l = math.max(nd.length, rd.length) | ||
| sb.append(StringUtils.leftPad(nd, l)) | ||
| sb.append("\n") | ||
| sb.append(StringUtils.leftPad(rd, l)) | ||
|
|
||
| if (family.name != "tweedie") { | ||
| sb.append("\n") | ||
| sb.append(s"AIC: " + round(aic, numDigits)) | ||
| } | ||
|
|
||
| sb.toString() | ||
| } | ||
|
|
||
| /** | ||
| * Displays the summary of a GeneralizedLinearModel fit. | ||
| * | ||
| * @since 2.3.0 | ||
| */ | ||
| def show(): Unit = { | ||
| val numRows = coefficientMatrix.size | ||
| show(numRows, true, 3) | ||
| } | ||
|
|
||
| /** | ||
| * Displays the top numRows rows of the summary of a GeneralizedLinearModel fit. | ||
| * | ||
| * @param numRows Number of rows to show | ||
| * | ||
| * @since 2.3.0 | ||
| */ | ||
| @Since("2.3.0") | ||
| def show(numRows: Int): Unit = { | ||
| show(numRows, true, 3) | ||
| } | ||
|
|
||
| /** | ||
| * Displays the summary of a GeneralizedLinearModel fit. Strings more than 20 characters | ||
| * will be truncated, and all cells will be aligned right. | ||
| * | ||
| * @param numRows Number of rows to show | ||
| * @param truncate Whether truncate long strings. If true, strings more than 20 characters will | ||
| * be truncated and all cells will be aligned right | ||
| * @param numDigits Number of decimal places used to round numerical values. | ||
| * | ||
| * @since 2.3.0 | ||
| */ | ||
| // scalastyle:off println | ||
| def show(numRows: Int, truncate: Boolean, numDigits: Int): Unit = if (truncate) { | ||
|
||
| println(showString(numRows, truncate = 20, numDigits)) | ||
| } else { | ||
| println(showString(numRows, truncate = 0, numDigits)) | ||
| } | ||
| // scalastyle:on println | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -22,7 +22,7 @@ import scala.util.Random | |
| import org.apache.spark.SparkFunSuite | ||
| import org.apache.spark.ml.classification.LogisticRegressionSuite._ | ||
| import org.apache.spark.ml.feature.{Instance, OffsetInstance} | ||
| import org.apache.spark.ml.feature.LabeledPoint | ||
| import org.apache.spark.ml.feature.{LabeledPoint, RFormula} | ||
| import org.apache.spark.ml.linalg.{BLAS, DenseVector, Vector, Vectors} | ||
| import org.apache.spark.ml.param.{ParamMap, ParamsSuite} | ||
| import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} | ||
|
|
@@ -1524,6 +1524,87 @@ class GeneralizedLinearRegressionSuite | |
| .fit(datasetGaussianIdentity.as[LabeledPoint]) | ||
| } | ||
|
|
||
| test("glm summary: feature name") { | ||
| // dataset1 with no attribute | ||
| val dataset1 = Seq( | ||
| Instance(2.0, 1.0, Vectors.dense(0.0, 5.0)), | ||
| Instance(8.0, 2.0, Vectors.dense(1.0, 7.0)), | ||
| Instance(3.0, 3.0, Vectors.dense(2.0, 11.0)), | ||
| Instance(9.0, 4.0, Vectors.dense(3.0, 13.0)), | ||
| Instance(2.0, 5.0, Vectors.dense(2.0, 3.0)) | ||
| ).toDF() | ||
|
|
||
| // dataset2 with attribute | ||
| val datasetTmp = Seq( | ||
| (2.0, 1.0, 0.0, 5.0), | ||
| (8.0, 2.0, 1.0, 7.0), | ||
| (3.0, 3.0, 2.0, 11.0), | ||
| (9.0, 4.0, 3.0, 13.0), | ||
| (2.0, 5.0, 2.0, 3.0) | ||
| ).toDF("y", "w", "x1", "x2") | ||
| val formula = new RFormula().setFormula("y ~ x1 + x2") | ||
| val dataset2 = formula.fit(datasetTmp).transform(datasetTmp) | ||
|
|
||
| val expectedFeature = Seq(Array("features_0", "features_1"), Array("x1", "x2")) | ||
|
|
||
| var idx = 0 | ||
| for (dataset <- Seq(dataset1, dataset2)) { | ||
| val model = new GeneralizedLinearRegression().fit(dataset) | ||
| model.summary.featureNames.zip(expectedFeature(idx)) | ||
| .foreach{ x => assert(x._1 === x._2) } | ||
| idx += 1 | ||
| } | ||
| } | ||
|
|
||
| test("glm summary: coefficient matrix") { | ||
| /* | ||
| R code: | ||
|
|
||
| A <- matrix(c(0, 1, 2, 3, 2, 5, 7, 11, 13, 3), 5, 2) | ||
| b <- c(2, 8, 3, 9, 2) | ||
| df <- as.data.frame(cbind(A, b)) | ||
| model <- glm(formula = "b ~ .", data = df) | ||
| summary(model) | ||
|
|
||
| Coefficients: | ||
| Estimate Std. Error t value Pr(>|t|) | ||
| (Intercept) 0.7903 4.0129 0.197 0.862 | ||
| V1 0.2258 2.1153 0.107 0.925 | ||
| V2 0.4677 0.5815 0.804 0.506 | ||
| */ | ||
| val dataset = Seq( | ||
| Instance(2.0, 1.0, Vectors.dense(0.0, 5.0)), | ||
| Instance(8.0, 2.0, Vectors.dense(1.0, 7.0)), | ||
| Instance(3.0, 3.0, Vectors.dense(2.0, 11.0)), | ||
| Instance(9.0, 4.0, Vectors.dense(3.0, 13.0)), | ||
| Instance(2.0, 5.0, Vectors.dense(2.0, 3.0)) | ||
| ).toDF() | ||
|
|
||
| val expectedFeature = Seq(Array("features_0", "features_1"), | ||
| Array("(Intercept)", "features_0", "features_1")) | ||
| val expectedEstimate = Seq(Vectors.dense(0.2884, 0.538), | ||
|
||
| Vectors.dense(0.7903, 0.2258, 0.4677)) | ||
| val expectedStdError = Seq(Vectors.dense(1.724, 0.3787), | ||
| Vectors.dense(4.0129, 2.1153, 0.5815)) | ||
|
|
||
| var idx = 0 | ||
| for (fitIntercept <- Seq(false, true)) { | ||
| val trainer = new GeneralizedLinearRegression() | ||
| .setFamily("gaussian") | ||
|
||
| .setFitIntercept(fitIntercept) | ||
| val model = trainer.fit(dataset) | ||
| val coefficientMatrix = model.summary.coefficientMatrix | ||
|
|
||
| coefficientMatrix.map(_._1).zip(expectedFeature(idx)).foreach{ x => assert(x._1 === x._2, | ||
| "Feature name mismatch in summaryTable") } | ||
| assert(Vectors.dense(coefficientMatrix.map(_._2)) | ||
| ~== expectedEstimate(idx) absTol 1E-3, "Coefficient mismatch in summaryTable") | ||
| assert(Vectors.dense(coefficientMatrix.map(_._3)) | ||
| ~== expectedStdError(idx) absTol 1E-3, "Standard error mismatch in summaryTable") | ||
| idx += 1 | ||
| } | ||
| } | ||
|
|
||
| test("generalized linear regression: regularization parameter") { | ||
| /* | ||
| R code: | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in general I would have preferred to create a platform-level function (or use one if it exists) to format the strings in the same way, so there is no duplicate code in VectorAssembler vs here that can diverge (and which other functions in spark can generally use). However, this seems a bit out of scope of this code review, so I don't think you need to do this.