Skip to content
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,14 @@ def test_udf_in_filter_on_top_of_outer_join(self):
df = df.withColumn('b', udf(lambda x: 'x')(df.a))
self.assertEqual(df.filter('b = "x"').collect(), [Row(a=1, b='x')])

def test_udf_in_filter_on_top_of_join(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should reference jira number

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

from pyspark.sql.functions import udf
left = self.spark.createDataFrame([Row(a=1)])
right = self.spark.createDataFrame([Row(b=1)])
f = udf(lambda a, b: a == b, BooleanType())
df = left.crossJoin(right).filter(f("a", "b"))
self.assertEqual(df.collect(), [Row(a=1, b=1)])

def test_udf_without_arguments(self):
self.spark.catalog.registerFunction("foo", lambda: "bar")
[row] = self.spark.sql("SELECT foo()").collect()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,19 @@ trait PredicateHelper {
*/
protected def canEvaluate(expr: Expression, plan: LogicalPlan): Boolean =
expr.references.subsetOf(plan.outputSet)

/**
* Returns true iff `expr` could be evaluated as a condition within join.
*/
protected def canEvaluateWithinJoin(expr: Expression): Boolean = {
expr.find {
case e: SubqueryExpression =>
// non-correlated subquery will be replaced as literal
e.children.nonEmpty
case e: Unevaluable => true
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we need more documentation here on why should be considered evaluable as a join condition.

for example, just looking at this code i have no idea why Uneavaluable is evaluable.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unevaluable is not evaluable. This block tries to find a case that is not evaluable in a join, and then negates it by isEmpty. I have to admit that we should document this.

case _ => false
}.isEmpty
}
}

@ExpressionDescription(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -893,7 +893,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper {
val newRight = rightFilterConditions.
reduceLeftOption(And).map(Filter(_, right)).getOrElse(right)
val (newJoinConditions, others) =
commonFilterCondition.partition(e => !SubqueryExpression.hasCorrelatedSubquery(e))
commonFilterCondition.partition(canEvaluateWithinJoin)
val newJoinCond = (newJoinConditions ++ joinCondition).reduceLeftOption(And)

val join = Join(newLeft, newRight, joinType, newJoinCond)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,7 @@ object ReorderJoin extends Rule[LogicalPlan] with PredicateHelper {
: LogicalPlan = {
assert(input.size >= 2)
if (input.size == 2) {
val (joinConditions, others) = conditions.partition(
e => !SubqueryExpression.hasCorrelatedSubquery(e))
val (joinConditions, others) = conditions.partition(canEvaluateWithinJoin)
val ((left, leftJoinType), (right, rightJoinType)) = (input(0), input(1))
val innerJoinType = (leftJoinType, rightJoinType) match {
case (Inner, Inner) => Inner
Expand Down Expand Up @@ -75,7 +74,7 @@ object ReorderJoin extends Rule[LogicalPlan] with PredicateHelper {

val joinedRefs = left.outputSet ++ right.outputSet
val (joinConditions, others) = conditions.partition(
e => e.references.subsetOf(joinedRefs) && !SubqueryExpression.hasCorrelatedSubquery(e))
e => e.references.subsetOf(joinedRefs) && canEvaluateWithinJoin(e))
val joined = Join(left, right, innerJoinType, joinConditions.reduceLeftOption(And))

// should not have reference to same logical plan
Expand Down Expand Up @@ -108,11 +107,10 @@ object EliminateOuterJoin extends Rule[LogicalPlan] with PredicateHelper {
* Returns whether the expression returns null or false when all inputs are nulls.
*/
private def canFilterOutNull(e: Expression): Boolean = {
if (!e.deterministic || SubqueryExpression.hasCorrelatedSubquery(e)) return false
if (!e.deterministic || e.find(_.isInstanceOf[Unevaluable]).isDefined) return false
val attributes = e.references.toSeq
val emptyRow = new GenericInternalRow(attributes.length)
val boundE = BindReferences.bindReference(e, attributes)
if (boundE.find(_.isInstanceOf[Unevaluable]).isDefined) return false
val v = boundE.eval(emptyRow)
v == null || v == false
}
Expand Down