@@ -19,6 +19,8 @@ package org.apache.spark.sql.catalyst
1919
2020import java .sql .{Date , Timestamp }
2121
22+ import scala .reflect .runtime .universe .TypeTag
23+
2224import org .apache .spark .SparkFunSuite
2325import org .apache .spark .sql .catalyst .analysis .UnresolvedExtractValue
2426import org .apache .spark .sql .catalyst .expressions .{CreateNamedStruct , Expression , If , SpecificInternalRow , UpCast }
@@ -111,6 +113,10 @@ object TestingUDT {
111113class ScalaReflectionSuite extends SparkFunSuite {
112114 import org .apache .spark .sql .catalyst .ScalaReflection ._
113115
116+ // A helper method used to test `ScalaReflection.deserializerForType`.
117+ private def deserializerFor [T : TypeTag ]: Expression =
118+ deserializerForType(ScalaReflection .localTypeOf[T ])
119+
114120 test(" SQLUserDefinedType annotation on Scala structure" ) {
115121 val schema = schemaFor[TestingUDT .NestedStruct ]
116122 assert(schema === Schema (
@@ -269,7 +275,7 @@ class ScalaReflectionSuite extends SparkFunSuite {
269275 }
270276
271277 test(" SPARK 16792: Get correct deserializer for List[_]" ) {
272- val listDeserializer = deserializerForType( ScalaReflection .localTypeOf [List [Int ]])
278+ val listDeserializer = deserializerFor [List [Int ]]
273279 assert(listDeserializer.dataType == ObjectType (classOf [List [_]]))
274280 }
275281
@@ -278,38 +284,37 @@ class ScalaReflectionSuite extends SparkFunSuite {
278284 val queueSerializer = serializerForType(ScalaReflection .localTypeOf[Queue [Int ]])
279285 assert(queueSerializer.dataType ==
280286 ArrayType (IntegerType , containsNull = false ))
281- val queueDeserializer = deserializerForType( ScalaReflection .localTypeOf [Queue [Int ]])
287+ val queueDeserializer = deserializerFor [Queue [Int ]]
282288 assert(queueDeserializer.dataType == ObjectType (classOf [Queue [_]]))
283289
284290 import scala .collection .mutable .ArrayBuffer
285291 val arrayBufferSerializer = serializerForType(ScalaReflection .localTypeOf[ArrayBuffer [Int ]])
286292 assert(arrayBufferSerializer.dataType ==
287293 ArrayType (IntegerType , containsNull = false ))
288- val arrayBufferDeserializer = deserializerForType( ScalaReflection .localTypeOf [ArrayBuffer [Int ]])
294+ val arrayBufferDeserializer = deserializerFor [ArrayBuffer [Int ]]
289295 assert(arrayBufferDeserializer.dataType == ObjectType (classOf [ArrayBuffer [_]]))
290296 }
291297
292298 test(" serialize and deserialize arbitrary map types" ) {
293299 val mapSerializer = serializerForType(ScalaReflection .localTypeOf[Map [Int , Int ]])
294300 assert(mapSerializer.dataType ==
295301 MapType (IntegerType , IntegerType , valueContainsNull = false ))
296- val mapDeserializer = deserializerForType( ScalaReflection .localTypeOf [Map [Int , Int ]])
302+ val mapDeserializer = deserializerFor [Map [Int , Int ]]
297303 assert(mapDeserializer.dataType == ObjectType (classOf [Map [_, _]]))
298304
299305 import scala .collection .immutable .HashMap
300306 val hashMapSerializer = serializerForType(ScalaReflection .localTypeOf[HashMap [Int , Int ]])
301307 assert(hashMapSerializer.dataType ==
302308 MapType (IntegerType , IntegerType , valueContainsNull = false ))
303- val hashMapDeserializer = deserializerForType( ScalaReflection .localTypeOf [HashMap [Int , Int ]])
309+ val hashMapDeserializer = deserializerFor [HashMap [Int , Int ]]
304310 assert(hashMapDeserializer.dataType == ObjectType (classOf [HashMap [_, _]]))
305311
306312 import scala .collection .mutable .{LinkedHashMap => LHMap }
307313 val linkedHashMapSerializer = serializerForType(
308314 ScalaReflection .localTypeOf[LHMap [Long , String ]])
309315 assert(linkedHashMapSerializer.dataType ==
310316 MapType (LongType , StringType , valueContainsNull = true ))
311- val linkedHashMapDeserializer = deserializerForType(
312- ScalaReflection .localTypeOf[LHMap [Long , String ]])
317+ val linkedHashMapDeserializer = deserializerFor[LHMap [Long , String ]]
313318 assert(linkedHashMapDeserializer.dataType == ObjectType (classOf [LHMap [_, _]]))
314319 }
315320
@@ -318,7 +323,7 @@ class ScalaReflectionSuite extends SparkFunSuite {
318323 .collect {
319324 case If (_, _, s : CreateNamedStruct ) => s
320325 }.head
321- val deserializer = deserializerForType( ScalaReflection .localTypeOf [SpecialCharAsFieldData ])
326+ val deserializer = deserializerFor [SpecialCharAsFieldData ]
322327 assert(serializer.dataType(0 ).name == " field.1" )
323328 assert(serializer.dataType(1 ).name == " field 2" )
324329
@@ -332,8 +337,8 @@ class ScalaReflectionSuite extends SparkFunSuite {
332337 }
333338
334339 test(" SPARK-22472: add null check for top-level primitive values" ) {
335- assert(deserializerForType( ScalaReflection .localTypeOf [Int ]) .isInstanceOf [AssertNotNull ])
336- assert(! deserializerForType( ScalaReflection .localTypeOf [String ]) .isInstanceOf [AssertNotNull ])
340+ assert(deserializerFor [Int ].isInstanceOf [AssertNotNull ])
341+ assert(! deserializerFor [String ].isInstanceOf [AssertNotNull ])
337342 }
338343
339344 test(" SPARK-23025: schemaFor should support Null type" ) {
@@ -350,12 +355,8 @@ class ScalaReflectionSuite extends SparkFunSuite {
350355 val newInstance = deserializer.collect { case n : NewInstance => n}.head
351356 newInstance.arguments.count(_.isInstanceOf [AssertNotNull ])
352357 }
353- assert(numberOfCheckedArguments(
354- deserializerForType(ScalaReflection .localTypeOf[(Double , Double )])) == 2 )
355- assert(numberOfCheckedArguments(
356- deserializerForType(ScalaReflection .localTypeOf[(java.lang.Double , Int )])) == 1 )
357- assert(numberOfCheckedArguments(
358- deserializerForType(
359- ScalaReflection .localTypeOf[(java.lang.Integer , java.lang.Integer )])) == 0 )
358+ assert(numberOfCheckedArguments(deserializerFor[(Double , Double )]) == 2 )
359+ assert(numberOfCheckedArguments(deserializerFor[(java.lang.Double , Int )]) == 1 )
360+ assert(numberOfCheckedArguments(deserializerFor[(java.lang.Integer , java.lang.Integer )]) == 0 )
360361 }
361362}
0 commit comments