Skip to content
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -514,7 +514,9 @@ class CHSparkPlanExecApi extends SparkPlanExecApi with Logging {
val childWithAdapter = ColumnarCollapseTransformStages.wrapInputIteratorTransformer(child)
WholeStageTransformer(
ProjectExecTransformer(child.output ++ appendedProjections, childWithAdapter))(
ColumnarCollapseTransformStages.transformStageCounter.incrementAndGet()
ColumnarCollapseTransformStages
.getTransformStageCounter(childWithAdapter)
.incrementAndGet()
)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,10 @@ object CHAggAndShuffleBenchmark extends SqlBasedBenchmark with CHSqlBasedBenchma
// Get the `FileSourceScanExecTransformer`
val fileScan = executedPlan.collect { case scan: FileSourceScanExecTransformer => scan }.head
val scanStage = WholeStageTransformer(fileScan)(
ColumnarCollapseTransformStages.transformStageCounter.incrementAndGet())
ColumnarCollapseTransformStages
.getTransformStageCounter(fileScan)
.incrementAndGet()
)
val scanStageRDD = scanStage.executeColumnar()

// Get the total row count
Expand Down Expand Up @@ -200,7 +203,9 @@ object CHAggAndShuffleBenchmark extends SqlBasedBenchmark with CHSqlBasedBenchma
val projectFilter = executedPlan.collect { case project: ProjectExecTransformer => project }
if (projectFilter.nonEmpty) {
val projectFilterStage = WholeStageTransformer(projectFilter.head)(
ColumnarCollapseTransformStages.transformStageCounter.incrementAndGet())
ColumnarCollapseTransformStages
.getTransformStageCounter(projectFilter.head)
.incrementAndGet())
val projectFilterStageRDD = projectFilterStage.executeColumnar()

chAllStagesBenchmark.addCase(s"Project Stage", executedCnt) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -260,9 +260,8 @@ object GlutenDeltaFileFormatWriter extends LoggingShims {
nativeSortPlan
}
val newPlan = sortPlan.child match {
case WholeStageTransformer(wholeStageChild, materializeInput) =>
WholeStageTransformer(addNativeSort(wholeStageChild),
materializeInput)(ColumnarCollapseTransformStages.transformStageCounter.incrementAndGet())
case wst @ WholeStageTransformer(wholeStageChild, _) =>
wst.withNewChildren(Seq(addNativeSort(wholeStageChild)))
case other =>
Transitions.toBatchPlan(sortPlan, VeloxBatchType)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ object VeloxRuleApi {
c => GlutenAutoAdjustStageResourceProfile(new GlutenConfig(c.sqlConf), c.session))
injector.injectFinal(c => GlutenFallbackReporter(new GlutenConfig(c.sqlConf), c.session))
injector.injectFinal(_ => RemoveFallbackTagRule())
injector.injectFinal(_ => GenerateTransformStageId())
}

/**
Expand Down Expand Up @@ -250,5 +251,6 @@ object VeloxRuleApi {
injector.injectPostTransform(
c => GlutenFallbackReporter(new GlutenConfig(c.sqlConf), c.session))
injector.injectPostTransform(_ => RemoveFallbackTagRule())
injector.injectPostTransform(_ => GenerateTransformStageId())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we move the rule right after ColumnarCollapseTransformStages? To make it clear that GenerateTransformStageId doesn't depend on other rules.

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ case class TakeOrderedAndProjectExecTransformer(
LimitExecTransformer(localSortPlan, limitBeforeShuffleOffset, limit)
}
val transformStageCounter: AtomicInteger =
ColumnarCollapseTransformStages.transformStageCounter
ColumnarCollapseTransformStages.getTransformStageCounter(child)
val finalLimitPlan = if (hasShuffle) {
limitBeforeShuffle
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ trait UnaryTransformSupport extends TransformSupport with UnaryExecNode {
}

case class WholeStageTransformer(child: SparkPlan, materializeInput: Boolean = false)(
val transformStageId: Int
var transformStageId: Int
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why changing to var?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When AQE is on, transformStageId always starts at 0 within each query stage. In this case transformStageId is not globally unique within the whole plan tree. After all query stages have been executed, the RegenerateTransformStageId rule traverses the whole plan tree and updates this value incrementally to a unique id.

) extends WholeStageTransformerGenerateTreeStringShim
with UnaryTransformSupport {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,10 +139,7 @@ case class InputIteratorTransformer(child: SparkPlan) extends UnaryTransformSupp
* created, e.g. for special fallback handling when an existing WholeStageTransformer failed to
* generate/compile code.
*/
case class ColumnarCollapseTransformStages(
glutenConf: GlutenConfig,
transformStageCounter: AtomicInteger = ColumnarCollapseTransformStages.transformStageCounter)
extends Rule[SparkPlan] {
case class ColumnarCollapseTransformStages(glutenConf: GlutenConfig) extends Rule[SparkPlan] {

def apply(plan: SparkPlan): SparkPlan = {
insertWholeStageTransformer(plan)
Expand Down Expand Up @@ -176,8 +173,8 @@ case class ColumnarCollapseTransformStages(
private def insertWholeStageTransformer(plan: SparkPlan): SparkPlan = {
plan match {
case t if supportTransform(t) =>
WholeStageTransformer(t.withNewChildren(t.children.map(insertInputIteratorTransformer)))(
transformStageCounter.incrementAndGet())
// transformStageId will be updated by rule `GenerateTransformStageId`.
WholeStageTransformer(t.withNewChildren(t.children.map(insertInputIteratorTransformer)))(-1)
case other =>
other.withNewChildren(other.children.map(insertWholeStageTransformer))
}
Expand Down Expand Up @@ -213,9 +210,20 @@ case class ColumnarInputAdapter(child: SparkPlan)
}

object ColumnarCollapseTransformStages {
val transformStageCounter = new AtomicInteger(0)

def wrapInputIteratorTransformer(plan: SparkPlan): TransformSupport = {
InputIteratorTransformer(ColumnarInputAdapter(plan))
}

def getTransformStageCounter(plan: SparkPlan): AtomicInteger = {
new AtomicInteger(findMaxTransformStageId(plan))
}

private def findMaxTransformStageId(plan: SparkPlan): Int = {
plan match {
case wst: WholeStageTransformer =>
wst.transformStageId
case _ =>
plan.children.map(findMaxTransformStageId).foldLeft(0)(Math.max)
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution

import org.apache.gluten.exception.GlutenException
import org.apache.gluten.execution.WholeStageTransformer
import org.apache.gluten.sql.shims.SparkShimLoader

import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AdaptiveSparkPlanHelper, BroadcastQueryStageExec, ShuffleQueryStageExec}
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeLike, ReusedExchangeExec, ShuffleExchangeLike}

import java.util
import java.util.Collections.newSetFromMap
import java.util.concurrent.atomic.AtomicInteger

/**
* Generate `transformStageId` for `WholeStageTransformerExec`. This rule updates the whole plan
* tree with * incremental and unique transform stage id before the final execution.
*
* In Spark, the whole stage id is generated by incrementing a global counter. In Gluten, it's not
* possible to use global counter for id generation, especially in the case of AQE.
*/
case class GenerateTransformStageId() extends Rule[SparkPlan] with AdaptiveSparkPlanHelper {
private val transformStageCounter: AtomicInteger = new AtomicInteger(0)

private val wholeStageTransformerCache =
newSetFromMap[WholeStageTransformer](new util.IdentityHashMap())

def apply(plan: SparkPlan): SparkPlan = {
updateStageId(plan)
plan
}

private def updateStageId(plan: SparkPlan): Unit = {
plan match {
case b: BroadcastQueryStageExec =>
b.plan match {
case b: BroadcastExchangeLike => updateStageId(b)
case _: ReusedExchangeExec =>
case _ =>
throw new GlutenException(s"wrong plan for broadcast stage:\n ${plan.treeString}")
}
case s: ShuffleQueryStageExec =>
s.plan match {
case s: ShuffleExchangeLike => updateStageId(s)
case _: ReusedExchangeExec =>
case _ =>
throw new GlutenException(s"wrong plan for shuffle stage:\n ${plan.treeString}")
}
case aqe: AdaptiveSparkPlanExec if SparkShimLoader.getSparkShims.isFinalAdaptivePlan(aqe) =>
updateStageId(stripAQEPlan(aqe))
case wst: WholeStageTransformer if !wholeStageTransformerCache.contains(wst) =>
updateStageId(wst.child)
wst.transformStageId = transformStageCounter.incrementAndGet()
wholeStageTransformerCache.add(wst)
case plan =>
plan.subqueries.foreach(updateStageId)
plan.children.foreach(updateStageId)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -95,10 +95,7 @@ object GlutenImplicits {
}

private def isFinalAdaptivePlan(p: AdaptiveSparkPlanExec): Boolean = {
val args = p.argString(Int.MaxValue)
val index = args.indexOf("isFinalPlan=")
assert(index >= 0)
args.substring(index + "isFinalPlan=".length).trim.toBoolean
SparkShimLoader.getSparkShims.isFinalAdaptivePlan(p)
}

private def collectFallbackNodes(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ import org.apache.gluten.extension.columnar.heuristic.HeuristicTransform
import org.apache.gluten.extension.columnar.transition.{InsertTransitions, RemoveTransitions}

import org.apache.spark.sql.execution.{ColumnarCollapseTransformStages, SparkPlan}
import org.apache.spark.sql.execution.ColumnarCollapseTransformStages.transformStageCounter

trait GlutenFormatWriterInjectsBase extends GlutenFormatWriterInjects {
private lazy val transform = HeuristicTransform.static()
Expand Down Expand Up @@ -66,7 +65,7 @@ trait GlutenFormatWriterInjectsBase extends GlutenFormatWriterInjects {
// and cannot provide const-ness.
val transformedWithAdapter = injectAdapter(transformed)
val wst = WholeStageTransformer(transformedWithAdapter, materializeInput = true)(
transformStageCounter.incrementAndGet())
ColumnarCollapseTransformStages.getTransformStageCounter(transformed).incrementAndGet())
val wstWithTransitions = BackendsApiManager.getSparkPlanExecApiInstance.genColumnarToCarrierRow(
InsertTransitions.create(outputsColumnar = true, wst.batchType()).apply(wst))
wstWithTransitions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -858,7 +858,7 @@ class VeloxTestSettings extends BackendTestSettings {
// ORC related
.exclude("SPARK-37965: Spark support read/write orc file with invalid char in field name")
.exclude("SPARK-38173: Quoted column cannot be recognized correctly when quotedRegexColumnNames is true")
// TODO: fix in Spark-4.0
// Rewrite with Gluten's explained result.
.exclude("SPARK-47939: Explain should work with parameterized queries")
enableSuite[GlutenSQLQueryTestSuite]
enableSuite[GlutenStatisticsCollectionSuite]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,4 +154,121 @@ class GlutenSQLQuerySuite extends SQLQuerySuite with GlutenSQLTestsTrait {
assert(inputOutputPairs.map(_._1).sum == numRows)
}
}

testGluten("SPARK-47939: Explain should work with parameterized queries") {
def checkQueryPlan(df: DataFrame, plan: String): Unit = assert(
df.collect()
.map(_.getString(0))
.map(_.replaceAll("#[0-9]+", "#N"))
// Remove the backend keyword in c2r/r2c.
.map(_.replaceAll("[A-Za-z]*ColumnarToRow", "ColumnarToRow"))
.map(_.replaceAll("RowTo[A-Za-z]*Columnar", "RowToColumnar"))
=== Array(plan.stripMargin)
)

checkQueryPlan(
spark.sql("explain select ?", Array(1)),
"""== Physical Plan ==
|ColumnarToRow
|+- ^(1) ProjectExecTransformer [1 AS 1#N]
| +- ^(1) InputIteratorTransformer[]
| +- RowToColumnar
| +- *(1) Scan OneRowRelation[]
|
|"""
)
checkQueryPlan(
spark.sql("explain select :first", Map("first" -> 1)),
"""== Physical Plan ==
|ColumnarToRow
|+- ^(1) ProjectExecTransformer [1 AS 1#N]
| +- ^(1) InputIteratorTransformer[]
| +- RowToColumnar
| +- *(1) Scan OneRowRelation[]
|
|"""
)

checkQueryPlan(
spark.sql("explain explain explain select ?", Array(1)),
"""== Physical Plan ==
|Execute ExplainCommand
| +- ExplainCommand ExplainCommand 'PosParameterizedQuery [1], SimpleMode, SimpleMode

|"""
)
checkQueryPlan(
spark.sql("explain explain explain select :first", Map("first" -> 1)),
// scalastyle:off
"""== Physical Plan ==
|Execute ExplainCommand
| +- ExplainCommand ExplainCommand 'NameParameterizedQuery [first], [1], SimpleMode, SimpleMode

|"""
// scalastyle:on
)

checkQueryPlan(
spark.sql("explain describe select ?", Array(1)),
"""== Physical Plan ==
|Execute DescribeQueryCommand
| +- DescribeQueryCommand select ?

|"""
)
checkQueryPlan(
spark.sql("explain describe select :first", Map("first" -> 1)),
"""== Physical Plan ==
|Execute DescribeQueryCommand
| +- DescribeQueryCommand select :first

|"""
)

checkQueryPlan(
spark.sql("explain extended select * from values (?, ?) t(x, y)", Array(1, "a")),
"""== Parsed Logical Plan ==
|'PosParameterizedQuery [1, a]
|+- 'Project [*]
| +- 'SubqueryAlias t
| +- 'UnresolvedInlineTable [x, y], [[posparameter(39), posparameter(42)]]

|== Analyzed Logical Plan ==
|x: int, y: string
|Project [x#N, y#N]
|+- SubqueryAlias t
| +- LocalRelation [x#N, y#N]

|== Optimized Logical Plan ==
|LocalRelation [x#N, y#N]

|== Physical Plan ==
|LocalTableScan [x#N, y#N]
|"""
)
checkQueryPlan(
spark.sql(
"explain extended select * from values (:first, :second) t(x, y)",
Map("first" -> 1, "second" -> "a")
),
"""== Parsed Logical Plan ==
|'NameParameterizedQuery [first, second], [1, a]
|+- 'Project [*]
| +- 'SubqueryAlias t
| +- 'UnresolvedInlineTable [x, y], [[namedparameter(first), namedparameter(second)]]

|== Analyzed Logical Plan ==
|x: int, y: string
|Project [x#N, y#N]
|+- SubqueryAlias t
| +- LocalRelation [x#N, y#N]

|== Optimized Logical Plan ==
|LocalRelation [x#N, y#N]

|== Physical Plan ==
|LocalTableScan [x#N, y#N]
|"""
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ import org.apache.spark.sql.connector.expressions.Transform
import org.apache.spark.sql.connector.read.{InputPartition, Scan}
import org.apache.spark.sql.connector.read.streaming.SparkDataStream
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.execution.datasources.parquet.ParquetFilters
import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec, DataSourceV2ScanExecBase}
Expand Down Expand Up @@ -365,4 +366,6 @@ trait SparkShims {
sparkSession: SparkSession,
planner: SparkPlanner,
plan: LogicalPlan): SparkPlan

def isFinalAdaptivePlan(p: AdaptiveSparkPlanExec): Boolean
}
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ import org.apache.spark.sql.catalyst.util.RebaseDateTime.RebaseSpec
import org.apache.spark.sql.connector.catalog.Table
import org.apache.spark.sql.connector.expressions.Transform
import org.apache.spark.sql.execution.{FileSourceScanExec, PartitionedFileUtil, QueryExecution, SparkPlan, SparkPlanner}
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.execution.datasources.FileFormatWriter.Empty2Null
import org.apache.spark.sql.execution.datasources.parquet.ParquetFilters
Expand Down Expand Up @@ -315,4 +316,11 @@ class Spark32Shims extends SparkShims {
planner: SparkPlanner,
plan: LogicalPlan): SparkPlan =
QueryExecution.createSparkPlan(sparkSession, planner, plan)

override def isFinalAdaptivePlan(p: AdaptiveSparkPlanExec): Boolean = {
val args = p.argString(Int.MaxValue)
val index = args.indexOf("isFinalPlan=")
assert(index >= 0)
args.substring(index + "isFinalPlan=".length).trim.toBoolean
}
}
Loading
Loading