Skip to content

Commit ac55879

Browse files
committed
Temporary unregister H2Dialect in tests
1 parent 90fcaf3 commit ac55879

2 files changed

Lines changed: 29 additions & 19 deletions

File tree

sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -770,9 +770,14 @@ class JDBCSuite extends QueryTest
770770
}
771771

772772
test("Dialect unregister") {
773-
JdbcDialects.registerDialect(testH2Dialect)
774-
JdbcDialects.unregisterDialect(testH2Dialect)
775-
assert(JdbcDialects.get(urlWithUserAndPass) == NoopDialect)
773+
JdbcDialects.unregisterDialect(H2Dialect)
774+
try {
775+
JdbcDialects.registerDialect(testH2Dialect)
776+
JdbcDialects.unregisterDialect(testH2Dialect)
777+
assert(JdbcDialects.get(urlWithUserAndPass) == NoopDialect)
778+
} finally {
779+
JdbcDialects.registerDialect(H2Dialect)
780+
}
776781
}
777782

778783
test("Aggregated dialects") {

sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -194,24 +194,29 @@ class JDBCWriteSuite extends SharedSparkSession with BeforeAndAfter {
194194
}
195195

196196
test("Truncate") {
197-
JdbcDialects.registerDialect(testH2Dialect)
198-
val df = spark.createDataFrame(sparkContext.parallelize(arr2x2), schema2)
199-
val df2 = spark.createDataFrame(sparkContext.parallelize(arr1x2), schema2)
200-
val df3 = spark.createDataFrame(sparkContext.parallelize(arr2x3), schema3)
201-
202-
df.write.jdbc(url1, "TEST.TRUNCATETEST", properties)
203-
df2.write.mode(SaveMode.Overwrite).option("truncate", true)
204-
.jdbc(url1, "TEST.TRUNCATETEST", properties)
205-
assert(1 === spark.read.jdbc(url1, "TEST.TRUNCATETEST", properties).count())
206-
assert(2 === spark.read.jdbc(url1, "TEST.TRUNCATETEST", properties).collect()(0).length)
197+
JdbcDialects.unregisterDialect(H2Dialect)
198+
try {
199+
JdbcDialects.registerDialect(testH2Dialect)
200+
val df = spark.createDataFrame(sparkContext.parallelize(arr2x2), schema2)
201+
val df2 = spark.createDataFrame(sparkContext.parallelize(arr1x2), schema2)
202+
val df3 = spark.createDataFrame(sparkContext.parallelize(arr2x3), schema3)
207203

208-
val m = intercept[AnalysisException] {
209-
df3.write.mode(SaveMode.Overwrite).option("truncate", true)
204+
df.write.jdbc(url1, "TEST.TRUNCATETEST", properties)
205+
df2.write.mode(SaveMode.Overwrite).option("truncate", true)
210206
.jdbc(url1, "TEST.TRUNCATETEST", properties)
211-
}.getMessage
212-
assert(m.contains("Column \"seq\" not found"))
213-
assert(0 === spark.read.jdbc(url1, "TEST.TRUNCATETEST", properties).count())
214-
JdbcDialects.unregisterDialect(testH2Dialect)
207+
assert(1 === spark.read.jdbc(url1, "TEST.TRUNCATETEST", properties).count())
208+
assert(2 === spark.read.jdbc(url1, "TEST.TRUNCATETEST", properties).collect()(0).length)
209+
210+
val m = intercept[AnalysisException] {
211+
df3.write.mode(SaveMode.Overwrite).option("truncate", true)
212+
.jdbc(url1, "TEST.TRUNCATETEST", properties)
213+
}.getMessage
214+
assert(m.contains("Column \"seq\" not found"))
215+
assert(0 === spark.read.jdbc(url1, "TEST.TRUNCATETEST", properties).count())
216+
} finally {
217+
JdbcDialects.unregisterDialect(testH2Dialect)
218+
JdbcDialects.registerDialect(H2Dialect)
219+
}
215220
}
216221

217222
test("createTableOptions") {

0 commit comments

Comments
 (0)