@@ -95,7 +95,7 @@ object ExtractPythonUDFFromAggregate extends Rule[LogicalPlan] {
9595 */
9696object ExtractPythonUDFs extends Rule [SparkPlan ] with PredicateHelper {
9797
98- private case class LazyEvalType (var evalType : Int = - 1 ) {
98+ private case class EvalTypeHolder (var evalType : Int = - 1 ) {
9999
100100 def isSet : Boolean = evalType >= 0
101101
@@ -120,57 +120,24 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper {
120120 e.find(PythonUDF .isScalarPythonUDF).isDefined
121121 }
122122
123- /**
124- * Check whether a PythonUDF expression can be evaluated in Python.
125- *
126- * If the lazy eval type is not set, this method checks for either Batched Python UDF and Scalar
127- * Pandas UDF. If the lazy eval type is set, this method checks for the expression of the
128- * specified eval type.
129- *
130- * This method will also set the lazy eval type to be the type of the first evaluable expression,
131- * i.e., if lazy eval type is not set and we find a evaluable Python UDF expression, lazy eval
132- * type will be set to the eval type of the expression.
133- *
134- */
135- private def canEvaluateInPython (e : PythonUDF , lazyEvalType : LazyEvalType ): Boolean = {
136- if (! lazyEvalType.isSet) {
137- e.children match {
138- // single PythonUDF child could be chained and evaluated in Python if eval type is the same
139- case Seq (u : PythonUDF ) =>
140- // Need to recheck the eval type because lazy eval type will be set if child Python UDF is
141- // evaluable
142- canEvaluateInPython(u, lazyEvalType) && lazyEvalType.get == e.evalType
143- // Python UDF can't be evaluated directly in JVM
144- case children => if (! children.exists(hasScalarPythonUDF)) {
145- // We found the first evaluable expression, set lazy eval type to its eval type.
146- lazyEvalType.set(e.evalType)
147- true
148- } else {
149- false
150- }
151- }
152- } else {
153- if (e.evalType != lazyEvalType.get) {
154- false
155- } else {
156- e.children match {
157- case Seq (u : PythonUDF ) => canEvaluateInPython(u, lazyEvalType)
158- case children => ! children.exists(hasScalarPythonUDF)
159- }
160- }
123+ private def canEvaluateInPython (e : PythonUDF ): Boolean = {
124+ e.children match {
125+ // single PythonUDF child could be chained and evaluated in Python
126+ case Seq (u : PythonUDF ) => e.evalType == u.evalType && canEvaluateInPython(u)
127+ // Python UDF can't be evaluated directly in JVM
128+ case children => ! children.exists(hasScalarPythonUDF)
161129 }
162130 }
163131
164132 private def collectEvaluableUDFs (
165133 expr : Expression ,
166- evalType : LazyEvalType
167- ): Seq [PythonUDF ] = {
168- expr match {
169- case udf : PythonUDF if
170- PythonUDF .isScalarPythonUDF(udf) && canEvaluateInPython(udf, evalType) =>
171- Seq (udf)
172- case e => e.children.flatMap(collectEvaluableUDFs(_, evalType))
173- }
134+ firstEvalType : EvalTypeHolder ): Seq [PythonUDF ] = expr match {
135+ case udf : PythonUDF if PythonUDF .isScalarPythonUDF(udf)
136+ && (! firstEvalType.isSet || firstEvalType.get == udf.evalType)
137+ && canEvaluateInPython(udf) =>
138+ firstEvalType.evalType = udf.evalType
139+ Seq (udf)
140+ case e => e.children.flatMap(collectEvaluableUDFs(_, firstEvalType))
174141 }
175142
176143 def apply (plan : SparkPlan ): SparkPlan = plan transformUp {
@@ -181,8 +148,8 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper {
181148 * Extract all the PythonUDFs from the current operator and evaluate them before the operator.
182149 */
183150 private def extract (plan : SparkPlan ): SparkPlan = {
184- val lazyEvalType = new LazyEvalType
185- val udfs = plan.expressions.flatMap(collectEvaluableUDFs(_, lazyEvalType ))
151+ val firstEvalType = new EvalTypeHolder
152+ val udfs = plan.expressions.flatMap(collectEvaluableUDFs(_, firstEvalType ))
186153 // ignore the PythonUDF that come from second/third aggregate, which is not used
187154 .filter(udf => udf.references.subsetOf(plan.inputSet))
188155 if (udfs.isEmpty) {
0 commit comments