Skip to content

Commit d68198d

Browse files
committed
[SPARK-23223][SQL] Make stacking dataset transforms more performant
## What changes were proposed in this pull request? It is a common pattern to apply multiple transforms to a `Dataset` (using `Dataset.withColumn` for example. This is currently quite expensive because we run `CheckAnalysis` on the full plan and create an encoder for each intermediate `Dataset`. This PR extends the usage of the `AnalysisBarrier` to include `CheckAnalysis`. By doing this we hide the already analyzed plan from `CheckAnalysis` because barrier is a `LeafNode`. The `AnalysisBarrier` is in the `FinishAnalysis` phase of the optimizer. We also make binding the `Dataset` encoder lazy. The bound encoder is only needed when we materialize the dataset. ## How was this patch tested? Existing test should cover this. Author: Herman van Hovell <hvanhovell@databricks.com> Closes #20402 from hvanhovell/SPARK-23223. (cherry picked from commit 2d903cf) Signed-off-by: Herman van Hovell <hvanhovell@databricks.com>
1 parent 4059454 commit d68198d

6 files changed

Lines changed: 25 additions & 21 deletions

File tree

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,19 @@ class Analyzer(
9898
this(catalog, conf, conf.optimizerMaxIterations)
9999
}
100100

101+
def executeAndCheck(plan: LogicalPlan): LogicalPlan = {
102+
val analyzed = execute(plan)
103+
try {
104+
checkAnalysis(analyzed)
105+
EliminateBarriers(analyzed)
106+
} catch {
107+
case e: AnalysisException =>
108+
val ae = new AnalysisException(e.message, e.line, e.startPosition, Option(analyzed))
109+
ae.setStackTrace(e.getStackTrace)
110+
throw ae
111+
}
112+
}
113+
101114
override def execute(plan: LogicalPlan): LogicalPlan = {
102115
AnalysisContext.reset()
103116
try {
@@ -178,8 +191,7 @@ class Analyzer(
178191
Batch("Subquery", Once,
179192
UpdateOuterReferences),
180193
Batch("Cleanup", fixedPoint,
181-
CleanupAliases,
182-
EliminateBarriers)
194+
CleanupAliases)
183195
)
184196

185197
/**

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,7 @@ trait CheckAnalysis extends PredicateHelper {
348348
}
349349
extendedCheckRules.foreach(_(plan))
350350
plan.foreachUp {
351+
case AnalysisBarrier(child) if !child.resolved => checkAnalysis(child)
351352
case o if !o.resolved => failAnalysis(s"unresolved operator ${o.simpleString}")
352353
case _ =>
353354
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,7 @@ trait AnalysisTest extends PlanTest {
5454
expectedPlan: LogicalPlan,
5555
caseSensitive: Boolean = true): Unit = {
5656
val analyzer = getAnalyzer(caseSensitive)
57-
val actualPlan = analyzer.execute(inputPlan)
58-
analyzer.checkAnalysis(actualPlan)
57+
val actualPlan = analyzer.executeAndCheck(inputPlan)
5958
comparePlans(actualPlan, expectedPlan)
6059
}
6160

sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,11 @@ import org.apache.spark.util.Utils
6262

6363
private[sql] object Dataset {
6464
def apply[T: Encoder](sparkSession: SparkSession, logicalPlan: LogicalPlan): Dataset[T] = {
65-
new Dataset(sparkSession, logicalPlan, implicitly[Encoder[T]])
65+
val dataset = new Dataset(sparkSession, logicalPlan, implicitly[Encoder[T]])
66+
// Eagerly bind the encoder so we verify that the encoder matches the underlying
67+
// schema. The user will get an error if this is not the case.
68+
dataset.deserializer
69+
dataset
6670
}
6771

6872
def ofRows(sparkSession: SparkSession, logicalPlan: LogicalPlan): DataFrame = {
@@ -204,7 +208,7 @@ class Dataset[T] private[sql](
204208

205209
// The deserializer expression which can be used to build a projection and turn rows to objects
206210
// of type T, after collecting rows to the driver side.
207-
private val deserializer =
211+
private lazy val deserializer =
208212
exprEnc.resolveAndBind(logicalPlan.output, sparkSession.sessionState.analyzer).deserializer
209213

210214
private implicit def classTag = exprEnc.clsTag

sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -44,19 +44,7 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) {
4444
// TODO: Move the planner an optimizer into here from SessionState.
4545
protected def planner = sparkSession.sessionState.planner
4646

47-
def assertAnalyzed(): Unit = {
48-
// Analyzer is invoked outside the try block to avoid calling it again from within the
49-
// catch block below.
50-
analyzed
51-
try {
52-
sparkSession.sessionState.analyzer.checkAnalysis(analyzed)
53-
} catch {
54-
case e: AnalysisException =>
55-
val ae = new AnalysisException(e.message, e.line, e.startPosition, Option(analyzed))
56-
ae.setStackTrace(e.getStackTrace)
57-
throw ae
58-
}
59-
}
47+
def assertAnalyzed(): Unit = analyzed
6048

6149
def assertSupported(): Unit = {
6250
if (sparkSession.sessionState.conf.isUnsupportedOperationCheckEnabled) {
@@ -66,7 +54,7 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) {
6654

6755
lazy val analyzed: LogicalPlan = {
6856
SparkSession.setActiveSession(sparkSession)
69-
sparkSession.sessionState.analyzer.execute(logical)
57+
sparkSession.sessionState.analyzer.executeAndCheck(logical)
7058
}
7159

7260
lazy val withCachedData: LogicalPlan = {

sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -575,7 +575,7 @@ private[hive] class TestHiveQueryExecution(
575575
logDebug(s"Query references test tables: ${referencedTestTables.mkString(", ")}")
576576
referencedTestTables.foreach(sparkSession.loadTestTable)
577577
// Proceed with analysis.
578-
sparkSession.sessionState.analyzer.execute(logical)
578+
sparkSession.sessionState.analyzer.executeAndCheck(logical)
579579
}
580580
}
581581

0 commit comments

Comments
 (0)