Skip to content

Commit 72beea6

Browse files
[SPARK-10186] [SQL] Add support for array types using JDBCRDD and postgres
This change allows reading from jdbc array column types for the postgresql dialect. This also opens up some implementation for array types using other jdbc backends.
1 parent d45a0d3 commit 72beea6

3 files changed

Lines changed: 91 additions & 41 deletions

File tree

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala

Lines changed: 45 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ import org.apache.spark.rdd.RDD
2626
import org.apache.spark.sql.catalyst.InternalRow
2727
import org.apache.spark.sql.catalyst.expressions.SpecificMutableRow
2828
import org.apache.spark.sql.catalyst.util.DateTimeUtils
29+
import org.apache.spark.sql.catalyst.util.DateTimeUtils.SQLTimestamp
2930
import org.apache.spark.sql.jdbc.JdbcDialects
3031
import org.apache.spark.sql.sources._
3132
import org.apache.spark.sql.types._
@@ -324,12 +325,13 @@ private[sql] class JDBCRDD(
324325
case object StringConversion extends JDBCConversion
325326
case object TimestampConversion extends JDBCConversion
326327
case object BinaryConversion extends JDBCConversion
328+
case class ArrayConversion(elementConversion: JDBCConversion) extends JDBCConversion
327329

328330
/**
329-
* Maps a StructType to a type tag list.
331+
* Maps a StructField and its associated DataType to a type tag.
330332
*/
331-
def getConversions(schema: StructType): Array[JDBCConversion] = {
332-
schema.fields.map(sf => sf.dataType match {
333+
def getConversion(sf: StructField, dataType: DataType): JDBCConversion = {
334+
dataType match {
333335
case BooleanType => BooleanConversion
334336
case DateType => DateConversion
335337
case DecimalType.Fixed(p, s) => DecimalConversion(p, s)
@@ -341,8 +343,16 @@ private[sql] class JDBCRDD(
341343
case StringType => StringConversion
342344
case TimestampType => TimestampConversion
343345
case BinaryType => BinaryConversion
346+
case ArrayType(d, x) => ArrayConversion(getConversion(sf, d))
344347
case _ => throw new IllegalArgumentException(s"Unsupported field $sf")
345-
}).toArray
348+
}
349+
}
350+
351+
/**
352+
* Maps a StructType to a type tag list.
353+
*/
354+
def getConversions(schema: StructType): Array[JDBCConversion] = {
355+
schema.fields.map(sf => getConversion(sf, sf.dataType))
346356
}
347357

348358
/**
@@ -375,6 +385,10 @@ private[sql] class JDBCRDD(
375385
val conversions = getConversions(schema)
376386
val mutableRow = new SpecificMutableRow(schema.fields.map(x => x.dataType))
377387

388+
def convert_date(dateVal: java.sql.Date): Int = DateTimeUtils.fromJavaDate(dateVal)
389+
def convert_decimal(decimal: java.math.BigDecimal, p: Int, s: Int): Decimal = Decimal(decimal, p, s)
390+
def convert_timestamp(ts: java.sql.Timestamp): SQLTimestamp = DateTimeUtils.fromJavaTimestamp(ts)
391+
378392
def getNext(): InternalRow = {
379393
if (rs.next()) {
380394
var i = 0
@@ -386,7 +400,7 @@ private[sql] class JDBCRDD(
386400
// DateTimeUtils.fromJavaDate does not handle null value, so we need to check it.
387401
val dateVal = rs.getDate(pos)
388402
if (dateVal != null) {
389-
mutableRow.setInt(i, DateTimeUtils.fromJavaDate(dateVal))
403+
mutableRow.setInt(i, convert_date(dateVal))
390404
} else {
391405
mutableRow.update(i, null)
392406
}
@@ -403,7 +417,7 @@ private[sql] class JDBCRDD(
403417
if (decimalVal == null) {
404418
mutableRow.update(i, null)
405419
} else {
406-
mutableRow.update(i, Decimal(decimalVal, p, s))
420+
mutableRow.update(i, convert_decimal(decimalVal, p, s))
407421
}
408422
case DoubleConversion => mutableRow.setDouble(i, rs.getDouble(pos))
409423
case FloatConversion => mutableRow.setFloat(i, rs.getFloat(pos))
@@ -414,21 +428,42 @@ private[sql] class JDBCRDD(
414428
case TimestampConversion =>
415429
val t = rs.getTimestamp(pos)
416430
if (t != null) {
417-
mutableRow.setLong(i, DateTimeUtils.fromJavaTimestamp(t))
431+
mutableRow.setLong(i, convert_timestamp(t))
418432
} else {
419433
mutableRow.update(i, null)
420434
}
421435
case BinaryConversion => mutableRow.update(i, rs.getBytes(pos))
422-
case BinaryLongConversion => {
436+
case BinaryLongConversion =>
423437
val bytes = rs.getBytes(pos)
424438
var ans = 0L
425439
var j = 0
426440
while (j < bytes.size) {
427441
ans = 256 * ans + (255 & bytes(j))
428-
j = j + 1;
442+
j = j + 1
429443
}
430444
mutableRow.setLong(i, ans)
431-
}
445+
446+
case ArrayConversion(BinaryLongConversion) => throw new IllegalArgumentException(s"Unsupported array element conversion $i")
447+
case ArrayConversion(subConvert) =>
448+
val a = rs.getArray(pos)
449+
if (a != null) {
450+
val genericArrayData = a.getArray match {
451+
case x: Array[java.math.BigDecimal] =>
452+
subConvert match {
453+
case DecimalConversion(p, s) => new GenericArrayData(x.map(convert_decimal(_, p, s)))
454+
case _ => throw new IllegalArgumentException("Incompatible decimal conversions")
455+
}
456+
case x: Array[java.sql.Timestamp] => new GenericArrayData(x.map(convert_timestamp))
457+
case x: Array[java.lang.String] => new GenericArrayData(x.map(UTF8String.fromString))
458+
case x: Array[java.sql.Date] => new GenericArrayData(x.map(convert_date))
459+
case x: Array[Any] => new GenericArrayData(x)
460+
case _ => throw new IllegalArgumentException(s"Unsupported arraytype $a")
461+
}
462+
mutableRow.update(i, genericArrayData)
463+
} else {
464+
mutableRow.update(i, null)
465+
}
466+
432467
}
433468
if (rs.wasNull) mutableRow.setNullAt(i)
434469
i = i + 1

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala

Lines changed: 7 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -171,21 +171,9 @@ object JdbcUtils extends Logging {
171171
val name = field.name
172172
val typ: String =
173173
dialect.getJDBCType(field.dataType).map(_.databaseTypeDefinition).getOrElse(
174-
field.dataType match {
175-
case IntegerType => "INTEGER"
176-
case LongType => "BIGINT"
177-
case DoubleType => "DOUBLE PRECISION"
178-
case FloatType => "REAL"
179-
case ShortType => "INTEGER"
180-
case ByteType => "BYTE"
181-
case BooleanType => "BIT(1)"
182-
case StringType => "TEXT"
183-
case BinaryType => "BLOB"
184-
case TimestampType => "TIMESTAMP"
185-
case DateType => "DATE"
186-
case t: DecimalType => s"DECIMAL(${t.precision},${t.scale})"
187-
case _ => throw new IllegalArgumentException(s"Don't know how to save $field to JDBC")
188-
})
174+
dialect.getCommonJDBCType(field.dataType).map(_.databaseTypeDefinition).getOrElse(
175+
throw new IllegalArgumentException(s"Don't know how to save $field to JDBC")
176+
))
189177
val nullable = if (field.nullable) "" else "NOT NULL"
190178
sb.append(s", $name $typ $nullable")
191179
}}
@@ -203,23 +191,11 @@ object JdbcUtils extends Logging {
203191
val dialect = JdbcDialects.get(url)
204192
val nullTypes: Array[Int] = df.schema.fields.map { field =>
205193
dialect.getJDBCType(field.dataType).map(_.jdbcNullType).getOrElse(
206-
field.dataType match {
207-
case IntegerType => java.sql.Types.INTEGER
208-
case LongType => java.sql.Types.BIGINT
209-
case DoubleType => java.sql.Types.DOUBLE
210-
case FloatType => java.sql.Types.REAL
211-
case ShortType => java.sql.Types.INTEGER
212-
case ByteType => java.sql.Types.INTEGER
213-
case BooleanType => java.sql.Types.BIT
214-
case StringType => java.sql.Types.CLOB
215-
case BinaryType => java.sql.Types.BLOB
216-
case TimestampType => java.sql.Types.TIMESTAMP
217-
case DateType => java.sql.Types.DATE
218-
case t: DecimalType => java.sql.Types.DECIMAL
219-
case _ => throw new IllegalArgumentException(
194+
dialect.getCommonJDBCType(field.dataType).map(_.jdbcNullType).getOrElse(
195+
throw new IllegalArgumentException(
220196
s"Can't translate null value for field $field")
221-
})
222-
}
197+
))
198+
}
223199

224200
val rddSchema = df.schema
225201
val driver: String = DriverRegistry.getDriverClassName(url)

sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,24 @@ abstract class JdbcDialect {
8181
*/
8282
def getJDBCType(dt: DataType): Option[JdbcType] = None
8383

84+
def getCommonJDBCType(dataType: DataType): Option[JdbcType] = {
85+
dataType match {
86+
case IntegerType => Some(JdbcType("INTEGER", java.sql.Types.INTEGER))
87+
case LongType => Some(JdbcType("BIGINT", java.sql.Types.BIGINT))
88+
case DoubleType => Some(JdbcType("DOUBLE PRECISION", java.sql.Types.DOUBLE))
89+
case FloatType => Some(JdbcType("REAL", java.sql.Types.FLOAT))
90+
case ShortType => Some(JdbcType("INTEGER", java.sql.Types.SMALLINT))
91+
case ByteType => Some(JdbcType("BYTE", java.sql.Types.TINYINT))
92+
case BooleanType => Some(JdbcType("BIT(1)", java.sql.Types.BIT))
93+
case StringType => Some(JdbcType("TEXT", java.sql.Types.CLOB))
94+
case BinaryType => Some(JdbcType("BLOB", java.sql.Types.BLOB))
95+
case TimestampType => Some(JdbcType("TIMESTAMP", java.sql.Types.TIMESTAMP))
96+
case DateType => Some(JdbcType("DATE", java.sql.Types.DATE))
97+
case DecimalType(p, s) => Some(JdbcType(s"DECIMAL($p,$s)", java.sql.Types.DECIMAL))
98+
case _ => None
99+
}
100+
}
101+
84102
/**
85103
* Quotes the identifier. This is used to put quotes around the identifier in case the column
86104
* name is a reserved keyword, or in case it contains characters that require quotes (e.g. space).
@@ -207,13 +225,34 @@ case object PostgresDialect extends JdbcDialect {
207225
Some(StringType)
208226
} else if (sqlType == Types.OTHER && typeName.equals("jsonb")) {
209227
Some(StringType)
228+
} else if (sqlType == Types.ARRAY) {
229+
typeName match {
230+
case "_bit" => Some(ArrayType(BinaryType))
231+
case "_int1" => Some(ArrayType(ByteType))
232+
case "_int2" => Some(ArrayType(ShortType))
233+
case "_int4" => Some(ArrayType(IntegerType))
234+
case "_int8" => Some(ArrayType(LongType))
235+
case "_float4" => Some(ArrayType(FloatType))
236+
case "_float8" => Some(ArrayType(DoubleType))
237+
case "_text" | "_char" | "_varchar" => Some(ArrayType(StringType))
238+
case "_timestamp" | "timestamptz" => Some(ArrayType(TimestampType))
239+
case "_date" => Some(ArrayType(DateType))
240+
case _ => throw new IllegalArgumentException(s"Unhandled postgres array type $typeName")
241+
}
210242
} else None
211243
}
212244

213245
override def getJDBCType(dt: DataType): Option[JdbcType] = dt match {
214246
case StringType => Some(JdbcType("TEXT", java.sql.Types.CHAR))
215247
case BinaryType => Some(JdbcType("BYTEA", java.sql.Types.BINARY))
216248
case BooleanType => Some(JdbcType("BOOLEAN", java.sql.Types.BOOLEAN))
249+
case ArrayType(t) =>
250+
val subtype = getJDBCType(t).map(_.databaseTypeDefinition).getOrElse(
251+
getCommonJDBCType(t).map(_.databaseTypeDefinition).getOrElse(
252+
throw new IllegalArgumentException(s"Unexpected JDBC array subtype $t")
253+
)
254+
)
255+
Some(JdbcType(s"$subtype[]", java.sql.Types.ARRAY))
217256
case _ => None
218257
}
219258

0 commit comments

Comments
 (0)