Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -788,12 +788,37 @@ object ScalaReflection extends ScalaReflection {
}

/**
* Finds an accessible constructor with compatible parameters. This is a more flexible search
* than the exact matching algorithm in `Class.getConstructor`. The first assignment-compatible
* matching constructor is returned. Otherwise, it returns `None`.
* Finds an accessible constructor with compatible parameters. This is a more flexible search than
* the exact matching algorithm in `Class.getConstructor`. The first assignment-compatible
* matching constructor is returned if it exists. Otherwise, we check for additional compatible
* constructors defined in the companion object as `apply` methods. Otherwise, it returns `None`.
*/
def findConstructor(cls: Class[_], paramTypes: Seq[Class[_]]): Option[Constructor[_]] = {
Option(ConstructorUtils.getMatchingAccessibleConstructor(cls, paramTypes: _*))
def findConstructor(cls: Class[_], paramTypes: Seq[Class[_]]): Option[Seq[AnyRef] => Any] = {
Option(ConstructorUtils.getMatchingAccessibleConstructor(cls, paramTypes: _*)) match {
case Some(c) => Some((x: Seq[AnyRef]) => c.newInstance(x: _*))
case None =>
val companion = mirror.staticClass(cls.getName).companion
val moduleMirror = mirror.reflectModule(companion.asModule)
val applyMethods = companion.asTerm.typeSignature
.member(universe.TermName("apply")).asTerm.alternatives
applyMethods.filter{ method =>
val params = method.typeSignature.paramLists.head
// Check that the needed params are the same length and of matching types
params.size == paramTypes.tail.size &&
params.zip(paramTypes.tail).map{case(ps, pc) =>
ps.typeSignature.typeSymbol == mirror.classSymbol(pc)
}.reduce(_&&_)
}.headOption.map{ applyMethodSymbol =>
val expectedArgsCount = applyMethodSymbol.typeSignature.paramLists.head.size
val instanceMirror = mirror.reflect(moduleMirror.instance)
val method = instanceMirror.reflectMethod(applyMethodSymbol.asMethod)
(_args: Seq[AnyRef]) => {
// Drop the "outer" argument if it is provided
val args = if (_args.size == expectedArgsCount) _args else _args.tail
method.apply(args: _*)
}
}
}
}

/**
Expand Down Expand Up @@ -973,8 +998,19 @@ trait ScalaReflection extends Logging {
}
}

/**
* If our type is a Scala trait it may have a companion object that
* only defines a constructor via `apply` method.
*/
def getCompanionConstructor(tpe: Type): Symbol = {
tpe.typeSymbol.asClass.companion.asTerm.typeSignature.member(universe.TermName("apply"))
}

protected def constructParams(tpe: Type): Seq[Symbol] = {
val constructorSymbol = tpe.dealias.member(termNames.CONSTRUCTOR)
val constructorSymbol = tpe.member(termNames.CONSTRUCTOR) match {
case NoSymbol => getCompanionConstructor(tpe)
case sym => sym
}
val params = if (constructorSymbol.isMethod) {
constructorSymbol.asMethod.paramLists
} else {
Expand All @@ -989,5 +1025,4 @@ trait ScalaReflection extends Logging {
}
params.flatten
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -453,7 +453,7 @@ case class NewInstance(
@transient private lazy val constructor: (Seq[AnyRef]) => Any = {
val paramTypes = ScalaReflection.expressionJavaClasses(arguments)
val getConstructor = (paramClazz: Seq[Class[_]]) => {
ScalaReflection.findConstructor(cls, paramClazz).getOrElse {
ScalaReflection.findConstructor(cls, paramClazz).getOrElse{
sys.error(s"Couldn't find a valid constructor on $cls")
}
}
Expand All @@ -462,12 +462,12 @@ case class NewInstance(
val d = outerObj.getClass +: paramTypes
val c = getConstructor(outerObj.getClass +: paramTypes)
(args: Seq[AnyRef]) => {
c.newInstance(outerObj +: args: _*)
c(Seq(outerObj +: args: _*))
}
}.getOrElse {
val c = getConstructor(paramTypes)
(args: Seq[AnyRef]) => {
c.newInstance(args: _*)
c(Seq(args: _*))
}
}
}
Expand All @@ -486,10 +486,16 @@ case class NewInstance(

ev.isNull = resultIsNull

val constructorCall = outer.map { gen =>
s"${gen.value}.new ${cls.getSimpleName}($argString)"
}.getOrElse {
s"new $className($argString)"
val constructorCall = cls.getConstructors.size match {
// If there are no constructors, the `new` method will fail. In
// this case we can try to call the apply method constructor
// that might be defined on the companion object.
case 0 => s"$className$$.MODULE$$.apply($argString)"
case _ => outer.map { gen =>
s"${gen.value}.new ${cls.getSimpleName}($argString)"
}.getOrElse {
s"new $className($argString)"
}
}

val code = code"""
Expand All @@ -498,6 +504,7 @@ case class NewInstance(
final $javaType ${ev.value} = ${ev.isNull} ?
${CodeGenerator.defaultValue(dataType)} : $constructorCall;
"""

ev.copy(code = code)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,36 @@ object TestingUDT {
}
}

trait ScroogeLikeExample extends Product1[Int] with java.io.Serializable {
import ScroogeLikeExample._

def _1: Int

def canEqual(other: Any): Boolean = other.isInstanceOf[ScroogeLikeExample]

private def _equals(x: ScroogeLikeExample, y: ScroogeLikeExample): Boolean =
x.productArity == y.productArity &&
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, but you've already established it's a ScroogeLikeExample here. Why must it be Product1 just to check whether it's also Product1? seems like it's not adding anything. In fact, why not just compare the one element that this trait knows about? to the extent it can implement equals() meaningfully, that's all it is doing already.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My previous answer was not complete. Product1 is also necessary so that the implicit Encoders.product[T <: Product : TypeTag] will work with this class, if omitted the DatasetSuite test will not compile:

[error] /home/drew/spark/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala:1577: value toDS is not a member of Seq[org.apache.spark.sql.catalyst.ScroogeLikeExample]
[error]     val ds = data.toDS

I could add some new encoder, but I think that might be worse as the goal of this PR is for Scrooge classes to work with the provided implicit encoders.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just use SparkSession.createDataset?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, actually that probably won't work any more or less. OK, it's because there is an Encoder for Product. You can still simplify the equals() and so on I think, but looks like that's easier than a new Encoder. Or is it sufficient to test a Seq of a concrete subtype of ScroogeLikeExample?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm worried about changing the tests to use a concrete subtype, because the reflection calls might behave differently in that case either now or later on. I simplified a little more. canEqual is necessary to implement product. equals is necessary or tests will not pass (it will check object pointer equality), and hashCode is needed for scalastyle to pass since equals is necessary.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's OK, leave in Product, if it's actually testing the case you have in mind. Yes I know equals() is needed. The new implementation looks good.

x.productIterator.sameElements(y.productIterator)

override def equals(other: Any): Boolean =
canEqual(other) &&
_equals(this, other.asInstanceOf[ScroogeLikeExample])

override def hashCode: Int = {
var hash = _root_.scala.runtime.ScalaRunTime._hashCode(this)
hash
}
override def toString: String = s"ScroogeLikeExample(${_1})"
}


object ScroogeLikeExample {
def apply(x: Int): ScroogeLikeExample = new Immutable(x)

class Immutable(x: Int) extends ScroogeLikeExample {
def _1: Int = x
}
}

class ScalaReflectionSuite extends SparkFunSuite {
import org.apache.spark.sql.catalyst.ScalaReflection._
Expand Down Expand Up @@ -362,4 +392,11 @@ class ScalaReflectionSuite extends SparkFunSuite {
assert(numberOfCheckedArguments(deserializerFor[(java.lang.Double, Int)]) == 1)
assert(numberOfCheckedArguments(deserializerFor[(java.lang.Integer, java.lang.Integer)]) == 0)
}

test("SPARK-8288") {
val schema = schemaFor[ScroogeLikeExample]
assert(schema === Schema(
StructType(Seq(
StructField("x", IntegerType, nullable = false))), nullable = true))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,16 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
dataType = ObjectType(classOf[outerObj.Inner]),
outerPointer = Some(() => outerObj))
checkObjectExprEvaluation(newInst2, new outerObj.Inner(1))

// SPARK-8288 Test
import org.apache.spark.sql.catalyst.ScroogeLikeExample
val newInst3 = NewInstance(
cls = classOf[ScroogeLikeExample],
arguments = Literal(1) :: Nil,
propagateNull = false,
dataType = ObjectType(classOf[ScroogeLikeExample]),
outerPointer = Some(() => outerObj))
checkObjectExprEvaluation(newInst3, ScroogeLikeExample(1))
}

test("LambdaVariable should support interpreted execution") {
Expand Down