Skip to content

Commit ed4f4c9

Browse files
committed
Merge remote-tracking branch 'upstream/master' into SPARK-24762-refactor
2 parents 552e8dd + 584e767 commit ed4f4c9

5 files changed

Lines changed: 325 additions & 162 deletions

File tree

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -275,24 +275,20 @@ object JavaTypeInference {
275275

276276
case _ if mapType.isAssignableFrom(typeToken) =>
277277
val (keyType, valueType) = mapKeyValueType(typeToken)
278-
val keyDataType = inferDataType(keyType)._1
279-
val valueDataType = inferDataType(valueType)._1
280278

281279
val keyData =
282280
Invoke(
283-
MapObjects(
281+
UnresolvedMapObjects(
284282
p => deserializerFor(keyType, p),
285-
Invoke(path, "keyArray", ArrayType(keyDataType)),
286-
keyDataType),
283+
GetKeyArrayFromMap(path)),
287284
"array",
288285
ObjectType(classOf[Array[Any]]))
289286

290287
val valueData =
291288
Invoke(
292-
MapObjects(
289+
UnresolvedMapObjects(
293290
p => deserializerFor(valueType, p),
294-
Invoke(path, "valueArray", ArrayType(valueDataType)),
295-
valueDataType),
291+
GetValueArrayFromMap(path)),
296292
"array",
297293
ObjectType(classOf[Array[Any]]))
298294

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ import org.apache.spark.serializer._
3030
import org.apache.spark.sql.Row
3131
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, ScalaReflection}
3232
import org.apache.spark.sql.catalyst.ScalaReflection.universe.TermName
33+
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
3334
import org.apache.spark.sql.catalyst.encoders.RowEncoder
3435
import org.apache.spark.sql.catalyst.expressions._
3536
import org.apache.spark.sql.catalyst.expressions.codegen._
@@ -1787,3 +1788,78 @@ case class ValidateExternalType(child: Expression, expected: DataType)
17871788
ev.copy(code = code, isNull = input.isNull)
17881789
}
17891790
}
1791+
1792+
object GetKeyArrayFromMap {
1793+
1794+
/**
1795+
* Construct an instance of GetArrayFromMap case class
1796+
* extracting a key array from a Map expression.
1797+
*
1798+
* @param child a Map expression to extract a key array from
1799+
*/
1800+
def apply(child: Expression): Expression = {
1801+
GetArrayFromMap(
1802+
child,
1803+
"keyArray",
1804+
_.keyArray(),
1805+
{ case MapType(kt, _, _) => kt })
1806+
}
1807+
}
1808+
1809+
object GetValueArrayFromMap {
1810+
1811+
/**
1812+
* Construct an instance of GetArrayFromMap case class
1813+
* extracting a value array from a Map expression.
1814+
*
1815+
* @param child a Map expression to extract a value array from
1816+
*/
1817+
def apply(child: Expression): Expression = {
1818+
GetArrayFromMap(
1819+
child,
1820+
"valueArray",
1821+
_.valueArray(),
1822+
{ case MapType(_, vt, _) => vt })
1823+
}
1824+
}
1825+
1826+
/**
1827+
* Extracts a key/value array from a Map expression.
1828+
*
1829+
* @param child a Map expression to extract an array from
1830+
* @param functionName name of the function that is invoked to extract an array
1831+
* @param arrayGetter function extracting `ArrayData` from `MapData`
1832+
* @param elementTypeGetter function extracting array element `DataType` from `MapType`
1833+
*/
1834+
case class GetArrayFromMap private(
1835+
child: Expression,
1836+
functionName: String,
1837+
arrayGetter: MapData => ArrayData,
1838+
elementTypeGetter: MapType => DataType) extends UnaryExpression with NonSQLExpression {
1839+
1840+
private lazy val encodedFunctionName: String = TermName(functionName).encodedName.toString
1841+
1842+
lazy val dataType: DataType = {
1843+
val mt: MapType = child.dataType.asInstanceOf[MapType]
1844+
ArrayType(elementTypeGetter(mt))
1845+
}
1846+
1847+
override def checkInputDataTypes(): TypeCheckResult = {
1848+
if (child.dataType.isInstanceOf[MapType]) {
1849+
TypeCheckResult.TypeCheckSuccess
1850+
} else {
1851+
TypeCheckResult.TypeCheckFailure(
1852+
s"Can't extract array from $child: need map type but got ${child.dataType.catalogString}")
1853+
}
1854+
}
1855+
1856+
override def nullSafeEval(input: Any): Any = {
1857+
arrayGetter(input.asInstanceOf[MapData])
1858+
}
1859+
1860+
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
1861+
defineCodeGen(ctx, ev, childValue => s"$childValue.$encodedFunctionName()")
1862+
}
1863+
1864+
override def toString: String = s"$child.$functionName"
1865+
}
Lines changed: 240 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,240 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package test.org.apache.spark.sql;
19+
20+
import java.io.Serializable;
21+
import java.util.*;
22+
23+
import org.junit.*;
24+
25+
import org.apache.spark.sql.Dataset;
26+
import org.apache.spark.sql.Encoder;
27+
import org.apache.spark.sql.Encoders;
28+
import org.apache.spark.sql.test.TestSparkSession;
29+
30+
public class JavaBeanDeserializationSuite implements Serializable {
31+
32+
private TestSparkSession spark;
33+
34+
@Before
35+
public void setUp() {
36+
spark = new TestSparkSession();
37+
}
38+
39+
@After
40+
public void tearDown() {
41+
spark.stop();
42+
spark = null;
43+
}
44+
45+
private static final List<ArrayRecord> ARRAY_RECORDS = new ArrayList<>();
46+
47+
static {
48+
ARRAY_RECORDS.add(
49+
new ArrayRecord(1, Arrays.asList(new Interval(111, 211), new Interval(121, 221)))
50+
);
51+
ARRAY_RECORDS.add(
52+
new ArrayRecord(2, Arrays.asList(new Interval(112, 212), new Interval(122, 222)))
53+
);
54+
ARRAY_RECORDS.add(
55+
new ArrayRecord(3, Arrays.asList(new Interval(113, 213), new Interval(123, 223)))
56+
);
57+
}
58+
59+
@Test
60+
public void testBeanWithArrayFieldDeserialization() {
61+
62+
Encoder<ArrayRecord> encoder = Encoders.bean(ArrayRecord.class);
63+
64+
Dataset<ArrayRecord> dataset = spark
65+
.read()
66+
.format("json")
67+
.schema("id int, intervals array<struct<startTime: bigint, endTime: bigint>>")
68+
.load("src/test/resources/test-data/with-array-fields.json")
69+
.as(encoder);
70+
71+
List<ArrayRecord> records = dataset.collectAsList();
72+
Assert.assertEquals(records, ARRAY_RECORDS);
73+
}
74+
75+
private static final List<MapRecord> MAP_RECORDS = new ArrayList<>();
76+
77+
static {
78+
MAP_RECORDS.add(new MapRecord(1,
79+
toMap(Arrays.asList("a", "b"), Arrays.asList(new Interval(111, 211), new Interval(121, 221)))
80+
));
81+
MAP_RECORDS.add(new MapRecord(2,
82+
toMap(Arrays.asList("a", "b"), Arrays.asList(new Interval(112, 212), new Interval(122, 222)))
83+
));
84+
MAP_RECORDS.add(new MapRecord(3,
85+
toMap(Arrays.asList("a", "b"), Arrays.asList(new Interval(113, 213), new Interval(123, 223)))
86+
));
87+
MAP_RECORDS.add(new MapRecord(4, new HashMap<>()));
88+
MAP_RECORDS.add(new MapRecord(5, null));
89+
}
90+
91+
private static <K, V> Map<K, V> toMap(Collection<K> keys, Collection<V> values) {
92+
Map<K, V> map = new HashMap<>();
93+
Iterator<K> keyI = keys.iterator();
94+
Iterator<V> valueI = values.iterator();
95+
while (keyI.hasNext() && valueI.hasNext()) {
96+
map.put(keyI.next(), valueI.next());
97+
}
98+
return map;
99+
}
100+
101+
@Test
102+
public void testBeanWithMapFieldsDeserialization() {
103+
104+
Encoder<MapRecord> encoder = Encoders.bean(MapRecord.class);
105+
106+
Dataset<MapRecord> dataset = spark
107+
.read()
108+
.format("json")
109+
.schema("id int, intervals map<string, struct<startTime: bigint, endTime: bigint>>")
110+
.load("src/test/resources/test-data/with-map-fields.json")
111+
.as(encoder);
112+
113+
List<MapRecord> records = dataset.collectAsList();
114+
115+
Assert.assertEquals(records, MAP_RECORDS);
116+
}
117+
118+
public static class ArrayRecord {
119+
120+
private int id;
121+
private List<Interval> intervals;
122+
123+
public ArrayRecord() { }
124+
125+
ArrayRecord(int id, List<Interval> intervals) {
126+
this.id = id;
127+
this.intervals = intervals;
128+
}
129+
130+
public int getId() {
131+
return id;
132+
}
133+
134+
public void setId(int id) {
135+
this.id = id;
136+
}
137+
138+
public List<Interval> getIntervals() {
139+
return intervals;
140+
}
141+
142+
public void setIntervals(List<Interval> intervals) {
143+
this.intervals = intervals;
144+
}
145+
146+
@Override
147+
public boolean equals(Object obj) {
148+
if (!(obj instanceof ArrayRecord)) return false;
149+
ArrayRecord other = (ArrayRecord) obj;
150+
return (other.id == this.id) && other.intervals.equals(this.intervals);
151+
}
152+
153+
@Override
154+
public String toString() {
155+
return String.format("{ id: %d, intervals: %s }", id, intervals);
156+
}
157+
}
158+
159+
public static class MapRecord {
160+
161+
private int id;
162+
private Map<String, Interval> intervals;
163+
164+
public MapRecord() { }
165+
166+
MapRecord(int id, Map<String, Interval> intervals) {
167+
this.id = id;
168+
this.intervals = intervals;
169+
}
170+
171+
public int getId() {
172+
return id;
173+
}
174+
175+
public void setId(int id) {
176+
this.id = id;
177+
}
178+
179+
public Map<String, Interval> getIntervals() {
180+
return intervals;
181+
}
182+
183+
public void setIntervals(Map<String, Interval> intervals) {
184+
this.intervals = intervals;
185+
}
186+
187+
@Override
188+
public boolean equals(Object obj) {
189+
if (!(obj instanceof MapRecord)) return false;
190+
MapRecord other = (MapRecord) obj;
191+
return (other.id == this.id) && Objects.equals(other.intervals, this.intervals);
192+
}
193+
194+
@Override
195+
public String toString() {
196+
return String.format("{ id: %d, intervals: %s }", id, intervals);
197+
}
198+
}
199+
200+
public static class Interval {
201+
202+
private long startTime;
203+
private long endTime;
204+
205+
public Interval() { }
206+
207+
Interval(long startTime, long endTime) {
208+
this.startTime = startTime;
209+
this.endTime = endTime;
210+
}
211+
212+
public long getStartTime() {
213+
return startTime;
214+
}
215+
216+
public void setStartTime(long startTime) {
217+
this.startTime = startTime;
218+
}
219+
220+
public long getEndTime() {
221+
return endTime;
222+
}
223+
224+
public void setEndTime(long endTime) {
225+
this.endTime = endTime;
226+
}
227+
228+
@Override
229+
public boolean equals(Object obj) {
230+
if (!(obj instanceof Interval)) return false;
231+
Interval other = (Interval) obj;
232+
return (other.startTime == this.startTime) && (other.endTime == this.endTime);
233+
}
234+
235+
@Override
236+
public String toString() {
237+
return String.format("[%d,%d]", startTime, endTime);
238+
}
239+
}
240+
}

0 commit comments

Comments
 (0)