From ae95a19169ccab1bad4b6a2585b75767f5be7e7f Mon Sep 17 00:00:00 2001 From: "Jungtaek Lim (HeartSaVioR)" Date: Fri, 8 Mar 2019 06:42:55 +0900 Subject: [PATCH 1/2] [SPARK-27001][SQL][FOLLOWUP] Address primitive array type for serializer --- .../sql/catalyst/JavaTypeInference.scala | 7 +++- .../spark/sql/catalyst/ScalaReflection.scala | 4 +-- .../sql/JavaBeanDeserializationSuite.java | 32 +++++++++++++------ .../test-data/with-array-fields.json | 6 ++-- 4 files changed, 34 insertions(+), 15 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index 933a6dbeb705..39132139237c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -362,7 +362,12 @@ object JavaTypeInference { def toCatalystArray(input: Expression, elementType: TypeToken[_]): Expression = { val (dataType, nullable) = inferDataType(elementType) if (ScalaReflection.isNativeType(dataType)) { - createSerializerForGenericArray(input, dataType, nullable = nullable) + val cls = input.dataType.asInstanceOf[ObjectType].cls + if (cls.isArray && cls.getComponentType.isPrimitive) { + createSerializerForPrimitiveArray(input, dataType) + } else { + createSerializerForGenericArray(input, dataType, nullable = nullable) + } } else { createSerializerForMapObjects(input, ObjectType(elementType.getRawType), serializerFor(_, elementType)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 5b3109af6a53..d8d268a77ca1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -394,8 +394,8 @@ object ScalaReflection extends ScalaReflection { createSerializerForMapObjects(input, dt, serializerFor(_, elementType, newPath, seenTypeSet)) - case dt @ (BooleanType | ByteType | ShortType | IntegerType | LongType | - FloatType | DoubleType) => + case dt @ (BooleanType | ByteType | ShortType | IntegerType | LongType | + FloatType | DoubleType) => val cls = input.dataType.asInstanceOf[ObjectType].cls if (cls.isArray && cls.getComponentType.isPrimitive) { createSerializerForPrimitiveArray(input, dt) diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaBeanDeserializationSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaBeanDeserializationSuite.java index 49ff522cee8e..3b9b3c4b2a49 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaBeanDeserializationSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaBeanDeserializationSuite.java @@ -47,25 +47,27 @@ public void tearDown() { static { ARRAY_RECORDS.add( - new ArrayRecord(1, Arrays.asList(new Interval(111, 211), new Interval(121, 221))) + new ArrayRecord(1, Arrays.asList(new Interval(111, 211), new Interval(121, 221)), + new int[] { 11, 12, 13, 14 }) ); ARRAY_RECORDS.add( - new ArrayRecord(2, Arrays.asList(new Interval(112, 212), new Interval(122, 222))) + new ArrayRecord(2, Arrays.asList(new Interval(112, 212), new Interval(122, 222)), + new int[] { 21, 22, 23, 24 }) ); ARRAY_RECORDS.add( - new ArrayRecord(3, Arrays.asList(new Interval(113, 213), new Interval(123, 223))) + new ArrayRecord(3, Arrays.asList(new Interval(113, 213), new Interval(123, 223)), + new int[] { 31, 32, 33, 34 }) ); } @Test public void testBeanWithArrayFieldDeserialization() { - Encoder encoder = Encoders.bean(ArrayRecord.class); Dataset dataset = spark .read() .format("json") - .schema("id int, intervals array>") + .schema("id int, intervals array>, ints array") .load("src/test/resources/test-data/with-array-fields.json") .as(encoder); @@ -223,12 +225,14 @@ public static class ArrayRecord { private int id; private List intervals; + private int[] ints; public ArrayRecord() { } - ArrayRecord(int id, List intervals) { + ArrayRecord(int id, List intervals, int[] ints) { this.id = id; this.intervals = intervals; + this.ints = ints; } public int getId() { @@ -247,21 +251,31 @@ public void setIntervals(List intervals) { this.intervals = intervals; } + public int[] getInts() { + return ints; + } + + public void setInts(int[] ints) { + this.ints = ints; + } + @Override public int hashCode() { - return id ^ Objects.hashCode(intervals); + return id ^ Objects.hashCode(intervals) ^ Objects.hashCode(ints); } @Override public boolean equals(Object obj) { if (!(obj instanceof ArrayRecord)) return false; ArrayRecord other = (ArrayRecord) obj; - return (other.id == this.id) && Objects.equals(other.intervals, this.intervals); + return (other.id == this.id) && Objects.equals(other.intervals, this.intervals) && + Arrays.equals(other.ints, ints); } @Override public String toString() { - return String.format("{ id: %d, intervals: %s }", id, intervals); + return String.format("{ id: %d, intervals: %s, ints: %s }", id, intervals, + Arrays.toString(ints)); } } diff --git a/sql/core/src/test/resources/test-data/with-array-fields.json b/sql/core/src/test/resources/test-data/with-array-fields.json index ff3674af2fbc..09022ec02895 100644 --- a/sql/core/src/test/resources/test-data/with-array-fields.json +++ b/sql/core/src/test/resources/test-data/with-array-fields.json @@ -1,3 +1,3 @@ -{ "id": 1, "intervals": [{ "startTime": 111, "endTime": 211 }, { "startTime": 121, "endTime": 221 }]} -{ "id": 2, "intervals": [{ "startTime": 112, "endTime": 212 }, { "startTime": 122, "endTime": 222 }]} -{ "id": 3, "intervals": [{ "startTime": 113, "endTime": 213 }, { "startTime": 123, "endTime": 223 }]} \ No newline at end of file +{ "id": 1, "intervals": [{ "startTime": 111, "endTime": 211 }, { "startTime": 121, "endTime": 221 }], "ints": [11, 12, 13, 14]} +{ "id": 2, "intervals": [{ "startTime": 112, "endTime": 212 }, { "startTime": 122, "endTime": 222 }], "ints": [21, 22, 23, 24]} +{ "id": 3, "intervals": [{ "startTime": 113, "endTime": 213 }, { "startTime": 123, "endTime": 223 }], "ints": [31, 32, 33, 34]} \ No newline at end of file From 8f856ea25af518392e647f9e9acfe7e0108b4f42 Mon Sep 17 00:00:00 2001 From: "Jungtaek Lim (HeartSaVioR)" Date: Fri, 8 Mar 2019 08:04:03 +0900 Subject: [PATCH 2/2] Fix java lint --- .../org/apache/spark/sql/JavaBeanDeserializationSuite.java | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaBeanDeserializationSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaBeanDeserializationSuite.java index 3b9b3c4b2a49..f59afef36a5a 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaBeanDeserializationSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaBeanDeserializationSuite.java @@ -67,7 +67,8 @@ public void testBeanWithArrayFieldDeserialization() { Dataset dataset = spark .read() .format("json") - .schema("id int, intervals array>, ints array") + .schema("id int, intervals array>, " + + "ints array") .load("src/test/resources/test-data/with-array-fields.json") .as(encoder);