-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-27986][SQL] Support ANSI SQL filter clause for aggregate expression #26656
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 51 commits
17b76e2
3f0583f
f64d14c
d521be1
8e342da
0e56d03
fd6461f
f32ac4d
5d33dab
9ea4736
4dcd0d3
060d3d4
4443883
8beff8a
4d0c3aa
14f2b21
895f6ac
675dca9
1c1cf52
fb8f477
b677268
b831855
4c644ca
255650a
518aa4f
d979509
6082e57
ed80517
3d37370
392c18d
9a127e4
967b135
07f774a
c86b691
81c9482
747b3ab
f66c180
3652aef
8bfff6f
0911a76
61bf6fd
ea472aa
583d51f
030a9dc
df643ba
14daee6
ce51461
ce53930
cb31eea
f154622
4d1413f
0d20561
bc2ad92
1297e03
f56400a
eb856df
cffe318
4523616
1cb0725
33d2b5b
c3e0f6a
6c878d3
40e31be
affb6c0
de11c4d
7c40292
258a6c6
4a494ae
8cdd92d
46c4980
d3f38f2
d40dd9f
2518692
94a4a06
9adfd2d
0a4a5a2
66ceeca
3a350cb
01f306e
87697ec
4dd7527
53a6f2a
b29ef0f
c15389d
0cabcb6
b58c126
a3bb997
e520938
a2a79d5
ff1147f
b3584c8
e03a959
998929c
a3d71f4
91fe90e
eb96463
97bb440
bb14439
c1dbb27
000ae72
27ad46e
436b1c0
d5aa8ee
5b2a1b9
6f9e839
d98ea41
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1283,8 +1283,8 @@ class Analyzer( | |
| */ | ||
| def expandStarExpression(expr: Expression, child: LogicalPlan): Expression = { | ||
| expr.transformUp { | ||
| case f1: UnresolvedFunction if containsStar(f1.children) => | ||
| f1.copy(children = f1.children.flatMap { | ||
| case f1: UnresolvedFunction if containsStar(f1.arguments) => | ||
| f1.copy(arguments = f1.arguments.flatMap { | ||
| case s: Star => s.expand(child, resolver) | ||
| case o => o :: Nil | ||
| }) | ||
|
|
@@ -1636,26 +1636,33 @@ class Analyzer( | |
| s"its class is ${other.getClass.getCanonicalName}, which is not a generator.") | ||
| } | ||
| } | ||
| case u @ UnresolvedFunction(funcId, children, isDistinct) => | ||
| case u @ UnresolvedFunction(funcId, arguments, isDistinct, filter) => | ||
| withPosition(u) { | ||
| v1SessionCatalog.lookupFunction(funcId, children) match { | ||
| v1SessionCatalog.lookupFunction(funcId, arguments) match { | ||
| // AggregateWindowFunctions are AggregateFunctions that can only be evaluated within | ||
| // the context of a Window clause. They do not need to be wrapped in an | ||
| // AggregateExpression. | ||
| case wf: AggregateWindowFunction => | ||
| if (isDistinct) { | ||
cloud-fan marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| failAnalysis( | ||
| s"DISTINCT specified, but ${wf.prettyName} is not an aggregate function") | ||
| } else if (filter.isDefined) { | ||
|
||
| failAnalysis("FILTER predicate specified, " + | ||
| s"but ${wf.prettyName} is not an aggregate function") | ||
| } else { | ||
| wf | ||
| } | ||
| // We get an aggregate function, we need to wrap it in an AggregateExpression. | ||
| case agg: AggregateFunction => AggregateExpression(agg, Complete, isDistinct) | ||
| case agg: AggregateFunction => | ||
| AggregateExpression(agg, Complete, isDistinct, filter) | ||
| // This function is not an aggregate function, just return the resolved one. | ||
| case other => | ||
| if (isDistinct) { | ||
| failAnalysis( | ||
| s"DISTINCT specified, but ${other.prettyName} is not an aggregate function") | ||
| } else if (filter.isDefined) { | ||
|
||
| failAnalysis("FILTER predicate specified, " + | ||
| s"but ${other.prettyName} is not an aggregate function") | ||
| } else { | ||
| other | ||
| } | ||
|
|
@@ -2253,15 +2260,15 @@ class Analyzer( | |
|
|
||
| // Extract Windowed AggregateExpression | ||
| case we @ WindowExpression( | ||
| ae @ AggregateExpression(function, _, _, _), | ||
| ae @ AggregateExpression(function, _, _, _, _), | ||
| spec: WindowSpecDefinition) => | ||
| val newChildren = function.children.map(extractExpr) | ||
| val newFunction = function.withNewChildren(newChildren).asInstanceOf[AggregateFunction] | ||
| val newAgg = ae.copy(aggregateFunction = newFunction) | ||
| seenWindowAggregates += newAgg | ||
| WindowExpression(newAgg, spec) | ||
|
|
||
| case AggregateExpression(aggFunc, _, _, _) if hasWindowFunction(aggFunc.children) => | ||
| case AggregateExpression(aggFunc, _, _, _, _) if hasWindowFunction(aggFunc.children) => | ||
| failAnalysis("It is not allowed to use a window function inside an aggregate " + | ||
| "function. Please use the inner window function in a sub-query.") | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -33,11 +33,14 @@ import org.apache.spark.sql.types.DataType | |
| case class ResolveHigherOrderFunctions(catalog: SessionCatalog) extends Rule[LogicalPlan] { | ||
|
|
||
| override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveExpressions { | ||
| case u @ UnresolvedFunction(fn, children, false) | ||
| case u @ UnresolvedFunction(fn, children, false, filter) | ||
| if hasLambdaAndResolvedArguments(children) => | ||
| withPosition(u) { | ||
| catalog.lookupFunction(fn, children) match { | ||
| case func: HigherOrderFunction => func | ||
| case func: HigherOrderFunction => | ||
| filter.foreach(_.failAnalysis("FILTER predicate specified, " + | ||
| s"but ${func.prettyName} is not an aggregate function")) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you add tests for this path?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK |
||
| func | ||
| case other => other.failAnalysis( | ||
| "A lambda function should only be used in a higher order function. However, " + | ||
| s"its class is ${other.getClass.getCanonicalName}, which is not a " + | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -243,10 +243,16 @@ case class UnresolvedGenerator(name: FunctionIdentifier, children: Seq[Expressio | |
|
|
||
| case class UnresolvedFunction( | ||
| name: FunctionIdentifier, | ||
| children: Seq[Expression], | ||
| isDistinct: Boolean) | ||
| arguments: Seq[Expression], | ||
| isDistinct: Boolean, | ||
| filter: Option[Expression] = None) | ||
| extends Expression with Unevaluable { | ||
|
|
||
| override def children: Seq[Expression] = filter match { | ||
|
||
| case Some(expr) => arguments :+ expr | ||
| case _ => arguments | ||
| } | ||
|
|
||
| override def dataType: DataType = throw new UnresolvedException(this, "dataType") | ||
| override def foldable: Boolean = throw new UnresolvedException(this, "foldable") | ||
| override def nullable: Boolean = throw new UnresolvedException(this, "nullable") | ||
|
|
@@ -257,8 +263,8 @@ case class UnresolvedFunction( | |
| } | ||
|
|
||
| object UnresolvedFunction { | ||
| def apply(name: String, children: Seq[Expression], isDistinct: Boolean): UnresolvedFunction = { | ||
| UnresolvedFunction(FunctionIdentifier(name, None), children, isDistinct) | ||
| def apply(name: String, arguments: Seq[Expression], isDistinct: Boolean): UnresolvedFunction = { | ||
| UnresolvedFunction(FunctionIdentifier(name, None), arguments, isDistinct) | ||
| } | ||
| } | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -71,23 +71,27 @@ object AggregateExpression { | |
| def apply( | ||
| aggregateFunction: AggregateFunction, | ||
| mode: AggregateMode, | ||
| isDistinct: Boolean): AggregateExpression = { | ||
| isDistinct: Boolean, | ||
| filter: Option[Expression] = None): AggregateExpression = { | ||
| AggregateExpression( | ||
| aggregateFunction, | ||
| mode, | ||
| isDistinct, | ||
| filter, | ||
| NamedExpression.newExprId) | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * A container for an [[AggregateFunction]] with its [[AggregateMode]] and a field | ||
| * (`isDistinct`) indicating if DISTINCT keyword is specified for this function. | ||
| * (`isDistinct`) indicating if DISTINCT keyword is specified for this function and | ||
| * a field (`filter`) indicating if filter clause is specified for this function. | ||
| */ | ||
| case class AggregateExpression( | ||
| aggregateFunction: AggregateFunction, | ||
| mode: AggregateMode, | ||
| isDistinct: Boolean, | ||
| filter: Option[Expression], | ||
| resultId: ExprId) | ||
| extends Expression | ||
| with Unevaluable { | ||
|
|
@@ -104,6 +108,8 @@ case class AggregateExpression( | |
| UnresolvedAttribute(aggregateFunction.toString) | ||
| } | ||
|
|
||
| lazy val filterAttributes: AttributeSet = filter.map(_.references).getOrElse(AttributeSet.empty) | ||
|
|
||
| // We compute the same thing regardless of our final result. | ||
| override lazy val canonicalized: Expression = { | ||
| val normalizedAggFunc = mode match { | ||
|
|
@@ -119,18 +125,24 @@ case class AggregateExpression( | |
| normalizedAggFunc.canonicalized.asInstanceOf[AggregateFunction], | ||
| mode, | ||
| isDistinct, | ||
| filter, | ||
|
||
| ExprId(0)) | ||
| } | ||
|
|
||
| override def children: Seq[Expression] = aggregateFunction :: Nil | ||
| override def children: Seq[Expression] = filter match { | ||
|
||
| case Some(expr) => aggregateFunction :: expr :: Nil | ||
| case _ => aggregateFunction :: Nil | ||
| } | ||
|
|
||
| override def dataType: DataType = aggregateFunction.dataType | ||
| override def foldable: Boolean = false | ||
| override def nullable: Boolean = aggregateFunction.nullable | ||
|
|
||
| @transient | ||
| override lazy val references: AttributeSet = { | ||
| mode match { | ||
| case Partial | Complete => aggregateFunction.references | ||
| case Partial | Complete => | ||
| aggregateFunction.references ++ filter.map(_.references).getOrElse(AttributeSet.empty) | ||
|
||
| case PartialMerge | Final => AttributeSet(aggregateFunction.aggBufferAttributes) | ||
| } | ||
| } | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -169,11 +169,21 @@ class AnalysisErrorSuite extends AnalysisTest { | |
| CatalystSqlParser.parsePlan("SELECT hex(DISTINCT a) FROM TaBlE"), | ||
| "DISTINCT specified, but hex is not an aggregate function" :: Nil) | ||
|
|
||
| errorTest( | ||
| "non aggregate function with filter predicate", | ||
| CatalystSqlParser.parsePlan("SELECT hex(a) filter (where c = 1) FROM TaBlE2"), | ||
|
||
| "FILTER predicate specified, but hex is not an aggregate function" :: Nil) | ||
|
|
||
| errorTest( | ||
| "distinct window function", | ||
| CatalystSqlParser.parsePlan("SELECT percent_rank(DISTINCT a) over () FROM TaBlE"), | ||
| "DISTINCT specified, but percent_rank is not an aggregate function" :: Nil) | ||
|
|
||
| errorTest( | ||
| "window function with filter predicate", | ||
| CatalystSqlParser.parsePlan("SELECT percent_rank(a) filter (where c > 1) over () FROM TaBlE2"), | ||
| "FILTER predicate specified, but percent_rank is not an aggregate function" :: Nil) | ||
|
|
||
| errorTest( | ||
| "nested aggregate functions", | ||
| testRelation.groupBy('a)( | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -135,19 +135,27 @@ object AggUtils { | |
| } | ||
| val distinctAttributes = namedDistinctExpressions.map(_.toAttribute) | ||
| val groupingAttributes = groupingExpressions.map(_.toAttribute) | ||
| val filterWithDistinctAttributes = functionsWithDistinct.flatMap(_.filterAttributes.toSeq) | ||
|
|
||
| // 1. Create an Aggregate Operator for partial aggregations. | ||
| val partialAggregate: SparkPlan = { | ||
| val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Partial)) | ||
| val aggregateAttributes = aggregateExpressions.map(_.resultAttribute) | ||
| // We will group by the original grouping expression, plus an additional expression for the | ||
| // DISTINCT column. For example, for AVG(DISTINCT value) GROUP BY key, the grouping | ||
| // expressions will be [key, value]. | ||
| // DISTINCT column and the expression in the FILTER clause associated with each aggregate | ||
|
||
| // function. For example: | ||
| // 1.for the AVG (DISTINCT value) GROUP BY key, the grouping expression will be [key, value]; | ||
| // 2.for the AVG (value) Filter (WHERE value2> 20) GROUP BY key, the grouping expression | ||
| // will be [key, value2]; | ||
|
||
| // 3.for AVG (DISTINCT value) Filter (WHERE value2> 20) GROUP BY key, the grouping expression | ||
|
||
| // will be [key, value, value2]. | ||
cloud-fan marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| createAggregate( | ||
| groupingExpressions = groupingExpressions ++ namedDistinctExpressions, | ||
| groupingExpressions = groupingExpressions ++ namedDistinctExpressions ++ | ||
| filterWithDistinctAttributes, | ||
| aggregateExpressions = aggregateExpressions, | ||
| aggregateAttributes = aggregateAttributes, | ||
| resultExpressions = groupingAttributes ++ distinctAttributes ++ | ||
| filterWithDistinctAttributes ++ | ||
| aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes), | ||
| child = child) | ||
| } | ||
|
|
@@ -159,11 +167,13 @@ object AggUtils { | |
| createAggregate( | ||
| requiredChildDistributionExpressions = | ||
| Some(groupingAttributes ++ distinctAttributes), | ||
| groupingExpressions = groupingAttributes ++ distinctAttributes, | ||
| groupingExpressions = groupingAttributes ++ distinctAttributes ++ | ||
| filterWithDistinctAttributes, | ||
| aggregateExpressions = aggregateExpressions, | ||
| aggregateAttributes = aggregateAttributes, | ||
| initialInputBufferOffset = (groupingAttributes ++ distinctAttributes).length, | ||
| resultExpressions = groupingAttributes ++ distinctAttributes ++ | ||
| filterWithDistinctAttributes ++ | ||
| aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes), | ||
| child = partialAggregate) | ||
| } | ||
|
|
@@ -174,7 +184,7 @@ object AggUtils { | |
| // Children of an AggregateFunction with DISTINCT keyword has already | ||
| // been evaluated. At here, we need to replace original children | ||
| // to AttributeReferences. | ||
| case agg @ AggregateExpression(aggregateFunction, mode, true, _) => | ||
| case agg @ AggregateExpression(aggregateFunction, mode, true, _, _) => | ||
| aggregateFunction.transformDown(distinctColumnAttributeLookup) | ||
| .asInstanceOf[AggregateFunction] | ||
| case agg => | ||
|
|
@@ -194,7 +204,8 @@ object AggUtils { | |
| // its input will have distinct arguments. | ||
| // We just keep the isDistinct setting to true, so when users look at the query plan, | ||
| // they still can see distinct aggregations. | ||
| val expr = AggregateExpression(func, Partial, isDistinct = true) | ||
| val filter = functionsWithDistinct(i).filter | ||
| val expr = AggregateExpression(func, Partial, isDistinct = true, filter) | ||
| // Use original AggregationFunction to lookup attributes, which is used to build | ||
| // aggregateFunctionToAttribute | ||
| val attr = functionsWithDistinct(i).resultAttribute | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's make it reserved under ansi mode. i.e. don't put the keyword in
ansiNonreservedUh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@cloud-fan
There exists a issue as below:
spark/sql/core/src/test/resources/sql-tests/results/ansi/higher-order-functions.sql.out
Line 85 in bf7215c
filter is a function.
It seems we can't put the keyword in
ansiNonreservedThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah i see. We should put it in
functionName, which means that it's a reserved keyword, but can be used as function name.Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I put it in
functionNamestill occur the issue.