diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index a8397aa5e5c2..becfc4f91e0b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -120,17 +120,19 @@ object RowEncoder { case StructType(fields) => val convertedFields = fields.zipWithIndex.map { case (f, i) => - val method = if (f.dataType.isInstanceOf[StructType]) { - "getStruct" + val x = serializerFor( + GetExternalRowField(inputObject, i, externalDataTypeForInput(f.dataType), f.nullable), + f.dataType + ) + if (f.nullable) { + If( + Invoke(inputObject, "isNullAt", BooleanType, Literal(i) :: Nil), + Literal.create(null, f.dataType), + x + ) } else { - "get" + x } - If( - Invoke(inputObject, "isNullAt", BooleanType, Literal(i) :: Nil), - Literal.create(null, f.dataType), - serializerFor( - Invoke(inputObject, method, externalDataTypeForInput(f.dataType), Literal(i) :: Nil), - f.dataType)) } If(IsNull(inputObject), Literal.create(null, inputType), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala index 07b67a0240f0..13b93cfe24c0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala @@ -680,3 +680,53 @@ case class AssertNotNull(child: Expression, walkedTypePath: Seq[String]) """ } } + +case class GetExternalRowField( + targetObject: Expression, + index: Int, + dataType: DataType, + nullable: Boolean) extends Expression with NonSQLExpression { + + override def children: Seq[Expression] = Seq(targetObject) + + override def eval(input: InternalRow): Any = + throw new UnsupportedOperationException("Only code-generated evaluation is supported") + + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + val javaType = ctx.javaType(dataType) + val obj = targetObject.gen(ctx) + + val get = dataType match { + case IntegerType => s"""${obj.value}.getInt($index)""" + case LongType => s"""${obj.value}.getLong($index)""" + case FloatType => s"""${obj.value}.getFloat($index)""" + case ShortType => s"""${obj.value}.getShort($index)""" + case ByteType => s"""${obj.value}.getByte($index)""" + case DoubleType => s"""${obj.value}.getDouble($index)""" + case BooleanType => s"""${obj.value}.getBoolean($index)""" + case ObjectType(x) if x == classOf[Row] => s"""${obj.value}.getStruct($index)""" + case _ => s"""((${javaType}) ${obj.value}.get($index))""" + } + + if (nullable) { + s""" + ${obj.code} + final ${javaType} ${ev.value}; + final boolean ${ev.isNull}; + if (${obj.value}.isNullAt(${index})) { + ${ev.value} = ${ctx.defaultValue(dataType)}; + ${ev.isNull} = true; + } else { + ${ev.value} = ${get}; + ${ev.isNull} = false; + } + """ + } else { + s""" + ${obj.code} + final ${javaType} ${ev.value} = ${get}; + final boolean ${ev.isNull} = false; + """ + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 86c640552236..953b3e122223 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -26,6 +26,7 @@ import scala.util.Random import org.scalatest.Matchers._ import org.apache.spark.SparkException +import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Union} import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.execution.aggregate.TungstenAggregate @@ -1432,4 +1433,9 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { getMessage() assert(e1.startsWith("Path does not exist")) } + + test("SPARK-14139: map on row and preserve schema nullability") { + val df1 = Seq(1, 2, 3).toDF + assert(df1.map(row => Row(row.getInt(0) + 1))(RowEncoder(df1.schema)).schema === df1.schema) + } }