diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index 933a6dbeb705..39132139237c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -362,7 +362,12 @@ object JavaTypeInference { def toCatalystArray(input: Expression, elementType: TypeToken[_]): Expression = { val (dataType, nullable) = inferDataType(elementType) if (ScalaReflection.isNativeType(dataType)) { - createSerializerForGenericArray(input, dataType, nullable = nullable) + val cls = input.dataType.asInstanceOf[ObjectType].cls + if (cls.isArray && cls.getComponentType.isPrimitive) { + createSerializerForPrimitiveArray(input, dataType) + } else { + createSerializerForGenericArray(input, dataType, nullable = nullable) + } } else { createSerializerForMapObjects(input, ObjectType(elementType.getRawType), serializerFor(_, elementType)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 5b3109af6a53..d8d268a77ca1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -394,8 +394,8 @@ object ScalaReflection extends ScalaReflection { createSerializerForMapObjects(input, dt, serializerFor(_, elementType, newPath, seenTypeSet)) - case dt @ (BooleanType | ByteType | ShortType | IntegerType | LongType | - FloatType | DoubleType) => + case dt @ (BooleanType | ByteType | ShortType | IntegerType | LongType | + FloatType | DoubleType) => val cls = input.dataType.asInstanceOf[ObjectType].cls if (cls.isArray && cls.getComponentType.isPrimitive) { createSerializerForPrimitiveArray(input, dt) diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaBeanDeserializationSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaBeanDeserializationSuite.java index 49ff522cee8e..f59afef36a5a 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaBeanDeserializationSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaBeanDeserializationSuite.java @@ -47,25 +47,28 @@ public void tearDown() { static { ARRAY_RECORDS.add( - new ArrayRecord(1, Arrays.asList(new Interval(111, 211), new Interval(121, 221))) + new ArrayRecord(1, Arrays.asList(new Interval(111, 211), new Interval(121, 221)), + new int[] { 11, 12, 13, 14 }) ); ARRAY_RECORDS.add( - new ArrayRecord(2, Arrays.asList(new Interval(112, 212), new Interval(122, 222))) + new ArrayRecord(2, Arrays.asList(new Interval(112, 212), new Interval(122, 222)), + new int[] { 21, 22, 23, 24 }) ); ARRAY_RECORDS.add( - new ArrayRecord(3, Arrays.asList(new Interval(113, 213), new Interval(123, 223))) + new ArrayRecord(3, Arrays.asList(new Interval(113, 213), new Interval(123, 223)), + new int[] { 31, 32, 33, 34 }) ); } @Test public void testBeanWithArrayFieldDeserialization() { - Encoder encoder = Encoders.bean(ArrayRecord.class); Dataset dataset = spark .read() .format("json") - .schema("id int, intervals array>") + .schema("id int, intervals array>, " + + "ints array") .load("src/test/resources/test-data/with-array-fields.json") .as(encoder); @@ -223,12 +226,14 @@ public static class ArrayRecord { private int id; private List intervals; + private int[] ints; public ArrayRecord() { } - ArrayRecord(int id, List intervals) { + ArrayRecord(int id, List intervals, int[] ints) { this.id = id; this.intervals = intervals; + this.ints = ints; } public int getId() { @@ -247,21 +252,31 @@ public void setIntervals(List intervals) { this.intervals = intervals; } + public int[] getInts() { + return ints; + } + + public void setInts(int[] ints) { + this.ints = ints; + } + @Override public int hashCode() { - return id ^ Objects.hashCode(intervals); + return id ^ Objects.hashCode(intervals) ^ Objects.hashCode(ints); } @Override public boolean equals(Object obj) { if (!(obj instanceof ArrayRecord)) return false; ArrayRecord other = (ArrayRecord) obj; - return (other.id == this.id) && Objects.equals(other.intervals, this.intervals); + return (other.id == this.id) && Objects.equals(other.intervals, this.intervals) && + Arrays.equals(other.ints, ints); } @Override public String toString() { - return String.format("{ id: %d, intervals: %s }", id, intervals); + return String.format("{ id: %d, intervals: %s, ints: %s }", id, intervals, + Arrays.toString(ints)); } } diff --git a/sql/core/src/test/resources/test-data/with-array-fields.json b/sql/core/src/test/resources/test-data/with-array-fields.json index ff3674af2fbc..09022ec02895 100644 --- a/sql/core/src/test/resources/test-data/with-array-fields.json +++ b/sql/core/src/test/resources/test-data/with-array-fields.json @@ -1,3 +1,3 @@ -{ "id": 1, "intervals": [{ "startTime": 111, "endTime": 211 }, { "startTime": 121, "endTime": 221 }]} -{ "id": 2, "intervals": [{ "startTime": 112, "endTime": 212 }, { "startTime": 122, "endTime": 222 }]} -{ "id": 3, "intervals": [{ "startTime": 113, "endTime": 213 }, { "startTime": 123, "endTime": 223 }]} \ No newline at end of file +{ "id": 1, "intervals": [{ "startTime": 111, "endTime": 211 }, { "startTime": 121, "endTime": 221 }], "ints": [11, 12, 13, 14]} +{ "id": 2, "intervals": [{ "startTime": 112, "endTime": 212 }, { "startTime": 122, "endTime": 222 }], "ints": [21, 22, 23, 24]} +{ "id": 3, "intervals": [{ "startTime": 113, "endTime": 213 }, { "startTime": 123, "endTime": 223 }], "ints": [31, 32, 33, 34]} \ No newline at end of file