Skip to content

Commit 59afac5

Browse files
Ngone51Seongjin Cho
authored andcommitted
[SPARK-31391][SQL][TEST] Add AdaptiveTestUtils to ease the test of AQE
### What changes were proposed in this pull request? This PR adds `AdaptiveTestUtils` to make AQE test simpler, which includes: `DisableAdaptiveExecution` - a test tag to skip a single test case if AQE is enabled. `EnableAdaptiveExecutionSuite` - a helper trait to enable AQE for all tests except those tagged with `DisableAdaptiveExecution`. `DisableAdaptiveExecutionSuite` - a helper trait to disable AQE for all tests. `assertExceptionMessage` - a method to handle message of normal or AQE exception in a consistent way. `assertExceptionCause` - a method to handle cause of normal or AQE exception in a consistent way. ### Why are the changes needed? With this utils, we can: - reduce much more duplicate codes; - handle normal or AQE exception in a consistent way; - improve the stability of AQE tests; ### Does this PR introduce any user-facing change? No. ### How was this patch tested? Updated tests with the util. Closes apache#28162 from Ngone51/add_aqe_test_utils. Authored-by: yi.wu <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent 2734e4a commit 59afac5

22 files changed

Lines changed: 217 additions & 169 deletions

sql/core/src/test/scala/org/apache/spark/sql/ConfigBehaviorSuite.scala

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package org.apache.spark.sql
1919

2020
import org.apache.commons.math3.stat.inference.ChiSquareTest
2121

22+
import org.apache.spark.sql.execution.adaptive.DisableAdaptiveExecution
2223
import org.apache.spark.sql.internal.SQLConf
2324
import org.apache.spark.sql.test.SharedSparkSession
2425

@@ -27,7 +28,8 @@ class ConfigBehaviorSuite extends QueryTest with SharedSparkSession {
2728

2829
import testImplicits._
2930

30-
test("SPARK-22160 spark.sql.execution.rangeExchange.sampleSizePerPartition") {
31+
test("SPARK-22160 spark.sql.execution.rangeExchange.sampleSizePerPartition",
32+
DisableAdaptiveExecution("Post shuffle partition number can be different")) {
3133
// In this test, we run a sort and compute the histogram for partition size post shuffle.
3234
// With a high sample count, the partition size should be more evenly distributed, and has a
3335
// low chi-sq test value.
@@ -53,11 +55,8 @@ class ConfigBehaviorSuite extends QueryTest with SharedSparkSession {
5355
dist)
5456
}
5557

56-
// When enable AQE, the post partition number is changed.
5758
// And the ChiSquareTest result is also need updated. So disable AQE.
58-
withSQLConf(
59-
SQLConf.SHUFFLE_PARTITIONS.key -> numPartitions.toString,
60-
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false") {
59+
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> numPartitions.toString) {
6160
// The default chi-sq value should be low
6261
assert(computeChiSquareTest() < 100)
6362

sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -18,26 +18,15 @@
1818
package org.apache.spark.sql
1919

2020
import org.apache.spark.sql.execution._
21+
import org.apache.spark.sql.execution.adaptive.DisableAdaptiveExecutionSuite
2122
import org.apache.spark.sql.functions._
2223
import org.apache.spark.sql.internal.SQLConf
2324
import org.apache.spark.sql.test.SharedSparkSession
2425
import org.apache.spark.sql.types.StructType
2526

26-
class ExplainSuite extends QueryTest with SharedSparkSession {
27+
class ExplainSuite extends QueryTest with SharedSparkSession with DisableAdaptiveExecutionSuite {
2728
import testImplicits._
2829

29-
var originalValue: String = _
30-
protected override def beforeAll(): Unit = {
31-
super.beforeAll()
32-
originalValue = spark.conf.get(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key)
33-
spark.conf.set(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key, "false")
34-
}
35-
36-
protected override def afterAll(): Unit = {
37-
spark.conf.set(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key, originalValue)
38-
super.afterAll()
39-
}
40-
4130
private def getNormalizedExplain(df: DataFrame, mode: ExplainMode): String = {
4231
val output = new java.io.ByteArrayOutputStream()
4332
Console.withOut(output) {

sql/core/src/test/scala/org/apache/spark/sql/MetadataCacheSuite.scala

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ package org.apache.spark.sql
2020
import java.io.File
2121

2222
import org.apache.spark.{SparkConf, SparkException}
23+
import org.apache.spark.sql.execution.adaptive.AdaptiveTestUtils.assertExceptionMessage
2324
import org.apache.spark.sql.internal.SQLConf
2425
import org.apache.spark.sql.test.SharedSparkSession
2526

@@ -55,8 +56,8 @@ abstract class MetadataCacheSuite extends QueryTest with SharedSparkSession {
5556
val e = intercept[SparkException] {
5657
df.count()
5758
}
58-
assert(e.getMessage.contains("FileNotFoundException"))
59-
assert(e.getMessage.contains("recreating the Dataset/DataFrame involved"))
59+
assertExceptionMessage(e, "FileNotFoundException")
60+
assertExceptionMessage(e, "recreating the Dataset/DataFrame involved")
6061
}
6162
}
6263
}
@@ -84,8 +85,8 @@ class MetadataCacheV1Suite extends MetadataCacheSuite {
8485
val e = intercept[SparkException] {
8586
sql("select count(*) from view_refresh").first()
8687
}
87-
assert(e.getMessage.contains("FileNotFoundException"))
88-
assert(e.getMessage.contains("REFRESH"))
88+
assertExceptionMessage(e, "FileNotFoundException")
89+
assertExceptionMessage(e, "REFRESH")
8990

9091
// Refresh and we should be able to read it again.
9192
spark.catalog.refreshTable("view_refresh")

sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ import scala.collection.mutable.ArrayBuffer
2222
import org.apache.spark.sql.catalyst.expressions.SubqueryExpression
2323
import org.apache.spark.sql.catalyst.plans.logical.{Join, LogicalPlan, Sort}
2424
import org.apache.spark.sql.execution.{ColumnarToRowExec, ExecSubqueryExpression, FileSourceScanExec, InputAdapter, ReusedSubqueryExec, ScalarSubquery, SubqueryExec, WholeStageCodegenExec}
25-
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
25+
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper, DisableAdaptiveExecution}
2626
import org.apache.spark.sql.execution.datasources.FileScanRDD
2727
import org.apache.spark.sql.internal.SQLConf
2828
import org.apache.spark.sql.test.SharedSparkSession
@@ -1357,11 +1357,9 @@ class SubquerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark
13571357
}
13581358
}
13591359

1360-
test("SPARK-27279: Reuse Subquery") {
1360+
test("SPARK-27279: Reuse Subquery", DisableAdaptiveExecution("reuse is dynamic in AQE")) {
13611361
Seq(true, false).foreach { reuse =>
1362-
withSQLConf(SQLConf.SUBQUERY_REUSE_ENABLED.key -> reuse.toString,
1363-
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false") {
1364-
// when enable AQE, the reusedExchange is inserted when executed.
1362+
withSQLConf(SQLConf.SUBQUERY_REUSE_ENABLED.key -> reuse.toString) {
13651363
val df = sql(
13661364
"""
13671365
|SELECT (SELECT avg(key) FROM testData) + (SELECT avg(key) FROM testData)

sql/core/src/test/scala/org/apache/spark/sql/execution/DeprecatedWholeStageCodegenSuite.scala

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -18,30 +18,27 @@
1818
package org.apache.spark.sql.execution
1919

2020
import org.apache.spark.sql.QueryTest
21-
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
21+
import org.apache.spark.sql.execution.adaptive.DisableAdaptiveExecutionSuite
2222
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
2323
import org.apache.spark.sql.expressions.scalalang.typed
24-
import org.apache.spark.sql.internal.SQLConf
2524
import org.apache.spark.sql.test.SharedSparkSession
2625

26+
// Disable AQE because the WholeStageCodegenExec is added when running QueryStageExec
2727
@deprecated("This test suite will be removed.", "3.0.0")
2828
class DeprecatedWholeStageCodegenSuite extends QueryTest
2929
with SharedSparkSession
30-
with AdaptiveSparkPlanHelper {
30+
with DisableAdaptiveExecutionSuite {
3131

3232
test("simple typed UDAF should be included in WholeStageCodegen") {
33-
withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false") {
34-
// With enable AQE, the WholeStageCodegenExec rule is applied when running QueryStageExec.
35-
import testImplicits._
33+
import testImplicits._
3634

37-
val ds = Seq(("a", 10), ("b", 1), ("b", 2), ("c", 1)).toDS()
38-
.groupByKey(_._1).agg(typed.sum(_._2))
35+
val ds = Seq(("a", 10), ("b", 1), ("b", 2), ("c", 1)).toDS()
36+
.groupByKey(_._1).agg(typed.sum(_._2))
3937

40-
val plan = ds.queryExecution.executedPlan
41-
assert(find(plan)(p =>
42-
p.isInstanceOf[WholeStageCodegenExec] &&
43-
p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[HashAggregateExec]).isDefined)
44-
assert(ds.collect() === Array(("a", 10.0), ("b", 3.0), ("c", 1.0)))
45-
}
38+
val plan = ds.queryExecution.executedPlan
39+
assert(plan.find(p =>
40+
p.isInstanceOf[WholeStageCodegenExec] &&
41+
p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[HashAggregateExec]).isDefined)
42+
assert(ds.collect() === Array(("a", 10.0), ("b", 3.0), ("c", 1.0)))
4643
}
4744
}

sql/core/src/test/scala/org/apache/spark/sql/execution/LogicalPlanTagInSparkPlanSuite.scala

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -22,29 +22,17 @@ import scala.reflect.ClassTag
2222
import org.apache.spark.sql.TPCDSQuerySuite
2323
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete, Final}
2424
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Generate, Join, LocalRelation, LogicalPlan, Range, Sample, Union, Window}
25+
import org.apache.spark.sql.execution.adaptive.DisableAdaptiveExecutionSuite
2526
import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec}
2627
import org.apache.spark.sql.execution.columnar.{InMemoryRelation, InMemoryTableScanExec}
2728
import org.apache.spark.sql.execution.datasources.LogicalRelation
2829
import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec, DataSourceV2Relation}
2930
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec}
3031
import org.apache.spark.sql.execution.joins._
3132
import org.apache.spark.sql.execution.window.WindowExec
32-
import org.apache.spark.sql.internal.SQLConf
3333

34-
class LogicalPlanTagInSparkPlanSuite extends TPCDSQuerySuite {
35-
36-
var originalValue: String = _
37-
// when enable AQE, the 'AdaptiveSparkPlanExec' node does not have a logical plan link
38-
override def beforeAll(): Unit = {
39-
super.beforeAll()
40-
originalValue = spark.conf.get(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key)
41-
spark.conf.set(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key, "false")
42-
}
43-
44-
override def afterAll(): Unit = {
45-
spark.conf.set(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key, originalValue)
46-
super.afterAll()
47-
}
34+
// Disable AQE because AdaptiveSparkPlanExec does not have a logical plan link
35+
class LogicalPlanTagInSparkPlanSuite extends TPCDSQuerySuite with DisableAdaptiveExecutionSuite {
4836

4937
override protected def checkGeneratedCode(
5038
plan: SparkPlan, checkMethodCodeSize: Boolean = true): Unit = {

sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.expressions._
2424
import org.apache.spark.sql.catalyst.plans._
2525
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Range, Repartition, Sort, Union}
2626
import org.apache.spark.sql.catalyst.plans.physical._
27-
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
27+
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper, DisableAdaptiveExecution}
2828
import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec}
2929
import org.apache.spark.sql.execution.columnar.{InMemoryRelation, InMemoryTableScanExec}
3030
import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReusedExchangeExec, ReuseExchange, ShuffleExchangeExec}
@@ -752,7 +752,8 @@ class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper {
752752
}
753753

754754
test("SPARK-24556: always rewrite output partitioning in ReusedExchangeExec " +
755-
"and InMemoryTableScanExec") {
755+
"and InMemoryTableScanExec",
756+
DisableAdaptiveExecution("Reuse is dynamic in AQE")) {
756757
def checkOutputPartitioningRewrite(
757758
plans: Seq[SparkPlan],
758759
expectedPartitioningClass: Class[_]): Unit = {
@@ -782,8 +783,7 @@ class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper {
782783
checkOutputPartitioningRewrite(inMemoryScan, expectedPartitioningClass)
783784
}
784785
// when enable AQE, the reusedExchange is inserted when executed.
785-
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
786-
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false") {
786+
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
787787
// ReusedExchange is HashPartitioning
788788
val df1 = Seq(1 -> "a").toDF("i", "j").repartition($"i")
789789
val df2 = Seq(1 -> "a").toDF("i", "j").repartition($"i")

sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package org.apache.spark.sql.execution
1919

2020
import org.apache.spark.sql.{Dataset, QueryTest, Row, SaveMode}
2121
import org.apache.spark.sql.catalyst.expressions.codegen.{ByteCodeStats, CodeAndComment, CodeGenerator}
22+
import org.apache.spark.sql.execution.adaptive.DisableAdaptiveExecutionSuite
2223
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
2324
import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec
2425
import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec
@@ -28,23 +29,12 @@ import org.apache.spark.sql.internal.SQLConf
2829
import org.apache.spark.sql.test.SharedSparkSession
2930
import org.apache.spark.sql.types.{IntegerType, StringType, StructType}
3031

31-
class WholeStageCodegenSuite extends QueryTest with SharedSparkSession {
32+
// Disable AQE because the WholeStageCodegenExec is added when running QueryStageExec
33+
class WholeStageCodegenSuite extends QueryTest with SharedSparkSession
34+
with DisableAdaptiveExecutionSuite {
3235

3336
import testImplicits._
3437

35-
var originalValue: String = _
36-
// With on AQE, the WholeStageCodegenExec is added when running QueryStageExec.
37-
override def beforeAll(): Unit = {
38-
super.beforeAll()
39-
originalValue = spark.conf.get(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key)
40-
spark.conf.set(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key, "false")
41-
}
42-
43-
override def afterAll(): Unit = {
44-
spark.conf.set(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key, originalValue)
45-
super.afterAll()
46-
}
47-
4838
test("range/filter should be combined") {
4939
val df = spark.range(10).filter("id = 1").selectExpr("id + 1")
5040
val plan = df.queryExecution.executedPlan
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.execution.adaptive
19+
20+
import java.io.{PrintWriter, StringWriter}
21+
22+
import org.scalactic.source.Position
23+
import org.scalatest.Tag
24+
25+
import org.apache.spark.sql.internal.SQLConf
26+
import org.apache.spark.sql.test.SQLTestUtils
27+
28+
/**
29+
* Test with this tag will be ignored if the test suite extends `EnableAdaptiveExecutionSuite`.
30+
* Otherwise, it will be executed with adaptive execution disabled.
31+
*/
32+
case class DisableAdaptiveExecution(reason: String) extends Tag("DisableAdaptiveExecution")
33+
34+
/**
35+
* Helper trait that enables AQE for all tests regardless of default config values, except that
36+
* tests tagged with [[DisableAdaptiveExecution]] will be skipped.
37+
*/
38+
trait EnableAdaptiveExecutionSuite extends SQLTestUtils {
39+
protected val forceApply = true
40+
41+
override protected def test(testName: String, testTags: Tag*)(testFun: => Any)
42+
(implicit pos: Position): Unit = {
43+
if (testTags.exists(_.isInstanceOf[DisableAdaptiveExecution])) {
44+
// we ignore the test here but assume that another test suite which extends
45+
// `DisableAdaptiveExecutionSuite` will test it anyway to ensure test coverage
46+
ignore(testName + " (disabled when AQE is on)", testTags: _*)(testFun)
47+
} else {
48+
super.test(testName, testTags: _*) {
49+
withSQLConf(
50+
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
51+
SQLConf.ADAPTIVE_EXECUTION_FORCE_APPLY.key -> forceApply.toString) {
52+
testFun
53+
}
54+
}
55+
}
56+
}
57+
}
58+
59+
/**
60+
* Helper trait that disables AQE for all tests regardless of default config values.
61+
*/
62+
trait DisableAdaptiveExecutionSuite extends SQLTestUtils {
63+
override protected def test(testName: String, testTags: Tag*)(testFun: => Any)
64+
(implicit pos: Position): Unit = {
65+
super.test(testName, testTags: _*) {
66+
withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false") {
67+
testFun
68+
}
69+
}
70+
}
71+
}
72+
73+
object AdaptiveTestUtils {
74+
def assertExceptionMessage(e: Exception, expected: String): Unit = {
75+
val stringWriter = new StringWriter()
76+
e.printStackTrace(new PrintWriter(stringWriter))
77+
val errorMsg = stringWriter.toString
78+
assert(errorMsg.contains(expected))
79+
}
80+
81+
def assertExceptionCause(t: Throwable, causeClass: Class[_]): Unit = {
82+
var c = t.getCause
83+
var foundCause = false
84+
while (c != null && !foundCause) {
85+
if (causeClass.isAssignableFrom(c.getClass)) {
86+
foundCause = true
87+
} else {
88+
c = c.getCause
89+
}
90+
}
91+
assert(foundCause, s"Can not find cause: $causeClass")
92+
}
93+
}

sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ import org.apache.spark.sql.{functions => F, _}
3535
import org.apache.spark.sql.catalyst.json._
3636
import org.apache.spark.sql.catalyst.util.DateTimeUtils
3737
import org.apache.spark.sql.execution.ExternalRDD
38+
import org.apache.spark.sql.execution.adaptive.AdaptiveTestUtils.assertExceptionMessage
3839
import org.apache.spark.sql.execution.datasources.DataSource
3940
import org.apache.spark.sql.internal.SQLConf
4041
import org.apache.spark.sql.test.SharedSparkSession
@@ -2192,9 +2193,8 @@ abstract class JsonSuite extends QueryTest with SharedSparkSession with TestJson
21922193
.json(testFile(fileName))
21932194
.count()
21942195
}
2195-
val errMsg = exception.getMessage
21962196

2197-
assert(errMsg.contains("Malformed records are detected in record parsing"))
2197+
assertExceptionMessage(exception, "Malformed records are detected in record parsing")
21982198
}
21992199

22002200
def checkEncoding(expectedEncoding: String, pathToJsonFiles: String,

0 commit comments

Comments
 (0)