Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import org.apache.spark.sql.catalyst.analysis.{DecimalPrecision, FunctionRegistr
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.trees.UnaryLike
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types._

@ExpressionDescription(
Expand All @@ -40,10 +39,15 @@ case class Average(child: Expression) extends DeclarativeAggregate with Implicit

override def prettyName: String = getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("avg")

override def inputTypes: Seq[AbstractDataType] = Seq(NumericType)
override def inputTypes: Seq[AbstractDataType] =
Seq(TypeCollection(NumericType, YearMonthIntervalType, DayTimeIntervalType))

override def checkInputDataTypes(): TypeCheckResult =
TypeUtils.checkForNumericExpr(child.dataType, "function average")
override def checkInputDataTypes(): TypeCheckResult = child.dataType match {
case YearMonthIntervalType | DayTimeIntervalType | NullType => TypeCheckResult.TypeCheckSuccess
case dt if dt.isInstanceOf[NumericType] => TypeCheckResult.TypeCheckSuccess
case other => TypeCheckResult.TypeCheckFailure(
s"function average requires numeric or interval types, not ${other.catalogString}")
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code is similar to 12abfe7#diff-310e24b112a4665955ce5db769e4afe57cf2cd07cf9c19ef8b8173d324285ecfR51-R56 . Could you place it to a function in TypeUtils.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea.


override def nullable: Boolean = true

Expand All @@ -53,11 +57,15 @@ case class Average(child: Expression) extends DeclarativeAggregate with Implicit
private lazy val resultType = child.dataType match {
case DecimalType.Fixed(p, s) =>
DecimalType.bounded(p + 4, s + 4)
case _: YearMonthIntervalType => YearMonthIntervalType
case _: DayTimeIntervalType => DayTimeIntervalType
case _ => DoubleType
}

private lazy val sumDataType = child.dataType match {
case _ @ DecimalType.Fixed(p, s) => DecimalType.bounded(p + 10, s)
case _: YearMonthIntervalType => YearMonthIntervalType
case _: DayTimeIntervalType => DayTimeIntervalType
case _ => DoubleType
}

Expand All @@ -82,6 +90,8 @@ case class Average(child: Expression) extends DeclarativeAggregate with Implicit
case _: DecimalType =>
DecimalPrecision.decimalAndDecimal(
Divide(sum, count.cast(DecimalType.LongDecimal), failOnError = false)).cast(resultType)
case _: YearMonthIntervalType => DivideYMInterval(sum, count)
case _: DayTimeIntervalType => DivideDTInterval(sum, count)
case _ =>
Divide(sum.cast(resultType), count.cast(resultType), failOnError = false)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about global aggregate with empty input? we should return null or fail?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I fixed it with #32358.

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,8 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
assertError(Min(Symbol("mapField")), "min does not support ordering on type")
assertError(Max(Symbol("mapField")), "max does not support ordering on type")
assertError(Sum(Symbol("booleanField")), "function sum requires numeric or interval types")
assertError(Average(Symbol("booleanField")), "function average requires numeric type")
assertError(Average(Symbol("booleanField")),
"function average requires numeric or interval types")
}

test("check types for others") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1151,6 +1151,44 @@ class DataFrameAggregateSuite extends QueryTest
}
assert(error2.toString contains "java.lang.ArithmeticException: long overflow")
}

test("SPARK-34837: Support ANSI SQL intervals by the aggregate function `avg`") {
val df = Seq((1, Period.ofMonths(10), Duration.ofDays(10)),
(2, Period.ofMonths(1), Duration.ofDays(1)),
(2, null, null),
(3, Period.ofMonths(-3), Duration.ofDays(-6)),
(3, Period.ofMonths(21), Duration.ofDays(-5)))
.toDF("class", "year-month", "day-time")

val df2 = Seq((Period.ofMonths(Int.MaxValue), Duration.ofDays(106751991)),
(Period.ofMonths(10), Duration.ofDays(10)))
.toDF("year-month", "day-time")

val avgDF = df.select(avg($"year-month"), avg($"day-time"))
checkAnswer(avgDF, Row(Period.ofMonths(7), Duration.ofDays(0)))
assert(find(avgDF.queryExecution.executedPlan)(_.isInstanceOf[HashAggregateExec]).isDefined)
assert(avgDF.schema == StructType(Seq(StructField("avg(year-month)", YearMonthIntervalType),
StructField("avg(day-time)", DayTimeIntervalType))))

val avgDF2 = df.groupBy($"class").agg(avg($"year-month"), avg($"day-time"))
checkAnswer(avgDF2, Row(1, Period.ofMonths(10), Duration.ofDays(10)) ::
Row(2, Period.ofMonths(1), Duration.ofDays(1)) ::
Row(3, Period.ofMonths(9), Duration.ofDays(-5).plusHours(-12)) ::Nil)
assert(find(avgDF2.queryExecution.executedPlan)(_.isInstanceOf[HashAggregateExec]).isDefined)
assert(avgDF2.schema == StructType(Seq(StructField("class", IntegerType, false),
StructField("avg(year-month)", YearMonthIntervalType),
StructField("avg(day-time)", DayTimeIntervalType))))

val error = intercept[SparkException] {
checkAnswer(df2.select(avg($"year-month")), Nil)
}
assert(error.toString contains "java.lang.ArithmeticException: integer overflow")

val error2 = intercept[SparkException] {
checkAnswer(df2.select(avg($"day-time")), Nil)
}
assert(error2.toString contains "java.lang.ArithmeticException: long overflow")
}
}

case class B(c: Option[Double])
Expand Down