@@ -26,6 +26,7 @@ import org.apache.spark.rdd.RDD
2626import org .apache .spark .sql .catalyst .InternalRow
2727import org .apache .spark .sql .catalyst .expressions .SpecificMutableRow
2828import org .apache .spark .sql .catalyst .util .DateTimeUtils
29+ import org .apache .spark .sql .catalyst .util .DateTimeUtils .SQLTimestamp
2930import org .apache .spark .sql .jdbc .JdbcDialects
3031import org .apache .spark .sql .sources ._
3132import org .apache .spark .sql .types ._
@@ -324,12 +325,13 @@ private[sql] class JDBCRDD(
324325 case object StringConversion extends JDBCConversion
325326 case object TimestampConversion extends JDBCConversion
326327 case object BinaryConversion extends JDBCConversion
328+ case class ArrayConversion (elementConversion : JDBCConversion ) extends JDBCConversion
327329
328330 /**
329- * Maps a StructType to a type tag list .
331+ * Maps a StructField and its associated DataType to a type tag.
330332 */
331- def getConversions ( schema : StructType ): Array [ JDBCConversion ] = {
332- schema.fields.map(sf => sf. dataType match {
333+ def getConversion ( sf : StructField , dataType : DataType ): JDBCConversion = {
334+ dataType match {
333335 case BooleanType => BooleanConversion
334336 case DateType => DateConversion
335337 case DecimalType .Fixed (p, s) => DecimalConversion (p, s)
@@ -341,8 +343,16 @@ private[sql] class JDBCRDD(
341343 case StringType => StringConversion
342344 case TimestampType => TimestampConversion
343345 case BinaryType => BinaryConversion
346+ case ArrayType (d, x) => ArrayConversion (getConversion(sf, d))
344347 case _ => throw new IllegalArgumentException (s " Unsupported field $sf" )
345- }).toArray
348+ }
349+ }
350+
351+ /**
352+ * Maps a StructType to a type tag list.
353+ */
354+ def getConversions (schema : StructType ): Array [JDBCConversion ] = {
355+ schema.fields.map(sf => getConversion(sf, sf.dataType))
346356 }
347357
348358 /**
@@ -375,6 +385,10 @@ private[sql] class JDBCRDD(
375385 val conversions = getConversions(schema)
376386 val mutableRow = new SpecificMutableRow (schema.fields.map(x => x.dataType))
377387
388+ def convert_date (dateVal : java.sql.Date ): Int = DateTimeUtils .fromJavaDate(dateVal)
389+ def convert_decimal (decimal : java.math.BigDecimal , p : Int , s : Int ): Decimal = Decimal (decimal, p, s)
390+ def convert_timestamp (ts : java.sql.Timestamp ): SQLTimestamp = DateTimeUtils .fromJavaTimestamp(ts)
391+
378392 def getNext (): InternalRow = {
379393 if (rs.next()) {
380394 var i = 0
@@ -386,7 +400,7 @@ private[sql] class JDBCRDD(
386400 // DateTimeUtils.fromJavaDate does not handle null value, so we need to check it.
387401 val dateVal = rs.getDate(pos)
388402 if (dateVal != null ) {
389- mutableRow.setInt(i, DateTimeUtils .fromJavaDate (dateVal))
403+ mutableRow.setInt(i, convert_date (dateVal))
390404 } else {
391405 mutableRow.update(i, null )
392406 }
@@ -403,7 +417,7 @@ private[sql] class JDBCRDD(
403417 if (decimalVal == null ) {
404418 mutableRow.update(i, null )
405419 } else {
406- mutableRow.update(i, Decimal (decimalVal, p, s))
420+ mutableRow.update(i, convert_decimal (decimalVal, p, s))
407421 }
408422 case DoubleConversion => mutableRow.setDouble(i, rs.getDouble(pos))
409423 case FloatConversion => mutableRow.setFloat(i, rs.getFloat(pos))
@@ -414,21 +428,42 @@ private[sql] class JDBCRDD(
414428 case TimestampConversion =>
415429 val t = rs.getTimestamp(pos)
416430 if (t != null ) {
417- mutableRow.setLong(i, DateTimeUtils .fromJavaTimestamp (t))
431+ mutableRow.setLong(i, convert_timestamp (t))
418432 } else {
419433 mutableRow.update(i, null )
420434 }
421435 case BinaryConversion => mutableRow.update(i, rs.getBytes(pos))
422- case BinaryLongConversion => {
436+ case BinaryLongConversion =>
423437 val bytes = rs.getBytes(pos)
424438 var ans = 0L
425439 var j = 0
426440 while (j < bytes.size) {
427441 ans = 256 * ans + (255 & bytes(j))
428- j = j + 1 ;
442+ j = j + 1
429443 }
430444 mutableRow.setLong(i, ans)
431- }
445+
446+ case ArrayConversion (BinaryLongConversion ) => throw new IllegalArgumentException (s " Unsupported array element conversion $i" )
447+ case ArrayConversion (subConvert) =>
448+ val a = rs.getArray(pos)
449+ if (a != null ) {
450+ val genericArrayData = a.getArray match {
451+ case x : Array [java.math.BigDecimal ] =>
452+ subConvert match {
453+ case DecimalConversion (p, s) => new GenericArrayData (x.map(convert_decimal(_, p, s)))
454+ case _ => throw new IllegalArgumentException (" Incompatible decimal conversions" )
455+ }
456+ case x : Array [java.sql.Timestamp ] => new GenericArrayData (x.map(convert_timestamp))
457+ case x : Array [java.lang.String ] => new GenericArrayData (x.map(UTF8String .fromString))
458+ case x : Array [java.sql.Date ] => new GenericArrayData (x.map(convert_date))
459+ case x : Array [Any ] => new GenericArrayData (x)
460+ case _ => throw new IllegalArgumentException (s " Unsupported arraytype $a" )
461+ }
462+ mutableRow.update(i, genericArrayData)
463+ } else {
464+ mutableRow.update(i, null )
465+ }
466+
432467 }
433468 if (rs.wasNull) mutableRow.setNullAt(i)
434469 i = i + 1
0 commit comments