Skip to content

Commit 6b22fea

Browse files
committed
Address comments
1 parent 2bc906d commit 6b22fea

1 file changed

Lines changed: 16 additions & 49 deletions

File tree

sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala

Lines changed: 16 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ object ExtractPythonUDFFromAggregate extends Rule[LogicalPlan] {
9595
*/
9696
object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper {
9797

98-
private case class LazyEvalType(var evalType: Int = -1) {
98+
private case class EvalTypeHolder(var evalType: Int = -1) {
9999

100100
def isSet: Boolean = evalType >= 0
101101

@@ -120,57 +120,24 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper {
120120
e.find(PythonUDF.isScalarPythonUDF).isDefined
121121
}
122122

123-
/**
124-
* Check whether a PythonUDF expression can be evaluated in Python.
125-
*
126-
* If the lazy eval type is not set, this method checks for either Batched Python UDF and Scalar
127-
* Pandas UDF. If the lazy eval type is set, this method checks for the expression of the
128-
* specified eval type.
129-
*
130-
* This method will also set the lazy eval type to be the type of the first evaluable expression,
131-
* i.e., if lazy eval type is not set and we find a evaluable Python UDF expression, lazy eval
132-
* type will be set to the eval type of the expression.
133-
*
134-
*/
135-
private def canEvaluateInPython(e: PythonUDF, lazyEvalType: LazyEvalType): Boolean = {
136-
if (!lazyEvalType.isSet) {
137-
e.children match {
138-
// single PythonUDF child could be chained and evaluated in Python if eval type is the same
139-
case Seq(u: PythonUDF) =>
140-
// Need to recheck the eval type because lazy eval type will be set if child Python UDF is
141-
// evaluable
142-
canEvaluateInPython(u, lazyEvalType) && lazyEvalType.get == e.evalType
143-
// Python UDF can't be evaluated directly in JVM
144-
case children => if (!children.exists(hasScalarPythonUDF)) {
145-
// We found the first evaluable expression, set lazy eval type to its eval type.
146-
lazyEvalType.set(e.evalType)
147-
true
148-
} else {
149-
false
150-
}
151-
}
152-
} else {
153-
if (e.evalType != lazyEvalType.get) {
154-
false
155-
} else {
156-
e.children match {
157-
case Seq(u: PythonUDF) => canEvaluateInPython(u, lazyEvalType)
158-
case children => !children.exists(hasScalarPythonUDF)
159-
}
160-
}
123+
private def canEvaluateInPython(e: PythonUDF): Boolean = {
124+
e.children match {
125+
// single PythonUDF child could be chained and evaluated in Python
126+
case Seq(u: PythonUDF) => e.evalType == u.evalType && canEvaluateInPython(u)
127+
// Python UDF can't be evaluated directly in JVM
128+
case children => !children.exists(hasScalarPythonUDF)
161129
}
162130
}
163131

164132
private def collectEvaluableUDFs(
165133
expr: Expression,
166-
evalType: LazyEvalType
167-
): Seq[PythonUDF] = {
168-
expr match {
169-
case udf: PythonUDF if
170-
PythonUDF.isScalarPythonUDF(udf) && canEvaluateInPython(udf, evalType) =>
171-
Seq(udf)
172-
case e => e.children.flatMap(collectEvaluableUDFs(_, evalType))
173-
}
134+
firstEvalType: EvalTypeHolder): Seq[PythonUDF] = expr match {
135+
case udf: PythonUDF if PythonUDF.isScalarPythonUDF(udf)
136+
&& (!firstEvalType.isSet || firstEvalType.get == udf.evalType)
137+
&& canEvaluateInPython(udf) =>
138+
firstEvalType.evalType = udf.evalType
139+
Seq(udf)
140+
case e => e.children.flatMap(collectEvaluableUDFs(_, firstEvalType))
174141
}
175142

176143
def apply(plan: SparkPlan): SparkPlan = plan transformUp {
@@ -181,8 +148,8 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper {
181148
* Extract all the PythonUDFs from the current operator and evaluate them before the operator.
182149
*/
183150
private def extract(plan: SparkPlan): SparkPlan = {
184-
val lazyEvalType = new LazyEvalType
185-
val udfs = plan.expressions.flatMap(collectEvaluableUDFs(_, lazyEvalType))
151+
val firstEvalType = new EvalTypeHolder
152+
val udfs = plan.expressions.flatMap(collectEvaluableUDFs(_, firstEvalType))
186153
// ignore the PythonUDF that come from second/third aggregate, which is not used
187154
.filter(udf => udf.references.subsetOf(plan.inputSet))
188155
if (udfs.isEmpty) {

0 commit comments

Comments
 (0)