Skip to content
Closed
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
942dad7
Replace catalyst converter with RowEncoder.
viirya Nov 9, 2015
39f6c26
Add UserDefinedType to RowEncoder.
viirya Nov 9, 2015
75ffaeb
Fix scala style.
viirya Nov 9, 2015
1e13ff9
Call serialize on udt instead of user class.
viirya Nov 9, 2015
07ff97a
Add getField for UserDefinedType.
viirya Nov 10, 2015
ecf01bf
Move outputEncoder outside of eval and add calling copy().
viirya Nov 10, 2015
5186777
Merge remote-tracking branch 'upstream/master' into rowencoder-scalaudf
viirya Nov 11, 2015
39c0b7a
Replace catalyst converter with RowEncoder for the generated ScalaUDF.
viirya Nov 11, 2015
c910e6e
Fix scala style.
viirya Nov 11, 2015
1234515
Fix scala style.
viirya Nov 11, 2015
5c18c0c
Merge remote-tracking branch 'upstream/master' into rowencoder-scalaudf
viirya Nov 15, 2015
934f2cc
Merge remote-tracking branch 'upstream/master' into rowencoder-scalaudf
viirya Nov 25, 2015
fc882ca
Merge remote-tracking branch 'upstream/master' into rowencoder-scalaudf
viirya Dec 7, 2015
2c85714
Use reflection to call function for interpreted version. Add more com…
viirya Dec 7, 2015
f806755
Move reflection code outside eval function.
viirya Dec 7, 2015
26b4d85
Process exception thrown in UDF.
viirya Dec 7, 2015
693a6fe
Try to solve failed tests.
viirya Dec 8, 2015
1ca2efc
Try again.
viirya Dec 8, 2015
b8f3cce
Merge remote-tracking branch 'upstream/master' into rowencoder-scalaudf
viirya Apr 1, 2016
2fcbe69
Fix scala style.
viirya Apr 1, 2016
5da1c13
Try it.
viirya Apr 2, 2016
60f4ca0
Try to fix test.
viirya Apr 7, 2016
7a046fa
Make createTransformFunc as val.
viirya Apr 7, 2016
898acfa
Merge remote-tracking branch 'upstream/master' into rowencoder-scalaudf
viirya Apr 7, 2016
dd43918
Make createTransformFunc as lazy val.
viirya Apr 7, 2016
8dbc551
Pass Transformer into UDF to get the updated param values.
viirya Apr 7, 2016
5f987c0
Merge remote-tracking branch 'upstream/master' into rowencoder-scalaudf
viirya Apr 7, 2016
597c971
Fix MiMa problem.
viirya Apr 7, 2016
405e8b0
Fix passing null into ScalaUDF.
viirya Apr 7, 2016
648c7b2
Check PrimitiveType.
viirya Apr 7, 2016
10a9f91
Merge remote-tracking branch 'upstream/master' into rowencoder-scalaudf
viirya Apr 8, 2016
2a0c319
Merge remote-tracking branch 'upstream/master' into rowencoder-scalaudf
viirya Apr 8, 2016
21a2af2
Merge remote-tracking branch 'upstream/master' into rowencoder-scalaudf
viirya Apr 11, 2016
884a176
Remove ScalaUDF non code-generated evaluation support.
viirya Apr 11, 2016
30a867e
Remove unnecessary import.
viirya Apr 11, 2016
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,15 @@ trait Row extends Serializable {
*
* @throws ClassCastException when data type does not match.
*/
def getStruct(i: Int): Row = getAs[Row](i)
def getStruct(i: Int): Row = {
// Product and Row both are recoginized as StructType in a Row
val t = get(i)
if (t.isInstanceOf[Product]) {
Row.fromTuple(t.asInstanceOf[Product])
} else {
t.asInstanceOf[Row]
}
}
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We use schemaFor to get a catalyst DataType for udf's return type. For Product type, we return a StructType now. That causes a problem in RowEncoder because RowEncoder will try to get a Row not a Product for a field of StructType. You will get a casting exception if your udf returns something like (1, 2).

The problem is a field of StructType in a Row can be a Product or a Row. I modified the getStruct method in Row to turn a Row for a Product.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we also need to update the javadoc of Row to say that Product is also a valid value type of StructType.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok.


/**
* Returns the value at position i.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,14 @@ object RowEncoder {
case BooleanType | ByteType | ShortType | IntegerType | LongType |
FloatType | DoubleType | BinaryType => inputObject

case udt: UserDefinedType[_] =>
val obj = NewInstance(
udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(),
Nil,
false,
dataType = ObjectType(udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt()))
Invoke(obj, "serialize", udt.sqlType, inputObject :: Nil)

case TimestampType =>
StaticInvoke(
DateTimeUtils,
Expand Down Expand Up @@ -109,11 +117,16 @@ object RowEncoder {

case StructType(fields) =>
val convertedFields = fields.zipWithIndex.map { case (f, i) =>
val method = if (f.dataType.isInstanceOf[StructType]) {
"getStruct"
} else {
"get"
}
If(
Invoke(inputObject, "isNullAt", BooleanType, Literal(i) :: Nil),
Literal.create(null, f.dataType),
extractorsFor(
Invoke(inputObject, "get", externalDataTypeFor(f.dataType), Literal(i) :: Nil),
Invoke(inputObject, method, externalDataTypeFor(f.dataType), Literal(i) :: Nil),
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If a field is StructType, we explicitly call getStruct to take care both Product and Row.

f.dataType))
}
CreateStruct(convertedFields)
Expand All @@ -129,6 +142,7 @@ object RowEncoder {
case _: ArrayType => ObjectType(classOf[scala.collection.Seq[_]])
case _: MapType => ObjectType(classOf[scala.collection.Map[_, _]])
case _: StructType => ObjectType(classOf[Row])
case udt: UserDefinedType[_] => ObjectType(udt.userClass)
}

private def constructorFor(schema: StructType): Expression = {
Expand All @@ -147,6 +161,14 @@ object RowEncoder {
case BooleanType | ByteType | ShortType | IntegerType | LongType |
FloatType | DoubleType | BinaryType => input

case udt: UserDefinedType[_] =>
val obj = NewInstance(
udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(),
Nil,
false,
dataType = ObjectType(udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt()))
Invoke(obj, "deserialize", ObjectType(udt.userClass), input :: Nil)

case TimestampType =>
StaticInvoke(
DateTimeUtils,
Expand Down Expand Up @@ -234,5 +256,7 @@ object RowEncoder {
Invoke(row, "getArray", dataType, Literal(ordinal) :: Nil)
case _: MapType =>
Invoke(row, "getMap", dataType, Literal(ordinal) :: Nil)
case udt: UserDefinedType[_] =>
getField(row, ordinal, udt.sqlType)
}
}
Loading