Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ import org.apache.spark.broadcast
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.expressions.{Attribute, SortOrder}
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, ReturnAnswer}
import org.apache.spark.sql.catalyst.plans.physical.{Distribution, UnspecifiedDistribution}
import org.apache.spark.sql.catalyst.plans.physical.{Distribution, Partitioning, UnspecifiedDistribution}
import org.apache.spark.sql.catalyst.rules.{PlanChangeLogger, Rule}
import org.apache.spark.sql.catalyst.trees.TreeNodeTag
import org.apache.spark.sql.catalyst.util.sideBySide
Expand Down Expand Up @@ -188,7 +188,7 @@ case class AdaptiveSparkPlanExec(

@volatile private var currentPhysicalPlan = initialPlan

private var isFinalPlan = false
@volatile private var isFinalPlan = false

private var currentStageId = 0

Expand All @@ -209,6 +209,19 @@ case class AdaptiveSparkPlanExec(

override def output: Seq[Attribute] = inputPlan.output

// Try our best to give a stable output partitioning and ordering.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm trying to understand this "best effort". AFAIK, table cache is lazy. For a query that accesses a cached query the first time, the cached query is not executed yet so we don't know the output partitioning/ordering and can't optimize out shuffles. But when the cached query is accessed the next time, it's already executed and we know the output partitioning/ordering.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, in general the first action for a cached plan is count, e.g. CacheTableAsSelectExec, so I think it is a not big issue that we can not optimize the shuffle/sort for the first action.

The usage of the cache is: user wants to reference it multi-times, then this optimization will help a lot.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would be super limited use... and cause inconsistency.

I'd only return output partitioning if there is a user repartition op in the end. In other words, only if AQE plan is required to preserve user specified partitioning.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately we hint this.. per my experience, user always caches an arbitrary df and use the cached df to build an another arbitrary df. So why can't we preserve the partitioning/ordering of the cached plan ? If you really feel inconsistency in AdaptiveSparkPlanExec, we can probably move to InMemoryRelationExec.

My original idea is to do the both but feel a little overkill (requiredOrdering should be inferred separately like #35924)

requiredDistribution.map(_.createPartitioning(conf.shufflePartitions)).getOrElse {
  if (isFinalPlan) {
    executedPlan.outputPartitioning
  } else {
    super.outputPartitioning
  }
}

A useful distribution before caching is few in production since repartition(col) will introduce skew

override def outputPartitioning: Partitioning = if (isFinalPlan) {
executedPlan.outputPartitioning
} else {
super.outputPartitioning
}

override def outputOrdering: Seq[SortOrder] = if (isFinalPlan) {
executedPlan.outputOrdering
} else {
super.outputOrdering
}

override def doCanonicalize(): SparkPlan = inputPlan.canonicalized

override def resetMetrics(): Unit = {
Expand Down
23 changes: 21 additions & 2 deletions sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,10 @@ import org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation
import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LeafNode, LocalRelation, LogicalPlan, OneRowRelation, Statistics}
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.connector.FakeV2Provider
import org.apache.spark.sql.execution.{FilterExec, LogicalRDD, QueryExecution, WholeStageCodegenExec}
import org.apache.spark.sql.execution.{FilterExec, LogicalRDD, QueryExecution, SortExec, WholeStageCodegenExec}
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec}
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec, ShuffleExchangeLike}
import org.apache.spark.sql.expressions.{Aggregator, Window}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
Expand Down Expand Up @@ -3513,6 +3513,25 @@ class DataFrameSuite extends QueryTest
assert(df.queryExecution.executedPlan.execute().getNumPartitions == 2)
}
}

test("SPARK-41048: Improve output partitioning and ordering with AQE cache") {
withSQLConf(
SQLConf.CAN_CHANGE_CACHED_PLAN_OUTPUT_PARTITIONING.key -> "true",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

after this PR, we can probably turn this on by default, to improve AQE coverage.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agree, We can also remvoe the internal tag

SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
val df1 = spark.range(10).selectExpr("cast(id as string) c1")
val df2 = spark.range(10).selectExpr("cast(id as string) c2")
val cached = df1.join(df2, $"c1" === $"c2").cache()
cached.count()
val executedPlan = cached.groupBy("c1").agg(max($"c2")).queryExecution.executedPlan
// before is 2 sort and 1 shuffle
assert(collect(executedPlan) {
case s: ShuffleExchangeLike => s
}.isEmpty)
assert(collect(executedPlan) {
case s: SortExec => s
}.isEmpty)
}
}
}

case class GroupByKey(a: Int, b: Int)
Expand Down