Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog)
SimplifyConditionals,
RemoveDispensableExpressions,
SimplifyBinaryComparison,
ReplaceNullWithFalse,
ReplaceNullWithFalseInPredicate,
PruneFilters,
EliminateSorts,
SimplifyCasts,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -755,7 +755,7 @@ object CombineConcats extends Rule[LogicalPlan] {
*
* As a result, many unnecessary computations can be removed in the query optimization phase.
*/
object ReplaceNullWithFalse extends Rule[LogicalPlan] {
object ReplaceNullWithFalseInPredicate extends Rule[LogicalPlan] {

def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case f @ Filter(cond, _) => f.copy(condition = replaceNullWithFalse(cond))
Expand All @@ -767,6 +767,15 @@ object ReplaceNullWithFalse extends Rule[LogicalPlan] {
replaceNullWithFalse(cond) -> value
}
cw.copy(branches = newBranches)
case af @ ArrayFilter(_, lf @ LambdaFunction(func, _, _)) =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we add a withNewFunctions method in HigherOrderFunction? Then we can simplify this rule to

case f: HigherOrderFunction => f.withNewFunctions(f.functions.map(replaceNullWithFalse))

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if that's useful or not. First of all, the replaceNullWithFalse handling doesn't apply to all higher-order functions. In fact it only applies to a very narrow set, ones where a lambda function returns BooleanType and is immediately used as a predicate. So having a generic utility can certainly help make this PR slightly simpler, but I don't know how useful it is for other cases.
I'd prefer waiting for more such transformation cases to introduce a new utility for the pattern. WDYT?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah i see. Sorry I missed it. Then it's safer to use a white-list here.

val newLambda = lf.copy(function = replaceNullWithFalse(func))
af.copy(function = newLambda)
case ae @ ArrayExists(_, lf @ LambdaFunction(func, _, _)) =>
val newLambda = lf.copy(function = replaceNullWithFalse(func))
ae.copy(function = newLambda)
case mf @ MapFilter(_, lf @ LambdaFunction(func, _, _)) =>
val newLambda = lf.copy(function = replaceNullWithFalse(func))
mf.copy(function = newLambda)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,14 @@ package org.apache.spark.sql.catalyst.optimizer
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.{And, CaseWhen, Expression, GreaterThan, If, Literal, Or}
import org.apache.spark.sql.catalyst.expressions.{And, ArrayExists, ArrayFilter, ArrayTransform, CaseWhen, Expression, GreaterThan, If, LambdaFunction, Literal, MapFilter, NamedExpression, Or}
import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral}
import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest}
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.types.{BooleanType, IntegerType}

class ReplaceNullWithFalseSuite extends PlanTest {
class ReplaceNullWithFalseInPredicateSuite extends PlanTest {

object Optimize extends RuleExecutor[LogicalPlan] {
val batches =
Expand All @@ -36,10 +36,11 @@ class ReplaceNullWithFalseSuite extends PlanTest {
ConstantFolding,
BooleanSimplification,
SimplifyConditionals,
ReplaceNullWithFalse) :: Nil
ReplaceNullWithFalseInPredicate) :: Nil
}

private val testRelation = LocalRelation('i.int, 'b.boolean)
private val testRelation =
LocalRelation('i.int, 'b.boolean, 'a.array(IntegerType), 'm.map(IntegerType, IntegerType))
private val anotherTestRelation = LocalRelation('d.int)

test("replace null inside filter and join conditions") {
Expand Down Expand Up @@ -298,6 +299,26 @@ class ReplaceNullWithFalseSuite extends PlanTest {
testProjection(originalExpr = column, expectedExpr = column)
}

test("replace nulls in lambda function of ArrayFilter") {
testHigherOrderFunc('a, ArrayFilter, Seq('e))
}

test("replace nulls in lambda function of ArrayExists") {
testHigherOrderFunc('a, ArrayExists, Seq('e))
}

test("replace nulls in lambda function of MapFilter") {
testHigherOrderFunc('m, MapFilter, Seq('k, 'v))
}

test("inability to replace nulls in arbitrary higher-order function") {
val lambdaFunc = LambdaFunction(
function = If('e > 0, Literal(null, BooleanType), TrueLiteral),
arguments = Seq[NamedExpression]('e))
val column = ArrayTransform('a, lambdaFunc)
testProjection(originalExpr = column, expectedExpr = column)
}

private def testFilter(originalCond: Expression, expectedCond: Expression): Unit = {
test((rel, exp) => rel.where(exp), originalCond, expectedCond)
}
Expand All @@ -310,6 +331,25 @@ class ReplaceNullWithFalseSuite extends PlanTest {
test((rel, exp) => rel.select(exp), originalExpr, expectedExpr)
}

private def testHigherOrderFunc(
argument: Expression,
createExpr: (Expression, Expression) => Expression,
lambdaArgs: Seq[NamedExpression]): Unit = {
val condArg = lambdaArgs.last
// the lambda body is: if(arg > 0, null, true)
val cond = GreaterThan(condArg, Literal(0))
val lambda1 = LambdaFunction(
function = If(cond, Literal(null, BooleanType), TrueLiteral),
arguments = lambdaArgs)
// the optimized lambda body is: if(arg > 0, false, true)
val lambda2 = LambdaFunction(
function = If(cond, FalseLiteral, TrueLiteral),
arguments = lambdaArgs)
testProjection(
originalExpr = createExpr(argument, lambda1) as 'x,
expectedExpr = createExpr(argument, lambda2) as 'x)
}

private def test(
func: (LogicalPlan, Expression) => LogicalPlan,
originalExpr: Expression,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,13 @@

package org.apache.spark.sql

import org.apache.spark.sql.catalyst.expressions.{CaseWhen, If}
import org.apache.spark.sql.catalyst.expressions.{CaseWhen, If, Literal}
import org.apache.spark.sql.execution.LocalTableScanExec
import org.apache.spark.sql.functions.{lit, when}
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.BooleanType

class ReplaceNullWithFalseEndToEndSuite extends QueryTest with SharedSQLContext {
class ReplaceNullWithFalseInPredicateEndToEndSuite extends QueryTest with SharedSQLContext {
import testImplicits._

test("SPARK-25860: Replace Literal(null, _) with FalseLiteral whenever possible") {
Expand Down Expand Up @@ -68,4 +69,44 @@ class ReplaceNullWithFalseEndToEndSuite extends QueryTest with SharedSQLContext
case p => fail(s"$p is not LocalTableScanExec")
}
}

test("SPARK-26107: Replace Literal(null, _) with FalseLiteral in higher-order functions") {
def assertNoLiteralNullInPlan(df: DataFrame): Unit = {
df.queryExecution.executedPlan.foreach { p =>
assert(p.expressions.forall(_.find {
case Literal(null, BooleanType) => true
case _ => false
}.isEmpty))
}
}

withTable("t1", "t2") {
// to test ArrayFilter and ArrayExists
spark.sql("select array(null, 1, null, 3) as a")
.write.saveAsTable("t1")
// to test MapFilter
spark.sql("""
select map_from_entries(arrays_zip(a, transform(a, e -> if(mod(e, 2) = 0, null, e)))) as m
from (select array(0, 1, 2, 3) as a)
""").write.saveAsTable("t2")

val df1 = spark.table("t1")
val df2 = spark.table("t2")

// ArrayExists
val q1 = df1.selectExpr("EXISTS(a, e -> IF(e is null, null, true))")
checkAnswer(q1, Row(true) :: Nil)
assertNoLiteralNullInPlan(q1)

// ArrayFilter
val q2 = df1.selectExpr("FILTER(a, e -> IF(e is null, null, true))")
checkAnswer(q2, Row(Seq[Any](1, 3)) :: Nil)
assertNoLiteralNullInPlan(q2)

// MapFilter
val q3 = df2.selectExpr("MAP_FILTER(m, (k, v) -> IF(v is null, null, true))")
checkAnswer(q3, Row(Map[Any, Any](1 -> 1, 3 -> 3)))
assertNoLiteralNullInPlan(q3)
}
}
}