Skip to content

Commit 596f680

Browse files
jackylee-chyaooqinn
andcommitted
[SPARK-48845][SQL] GenericUDF catch exceptions from children
### What changes were proposed in this pull request? This pr is trying to fix the syntax issues with GenericUDF since 3.5.0. The problem arose from DeferredObject currently passing a value instead of a function, which prevented users from catching exceptions in GenericUDF, resulting in semantic differences. Here is an example case we encountered. Originally, the semantics were that udf_exception would throw an exception, while udf_catch_exception could catch the exception and return a null value. However, currently, any exception encountered by udf_exception will cause the program to fail. ``` select udf_catch_exception(udf_exception(col1)) from table ``` ### Why are the changes needed? For before Spark 3.5, we directly made the GenericUDF's DeferredObject lazy and evaluated the children in `function.evaluate(deferredObjects)`. Now, we would run the children's code first. If an exception is thrown, we would make it lazy to GenericUDF's DeferredObject. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Newly added UT. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #47268 from jackylee-ch/generic_udf_catch_exception_from_child_func. Lead-authored-by: jackylee-ch <lijunqing@baidu.com> Co-authored-by: Kent Yao <yao@apache.org> Signed-off-by: Kent Yao <yao@apache.org> (cherry picked from commit 236d957) Signed-off-by: Kent Yao <yao@apache.org>
1 parent b15a872 commit 596f680

5 files changed

Lines changed: 124 additions & 10 deletions

File tree

sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFEvaluators.scala

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,11 @@ class HiveGenericUDFEvaluator(
129129
override def returnType: DataType = inspectorToDataType(returnInspector)
130130

131131
def setArg(index: Int, arg: Any): Unit =
132-
deferredObjects(index).asInstanceOf[DeferredObjectAdapter].set(arg)
132+
deferredObjects(index).asInstanceOf[DeferredObjectAdapter].set(() => arg)
133+
134+
def setException(index: Int, exp: Throwable): Unit = {
135+
deferredObjects(index).asInstanceOf[DeferredObjectAdapter].set(() => throw exp)
136+
}
133137

134138
override def doEvaluate(): Any = unwrapper(function.evaluate(deferredObjects))
135139
}
@@ -139,10 +143,10 @@ private[hive] class DeferredObjectAdapter(oi: ObjectInspector, dataType: DataTyp
139143
extends DeferredObject with HiveInspectors {
140144

141145
private val wrapper = wrapperFor(oi, dataType)
142-
private var func: Any = _
143-
def set(func: Any): Unit = {
146+
private var func: () => Any = _
147+
def set(func: () => Any): Unit = {
144148
this.func = func
145149
}
146150
override def prepare(i: Int): Unit = {}
147-
override def get(): AnyRef = wrapper(func).asInstanceOf[AnyRef]
151+
override def get(): AnyRef = wrapper(func()).asInstanceOf[AnyRef]
148152
}

sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,13 @@ private[hive] case class HiveGenericUDF(
136136

137137
override def eval(input: InternalRow): Any = {
138138
children.zipWithIndex.foreach {
139-
case (child, idx) => evaluator.setArg(idx, child.eval(input))
139+
case (child, idx) =>
140+
try {
141+
evaluator.setArg(idx, child.eval(input))
142+
} catch {
143+
case t: Throwable =>
144+
evaluator.setException(idx, t)
145+
}
140146
}
141147
evaluator.evaluate()
142148
}
@@ -157,10 +163,15 @@ private[hive] case class HiveGenericUDF(
157163
val setValues = evals.zipWithIndex.map {
158164
case (eval, i) =>
159165
s"""
160-
|if (${eval.isNull}) {
161-
| $refEvaluator.setArg($i, null);
162-
|} else {
163-
| $refEvaluator.setArg($i, ${eval.value});
166+
|try {
167+
| ${eval.code}
168+
| if (${eval.isNull}) {
169+
| $refEvaluator.setArg($i, null);
170+
| } else {
171+
| $refEvaluator.setArg($i, ${eval.value});
172+
| }
173+
|} catch (Throwable t) {
174+
| $refEvaluator.setException($i, t);
164175
|}
165176
|""".stripMargin
166177
}
@@ -169,7 +180,6 @@ private[hive] case class HiveGenericUDF(
169180
val resultTerm = ctx.freshName("result")
170181
ev.copy(code =
171182
code"""
172-
|${evals.map(_.code).mkString("\n")}
173183
|${setValues.mkString("\n")}
174184
|$resultType $resultTerm = ($resultType) $refEvaluator.evaluate();
175185
|boolean ${ev.isNull} = $resultTerm == null;
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.hive.execution;
19+
20+
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
21+
import org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
22+
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
23+
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
24+
25+
public class UDFCatchException extends GenericUDF {
26+
27+
@Override
28+
public ObjectInspector initialize(ObjectInspector[] args) throws UDFArgumentException {
29+
if (args.length != 1) {
30+
throw new UDFArgumentException("Exactly one argument is expected.");
31+
}
32+
return PrimitiveObjectInspectorFactory.javaStringObjectInspector;
33+
}
34+
35+
@Override
36+
public Object evaluate(GenericUDF.DeferredObject[] args) {
37+
if (args == null) {
38+
return null;
39+
}
40+
try {
41+
return args[0].get();
42+
} catch (Exception e) {
43+
return null;
44+
}
45+
}
46+
47+
@Override
48+
public String getDisplayString(String[] children) {
49+
return null;
50+
}
51+
}
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.hive.execution;
19+
20+
import org.apache.hadoop.hive.ql.exec.UDF;
21+
22+
public class UDFThrowException extends UDF {
23+
public String evaluate(String data) {
24+
return Integer.valueOf(data).toString();
25+
}
26+
}

sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ import org.apache.hadoop.io.{LongWritable, Writable}
3535

3636
import org.apache.spark.{SparkException, SparkFiles, TestUtils}
3737
import org.apache.spark.sql.{AnalysisException, QueryTest, Row}
38+
import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode
3839
import org.apache.spark.sql.catalyst.plans.logical.Project
3940
import org.apache.spark.sql.execution.WholeStageCodegenExec
4041
import org.apache.spark.sql.functions.{call_function, max}
@@ -791,6 +792,28 @@ class HiveUDFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils {
791792
}
792793
}
793794
}
795+
796+
test("SPARK-48845: GenericUDF catch exceptions from child UDFs") {
797+
withTable("test_catch_exception") {
798+
withUserDefinedFunction("udf_throw_exception" -> true, "udf_catch_exception" -> true) {
799+
Seq("9", "9-1").toDF("a").write.saveAsTable("test_catch_exception")
800+
sql("CREATE TEMPORARY FUNCTION udf_throw_exception AS " +
801+
s"'${classOf[UDFThrowException].getName}'")
802+
sql("CREATE TEMPORARY FUNCTION udf_catch_exception AS " +
803+
s"'${classOf[UDFCatchException].getName}'")
804+
Seq(
805+
CodegenObjectFactoryMode.FALLBACK.toString,
806+
CodegenObjectFactoryMode.NO_CODEGEN.toString
807+
).foreach { codegenMode =>
808+
withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> codegenMode) {
809+
val df = sql(
810+
"SELECT udf_catch_exception(udf_throw_exception(a)) FROM test_catch_exception")
811+
checkAnswer(df, Seq(Row("9"), Row(null)))
812+
}
813+
}
814+
}
815+
}
816+
}
794817
}
795818

796819
class TestPair(x: Int, y: Int) extends Writable with Serializable {

0 commit comments

Comments
 (0)