diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index d28ff7888d12..4060b3e923f3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -38,18 +38,18 @@ import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.catalog.CatalogRelation import org.apache.spark.sql.catalyst.encoders._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.json.{JacksonGenerator, JSONOptions} import org.apache.spark.sql.catalyst.optimizer.CombineUnions import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, PartitioningCollection} -import org.apache.spark.sql.catalyst.util.{usePrettyExpression, DateTimeUtils} +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.command._ import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.python.EvaluatePython +import org.apache.spark.sql.execution.stat.StatFunctions import org.apache.spark.sql.streaming.DataStreamWriter import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel @@ -224,7 +224,7 @@ class Dataset[T] private[sql]( } } - private def aggregatableColumns: Seq[Expression] = { + private[sql] def aggregatableColumns: Seq[Expression] = { schema.fields .filter(f => f.dataType.isInstanceOf[NumericType] || f.dataType.isInstanceOf[StringType]) .map { n => @@ -2185,9 +2185,9 @@ class Dataset[T] private[sql]( } /** - * Computes statistics for numeric and string columns, including count, mean, stddev, min, and - * max. If no columns are given, this function computes statistics for all numerical or string - * columns. + * Computes basic statistics for numeric and string columns, including count, mean, stddev, min, + * and max. If no columns are given, this function computes statistics for all numerical or + * string columns. * * This function is meant for exploratory data analysis, as we make no guarantee about the * backward compatibility of the schema of the resulting Dataset. If you want to @@ -2205,46 +2205,79 @@ class Dataset[T] private[sql]( * // max 92.0 192.0 * }}} * + * Use [[summary]] for expanded statistics and control over which statistics to compute. + * + * @param cols Columns to compute statistics on. + * * @group action * @since 1.6.0 */ @scala.annotation.varargs - def describe(cols: String*): DataFrame = withPlan { - - // The list of summary statistics to compute, in the form of expressions. - val statistics = List[(String, Expression => Expression)]( - "count" -> ((child: Expression) => Count(child).toAggregateExpression()), - "mean" -> ((child: Expression) => Average(child).toAggregateExpression()), - "stddev" -> ((child: Expression) => StddevSamp(child).toAggregateExpression()), - "min" -> ((child: Expression) => Min(child).toAggregateExpression()), - "max" -> ((child: Expression) => Max(child).toAggregateExpression())) - - val outputCols = - (if (cols.isEmpty) aggregatableColumns.map(usePrettyExpression(_).sql) else cols).toList - - val ret: Seq[Row] = if (outputCols.nonEmpty) { - val aggExprs = statistics.flatMap { case (_, colToAgg) => - outputCols.map(c => Column(Cast(colToAgg(Column(c).expr), StringType)).as(c)) - } - - val row = groupBy().agg(aggExprs.head, aggExprs.tail: _*).head().toSeq - - // Pivot the data so each summary is one row - row.grouped(outputCols.size).toSeq.zip(statistics).map { case (aggregation, (statistic, _)) => - Row(statistic :: aggregation.toList: _*) - } - } else { - // If there are no output columns, just output a single column that contains the stats. - statistics.map { case (name, _) => Row(name) } - } - - // All columns are string type - val schema = StructType( - StructField("summary", StringType) :: outputCols.map(StructField(_, StringType))).toAttributes - // `toArray` forces materialization to make the seq serializable - LocalRelation.fromExternalRows(schema, ret.toArray.toSeq) + def describe(cols: String*): DataFrame = { + val selected = if (cols.isEmpty) this else select(cols.head, cols.tail: _*) + selected.summary("count", "mean", "stddev", "min", "max") } + /** + * Computes specified statistics for numeric and string columns. Available statistics are: + * + * - count + * - mean + * - stddev + * - min + * - max + * - arbitrary approximate percentiles specified as a percentage (eg, 75%) + * + * If no statistics are given, this function computes count, mean, stddev, min, + * approximate quartiles (percentiles at 25%, 50%, and 75%), and max. + * + * This function is meant for exploratory data analysis, as we make no guarantee about the + * backward compatibility of the schema of the resulting Dataset. If you want to + * programmatically compute summary statistics, use the `agg` function instead. + * + * {{{ + * ds.summary().show() + * + * // output: + * // summary age height + * // count 10.0 10.0 + * // mean 53.3 178.05 + * // stddev 11.6 15.7 + * // min 18.0 163.0 + * // 25% 24.0 176.0 + * // 50% 24.0 176.0 + * // 75% 32.0 180.0 + * // max 92.0 192.0 + * }}} + * + * {{{ + * ds.summary("count", "min", "25%", "75%", "max").show() + * + * // output: + * // summary age height + * // count 10.0 10.0 + * // min 18.0 163.0 + * // 25% 24.0 176.0 + * // 75% 32.0 180.0 + * // max 92.0 192.0 + * }}} + * + * To do a summary for specific columns first select them: + * + * {{{ + * ds.select("age", "height").summary().show() + * }}} + * + * See also [[describe]] for basic statistics. + * + * @param statistics Statistics from above list to be computed. + * + * @group action + * @since 2.3.0 + */ + @scala.annotation.varargs + def summary(statistics: String*): DataFrame = StatFunctions.summary(this, statistics.toSeq) + /** * Returns the first `n` rows. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala index 1debad03c93f..436e18fdb5ff 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala @@ -19,9 +19,10 @@ package org.apache.spark.sql.execution.stat import org.apache.spark.internal.Logging import org.apache.spark.sql.{Column, DataFrame, Dataset, Row} -import org.apache.spark.sql.catalyst.expressions.{Cast, GenericInternalRow} +import org.apache.spark.sql.catalyst.expressions.{Cast, CreateArray, Expression, GenericInternalRow, Literal} +import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical.LocalRelation -import org.apache.spark.sql.catalyst.util.QuantileSummaries +import org.apache.spark.sql.catalyst.util.{usePrettyExpression, QuantileSummaries} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -220,4 +221,97 @@ object StatFunctions extends Logging { Dataset.ofRows(df.sparkSession, LocalRelation(schema.toAttributes, table)).na.fill(0.0) } + + /** Calculate selected summary statistics for a dataset */ + def summary(ds: Dataset[_], statistics: Seq[String]): DataFrame = { + + val defaultStatistics = Seq("count", "mean", "stddev", "min", "25%", "50%", "75%", "max") + val selectedStatistics = if (statistics.nonEmpty) statistics else defaultStatistics + + val hasPercentiles = selectedStatistics.exists(_.endsWith("%")) + val (percentiles, percentileNames, remainingAggregates) = if (hasPercentiles) { + val (pStrings, rest) = selectedStatistics.partition(a => a.endsWith("%")) + val percentiles = pStrings.map { p => + try { + p.stripSuffix("%").toDouble / 100.0 + } catch { + case e: NumberFormatException => + throw new IllegalArgumentException(s"Unable to parse $p as a percentile", e) + } + } + require(percentiles.forall(p => p >= 0 && p <= 1), "Percentiles must be in the range [0, 1]") + (percentiles, pStrings, rest) + } else { + (Seq(), Seq(), selectedStatistics) + } + + + // The list of summary statistics to compute, in the form of expressions. + val availableStatistics = Map[String, Expression => Expression]( + "count" -> ((child: Expression) => Count(child).toAggregateExpression()), + "mean" -> ((child: Expression) => Average(child).toAggregateExpression()), + "stddev" -> ((child: Expression) => StddevSamp(child).toAggregateExpression()), + "min" -> ((child: Expression) => Min(child).toAggregateExpression()), + "max" -> ((child: Expression) => Max(child).toAggregateExpression())) + + val statisticFns = remainingAggregates.map { agg => + require(availableStatistics.contains(agg), s"$agg is not a recognised statistic") + agg -> availableStatistics(agg) + } + + def percentileAgg(child: Expression): Expression = + new ApproximatePercentile(child, CreateArray(percentiles.map(Literal(_)))) + .toAggregateExpression() + + val outputCols = ds.aggregatableColumns.map(usePrettyExpression(_).sql).toList + + val ret: Seq[Row] = if (outputCols.nonEmpty) { + var aggExprs = statisticFns.toList.flatMap { case (_, colToAgg) => + outputCols.map(c => Column(Cast(colToAgg(Column(c).expr), StringType)).as(c)) + } + if (hasPercentiles) { + aggExprs = outputCols.map(c => Column(percentileAgg(Column(c).expr)).as(c)) ++ aggExprs + } + + val row = ds.groupBy().agg(aggExprs.head, aggExprs.tail: _*).head().toSeq + + // Pivot the data so each summary is one row + val grouped: Seq[Seq[Any]] = row.grouped(outputCols.size).toSeq + + val basicStats = if (hasPercentiles) grouped.tail else grouped + + val rows = basicStats.zip(statisticFns).map { case (aggregation, (statistic, _)) => + Row(statistic :: aggregation.toList: _*) + } + + if (hasPercentiles) { + def nullSafeString(x: Any) = if (x == null) null else x.toString + val percentileRows = grouped.head + .map { + case a: Seq[Any] => a + case _ => Seq.fill(percentiles.length)(null: Any) + } + .transpose + .zip(percentileNames) + .map { case (values: Seq[Any], name) => + Row(name :: values.map(nullSafeString).toList: _*) + } + (rows ++ percentileRows) + .sortWith((left, right) => + selectedStatistics.indexOf(left(0)) < selectedStatistics.indexOf(right(0))) + } else { + rows + } + } else { + // If there are no output columns, just output a single column that contains the stats. + selectedStatistics.map(Row(_)) + } + + // All columns are string type + val schema = StructType( + StructField("summary", StringType) :: outputCols.map(StructField(_, StringType))).toAttributes + // `toArray` forces materialization to make the seq serializable + Dataset.ofRows(ds.sparkSession, LocalRelation.fromExternalRows(schema, ret.toArray.toSeq)) + } + } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 9ea9951c24ef..2c7051bf431c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -28,8 +28,7 @@ import org.scalatest.Matchers._ import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.plans.logical.{Filter, OneRowRelation, Project, Union} -import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.catalyst.plans.logical.{Filter, OneRowRelation, Union} import org.apache.spark.sql.execution.{FilterExec, QueryExecution} import org.apache.spark.sql.execution.aggregate.HashAggregateExec import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchange} @@ -663,13 +662,13 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { assert(df.schema.map(_.name) === Seq("key", "valueRenamed", "newCol")) } - test("describe") { - val describeTestData = Seq( - ("Bob", 16, 176), - ("Alice", 32, 164), - ("David", 60, 192), - ("Amy", 24, 180)).toDF("name", "age", "height") + private lazy val person2: DataFrame = Seq( + ("Bob", 16, 176), + ("Alice", 32, 164), + ("David", 60, 192), + ("Amy", 24, 180)).toDF("name", "age", "height") + test("describe") { val describeResult = Seq( Row("count", "4", "4", "4"), Row("mean", null, "33.0", "178.0"), @@ -686,32 +685,99 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { def getSchemaAsSeq(df: DataFrame): Seq[String] = df.schema.map(_.name) - val describeTwoCols = describeTestData.describe("name", "age", "height") - assert(getSchemaAsSeq(describeTwoCols) === Seq("summary", "name", "age", "height")) - checkAnswer(describeTwoCols, describeResult) - // All aggregate value should have been cast to string - describeTwoCols.collect().foreach { row => - assert(row.get(2).isInstanceOf[String], "expected string but found " + row.get(2).getClass) - assert(row.get(3).isInstanceOf[String], "expected string but found " + row.get(3).getClass) - } - - val describeAllCols = describeTestData.describe() + val describeAllCols = person2.describe() assert(getSchemaAsSeq(describeAllCols) === Seq("summary", "name", "age", "height")) checkAnswer(describeAllCols, describeResult) + // All aggregate value should have been cast to string + describeAllCols.collect().foreach { row => + row.toSeq.foreach { value => + if (value != null) { + assert(value.isInstanceOf[String], "expected string but found " + value.getClass) + } + } + } - val describeOneCol = describeTestData.describe("age") + val describeOneCol = person2.describe("age") assert(getSchemaAsSeq(describeOneCol) === Seq("summary", "age")) checkAnswer(describeOneCol, describeResult.map { case Row(s, _, d, _) => Row(s, d)} ) - val describeNoCol = describeTestData.select("name").describe() - assert(getSchemaAsSeq(describeNoCol) === Seq("summary", "name")) - checkAnswer(describeNoCol, describeResult.map { case Row(s, n, _, _) => Row(s, n)} ) + val describeNoCol = person2.select().describe() + assert(getSchemaAsSeq(describeNoCol) === Seq("summary")) + checkAnswer(describeNoCol, describeResult.map { case Row(s, _, _, _) => Row(s)} ) - val emptyDescription = describeTestData.limit(0).describe() + val emptyDescription = person2.limit(0).describe() assert(getSchemaAsSeq(emptyDescription) === Seq("summary", "name", "age", "height")) checkAnswer(emptyDescription, emptyDescribeResult) } + test("summary") { + val summaryResult = Seq( + Row("count", "4", "4", "4"), + Row("mean", null, "33.0", "178.0"), + Row("stddev", null, "19.148542155126762", "11.547005383792516"), + Row("min", "Alice", "16", "164"), + Row("25%", null, "24.0", "176.0"), + Row("50%", null, "24.0", "176.0"), + Row("75%", null, "32.0", "180.0"), + Row("max", "David", "60", "192")) + + val emptySummaryResult = Seq( + Row("count", "0", "0", "0"), + Row("mean", null, null, null), + Row("stddev", null, null, null), + Row("min", null, null, null), + Row("25%", null, null, null), + Row("50%", null, null, null), + Row("75%", null, null, null), + Row("max", null, null, null)) + + def getSchemaAsSeq(df: DataFrame): Seq[String] = df.schema.map(_.name) + + val summaryAllCols = person2.summary() + + assert(getSchemaAsSeq(summaryAllCols) === Seq("summary", "name", "age", "height")) + checkAnswer(summaryAllCols, summaryResult) + // All aggregate value should have been cast to string + summaryAllCols.collect().foreach { row => + row.toSeq.foreach { value => + if (value != null) { + assert(value.isInstanceOf[String], "expected string but found " + value.getClass) + } + } + } + + val summaryOneCol = person2.select("age").summary() + assert(getSchemaAsSeq(summaryOneCol) === Seq("summary", "age")) + checkAnswer(summaryOneCol, summaryResult.map { case Row(s, _, d, _) => Row(s, d)} ) + + val summaryNoCol = person2.select().summary() + assert(getSchemaAsSeq(summaryNoCol) === Seq("summary")) + checkAnswer(summaryNoCol, summaryResult.map { case Row(s, _, _, _) => Row(s)} ) + + val emptyDescription = person2.limit(0).summary() + assert(getSchemaAsSeq(emptyDescription) === Seq("summary", "name", "age", "height")) + checkAnswer(emptyDescription, emptySummaryResult) + } + + test("summary advanced") { + val stats = Array("count", "50.01%", "max", "mean", "min", "25%") + val orderMatters = person2.summary(stats: _*) + assert(orderMatters.collect().map(_.getString(0)) === stats) + + val onlyPercentiles = person2.summary("0.1%", "99.9%") + assert(onlyPercentiles.count() === 2) + + val fooE = intercept[IllegalArgumentException] { + person2.summary("foo") + } + assert(fooE.getMessage === "requirement failed: foo is not a recognised statistic") + + val parseE = intercept[IllegalArgumentException] { + person2.summary("foo%") + } + assert(parseE.getMessage === "Unable to parse foo% as a percentile") + } + test("apply on query results (SPARK-5462)") { val df = testData.sparkSession.sql("select key from testData") checkAnswer(df.select(df("key")), testData.select('key).collect().toSeq)