Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -788,12 +788,37 @@ object ScalaReflection extends ScalaReflection {
}

/**
* Finds an accessible constructor with compatible parameters. This is a more flexible search
* than the exact matching algorithm in `Class.getConstructor`. The first assignment-compatible
* matching constructor is returned. Otherwise, it returns `None`.
* Finds an accessible constructor with compatible parameters. This is a more flexible search than
* the exact matching algorithm in `Class.getConstructor`. The first assignment-compatible
* matching constructor is returned if it exists. Otherwise, we check for additional compatible
* constructors defined in the companion object as `apply` methods. Otherwise, it returns `None`.
*/
def findConstructor(cls: Class[_], paramTypes: Seq[Class[_]]): Option[Constructor[_]] = {
Option(ConstructorUtils.getMatchingAccessibleConstructor(cls, paramTypes: _*))
def findConstructor[T](cls: Class[T], paramTypes: Seq[Class[_]]): Option[Seq[AnyRef] => T] = {
Option(ConstructorUtils.getMatchingAccessibleConstructor(cls, paramTypes: _*)) match {
case Some(c) => Some(x => c.newInstance(x: _*).asInstanceOf[T])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this cast isn't needed because Class[T].newInstance returns T already, but it's fine to leave.

case None =>
val companion = mirror.staticClass(cls.getName).companion
val moduleMirror = mirror.reflectModule(companion.asModule)
val applyMethods = companion.asTerm.typeSignature
.member(universe.TermName("apply")).asTerm.alternatives
applyMethods.find { method =>
val params = method.typeSignature.paramLists.head
// Check that the needed params are the same length and of matching types
params.size == paramTypes.tail.size &&
params.zip(paramTypes.tail).forall { case(ps, pc) =>
ps.typeSignature.typeSymbol == mirror.classSymbol(pc)
}
}.map { applyMethodSymbol =>
val expectedArgsCount = applyMethodSymbol.typeSignature.paramLists.head.size
val instanceMirror = mirror.reflect(moduleMirror.instance)
val method = instanceMirror.reflectMethod(applyMethodSymbol.asMethod)
(_args: Seq[AnyRef]) => {
// Drop the "outer" argument if it is provided
val args = if (_args.size == expectedArgsCount) _args else _args.tail
method.apply(args: _*).asInstanceOf[T]
}
}
}
}

/**
Expand Down Expand Up @@ -973,8 +998,19 @@ trait ScalaReflection extends Logging {
}
}

/**
* If our type is a Scala trait it may have a companion object that
* only defines a constructor via `apply` method.
*/
private def getCompanionConstructor(tpe: Type): Symbol = {
tpe.typeSymbol.asClass.companion.asTerm.typeSignature.member(universe.TermName("apply"))
}

protected def constructParams(tpe: Type): Seq[Symbol] = {
val constructorSymbol = tpe.dealias.member(termNames.CONSTRUCTOR)
val constructorSymbol = tpe.member(termNames.CONSTRUCTOR) match {
case NoSymbol => getCompanionConstructor(tpe)
case sym => sym
}
val params = if (constructorSymbol.isMethod) {
constructorSymbol.asMethod.paramLists
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -462,12 +462,12 @@ case class NewInstance(
val d = outerObj.getClass +: paramTypes
val c = getConstructor(outerObj.getClass +: paramTypes)
(args: Seq[AnyRef]) => {
c.newInstance(outerObj +: args: _*)
c(outerObj +: args)
}
}.getOrElse {
val c = getConstructor(paramTypes)
(args: Seq[AnyRef]) => {
c.newInstance(args: _*)
c(args)
}
}
}
Expand All @@ -486,10 +486,16 @@ case class NewInstance(

ev.isNull = resultIsNull

val constructorCall = outer.map { gen =>
s"${gen.value}.new ${cls.getSimpleName}($argString)"
}.getOrElse {
s"new $className($argString)"
val constructorCall = cls.getConstructors.size match {
// If there are no constructors, the `new` method will fail. In
// this case we can try to call the apply method constructor
// that might be defined on the companion object.
case 0 => s"$className$$.MODULE$$.apply($argString)"
case _ => outer.map { gen =>
s"${gen.value}.new ${cls.getSimpleName}($argString)"
}.getOrElse {
s"new $className($argString)"
}
}

val code = code"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,30 @@ object TestingUDT {
}
}

/** An example derived from Twitter/Scrooge codegen for thrift */
object ScroogeLikeExample {
def apply(x: Int): ScroogeLikeExample = new Immutable(x)

def unapply(_item: ScroogeLikeExample): Option[Int] = Some(_item.x)

class Immutable(val x: Int) extends ScroogeLikeExample
}

trait ScroogeLikeExample extends Product1[Int] with Serializable {
import ScroogeLikeExample._

def x: Int

def _1: Int = x

override def canEqual(other: Any): Boolean = other.isInstanceOf[ScroogeLikeExample]

override def equals(other: Any): Boolean =
canEqual(other) &&
this.x == other.asInstanceOf[ScroogeLikeExample].x

override def hashCode: Int = x
}

class ScalaReflectionSuite extends SparkFunSuite {
import org.apache.spark.sql.catalyst.ScalaReflection._
Expand Down Expand Up @@ -362,4 +386,11 @@ class ScalaReflectionSuite extends SparkFunSuite {
assert(numberOfCheckedArguments(deserializerFor[(java.lang.Double, Int)]) == 1)
assert(numberOfCheckedArguments(deserializerFor[(java.lang.Integer, java.lang.Integer)]) == 0)
}

test("SPARK-8288: schemaFor works for a class with only a companion object constructor") {
val schema = schemaFor[ScroogeLikeExample]
assert(schema === Schema(
StructType(Seq(
StructField("x", IntegerType, nullable = false))), nullable = true))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.serializer.{JavaSerializer, KryoSerializer}
import org.apache.spark.sql.{RandomDataGenerator, Row}
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, JavaTypeInference, ScalaReflection}
import org.apache.spark.sql.catalyst.ScroogeLikeExample
import org.apache.spark.sql.catalyst.analysis.{ResolveTimeZone, SimpleAnalyzer, UnresolvedDeserializer}
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.encoders._
Expand Down Expand Up @@ -410,6 +411,15 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
dataType = ObjectType(classOf[outerObj.Inner]),
outerPointer = Some(() => outerObj))
checkObjectExprEvaluation(newInst2, new outerObj.Inner(1))

// SPARK-8288: A class with only a companion object constructor
val newInst3 = NewInstance(
cls = classOf[ScroogeLikeExample],
arguments = Literal(1) :: Nil,
propagateNull = false,
dataType = ObjectType(classOf[ScroogeLikeExample]),
outerPointer = Some(() => outerObj))
checkObjectExprEvaluation(newInst3, ScroogeLikeExample(1))
}

test("LambdaVariable should support interpreted execution") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import java.io.{Externalizable, ObjectInput, ObjectOutput}
import java.sql.{Date, Timestamp}

import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst.ScroogeLikeExample
import org.apache.spark.sql.catalyst.encoders.{OuterScopes, RowEncoder}
import org.apache.spark.sql.catalyst.plans.{LeftAnti, LeftSemi}
import org.apache.spark.sql.catalyst.util.sideBySide
Expand Down Expand Up @@ -1570,6 +1571,13 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
val agg = ds.groupByKey(x => x).agg(sum("_1").as[Long], sum($"_2" + 1).as[Long])
checkDatasetUnorderly(agg, ((1, 2), 1L, 3L), ((2, 3), 2L, 4L), ((3, 4), 3L, 5L))
}

test("SPARK-8288: class with only a companion object constructor") {
val data = Seq(ScroogeLikeExample(1), ScroogeLikeExample(2))
val ds = data.toDS
checkDataset(ds, data: _*)
checkAnswer(ds.select("x"), Seq(Row(1), Row(2)))
}
}

case class TestDataUnion(x: Int, y: Int, z: Int)
Expand Down