Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -120,17 +120,19 @@ object RowEncoder {

case StructType(fields) =>
val convertedFields = fields.zipWithIndex.map { case (f, i) =>
val method = if (f.dataType.isInstanceOf[StructType]) {
"getStruct"
val x = serializerFor(
GetExternalRowField(inputObject, i, externalDataTypeForInput(f.dataType), f.nullable),
f.dataType
)
if (f.nullable) {
If(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can do the null check inside GetExternalRowField

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i had that before (basically without the If, since GetExternalRowField already does null checks inside), but the issue is that extractorsFor code path then also runs for nulls , and that causes some weird runtime errors. i will try again.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if we do the null check inside GetExternalRowField then the code for serializerFor also needs to be pushed into it (to be inside the null check), and i can not figure out how to do that...

Invoke(inputObject, "isNullAt", BooleanType, Literal(i) :: Nil),
Literal.create(null, f.dataType),
x
)
} else {
"get"
x
}
If(
Invoke(inputObject, "isNullAt", BooleanType, Literal(i) :: Nil),
Literal.create(null, f.dataType),
serializerFor(
Invoke(inputObject, method, externalDataTypeForInput(f.dataType), Literal(i) :: Nil),
f.dataType))
}
If(IsNull(inputObject),
Literal.create(null, inputType),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -680,3 +680,53 @@ case class AssertNotNull(child: Expression, walkedTypePath: Seq[String])
"""
}
}

case class GetExternalRowField(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can take a look at GetStructFiled, which gets field from internal row, and is similar to this one.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok will do

targetObject: Expression,
index: Int,
dataType: DataType,
nullable: Boolean) extends Expression with NonSQLExpression {

override def children: Seq[Expression] = Seq(targetObject)

override def eval(input: InternalRow): Any =
throw new UnsupportedOperationException("Only code-generated evaluation is supported")

override def genCode(ctx: CodegenContext, ev: ExprCode): String = {
val javaType = ctx.javaType(dataType)
val obj = targetObject.gen(ctx)

val get = dataType match {
case IntegerType => s"""${obj.value}.getInt($index)"""
case LongType => s"""${obj.value}.getLong($index)"""
case FloatType => s"""${obj.value}.getFloat($index)"""
case ShortType => s"""${obj.value}.getShort($index)"""
case ByteType => s"""${obj.value}.getByte($index)"""
case DoubleType => s"""${obj.value}.getDouble($index)"""
case BooleanType => s"""${obj.value}.getBoolean($index)"""
case ObjectType(x) if x == classOf[Row] => s"""${obj.value}.getStruct($index)"""
case _ => s"""((${javaType}) ${obj.value}.get($index))"""
}

if (nullable) {
s"""
${obj.code}
final ${javaType} ${ev.value};
final boolean ${ev.isNull};
if (${obj.value}.isNullAt(${index})) {
${ev.value} = ${ctx.defaultValue(dataType)};
${ev.isNull} = true;
} else {
${ev.value} = ${get};
${ev.isNull} = false;
}
"""
} else {
s"""
${obj.code}
final ${javaType} ${ev.value} = ${get};
final boolean ${ev.isNull} = false;
"""
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import scala.util.Random
import org.scalatest.Matchers._

import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Union}
import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.execution.aggregate.TungstenAggregate
Expand Down Expand Up @@ -1432,4 +1433,9 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
getMessage()
assert(e1.startsWith("Path does not exist"))
}

test("SPARK-14139: map on row and preserve schema nullability") {
val df1 = Seq(1, 2, 3).toDF
assert(df1.map(row => Row(row.getInt(0) + 1))(RowEncoder(df1.schema)).schema === df1.schema)
}
}