From f5963a202f10273693329241b0f011c7f6dbf185 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 28 Feb 2019 13:05:55 +0800 Subject: [PATCH] Add comment and do simple cleanup. --- .../catalyst/DeserializerBuildHelper.scala | 25 ++++++++++++++----- .../sql/catalyst/JavaTypeInference.scala | 1 - .../spark/sql/catalyst/ScalaReflection.scala | 3 +-- 3 files changed, 20 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala index 3a2f38622d00d..cbf3bb02026c4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala @@ -44,25 +44,38 @@ object DeserializerBuildHelper { upCastToExpectedType(newPath, dataType, walkedTypePath) } + /** + * Returns an expression that can be used to deserialize input expression. + * + * @param expr The input expression that can be used to extract serialized value. + * @param nullable Whether deserialized expression evalutes to null value. + * @param walkedTypePath The paths from top to bottom to access current field when deserializing. + * @param funcForCreatingDeserializer Given input expression and typed path, this function + * returns deserializer expression. + */ def deserializerForWithNullSafety( expr: Expression, - dataType: DataType, nullable: Boolean, walkedTypePath: Seq[String], - funcForCreatingNewExpr: (Expression, Seq[String]) => Expression): Expression = { - val newExpr = funcForCreatingNewExpr(expr, walkedTypePath) + funcForCreatingDeserializer: (Expression, Seq[String]) => Expression): Expression = { + val newExpr = funcForCreatingDeserializer(expr, walkedTypePath) expressionWithNullSafety(newExpr, nullable, walkedTypePath) } + /** + * This returns deserializer expression as `deserializerForWithNullSafety` does. The only + * difference is this method adds `UpCast` to input expression to avoid possible runtime + * error caused by type mimatch between serialized column data type and deserializing type. + */ def deserializerForWithNullSafetyAndUpcast( expr: Expression, dataType: DataType, nullable: Boolean, walkedTypePath: Seq[String], - funcForCreatingNewExpr: (Expression, Seq[String]) => Expression): Expression = { + funcForCreatingDeserializer: (Expression, Seq[String]) => Expression): Expression = { val casted = upCastToExpectedType(expr, dataType, walkedTypePath) - deserializerForWithNullSafety(casted, dataType, nullable, walkedTypePath, - funcForCreatingNewExpr) + deserializerForWithNullSafety(casted, nullable, walkedTypePath, + funcForCreatingDeserializer) } private def expressionWithNullSafety( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index 1822f9b036f72..3e4ca7f6b01cc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -329,7 +329,6 @@ object JavaTypeInference { s""", name: "$fieldName")""") +: walkedTypePath val setter = deserializerForWithNullSafety( path, - dataType, nullable = nullable, newTypePath, (expr, typePath) => deserializerFor(fieldType, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 26cc7b4d7ad80..88c321daf6846 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -322,13 +322,12 @@ object ScalaReflection extends ScalaReflection { val newTypePath = (s"""- field (class: "$clsName", """ + s"""name: "$fieldName")""") +: walkedTypePath - // For tuples, we based grab the inner fields by ordinal instead of name. deserializerForWithNullSafety( path, - dataType, nullable = nullable, newTypePath, (expr, typePath) => { + // For tuples, we based grab the inner fields by ordinal instead of name. if (cls.getName startsWith "scala.Tuple") { deserializerFor( fieldType,