Skip to content

Commit 86609a9

Browse files
viiryacloud-fan
authored andcommitted
[SPARK-21567][SQL] Dataset should work with type alias
If we create a type alias for a type workable with Dataset, the type alias doesn't work with Dataset. A reproducible case looks like: object C { type TwoInt = (Int, Int) def tupleTypeAlias: TwoInt = (1, 1) } Seq(1).toDS().map(_ => ("", C.tupleTypeAlias)) It throws an exception like: type T1 is not a class scala.ScalaReflectionException: type T1 is not a class at scala.reflect.api.Symbols$SymbolApi$class.asClass(Symbols.scala:275) ... This patch accesses the dealias of type in many places in `ScalaReflection` to fix it. Added test case. Author: Liang-Chi Hsieh <viirya@gmail.com> Closes #18813 from viirya/SPARK-21567. (cherry picked from commit ee13041) Signed-off-by: Wenchen Fan <wenchen@databricks.com>
1 parent a1c1199 commit 86609a9

2 files changed

Lines changed: 38 additions & 13 deletions

File tree

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ object ScalaReflection extends ScalaReflection {
6363
def dataTypeFor[T : TypeTag]: DataType = dataTypeFor(localTypeOf[T])
6464

6565
private def dataTypeFor(tpe: `Type`): DataType = ScalaReflectionLock.synchronized {
66-
tpe match {
66+
tpe.dealias match {
6767
case t if t <:< definitions.IntTpe => IntegerType
6868
case t if t <:< definitions.LongTpe => LongType
6969
case t if t <:< definitions.DoubleTpe => DoubleType
@@ -93,7 +93,7 @@ object ScalaReflection extends ScalaReflection {
9393
* JVM form instead of the Scala Array that handles auto boxing.
9494
*/
9595
private def arrayClassFor(tpe: `Type`): ObjectType = ScalaReflectionLock.synchronized {
96-
val cls = tpe match {
96+
val cls = tpe.dealias match {
9797
case t if t <:< definitions.IntTpe => classOf[Array[Int]]
9898
case t if t <:< definitions.LongTpe => classOf[Array[Long]]
9999
case t if t <:< definitions.DoubleTpe => classOf[Array[Double]]
@@ -192,7 +192,7 @@ object ScalaReflection extends ScalaReflection {
192192
case _ => UpCast(expr, expected, walkedTypePath)
193193
}
194194

195-
tpe match {
195+
tpe.dealias match {
196196
case t if !dataTypeFor(t).isInstanceOf[ObjectType] => getPath
197197

198198
case t if t <:< localTypeOf[Option[_]] =>
@@ -479,7 +479,7 @@ object ScalaReflection extends ScalaReflection {
479479
}
480480
}
481481

482-
tpe match {
482+
tpe.dealias match {
483483
case _ if !inputObject.dataType.isInstanceOf[ObjectType] => inputObject
484484

485485
case t if t <:< localTypeOf[Option[_]] =>
@@ -633,7 +633,7 @@ object ScalaReflection extends ScalaReflection {
633633
* we also treat [[DefinedByConstructorParams]] as product type.
634634
*/
635635
def optionOfProductType(tpe: `Type`): Boolean = ScalaReflectionLock.synchronized {
636-
tpe match {
636+
tpe.dealias match {
637637
case t if t <:< localTypeOf[Option[_]] =>
638638
val TypeRef(_, _, Seq(optType)) = t
639639
definedByConstructorParams(optType)
@@ -680,7 +680,7 @@ object ScalaReflection extends ScalaReflection {
680680
/*
681681
* Retrieves the runtime class corresponding to the provided type.
682682
*/
683-
def getClassFromType(tpe: Type): Class[_] = mirror.runtimeClass(tpe.typeSymbol.asClass)
683+
def getClassFromType(tpe: Type): Class[_] = mirror.runtimeClass(tpe.dealias.typeSymbol.asClass)
684684

685685
case class Schema(dataType: DataType, nullable: Boolean)
686686

@@ -695,7 +695,7 @@ object ScalaReflection extends ScalaReflection {
695695

696696
/** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */
697697
def schemaFor(tpe: `Type`): Schema = ScalaReflectionLock.synchronized {
698-
tpe match {
698+
tpe.dealias match {
699699
case t if t.typeSymbol.annotations.exists(_.tpe =:= typeOf[SQLUserDefinedType]) =>
700700
val udt = getClassFromType(t).getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance()
701701
Schema(udt, nullable = true)
@@ -761,7 +761,7 @@ object ScalaReflection extends ScalaReflection {
761761
* Whether the fields of the given type is defined entirely by its constructor parameters.
762762
*/
763763
def definedByConstructorParams(tpe: Type): Boolean = {
764-
tpe <:< localTypeOf[Product] || tpe <:< localTypeOf[DefinedByConstructorParams]
764+
tpe.dealias <:< localTypeOf[Product] || tpe.dealias <:< localTypeOf[DefinedByConstructorParams]
765765
}
766766

767767
private val javaKeywords = Set("abstract", "assert", "boolean", "break", "byte", "case", "catch",
@@ -816,7 +816,7 @@ trait ScalaReflection {
816816
* synthetic classes, emulating behaviour in Java bytecode.
817817
*/
818818
def getClassNameFromType(tpe: `Type`): String = {
819-
tpe.erasure.typeSymbol.asClass.fullName
819+
tpe.dealias.erasure.typeSymbol.asClass.fullName
820820
}
821821

822822
/**
@@ -835,9 +835,10 @@ trait ScalaReflection {
835835
* support inner class.
836836
*/
837837
def getConstructorParameters(tpe: Type): Seq[(String, Type)] = {
838-
val formalTypeArgs = tpe.typeSymbol.asClass.typeParams
839-
val TypeRef(_, _, actualTypeArgs) = tpe
840-
val params = constructParams(tpe)
838+
val dealiasedTpe = tpe.dealias
839+
val formalTypeArgs = dealiasedTpe.typeSymbol.asClass.typeParams
840+
val TypeRef(_, _, actualTypeArgs) = dealiasedTpe
841+
val params = constructParams(dealiasedTpe)
841842
// if there are type variables to fill in, do the substitution (SomeClass[T] -> SomeClass[Int])
842843
if (actualTypeArgs.nonEmpty) {
843844
params.map { p =>
@@ -851,7 +852,7 @@ trait ScalaReflection {
851852
}
852853

853854
protected def constructParams(tpe: Type): Seq[Symbol] = {
854-
val constructorSymbol = tpe.member(nme.CONSTRUCTOR)
855+
val constructorSymbol = tpe.dealias.member(nme.CONSTRUCTOR)
855856
val params = if (constructorSymbol.isMethod) {
856857
constructorSymbol.asMethod.paramss
857858
} else {

sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,16 @@ import org.apache.spark.sql.types._
3434
case class TestDataPoint(x: Int, y: Double, s: String, t: TestDataPoint2)
3535
case class TestDataPoint2(x: Int, s: String)
3636

37+
object TestForTypeAlias {
38+
type TwoInt = (Int, Int)
39+
type ThreeInt = (TwoInt, Int)
40+
type SeqOfTwoInt = Seq[TwoInt]
41+
42+
def tupleTypeAlias: TwoInt = (1, 1)
43+
def nestedTupleTypeAlias: ThreeInt = ((1, 1), 2)
44+
def seqOfTupleTypeAlias: SeqOfTwoInt = Seq((1, 1), (2, 2))
45+
}
46+
3747
class DatasetSuite extends QueryTest with SharedSQLContext {
3848
import testImplicits._
3949

@@ -1210,6 +1220,20 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
12101220
checkAnswer(df.orderBy($"id"), expected)
12111221
checkAnswer(df.orderBy('id), expected)
12121222
}
1223+
1224+
test("SPARK-21567: Dataset should work with type alias") {
1225+
checkDataset(
1226+
Seq(1).toDS().map(_ => ("", TestForTypeAlias.tupleTypeAlias)),
1227+
("", (1, 1)))
1228+
1229+
checkDataset(
1230+
Seq(1).toDS().map(_ => ("", TestForTypeAlias.nestedTupleTypeAlias)),
1231+
("", ((1, 1), 2)))
1232+
1233+
checkDataset(
1234+
Seq(1).toDS().map(_ => ("", TestForTypeAlias.seqOfTupleTypeAlias)),
1235+
("", Seq((1, 1), (2, 2))))
1236+
}
12131237
}
12141238

12151239
case class WithImmutableMap(id: String, map_test: scala.collection.immutable.Map[Long, String])

0 commit comments

Comments
 (0)