@@ -20,6 +20,9 @@ package org.apache.spark.ml.regression
2020import java .util .Locale
2121
2222import breeze .stats .{distributions => dist }
23+
24+ import org .apache .commons .lang3 .StringUtils
25+
2326import org .apache .hadoop .fs .Path
2427
2528import org .apache .spark .SparkException
@@ -34,7 +37,7 @@ import org.apache.spark.ml.param._
3437import org .apache .spark .ml .param .shared ._
3538import org .apache .spark .ml .util ._
3639import org .apache .spark .rdd .RDD
37- import org .apache .spark .sql .{Column , DataFrame , Dataset , Row , SparkSession }
40+ import org .apache .spark .sql .{Column , DataFrame , Dataset , Row }
3841import org .apache .spark .sql .functions ._
3942import org .apache .spark .sql .types .{DataType , DoubleType , StructType }
4043
@@ -1211,8 +1214,7 @@ class GeneralizedLinearRegressionSummary private[regression] (
12111214 * Name of features. If the name cannot be retrieved from attributes,
12121215 * set default names to feature column name with numbered suffix "_0", "_1", and so on.
12131216 */
1214- @ Since (" 2.2.0" )
1215- lazy val featureNames : Array [String ] = {
1217+ private [ml] lazy val featureNames : Array [String ] = {
12161218 val featureAttrs = AttributeGroup .fromStructField(
12171219 dataset.schema(model.getFeaturesCol)).attributes
12181220 if (featureAttrs == None ) {
@@ -1479,31 +1481,165 @@ class GeneralizedLinearRegressionTrainingSummary private[regression] (
14791481 }
14801482
14811483 /**
1482- * Summary table with feature name, coefficient, standard error,
1484+ * Coefficient matrix with feature name, coefficient, standard error,
14831485 * tValue and pValue.
14841486 */
1485- @ Since (" 2.2 .0" )
1486- lazy val summaryTable : DataFrame = {
1487+ @ Since (" 2.3 .0" )
1488+ lazy val coefficientMatrix : Array [( String , Double , Double , Double , Double )] = {
14871489 if (isNormalSolver) {
14881490 var featureNamesLocal = featureNames
14891491 var coefficients = model.coefficients.toArray
14901492 var idx = Array .range(0 , coefficients.length)
14911493 if (model.getFitIntercept) {
1492- featureNamesLocal = featureNamesLocal :+ Intercept
1494+ featureNamesLocal = featureNamesLocal :+ " ( Intercept) "
14931495 coefficients = coefficients :+ model.intercept
14941496 // Reorder so that intercept comes first
14951497 idx = (coefficients.length - 1 ) +: idx
14961498 }
1497- val result = for (i <- idx.toSeq ) yield
1499+ val result = for (i <- idx) yield
14981500 (featureNamesLocal(i), coefficients(i), coefficientStandardErrors(i),
14991501 tValues(i), pValues(i))
1500-
1501- val spark = SparkSession .builder().getOrCreate()
1502- import spark .implicits ._
1503- result.toDF(" Feature" , " Coefficient" , " StdError" , " TValue" , " PValue" ).repartition(1 )
1502+ result
15041503 } else {
15051504 throw new UnsupportedOperationException (
15061505 " No summary table available for this GeneralizedLinearRegressionModel" )
15071506 }
15081507 }
1508+
1509+ private def round (x : Double , digit : Int ): String = {
1510+ BigDecimal (x).setScale(digit, BigDecimal .RoundingMode .HALF_UP ).toString()
1511+ }
1512+
1513+ private [regression] def showString (_numRows : Int , truncate : Int = 20 ,
1514+ numDigits : Int = 3 ): String = {
1515+ val numRows = _numRows.max(1 )
1516+ val data = coefficientMatrix.take(numRows)
1517+ val hasMoreData = coefficientMatrix.size > numRows
1518+
1519+ val colNames = Array (" Feature" , " Estimate" , " StdError" , " TValue" , " PValue" )
1520+ val numCols = colNames.size
1521+
1522+ val rows = colNames +: data.map( row => {
1523+ val mrow = for (cell <- row.productIterator) yield {
1524+ val str = cell match {
1525+ case s : String => s
1526+ case n : Double => round(n, numDigits).toString
1527+ }
1528+ if (truncate > 0 && str.length > truncate) {
1529+ // do not show ellipses for strings shorter than 4 characters.
1530+ if (truncate < 4 ) str.substring(0 , truncate)
1531+ else str.substring(0 , truncate - 3 ) + " ..."
1532+ } else {
1533+ str
1534+ }
1535+ }
1536+ mrow.toArray
1537+ })
1538+
1539+ val sb = new StringBuilder
1540+ val colWidths = Array .fill(numCols)(3 )
1541+
1542+ // Compute the width of each column
1543+ for (row <- rows) {
1544+ for ((cell, i) <- row.zipWithIndex) {
1545+ colWidths(i) = math.max(colWidths(i), cell.length)
1546+ }
1547+ }
1548+
1549+ // Create SeparateLine
1550+ val sep : String = colWidths.map(" -" * _).addString(sb, " +" , " +" , " +\n " ).toString()
1551+
1552+ // column names
1553+ rows.head.zipWithIndex.map { case (cell, i) =>
1554+ if (truncate > 0 ) {
1555+ StringUtils .leftPad(cell, colWidths(i))
1556+ } else {
1557+ StringUtils .rightPad(cell, colWidths(i))
1558+ }
1559+ }.addString(sb, " |" , " |" , " |\n " )
1560+ sb.append(sep)
1561+
1562+ // data
1563+ rows.tail.map {
1564+ _.zipWithIndex.map { case (cell, i) =>
1565+ if (truncate > 0 ) {
1566+ StringUtils .leftPad(cell.toString, colWidths(i))
1567+ } else {
1568+ StringUtils .rightPad(cell.toString, colWidths(i))
1569+ }
1570+ }.addString(sb, " |" , " |" , " |\n " )
1571+ }
1572+
1573+ // For Data that has more than "numRows" records
1574+ if (hasMoreData) {
1575+ sb.append(" ...\n " )
1576+ sb.append(sep)
1577+ val rowsString = if (numRows == 1 ) " row" else " rows"
1578+ sb.append(s " only showing top $numRows $rowsString\n " )
1579+ } else {
1580+ sb.append(sep)
1581+ }
1582+
1583+ sb.append(" \n " )
1584+ sb.append(s " (Dispersion parameter for ${family.name} family taken to be " +
1585+ round(dispersion, numDigits) + " )" )
1586+
1587+ sb.append(" \n " )
1588+ val nd = " Null deviance: " + round(nullDeviance, numDigits) +
1589+ s " on $degreesOfFreedom degrees of freedom "
1590+ val rd = " Residual deviance: " + round(deviance, numDigits) +
1591+ s " on $residualDegreeOfFreedom degrees of freedom "
1592+ val l = math.max(nd.length, rd.length)
1593+ sb.append(StringUtils .leftPad(nd, l))
1594+ sb.append(" \n " )
1595+ sb.append(StringUtils .leftPad(rd, l))
1596+
1597+ if (family.name != " tweedie" ) {
1598+ sb.append(" \n " )
1599+ sb.append(s " AIC: " + round(aic, numDigits))
1600+ }
1601+
1602+ sb.toString()
1603+ }
1604+
1605+ /**
1606+ * Displays the summary of a GeneralizedLinearModel fit.
1607+ *
1608+ * @since 2.3.0
1609+ */
1610+ def show (): Unit = {
1611+ val numRows = coefficientMatrix.size
1612+ show(numRows, true , 3 )
1613+ }
1614+
1615+ /**
1616+ * Displays the top numRows rows of the summary of a GeneralizedLinearModel fit.
1617+ *
1618+ * @param numRows Number of rows to show
1619+ *
1620+ * @since 2.3.0
1621+ */
1622+ @ Since (" 2.3.0" )
1623+ def show (numRows : Int ): Unit = {
1624+ show(numRows, true , 3 )
1625+ }
1626+
1627+ /**
1628+ * Displays the summary of a GeneralizedLinearModel fit. Strings more than 20 characters
1629+ * will be truncated, and all cells will be aligned right.
1630+ *
1631+ * @param numRows Number of rows to show
1632+ * @param truncate Whether truncate long strings. If true, strings more than 20 characters will
1633+ * be truncated and all cells will be aligned right
1634+ * @param numDigits Number of decimal places used to round numerical values.
1635+ *
1636+ * @since 2.3.0
1637+ */
1638+ // scalastyle:off println
1639+ def show (numRows : Int , truncate : Boolean , numDigits : Int ): Unit = if (truncate) {
1640+ println(showString(numRows, truncate = 20 , numDigits))
1641+ } else {
1642+ println(showString(numRows, truncate = 0 , numDigits))
1643+ }
1644+ // scalastyle:on println
15091645}
0 commit comments