Skip to content

Commit 30c4774

Browse files
cloud-fanliancheng
authored andcommitted
[SPARK-15657][SQL] RowEncoder should validate the data type of input object
## What changes were proposed in this pull request? This PR improves the error handling of `RowEncoder`. When we create a `RowEncoder` with a given schema, we should validate the data type of input object. e.g. we should throw an exception when a field is boolean but is declared as a string column. This PR also removes the support to use `Product` as a valid external type of struct type. This support is added at #9712, but is incomplete, e.g. nested product, product in array are both not working. However, we never officially support this feature and I think it's ok to ban it. ## How was this patch tested? new tests in `RowEncoderSuite`. Author: Wenchen Fan <wenchen@databricks.com> Closes #13401 from cloud-fan/bug.
1 parent 8a91105 commit 30c4774

4 files changed

Lines changed: 95 additions & 40 deletions

File tree

sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -304,15 +304,7 @@ trait Row extends Serializable {
304304
*
305305
* @throws ClassCastException when data type does not match.
306306
*/
307-
def getStruct(i: Int): Row = {
308-
// Product and Row both are recognized as StructType in a Row
309-
val t = get(i)
310-
if (t.isInstanceOf[Product]) {
311-
Row.fromTuple(t.asInstanceOf[Product])
312-
} else {
313-
t.asInstanceOf[Row]
314-
}
315-
}
307+
def getStruct(i: Int): Row = getAs[Row](i)
316308

317309
/**
318310
* Returns the value at position i.

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ import org.apache.spark.unsafe.types.UTF8String
5151
* BinaryType -> byte array
5252
* ArrayType -> scala.collection.Seq or Array
5353
* MapType -> scala.collection.Map
54-
* StructType -> org.apache.spark.sql.Row or Product
54+
* StructType -> org.apache.spark.sql.Row
5555
* }}}
5656
*/
5757
object RowEncoder {
@@ -121,11 +121,15 @@ object RowEncoder {
121121

122122
case t @ ArrayType(et, _) => et match {
123123
case BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType =>
124+
// TODO: validate input type for primitive array.
124125
NewInstance(
125126
classOf[GenericArrayData],
126127
inputObject :: Nil,
127128
dataType = t)
128-
case _ => MapObjects(serializerFor(_, et), inputObject, externalDataTypeForInput(et))
129+
case _ => MapObjects(
130+
element => serializerFor(ValidateExternalType(element, et), et),
131+
inputObject,
132+
ObjectType(classOf[Object]))
129133
}
130134

131135
case t @ MapType(kt, vt, valueNullable) =>
@@ -151,8 +155,9 @@ object RowEncoder {
151155
case StructType(fields) =>
152156
val nonNullOutput = CreateNamedStruct(fields.zipWithIndex.flatMap { case (field, index) =>
153157
val fieldValue = serializerFor(
154-
GetExternalRowField(
155-
inputObject, index, field.name, externalDataTypeForInput(field.dataType)),
158+
ValidateExternalType(
159+
GetExternalRowField(inputObject, index, field.name),
160+
field.dataType),
156161
field.dataType)
157162
val convertedField = if (field.nullable) {
158163
If(
@@ -183,7 +188,7 @@ object RowEncoder {
183188
* can be `scala.math.BigDecimal`, `java.math.BigDecimal`, or
184189
* `org.apache.spark.sql.types.Decimal`.
185190
*/
186-
private def externalDataTypeForInput(dt: DataType): DataType = dt match {
191+
def externalDataTypeForInput(dt: DataType): DataType = dt match {
187192
// In order to support both Decimal and java/scala BigDecimal in external row, we make this
188193
// as java.lang.Object.
189194
case _: DecimalType => ObjectType(classOf[java.lang.Object])
@@ -192,7 +197,7 @@ object RowEncoder {
192197
case _ => externalDataTypeFor(dt)
193198
}
194199

195-
private def externalDataTypeFor(dt: DataType): DataType = dt match {
200+
def externalDataTypeFor(dt: DataType): DataType = dt match {
196201
case _ if ScalaReflection.isNativeType(dt) => dt
197202
case TimestampType => ObjectType(classOf[java.sql.Timestamp])
198203
case DateType => ObjectType(classOf[java.sql.Date])

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala

Lines changed: 52 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ import org.apache.spark.SparkConf
2626
import org.apache.spark.serializer._
2727
import org.apache.spark.sql.Row
2828
import org.apache.spark.sql.catalyst.InternalRow
29+
import org.apache.spark.sql.catalyst.encoders.RowEncoder
2930
import org.apache.spark.sql.catalyst.expressions._
3031
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
3132
import org.apache.spark.sql.catalyst.util.GenericArrayData
@@ -692,22 +693,17 @@ case class AssertNotNull(child: Expression, walkedTypePath: Seq[String])
692693
case class GetExternalRowField(
693694
child: Expression,
694695
index: Int,
695-
fieldName: String,
696-
dataType: DataType) extends UnaryExpression with NonSQLExpression {
696+
fieldName: String) extends UnaryExpression with NonSQLExpression {
697697

698698
override def nullable: Boolean = false
699699

700+
override def dataType: DataType = ObjectType(classOf[Object])
701+
700702
override def eval(input: InternalRow): Any =
701703
throw new UnsupportedOperationException("Only code-generated evaluation is supported")
702704

703705
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
704706
val row = child.genCode(ctx)
705-
706-
val getField = dataType match {
707-
case ObjectType(x) if x == classOf[Row] => s"""${row.value}.getStruct($index)"""
708-
case _ => s"""(${ctx.boxedType(dataType)}) ${row.value}.get($index)"""
709-
}
710-
711707
val code = s"""
712708
${row.code}
713709

@@ -720,8 +716,55 @@ case class GetExternalRowField(
720716
"cannot be null.");
721717
}
722718

723-
final ${ctx.javaType(dataType)} ${ev.value} = $getField;
719+
final Object ${ev.value} = ${row.value}.get($index);
724720
"""
725721
ev.copy(code = code, isNull = "false")
726722
}
727723
}
724+
725+
/**
726+
* Validates the actual data type of input expression at runtime. If it doesn't match the
727+
* expectation, throw an exception.
728+
*/
729+
case class ValidateExternalType(child: Expression, expected: DataType)
730+
extends UnaryExpression with NonSQLExpression with ExpectsInputTypes {
731+
732+
override def inputTypes: Seq[AbstractDataType] = Seq(ObjectType(classOf[Object]))
733+
734+
override def nullable: Boolean = child.nullable
735+
736+
override def dataType: DataType = RowEncoder.externalDataTypeForInput(expected)
737+
738+
override def eval(input: InternalRow): Any =
739+
throw new UnsupportedOperationException("Only code-generated evaluation is supported")
740+
741+
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
742+
val input = child.genCode(ctx)
743+
val obj = input.value
744+
745+
val typeCheck = expected match {
746+
case _: DecimalType =>
747+
Seq(classOf[java.math.BigDecimal], classOf[scala.math.BigDecimal], classOf[Decimal])
748+
.map(cls => s"$obj instanceof ${cls.getName}").mkString(" || ")
749+
case _: ArrayType =>
750+
s"$obj instanceof ${classOf[Seq[_]].getName} || $obj.getClass().isArray()"
751+
case _ =>
752+
s"$obj instanceof ${ctx.boxedType(dataType)}"
753+
}
754+
755+
val code = s"""
756+
${input.code}
757+
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
758+
if (!${input.isNull}) {
759+
if ($typeCheck) {
760+
${ev.value} = (${ctx.boxedType(dataType)}) $obj;
761+
} else {
762+
throw new RuntimeException($obj.getClass().getName() + " is not a valid " +
763+
"external type for schema of ${expected.simpleString}");
764+
}
765+
}
766+
767+
"""
768+
ev.copy(code = code, isNull = input.isNull)
769+
}
770+
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala

Lines changed: 31 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -127,22 +127,6 @@ class RowEncoderSuite extends SparkFunSuite {
127127
new StructType().add("array", arrayOfString).add("map", mapOfString))
128128
.add("structOfUDT", structOfUDT))
129129

130-
test(s"encode/decode: Product") {
131-
val schema = new StructType()
132-
.add("structAsProduct",
133-
new StructType()
134-
.add("int", IntegerType)
135-
.add("string", StringType)
136-
.add("double", DoubleType))
137-
138-
val encoder = RowEncoder(schema).resolveAndBind()
139-
140-
val input: Row = Row((100, "test", 0.123))
141-
val row = encoder.toRow(input)
142-
val convertedBack = encoder.fromRow(row)
143-
assert(input.getStruct(0) == convertedBack.getStruct(0))
144-
}
145-
146130
test("encode/decode decimal type") {
147131
val schema = new StructType()
148132
.add("int", IntegerType)
@@ -232,6 +216,37 @@ class RowEncoderSuite extends SparkFunSuite {
232216
assert(e.getMessage.contains("top level row object"))
233217
}
234218

219+
test("RowEncoder should validate external type") {
220+
val e1 = intercept[RuntimeException] {
221+
val schema = new StructType().add("a", IntegerType)
222+
val encoder = RowEncoder(schema)
223+
encoder.toRow(Row(1.toShort))
224+
}
225+
assert(e1.getMessage.contains("java.lang.Short is not a valid external type"))
226+
227+
val e2 = intercept[RuntimeException] {
228+
val schema = new StructType().add("a", StringType)
229+
val encoder = RowEncoder(schema)
230+
encoder.toRow(Row(1))
231+
}
232+
assert(e2.getMessage.contains("java.lang.Integer is not a valid external type"))
233+
234+
val e3 = intercept[RuntimeException] {
235+
val schema = new StructType().add("a",
236+
new StructType().add("b", IntegerType).add("c", StringType))
237+
val encoder = RowEncoder(schema)
238+
encoder.toRow(Row(1 -> "a"))
239+
}
240+
assert(e3.getMessage.contains("scala.Tuple2 is not a valid external type"))
241+
242+
val e4 = intercept[RuntimeException] {
243+
val schema = new StructType().add("a", ArrayType(TimestampType))
244+
val encoder = RowEncoder(schema)
245+
encoder.toRow(Row(Array("a")))
246+
}
247+
assert(e4.getMessage.contains("java.lang.String is not a valid external type"))
248+
}
249+
235250
private def encodeDecodeTest(schema: StructType): Unit = {
236251
test(s"encode/decode: ${schema.simpleString}") {
237252
val encoder = RowEncoder(schema).resolveAndBind()

0 commit comments

Comments
 (0)