Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -719,6 +719,12 @@ trait ScalaReflection {
}
}

def getParameterTypes(func: AnyRef): Seq[Class[_]] = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Scaladoc please

val methods = func.getClass.getMethods.filter(m => m.getName == "apply" && !m.isBridge)
assert(methods.length == 1)
methods.head.getParameterTypes
}

def typeOfObject: PartialFunction[Any, DataType] = {
// The data type can be determined without ambiguity.
case obj: Boolean => BooleanType
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.catalyst.trees.TreeNodeRef
import org.apache.spark.sql.catalyst.{SimpleCatalystConf, CatalystConf}
import org.apache.spark.sql.catalyst.{ScalaReflection, SimpleCatalystConf, CatalystConf}
import org.apache.spark.sql.types._

/**
Expand Down Expand Up @@ -85,6 +85,8 @@ class Analyzer(
extendedResolutionRules : _*),
Batch("Nondeterministic", Once,
PullOutNondeterministic),
Batch("UDF", Once,
HandleNullInputsForUDF),
Batch("Cleanup", fixedPoint,
CleanupAliases)
)
Expand Down Expand Up @@ -1063,6 +1065,29 @@ class Analyzer(
Project(p.output, newPlan.withNewChildren(newChild :: Nil))
}
}

/**
* Correctly handle null primitive inputs for UDF by adding extra [[If]] expression to do the
* null check. When user defines a UDF with primitive parameters, there is no way to tell if the
* primitive parameter is null or not, so here we assume the primitive input is null-propagatable
* and we should return null if the input is null.
*/
object HandleNullInputsForUDF extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
case p if !p.resolved => p // Skip unresolved nodes.

case plan => plan transformExpressionsUp {

case udf @ ScalaUDF(func, _, inputs, _) =>
val parameterTypes = ScalaReflection.getParameterTypes(func)
assert(parameterTypes.length == inputs.length)

parameterTypes.zip(inputs).filter(_._1.isPrimitive).map(_._2).foldLeft(udf: Expression) {
case (result, input) => If(IsNull(input), Literal.create(null, udf.dataType), result)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this would be a lot easier to read in the query plan if you created a single If with Ors.

}
}
}
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,13 @@ import org.apache.spark.sql.types.DataType

/**
* User-defined function.
* @param function The user defined scala function to run.
* Note that if you use primitive parameters, you are not able to check if it is
* null or not, and the UDF will return null for you if the primitive input is
* null. Use boxed type or [[Option]] if you wanna do the null-handling yourself.
* @param dataType Return type of function.
* @param children The input expressions of this UDF.
* @param inputTypes The expected input types of this UDF.
*/
case class ScalaUDF(
function: AnyRef,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -280,4 +280,21 @@ class ScalaReflectionSuite extends SparkFunSuite {
assert(s.fields.map(_.dataType) === Seq(IntegerType, StringType, DoubleType))
}
}

test("get parameter type from a function object") {
val primitiveFunc = (i: Int, j: Long) => "x"
val primitiveTypes = getParameterTypes(primitiveFunc)
assert(primitiveTypes.forall(_.isPrimitive))
assert(primitiveTypes === Seq(classOf[Int], classOf[Long]))

val boxedFunc = (i: java.lang.Integer, j: java.lang.Long) => "x"
val boxedTypes = getParameterTypes(boxedFunc)
assert(boxedTypes.forall(!_.isPrimitive))
assert(boxedTypes === Seq(classOf[java.lang.Integer], classOf[java.lang.Long]))

val anyFunc = (i: Any, j: AnyRef) => "x"
val anyTypes = getParameterTypes(anyFunc)
assert(anyTypes.forall(!_.isPrimitive))
assert(anyTypes === Seq(classOf[java.lang.Object], classOf[java.lang.Object]))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -174,4 +174,33 @@ class AnalysisSuite extends AnalysisTest {
)
assertAnalysisError(plan, Seq("data type mismatch: Arguments must be same type"))
}

test("SPARK-11725: correctly handle null inputs for ScalaUDF") {
val string = testRelation2.output(0)
val double = testRelation2.output(2)
val short = testRelation2.output(4)
val nullResult = Literal.create(null, StringType)

def checkUDF(udf: Expression, transformed: Expression): Unit = {
checkAnalysis(
Project(Alias(udf, "")() :: Nil, testRelation2),
Project(Alias(transformed, "")() :: Nil, testRelation2)
)
}

val udf1 = ScalaUDF((s: String) => "x", StringType, string :: Nil)
val expected1 = udf1
checkUDF(udf1, expected1)

val udf2 = ScalaUDF((s: String, d: Double) => "x", StringType, string :: double :: Nil)
val expected2 = If(IsNull(double), nullResult, udf2)
checkUDF(udf2, expected2)

val udf3 = ScalaUDF((s: Short, d: Double) => "x", StringType, short :: double :: Nil)
val expected3 = If(
IsNull(double),
nullResult,
If(IsNull(short), nullResult, udf3))
checkUDF(udf3, expected3)
}
}
14 changes: 14 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1115,4 +1115,18 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
checkAnswer(df.select(df("*")), Row(1, "a"))
checkAnswer(df.withColumnRenamed("d^'a.", "a"), Row(1, "a"))
}

test("SPARK-11725: correctly handle null inputs for ScalaUDF") {
val df = Seq(
new java.lang.Integer(22) -> "John",
null.asInstanceOf[java.lang.Integer] -> "Lucy").toDF("age", "name")

val boxedUDF = udf[java.lang.Integer, java.lang.Integer] {
(i: java.lang.Integer) => if (i == null) null else i * 2
}
checkAnswer(df.select(boxedUDF($"age")), Row(44) :: Row(null) :: Nil)

val primitiveUDF = udf((i: Int) => i * 2)
checkAnswer(df.select(primitiveUDF($"age")), Row(44) :: Row(null) :: Nil)
}
}