Skip to content

Commit 078a071

Browse files
committed
Simplify test change.
1 parent 682fa4b commit 078a071

1 file changed

Lines changed: 18 additions & 17 deletions

File tree

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ package org.apache.spark.sql.catalyst
1919

2020
import java.sql.{Date, Timestamp}
2121

22+
import scala.reflect.runtime.universe.TypeTag
23+
2224
import org.apache.spark.SparkFunSuite
2325
import org.apache.spark.sql.catalyst.analysis.UnresolvedExtractValue
2426
import org.apache.spark.sql.catalyst.expressions.{CreateNamedStruct, Expression, If, SpecificInternalRow, UpCast}
@@ -111,6 +113,10 @@ object TestingUDT {
111113
class ScalaReflectionSuite extends SparkFunSuite {
112114
import org.apache.spark.sql.catalyst.ScalaReflection._
113115

116+
// A helper method used to test `ScalaReflection.deserializerForType`.
117+
private def deserializerFor[T: TypeTag]: Expression =
118+
deserializerForType(ScalaReflection.localTypeOf[T])
119+
114120
test("SQLUserDefinedType annotation on Scala structure") {
115121
val schema = schemaFor[TestingUDT.NestedStruct]
116122
assert(schema === Schema(
@@ -269,7 +275,7 @@ class ScalaReflectionSuite extends SparkFunSuite {
269275
}
270276

271277
test("SPARK 16792: Get correct deserializer for List[_]") {
272-
val listDeserializer = deserializerForType(ScalaReflection.localTypeOf[List[Int]])
278+
val listDeserializer = deserializerFor[List[Int]]
273279
assert(listDeserializer.dataType == ObjectType(classOf[List[_]]))
274280
}
275281

@@ -278,38 +284,37 @@ class ScalaReflectionSuite extends SparkFunSuite {
278284
val queueSerializer = serializerForType(ScalaReflection.localTypeOf[Queue[Int]])
279285
assert(queueSerializer.dataType ==
280286
ArrayType(IntegerType, containsNull = false))
281-
val queueDeserializer = deserializerForType(ScalaReflection.localTypeOf[Queue[Int]])
287+
val queueDeserializer = deserializerFor[Queue[Int]]
282288
assert(queueDeserializer.dataType == ObjectType(classOf[Queue[_]]))
283289

284290
import scala.collection.mutable.ArrayBuffer
285291
val arrayBufferSerializer = serializerForType(ScalaReflection.localTypeOf[ArrayBuffer[Int]])
286292
assert(arrayBufferSerializer.dataType ==
287293
ArrayType(IntegerType, containsNull = false))
288-
val arrayBufferDeserializer = deserializerForType(ScalaReflection.localTypeOf[ArrayBuffer[Int]])
294+
val arrayBufferDeserializer = deserializerFor[ArrayBuffer[Int]]
289295
assert(arrayBufferDeserializer.dataType == ObjectType(classOf[ArrayBuffer[_]]))
290296
}
291297

292298
test("serialize and deserialize arbitrary map types") {
293299
val mapSerializer = serializerForType(ScalaReflection.localTypeOf[Map[Int, Int]])
294300
assert(mapSerializer.dataType ==
295301
MapType(IntegerType, IntegerType, valueContainsNull = false))
296-
val mapDeserializer = deserializerForType(ScalaReflection.localTypeOf[Map[Int, Int]])
302+
val mapDeserializer = deserializerFor[Map[Int, Int]]
297303
assert(mapDeserializer.dataType == ObjectType(classOf[Map[_, _]]))
298304

299305
import scala.collection.immutable.HashMap
300306
val hashMapSerializer = serializerForType(ScalaReflection.localTypeOf[HashMap[Int, Int]])
301307
assert(hashMapSerializer.dataType ==
302308
MapType(IntegerType, IntegerType, valueContainsNull = false))
303-
val hashMapDeserializer = deserializerForType(ScalaReflection.localTypeOf[HashMap[Int, Int]])
309+
val hashMapDeserializer = deserializerFor[HashMap[Int, Int]]
304310
assert(hashMapDeserializer.dataType == ObjectType(classOf[HashMap[_, _]]))
305311

306312
import scala.collection.mutable.{LinkedHashMap => LHMap}
307313
val linkedHashMapSerializer = serializerForType(
308314
ScalaReflection.localTypeOf[LHMap[Long, String]])
309315
assert(linkedHashMapSerializer.dataType ==
310316
MapType(LongType, StringType, valueContainsNull = true))
311-
val linkedHashMapDeserializer = deserializerForType(
312-
ScalaReflection.localTypeOf[LHMap[Long, String]])
317+
val linkedHashMapDeserializer = deserializerFor[LHMap[Long, String]]
313318
assert(linkedHashMapDeserializer.dataType == ObjectType(classOf[LHMap[_, _]]))
314319
}
315320

@@ -318,7 +323,7 @@ class ScalaReflectionSuite extends SparkFunSuite {
318323
.collect {
319324
case If(_, _, s: CreateNamedStruct) => s
320325
}.head
321-
val deserializer = deserializerForType(ScalaReflection.localTypeOf[SpecialCharAsFieldData])
326+
val deserializer = deserializerFor[SpecialCharAsFieldData]
322327
assert(serializer.dataType(0).name == "field.1")
323328
assert(serializer.dataType(1).name == "field 2")
324329

@@ -332,8 +337,8 @@ class ScalaReflectionSuite extends SparkFunSuite {
332337
}
333338

334339
test("SPARK-22472: add null check for top-level primitive values") {
335-
assert(deserializerForType(ScalaReflection.localTypeOf[Int]).isInstanceOf[AssertNotNull])
336-
assert(!deserializerForType(ScalaReflection.localTypeOf[String]).isInstanceOf[AssertNotNull])
340+
assert(deserializerFor[Int].isInstanceOf[AssertNotNull])
341+
assert(!deserializerFor[String].isInstanceOf[AssertNotNull])
337342
}
338343

339344
test("SPARK-23025: schemaFor should support Null type") {
@@ -350,12 +355,8 @@ class ScalaReflectionSuite extends SparkFunSuite {
350355
val newInstance = deserializer.collect { case n: NewInstance => n}.head
351356
newInstance.arguments.count(_.isInstanceOf[AssertNotNull])
352357
}
353-
assert(numberOfCheckedArguments(
354-
deserializerForType(ScalaReflection.localTypeOf[(Double, Double)])) == 2)
355-
assert(numberOfCheckedArguments(
356-
deserializerForType(ScalaReflection.localTypeOf[(java.lang.Double, Int)])) == 1)
357-
assert(numberOfCheckedArguments(
358-
deserializerForType(
359-
ScalaReflection.localTypeOf[(java.lang.Integer, java.lang.Integer)])) == 0)
358+
assert(numberOfCheckedArguments(deserializerFor[(Double, Double)]) == 2)
359+
assert(numberOfCheckedArguments(deserializerFor[(java.lang.Double, Int)]) == 1)
360+
assert(numberOfCheckedArguments(deserializerFor[(java.lang.Integer, java.lang.Integer)]) == 0)
360361
}
361362
}

0 commit comments

Comments
 (0)