Skip to content
168 changes: 157 additions & 11 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -4763,17 +4763,6 @@ def test_vectorized_udf_invalid_length(self):
'Result vector from pandas_udf was not the required length'):
df.select(raise_exception(col('id'))).collect()

def test_vectorized_udf_mix_udf(self):
from pyspark.sql.functions import pandas_udf, udf, col
df = self.spark.range(10)
row_by_row_udf = udf(lambda x: x, LongType())
pd_udf = pandas_udf(lambda x: x, LongType())
with QuietTest(self.sc):
with self.assertRaisesRegexp(
Exception,
'Can not mix vectorized and non-vectorized UDFs'):
df.select(row_by_row_udf(col('id')), pd_udf(col('id'))).collect()

def test_vectorized_udf_chained(self):
from pyspark.sql.functions import pandas_udf, col
df = self.spark.range(10)
Expand Down Expand Up @@ -5060,6 +5049,147 @@ def test_type_annotation(self):
df = self.spark.range(1).select(pandas_udf(f=_locals['noop'], returnType='bigint')('id'))
self.assertEqual(df.first()[0], 0)

def test_mixed_udf(self):
import pandas as pd
from pyspark.sql.functions import col, udf, pandas_udf

df = self.spark.range(0, 1).toDF('v')

# Test mixture of multiple UDFs and Pandas UDFs

@udf('int')
def f1(x):
assert type(x) == int
return x + 1

@pandas_udf('int')
def f2(x):
assert type(x) == pd.Series
return x + 10

@udf('int')
def f3(x):
assert type(x) == int
return x + 100

@pandas_udf('int')
def f4(x):
assert type(x) == pd.Series
return x + 1000

# Test mixed udfs in a single projection
df1 = df \
.withColumn('f1', f1(col('v'))) \
.withColumn('f2', f2(col('v'))) \
.withColumn('f3', f3(col('v'))) \
.withColumn('f4', f4(col('v'))) \
.withColumn('f2_f1', f2(col('f1'))) \
.withColumn('f3_f1', f3(col('f1'))) \
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks testing udf + udf

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, the way the test is written is that I am trying to test many combinations so some combinations might not be mixed UDF. Do you prefer that I remove these cases?

.withColumn('f4_f1', f4(col('f1'))) \
.withColumn('f3_f2', f3(col('f2'))) \
.withColumn('f4_f2', f4(col('f2'))) \
.withColumn('f4_f3', f4(col('f3'))) \
.withColumn('f3_f2_f1', f3(col('f2_f1'))) \
.withColumn('f4_f2_f1', f4(col('f2_f1'))) \
.withColumn('f4_f3_f1', f4(col('f3_f1'))) \
.withColumn('f4_f3_f2', f4(col('f3_f2'))) \
.withColumn('f4_f3_f2_f1', f4(col('f3_f2_f1')))

# Test mixed udfs in a single expression
df2 = df \
.withColumn('f1', f1(col('v'))) \
.withColumn('f2', f2(col('v'))) \
.withColumn('f3', f3(col('v'))) \
.withColumn('f4', f4(col('v'))) \
.withColumn('f2_f1', f2(f1(col('v')))) \
.withColumn('f3_f1', f3(f1(col('v')))) \
.withColumn('f4_f1', f4(f1(col('v')))) \
.withColumn('f3_f2', f3(f2(col('v')))) \
.withColumn('f4_f2', f4(f2(col('v')))) \
.withColumn('f4_f3', f4(f3(col('v')))) \
.withColumn('f3_f2_f1', f3(f2(f1(col('v'))))) \
.withColumn('f4_f2_f1', f4(f2(f1(col('v'))))) \
.withColumn('f4_f3_f1', f4(f3(f1(col('v'))))) \
.withColumn('f4_f3_f2', f4(f3(f2(col('v'))))) \
.withColumn('f4_f3_f2_f1', f4(f3(f2(f1(col('v'))))))

# expected result
df3 = df \
.withColumn('f1', df['v'] + 1) \
.withColumn('f2', df['v'] + 10) \
.withColumn('f3', df['v'] + 100) \
.withColumn('f4', df['v'] + 1000) \
.withColumn('f2_f1', df['v'] + 11) \
.withColumn('f3_f1', df['v'] + 101) \
.withColumn('f4_f1', df['v'] + 1001) \
.withColumn('f3_f2', df['v'] + 110) \
.withColumn('f4_f2', df['v'] + 1010) \
.withColumn('f4_f3', df['v'] + 1100) \
.withColumn('f3_f2_f1', df['v'] + 111) \
.withColumn('f4_f2_f1', df['v'] + 1011) \
.withColumn('f4_f3_f1', df['v'] + 1101) \
.withColumn('f4_f3_f2', df['v'] + 1110) \
.withColumn('f4_f3_f2_f1', df['v'] + 1111)

self.assertEquals(df3.collect(), df1.collect())
self.assertEquals(df3.collect(), df2.collect())

def test_mixed_udf_and_sql(self):
import pandas as pd
from pyspark.sql.functions import udf, pandas_udf

df = self.spark.range(0, 1).toDF('v')

# Test mixture of UDFs, Pandas UDFs and SQL expression.

@udf('int')
def f1(x):
assert type(x) == int
return x + 1

def f2(x):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like this is neither @udf nor @pandas_udf, is it on purpose? If so, could you add a comment to explain why?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the purpose is to test mixing udf, pandas_udf and sql expression. I will add comments to make it clearer.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added comments in test

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I see why it looks confusing. Can we add an assert here too (check if it's a column)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added

return x + 10

@pandas_udf('int')
def f3(x):
assert type(x) == pd.Series
return x + 100

df1 = df.withColumn('f1', f1(df['v'])) \
.withColumn('f2', f2(df['v'])) \
.withColumn('f3', f3(df['v'])) \
.withColumn('f1_f2', f1(f2(df['v']))) \
.withColumn('f1_f3', f1(f3(df['v']))) \
.withColumn('f2_f1', f2(f1(df['v']))) \
.withColumn('f2_f3', f2(f3(df['v']))) \
.withColumn('f3_f1', f3(f1(df['v']))) \
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks combination between f1 and f3 duplicating few tests in test_mixed_udf, for instance f4_f3

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, the way the test is written is that I am trying to test many combinations so there are some dup cases. Do you prefer that I remove these?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea.. I know it's still minor since the elapsed time will be virtually the same but recently the build / test time was an issue, and I wonder if there's better way then avoding duplicated tests for now..

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was discussed here #21845

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. I don't think it's necessary (we are only likely to remove a few cases and like you said, the test time is virtually the same) and helps the readability of the tests (so it doesn't look like some test cases are missed).

But if that's the preferred practice I can remove duplicate cases in the next commit.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am okay to leave it too here since it's clear they are virtually the same but let's remove duplicated tests or orthogonal tests next time.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gotcha. I will keep that in mind next time.

.withColumn('f3_f2', f3(f2(df['v']))) \
.withColumn('f1_f2_f3', f1(f2(f3(df['v'])))) \
.withColumn('f1_f3_f2', f1(f3(f2(df['v'])))) \
.withColumn('f2_f1_f3', f2(f1(f3(df['v'])))) \
.withColumn('f2_f3_f1', f2(f3(f1(df['v'])))) \
.withColumn('f3_f1_f2', f3(f1(f2(df['v'])))) \
.withColumn('f3_f2_f1', f3(f2(f1(df['v']))))

# expected result
df2 = df.withColumn('f1', df['v'] + 1) \
.withColumn('f2', df['v'] + 10) \
.withColumn('f3', df['v'] + 100) \
.withColumn('f1_f2', df['v'] + 11) \
.withColumn('f1_f3', df['v'] + 101) \
.withColumn('f2_f1', df['v'] + 11) \
.withColumn('f2_f3', df['v'] + 110) \
.withColumn('f3_f1', df['v'] + 101) \
.withColumn('f3_f2', df['v'] + 110) \
.withColumn('f1_f2_f3', df['v'] + 111) \
.withColumn('f1_f3_f2', df['v'] + 111) \
.withColumn('f2_f1_f3', df['v'] + 111) \
.withColumn('f2_f3_f1', df['v'] + 111) \
.withColumn('f3_f1_f2', df['v'] + 111) \
.withColumn('f3_f2_f1', df['v'] + 111)

self.assertEquals(df2.collect(), df1.collect())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be better to combine this test with the one above and construct it as a list of cases that you could loop over instead of so many blocks of withColumns. Something like

class TestCase():
    def __init__(self, col_name, col_expected, col_projection, col_udf_expression, col_sql_expression):
        ...

cases = [
    TestCase('f4_f3_f2_f1', df['v'] + 1111, f4(df1['f3_f2_f1']), f4(f3(f2(f1(df['v']))), f4(f3(f1(df['v']) + 10)))
    ...]

expected_df = df

for case in cases:
    expected_df = expected_df.with_column(case.col_name, case.col_expected)
    ....

self.assertEquals(expected_df.collect(), projection_df.collect())

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, could you please elaborate a bit? e.g.

TestCase('f4_f3_f2_f1', df['v'] + 1111, f4(df1['f3_f2_f1']), f4(f3(f2(f1(df['v']))), f4(f3(f1(df['v']) + 10)))

How is df1['f3_f2_f1'] defined in this test case?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I chained withColumn together instead of reassigning DataFrames. How does it look now?



@unittest.skipIf(
not _have_pandas or not _have_pyarrow,
Expand Down Expand Up @@ -5487,6 +5617,22 @@ def dummy_pandas_udf(df):
F.col('temp0.key') == F.col('temp1.key'))
self.assertEquals(res.count(), 5)

def test_mixed_scalar_udfs_followed_by_grouby_apply(self):
# Test Pandas UDF and scalar Python UDF followed by groupby apply
from pyspark.sql.functions import udf, pandas_udf, PandasUDFType
import pandas as pd
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not a big deal at all really .. but I would swap the import order (thridparty, pyspark)


df = self.spark.range(0, 10).toDF('v1')
df = df.withColumn('v2', udf(lambda x: x + 1, 'int')(df['v1'])) \
.withColumn('v3', pandas_udf(lambda x: x + 2, 'int')(df['v1']))

result = df.groupby() \
.apply(pandas_udf(lambda x: pd.DataFrame([x.sum().sum()]),
'sum int',
PandasUDFType.GROUPED_MAP))

self.assertEquals(result.collect()[0]['sum'], 165)


@unittest.skipIf(
not _have_pandas or not _have_pyarrow,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer

import org.apache.spark.api.python.PythonEvalType
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Project}
Expand Down Expand Up @@ -94,36 +95,59 @@ object ExtractPythonUDFFromAggregate extends Rule[LogicalPlan] {
*/
object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper {

private def hasPythonUDF(e: Expression): Boolean = {
private def hasScalarPythonUDF(e: Expression): Boolean = {
e.find(PythonUDF.isScalarPythonUDF).isDefined
}

private def canEvaluateInPython(e: PythonUDF): Boolean = {
e.children match {
// single PythonUDF child could be chained and evaluated in Python
case Seq(u: PythonUDF) => canEvaluateInPython(u)
// Python UDF can't be evaluated directly in JVM
case children => !children.exists(hasPythonUDF)
private def canEvaluateInPython(e: PythonUDF, evalType: Int): Boolean = {
if (e.evalType != evalType) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we rename this function or write a comment since Scalar both Vectorized UDF and normal UDF can be evaluated in Python each but it returns false in this case?

false
} else {
e.children match {
// single PythonUDF child could be chained and evaluated in Python
case Seq(u: PythonUDF) => canEvaluateInPython(u, evalType)
// Python UDF can't be evaluated directly in JVM
case children => !children.exists(hasScalarPythonUDF)
}
}
}

private def collectEvaluatableUDF(expr: Expression): Seq[PythonUDF] = expr match {
case udf: PythonUDF if PythonUDF.isScalarPythonUDF(udf) && canEvaluateInPython(udf) => Seq(udf)
case e => e.children.flatMap(collectEvaluatableUDF)
private def collectEvaluableUDF(expr: Expression, evalType: Int): Seq[PythonUDF] = expr match {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a little confusing to have this function named so similar to the one below, maybe you can combine them if just doing a single loop (see other comment).

case udf: PythonUDF if PythonUDF.isScalarPythonUDF(udf) && canEvaluateInPython(udf, evalType) =>
Seq(udf)
case e => e.children.flatMap(collectEvaluableUDF(_, evalType))
}

/**
* Collect evaluable UDFs from the current node.
*
* This function collects Python UDFs or Scalar Python UDFs from expressions of the input node,
* and returns a list of UDFs of the same eval type.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens if the user tries to mix a non-scalar UDF?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm.. It currently will throw an exception in the codegen stage. (Because non-scalar UDF will not be extracted by this rule)

We should probably throw a better exception but I need to think a bit how to do it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried this on master and got the same exception:

>>> foo = pandas_udf(lambda x: x, 'v int', PandasUDFType.GROUPED_MAP)
>>> df.select(foo(df['v'])).show()
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/Users/icexelloss/workspace/upstream/spark/python/pyspark/sql/dataframe.py", line 353, in show
    print(self._jdf.showString(n, 20, vertical))
  File "/Users/icexelloss/workspace/upstream/spark/python/lib/py4j-0.10.7-src.zip/py4j/java_gateway.py", line 1257, in __call__
  File "/Users/icexelloss/workspace/upstream/spark/python/pyspark/sql/utils.py", line 63, in deco
    return f(*a, **kw)
  File "/Users/icexelloss/workspace/upstream/spark/python/lib/py4j-0.10.7-src.zip/py4j/protocol.py", line 328, in get_return_value
py4j.protocol.Py4JJavaError: An error occurred while calling o257.showString.
: java.lang.UnsupportedOperationException: Cannot evaluate expression: <lambda>(input[0, bigint, false])
	at org.apache.spark.sql.catalyst.expressions.Unevaluable$class.doGenCode(Expression.scala:261)
	at org.apache.spark.sql.catalyst.expressions.PythonUDF.doGenCode(PythonUDF.scala:50)
	at org.apache.spark.sql.catalyst.expressions.Expression$$anonfun$genCode$2.apply(Expression.scala:108)
	at org.apache.spark.sql.catalyst.expressions.Expression$$anonfun$genCode$2.apply(Expression.scala:105)
	at scala.Option.getOrElse(Option.scala:121)
        ...

Therefore, this PR doesn't change that behavior. Both master and this PR don't extract non-scalar UDF in the expression.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, that's not a very informative exception but we can fix that later. I made https://issues.apache.org/jira/browse/SPARK-24735 to track.

*
* If expressions contain both UDFs eval types, this function will only return Python UDFs.
*
* The caller should call this function multiple times until all evaluable UDFs are collected.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this will pipeline UDFs of the same eval type so that they can be processed together in the same call to python worker?

For example if we have pandas_udf, pandas_udf, udf, udf then both pandas_udfs will be sent together to the worker, then both udfs together - python runner gets executed twice.

On the other hand, if we have pandas_udf, udf, pandas_udf, udf then each one will have to be executed at a time, and python runner gets executed 4 times. Is that right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's correct.

*/
private def collectEvaluableUDFs(plan: SparkPlan): Seq[PythonUDF] = {
val pythonUDFs =
plan.expressions.flatMap(collectEvaluableUDF(_, PythonEvalType.SQL_BATCHED_UDF))

if (pythonUDFs.isEmpty) {
plan.expressions.flatMap(collectEvaluableUDF(_, PythonEvalType.SQL_SCALAR_PANDAS_UDF))
} else {
pythonUDFs
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be better to loop through the expressions and find the first scalar python udf, either SQL_BATCHED_UDF or SQL_SCALAR_PANDAS_UDF and then collect the rest of that type. This is really what is happening here so I think it would be more straightforward to do this in a single loop instead of 2 flatMaps.

Copy link
Contributor Author

@icexelloss icexelloss Jul 23, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What you said makes sense and that's actually my first attempt but end up being pretty complicated. The issue is that it is hard to do a one traversal of the expression tree to find the UDFs because we need to pass the evalType to all subtree and the result of one subtree can affect the result of another (i.e, if we find one type of UDF in one subtree, we need to pass the type to all other subtree because they must agree on evalType). Because the code is recursive in natural, this makes it pretty complicated to pass the correct eval Type in all places.

Another way is to do two traversals where in the first traversal, we look for eval type and in the second traversal, we look for UDFs of the eval type, but this isn't much different from what I have now in terms of efficiency and I find the current logic is simpler and less likely to have bugs. I actually tried these approaches and found the current way to be the easiest to implement and least likely to have bugs.

WDYT?

}
}

def apply(plan: SparkPlan): SparkPlan = plan transformUp {
// AggregateInPandasExec and FlatMapGroupsInPandas can be evaluated directly in python worker
// Therefore we don't need to extract the UDFs
case plan: FlatMapGroupsInPandasExec => plan
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is no longer needed because this rule will only extract Python UDF and Scalar Pandas UDF and ignore other types of UDFs

case plan: SparkPlan => extract(plan)
}

/**
* Extract all the PythonUDFs from the current operator and evaluate them before the operator.
*/
private def extract(plan: SparkPlan): SparkPlan = {
val udfs = plan.expressions.flatMap(collectEvaluatableUDF)
val udfs = collectEvaluableUDFs(plan)
// ignore the PythonUDF that come from second/third aggregate, which is not used
.filter(udf => udf.references.subsetOf(plan.inputSet))
if (udfs.isEmpty) {
Expand Down Expand Up @@ -167,7 +191,8 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper {
case (vectorizedUdfs, plainUdfs) if vectorizedUdfs.isEmpty =>
BatchEvalPythonExec(plainUdfs, child.output ++ resultAttrs, child)
case _ =>
throw new IllegalArgumentException("Can not mix vectorized and non-vectorized UDFs")
throw new AnalysisException(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why change the exception type? Can you make a test that causes this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is because we shouldn't reach here. (Otherwise it's bug). Don't know what's the best exception type here though.

"Expected either Scalar Pandas UDFs or Batched UDFs but got both")
}

attributeMap ++= validUdfs.zip(resultAttrs)
Expand Down Expand Up @@ -205,7 +230,7 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper {
case filter: FilterExec =>
val (candidates, nonDeterministic) =
splitConjunctivePredicates(filter.condition).partition(_.deterministic)
val (pushDown, rest) = candidates.partition(!hasPythonUDF(_))
val (pushDown, rest) = candidates.partition(!hasScalarPythonUDF(_))
if (pushDown.nonEmpty) {
val newChild = FilterExec(pushDown.reduceLeft(And), filter.child)
FilterExec((rest ++ nonDeterministic).reduceLeft(And), newChild)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,3 +115,10 @@ class MyDummyPythonUDF extends UserDefinedPythonFunction(
dataType = BooleanType,
pythonEvalType = PythonEvalType.SQL_BATCHED_UDF,
udfDeterministic = true)

class MyDummyScalarPandasUDF extends UserDefinedPythonFunction(
name = "dummyScalarPandasUDF",
func = new DummyUDF,
dataType = BooleanType,
pythonEvalType = PythonEvalType.SQL_SCALAR_PANDAS_UDF,
udfDeterministic = true)
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.python

import org.apache.spark.sql.execution.{SparkPlan, SparkPlanTest}
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.test.SharedSQLContext

class ExtractPythonUDFsSuite extends SparkPlanTest with SharedSQLContext {
import testImplicits.newProductEncoder
import testImplicits.localSeqToDatasetHolder

val batchedPythonUDF = new MyDummyPythonUDF
val scalarPandasUDF = new MyDummyScalarPandasUDF

private def collectBatchExec(plan: SparkPlan): Seq[BatchEvalPythonExec] = plan.collect {
case b: BatchEvalPythonExec => b
}

private def collectArrowExec(plan: SparkPlan): Seq[ArrowEvalPythonExec] = plan.collect {
case b: ArrowEvalPythonExec => b
}

test("Chained Batched Python UDFs should be combined to a single physical node") {
val df = Seq(("Hello", 4)).toDF("a", "b")
val df2 = df.withColumn("c", batchedPythonUDF(col("a")))
.withColumn("d", batchedPythonUDF(col("c")))
val pythonEvalNodes = collectBatchExec(df2.queryExecution.executedPlan)
assert(pythonEvalNodes.size == 1)
}

test("Chained Scalar Pandas UDFs should be combined to a single physical node") {
val df = Seq(("Hello", 4)).toDF("a", "b")
val df2 = df.withColumn("c", scalarPandasUDF(col("a")))
.withColumn("d", scalarPandasUDF(col("c")))
val arrowEvalNodes = collectArrowExec(df2.queryExecution.executedPlan)
assert(arrowEvalNodes.size == 1)
}

test("Mixed Batched Python UDFs and Pandas UDF should be separate physical node") {
val df = Seq(("Hello", 4)).toDF("a", "b")
val df2 = df.withColumn("c", batchedPythonUDF(col("a")))
.withColumn("d", scalarPandasUDF(col("b")))

val pythonEvalNodes = collectBatchExec(df2.queryExecution.executedPlan)
val arrowEvalNodes = collectArrowExec(df2.queryExecution.executedPlan)
assert(pythonEvalNodes.size == 1)
assert(arrowEvalNodes.size == 1)
}

test("Independent Batched Python UDFs and Scalar Pandas UDFs should be combined separately") {
val df = Seq(("Hello", 4)).toDF("a", "b")
val df2 = df.withColumn("c1", batchedPythonUDF(col("a")))
.withColumn("c2", batchedPythonUDF(col("c1")))
.withColumn("d1", scalarPandasUDF(col("a")))
.withColumn("d2", scalarPandasUDF(col("d1")))

val pythonEvalNodes = collectBatchExec(df2.queryExecution.executedPlan)
val arrowEvalNodes = collectArrowExec(df2.queryExecution.executedPlan)
assert(pythonEvalNodes.size == 1)
assert(arrowEvalNodes.size == 1)
}

test("Dependent Batched Python UDFs and Scalar Pandas UDFs should not be combined") {
val df = Seq(("Hello", 4)).toDF("a", "b")
val df2 = df.withColumn("c1", batchedPythonUDF(col("a")))
.withColumn("d1", scalarPandasUDF(col("c1")))
.withColumn("c2", batchedPythonUDF(col("d1")))
.withColumn("d2", scalarPandasUDF(col("c2")))

val pythonEvalNodes = collectBatchExec(df2.queryExecution.executedPlan)
val arrowEvalNodes = collectArrowExec(df2.queryExecution.executedPlan)
assert(pythonEvalNodes.size == 2)
assert(arrowEvalNodes.size == 2)
}
}