diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 64ea236532839..c8542d0f2f7de 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -788,12 +788,37 @@ object ScalaReflection extends ScalaReflection { } /** - * Finds an accessible constructor with compatible parameters. This is a more flexible search - * than the exact matching algorithm in `Class.getConstructor`. The first assignment-compatible - * matching constructor is returned. Otherwise, it returns `None`. + * Finds an accessible constructor with compatible parameters. This is a more flexible search than + * the exact matching algorithm in `Class.getConstructor`. The first assignment-compatible + * matching constructor is returned if it exists. Otherwise, we check for additional compatible + * constructors defined in the companion object as `apply` methods. Otherwise, it returns `None`. */ - def findConstructor(cls: Class[_], paramTypes: Seq[Class[_]]): Option[Constructor[_]] = { - Option(ConstructorUtils.getMatchingAccessibleConstructor(cls, paramTypes: _*)) + def findConstructor[T](cls: Class[T], paramTypes: Seq[Class[_]]): Option[Seq[AnyRef] => T] = { + Option(ConstructorUtils.getMatchingAccessibleConstructor(cls, paramTypes: _*)) match { + case Some(c) => Some(x => c.newInstance(x: _*).asInstanceOf[T]) + case None => + val companion = mirror.staticClass(cls.getName).companion + val moduleMirror = mirror.reflectModule(companion.asModule) + val applyMethods = companion.asTerm.typeSignature + .member(universe.TermName("apply")).asTerm.alternatives + applyMethods.find { method => + val params = method.typeSignature.paramLists.head + // Check that the needed params are the same length and of matching types + params.size == paramTypes.tail.size && + params.zip(paramTypes.tail).forall { case(ps, pc) => + ps.typeSignature.typeSymbol == mirror.classSymbol(pc) + } + }.map { applyMethodSymbol => + val expectedArgsCount = applyMethodSymbol.typeSignature.paramLists.head.size + val instanceMirror = mirror.reflect(moduleMirror.instance) + val method = instanceMirror.reflectMethod(applyMethodSymbol.asMethod) + (_args: Seq[AnyRef]) => { + // Drop the "outer" argument if it is provided + val args = if (_args.size == expectedArgsCount) _args else _args.tail + method.apply(args: _*).asInstanceOf[T] + } + } + } } /** @@ -973,8 +998,19 @@ trait ScalaReflection extends Logging { } } + /** + * If our type is a Scala trait it may have a companion object that + * only defines a constructor via `apply` method. + */ + private def getCompanionConstructor(tpe: Type): Symbol = { + tpe.typeSymbol.asClass.companion.asTerm.typeSignature.member(universe.TermName("apply")) + } + protected def constructParams(tpe: Type): Seq[Symbol] = { - val constructorSymbol = tpe.dealias.member(termNames.CONSTRUCTOR) + val constructorSymbol = tpe.member(termNames.CONSTRUCTOR) match { + case NoSymbol => getCompanionConstructor(tpe) + case sym => sym + } val params = if (constructorSymbol.isMethod) { constructorSymbol.asMethod.paramLists } else { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 4fd36a47cef52..59c897b6a53ce 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -462,12 +462,12 @@ case class NewInstance( val d = outerObj.getClass +: paramTypes val c = getConstructor(outerObj.getClass +: paramTypes) (args: Seq[AnyRef]) => { - c.newInstance(outerObj +: args: _*) + c(outerObj +: args) } }.getOrElse { val c = getConstructor(paramTypes) (args: Seq[AnyRef]) => { - c.newInstance(args: _*) + c(args) } } } @@ -486,10 +486,16 @@ case class NewInstance( ev.isNull = resultIsNull - val constructorCall = outer.map { gen => - s"${gen.value}.new ${cls.getSimpleName}($argString)" - }.getOrElse { - s"new $className($argString)" + val constructorCall = cls.getConstructors.size match { + // If there are no constructors, the `new` method will fail. In + // this case we can try to call the apply method constructor + // that might be defined on the companion object. + case 0 => s"$className$$.MODULE$$.apply($argString)" + case _ => outer.map { gen => + s"${gen.value}.new ${cls.getSimpleName}($argString)" + }.getOrElse { + s"new $className($argString)" + } } val code = code""" diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala index d98589db323cc..80824cc2a7f21 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -109,6 +109,30 @@ object TestingUDT { } } +/** An example derived from Twitter/Scrooge codegen for thrift */ +object ScroogeLikeExample { + def apply(x: Int): ScroogeLikeExample = new Immutable(x) + + def unapply(_item: ScroogeLikeExample): Option[Int] = Some(_item.x) + + class Immutable(val x: Int) extends ScroogeLikeExample +} + +trait ScroogeLikeExample extends Product1[Int] with Serializable { + import ScroogeLikeExample._ + + def x: Int + + def _1: Int = x + + override def canEqual(other: Any): Boolean = other.isInstanceOf[ScroogeLikeExample] + + override def equals(other: Any): Boolean = + canEqual(other) && + this.x == other.asInstanceOf[ScroogeLikeExample].x + + override def hashCode: Int = x +} class ScalaReflectionSuite extends SparkFunSuite { import org.apache.spark.sql.catalyst.ScalaReflection._ @@ -362,4 +386,11 @@ class ScalaReflectionSuite extends SparkFunSuite { assert(numberOfCheckedArguments(deserializerFor[(java.lang.Double, Int)]) == 1) assert(numberOfCheckedArguments(deserializerFor[(java.lang.Integer, java.lang.Integer)]) == 0) } + + test("SPARK-8288: schemaFor works for a class with only a companion object constructor") { + val schema = schemaFor[ScroogeLikeExample] + assert(schema === Schema( + StructType(Seq( + StructField("x", IntegerType, nullable = false))), nullable = true)) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala index 16842c1bcc8cb..436675bf50353 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala @@ -29,6 +29,7 @@ import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} import org.apache.spark.sql.{RandomDataGenerator, Row} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, JavaTypeInference, ScalaReflection} +import org.apache.spark.sql.catalyst.ScroogeLikeExample import org.apache.spark.sql.catalyst.analysis.{ResolveTimeZone, SimpleAnalyzer, UnresolvedDeserializer} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.encoders._ @@ -410,6 +411,15 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { dataType = ObjectType(classOf[outerObj.Inner]), outerPointer = Some(() => outerObj)) checkObjectExprEvaluation(newInst2, new outerObj.Inner(1)) + + // SPARK-8288: A class with only a companion object constructor + val newInst3 = NewInstance( + cls = classOf[ScroogeLikeExample], + arguments = Literal(1) :: Nil, + propagateNull = false, + dataType = ObjectType(classOf[ScroogeLikeExample]), + outerPointer = Some(() => outerObj)) + checkObjectExprEvaluation(newInst3, ScroogeLikeExample(1)) } test("LambdaVariable should support interpreted execution") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index ac677e8ec6bc2..540fbff6a3a63 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -21,6 +21,7 @@ import java.io.{Externalizable, ObjectInput, ObjectOutput} import java.sql.{Date, Timestamp} import org.apache.spark.SparkException +import org.apache.spark.sql.catalyst.ScroogeLikeExample import org.apache.spark.sql.catalyst.encoders.{OuterScopes, RowEncoder} import org.apache.spark.sql.catalyst.plans.{LeftAnti, LeftSemi} import org.apache.spark.sql.catalyst.util.sideBySide @@ -1570,6 +1571,13 @@ class DatasetSuite extends QueryTest with SharedSQLContext { val agg = ds.groupByKey(x => x).agg(sum("_1").as[Long], sum($"_2" + 1).as[Long]) checkDatasetUnorderly(agg, ((1, 2), 1L, 3L), ((2, 3), 2L, 4L), ((3, 4), 3L, 5L)) } + + test("SPARK-8288: class with only a companion object constructor") { + val data = Seq(ScroogeLikeExample(1), ScroogeLikeExample(2)) + val ds = data.toDS + checkDataset(ds, data: _*) + checkAnswer(ds.select("x"), Seq(Row(1), Row(2))) + } } case class TestDataUnion(x: Int, y: Int, z: Int)