diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala index 82da9aa69401..68ada074e03a 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala @@ -514,7 +514,9 @@ class CHSparkPlanExecApi extends SparkPlanExecApi with Logging { val childWithAdapter = ColumnarCollapseTransformStages.wrapInputIteratorTransformer(child) WholeStageTransformer( ProjectExecTransformer(child.output ++ appendedProjections, childWithAdapter))( - ColumnarCollapseTransformStages.transformStageCounter.incrementAndGet() + ColumnarCollapseTransformStages + .getTransformStageCounter(childWithAdapter) + .incrementAndGet() ) } diff --git a/backends-clickhouse/src/test/scala/org/apache/spark/sql/execution/benchmarks/CHAggAndShuffleBenchmark.scala b/backends-clickhouse/src/test/scala/org/apache/spark/sql/execution/benchmarks/CHAggAndShuffleBenchmark.scala index 3f7ac3eccc3d..fb6c7974a8ce 100644 --- a/backends-clickhouse/src/test/scala/org/apache/spark/sql/execution/benchmarks/CHAggAndShuffleBenchmark.scala +++ b/backends-clickhouse/src/test/scala/org/apache/spark/sql/execution/benchmarks/CHAggAndShuffleBenchmark.scala @@ -156,7 +156,10 @@ object CHAggAndShuffleBenchmark extends SqlBasedBenchmark with CHSqlBasedBenchma // Get the `FileSourceScanExecTransformer` val fileScan = executedPlan.collect { case scan: FileSourceScanExecTransformer => scan }.head val scanStage = WholeStageTransformer(fileScan)( - ColumnarCollapseTransformStages.transformStageCounter.incrementAndGet()) + ColumnarCollapseTransformStages + .getTransformStageCounter(fileScan) + .incrementAndGet() + ) val scanStageRDD = scanStage.executeColumnar() // Get the total row count @@ -200,7 +203,9 @@ object CHAggAndShuffleBenchmark extends SqlBasedBenchmark with CHSqlBasedBenchma val projectFilter = executedPlan.collect { case project: ProjectExecTransformer => project } if (projectFilter.nonEmpty) { val projectFilterStage = WholeStageTransformer(projectFilter.head)( - ColumnarCollapseTransformStages.transformStageCounter.incrementAndGet()) + ColumnarCollapseTransformStages + .getTransformStageCounter(projectFilter.head) + .incrementAndGet()) val projectFilterStageRDD = projectFilterStage.executeColumnar() chAllStagesBenchmark.addCase(s"Project Stage", executedCnt) { diff --git a/backends-velox/src-delta33/main/scala/org/apache/spark/sql/delta/files/GlutenDeltaFileFormatWriter.scala b/backends-velox/src-delta33/main/scala/org/apache/spark/sql/delta/files/GlutenDeltaFileFormatWriter.scala index 69674b40a31a..8b5a4fdc3428 100644 --- a/backends-velox/src-delta33/main/scala/org/apache/spark/sql/delta/files/GlutenDeltaFileFormatWriter.scala +++ b/backends-velox/src-delta33/main/scala/org/apache/spark/sql/delta/files/GlutenDeltaFileFormatWriter.scala @@ -260,9 +260,8 @@ object GlutenDeltaFileFormatWriter extends LoggingShims { nativeSortPlan } val newPlan = sortPlan.child match { - case WholeStageTransformer(wholeStageChild, materializeInput) => - WholeStageTransformer(addNativeSort(wholeStageChild), - materializeInput)(ColumnarCollapseTransformStages.transformStageCounter.incrementAndGet()) + case wst @ WholeStageTransformer(wholeStageChild, _) => + wst.withNewChildren(Seq(addNativeSort(wholeStageChild))) case other => Transitions.toBatchPlan(sortPlan, VeloxBatchType) } diff --git a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala index 7a78bb846817..3bf030bf2ca8 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala @@ -139,6 +139,7 @@ object VeloxRuleApi { .getExtendedColumnarPostRules() .foreach(each => injector.injectPost(c => each(c.session))) injector.injectPost(c => ColumnarCollapseTransformStages(new GlutenConfig(c.sqlConf))) + injector.injectPost(_ => GenerateTransformStageId()) injector.injectPost(c => CudfNodeValidationRule(new GlutenConfig(c.sqlConf))) injector.injectPost(c => GlutenNoopWriterRule(c.session)) @@ -240,6 +241,7 @@ object VeloxRuleApi { .getExtendedColumnarPostRules() .foreach(each => injector.injectPostTransform(c => each(c.session))) injector.injectPostTransform(c => ColumnarCollapseTransformStages(new GlutenConfig(c.sqlConf))) + injector.injectPostTransform(_ => GenerateTransformStageId()) injector.injectPostTransform(c => CudfNodeValidationRule(new GlutenConfig(c.sqlConf))) injector.injectPostTransform(c => GlutenNoopWriterRule(c.session)) injector.injectPostTransform(c => RemoveGlutenTableCacheColumnarToRow(c.session)) diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/execution/TakeOrderedAndProjectExecTransformer.scala b/gluten-substrait/src/main/scala/org/apache/gluten/execution/TakeOrderedAndProjectExecTransformer.scala index 2887812c8a75..39b37a0d1877 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/execution/TakeOrderedAndProjectExecTransformer.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/execution/TakeOrderedAndProjectExecTransformer.scala @@ -137,7 +137,7 @@ case class TakeOrderedAndProjectExecTransformer( LimitExecTransformer(localSortPlan, limitBeforeShuffleOffset, limit) } val transformStageCounter: AtomicInteger = - ColumnarCollapseTransformStages.transformStageCounter + ColumnarCollapseTransformStages.getTransformStageCounter(child) val finalLimitPlan = if (hasShuffle) { limitBeforeShuffle } else { diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/execution/WholeStageTransformer.scala b/gluten-substrait/src/main/scala/org/apache/gluten/execution/WholeStageTransformer.scala index db42778d294e..acef5d798ea0 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/execution/WholeStageTransformer.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/execution/WholeStageTransformer.scala @@ -155,7 +155,7 @@ trait UnaryTransformSupport extends TransformSupport with UnaryExecNode { } case class WholeStageTransformer(child: SparkPlan, materializeInput: Boolean = false)( - val transformStageId: Int + var transformStageId: Int ) extends WholeStageTransformerGenerateTreeStringShim with UnaryTransformSupport { diff --git a/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/ColumnarCollapseTransformStages.scala b/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/ColumnarCollapseTransformStages.scala index d41f06d87d9d..a770fdf7a328 100644 --- a/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/ColumnarCollapseTransformStages.scala +++ b/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/ColumnarCollapseTransformStages.scala @@ -139,10 +139,7 @@ case class InputIteratorTransformer(child: SparkPlan) extends UnaryTransformSupp * created, e.g. for special fallback handling when an existing WholeStageTransformer failed to * generate/compile code. */ -case class ColumnarCollapseTransformStages( - glutenConf: GlutenConfig, - transformStageCounter: AtomicInteger = ColumnarCollapseTransformStages.transformStageCounter) - extends Rule[SparkPlan] { +case class ColumnarCollapseTransformStages(glutenConf: GlutenConfig) extends Rule[SparkPlan] { def apply(plan: SparkPlan): SparkPlan = { insertWholeStageTransformer(plan) @@ -176,8 +173,8 @@ case class ColumnarCollapseTransformStages( private def insertWholeStageTransformer(plan: SparkPlan): SparkPlan = { plan match { case t if supportTransform(t) => - WholeStageTransformer(t.withNewChildren(t.children.map(insertInputIteratorTransformer)))( - transformStageCounter.incrementAndGet()) + // transformStageId will be updated by rule `GenerateTransformStageId`. + WholeStageTransformer(t.withNewChildren(t.children.map(insertInputIteratorTransformer)))(-1) case other => other.withNewChildren(other.children.map(insertWholeStageTransformer)) } @@ -213,9 +210,20 @@ case class ColumnarInputAdapter(child: SparkPlan) } object ColumnarCollapseTransformStages { - val transformStageCounter = new AtomicInteger(0) - def wrapInputIteratorTransformer(plan: SparkPlan): TransformSupport = { InputIteratorTransformer(ColumnarInputAdapter(plan)) } + + def getTransformStageCounter(plan: SparkPlan): AtomicInteger = { + new AtomicInteger(findMaxTransformStageId(plan)) + } + + private def findMaxTransformStageId(plan: SparkPlan): Int = { + plan match { + case wst: WholeStageTransformer => + wst.transformStageId + case _ => + plan.children.map(findMaxTransformStageId).foldLeft(0)(Math.max) + } + } } diff --git a/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/GenerateTransformStageId.scala b/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/GenerateTransformStageId.scala new file mode 100644 index 000000000000..b28eac8a126a --- /dev/null +++ b/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/GenerateTransformStageId.scala @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution + +import org.apache.gluten.exception.GlutenException +import org.apache.gluten.execution.WholeStageTransformer +import org.apache.gluten.sql.shims.SparkShimLoader + +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AdaptiveSparkPlanHelper, BroadcastQueryStageExec, ShuffleQueryStageExec} +import org.apache.spark.sql.execution.exchange.{BroadcastExchangeLike, ReusedExchangeExec, ShuffleExchangeLike} + +import java.util +import java.util.Collections.newSetFromMap +import java.util.concurrent.atomic.AtomicInteger + +/** + * Generate `transformStageId` for `WholeStageTransformerExec`. This rule updates the whole plan + * tree with * incremental and unique transform stage id before the final execution. + * + * In Spark, the whole stage id is generated by incrementing a global counter. In Gluten, it's not + * possible to use global counter for id generation, especially in the case of AQE. + */ +case class GenerateTransformStageId() extends Rule[SparkPlan] with AdaptiveSparkPlanHelper { + private val transformStageCounter: AtomicInteger = new AtomicInteger(0) + + private val wholeStageTransformerCache = + newSetFromMap[WholeStageTransformer](new util.IdentityHashMap()) + + def apply(plan: SparkPlan): SparkPlan = { + updateStageId(plan) + plan + } + + private def updateStageId(plan: SparkPlan): Unit = { + plan match { + case b: BroadcastQueryStageExec => + b.plan match { + case b: BroadcastExchangeLike => updateStageId(b) + case _: ReusedExchangeExec => + case _ => + throw new GlutenException(s"wrong plan for broadcast stage:\n ${plan.treeString}") + } + case s: ShuffleQueryStageExec => + s.plan match { + case s: ShuffleExchangeLike => updateStageId(s) + case _: ReusedExchangeExec => + case _ => + throw new GlutenException(s"wrong plan for shuffle stage:\n ${plan.treeString}") + } + case aqe: AdaptiveSparkPlanExec if SparkShimLoader.getSparkShims.isFinalAdaptivePlan(aqe) => + updateStageId(stripAQEPlan(aqe)) + case wst: WholeStageTransformer if !wholeStageTransformerCache.contains(wst) => + updateStageId(wst.child) + wst.transformStageId = transformStageCounter.incrementAndGet() + wholeStageTransformerCache.add(wst) + case plan => + plan.subqueries.foreach(updateStageId) + plan.children.foreach(updateStageId) + } + } +} diff --git a/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/GlutenImplicits.scala b/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/GlutenImplicits.scala index 7267ce56ba1c..48710a9edff5 100644 --- a/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/GlutenImplicits.scala +++ b/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/GlutenImplicits.scala @@ -95,10 +95,7 @@ object GlutenImplicits { } private def isFinalAdaptivePlan(p: AdaptiveSparkPlanExec): Boolean = { - val args = p.argString(Int.MaxValue) - val index = args.indexOf("isFinalPlan=") - assert(index >= 0) - args.substring(index + "isFinalPlan=".length).trim.toBoolean + SparkShimLoader.getSparkShims.isFinalAdaptivePlan(p) } private def collectFallbackNodes( diff --git a/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/datasources/GlutenFormatWriterInjectsBase.scala b/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/datasources/GlutenFormatWriterInjectsBase.scala index 53ee2855e6d8..7c9721e2f0ef 100644 --- a/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/datasources/GlutenFormatWriterInjectsBase.scala +++ b/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/datasources/GlutenFormatWriterInjectsBase.scala @@ -23,7 +23,6 @@ import org.apache.gluten.extension.columnar.heuristic.HeuristicTransform import org.apache.gluten.extension.columnar.transition.{InsertTransitions, RemoveTransitions} import org.apache.spark.sql.execution.{ColumnarCollapseTransformStages, SparkPlan} -import org.apache.spark.sql.execution.ColumnarCollapseTransformStages.transformStageCounter trait GlutenFormatWriterInjectsBase extends GlutenFormatWriterInjects { private lazy val transform = HeuristicTransform.static() @@ -66,7 +65,7 @@ trait GlutenFormatWriterInjectsBase extends GlutenFormatWriterInjects { // and cannot provide const-ness. val transformedWithAdapter = injectAdapter(transformed) val wst = WholeStageTransformer(transformedWithAdapter, materializeInput = true)( - transformStageCounter.incrementAndGet()) + ColumnarCollapseTransformStages.getTransformStageCounter(transformed).incrementAndGet()) val wstWithTransitions = BackendsApiManager.getSparkPlanExecApiInstance.genColumnarToCarrierRow( InsertTransitions.create(outputsColumnar = true, wst.batchType()).apply(wst)) wstWithTransitions diff --git a/gluten-ut/spark40/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala b/gluten-ut/spark40/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala index d1c66e66d954..24b84d58aee9 100644 --- a/gluten-ut/spark40/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala +++ b/gluten-ut/spark40/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala @@ -858,7 +858,7 @@ class VeloxTestSettings extends BackendTestSettings { // ORC related .exclude("SPARK-37965: Spark support read/write orc file with invalid char in field name") .exclude("SPARK-38173: Quoted column cannot be recognized correctly when quotedRegexColumnNames is true") - // TODO: fix in Spark-4.0 + // Rewrite with Gluten's explained result. .exclude("SPARK-47939: Explain should work with parameterized queries") enableSuite[GlutenSQLQueryTestSuite] enableSuite[GlutenStatisticsCollectionSuite] diff --git a/gluten-ut/spark40/src/test/scala/org/apache/spark/sql/GlutenSQLQuerySuite.scala b/gluten-ut/spark40/src/test/scala/org/apache/spark/sql/GlutenSQLQuerySuite.scala index 8f397c517ef1..c75569af259a 100644 --- a/gluten-ut/spark40/src/test/scala/org/apache/spark/sql/GlutenSQLQuerySuite.scala +++ b/gluten-ut/spark40/src/test/scala/org/apache/spark/sql/GlutenSQLQuerySuite.scala @@ -154,4 +154,121 @@ class GlutenSQLQuerySuite extends SQLQuerySuite with GlutenSQLTestsTrait { assert(inputOutputPairs.map(_._1).sum == numRows) } } + + testGluten("SPARK-47939: Explain should work with parameterized queries") { + def checkQueryPlan(df: DataFrame, plan: String): Unit = assert( + df.collect() + .map(_.getString(0)) + .map(_.replaceAll("#[0-9]+", "#N")) + // Remove the backend keyword in c2r/r2c. + .map(_.replaceAll("[A-Za-z]*ColumnarToRow", "ColumnarToRow")) + .map(_.replaceAll("RowTo[A-Za-z]*Columnar", "RowToColumnar")) + === Array(plan.stripMargin) + ) + + checkQueryPlan( + spark.sql("explain select ?", Array(1)), + """== Physical Plan == + |ColumnarToRow + |+- ^(1) ProjectExecTransformer [1 AS 1#N] + | +- ^(1) InputIteratorTransformer[] + | +- RowToColumnar + | +- *(1) Scan OneRowRelation[] + | + |""" + ) + checkQueryPlan( + spark.sql("explain select :first", Map("first" -> 1)), + """== Physical Plan == + |ColumnarToRow + |+- ^(1) ProjectExecTransformer [1 AS 1#N] + | +- ^(1) InputIteratorTransformer[] + | +- RowToColumnar + | +- *(1) Scan OneRowRelation[] + | + |""" + ) + + checkQueryPlan( + spark.sql("explain explain explain select ?", Array(1)), + """== Physical Plan == + |Execute ExplainCommand + | +- ExplainCommand ExplainCommand 'PosParameterizedQuery [1], SimpleMode, SimpleMode + + |""" + ) + checkQueryPlan( + spark.sql("explain explain explain select :first", Map("first" -> 1)), + // scalastyle:off + """== Physical Plan == + |Execute ExplainCommand + | +- ExplainCommand ExplainCommand 'NameParameterizedQuery [first], [1], SimpleMode, SimpleMode + + |""" + // scalastyle:on + ) + + checkQueryPlan( + spark.sql("explain describe select ?", Array(1)), + """== Physical Plan == + |Execute DescribeQueryCommand + | +- DescribeQueryCommand select ? + + |""" + ) + checkQueryPlan( + spark.sql("explain describe select :first", Map("first" -> 1)), + """== Physical Plan == + |Execute DescribeQueryCommand + | +- DescribeQueryCommand select :first + + |""" + ) + + checkQueryPlan( + spark.sql("explain extended select * from values (?, ?) t(x, y)", Array(1, "a")), + """== Parsed Logical Plan == + |'PosParameterizedQuery [1, a] + |+- 'Project [*] + | +- 'SubqueryAlias t + | +- 'UnresolvedInlineTable [x, y], [[posparameter(39), posparameter(42)]] + + |== Analyzed Logical Plan == + |x: int, y: string + |Project [x#N, y#N] + |+- SubqueryAlias t + | +- LocalRelation [x#N, y#N] + + |== Optimized Logical Plan == + |LocalRelation [x#N, y#N] + + |== Physical Plan == + |LocalTableScan [x#N, y#N] + |""" + ) + checkQueryPlan( + spark.sql( + "explain extended select * from values (:first, :second) t(x, y)", + Map("first" -> 1, "second" -> "a") + ), + """== Parsed Logical Plan == + |'NameParameterizedQuery [first, second], [1, a] + |+- 'Project [*] + | +- 'SubqueryAlias t + | +- 'UnresolvedInlineTable [x, y], [[namedparameter(first), namedparameter(second)]] + + |== Analyzed Logical Plan == + |x: int, y: string + |Project [x#N, y#N] + |+- SubqueryAlias t + | +- LocalRelation [x#N, y#N] + + |== Optimized Logical Plan == + |LocalRelation [x#N, y#N] + + |== Physical Plan == + |LocalTableScan [x#N, y#N] + |""" + ) + } } diff --git a/shims/common/src/main/scala/org/apache/gluten/sql/shims/SparkShims.scala b/shims/common/src/main/scala/org/apache/gluten/sql/shims/SparkShims.scala index 9164f4b7c43a..c2542f036882 100644 --- a/shims/common/src/main/scala/org/apache/gluten/sql/shims/SparkShims.scala +++ b/shims/common/src/main/scala/org/apache/gluten/sql/shims/SparkShims.scala @@ -39,6 +39,7 @@ import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.connector.read.{InputPartition, Scan} import org.apache.spark.sql.connector.read.streaming.SparkDataStream import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.parquet.ParquetFilters import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec, DataSourceV2ScanExecBase} @@ -365,4 +366,6 @@ trait SparkShims { sparkSession: SparkSession, planner: SparkPlanner, plan: LogicalPlan): SparkPlan + + def isFinalAdaptivePlan(p: AdaptiveSparkPlanExec): Boolean } diff --git a/shims/spark32/src/main/scala/org/apache/gluten/sql/shims/spark32/Spark32Shims.scala b/shims/spark32/src/main/scala/org/apache/gluten/sql/shims/spark32/Spark32Shims.scala index d1bf07e64b12..4d029929541c 100644 --- a/shims/spark32/src/main/scala/org/apache/gluten/sql/shims/spark32/Spark32Shims.scala +++ b/shims/spark32/src/main/scala/org/apache/gluten/sql/shims/spark32/Spark32Shims.scala @@ -40,6 +40,7 @@ import org.apache.spark.sql.catalyst.util.RebaseDateTime.RebaseSpec import org.apache.spark.sql.connector.catalog.Table import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.execution.{FileSourceScanExec, PartitionedFileUtil, QueryExecution, SparkPlan, SparkPlanner} +import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.FileFormatWriter.Empty2Null import org.apache.spark.sql.execution.datasources.parquet.ParquetFilters @@ -315,4 +316,11 @@ class Spark32Shims extends SparkShims { planner: SparkPlanner, plan: LogicalPlan): SparkPlan = QueryExecution.createSparkPlan(sparkSession, planner, plan) + + override def isFinalAdaptivePlan(p: AdaptiveSparkPlanExec): Boolean = { + val args = p.argString(Int.MaxValue) + val index = args.indexOf("isFinalPlan=") + assert(index >= 0) + args.substring(index + "isFinalPlan=".length).trim.toBoolean + } } diff --git a/shims/spark33/src/main/scala/org/apache/gluten/sql/shims/spark33/Spark33Shims.scala b/shims/spark33/src/main/scala/org/apache/gluten/sql/shims/spark33/Spark33Shims.scala index a18fb3171991..19bb10c0eea7 100644 --- a/shims/spark33/src/main/scala/org/apache/gluten/sql/shims/spark33/Spark33Shims.scala +++ b/shims/spark33/src/main/scala/org/apache/gluten/sql/shims/spark33/Spark33Shims.scala @@ -41,6 +41,7 @@ import org.apache.spark.sql.catalyst.util.RebaseDateTime.RebaseSpec import org.apache.spark.sql.connector.catalog.Table import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.execution.{FileSourceScanExec, PartitionedFileUtil, QueryExecution, SparkPlan, SparkPlanner} +import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.FileFormatWriter.Empty2Null import org.apache.spark.sql.execution.datasources.parquet.ParquetFilters @@ -420,4 +421,11 @@ class Spark33Shims extends SparkShims { planner: SparkPlanner, plan: LogicalPlan): SparkPlan = QueryExecution.createSparkPlan(sparkSession, planner, plan) + + override def isFinalAdaptivePlan(p: AdaptiveSparkPlanExec): Boolean = { + val args = p.argString(Int.MaxValue) + val index = args.indexOf("isFinalPlan=") + assert(index >= 0) + args.substring(index + "isFinalPlan=".length).trim.toBoolean + } } diff --git a/shims/spark34/src/main/scala/org/apache/gluten/sql/shims/spark34/Spark34Shims.scala b/shims/spark34/src/main/scala/org/apache/gluten/sql/shims/spark34/Spark34Shims.scala index cdbeaa47838b..199c5313da46 100644 --- a/shims/spark34/src/main/scala/org/apache/gluten/sql/shims/spark34/Spark34Shims.scala +++ b/shims/spark34/src/main/scala/org/apache/gluten/sql/shims/spark34/Spark34Shims.scala @@ -44,6 +44,7 @@ import org.apache.spark.sql.connector.catalog.Table import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.connector.read.{HasPartitionKey, InputPartition, Scan} import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.parquet.ParquetFilters import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec, DataSourceV2ScanExecBase} @@ -663,4 +664,8 @@ class Spark34Shims extends SparkShims { planner: SparkPlanner, plan: LogicalPlan): SparkPlan = QueryExecution.createSparkPlan(sparkSession, planner, plan) + + override def isFinalAdaptivePlan(p: AdaptiveSparkPlanExec): Boolean = { + p.isFinalPlan + } } diff --git a/shims/spark35/src/main/scala/org/apache/gluten/sql/shims/spark35/Spark35Shims.scala b/shims/spark35/src/main/scala/org/apache/gluten/sql/shims/spark35/Spark35Shims.scala index d993cc0bfd20..0619cea66ad0 100644 --- a/shims/spark35/src/main/scala/org/apache/gluten/sql/shims/spark35/Spark35Shims.scala +++ b/shims/spark35/src/main/scala/org/apache/gluten/sql/shims/spark35/Spark35Shims.scala @@ -44,6 +44,7 @@ import org.apache.spark.sql.connector.catalog.Table import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.connector.read.{HasPartitionKey, InputPartition, Scan} import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.parquet.{ParquetFileFormat, ParquetFilters} import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec, DataSourceV2ScanExecBase} @@ -714,4 +715,8 @@ class Spark35Shims extends SparkShims { planner: SparkPlanner, plan: LogicalPlan): SparkPlan = QueryExecution.createSparkPlan(sparkSession, planner, plan) + + override def isFinalAdaptivePlan(p: AdaptiveSparkPlanExec): Boolean = { + p.isFinalPlan + } } diff --git a/shims/spark40/src/main/scala/org/apache/gluten/sql/shims/spark40/Spark40Shims.scala b/shims/spark40/src/main/scala/org/apache/gluten/sql/shims/spark40/Spark40Shims.scala index e5258eafa46c..72c9b272934d 100644 --- a/shims/spark40/src/main/scala/org/apache/gluten/sql/shims/spark40/Spark40Shims.scala +++ b/shims/spark40/src/main/scala/org/apache/gluten/sql/shims/spark40/Spark40Shims.scala @@ -46,6 +46,7 @@ import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.connector.read.{HasPartitionKey, InputPartition, Scan} import org.apache.spark.sql.connector.read.streaming.SparkDataStream import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.parquet.{ParquetFileFormat, ParquetFilters} import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec, BatchScanExecShim, DataSourceV2ScanExecBase} @@ -765,4 +766,8 @@ class Spark40Shims extends SparkShims { planner: SparkPlanner, plan: LogicalPlan): SparkPlan = QueryExecution.createSparkPlan(sparkSession, planner, plan) + + override def isFinalAdaptivePlan(p: AdaptiveSparkPlanExec): Boolean = { + p.isFinalPlan + } } diff --git a/shims/spark41/src/main/scala/org/apache/gluten/sql/shims/spark41/Spark41Shims.scala b/shims/spark41/src/main/scala/org/apache/gluten/sql/shims/spark41/Spark41Shims.scala index ea5f733614cb..8aad6394f074 100644 --- a/shims/spark41/src/main/scala/org/apache/gluten/sql/shims/spark41/Spark41Shims.scala +++ b/shims/spark41/src/main/scala/org/apache/gluten/sql/shims/spark41/Spark41Shims.scala @@ -45,6 +45,7 @@ import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.connector.read.{HasPartitionKey, InputPartition, Scan} import org.apache.spark.sql.connector.read.streaming.SparkDataStream import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.parquet.{ParquetFileFormat, ParquetFilters} import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec, BatchScanExecShim, DataSourceV2ScanExecBase} @@ -764,4 +765,8 @@ class Spark41Shims extends SparkShims { planner: SparkPlanner, plan: LogicalPlan): SparkPlan = QueryExecution.createSparkPlan(planner, plan) + + override def isFinalAdaptivePlan(p: AdaptiveSparkPlanExec): Boolean = { + p.isFinalPlan + } }