diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala index f60bad180a710..a8439edcb871b 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala @@ -86,14 +86,7 @@ private[sql] class HiveSessionCatalog( } } catch { case NonFatal(e) => - val noHandlerMsg = s"No handler for UDF/UDAF/UDTF '${clazz.getCanonicalName}': $e" - val errorMsg = - if (classOf[GenericUDTF].isAssignableFrom(clazz)) { - s"$noHandlerMsg\nPlease make sure your function overrides " + - "`public StructObjectInspector initialize(ObjectInspector[] args)`." - } else { - noHandlerMsg - } + val errorMsg = s"No handler for UDF/UDAF/UDTF '${clazz.getCanonicalName}': $e" val analysisException = new AnalysisException(errorMsg) analysisException.setStackTrace(e.getStackTrace) throw analysisException diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala index c7002853bed54..7717e6ee207d9 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala @@ -213,10 +213,14 @@ private[hive] case class HiveGenericUDTF( } @transient - protected lazy val inputInspectors = children.map(toInspector) + protected lazy val inputInspector = { + val inspectors = children.map(toInspector) + val fields = inspectors.indices.map(index => s"_col$index").asJava + ObjectInspectorFactory.getStandardStructObjectInspector(fields, inspectors.asJava) + } @transient - protected lazy val outputInspector = function.initialize(inputInspectors.toArray) + protected lazy val outputInspector = function.initialize(inputInspector) @transient protected lazy val udtInput = new Array[AnyRef](children.length) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 3370695245fd0..96c5bf7e27279 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -2160,32 +2160,6 @@ abstract class SQLQuerySuiteBase extends QueryTest with SQLTestUtils with TestHi } } - test("SPARK-21101 UDTF should override initialize(ObjectInspector[] args)") { - withUserDefinedFunction("udtf_stack1" -> true, "udtf_stack2" -> true) { - sql( - s""" - |CREATE TEMPORARY FUNCTION udtf_stack1 - |AS 'org.apache.spark.sql.hive.execution.UDTFStack' - |USING JAR '${hiveContext.getHiveFile("SPARK-21101-1.0.jar").toURI}' - """.stripMargin) - val cnt = - sql("SELECT udtf_stack1(2, 'A', 10, date '2015-01-01', 'B', 20, date '2016-01-01')").count() - assert(cnt === 2) - - sql( - s""" - |CREATE TEMPORARY FUNCTION udtf_stack2 - |AS 'org.apache.spark.sql.hive.execution.UDTFStack2' - |USING JAR '${hiveContext.getHiveFile("SPARK-21101-1.0.jar").toURI}' - """.stripMargin) - val e = intercept[org.apache.spark.sql.AnalysisException] { - sql("SELECT udtf_stack2(2, 'A', 10, date '2015-01-01', 'B', 20, date '2016-01-01')") - } - assert( - e.getMessage.contains("public StructObjectInspector initialize(ObjectInspector[] args)")) - } - } - test("SPARK-21721: Clear FileSystem deleterOnExit cache if path is successfully removed") { val table = "test21721" withTable(table) { @@ -2583,6 +2557,30 @@ abstract class SQLQuerySuiteBase extends QueryTest with SQLTestUtils with TestHi } } } + + test("SPARK-32668: HiveGenericUDTF initialize UDTF should use StructObjectInspector method") { + withUserDefinedFunction("udtf_stack1" -> true, "udtf_stack2" -> true) { + sql( + s""" + |CREATE TEMPORARY FUNCTION udtf_stack1 + |AS 'org.apache.spark.sql.hive.execution.UDTFStack' + |USING JAR '${hiveContext.getHiveFile("SPARK-21101-1.0.jar").toURI}' + """.stripMargin) + sql( + s""" + |CREATE TEMPORARY FUNCTION udtf_stack2 + |AS 'org.apache.spark.sql.hive.execution.UDTFStack2' + |USING JAR '${hiveContext.getHiveFile("SPARK-21101-1.0.jar").toURI}' + """.stripMargin) + + Seq("udtf_stack1", "udtf_stack2").foreach { udf => + checkAnswer( + sql(s"SELECT $udf(2, 'A', 10, date '2015-01-01', 'B', 20, date '2016-01-01')"), + Seq(Row("A", 10, Date.valueOf("2015-01-01")), + Row("B", 20, Date.valueOf("2016-01-01")))) + } + } + } } @SlowHiveTest