Skip to content

Commit ee232db

Browse files
committed
Support List as a return type in Hive UDF
1 parent 96c5eee commit ee232db

4 files changed

Lines changed: 110 additions & 1 deletion

File tree

sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ import org.apache.hadoop.hive.serde2.objectinspector.{StructField => HiveStructF
2323
import org.apache.hadoop.hive.serde2.typeinfo.{DecimalTypeInfo, TypeInfoFactory}
2424
import org.apache.hadoop.hive.serde2.{io => hiveIo}
2525
import org.apache.hadoop.{io => hadoopIo}
26+
import org.apache.spark.Logging
27+
import org.apache.spark.annotation.DeveloperApi
2628

2729
import org.apache.spark.sql.catalyst.expressions._
2830
import org.apache.spark.sql.catalyst.util.DateTimeUtils
@@ -172,7 +174,7 @@ import scala.collection.JavaConversions._
172174
* e.g. date_add(printf("%s-%s-%s", a,b,c), 3)
173175
* We don't need to unwrap the data for printf and wrap it again and passes in data_add
174176
*/
175-
private[hive] trait HiveInspectors {
177+
private[hive] trait HiveInspectors extends Logging {
176178

177179
def javaClassToDataType(clz: Class[_]): DataType = clz match {
178180
// writable
@@ -216,8 +218,16 @@ private[hive] trait HiveInspectors {
216218

217219
case c: Class[_] if c.isArray => ArrayType(javaClassToDataType(c.getComponentType))
218220

221+
// list type
222+
case c: Class[_] if c == classOf[java.util.List[java.lang.Object]] =>
223+
logWarning("Failed to catch a correct component type in List<> because of type erasure," +
224+
" so you need to handle it correctly by yourself")
225+
ArrayType(ErasedType)
226+
219227
// Hive seems to return this for struct types?
220228
case c: Class[_] if c == classOf[java.lang.Object] => NullType
229+
230+
case c => throw new HiveDataTypeException("Unknown java type: " + c)
221231
}
222232

223233
/**
@@ -828,3 +838,18 @@ private[hive] trait HiveInspectors {
828838
}
829839
}
830840
}
841+
842+
/**
843+
* :: DeveloperApi ::
844+
* This represents an erased type because of type erasure in JVM.
845+
*/
846+
@DeveloperApi
847+
class ErasedType private() extends DataType {
848+
override def defaultSize: Int = 1
849+
private[spark] override def asNullable: ErasedType = this
850+
}
851+
852+
case object ErasedType extends ErasedType
853+
854+
/** The exception thrown from the [[HiveInspectors]]. */
855+
private[hive] class HiveDataTypeException(message: String) extends Exception(message)
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.hive.execution;
19+
20+
import org.apache.hadoop.hive.ql.exec.UDF;
21+
22+
import java.util.Arrays;
23+
import java.util.List;
24+
25+
public class UDFToListInt extends UDF {
26+
public List<Integer> evaluate(Object o) {
27+
return Arrays.asList(1, 2, 3);
28+
}
29+
}
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.hive.execution;
19+
20+
import org.apache.hadoop.hive.ql.exec.UDF;
21+
22+
import java.util.Arrays;
23+
import java.util.List;
24+
25+
public class UDFToListString extends UDF {
26+
public List<String> evaluate(Object o) {
27+
return Arrays.asList("data1", "data2", "data3");
28+
}
29+
}

sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,32 @@ class HiveUDFSuite extends QueryTest {
133133
TestHive.reset()
134134
}
135135

136+
test("UDFToListString") {
137+
val testData = TestHive.sparkContext.parallelize(StringCaseClass("") :: Nil).toDF()
138+
testData.registerTempTable("inputTable")
139+
140+
sql(s"CREATE TEMPORARY FUNCTION testUDFToListString AS '${classOf[UDFToListString].getName}'")
141+
checkAnswer(
142+
sql("SELECT testUDFToListString(s) FROM inputTable"), //.collect(),
143+
Seq(Row("data1" :: "data2" :: "data3" :: Nil)))
144+
sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFToListString")
145+
146+
TestHive.reset()
147+
}
148+
149+
test("UDFToListInt") {
150+
val testData = TestHive.sparkContext.parallelize(StringCaseClass("") :: Nil).toDF()
151+
testData.registerTempTable("inputTable")
152+
153+
sql(s"CREATE TEMPORARY FUNCTION testUDFToListInt AS '${classOf[UDFToListInt].getName}'")
154+
checkAnswer(
155+
sql("SELECT testUDFToListInt(s) FROM inputTable"), //.collect(),
156+
Seq(Row(1 :: 2 :: 3 :: Nil)))
157+
sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFToListInt")
158+
159+
TestHive.reset()
160+
}
161+
136162
test("UDFListListInt") {
137163
val testData = TestHive.sparkContext.parallelize(
138164
ListListIntCaseClass(Nil) ::

0 commit comments

Comments
 (0)