Skip to content

Commit a16cbee

Browse files
committed
use 2D array for summary table
1 parent 1c1d3e6 commit a16cbee

3 files changed

Lines changed: 157 additions & 39 deletions

File tree

mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -115,17 +115,10 @@ private[r] object GeneralizedLinearRegressionWrapper
115115
}
116116

117117
val rCoefficients: Array[Double] = if (summary.isNormalSolver) {
118-
val rCoefficientStandardErrors =
119-
summary.summaryTable.select("StdError").collect.map(_.getDouble(0))
120-
121-
val rTValues =
122-
summary.summaryTable.select("TValue").collect.map(_.getDouble(0))
123-
124-
val rPValues =
125-
summary.summaryTable.select("PValue").collect.map(_.getDouble(0))
126-
127-
summary.summaryTable.select("Coefficient").collect.map(_.getDouble(0)) ++
128-
rCoefficientStandardErrors ++ rTValues ++ rPValues
118+
summary.coefficientMatrix.map(_._2) ++
119+
summary.coefficientMatrix.map(_._3) ++
120+
summary.coefficientMatrix.map(_._4) ++
121+
summary.coefficientMatrix.map(_._5)
129122
} else {
130123
if (glm.getFitIntercept) {
131124
Array(glm.intercept) ++ glm.coefficients.toArray

mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala

Lines changed: 148 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@ package org.apache.spark.ml.regression
2020
import java.util.Locale
2121

2222
import breeze.stats.{distributions => dist}
23+
24+
import org.apache.commons.lang3.StringUtils
25+
2326
import org.apache.hadoop.fs.Path
2427

2528
import org.apache.spark.SparkException
@@ -34,7 +37,7 @@ import org.apache.spark.ml.param._
3437
import org.apache.spark.ml.param.shared._
3538
import org.apache.spark.ml.util._
3639
import org.apache.spark.rdd.RDD
37-
import org.apache.spark.sql.{Column, DataFrame, Dataset, Row, SparkSession}
40+
import org.apache.spark.sql.{Column, DataFrame, Dataset, Row}
3841
import org.apache.spark.sql.functions._
3942
import org.apache.spark.sql.types.{DataType, DoubleType, StructType}
4043

@@ -1211,8 +1214,7 @@ class GeneralizedLinearRegressionSummary private[regression] (
12111214
* Name of features. If the name cannot be retrieved from attributes,
12121215
* set default names to feature column name with numbered suffix "_0", "_1", and so on.
12131216
*/
1214-
@Since("2.2.0")
1215-
lazy val featureNames: Array[String] = {
1217+
private[ml] lazy val featureNames: Array[String] = {
12161218
val featureAttrs = AttributeGroup.fromStructField(
12171219
dataset.schema(model.getFeaturesCol)).attributes
12181220
if (featureAttrs == None) {
@@ -1479,31 +1481,165 @@ class GeneralizedLinearRegressionTrainingSummary private[regression] (
14791481
}
14801482

14811483
/**
1482-
* Summary table with feature name, coefficient, standard error,
1484+
* Coefficient matrix with feature name, coefficient, standard error,
14831485
* tValue and pValue.
14841486
*/
1485-
@Since("2.2.0")
1486-
lazy val summaryTable: DataFrame = {
1487+
@Since("2.3.0")
1488+
lazy val coefficientMatrix: Array[(String, Double, Double, Double, Double)] = {
14871489
if (isNormalSolver) {
14881490
var featureNamesLocal = featureNames
14891491
var coefficients = model.coefficients.toArray
14901492
var idx = Array.range(0, coefficients.length)
14911493
if (model.getFitIntercept) {
1492-
featureNamesLocal = featureNamesLocal :+ Intercept
1494+
featureNamesLocal = featureNamesLocal :+ "(Intercept)"
14931495
coefficients = coefficients :+ model.intercept
14941496
// Reorder so that intercept comes first
14951497
idx = (coefficients.length - 1) +: idx
14961498
}
1497-
val result = for (i <- idx.toSeq) yield
1499+
val result = for (i <- idx) yield
14981500
(featureNamesLocal(i), coefficients(i), coefficientStandardErrors(i),
14991501
tValues(i), pValues(i))
1500-
1501-
val spark = SparkSession.builder().getOrCreate()
1502-
import spark.implicits._
1503-
result.toDF("Feature", "Coefficient", "StdError", "TValue", "PValue").repartition(1)
1502+
result
15041503
} else {
15051504
throw new UnsupportedOperationException(
15061505
"No summary table available for this GeneralizedLinearRegressionModel")
15071506
}
15081507
}
1508+
1509+
private def round(x: Double, digit: Int): String = {
1510+
BigDecimal(x).setScale(digit, BigDecimal.RoundingMode.HALF_UP).toString()
1511+
}
1512+
1513+
private[regression] def showString(_numRows: Int, truncate: Int = 20,
1514+
numDigits: Int = 3): String = {
1515+
val numRows = _numRows.max(1)
1516+
val data = coefficientMatrix.take(numRows)
1517+
val hasMoreData = coefficientMatrix.size > numRows
1518+
1519+
val colNames = Array("Feature", "Estimate", "StdError", "TValue", "PValue")
1520+
val numCols = colNames.size
1521+
1522+
val rows = colNames +: data.map( row => {
1523+
val mrow = for (cell <- row.productIterator) yield {
1524+
val str = cell match {
1525+
case s: String => s
1526+
case n: Double => round(n, numDigits).toString
1527+
}
1528+
if (truncate > 0 && str.length > truncate) {
1529+
// do not show ellipses for strings shorter than 4 characters.
1530+
if (truncate < 4) str.substring(0, truncate)
1531+
else str.substring(0, truncate - 3) + "..."
1532+
} else {
1533+
str
1534+
}
1535+
}
1536+
mrow.toArray
1537+
})
1538+
1539+
val sb = new StringBuilder
1540+
val colWidths = Array.fill(numCols)(3)
1541+
1542+
// Compute the width of each column
1543+
for (row <- rows) {
1544+
for ((cell, i) <- row.zipWithIndex) {
1545+
colWidths(i) = math.max(colWidths(i), cell.length)
1546+
}
1547+
}
1548+
1549+
// Create SeparateLine
1550+
val sep: String = colWidths.map("-" * _).addString(sb, "+", "+", "+\n").toString()
1551+
1552+
// column names
1553+
rows.head.zipWithIndex.map { case (cell, i) =>
1554+
if (truncate > 0) {
1555+
StringUtils.leftPad(cell, colWidths(i))
1556+
} else {
1557+
StringUtils.rightPad(cell, colWidths(i))
1558+
}
1559+
}.addString(sb, "|", "|", "|\n")
1560+
sb.append(sep)
1561+
1562+
// data
1563+
rows.tail.map {
1564+
_.zipWithIndex.map { case (cell, i) =>
1565+
if (truncate > 0) {
1566+
StringUtils.leftPad(cell.toString, colWidths(i))
1567+
} else {
1568+
StringUtils.rightPad(cell.toString, colWidths(i))
1569+
}
1570+
}.addString(sb, "|", "|", "|\n")
1571+
}
1572+
1573+
// For Data that has more than "numRows" records
1574+
if (hasMoreData) {
1575+
sb.append("...\n")
1576+
sb.append(sep)
1577+
val rowsString = if (numRows == 1) "row" else "rows"
1578+
sb.append(s"only showing top $numRows $rowsString\n")
1579+
} else {
1580+
sb.append(sep)
1581+
}
1582+
1583+
sb.append("\n")
1584+
sb.append(s"(Dispersion parameter for ${family.name} family taken to be " +
1585+
round(dispersion, numDigits) + ")")
1586+
1587+
sb.append("\n")
1588+
val nd = "Null deviance: " + round(nullDeviance, numDigits) +
1589+
s" on $degreesOfFreedom degrees of freedom"
1590+
val rd = "Residual deviance: " + round(deviance, numDigits) +
1591+
s" on $residualDegreeOfFreedom degrees of freedom"
1592+
val l = math.max(nd.length, rd.length)
1593+
sb.append(StringUtils.leftPad(nd, l))
1594+
sb.append("\n")
1595+
sb.append(StringUtils.leftPad(rd, l))
1596+
1597+
if (family.name != "tweedie") {
1598+
sb.append("\n")
1599+
sb.append(s"AIC: " + round(aic, numDigits))
1600+
}
1601+
1602+
sb.toString()
1603+
}
1604+
1605+
/**
1606+
* Displays the summary of a GeneralizedLinearModel fit.
1607+
*
1608+
* @since 2.3.0
1609+
*/
1610+
def show(): Unit = {
1611+
val numRows = coefficientMatrix.size
1612+
show(numRows, true, 3)
1613+
}
1614+
1615+
/**
1616+
* Displays the top numRows rows of the summary of a GeneralizedLinearModel fit.
1617+
*
1618+
* @param numRows Number of rows to show
1619+
*
1620+
* @since 2.3.0
1621+
*/
1622+
@Since("2.3.0")
1623+
def show(numRows: Int): Unit = {
1624+
show(numRows, true, 3)
1625+
}
1626+
1627+
/**
1628+
* Displays the summary of a GeneralizedLinearModel fit. Strings more than 20 characters
1629+
* will be truncated, and all cells will be aligned right.
1630+
*
1631+
* @param numRows Number of rows to show
1632+
* @param truncate Whether truncate long strings. If true, strings more than 20 characters will
1633+
* be truncated and all cells will be aligned right
1634+
* @param numDigits Number of decimal places used to round numerical values.
1635+
*
1636+
* @since 2.3.0
1637+
*/
1638+
// scalastyle:off println
1639+
def show(numRows: Int, truncate: Boolean, numDigits: Int): Unit = if (truncate) {
1640+
println(showString(numRows, truncate = 20, numDigits))
1641+
} else {
1642+
println(showString(numRows, truncate = 0, numDigits))
1643+
}
1644+
// scalastyle:on println
15091645
}

mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1524,7 +1524,6 @@ class GeneralizedLinearRegressionSuite
15241524
.fit(datasetGaussianIdentity.as[LabeledPoint])
15251525
}
15261526

1527-
15281527
test("glm summary: feature name") {
15291528
// dataset1 with no attribute
15301529
val dataset1 = Seq(
@@ -1557,7 +1556,7 @@ class GeneralizedLinearRegressionSuite
15571556
}
15581557
}
15591558

1560-
test("glm summary: summaryTable") {
1559+
test("glm summary: coefficient matrix") {
15611560
/*
15621561
R code:
15631562
@@ -1587,31 +1586,21 @@ class GeneralizedLinearRegressionSuite
15871586
Vectors.dense(0.7903, 0.2258, 0.4677))
15881587
val expectedStdError = Seq(Vectors.dense(1.724, 0.3787),
15891588
Vectors.dense(4.0129, 2.1153, 0.5815))
1590-
val expectedTValue = Seq(Vectors.dense(0.1673, 1.4205),
1591-
Vectors.dense(0.1969, 0.1067, 0.8043))
1592-
val expectedPValue = Seq(Vectors.dense(0.8778, 0.2506),
1593-
Vectors.dense(0.8621, 0.9247, 0.5056))
15941589

15951590
var idx = 0
15961591
for (fitIntercept <- Seq(false, true)) {
15971592
val trainer = new GeneralizedLinearRegression()
15981593
.setFamily("gaussian")
15991594
.setFitIntercept(fitIntercept)
16001595
val model = trainer.fit(dataset)
1601-
val summaryTable = model.summary.summaryTable
1596+
val coefficientMatrix = model.summary.coefficientMatrix
16021597

1603-
summaryTable.select("Feature").collect.map(_.getString(0))
1604-
.zip(expectedFeature(idx)).foreach{ x => assert(x._1 === x._2,
1598+
coefficientMatrix.map(_._1).zip(expectedFeature(idx)).foreach{ x => assert(x._1 === x._2,
16051599
"Feature name mismatch in summaryTable") }
1606-
assert(Vectors.dense(summaryTable.select("Coefficient").collect.map(_.getDouble(0)))
1600+
assert(Vectors.dense(coefficientMatrix.map(_._2))
16071601
~== expectedEstimate(idx) absTol 1E-3, "Coefficient mismatch in summaryTable")
1608-
assert(Vectors.dense(summaryTable.select("StdError").collect.map(_.getDouble(0)))
1602+
assert(Vectors.dense(coefficientMatrix.map(_._3))
16091603
~== expectedStdError(idx) absTol 1E-3, "Standard error mismatch in summaryTable")
1610-
assert(Vectors.dense(summaryTable.select("TValue").collect.map(_.getDouble(0)))
1611-
~== expectedTValue(idx) absTol 1E-3, "TValue mismatch in summaryTable")
1612-
assert(Vectors.dense(summaryTable.select("PValue").collect.map(_.getDouble(0)))
1613-
~== expectedPValue(idx) absTol 1E-3, "PValue mismatch in summaryTable")
1614-
16151604
idx += 1
16161605
}
16171606
}

0 commit comments

Comments
 (0)