Skip to content

Commit b94fa97

Browse files
viiryaHyukjinKwon
authored andcommitted
[SPARK-28345][SQL][PYTHON] PythonUDF predicate should be able to pushdown to join
## What changes were proposed in this pull request? A `Filter` predicate using `PythonUDF` can't be push down into join condition, currently. A predicate like that should be able to push down to join condition. For `PythonUDF`s that can't be evaluated in join condition, `PullOutPythonUDFInJoinCondition` will pull them out later. An example like: ```scala val pythonTestUDF = TestPythonUDF(name = "udf") val left = Seq((1, 2), (2, 3)).toDF("a", "b") val right = Seq((1, 2), (3, 4)).toDF("c", "d") val df = left.crossJoin(right).where(pythonTestUDF($"a") === pythonTestUDF($"c")) ``` Query plan before the PR: ``` == Physical Plan == *(3) Project [a#2121, b#2122, c#2132, d#2133] +- *(3) Filter (pythonUDF0#2142 = pythonUDF1#2143) +- BatchEvalPython [udf(a#2121), udf(c#2132)], [pythonUDF0#2142, pythonUDF1#2143] +- BroadcastNestedLoopJoin BuildRight, Cross :- *(1) Project [_1#2116 AS a#2121, _2#2117 AS b#2122] : +- LocalTableScan [_1#2116, _2#2117] +- BroadcastExchange IdentityBroadcastMode +- *(2) Project [_1#2127 AS c#2132, _2#2128 AS d#2133] +- LocalTableScan [_1#2127, _2#2128] ``` Query plan after the PR: ``` == Physical Plan == *(3) Project [a#2121, b#2122, c#2132, d#2133] +- *(3) BroadcastHashJoin [pythonUDF0#2142], [pythonUDF0#2143], Cross, BuildRight :- BatchEvalPython [udf(a#2121)], [pythonUDF0#2142] : +- *(1) Project [_1#2116 AS a#2121, _2#2117 AS b#2122] : +- LocalTableScan [_1#2116, _2#2117] +- BroadcastExchange HashedRelationBroadcastMode(List(input[2, string, true])) +- BatchEvalPython [udf(c#2132)], [pythonUDF0#2143] +- *(2) Project [_1#2127 AS c#2132, _2#2128 AS d#2133] +- LocalTableScan [_1#2127, _2#2128] ``` After this PR, the join can use `BroadcastHashJoin`, instead of `BroadcastNestedLoopJoin`. ## How was this patch tested? Added tests. Closes #25106 from viirya/pythonudf-join-condition. Authored-by: Liang-Chi Hsieh <viirya@gmail.com> Signed-off-by: HyukjinKwon <gurwls223@apache.org>
1 parent 8e26d4d commit b94fa97

3 files changed

Lines changed: 59 additions & 4 deletions

File tree

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,10 @@ trait PredicateHelper {
116116
// non-correlated subquery will be replaced as literal
117117
e.children.isEmpty
118118
case a: AttributeReference => true
119+
// PythonUDF will be executed by dedicated physical operator later.
120+
// For PythonUDFs that can't be evaluated in join condition, `PullOutPythonUDFInJoinCondition`
121+
// will pull them out later.
122+
case _: PythonUDF => true
119123
case e: Unevaluable => false
120124
case e => e.children.forall(canEvaluateWithinJoin)
121125
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,15 @@
1717

1818
package org.apache.spark.sql.catalyst.optimizer
1919

20+
import org.apache.spark.api.python.PythonEvalType
2021
import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases
2122
import org.apache.spark.sql.catalyst.dsl.expressions._
2223
import org.apache.spark.sql.catalyst.dsl.plans._
2324
import org.apache.spark.sql.catalyst.expressions._
2425
import org.apache.spark.sql.catalyst.plans._
2526
import org.apache.spark.sql.catalyst.plans.logical._
2627
import org.apache.spark.sql.catalyst.rules._
27-
import org.apache.spark.sql.types.IntegerType
28+
import org.apache.spark.sql.types.{BooleanType, IntegerType}
2829
import org.apache.spark.unsafe.types.CalendarInterval
2930

3031
class FilterPushdownSuite extends PlanTest {
@@ -41,9 +42,14 @@ class FilterPushdownSuite extends PlanTest {
4142
CollapseProject) :: Nil
4243
}
4344

44-
val testRelation = LocalRelation('a.int, 'b.int, 'c.int)
45+
val attrA = 'a.int
46+
val attrB = 'b.int
47+
val attrC = 'c.int
48+
val attrD = 'd.int
4549

46-
val testRelation1 = LocalRelation('d.int)
50+
val testRelation = LocalRelation(attrA, attrB, attrC)
51+
52+
val testRelation1 = LocalRelation(attrD)
4753

4854
// This test already passes.
4955
test("eliminate subqueries") {
@@ -1202,4 +1208,26 @@ class FilterPushdownSuite extends PlanTest {
12021208
comparePlans(Optimize.execute(originalQuery.analyze), correctAnswer.analyze,
12031209
checkAnalysis = false)
12041210
}
1211+
1212+
test("SPARK-28345: PythonUDF predicate should be able to pushdown to join") {
1213+
val pythonUDFJoinCond = {
1214+
val pythonUDF = PythonUDF("pythonUDF", null,
1215+
IntegerType,
1216+
Seq(attrA),
1217+
PythonEvalType.SQL_BATCHED_UDF,
1218+
udfDeterministic = true)
1219+
pythonUDF === attrD
1220+
}
1221+
1222+
val query = testRelation.join(
1223+
testRelation1,
1224+
joinType = Cross).where(pythonUDFJoinCond)
1225+
1226+
val expected = testRelation.join(
1227+
testRelation1,
1228+
joinType = Cross,
1229+
condition = Some(pythonUDFJoinCond)).analyze
1230+
1231+
comparePlans(Optimize.execute(query.analyze), expected)
1232+
}
12051233
}

sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@ import org.apache.spark.TestUtils.{assertNotSpilled, assertSpilled}
2626
import org.apache.spark.sql.catalyst.TableIdentifier
2727
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
2828
import org.apache.spark.sql.catalyst.expressions.{Ascending, SortOrder}
29-
import org.apache.spark.sql.execution.{BinaryExecNode, SortExec}
29+
import org.apache.spark.sql.catalyst.plans.logical.Filter
30+
import org.apache.spark.sql.execution.{BinaryExecNode, FilterExec, SortExec}
3031
import org.apache.spark.sql.execution.joins._
3132
import org.apache.spark.sql.execution.python.BatchEvalPythonExec
3233
import org.apache.spark.sql.internal.SQLConf
@@ -994,4 +995,26 @@ class JoinSuite extends QueryTest with SharedSQLContext {
994995

995996
checkAnswer(df, Row(1, 2, 1, 2) :: Nil)
996997
}
998+
999+
test("SPARK-28345: PythonUDF predicate should be able to pushdown to join") {
1000+
import IntegratedUDFTestUtils._
1001+
1002+
assume(shouldTestPythonUDFs)
1003+
1004+
val pythonTestUDF = TestPythonUDF(name = "udf")
1005+
1006+
val left = Seq((1, 2), (2, 3)).toDF("a", "b")
1007+
val right = Seq((1, 2), (3, 4)).toDF("c", "d")
1008+
val df = left.crossJoin(right).where(pythonTestUDF($"a") === pythonTestUDF($"c"))
1009+
1010+
// Before optimization, there is a logical Filter operator.
1011+
val filterInAnalysis = df.queryExecution.analyzed.find(_.isInstanceOf[Filter])
1012+
assert(filterInAnalysis.isDefined)
1013+
1014+
// Filter predicate was pushdown as join condition. So there is no Filter exec operator.
1015+
val filterExec = df.queryExecution.executedPlan.find(_.isInstanceOf[FilterExec])
1016+
assert(filterExec.isEmpty)
1017+
1018+
checkAnswer(df, Row(1, 2, 1, 2) :: Nil)
1019+
}
9971020
}

0 commit comments

Comments
 (0)