Skip to content
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.SpecificMutableRow
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.catalyst.util.DateTimeUtils.SQLTimestamp
import org.apache.spark.sql.jdbc.JdbcDialects
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -136,7 +137,7 @@ private[sql] object JDBCRDD extends Logging {
val nullable = rsmd.isNullable(i + 1) != ResultSetMetaData.columnNoNulls
val metadata = new MetadataBuilder().putString("name", columnName)
val columnType =
dialect.getCatalystType(dataType, typeName, fieldSize, metadata).getOrElse(
dialect.getCatalystType(dataType, typeName, fieldSize, fieldScale, metadata).getOrElse(
getCatalystType(dataType, fieldSize, fieldScale, isSigned))
fields(i) = StructField(columnName, columnType, nullable, metadata.build())
i = i + 1
Expand Down Expand Up @@ -324,12 +325,13 @@ private[sql] class JDBCRDD(
case object StringConversion extends JDBCConversion
case object TimestampConversion extends JDBCConversion
case object BinaryConversion extends JDBCConversion
case class ArrayConversion(elementConversion: JDBCConversion) extends JDBCConversion

/**
* Maps a StructType to a type tag list.
* Maps a StructField and its associated DataType to a type tag.
*/
def getConversions(schema: StructType): Array[JDBCConversion] = {
schema.fields.map(sf => sf.dataType match {
def getConversion(sf: StructField, dataType: DataType): JDBCConversion = {
dataType match {
case BooleanType => BooleanConversion
case DateType => DateConversion
case DecimalType.Fixed(p, s) => DecimalConversion(p, s)
Expand All @@ -341,8 +343,16 @@ private[sql] class JDBCRDD(
case StringType => StringConversion
case TimestampType => TimestampConversion
case BinaryType => BinaryConversion
case ArrayType(d, x) => ArrayConversion(getConversion(sf, d))
case _ => throw new IllegalArgumentException(s"Unsupported field $sf")
}).toArray
}
}

/**
* Maps a StructType to a type tag list.
*/
def getConversions(schema: StructType): Array[JDBCConversion] = {
schema.fields.map(sf => getConversion(sf, sf.dataType))
}

/**
Expand Down Expand Up @@ -375,6 +385,10 @@ private[sql] class JDBCRDD(
val conversions = getConversions(schema)
val mutableRow = new SpecificMutableRow(schema.fields.map(x => x.dataType))

def convert_date(dateVal: java.sql.Date): Int = DateTimeUtils.fromJavaDate(dateVal)
def convert_decimal(decimal: java.math.BigDecimal, p: Int, s: Int): Decimal = Decimal(decimal, p, s)
def convert_timestamp(ts: java.sql.Timestamp): SQLTimestamp = DateTimeUtils.fromJavaTimestamp(ts)

def getNext(): InternalRow = {
if (rs.next()) {
var i = 0
Expand All @@ -386,7 +400,7 @@ private[sql] class JDBCRDD(
// DateTimeUtils.fromJavaDate does not handle null value, so we need to check it.
val dateVal = rs.getDate(pos)
if (dateVal != null) {
mutableRow.setInt(i, DateTimeUtils.fromJavaDate(dateVal))
mutableRow.setInt(i, convert_date(dateVal))
} else {
mutableRow.update(i, null)
}
Expand All @@ -403,7 +417,7 @@ private[sql] class JDBCRDD(
if (decimalVal == null) {
mutableRow.update(i, null)
} else {
mutableRow.update(i, Decimal(decimalVal, p, s))
mutableRow.update(i, convert_decimal(decimalVal, p, s))
}
case DoubleConversion => mutableRow.setDouble(i, rs.getDouble(pos))
case FloatConversion => mutableRow.setFloat(i, rs.getFloat(pos))
Expand All @@ -414,21 +428,39 @@ private[sql] class JDBCRDD(
case TimestampConversion =>
val t = rs.getTimestamp(pos)
if (t != null) {
mutableRow.setLong(i, DateTimeUtils.fromJavaTimestamp(t))
mutableRow.setLong(i, convert_timestamp(t))
} else {
mutableRow.update(i, null)
}
case BinaryConversion => mutableRow.update(i, rs.getBytes(pos))
case BinaryLongConversion => {
case BinaryLongConversion =>
val bytes = rs.getBytes(pos)
var ans = 0L
var j = 0
while (j < bytes.size) {
ans = 256 * ans + (255 & bytes(j))
j = j + 1;
j = j + 1
}
mutableRow.setLong(i, ans)
}

case ArrayConversion(BinaryLongConversion) => throw new IllegalArgumentException(s"Unsupported array element conversion $i")
case ArrayConversion(subConvert) =>
val a = rs.getArray(pos)
if (a != null) {
val x = a.getArray
val genericArrayData = new GenericArrayData(subConvert match {
case TimestampConversion => x.asInstanceOf[Array[java.sql.Timestamp]].map(convert_timestamp)
case StringConversion => x.asInstanceOf[Array[java.lang.String]].map(UTF8String.fromString)
case DateConversion => x.asInstanceOf[Array[java.sql.Date]].map(convert_date)
case DecimalConversion(p, s) => x.asInstanceOf[Array[java.math.BigDecimal]].map(convert_decimal(_, p, s))
case ArrayConversion(_) => throw new IllegalArgumentException("Nested arrays unsupported")
case _ => x.asInstanceOf[Array[Any]]
})
mutableRow.update(i, genericArrayData)
} else {
mutableRow.update(i, null)
}

}
if (rs.wasNull) mutableRow.setNullAt(i)
i = i + 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import java.util.Properties
import scala.util.Try

import org.apache.spark.Logging
import org.apache.spark.sql.jdbc.JdbcDialects
import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcDialects}
import org.apache.spark.sql.types._
import org.apache.spark.sql.{DataFrame, Row}

Expand Down Expand Up @@ -92,7 +92,8 @@ object JdbcUtils extends Logging {
iterator: Iterator[Row],
rddSchema: StructType,
nullTypes: Array[Int],
batchSize: Int): Iterator[Byte] = {
batchSize: Int,
dialect: JdbcDialect): Iterator[Byte] = {
val conn = getConnection()
var committed = false
try {
Expand Down Expand Up @@ -121,6 +122,21 @@ object JdbcUtils extends Logging {
case TimestampType => stmt.setTimestamp(i + 1, row.getAs[java.sql.Timestamp](i))
case DateType => stmt.setDate(i + 1, row.getAs[java.sql.Date](i))
case t: DecimalType => stmt.setBigDecimal(i + 1, row.getDecimal(i))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ISTM we need to check if input types are valid for target databases in advance, e.g., in JavaUtils#saveTable.
JavaUtils#savePartition should simply put input data as given typed-data.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the particular dialect does not support these types saveTable should toss an exception when building the nullTypes array

case ArrayType(elemType, _) =>
val elemDataBaseType = dialect.getJDBCType(elemType).map(_.databaseTypeDefinition).getOrElse(
dialect.getCommonJDBCType(elemType).map(_.databaseTypeDefinition).getOrElse(
throw new IllegalArgumentException(
s"Can't determine array element type for $elemType in field $i")
))
val array: Array[AnyRef] = elemType match {
case _: ArrayType =>
throw new IllegalArgumentException(s"Nested array writes to JDBC are not supported for field $i")
case BinaryType => row.getSeq[Array[Byte]](i).toArray
case TimestampType => row.getSeq[java.sql.Timestamp](i).toArray
case DateType => row.getSeq[java.sql.Date](i).toArray
case _ => row.getSeq[AnyRef](i).toArray
}
stmt.setArray(i + 1, conn.createArrayOf(elemDataBaseType, array))
case _ => throw new IllegalArgumentException(
s"Can't translate non-null value for field $i")
}
Expand Down Expand Up @@ -171,21 +187,9 @@ object JdbcUtils extends Logging {
val name = field.name
val typ: String =
dialect.getJDBCType(field.dataType).map(_.databaseTypeDefinition).getOrElse(
field.dataType match {
case IntegerType => "INTEGER"
case LongType => "BIGINT"
case DoubleType => "DOUBLE PRECISION"
case FloatType => "REAL"
case ShortType => "INTEGER"
case ByteType => "BYTE"
case BooleanType => "BIT(1)"
case StringType => "TEXT"
case BinaryType => "BLOB"
case TimestampType => "TIMESTAMP"
case DateType => "DATE"
case t: DecimalType => s"DECIMAL(${t.precision},${t.scale})"
case _ => throw new IllegalArgumentException(s"Don't know how to save $field to JDBC")
})
dialect.getCommonJDBCType(field.dataType).map(_.databaseTypeDefinition).getOrElse(
throw new IllegalArgumentException(s"Don't know how to save $field to JDBC")
))
val nullable = if (field.nullable) "" else "NOT NULL"
sb.append(s", $name $typ $nullable")
}}
Expand All @@ -203,30 +207,18 @@ object JdbcUtils extends Logging {
val dialect = JdbcDialects.get(url)
val nullTypes: Array[Int] = df.schema.fields.map { field =>
dialect.getJDBCType(field.dataType).map(_.jdbcNullType).getOrElse(
field.dataType match {
case IntegerType => java.sql.Types.INTEGER
case LongType => java.sql.Types.BIGINT
case DoubleType => java.sql.Types.DOUBLE
case FloatType => java.sql.Types.REAL
case ShortType => java.sql.Types.INTEGER
case ByteType => java.sql.Types.INTEGER
case BooleanType => java.sql.Types.BIT
case StringType => java.sql.Types.CLOB
case BinaryType => java.sql.Types.BLOB
case TimestampType => java.sql.Types.TIMESTAMP
case DateType => java.sql.Types.DATE
case t: DecimalType => java.sql.Types.DECIMAL
case _ => throw new IllegalArgumentException(
dialect.getCommonJDBCType(field.dataType).map(_.jdbcNullType).getOrElse(
throw new IllegalArgumentException(
s"Can't translate null value for field $field")
})
}
))
}

val rddSchema = df.schema
val driver: String = DriverRegistry.getDriverClassName(url)
val getConnection: () => Connection = JDBCRDD.getConnector(driver, url, properties)
val batchSize = properties.getProperty("batchsize", "1000").toInt
df.foreachPartition { iterator =>
savePartition(getConnection, table, iterator, rddSchema, nullTypes, batchSize)
savePartition(getConnection, table, iterator, rddSchema, nullTypes, batchSize, dialect)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ abstract class JdbcDialect {
* or null if the default type mapping should be used.
*/
def getCatalystType(
sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = None
sqlType: Int, typeName: String, size: Int, scale: Int, md: MetadataBuilder): Option[DataType] = None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a scaladoc entry for scale.


/**
* Retrieve the jdbc / sql type for a given datatype.
Expand All @@ -81,6 +81,24 @@ abstract class JdbcDialect {
*/
def getJDBCType(dt: DataType): Option[JdbcType] = None

def getCommonJDBCType(dataType: DataType): Option[JdbcType] = {
dataType match {
case IntegerType => Some(JdbcType("INTEGER", java.sql.Types.INTEGER))
case LongType => Some(JdbcType("BIGINT", java.sql.Types.BIGINT))
case DoubleType => Some(JdbcType("DOUBLE PRECISION", java.sql.Types.DOUBLE))
case FloatType => Some(JdbcType("REAL", java.sql.Types.FLOAT))
case ShortType => Some(JdbcType("INTEGER", java.sql.Types.SMALLINT))
case ByteType => Some(JdbcType("BYTE", java.sql.Types.TINYINT))
case BooleanType => Some(JdbcType("BIT(1)", java.sql.Types.BIT))
case StringType => Some(JdbcType("TEXT", java.sql.Types.CLOB))
case BinaryType => Some(JdbcType("BLOB", java.sql.Types.BLOB))
case TimestampType => Some(JdbcType("TIMESTAMP", java.sql.Types.TIMESTAMP))
case DateType => Some(JdbcType("DATE", java.sql.Types.DATE))
case t: DecimalType => Some(JdbcType(s"DECIMAL(${t.precision},${t.scale})", java.sql.Types.DECIMAL))
case _ => None
}
}

/**
* Quotes the identifier. This is used to put quotes around the identifier in case the column
* name is a reserved keyword, or in case it contains characters that require quotes (e.g. space).
Expand Down Expand Up @@ -170,8 +188,8 @@ class AggregatedDialect(dialects: List[JdbcDialect]) extends JdbcDialect {
dialects.map(_.canHandle(url)).reduce(_ && _)

override def getCatalystType(
sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = {
dialects.flatMap(_.getCatalystType(sqlType, typeName, size, md)).headOption
sqlType: Int, typeName: String, size: Int, scale: Int, md: MetadataBuilder): Option[DataType] = {
dialects.flatMap(_.getCatalystType(sqlType, typeName, size, scale, md)).headOption
}

override def getJDBCType(dt: DataType): Option[JdbcType] = {
Expand All @@ -196,7 +214,7 @@ case object NoopDialect extends JdbcDialect {
case object PostgresDialect extends JdbcDialect {
override def canHandle(url: String): Boolean = url.startsWith("jdbc:postgresql")
override def getCatalystType(
sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = {
sqlType: Int, typeName: String, size: Int, scale: Int, md: MetadataBuilder): Option[DataType] = {
if (sqlType == Types.BIT && typeName.equals("bit") && size != 1) {
Some(BinaryType)
} else if (sqlType == Types.OTHER && typeName.equals("cidr")) {
Expand All @@ -207,13 +225,39 @@ case object PostgresDialect extends JdbcDialect {
Some(StringType)
} else if (sqlType == Types.OTHER && typeName.equals("jsonb")) {
Some(StringType)
} else if (sqlType == Types.OTHER && typeName.equals("uuid")) {
Some(StringType)
} else if (sqlType == Types.ARRAY) {
typeName match {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does we need underscores in the head of typeName?
I quickly checked actual strings returned by ResultSetMetaData#getColumnTypeName in postgresql-jdbc, and
I found that they have no underscore.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The underscores are particularly for the array types. Postgres prepends them to all array types here https://github.com/pgjdbc/pgjdbc/blob/REL9_4_1204/org/postgresql/jdbc2/TypeInfoCache.java#L159

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Understood.

case "_bit" | "_bool" => Some(ArrayType(BooleanType))
case "_int2" => Some(ArrayType(ShortType))
case "_int4" => Some(ArrayType(IntegerType))
case "_int8" | "_oid" => Some(ArrayType(LongType))
case "_float4" => Some(ArrayType(FloatType))
case "_money" | "_float8" => Some(ArrayType(DoubleType))
case "_text" | "_varchar" | "_char" | "_bpchar" | "_name" => Some(ArrayType(StringType))
case "_bytea" => Some(ArrayType(BinaryType))
case "_timestamp" | "_timestamptz" | "_time" | "_timetz" => Some(ArrayType(TimestampType))
case "_date" => Some(ArrayType(DateType))
case "_numeric"
if size != 0 || scale != 0 => Some(ArrayType(DecimalType(size, scale)))
case "_numeric" => Some(ArrayType(DecimalType.SYSTEM_DEFAULT))
case _ => throw new IllegalArgumentException(s"Unhandled postgres array type $typeName")
}
} else None
}

override def getJDBCType(dt: DataType): Option[JdbcType] = dt match {
case StringType => Some(JdbcType("TEXT", java.sql.Types.CHAR))
case BinaryType => Some(JdbcType("BYTEA", java.sql.Types.BINARY))
case BooleanType => Some(JdbcType("BOOLEAN", java.sql.Types.BOOLEAN))
case ArrayType(t, _) =>
val subtype = getJDBCType(t).map(_.databaseTypeDefinition).getOrElse(
getCommonJDBCType(t).map(_.databaseTypeDefinition).getOrElse(
throw new IllegalArgumentException(s"Unexpected JDBC array subtype $t")
)
)
Some(JdbcType(s"$subtype[]", java.sql.Types.ARRAY))
case _ => None
}

Expand All @@ -231,7 +275,7 @@ case object PostgresDialect extends JdbcDialect {
case object MySQLDialect extends JdbcDialect {
override def canHandle(url : String): Boolean = url.startsWith("jdbc:mysql")
override def getCatalystType(
sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = {
sqlType: Int, typeName: String, size: Int, scale: Int, md: MetadataBuilder): Option[DataType] = {
if (sqlType == Types.VARBINARY && typeName.equals("BIT") && size != 1) {
// This could instead be a BinaryType if we'd rather return bit-vectors of up to 64 bits as
// byte arrays instead of longs.
Expand Down Expand Up @@ -276,7 +320,7 @@ case object DB2Dialect extends JdbcDialect {
case object MsSqlServerDialect extends JdbcDialect {
override def canHandle(url: String): Boolean = url.startsWith("jdbc:sqlserver")
override def getCatalystType(
sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = {
sqlType: Int, typeName: String, size: Int, scale: Int, md: MetadataBuilder): Option[DataType] = {
if (typeName.contains("datetimeoffset")) {
// String is recommend by Microsoft SQL Server for datetimeoffset types in non-MS clients
Some(StringType)
Expand All @@ -298,7 +342,7 @@ case object MsSqlServerDialect extends JdbcDialect {
case object DerbyDialect extends JdbcDialect {
override def canHandle(url: String): Boolean = url.startsWith("jdbc:derby")
override def getCatalystType(
sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = {
sqlType: Int, typeName: String, size: Int, scale: Int, md: MetadataBuilder): Option[DataType] = {
if (sqlType == Types.REAL) Option(FloatType) else None
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -437,7 +437,7 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext
val agg = new AggregatedDialect(List(new JdbcDialect {
override def canHandle(url: String) : Boolean = url.startsWith("jdbc:h2:")
override def getCatalystType(
sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] =
sqlType: Int, typeName: String, size: Int, scale: Int, md: MetadataBuilder): Option[DataType] =
if (sqlType % 2 == 0) {
Some(LongType)
} else {
Expand All @@ -446,8 +446,8 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext
}, testH2Dialect))
assert(agg.canHandle("jdbc:h2:xxx"))
assert(!agg.canHandle("jdbc:h2"))
assert(agg.getCatalystType(0, "", 1, null) === Some(LongType))
assert(agg.getCatalystType(1, "", 1, null) === Some(StringType))
assert(agg.getCatalystType(0, "", 1, 0, null) === Some(LongType))
assert(agg.getCatalystType(1, "", 1, 0, null) === Some(StringType))
}

test("DB2Dialect type mapping") {
Expand All @@ -458,8 +458,8 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext

test("PostgresDialect type mapping") {
val Postgres = JdbcDialects.get("jdbc:postgresql://127.0.0.1/db")
assert(Postgres.getCatalystType(java.sql.Types.OTHER, "json", 1, null) === Some(StringType))
assert(Postgres.getCatalystType(java.sql.Types.OTHER, "jsonb", 1, null) === Some(StringType))
assert(Postgres.getCatalystType(java.sql.Types.OTHER, "json", 1, 0, null) === Some(StringType))
assert(Postgres.getCatalystType(java.sql.Types.OTHER, "jsonb", 1, 0, null) === Some(StringType))
}

test("DerbyDialect jdbc type mapping") {
Expand Down