Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,10 @@ trait PredicateHelper {
// non-correlated subquery will be replaced as literal
e.children.isEmpty
case a: AttributeReference => true
// PythonUDF will be executed by dedicated physical operator later.
// For PythonUDFs that can't be evaluated in join condition, `PullOutPythonUDFInJoinCondition`
// will pull them out later.
case _: PythonUDF => true
case e: Unevaluable => false
case e => e.children.forall(canEvaluateWithinJoin)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,15 @@

package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.api.python.PythonEvalType
import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.types.IntegerType
import org.apache.spark.sql.types.{BooleanType, IntegerType}
import org.apache.spark.unsafe.types.CalendarInterval

class FilterPushdownSuite extends PlanTest {
Expand All @@ -41,9 +42,14 @@ class FilterPushdownSuite extends PlanTest {
CollapseProject) :: Nil
}

val testRelation = LocalRelation('a.int, 'b.int, 'c.int)
val attrA = 'a.int
val attrB = 'b.int
val attrC = 'c.int
val attrD = 'd.int

val testRelation1 = LocalRelation('d.int)
val testRelation = LocalRelation(attrA, attrB, attrC)

val testRelation1 = LocalRelation(attrD)

// This test already passes.
test("eliminate subqueries") {
Expand Down Expand Up @@ -1202,4 +1208,26 @@ class FilterPushdownSuite extends PlanTest {
comparePlans(Optimize.execute(originalQuery.analyze), correctAnswer.analyze,
checkAnalysis = false)
}

test("SPARK-28345: PythonUDF predicate should be able to pushdown to join") {
val pythonUDFJoinCond = {
val pythonUDF = PythonUDF("pythonUDF", null,
IntegerType,
Seq(attrA),
PythonEvalType.SQL_BATCHED_UDF,
udfDeterministic = true)
pythonUDF === attrD
}

val query = testRelation.join(
testRelation1,
joinType = Cross).where(pythonUDFJoinCond)

val expected = testRelation.join(
testRelation1,
joinType = Cross,
condition = Some(pythonUDFJoinCond)).analyze

comparePlans(Optimize.execute(query.analyze), expected)
}
}
25 changes: 24 additions & 1 deletion sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ import org.apache.spark.TestUtils.{assertNotSpilled, assertSpilled}
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
import org.apache.spark.sql.catalyst.expressions.{Ascending, SortOrder}
import org.apache.spark.sql.execution.{BinaryExecNode, SortExec}
import org.apache.spark.sql.catalyst.plans.logical.Filter
import org.apache.spark.sql.execution.{BinaryExecNode, FilterExec, SortExec}
import org.apache.spark.sql.execution.joins._
import org.apache.spark.sql.execution.python.BatchEvalPythonExec
import org.apache.spark.sql.internal.SQLConf
Expand Down Expand Up @@ -994,4 +995,26 @@ class JoinSuite extends QueryTest with SharedSQLContext {

checkAnswer(df, Row(1, 2, 1, 2) :: Nil)
}

test("SPARK-28345: PythonUDF predicate should be able to pushdown to join") {
import IntegratedUDFTestUtils._

assume(shouldTestPythonUDFs)

val pythonTestUDF = TestPythonUDF(name = "udf")

val left = Seq((1, 2), (2, 3)).toDF("a", "b")
val right = Seq((1, 2), (3, 4)).toDF("c", "d")
val df = left.crossJoin(right).where(pythonTestUDF($"a") === pythonTestUDF($"c"))

// Before optimization, there is a logical Filter operator.
val filterInAnalysis = df.queryExecution.analyzed.find(_.isInstanceOf[Filter])
assert(filterInAnalysis.isDefined)

// Filter predicate was pushdown as join condition. So there is no Filter exec operator.
val filterExec = df.queryExecution.executedPlan.find(_.isInstanceOf[FilterExec])
assert(filterExec.isEmpty)

checkAnswer(df, Row(1, 2, 1, 2) :: Nil)
}
}