-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-30276][SQL] Support Filter expression allows simultaneous use of DISTINCT #29291
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Changes from 21 commits
Commits
Show all changes
25 commits
Select commit
Hold shift + click to select a range
4a6f903
Reuse completeNextStageWithFetchFailure
beliefer 96456e2
Merge remote-tracking branch 'upstream/master'
beliefer 4314005
Merge remote-tracking branch 'upstream/master'
beliefer d6af4a7
Merge remote-tracking branch 'upstream/master'
beliefer f69094f
Merge remote-tracking branch 'upstream/master'
beliefer b86a42d
Merge remote-tracking branch 'upstream/master'
beliefer 2ac5159
Merge branch 'master' of github.com:beliefer/spark
beliefer 9021d6c
Merge remote-tracking branch 'upstream/master'
beliefer 74a2ef4
Merge branch 'master' of github.com:beliefer/spark
beliefer 199aa6f
Support single distinct group with filter.
beliefer a73f11e
Support distinct agg with filter
beliefer 72e95f1
Supplement doc and comment.
beliefer 8e82e83
Add test case and regenerate golden files.
beliefer 4ba808b
Add test case and regenerate golden files.
beliefer 145a9dd
Optimize code
beliefer 0fcf643
Update doc
beliefer 92a37a9
Optimize code.
beliefer 7362dfb
Optimize code.
beliefer 9828158
Merge remote-tracking branch 'upstream/master'
beliefer fbb051b
Merge branch 'master' into support-distinct-with-filter
beliefer 9939ea7
Add tests case like distinct 1
beliefer 2dc6f32
Optimize code
beliefer abafc20
Optimize code
beliefer 39583dd
Optimize code
beliefer 883973b
Optimize code
beliefer File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -18,7 +18,7 @@ | |
| package org.apache.spark.sql.catalyst.optimizer | ||
|
|
||
| import org.apache.spark.sql.catalyst.expressions._ | ||
| import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction, Complete} | ||
| import org.apache.spark.sql.catalyst.expressions.aggregate._ | ||
| import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Expand, LogicalPlan} | ||
| import org.apache.spark.sql.catalyst.rules.Rule | ||
| import org.apache.spark.sql.types.IntegerType | ||
|
|
@@ -81,10 +81,10 @@ import org.apache.spark.sql.types.IntegerType | |
| * COUNT(DISTINCT cat1) as cat1_cnt, | ||
| * COUNT(DISTINCT cat2) as cat2_cnt, | ||
| * SUM(value) FILTER (WHERE id > 1) AS total | ||
| * FROM | ||
| * data | ||
| * GROUP BY | ||
| * key | ||
| * FROM | ||
| * data | ||
| * GROUP BY | ||
| * key | ||
| * }}} | ||
| * | ||
| * This translates to the following (pseudo) logical plan: | ||
|
|
@@ -93,7 +93,7 @@ import org.apache.spark.sql.types.IntegerType | |
| * key = ['key] | ||
| * functions = [COUNT(DISTINCT 'cat1), | ||
| * COUNT(DISTINCT 'cat2), | ||
| * sum('value) with FILTER('id > 1)] | ||
| * sum('value) FILTER (WHERE 'id > 1)] | ||
| * output = ['key, 'cat1_cnt, 'cat2_cnt, 'total]) | ||
| * LocalTableScan [...] | ||
| * }}} | ||
|
|
@@ -108,7 +108,7 @@ import org.apache.spark.sql.types.IntegerType | |
| * output = ['key, 'cat1_cnt, 'cat2_cnt, 'total]) | ||
| * Aggregate( | ||
| * key = ['key, 'cat1, 'cat2, 'gid] | ||
| * functions = [sum('value) with FILTER('id > 1)] | ||
| * functions = [sum('value) FILTER (WHERE 'id > 1)] | ||
| * output = ['key, 'cat1, 'cat2, 'gid, 'total]) | ||
| * Expand( | ||
| * projections = [('key, null, null, 0, cast('value as bigint), 'id), | ||
|
|
@@ -118,6 +118,49 @@ import org.apache.spark.sql.types.IntegerType | |
| * LocalTableScan [...] | ||
| * }}} | ||
| * | ||
| * Third example: aggregate function with distinct and filter clauses (in sql): | ||
| * {{{ | ||
| * SELECT | ||
| * COUNT(DISTINCT cat1) FILTER (WHERE id > 1) as cat1_cnt, | ||
| * COUNT(DISTINCT cat2) FILTER (WHERE id > 2) as cat2_cnt, | ||
| * SUM(value) FILTER (WHERE id > 3) AS total | ||
| * FROM | ||
| * data | ||
| * GROUP BY | ||
| * key | ||
| * }}} | ||
| * | ||
| * This translates to the following (pseudo) logical plan: | ||
| * {{{ | ||
| * Aggregate( | ||
| * key = ['key] | ||
| * functions = [COUNT(DISTINCT 'cat1) FILTER (WHERE 'id > 1), | ||
| * COUNT(DISTINCT 'cat2) FILTER (WHERE 'id > 2), | ||
| * sum('value) FILTER (WHERE 'id > 3)] | ||
| * output = ['key, 'cat1_cnt, 'cat2_cnt, 'total]) | ||
| * LocalTableScan [...] | ||
| * }}} | ||
| * | ||
| * This rule rewrites this logical plan to the following (pseudo) logical plan: | ||
| * {{{ | ||
| * Aggregate( | ||
| * key = ['key] | ||
| * functions = [count(if (('gid = 1) and 'max_cond1) 'cat1 else null), | ||
| * count(if (('gid = 2) and 'max_cond2) 'cat2 else null), | ||
| * first(if (('gid = 0)) 'total else null) ignore nulls] | ||
| * output = ['key, 'cat1_cnt, 'cat2_cnt, 'total]) | ||
| * Aggregate( | ||
| * key = ['key, 'cat1, 'cat2, 'gid] | ||
| * functions = [max('cond1), max('cond2), sum('value) FILTER (WHERE 'id > 3)] | ||
| * output = ['key, 'cat1, 'cat2, 'gid, 'max_cond1, 'max_cond2, 'total]) | ||
| * Expand( | ||
| * projections = [('key, null, null, 0, null, null, cast('value as bigint), 'id), | ||
| * ('key, 'cat1, null, 1, 'id > 1, null, null, null), | ||
| * ('key, null, 'cat2, 2, null, 'id > 2, null, null)] | ||
| * output = ['key, 'cat1, 'cat2, 'gid, 'cond1, 'cond2, 'value, 'id]) | ||
| * LocalTableScan [...] | ||
| * }}} | ||
| * | ||
| * The rule does the following things here: | ||
| * 1. Expand the data. There are three aggregation groups in this query: | ||
| * i. the non-distinct group; | ||
|
|
@@ -126,15 +169,20 @@ import org.apache.spark.sql.types.IntegerType | |
| * An expand operator is inserted to expand the child data for each group. The expand will null | ||
| * out all unused columns for the given group; this must be done in order to ensure correctness | ||
| * later on. Groups can by identified by a group id (gid) column added by the expand operator. | ||
| * If distinct group exists filter clause, the expand will calculate the filter and output it's | ||
| * result which will be used to calculate the global conditions equivalent to filter clauses. | ||
| * 2. De-duplicate the distinct paths and aggregate the non-aggregate path. The group by clause of | ||
| * this aggregate consists of the original group by clause, all the requested distinct columns | ||
| * and the group id. Both de-duplication of distinct column and the aggregation of the | ||
| * non-distinct group take advantage of the fact that we group by the group id (gid) and that we | ||
| * have nulled out all non-relevant columns the given group. | ||
| * have nulled out all non-relevant columns the given group. If distinct group exists filter | ||
| * clause, we will use max to aggregate the results of the filter output in the previous step. | ||
| * These aggregate values are equivalent to filter clauses. | ||
| * 3. Aggregating the distinct groups and combining this with the results of the non-distinct | ||
| * aggregation. In this step we use the group id to filter the inputs for the aggregate | ||
| * functions. The result of the non-distinct group are 'aggregated' by using the first operator, | ||
| * it might be more elegant to use the native UDAF merge mechanism for this in the future. | ||
| * aggregation. In this step we use the group id and the global condition to filter the inputs | ||
| * for the aggregate functions. The result of the non-distinct group are 'aggregated' by using | ||
| * the first operator, it might be more elegant to use the native UDAF merge mechanism for this | ||
| * in the future. | ||
| * | ||
| * This rule duplicates the input data by two or more times (# distinct groups + an optional | ||
| * non-distinct group). This will put quite a bit of memory pressure of the used aggregate and | ||
|
|
@@ -144,28 +192,24 @@ import org.apache.spark.sql.types.IntegerType | |
| */ | ||
| object RewriteDistinctAggregates extends Rule[LogicalPlan] { | ||
|
|
||
| private def mayNeedtoRewrite(exprs: Seq[Expression]): Boolean = { | ||
| val distinctAggs = exprs.flatMap { _.collect { | ||
| case ae: AggregateExpression if ae.isDistinct => ae | ||
| }} | ||
| // We need at least two distinct aggregates for this rule because aggregation | ||
| // strategy can handle a single distinct group. | ||
| private def mayNeedtoRewrite(a: Aggregate): Boolean = { | ||
| val aggExpressions = collectAggregateExprs(a) | ||
| val distinctAggs = aggExpressions.filter(_.isDistinct) | ||
| // We need at least two distinct aggregates or the single distinct aggregate group exists filter | ||
| // clause for this rule because aggregation strategy can handle a single distinct aggregate | ||
| // group without filter clause. | ||
| // This check can produce false-positives, e.g., SUM(DISTINCT a) & COUNT(DISTINCT a). | ||
| distinctAggs.size > 1 | ||
| distinctAggs.size > 1 || distinctAggs.exists(_.filter.isDefined) | ||
| } | ||
|
|
||
| def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { | ||
| case a: Aggregate if mayNeedtoRewrite(a.aggregateExpressions) => rewrite(a) | ||
| case a: Aggregate if mayNeedtoRewrite(a) => rewrite(a) | ||
| } | ||
|
|
||
| def rewrite(a: Aggregate): Aggregate = { | ||
|
|
||
| // Collect all aggregate expressions. | ||
| val aggExpressions = a.aggregateExpressions.flatMap { e => | ||
| e.collect { | ||
| case ae: AggregateExpression => ae | ||
| } | ||
| } | ||
| val aggExpressions = collectAggregateExprs(a) | ||
| val distinctAggs = aggExpressions.filter(_.isDistinct) | ||
|
|
||
| // Extract distinct aggregate expressions. | ||
| val distinctAggGroups = aggExpressions.filter(_.isDistinct).groupBy { e => | ||
|
|
@@ -184,8 +228,8 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { | |
| } | ||
| } | ||
|
|
||
| // Aggregation strategy can handle queries with a single distinct group. | ||
| if (distinctAggGroups.size > 1) { | ||
cloud-fan marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| // Aggregation strategy can handle queries with a single distinct group without filter clause. | ||
| if (distinctAggGroups.size > 1 || distinctAggs.exists(_.filter.isDefined)) { | ||
| // Create the attributes for the grouping id and the group by clause. | ||
| val gid = AttributeReference("gid", IntegerType, nullable = false)() | ||
| val groupByMap = a.groupingExpressions.collect { | ||
|
|
@@ -195,7 +239,13 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { | |
| val groupByAttrs = groupByMap.map(_._2) | ||
|
|
||
| // Functions used to modify aggregate functions and their inputs. | ||
| def evalWithinGroup(id: Literal, e: Expression) = If(EqualTo(gid, id), e, nullify(e)) | ||
| def evalWithinGroup(id: Literal, e: Expression, condition: Option[Expression]) = | ||
| if (condition.isDefined) { | ||
| If(And(EqualTo(gid, id), condition.get), e, nullify(e)) | ||
| } else { | ||
| If(EqualTo(gid, id), e, nullify(e)) | ||
| } | ||
|
|
||
| def patchAggregateFunctionChildren( | ||
| af: AggregateFunction)( | ||
| attrs: Expression => Option[Expression]): AggregateFunction = { | ||
|
|
@@ -207,13 +257,30 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { | |
| val distinctAggChildren = distinctAggGroups.keySet.flatten.toSeq.distinct | ||
| val distinctAggChildAttrMap = distinctAggChildren.map(expressionAttributePair) | ||
| val distinctAggChildAttrs = distinctAggChildAttrMap.map(_._2) | ||
| // Setup all the filters in distinct aggregate. | ||
| val distinctAggExprs = aggExpressions | ||
| .filter(e => e.isDistinct && e.children.exists(!_.foldable)) | ||
cloud-fan marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| val (distinctAggFilters, distinctAggFilterAttrs, maxConds) = distinctAggExprs.collect { | ||
| case AggregateExpression(_, _, _, filter, _) if filter.isDefined => | ||
| val (e, attr) = expressionAttributePair(filter.get) | ||
| val aggregateExp = AggregateExpression(Max(attr), Partial, false) | ||
|
||
| (e, attr, Alias(aggregateExp, attr.name)()) | ||
| }.unzip3 | ||
|
|
||
| // Setup expand & aggregate operators for distinct aggregate expressions. | ||
| val distinctAggChildAttrLookup = distinctAggChildAttrMap.toMap | ||
| val distinctAggFilterAttrLookup = distinctAggFilters.zip(maxConds.map(_.toAttribute)).toMap | ||
| val distinctAggOperatorMap = distinctAggGroups.toSeq.zipWithIndex.map { | ||
| case ((group, expressions), i) => | ||
| val id = Literal(i + 1) | ||
|
|
||
| // Expand projection for filter | ||
| val filters = expressions.filter(_.filter.isDefined).map(_.filter.get) | ||
| val filterProjection = distinctAggFilters.map { | ||
| case e if filters.contains(e) => e | ||
| case e => nullify(e) | ||
| } | ||
|
|
||
| // Expand projection | ||
| val projection = distinctAggChildren.map { | ||
| case e if group.contains(e) => e | ||
|
|
@@ -224,12 +291,17 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { | |
| val operators = expressions.map { e => | ||
| val af = e.aggregateFunction | ||
| val naf = patchAggregateFunctionChildren(af) { x => | ||
| distinctAggChildAttrLookup.get(x).map(evalWithinGroup(id, _)) | ||
| val condition = if (e.filter.isDefined) { | ||
| e.filter.map(distinctAggFilterAttrLookup.get(_)).get | ||
| } else { | ||
| None | ||
| } | ||
| distinctAggChildAttrLookup.get(x).map(evalWithinGroup(id, _, condition)) | ||
| } | ||
| (e, e.copy(aggregateFunction = naf, isDistinct = false)) | ||
| (e, e.copy(aggregateFunction = naf, isDistinct = false, filter = None)) | ||
| } | ||
|
|
||
| (projection, operators) | ||
| (projection ++ filterProjection, operators) | ||
| } | ||
|
|
||
| // Setup expand for the 'regular' aggregate expressions. | ||
|
|
@@ -257,7 +329,7 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { | |
|
|
||
| // Select the result of the first aggregate in the last aggregate. | ||
| val result = AggregateExpression( | ||
| aggregate.First(evalWithinGroup(regularGroupId, operator.toAttribute), true), | ||
| aggregate.First(evalWithinGroup(regularGroupId, operator.toAttribute, None), true), | ||
| mode = Complete, | ||
| isDistinct = false) | ||
|
|
||
|
|
@@ -280,6 +352,7 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { | |
| Seq(a.groupingExpressions ++ | ||
| distinctAggChildren.map(nullify) ++ | ||
| Seq(regularGroupId) ++ | ||
| distinctAggFilters.map(nullify) ++ | ||
| regularAggChildren) | ||
| } else { | ||
| Seq.empty[Seq[Expression]] | ||
|
|
@@ -297,15 +370,16 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { | |
| // Construct the expand operator. | ||
| val expand = Expand( | ||
| regularAggProjection ++ distinctAggProjections, | ||
| groupByAttrs ++ distinctAggChildAttrs ++ Seq(gid) ++ regularAggChildAttrMap.map(_._2), | ||
| groupByAttrs ++ distinctAggChildAttrs ++ Seq(gid) ++ distinctAggFilterAttrs ++ | ||
| regularAggChildAttrMap.map(_._2), | ||
| a.child) | ||
|
|
||
| // Construct the first aggregate operator. This de-duplicates all the children of | ||
| // distinct operators, and applies the regular aggregate operators. | ||
| val firstAggregateGroupBy = groupByAttrs ++ distinctAggChildAttrs :+ gid | ||
| val firstAggregate = Aggregate( | ||
| firstAggregateGroupBy, | ||
| firstAggregateGroupBy ++ regularAggOperatorMap.map(_._2), | ||
| firstAggregateGroupBy ++ maxConds ++ regularAggOperatorMap.map(_._2), | ||
| expand) | ||
|
|
||
| // Construct the second aggregate | ||
|
|
@@ -331,6 +405,13 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { | |
| } | ||
| } | ||
|
|
||
| private def collectAggregateExprs(a: Aggregate): Seq[AggregateExpression] = { | ||
| // Collect all aggregate expressions. | ||
| a.aggregateExpressions.flatMap { _.collect { | ||
| case ae: AggregateExpression => ae | ||
| }} | ||
| } | ||
|
|
||
| private def nullify(e: Expression) = Literal.create(null, e.dataType) | ||
|
|
||
| private def expressionAttributePair(e: Expression) = | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.