Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions python/pyspark/sql/tests/test_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,18 @@ def test_udf_in_join_condition(self):
with self.sql_conf({"spark.sql.crossJoin.enabled": True}):
self.assertEqual(df.collect(), [Row(a=1, b=1)])

def test_udf_in_left_outer_join_condition(self):
# regression test for SPARK-26147
from pyspark.sql.functions import udf, col
left = self.spark.createDataFrame([Row(a=1)])
right = self.spark.createDataFrame([Row(b=1)])
f = udf(lambda a: str(a), StringType())
# The join condition can't be pushed down, as it refers to attributes from both sides.
# The Python UDF only refer to attributes from one side, so it's evaluable.
df = left.join(right, f("a") == col("b").cast("string"), how = "left_outer")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style nit: how="left_outer"

with self.sql_conf({"spark.sql.crossJoin.enabled": True}):
self.assertEqual(df.collect(), [Row(a=1, b=1)])

def test_udf_in_left_semi_join_condition(self):
# regression test for SPARK-25314
from pyspark.sql.functions import udf
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,19 +155,20 @@ object EliminateOuterJoin extends Rule[LogicalPlan] with PredicateHelper {
}

/**
* PythonUDF in join condition can not be evaluated, this rule will detect the PythonUDF
* and pull them out from join condition. For python udf accessing attributes from only one side,
* they are pushed down by operation push down rules. If not (e.g. user disables filter push
* down rules), we need to pull them out in this rule too.
* PythonUDF in join condition can't be evaluated if it refers to attributes from both join sides.
* See `ExtractPythonUDFs` for details. This rule will detect un-evaluable PythonUDF and pull them
* out from join condition.
*/
object PullOutPythonUDFInJoinCondition extends Rule[LogicalPlan] with PredicateHelper {
def hasPythonUDF(expression: Expression): Boolean = {
expression.collectFirst { case udf: PythonUDF => udf }.isDefined

private def hasUnevaluablePythonUDF(expr: Expression, j: Join): Boolean = {
expr.find { e =>
PythonUDF.isScalarPythonUDF(e) && !canEvaluate(e, j.left) && !canEvaluate(e, j.right)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We might need a comment to explain why we only pull out the Scalar PythonUDF.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's only possible to have scalar UDF in join condition, so changing it to e.isInstanceOf[PythonUDF] is same.

}.isDefined
}

override def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
case j @ Join(_, _, joinType, condition)
if condition.isDefined && hasPythonUDF(condition.get) =>
case j @ Join(_, _, joinType, Some(cond)) if hasUnevaluablePythonUDF(cond, j) =>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Followed by the rule changes, we need modify the suites in PullOutPythonUDFInJoinConditionSuite, the suites should also construct the dummy python udf from both side.

if (!joinType.isInstanceOf[InnerLike] && joinType != LeftSemi) {
// The current strategy only support InnerLike and LeftSemi join because for other type,
// it breaks SQL semantic if we run the join condition as a filter after join. If we pass
Expand All @@ -179,10 +180,9 @@ object PullOutPythonUDFInJoinCondition extends Rule[LogicalPlan] with PredicateH
}
// If condition expression contains python udf, it will be moved out from
// the new join conditions.
val (udf, rest) =
splitConjunctivePredicates(condition.get).partition(hasPythonUDF)
val (udf, rest) = splitConjunctivePredicates(cond).partition(hasUnevaluablePythonUDF(_, j))
val newCondition = if (rest.isEmpty) {
logWarning(s"The join condition:$condition of the join plan contains PythonUDF only," +
logWarning(s"The join condition:$cond of the join plan contains PythonUDF only," +
s" it will be moved out and the join plan will be turned to cross join.")
None
} else {
Expand Down