Skip to content
113 changes: 73 additions & 40 deletions sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -38,18 +38,18 @@ import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.catalog.CatalogRelation
import org.apache.spark.sql.catalyst.encoders._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.json.{JacksonGenerator, JSONOptions}
import org.apache.spark.sql.catalyst.optimizer.CombineUnions
import org.apache.spark.sql.catalyst.parser.ParseException
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, PartitioningCollection}
import org.apache.spark.sql.catalyst.util.{usePrettyExpression, DateTimeUtils}
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.command._
import org.apache.spark.sql.execution.datasources.LogicalRelation
import org.apache.spark.sql.execution.python.EvaluatePython
import org.apache.spark.sql.execution.stat.StatFunctions
import org.apache.spark.sql.streaming.DataStreamWriter
import org.apache.spark.sql.types._
import org.apache.spark.storage.StorageLevel
Expand Down Expand Up @@ -224,7 +224,7 @@ class Dataset[T] private[sql](
}
}

private def aggregatableColumns: Seq[Expression] = {
private[sql] def aggregatableColumns: Seq[Expression] = {
schema.fields
.filter(f => f.dataType.isInstanceOf[NumericType] || f.dataType.isInstanceOf[StringType])
.map { n =>
Expand Down Expand Up @@ -2185,9 +2185,9 @@ class Dataset[T] private[sql](
}

/**
* Computes statistics for numeric and string columns, including count, mean, stddev, min, and
* max. If no columns are given, this function computes statistics for all numerical or string
* columns.
* Computes basic statistics for numeric and string columns, including count, mean, stddev, min,
* and max. If no columns are given, this function computes statistics for all numerical or
* string columns.
*
* This function is meant for exploratory data analysis, as we make no guarantee about the
* backward compatibility of the schema of the resulting Dataset. If you want to
Expand All @@ -2205,46 +2205,79 @@ class Dataset[T] private[sql](
* // max 92.0 192.0
* }}}
*
* Use [[summary]] for expanded statistics and control over which statistics to compute.
*
* @param cols Columns to compute statistics on.
*
* @group action
* @since 1.6.0
*/
@scala.annotation.varargs
def describe(cols: String*): DataFrame = withPlan {

// The list of summary statistics to compute, in the form of expressions.
val statistics = List[(String, Expression => Expression)](
"count" -> ((child: Expression) => Count(child).toAggregateExpression()),
"mean" -> ((child: Expression) => Average(child).toAggregateExpression()),
"stddev" -> ((child: Expression) => StddevSamp(child).toAggregateExpression()),
"min" -> ((child: Expression) => Min(child).toAggregateExpression()),
"max" -> ((child: Expression) => Max(child).toAggregateExpression()))

val outputCols =
(if (cols.isEmpty) aggregatableColumns.map(usePrettyExpression(_).sql) else cols).toList

val ret: Seq[Row] = if (outputCols.nonEmpty) {
val aggExprs = statistics.flatMap { case (_, colToAgg) =>
outputCols.map(c => Column(Cast(colToAgg(Column(c).expr), StringType)).as(c))
}

val row = groupBy().agg(aggExprs.head, aggExprs.tail: _*).head().toSeq

// Pivot the data so each summary is one row
row.grouped(outputCols.size).toSeq.zip(statistics).map { case (aggregation, (statistic, _)) =>
Row(statistic :: aggregation.toList: _*)
}
} else {
// If there are no output columns, just output a single column that contains the stats.
statistics.map { case (name, _) => Row(name) }
}

// All columns are string type
val schema = StructType(
StructField("summary", StringType) :: outputCols.map(StructField(_, StringType))).toAttributes
// `toArray` forces materialization to make the seq serializable
LocalRelation.fromExternalRows(schema, ret.toArray.toSeq)
def describe(cols: String*): DataFrame = {
val selected = if (cols.isEmpty) this else select(cols.head, cols.tail: _*)
selected.summary("count", "mean", "stddev", "min", "max")
}

/**
* Computes specified statistics for numeric and string columns. Available statistics are:
*
* - count
* - mean
* - stddev
* - min
* - max
* - arbitrary approximate percentiles specified as a percentage (eg, 75%)
*
* If no statistics are given, this function computes count, mean, stddev, min,
* approximate quartiles (percentiles at 25%, 50%, and 75%), and max.
*
* This function is meant for exploratory data analysis, as we make no guarantee about the
* backward compatibility of the schema of the resulting Dataset. If you want to
* programmatically compute summary statistics, use the `agg` function instead.
*
* {{{
* ds.summary().show()
*
* // output:
* // summary age height
* // count 10.0 10.0
* // mean 53.3 178.05
* // stddev 11.6 15.7
* // min 18.0 163.0
* // 25% 24.0 176.0
* // 50% 24.0 176.0
* // 75% 32.0 180.0
* // max 92.0 192.0
* }}}
*
* {{{
* ds.summary("count", "min", "25%", "75%", "max").show()
*
* // output:
* // summary age height
* // count 10.0 10.0
* // min 18.0 163.0
* // 25% 24.0 176.0
* // 75% 32.0 180.0
* // max 92.0 192.0
* }}}
*
* To do a summary for specific columns first select them:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a better usage than the previous describe 👍

*
* {{{
* ds.select("age", "height").summary().show()
* }}}
*
* See also [[describe]] for basic statistics.
*
* @param statistics Statistics from above list to be computed.
*
* @group action
* @since 2.3.0
*/
@scala.annotation.varargs
def summary(statistics: String*): DataFrame = StatFunctions.summary(this, statistics.toSeq)

/**
* Returns the first `n` rows.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,10 @@ package org.apache.spark.sql.execution.stat

import org.apache.spark.internal.Logging
import org.apache.spark.sql.{Column, DataFrame, Dataset, Row}
import org.apache.spark.sql.catalyst.expressions.{Cast, GenericInternalRow}
import org.apache.spark.sql.catalyst.expressions.{Cast, CreateArray, Expression, GenericInternalRow, Literal}
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
import org.apache.spark.sql.catalyst.util.QuantileSummaries
import org.apache.spark.sql.catalyst.util.{usePrettyExpression, QuantileSummaries}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
Expand Down Expand Up @@ -220,4 +221,97 @@ object StatFunctions extends Logging {

Dataset.ofRows(df.sparkSession, LocalRelation(schema.toAttributes, table)).na.fill(0.0)
}

/** Calculate selected summary statistics for a dataset */
def summary(ds: Dataset[_], statistics: Seq[String]): DataFrame = {

val defaultStatistics = Seq("count", "mean", "stddev", "min", "25%", "50%", "75%", "max")
val selectedStatistics = if (statistics.nonEmpty) statistics else defaultStatistics

val hasPercentiles = selectedStatistics.exists(_.endsWith("%"))
val (percentiles, percentileNames, remainingAggregates) = if (hasPercentiles) {
val (pStrings, rest) = selectedStatistics.partition(a => a.endsWith("%"))
val percentiles = pStrings.map { p =>
try {
p.stripSuffix("%").toDouble / 100.0
} catch {
case e: NumberFormatException =>
throw new IllegalArgumentException(s"Unable to parse $p as a percentile", e)
}
}
require(percentiles.forall(p => p >= 0 && p <= 1), "Percentiles must be in the range [0, 1]")
(percentiles, pStrings, rest)
} else {
(Seq(), Seq(), selectedStatistics)
}


// The list of summary statistics to compute, in the form of expressions.
val availableStatistics = Map[String, Expression => Expression](
"count" -> ((child: Expression) => Count(child).toAggregateExpression()),
"mean" -> ((child: Expression) => Average(child).toAggregateExpression()),
"stddev" -> ((child: Expression) => StddevSamp(child).toAggregateExpression()),
"min" -> ((child: Expression) => Min(child).toAggregateExpression()),
"max" -> ((child: Expression) => Max(child).toAggregateExpression()))

val statisticFns = remainingAggregates.map { agg =>
require(availableStatistics.contains(agg), s"$agg is not a recognised statistic")
agg -> availableStatistics(agg)
}

def percentileAgg(child: Expression): Expression =
new ApproximatePercentile(child, CreateArray(percentiles.map(Literal(_))))
.toAggregateExpression()

val outputCols = ds.aggregatableColumns.map(usePrettyExpression(_).sql).toList

val ret: Seq[Row] = if (outputCols.nonEmpty) {
var aggExprs = statisticFns.toList.flatMap { case (_, colToAgg) =>
outputCols.map(c => Column(Cast(colToAgg(Column(c).expr), StringType)).as(c))
}
if (hasPercentiles) {
aggExprs = outputCols.map(c => Column(percentileAgg(Column(c).expr)).as(c)) ++ aggExprs
}

val row = ds.groupBy().agg(aggExprs.head, aggExprs.tail: _*).head().toSeq

// Pivot the data so each summary is one row
val grouped: Seq[Seq[Any]] = row.grouped(outputCols.size).toSeq

val basicStats = if (hasPercentiles) grouped.tail else grouped

val rows = basicStats.zip(statisticFns).map { case (aggregation, (statistic, _)) =>
Row(statistic :: aggregation.toList: _*)
}

if (hasPercentiles) {
def nullSafeString(x: Any) = if (x == null) null else x.toString
val percentileRows = grouped.head
.map {
case a: Seq[Any] => a
case _ => Seq.fill(percentiles.length)(null: Any)
}
.transpose
.zip(percentileNames)
.map { case (values: Seq[Any], name) =>
Row(name :: values.map(nullSafeString).toList: _*)
}
(rows ++ percentileRows)
.sortWith((left, right) =>
selectedStatistics.indexOf(left(0)) < selectedStatistics.indexOf(right(0)))
} else {
rows
}
} else {
// If there are no output columns, just output a single column that contains the stats.
selectedStatistics.map(Row(_))
}

// All columns are string type
val schema = StructType(
StructField("summary", StringType) :: outputCols.map(StructField(_, StringType))).toAttributes
// `toArray` forces materialization to make the seq serializable
Dataset.ofRows(ds.sparkSession, LocalRelation.fromExternalRows(schema, ret.toArray.toSeq))
}

}
Loading