Skip to content

Commit 6c83d93

Browse files
JoshRosenyhuai
authored andcommitted
[SPARK-12579][SQL] Force user-specified JDBC driver to take precedence
Spark SQL's JDBC data source allows users to specify an explicit JDBC driver to load (using the `driver` argument), but in the current code it's possible that the user-specified driver will not be used when it comes time to actually create a JDBC connection. In a nutshell, the problem is that you might have multiple JDBC drivers on the classpath that claim to be able to handle the same subprotocol, so simply registering the user-provided driver class with the our `DriverRegistry` and JDBC's `DriverManager` is not sufficient to ensure that it's actually used when creating the JDBC connection. This patch addresses this issue by first registering the user-specified driver with the DriverManager, then iterating over the driver manager's loaded drivers in order to obtain the correct driver and use it to create a connection (previously, we just called `DriverManager.getConnection()` directly). If a user did not specify a JDBC driver to use, then we call `DriverManager.getDriver` to figure out the class of the driver to use, then pass that class's name to executors; this guards against corner-case bugs in situations where the driver and executor JVMs might have different sets of JDBC drivers on their classpaths (previously, there was the (rare) potential for `DriverManager.getConnection()` to use different drivers on the driver and executors if the user had not explicitly specified a JDBC driver class and the classpaths were different). This patch is inspired by a similar patch that I made to the `spark-redshift` library (databricks/spark-redshift#143), which contains its own modified fork of some of Spark's JDBC data source code (for cross-Spark-version compatibility reasons). Author: Josh Rosen <[email protected]> Closes apache#10519 from JoshRosen/jdbc-driver-precedence.
1 parent 8f65939 commit 6c83d93

7 files changed

Lines changed: 34 additions & 50 deletions

File tree

docs/sql-programming-guide.md

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1895,9 +1895,7 @@ the Data Sources API. The following options are supported:
18951895
<tr>
18961896
<td><code>driver</code></td>
18971897
<td>
1898-
The class name of the JDBC driver needed to connect to this URL. This class will be loaded
1899-
on the master and workers before running an JDBC commands to allow the driver to
1900-
register itself with the JDBC subsystem.
1898+
The class name of the JDBC driver to use to connect to this URL.
19011899
</td>
19021900
</tr>
19031901

sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,7 @@ final class DataFrameWriter private[sql](df: DataFrame) {
275275
}
276276
// connectionProperties should override settings in extraOptions
277277
props.putAll(connectionProperties)
278-
val conn = JdbcUtils.createConnection(url, props)
278+
val conn = JdbcUtils.createConnectionFactory(url, props)()
279279

280280
try {
281281
var tableExists = JdbcUtils.tableExists(conn, url, table)

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DefaultSource.scala

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,15 +31,12 @@ class DefaultSource extends RelationProvider with DataSourceRegister {
3131
sqlContext: SQLContext,
3232
parameters: Map[String, String]): BaseRelation = {
3333
val url = parameters.getOrElse("url", sys.error("Option 'url' not specified"))
34-
val driver = parameters.getOrElse("driver", null)
3534
val table = parameters.getOrElse("dbtable", sys.error("Option 'dbtable' not specified"))
3635
val partitionColumn = parameters.getOrElse("partitionColumn", null)
3736
val lowerBound = parameters.getOrElse("lowerBound", null)
3837
val upperBound = parameters.getOrElse("upperBound", null)
3938
val numPartitions = parameters.getOrElse("numPartitions", null)
4039

41-
if (driver != null) DriverRegistry.register(driver)
42-
4340
if (partitionColumn != null
4441
&& (lowerBound == null || upperBound == null || numPartitions == null)) {
4542
sys.error("Partitioning incompletely specified")

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DriverRegistry.scala

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,5 @@ object DriverRegistry extends Logging {
5151
}
5252
}
5353
}
54-
55-
def getDriverClassName(url: String): String = DriverManager.getDriver(url) match {
56-
case wrapper: DriverWrapper => wrapper.wrapped.getClass.getCanonicalName
57-
case driver => driver.getClass.getCanonicalName
58-
}
5954
}
6055

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala

Lines changed: 4 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
package org.apache.spark.sql.execution.datasources.jdbc
1919

20-
import java.sql.{Connection, Date, DriverManager, ResultSet, ResultSetMetaData, SQLException, Timestamp}
20+
import java.sql.{Connection, Date, ResultSet, ResultSetMetaData, SQLException, Timestamp}
2121
import java.util.Properties
2222

2323
import scala.util.control.NonFatal
@@ -41,7 +41,6 @@ private[sql] case class JDBCPartition(whereClause: String, idx: Int) extends Par
4141
override def index: Int = idx
4242
}
4343

44-
4544
private[sql] object JDBCRDD extends Logging {
4645

4746
/**
@@ -120,7 +119,7 @@ private[sql] object JDBCRDD extends Logging {
120119
*/
121120
def resolveTable(url: String, table: String, properties: Properties): StructType = {
122121
val dialect = JdbcDialects.get(url)
123-
val conn: Connection = getConnector(properties.getProperty("driver"), url, properties)()
122+
val conn: Connection = JdbcUtils.createConnectionFactory(url, properties)()
124123
try {
125124
val statement = conn.prepareStatement(s"SELECT * FROM $table WHERE 1=0")
126125
try {
@@ -228,36 +227,13 @@ private[sql] object JDBCRDD extends Logging {
228227
})
229228
}
230229

231-
/**
232-
* Given a driver string and an url, return a function that loads the
233-
* specified driver string then returns a connection to the JDBC url.
234-
* getConnector is run on the driver code, while the function it returns
235-
* is run on the executor.
236-
*
237-
* @param driver - The class name of the JDBC driver for the given url, or null if the class name
238-
* is not necessary.
239-
* @param url - The JDBC url to connect to.
240-
*
241-
* @return A function that loads the driver and connects to the url.
242-
*/
243-
def getConnector(driver: String, url: String, properties: Properties): () => Connection = {
244-
() => {
245-
try {
246-
if (driver != null) DriverRegistry.register(driver)
247-
} catch {
248-
case e: ClassNotFoundException =>
249-
logWarning(s"Couldn't find class $driver", e)
250-
}
251-
DriverManager.getConnection(url, properties)
252-
}
253-
}
230+
254231

255232
/**
256233
* Build and return JDBCRDD from the given information.
257234
*
258235
* @param sc - Your SparkContext.
259236
* @param schema - The Catalyst schema of the underlying database table.
260-
* @param driver - The class name of the JDBC driver for the given url.
261237
* @param url - The JDBC url to connect to.
262238
* @param fqTable - The fully-qualified table name (or paren'd SQL query) to use.
263239
* @param requiredColumns - The names of the columns to SELECT.
@@ -270,7 +246,6 @@ private[sql] object JDBCRDD extends Logging {
270246
def scanTable(
271247
sc: SparkContext,
272248
schema: StructType,
273-
driver: String,
274249
url: String,
275250
properties: Properties,
276251
fqTable: String,
@@ -281,7 +256,7 @@ private[sql] object JDBCRDD extends Logging {
281256
val quotedColumns = requiredColumns.map(colName => dialect.quoteIdentifier(colName))
282257
new JDBCRDD(
283258
sc,
284-
getConnector(driver, url, properties),
259+
JdbcUtils.createConnectionFactory(url, properties),
285260
pruneSchema(schema, requiredColumns),
286261
fqTable,
287262
quotedColumns,

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,12 +91,10 @@ private[sql] case class JDBCRelation(
9191
override val schema: StructType = JDBCRDD.resolveTable(url, table, properties)
9292

9393
override def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = {
94-
val driver: String = DriverRegistry.getDriverClassName(url)
9594
// Rely on a type erasure hack to pass RDD[InternalRow] back as RDD[Row]
9695
JDBCRDD.scanTable(
9796
sqlContext.sparkContext,
9897
schema,
99-
driver,
10098
url,
10199
properties,
102100
table,

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,10 @@
1717

1818
package org.apache.spark.sql.execution.datasources.jdbc
1919

20-
import java.sql.{Connection, PreparedStatement}
20+
import java.sql.{Connection, Driver, DriverManager, PreparedStatement}
2121
import java.util.Properties
2222

23+
import scala.collection.JavaConverters._
2324
import scala.util.Try
2425
import scala.util.control.NonFatal
2526

@@ -34,10 +35,31 @@ import org.apache.spark.sql.{DataFrame, Row}
3435
object JdbcUtils extends Logging {
3536

3637
/**
37-
* Establishes a JDBC connection.
38+
* Returns a factory for creating connections to the given JDBC URL.
39+
*
40+
* @param url the JDBC url to connect to.
41+
* @param properties JDBC connection properties.
3842
*/
39-
def createConnection(url: String, connectionProperties: Properties): Connection = {
40-
JDBCRDD.getConnector(connectionProperties.getProperty("driver"), url, connectionProperties)()
43+
def createConnectionFactory(url: String, properties: Properties): () => Connection = {
44+
val userSpecifiedDriverClass = Option(properties.getProperty("driver"))
45+
userSpecifiedDriverClass.foreach(DriverRegistry.register)
46+
// Performing this part of the logic on the driver guards against the corner-case where the
47+
// driver returned for a URL is different on the driver and executors due to classpath
48+
// differences.
49+
val driverClass: String = userSpecifiedDriverClass.getOrElse {
50+
DriverManager.getDriver(url).getClass.getCanonicalName
51+
}
52+
() => {
53+
userSpecifiedDriverClass.foreach(DriverRegistry.register)
54+
val driver: Driver = DriverManager.getDrivers.asScala.collectFirst {
55+
case d: DriverWrapper if d.wrapped.getClass.getCanonicalName == driverClass => d
56+
case d if d.getClass.getCanonicalName == driverClass => d
57+
}.getOrElse {
58+
throw new IllegalStateException(
59+
s"Did not find registered driver with class $driverClass")
60+
}
61+
driver.connect(url, properties)
62+
}
4163
}
4264

4365
/**
@@ -242,15 +264,14 @@ object JdbcUtils extends Logging {
242264
df: DataFrame,
243265
url: String,
244266
table: String,
245-
properties: Properties = new Properties()) {
267+
properties: Properties) {
246268
val dialect = JdbcDialects.get(url)
247269
val nullTypes: Array[Int] = df.schema.fields.map { field =>
248270
getJdbcType(field.dataType, dialect).jdbcNullType
249271
}
250272

251273
val rddSchema = df.schema
252-
val driver: String = DriverRegistry.getDriverClassName(url)
253-
val getConnection: () => Connection = JDBCRDD.getConnector(driver, url, properties)
274+
val getConnection: () => Connection = createConnectionFactory(url, properties)
254275
val batchSize = properties.getProperty("batchsize", "1000").toInt
255276
df.foreachPartition { iterator =>
256277
savePartition(getConnection, table, iterator, rddSchema, nullTypes, batchSize, dialect)

0 commit comments

Comments
 (0)