Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ import org.apache.spark.sql.execution.{ColumnarRule, SparkPlan}
* <li>Customized Parser.</li>
* <li>(External) Catalog listeners.</li>
* <li>Columnar Rules.</li>
* <li>Adaptive Query Post Planner Strategy Rules.</li>
* <li>Adaptive Query Stage Preparation Rules.</li>
* <li>Adaptive Query Execution Runtime Optimizer Rules.</li>
* <li>Adaptive Query Stage Optimizer Rules.</li>
Expand Down Expand Up @@ -114,12 +115,15 @@ class SparkSessionExtensions {
type ColumnarRuleBuilder = SparkSession => ColumnarRule
type QueryStagePrepRuleBuilder = SparkSession => Rule[SparkPlan]
type QueryStageOptimizerRuleBuilder = SparkSession => Rule[SparkPlan]
type QueryPostPlannerStrategyBuilder = SparkSession => Rule[SparkPlan]
Copy link
Member

@dongjoon-hyun dongjoon-hyun Nov 29, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we put this before type QueryStagePrepRuleBuilder because this rule is supposed to be used before queryStagePreparationRules?


private[this] val columnarRuleBuilders = mutable.Buffer.empty[ColumnarRuleBuilder]
private[this] val queryStagePrepRuleBuilders = mutable.Buffer.empty[QueryStagePrepRuleBuilder]
private[this] val runtimeOptimizerRules = mutable.Buffer.empty[RuleBuilder]
private[this] val queryStageOptimizerRuleBuilders =
mutable.Buffer.empty[QueryStageOptimizerRuleBuilder]
private[this] val queryPostPlannerStrategyRuleBuilders =
mutable.Buffer.empty[QueryPostPlannerStrategyBuilder]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto. Shall we put this before private[this] val queryStagePrepRuleBuilders?


/**
* Build the override rules for columnar execution.
Expand Down Expand Up @@ -149,6 +153,14 @@ class SparkSessionExtensions {
queryStageOptimizerRuleBuilders.map(_.apply(session)).toSeq
}

/**
* Build the override rules for the query post planner strategy phase of adaptive query execution.
*/
private[sql] def buildQueryPostPlannerStrategyRules(
session: SparkSession): Seq[Rule[SparkPlan]] = {
queryPostPlannerStrategyRuleBuilders.map(_.apply(session)).toSeq
}

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto. Shall we put this before private[sql] def buildQueryStagePrepRules?

/**
* Inject a rule that can override the columnar execution of an executor.
*/
Expand Down Expand Up @@ -185,6 +197,15 @@ class SparkSessionExtensions {
queryStageOptimizerRuleBuilders += builder
}

/**
* Inject a rule that applied between `plannerStrategy` and `queryStagePrepRules`, so
* it can get the whole plan before injecting exchanges.
* Note, these rules can only be applied within AQE.
*/
def injectQueryPostPlannerStrategyRule(builder: QueryPostPlannerStrategyBuilder): Unit = {
queryPostPlannerStrategyRuleBuilders += builder
}

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto. Shall we put this before def injectQueryStagePrepRule?

private[this] val resolutionRuleBuilders = mutable.Buffer.empty[RuleBuilder]

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,12 @@ import org.apache.spark.sql.execution.SparkPlan
* query stage
* @param queryStageOptimizerRules applied to a new query stage before its execution. It makes sure
* all children query stages are materialized
* @param queryPostPlannerStrategyRules applied between `plannerStrategy` and `queryStagePrepRules`,
* so it can get the whole plan before injecting exchanges.
*/
class AdaptiveRulesHolder(
val queryStagePrepRules: Seq[Rule[SparkPlan]],
val runtimeOptimizerRules: Seq[Rule[LogicalPlan]],
val queryStageOptimizerRules: Seq[Rule[SparkPlan]]) {
val queryStageOptimizerRules: Seq[Rule[SparkPlan]],
val queryPostPlannerStrategyRules: Seq[Rule[SparkPlan]]) {
}
Original file line number Diff line number Diff line change
Expand Up @@ -193,9 +193,19 @@ case class AdaptiveSparkPlanExec(
optimized
}

private def applyQueryPostPlannerStrategyRules(plan: SparkPlan): SparkPlan = {
applyPhysicalRules(
plan,
context.session.sessionState.adaptiveRulesHolder.queryPostPlannerStrategyRules,
Some((planChangeLogger, "AQE Query Post Planner Strategy Rules"))
)
}

@transient val initialPlan = context.session.withActive {
applyPhysicalRules(
inputPlan, queryStagePreparationRules, Some((planChangeLogger, "AQE Preparations")))
applyQueryPostPlannerStrategyRules(inputPlan),
queryStagePreparationRules,
Some((planChangeLogger, "AQE Preparations")))
}

@volatile private var currentPhysicalPlan = initialPlan
Expand Down Expand Up @@ -706,7 +716,7 @@ case class AdaptiveSparkPlanExec(
val optimized = optimizer.execute(logicalPlan)
val sparkPlan = context.session.sessionState.planner.plan(ReturnAnswer(optimized)).next()
val newPlan = applyPhysicalRules(
sparkPlan,
applyQueryPostPlannerStrategyRules(sparkPlan),
preprocessingRules ++ queryStagePreparationRules,
Some((planChangeLogger, "AQE Replanning")))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,8 @@ abstract class BaseSessionStateBuilder(
new AdaptiveRulesHolder(
extensions.buildQueryStagePrepRules(session),
extensions.buildRuntimeOptimizerRules(session),
extensions.buildQueryStageOptimizerRules(session))
extensions.buildQueryStageOptimizerRules(session),
extensions.buildQueryPostPlannerStrategyRules(session))
}

protected def planNormalizationRules: Seq[Rule[LogicalPlan]] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,17 @@ import org.apache.spark.sql.catalyst.{FunctionIdentifier, InternalRow, TableIden
import org.apache.spark.sql.catalyst.catalog.BucketSpec
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.{Final, Partial}
import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParserInterface}
import org.apache.spark.sql.catalyst.plans.SQLHelper
import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Limit, LocalRelation, LogicalPlan, Statistics, UnresolvedHint}
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, SinglePartition}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.TreeNodeTag
import org.apache.spark.sql.connector.write.WriterCommitMessage
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AdaptiveSparkPlanHelper, AQEShuffleReadExec, QueryStageExec, ShuffleQueryStageExec}
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
import org.apache.spark.sql.execution.datasources.{FileFormat, WriteFilesExec, WriteFilesSpec}
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, BroadcastExchangeLike, ShuffleExchangeExec, ShuffleExchangeLike, ShuffleOrigin}
import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector
Expand Down Expand Up @@ -516,6 +518,33 @@ class SparkSessionExtensionSuite extends SparkFunSuite with SQLHelper with Adapt
}
}
}

test("SPARK-46170: Support inject adaptive query post planner strategy rules in " +
"SparkSessionExtensions") {
val extensions = create { extensions =>
extensions.injectQueryPostPlannerStrategyRule(_ => MyQueryPostPlannerStrategyRule)
}
withSession(extensions) { session =>
assert(session.sessionState.adaptiveRulesHolder.queryPostPlannerStrategyRules
.contains(MyQueryPostPlannerStrategyRule))
import session.sqlContext.implicits._
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about just use session.implicits._?

withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "3",
SQLConf.COALESCE_PARTITIONS_ENABLED.key -> "false") {
val input = Seq((10), (20), (10)).toDF("c1")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is it necessary to add parentheses to each element? for readability?

val df = input.groupBy("c1").count()
df.collect()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried it, and even if df.collect() is not executed, this test case can still pass. So, is it necessary?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is used to make sure we are checking the final plan

assert(df.rdd.partitions.length == 1)
assert(find(df.queryExecution.executedPlan) {
case s: ShuffleExchangeExec if s.outputPartitioning == SinglePartition => true
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about case s: ShuffleExchangeExec => s.outputPartitioning == SinglePartition

Additionally, a personal opinion unrelated to this pr: If there is an exists function in AdaptiveSparkPlanHelper, would this assertion be simpler to write?

case _ => false
}.isDefined)
assert(find(df.queryExecution.executedPlan) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems using collectFirst could eliminate the case _ => false branch?

case _: SortExec => true
case _ => false
}.isDefined)
}
}
}
}

case class MyRule(spark: SparkSession) extends Rule[LogicalPlan] {
Expand Down Expand Up @@ -1190,3 +1219,14 @@ object RequireAtLeaseTwoPartitions extends Rule[SparkPlan] {
}
}
}

object MyQueryPostPlannerStrategyRule extends Rule[SparkPlan] {
override def apply(plan: SparkPlan): SparkPlan = {
plan.transformUp {
case h: HashAggregateExec if h.aggregateExpressions.map(_.mode).contains(Partial) =>
ShuffleExchangeExec(SinglePartition, h)
case h: HashAggregateExec if h.aggregateExpressions.map(_.mode).contains(Final) =>
SortExec(h.groupingExpressions.map(k => SortOrder.apply(k, Ascending)), false, h)
}
}
}