diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index f9dfa3c92f1e..374af48b820c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -206,7 +206,8 @@ case class Sort( object ExistingRdd { def convertToCatalyst(a: Any): Any = a match { case o: Option[_] => o.orNull - case s: Seq[Any] => s.map(convertToCatalyst) + case s: Seq[_] => s.map(convertToCatalyst) + case m: Map[_, _] => m.map { case (k, v) => convertToCatalyst(k) -> convertToCatalyst(v) } case p: Product => new GenericRow(p.productIterator.map(convertToCatalyst).toArray) case other => other } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala index 5b84c658db94..e24c521d24c7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala @@ -21,6 +21,7 @@ import java.sql.Timestamp import org.scalatest.FunSuite +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.test.TestSQLContext._ case class ReflectData( @@ -56,6 +57,22 @@ case class OptionalReflectData( case class ReflectBinary(data: Array[Byte]) +case class Nested(i: Option[Int], s: String) + +case class Data( + array: Seq[Int], + arrayContainsNull: Seq[Option[Int]], + map: Map[Int, Long], + mapContainsNul: Map[Int, Option[Long]], + nested: Nested) + +case class ComplexReflectData( + arrayField: Seq[Int], + arrayFieldContainsNull: Seq[Option[Int]], + mapField: Map[Int, Long], + mapFieldContainsNull: Map[Int, Option[Long]], + dataField: Data) + class ScalaReflectionRelationSuite extends FunSuite { test("query case class RDD") { val data = ReflectData("a", 1, 1L, 1.toFloat, 1.toDouble, 1.toShort, 1.toByte, true, @@ -90,4 +107,33 @@ class ScalaReflectionRelationSuite extends FunSuite { val result = sql("SELECT data FROM reflectBinary").collect().head(0).asInstanceOf[Array[Byte]] assert(result.toSeq === Seq[Byte](1)) } + + test("query complex data") { + val data = ComplexReflectData( + Seq(1, 2, 3), + Seq(Some(1), Some(2), None), + Map(1 -> 10L, 2 -> 20L), + Map(1 -> Some(10L), 2 -> Some(20L), 3 -> None), + Data( + Seq(10, 20, 30), + Seq(Some(10), Some(20), None), + Map(10 -> 100L, 20 -> 200L), + Map(10 -> Some(100L), 20 -> Some(200L), 30 -> None), + Nested(None, "abc"))) + val rdd = sparkContext.parallelize(data :: Nil) + rdd.registerTempTable("reflectComplexData") + + assert(sql("SELECT * FROM reflectComplexData").collect().head === + new GenericRow(Array[Any]( + Seq(1, 2, 3), + Seq(1, 2, null), + Map(1 -> 10L, 2 -> 20L), + Map(1 -> 10L, 2 -> 20L, 3 -> null), + new GenericRow(Array[Any]( + Seq(10, 20, 30), + Seq(10, 20, null), + Map(10 -> 100L, 20 -> 200L), + Map(10 -> 100L, 20 -> 200L, 30 -> null), + new GenericRow(Array[Any](null, "abc"))))))) + } }