Skip to content

Commit 3554d68

Browse files
committed
Force user-specified JDBC driver to take precedence.
1 parent be86268 commit 3554d68

7 files changed

Lines changed: 34 additions & 50 deletions

File tree

docs/sql-programming-guide.md

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1895,9 +1895,7 @@ the Data Sources API. The following options are supported:
18951895
<tr>
18961896
<td><code>driver</code></td>
18971897
<td>
1898-
The class name of the JDBC driver needed to connect to this URL. This class will be loaded
1899-
on the master and workers before running an JDBC commands to allow the driver to
1900-
register itself with the JDBC subsystem.
1898+
The class name of the JDBC driver to use to connect to this URL.
19011899
</td>
19021900
</tr>
19031901

sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,7 @@ final class DataFrameWriter private[sql](df: DataFrame) {
275275
}
276276
// connectionProperties should override settings in extraOptions
277277
props.putAll(connectionProperties)
278-
val conn = JdbcUtils.createConnection(url, props)
278+
val conn = JdbcUtils.createConnectionFactory(url, props)()
279279

280280
try {
281281
var tableExists = JdbcUtils.tableExists(conn, url, table)

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DefaultSource.scala

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,15 +31,12 @@ class DefaultSource extends RelationProvider with DataSourceRegister {
3131
sqlContext: SQLContext,
3232
parameters: Map[String, String]): BaseRelation = {
3333
val url = parameters.getOrElse("url", sys.error("Option 'url' not specified"))
34-
val driver = parameters.getOrElse("driver", null)
3534
val table = parameters.getOrElse("dbtable", sys.error("Option 'dbtable' not specified"))
3635
val partitionColumn = parameters.getOrElse("partitionColumn", null)
3736
val lowerBound = parameters.getOrElse("lowerBound", null)
3837
val upperBound = parameters.getOrElse("upperBound", null)
3938
val numPartitions = parameters.getOrElse("numPartitions", null)
4039

41-
if (driver != null) DriverRegistry.register(driver)
42-
4340
if (partitionColumn != null
4441
&& (lowerBound == null || upperBound == null || numPartitions == null)) {
4542
sys.error("Partitioning incompletely specified")

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DriverRegistry.scala

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,5 @@ object DriverRegistry extends Logging {
5151
}
5252
}
5353
}
54-
55-
def getDriverClassName(url: String): String = DriverManager.getDriver(url) match {
56-
case wrapper: DriverWrapper => wrapper.wrapped.getClass.getCanonicalName
57-
case driver => driver.getClass.getCanonicalName
58-
}
5954
}
6055

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala

Lines changed: 4 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
package org.apache.spark.sql.execution.datasources.jdbc
1919

20-
import java.sql.{Connection, Date, DriverManager, ResultSet, ResultSetMetaData, SQLException, Timestamp}
20+
import java.sql.{Connection, Date, ResultSet, ResultSetMetaData, SQLException, Timestamp}
2121
import java.util.Properties
2222

2323
import scala.util.control.NonFatal
@@ -41,7 +41,6 @@ private[sql] case class JDBCPartition(whereClause: String, idx: Int) extends Par
4141
override def index: Int = idx
4242
}
4343

44-
4544
private[sql] object JDBCRDD extends Logging {
4645

4746
/**
@@ -120,7 +119,7 @@ private[sql] object JDBCRDD extends Logging {
120119
*/
121120
def resolveTable(url: String, table: String, properties: Properties): StructType = {
122121
val dialect = JdbcDialects.get(url)
123-
val conn: Connection = getConnector(properties.getProperty("driver"), url, properties)()
122+
val conn: Connection = JdbcUtils.createConnectionFactory(url, properties)()
124123
try {
125124
val statement = conn.prepareStatement(s"SELECT * FROM $table WHERE 1=0")
126125
try {
@@ -201,36 +200,13 @@ private[sql] object JDBCRDD extends Logging {
201200
case _ => null
202201
}
203202

204-
/**
205-
* Given a driver string and an url, return a function that loads the
206-
* specified driver string then returns a connection to the JDBC url.
207-
* getConnector is run on the driver code, while the function it returns
208-
* is run on the executor.
209-
*
210-
* @param driver - The class name of the JDBC driver for the given url, or null if the class name
211-
* is not necessary.
212-
* @param url - The JDBC url to connect to.
213-
*
214-
* @return A function that loads the driver and connects to the url.
215-
*/
216-
def getConnector(driver: String, url: String, properties: Properties): () => Connection = {
217-
() => {
218-
try {
219-
if (driver != null) DriverRegistry.register(driver)
220-
} catch {
221-
case e: ClassNotFoundException =>
222-
logWarning(s"Couldn't find class $driver", e)
223-
}
224-
DriverManager.getConnection(url, properties)
225-
}
226-
}
203+
227204

228205
/**
229206
* Build and return JDBCRDD from the given information.
230207
*
231208
* @param sc - Your SparkContext.
232209
* @param schema - The Catalyst schema of the underlying database table.
233-
* @param driver - The class name of the JDBC driver for the given url.
234210
* @param url - The JDBC url to connect to.
235211
* @param fqTable - The fully-qualified table name (or paren'd SQL query) to use.
236212
* @param requiredColumns - The names of the columns to SELECT.
@@ -243,7 +219,6 @@ private[sql] object JDBCRDD extends Logging {
243219
def scanTable(
244220
sc: SparkContext,
245221
schema: StructType,
246-
driver: String,
247222
url: String,
248223
properties: Properties,
249224
fqTable: String,
@@ -254,7 +229,7 @@ private[sql] object JDBCRDD extends Logging {
254229
val quotedColumns = requiredColumns.map(colName => dialect.quoteIdentifier(colName))
255230
new JDBCRDD(
256231
sc,
257-
getConnector(driver, url, properties),
232+
JdbcUtils.createConnectionFactory(url, properties),
258233
pruneSchema(schema, requiredColumns),
259234
fqTable,
260235
quotedColumns,

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,12 +91,10 @@ private[sql] case class JDBCRelation(
9191
override val schema: StructType = JDBCRDD.resolveTable(url, table, properties)
9292

9393
override def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = {
94-
val driver: String = DriverRegistry.getDriverClassName(url)
9594
// Rely on a type erasure hack to pass RDD[InternalRow] back as RDD[Row]
9695
JDBCRDD.scanTable(
9796
sqlContext.sparkContext,
9897
schema,
99-
driver,
10098
url,
10199
properties,
102100
table,

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,10 @@
1717

1818
package org.apache.spark.sql.execution.datasources.jdbc
1919

20-
import java.sql.{Connection, PreparedStatement}
20+
import java.sql.{Connection, Driver, DriverManager, PreparedStatement}
2121
import java.util.Properties
2222

23+
import scala.collection.JavaConverters._
2324
import scala.util.Try
2425
import scala.util.control.NonFatal
2526

@@ -34,10 +35,31 @@ import org.apache.spark.sql.{DataFrame, Row}
3435
object JdbcUtils extends Logging {
3536

3637
/**
37-
* Establishes a JDBC connection.
38+
* Returns a factory for creating connections to the given JDBC URL.
39+
*
40+
* @param url the JDBC url to connect to.
41+
* @param properties JDBC connection properties.
3842
*/
39-
def createConnection(url: String, connectionProperties: Properties): Connection = {
40-
JDBCRDD.getConnector(connectionProperties.getProperty("driver"), url, connectionProperties)()
43+
def createConnectionFactory(url: String, properties: Properties): () => Connection = {
44+
val userSpecifiedDriverClass = Option(properties.getProperty("driver"))
45+
userSpecifiedDriverClass.foreach(DriverRegistry.register)
46+
// Performing this part of the logic on the driver guards against the corner-case where the
47+
// driver returned for a URL is different on the driver and executors due to classpath
48+
// differences.
49+
val driverClass: String = userSpecifiedDriverClass.getOrElse {
50+
DriverManager.getDriver(url).getClass.getCanonicalName
51+
}
52+
() => {
53+
userSpecifiedDriverClass.foreach(DriverRegistry.register)
54+
val driver: Driver = DriverManager.getDrivers.asScala.collectFirst {
55+
case d: DriverWrapper if d.wrapped.getClass.getCanonicalName == driverClass => d
56+
case d if d.getClass.getCanonicalName == driverClass => d
57+
}.getOrElse {
58+
throw new IllegalStateException(
59+
s"Did not find registered driver with class $driverClass")
60+
}
61+
driver.connect(url, properties)
62+
}
4163
}
4264

4365
/**
@@ -242,15 +264,14 @@ object JdbcUtils extends Logging {
242264
df: DataFrame,
243265
url: String,
244266
table: String,
245-
properties: Properties = new Properties()) {
267+
properties: Properties) {
246268
val dialect = JdbcDialects.get(url)
247269
val nullTypes: Array[Int] = df.schema.fields.map { field =>
248270
getJdbcType(field.dataType, dialect).jdbcNullType
249271
}
250272

251273
val rddSchema = df.schema
252-
val driver: String = DriverRegistry.getDriverClassName(url)
253-
val getConnection: () => Connection = JDBCRDD.getConnector(driver, url, properties)
274+
val getConnection: () => Connection = createConnectionFactory(url, properties)
254275
val batchSize = properties.getProperty("batchsize", "1000").toInt
255276
df.foreachPartition { iterator =>
256277
savePartition(getConnection, table, iterator, rddSchema, nullTypes, batchSize, dialect)

0 commit comments

Comments
 (0)