-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-38204][SS] Use StatefulOpClusteredDistribution for stateful operators with respecting backward compatibility #35673
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
fa4cc46
411c372
11defa2
78a90a3
448aa0a
9afed1b
5d61b8a
dc9d1fc
0fa1fb4
181490c
cc30d92
4dd373c
6c7111b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -45,8 +45,28 @@ object AggUtils { | |
| } | ||
| } | ||
|
|
||
| private def createStreamingAggregate( | ||
| requiredChildDistributionExpressions: Option[Seq[Expression]] = None, | ||
| groupingExpressions: Seq[NamedExpression] = Nil, | ||
| aggregateExpressions: Seq[AggregateExpression] = Nil, | ||
| aggregateAttributes: Seq[Attribute] = Nil, | ||
| initialInputBufferOffset: Int = 0, | ||
| resultExpressions: Seq[NamedExpression] = Nil, | ||
| child: SparkPlan): SparkPlan = { | ||
| createAggregate( | ||
| requiredChildDistributionExpressions, | ||
| isStreaming = true, | ||
| groupingExpressions = groupingExpressions, | ||
| aggregateExpressions = aggregateExpressions, | ||
| aggregateAttributes = aggregateAttributes, | ||
| initialInputBufferOffset = initialInputBufferOffset, | ||
| resultExpressions = resultExpressions, | ||
| child = child) | ||
| } | ||
|
|
||
| private def createAggregate( | ||
| requiredChildDistributionExpressions: Option[Seq[Expression]] = None, | ||
| isStreaming: Boolean = false, | ||
| groupingExpressions: Seq[NamedExpression] = Nil, | ||
| aggregateExpressions: Seq[AggregateExpression] = Nil, | ||
| aggregateAttributes: Seq[Attribute] = Nil, | ||
|
|
@@ -60,6 +80,8 @@ object AggUtils { | |
| if (useHash && !forceSortAggregate) { | ||
| HashAggregateExec( | ||
| requiredChildDistributionExpressions = requiredChildDistributionExpressions, | ||
| isStreaming = isStreaming, | ||
| numShufflePartitions = None, | ||
| groupingExpressions = groupingExpressions, | ||
| aggregateExpressions = mayRemoveAggFilters(aggregateExpressions), | ||
| aggregateAttributes = aggregateAttributes, | ||
|
|
@@ -73,6 +95,8 @@ object AggUtils { | |
| if (objectHashEnabled && useObjectHash && !forceSortAggregate) { | ||
| ObjectHashAggregateExec( | ||
| requiredChildDistributionExpressions = requiredChildDistributionExpressions, | ||
| isStreaming = isStreaming, | ||
| numShufflePartitions = None, | ||
| groupingExpressions = groupingExpressions, | ||
| aggregateExpressions = mayRemoveAggFilters(aggregateExpressions), | ||
| aggregateAttributes = aggregateAttributes, | ||
|
|
@@ -82,6 +106,8 @@ object AggUtils { | |
| } else { | ||
| SortAggregateExec( | ||
| requiredChildDistributionExpressions = requiredChildDistributionExpressions, | ||
| isStreaming = isStreaming, | ||
| numShufflePartitions = None, | ||
| groupingExpressions = groupingExpressions, | ||
| aggregateExpressions = mayRemoveAggFilters(aggregateExpressions), | ||
| aggregateAttributes = aggregateAttributes, | ||
|
|
@@ -290,7 +316,7 @@ object AggUtils { | |
| val partialAggregate: SparkPlan = { | ||
| val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Partial)) | ||
| val aggregateAttributes = aggregateExpressions.map(_.resultAttribute) | ||
| createAggregate( | ||
| createStreamingAggregate( | ||
| groupingExpressions = groupingExpressions, | ||
| aggregateExpressions = aggregateExpressions, | ||
| aggregateAttributes = aggregateAttributes, | ||
|
|
@@ -302,7 +328,7 @@ object AggUtils { | |
| val partialMerged1: SparkPlan = { | ||
| val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = PartialMerge)) | ||
| val aggregateAttributes = aggregateExpressions.map(_.resultAttribute) | ||
| createAggregate( | ||
| createStreamingAggregate( | ||
| requiredChildDistributionExpressions = | ||
| Some(groupingAttributes), | ||
| groupingExpressions = groupingAttributes, | ||
|
|
@@ -320,7 +346,7 @@ object AggUtils { | |
| val partialMerged2: SparkPlan = { | ||
| val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = PartialMerge)) | ||
| val aggregateAttributes = aggregateExpressions.map(_.resultAttribute) | ||
| createAggregate( | ||
| createStreamingAggregate( | ||
| requiredChildDistributionExpressions = | ||
| Some(groupingAttributes), | ||
| groupingExpressions = groupingAttributes, | ||
|
|
@@ -348,7 +374,7 @@ object AggUtils { | |
| // projection: | ||
| val finalAggregateAttributes = finalAggregateExpressions.map(_.resultAttribute) | ||
|
|
||
| createAggregate( | ||
| createStreamingAggregate( | ||
| requiredChildDistributionExpressions = Some(groupingAttributes), | ||
| groupingExpressions = groupingAttributes, | ||
| aggregateExpressions = finalAggregateExpressions, | ||
|
|
@@ -407,7 +433,7 @@ object AggUtils { | |
| val partialAggregate: SparkPlan = { | ||
| val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Partial)) | ||
| val aggregateAttributes = aggregateExpressions.map(_.resultAttribute) | ||
| createAggregate( | ||
| createStreamingAggregate( | ||
| groupingExpressions = groupingExpressions, | ||
| aggregateExpressions = aggregateExpressions, | ||
| aggregateAttributes = aggregateAttributes, | ||
|
|
@@ -424,7 +450,8 @@ object AggUtils { | |
| // this is to reduce amount of rows to shuffle | ||
| MergingSessionsExec( | ||
| requiredChildDistributionExpressions = None, | ||
| requiredChildDistributionOption = None, | ||
| isStreaming = true, | ||
| numShufflePartitions = None, | ||
| groupingExpressions = groupingAttributes, | ||
| sessionExpression = sessionExpression, | ||
| aggregateExpressions = aggregateExpressions, | ||
|
|
@@ -447,8 +474,10 @@ object AggUtils { | |
| val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = PartialMerge)) | ||
| val aggregateAttributes = aggregateExpressions.map(_.resultAttribute) | ||
| MergingSessionsExec( | ||
| requiredChildDistributionExpressions = None, | ||
| requiredChildDistributionOption = Some(restored.requiredChildDistribution), | ||
| requiredChildDistributionExpressions = Some(groupingWithoutSessionAttributes), | ||
| isStreaming = true, | ||
| // This will be replaced with actual value in state rule. | ||
| numShufflePartitions = None, | ||
| groupingExpressions = groupingAttributes, | ||
| sessionExpression = sessionExpression, | ||
| aggregateExpressions = aggregateExpressions, | ||
|
|
@@ -476,8 +505,8 @@ object AggUtils { | |
| // projection: | ||
| val finalAggregateAttributes = finalAggregateExpressions.map(_.resultAttribute) | ||
|
|
||
| createAggregate( | ||
| requiredChildDistributionExpressions = Some(groupingAttributes), | ||
| createStreamingAggregate( | ||
| requiredChildDistributionExpressions = Some(groupingWithoutSessionAttributes), | ||
| groupingExpressions = groupingAttributes, | ||
| aggregateExpressions = finalAggregateExpressions, | ||
| aggregateAttributes = finalAggregateAttributes, | ||
|
|
@@ -491,10 +520,15 @@ object AggUtils { | |
|
|
||
| private def mayAppendUpdatingSessionExec( | ||
| groupingExpressions: Seq[NamedExpression], | ||
| maybeChildPlan: SparkPlan): SparkPlan = { | ||
| maybeChildPlan: SparkPlan, | ||
| isStreaming: Boolean = false): SparkPlan = { | ||
| groupingExpressions.find(_.metadata.contains(SessionWindow.marker)) match { | ||
| case Some(sessionExpression) => | ||
| UpdatingSessionsExec( | ||
| isStreaming = isStreaming, | ||
| // numShufflePartitions will be set to None, and replaced to the actual value in the | ||
| // state rule if the query is streaming. | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ditto
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same here |
||
| numShufflePartitions = None, | ||
| groupingExpressions.map(_.toAttribute), | ||
| sessionExpression.toAttribute, | ||
| maybeChildPlan) | ||
|
|
@@ -506,7 +540,8 @@ object AggUtils { | |
| private def mayAppendMergingSessionExec( | ||
| groupingExpressions: Seq[NamedExpression], | ||
| aggregateExpressions: Seq[AggregateExpression], | ||
| partialAggregate: SparkPlan): SparkPlan = { | ||
| partialAggregate: SparkPlan, | ||
| isStreaming: Boolean = false): SparkPlan = { | ||
| groupingExpressions.find(_.metadata.contains(SessionWindow.marker)) match { | ||
| case Some(sessionExpression) => | ||
| val aggExpressions = aggregateExpressions.map(_.copy(mode = PartialMerge)) | ||
|
|
@@ -519,7 +554,10 @@ object AggUtils { | |
|
|
||
| MergingSessionsExec( | ||
| requiredChildDistributionExpressions = Some(groupingWithoutSessionsAttributes), | ||
| requiredChildDistributionOption = None, | ||
| isStreaming = isStreaming, | ||
| // numShufflePartitions will be set to None, and replaced to the actual value in the | ||
| // state rule if the query is streaming. | ||
| numShufflePartitions = None, | ||
| groupingExpressions = groupingAttributes, | ||
| sessionExpression = sessionExpression, | ||
| aggregateExpressions = aggExpressions, | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -21,12 +21,15 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, | |
| import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Final, PartialMerge} | ||
| import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, UnspecifiedDistribution} | ||
| import org.apache.spark.sql.execution.{AliasAwareOutputPartitioning, ExplainUtils, UnaryExecNode} | ||
| import org.apache.spark.sql.execution.streaming.StatefulOperatorPartitioning | ||
|
|
||
| /** | ||
| * Holds common logic for aggregate operators | ||
| */ | ||
| trait BaseAggregateExec extends UnaryExecNode with AliasAwareOutputPartitioning { | ||
| def requiredChildDistributionExpressions: Option[Seq[Expression]] | ||
| def isStreaming: Boolean | ||
| def numShufflePartitions: Option[Int] | ||
| def groupingExpressions: Seq[NamedExpression] | ||
| def aggregateExpressions: Seq[AggregateExpression] | ||
| def aggregateAttributes: Seq[Attribute] | ||
|
|
@@ -92,7 +95,20 @@ trait BaseAggregateExec extends UnaryExecNode with AliasAwareOutputPartitioning | |
| override def requiredChildDistribution: List[Distribution] = { | ||
| requiredChildDistributionExpressions match { | ||
| case Some(exprs) if exprs.isEmpty => AllTuples :: Nil | ||
| case Some(exprs) => ClusteredDistribution(exprs) :: Nil | ||
| case Some(exprs) => | ||
| if (isStreaming) { | ||
| numShufflePartitions match { | ||
| case Some(parts) => | ||
| StatefulOperatorPartitioning.getCompatibleDistribution( | ||
| exprs, parts, conf) :: Nil | ||
|
|
||
| case _ => | ||
| throw new IllegalStateException("Expected to set the number of partitions before " + | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe we can also add a Or we can define only one variable
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We create a node with numShufflePartitions = None and replace the value in state rule. That said, we can't check the condition before state rule has been performed.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe we should use the new error framework to throw exception in newly added code.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The error framework is for user facing errors. This is something like "this should not be called, internal error". I just made the error message be general to make our developer life be easier.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok |
||
| "constructing required child distribution!") | ||
| } | ||
| } else { | ||
| ClusteredDistribution(exprs) :: Nil | ||
| } | ||
| case None => UnspecifiedDistribution :: Nil | ||
| } | ||
| } | ||
|
|
@@ -102,7 +118,8 @@ trait BaseAggregateExec extends UnaryExecNode with AliasAwareOutputPartitioning | |
| */ | ||
| def toSortAggregate: SortAggregateExec = { | ||
| SortAggregateExec( | ||
| requiredChildDistributionExpressions, groupingExpressions, aggregateExpressions, | ||
| aggregateAttributes, initialInputBufferOffset, resultExpressions, child) | ||
| requiredChildDistributionExpressions, isStreaming, numShufflePartitions, groupingExpressions, | ||
| aggregateExpressions, aggregateAttributes, initialInputBufferOffset, resultExpressions, | ||
| child) | ||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's link the state rule class name here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't create a dedicate class for state rule. See
spark/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
Lines 105 to 209 in a30575e