Skip to content

Commit d440cbf

Browse files
committed
Move ExtractPythonUDFs to end of optimize stage
1 parent 7076820 commit d440cbf

8 files changed

Lines changed: 56 additions & 33 deletions

File tree

python/pyspark/sql/tests.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3388,15 +3388,14 @@ def test_datasource_with_udf_filter_lit_input(self):
33883388
datasource_df = self.spark.read \
33893389
.format("org.apache.spark.sql.sources.SimpleScanSource") \
33903390
.option('from', 0).option('to', 1).load()
3391-
# TODO: Enable data source v2 after SPARK-25213 is fixed
3392-
# datasource_v2_df = self.spark.read \
3393-
# .format("org.apache.spark.sql.sources.v2.SimpleDataSourceV2") \
3394-
# .load()
3391+
datasource_v2_df = self.spark.read \
3392+
.format("org.apache.spark.sql.sources.v2.SimpleDataSourceV2") \
3393+
.load()
33953394

33963395
filter1 = udf(lambda: False, 'boolean')()
33973396
filter2 = udf(lambda x: False, 'boolean')(lit(1))
33983397

3399-
for df in [filesource_df, datasource_df]:
3398+
for df in [filesource_df, datasource_df, datasource_v2_df]:
34003399
for f in [filter1, filter2]:
34013400
result = df.filter(f)
34023401
self.assertEquals(0, result.count())
@@ -5309,7 +5308,7 @@ def f3(x):
53095308
# SPARK-24721
53105309
@unittest.skipIf(not _test_compiled, _test_not_compiled_message)
53115310
def test_datasource_with_udf_filter_lit_input(self):
5312-
# Same as SQLTests.test_datasource_with_udf_filter_lit_input, but with Pandas UDF
5311+
# Same as SQLTests.test_datasource_with_udf_filter_lit_input, but with Pantestdas UDF
53135312
# This needs to a separate test because Arrow dependency is optional
53145313
import pandas as pd
53155314
import numpy as np
@@ -5323,14 +5322,13 @@ def test_datasource_with_udf_filter_lit_input(self):
53235322
datasource_df = self.spark.read \
53245323
.format("org.apache.spark.sql.sources.SimpleScanSource") \
53255324
.option('from', 0).option('to', 1).load()
5326-
# TODO: Enable data source v2 after SPARK-25213 is fixed
5327-
# datasource_v2_df = self.spark.read \
5328-
# .format("org.apache.spark.sql.sources.v2.SimpleDataSourceV2") \
5329-
# .load()
5325+
datasource_v2_df = self.spark.read \
5326+
.format("org.apache.spark.sql.sources.v2.SimpleDataSourceV2") \
5327+
.load()
53305328

53315329
f = pandas_udf(lambda x: pd.Series(np.repeat(False, len(x))), 'boolean')(lit(1))
53325330

5333-
for df in [filesource_df, datasource_df]:
5331+
for df in [filesource_df, datasource_df, datasource_v2_df]:
53345332
result = df.filter(f)
53355333
self.assertEquals(0, result.count())
53365334

sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,6 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) {
8989

9090
/** A sequence of rules that will be applied in order to the physical plan before execution. */
9191
protected def preparations: Seq[Rule[SparkPlan]] = Seq(
92-
python.ExtractPythonUDFs,
9392
PlanSubqueries(sparkSession),
9493
EnsureRequirements(sparkSession.sessionState.conf),
9594
CollapseCodegenStages(sparkSession.sessionState.conf),

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.catalog.SessionCatalog
2222
import org.apache.spark.sql.catalyst.optimizer.Optimizer
2323
import org.apache.spark.sql.execution.datasources.PruneFileSourcePartitions
2424
import org.apache.spark.sql.execution.datasources.parquet.ParquetSchemaPruning
25-
import org.apache.spark.sql.execution.python.ExtractPythonUDFFromAggregate
25+
import org.apache.spark.sql.execution.python.{ExtractPythonUDFFromAggregate, ExtractPythonUDFs}
2626

2727
class SparkOptimizer(
2828
catalog: SessionCatalog,
@@ -31,7 +31,8 @@ class SparkOptimizer(
3131

3232
override def defaultBatches: Seq[Batch] = (preOptimizationBatches ++ super.defaultBatches :+
3333
Batch("Optimize Metadata Only Query", Once, OptimizeMetadataOnlyQuery(catalog)) :+
34-
Batch("Extract Python UDF from Aggregate", Once, ExtractPythonUDFFromAggregate) :+
34+
Batch("Extract Python UDFs", Once,
35+
Seq(ExtractPythonUDFFromAggregate, ExtractPythonUDFs): _*) :+
3536
Batch("Prune File Source Table Partitions", Once, PruneFileSourcePartitions) :+
3637
Batch("Parquet Schema Pruning", Once, ParquetSchemaPruning)) ++
3738
postHocOptimizationBatches :+

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ class SparkPlanner(
3636
override def strategies: Seq[Strategy] =
3737
experimentalMethods.extraStrategies ++
3838
extraPlanningStrategies ++ (
39+
PythonEvals ::
3940
DataSourceV2Strategy ::
4041
FileSourceStrategy ::
4142
DataSourceStrategy(conf) ::

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ import org.apache.spark.sql.execution.columnar.{InMemoryRelation, InMemoryTableS
3232
import org.apache.spark.sql.execution.command._
3333
import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
3434
import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight, BuildSide}
35+
import org.apache.spark.sql.execution.python._
3536
import org.apache.spark.sql.execution.streaming._
3637
import org.apache.spark.sql.execution.streaming.sources.MemoryPlanV2
3738
import org.apache.spark.sql.internal.SQLConf
@@ -517,6 +518,20 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
517518
}
518519
}
519520

521+
/**
522+
* Strategy to convert EvalPython logical operator to physical operator.
523+
*/
524+
object PythonEvals extends Strategy {
525+
override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
526+
case ArrowEvalPython(udfs, output, child) =>
527+
ArrowEvalPythonExec(udfs, output, planLater(child)) :: Nil
528+
case BatchEvalPython(udfs, output, child) =>
529+
BatchEvalPythonExec(udfs, output, planLater(child)) :: Nil
530+
case _ =>
531+
Nil
532+
}
533+
}
534+
520535
object BasicOperators extends Strategy {
521536
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
522537
case d: DataWritingCommand => DataWritingCommandExec(d, planLater(d.query)) :: Nil

sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import org.apache.spark.TaskContext
2323
import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType}
2424
import org.apache.spark.sql.catalyst.InternalRow
2525
import org.apache.spark.sql.catalyst.expressions._
26+
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, UnaryNode}
2627
import org.apache.spark.sql.execution.SparkPlan
2728
import org.apache.spark.sql.execution.arrow.ArrowUtils
2829
import org.apache.spark.sql.types.StructType
@@ -57,7 +58,13 @@ private class BatchIterator[T](iter: Iterator[T], batchSize: Int)
5758
}
5859

5960
/**
60-
* A physical plan that evaluates a [[PythonUDF]],
61+
* A logical plan that evaluates a [[PythonUDF]].
62+
*/
63+
case class ArrowEvalPython(udfs: Seq[PythonUDF], output: Seq[Attribute], child: LogicalPlan)
64+
extends UnaryNode
65+
66+
/**
67+
* A physical plan that evaluates a [[PythonUDF]].
6168
*/
6269
case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], child: SparkPlan)
6370
extends EvalPythonExec(udfs, output, child) {

sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,16 @@ import org.apache.spark.TaskContext
2525
import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType}
2626
import org.apache.spark.sql.catalyst.InternalRow
2727
import org.apache.spark.sql.catalyst.expressions._
28+
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, UnaryNode}
2829
import org.apache.spark.sql.execution.SparkPlan
2930
import org.apache.spark.sql.types.{StructField, StructType}
3031

32+
/**
33+
* A logical plan that evaluates a [[PythonUDF]]
34+
*/
35+
case class BatchEvalPython(udfs: Seq[PythonUDF], output: Seq[Attribute], child: LogicalPlan)
36+
extends UnaryNode
37+
3138
/**
3239
* A physical plan that evaluates a [[PythonUDF]]
3340
*/

sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala

Lines changed: 13 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,8 @@ import org.apache.spark.api.python.PythonEvalType
2424
import org.apache.spark.sql.AnalysisException
2525
import org.apache.spark.sql.catalyst.expressions._
2626
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
27-
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Project}
27+
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LogicalPlan, Project}
2828
import org.apache.spark.sql.catalyst.rules.Rule
29-
import org.apache.spark.sql.execution._
30-
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanExec
3129

3230

3331
/**
@@ -94,7 +92,7 @@ object ExtractPythonUDFFromAggregate extends Rule[LogicalPlan] {
9492
* This has the limitation that the input to the Python UDF is not allowed include attributes from
9593
* multiple child operators.
9694
*/
97-
object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper {
95+
object ExtractPythonUDFs extends Rule[LogicalPlan] with PredicateHelper {
9896

9997
private type EvalType = Int
10098
private type EvalTypeChecker = EvalType => Boolean
@@ -133,17 +131,14 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper {
133131
expressions.flatMap(collectEvaluableUDFs)
134132
}
135133

136-
def apply(plan: SparkPlan): SparkPlan = plan transformUp {
137-
// SPARK-24721: Ignore Python UDFs in DataSourceScan and DataSourceV2Scan
138-
case plan: DataSourceScanExec => plan
139-
case plan: DataSourceV2ScanExec => plan
140-
case plan: SparkPlan => extract(plan)
134+
def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
135+
case plan: LogicalPlan => extract(plan)
141136
}
142137

143138
/**
144139
* Extract all the PythonUDFs from the current operator and evaluate them before the operator.
145140
*/
146-
private def extract(plan: SparkPlan): SparkPlan = {
141+
private def extract(plan: LogicalPlan): LogicalPlan = {
147142
val udfs = collectEvaluableUDFsFromExpressions(plan.expressions)
148143
// ignore the PythonUDF that come from second/third aggregate, which is not used
149144
.filter(udf => udf.references.subsetOf(plan.inputSet))
@@ -155,7 +150,7 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper {
155150
val prunedChildren = plan.children.map { child =>
156151
val allNeededOutput = inputsForPlan.intersect(child.outputSet).toSeq
157152
if (allNeededOutput.length != child.output.length) {
158-
ProjectExec(allNeededOutput, child)
153+
Project(allNeededOutput, child)
159154
} else {
160155
child
161156
}
@@ -184,9 +179,9 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper {
184179
_.evalType == PythonEvalType.SQL_SCALAR_PANDAS_UDF
185180
) match {
186181
case (vectorizedUdfs, plainUdfs) if plainUdfs.isEmpty =>
187-
ArrowEvalPythonExec(vectorizedUdfs, child.output ++ resultAttrs, child)
182+
ArrowEvalPython(vectorizedUdfs, child.output ++ resultAttrs, child)
188183
case (vectorizedUdfs, plainUdfs) if vectorizedUdfs.isEmpty =>
189-
BatchEvalPythonExec(plainUdfs, child.output ++ resultAttrs, child)
184+
BatchEvalPython(plainUdfs, child.output ++ resultAttrs, child)
190185
case _ =>
191186
throw new AnalysisException(
192187
"Expected either Scalar Pandas UDFs or Batched UDFs but got both")
@@ -213,7 +208,7 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper {
213208
val newPlan = extract(rewritten)
214209
if (newPlan.output != plan.output) {
215210
// Trim away the new UDF value if it was only used for filtering or something.
216-
ProjectExec(plan.output, newPlan)
211+
Project(plan.output, newPlan)
217212
} else {
218213
newPlan
219214
}
@@ -222,15 +217,15 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper {
222217

223218
// Split the original FilterExec to two FilterExecs. Only push down the first few predicates
224219
// that are all deterministic.
225-
private def trySplitFilter(plan: SparkPlan): SparkPlan = {
220+
private def trySplitFilter(plan: LogicalPlan): LogicalPlan = {
226221
plan match {
227-
case filter: FilterExec =>
222+
case filter: Filter =>
228223
val (candidates, nonDeterministic) =
229224
splitConjunctivePredicates(filter.condition).partition(_.deterministic)
230225
val (pushDown, rest) = candidates.partition(!hasScalarPythonUDF(_))
231226
if (pushDown.nonEmpty) {
232-
val newChild = FilterExec(pushDown.reduceLeft(And), filter.child)
233-
FilterExec((rest ++ nonDeterministic).reduceLeft(And), newChild)
227+
val newChild = Filter(pushDown.reduceLeft(And), filter.child)
228+
Filter((rest ++ nonDeterministic).reduceLeft(And), newChild)
234229
} else {
235230
filter
236231
}

0 commit comments

Comments
 (0)