Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -650,7 +650,9 @@ object ColumnPruning extends Rule[LogicalPlan] {
*/
private def removeProjectBeforeFilter(plan: LogicalPlan): LogicalPlan = plan transformUp {
case p1 @ Project(_, f @ Filter(_, p2 @ Project(_, child)))
if p2.outputSet.subsetOf(child.outputSet) =>
if p2.outputSet.subsetOf(child.outputSet) &&
// We only remove attribute-only project.
p2.projectList.forall(_.isInstanceOf[AttributeReference]) =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure about this change. This may cause serious perf regression

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How can we remove project that's not attribute-only?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd say it was wrong previously, but if a project's output has same expr IDs with its child, it's usually attribute-only.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mmmmh... I may be missing something, but I'd imagine a case like this:

select a, b from
(select a, b, very_expensive_operation as c from ... where a = 1)

Before this change, would be optimized as:

select a, b from
(select a, b from ... where a = 1)

while after it is not. Am I wrong?

Copy link
Member Author

@viirya viirya Jul 25, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In above case, it has a Alias in project list, so it's not an attribute-only project. And I think it also create new attr c, so p2.outputSet.subsetOf(child.outputSet) is not met too.

I think the rules in ColumnPruning will trim very_expensive_operation in the end.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see now, sorry. Why do we need this? Seems an unrelated change to the fix in this PR, isn't it?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, the issue was seen in previous comment 33441a3. It was overwritten now.

We added a column for count bug. The column checks a always-true leading column alwaysTrueExpr, returns special value if alwaysTrueExpr is null, to simulate empty input case.

This column reuses expr id of original output in the subquery. In non-foldable expression case, the added column in a potential Project-Filter-Project, will be trimmed by removeProjectBeforeFilter, because the second project meets p2.outputSet.subsetOf(child.outputSet).

My original fix is to create an expr id. Replace original expr id with new one in the subquery. Looks complicated. This seems a simple fix, and looks reasonable.

p1.copy(child = f.copy(child = child))
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -316,25 +316,46 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] {
newExpression.asInstanceOf[E]
}

private def removeAlias(expr: Expression): Expression = expr match {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what if there are several aliases? Shall we use CleanupAliases instead?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We track expressions from aggregate expressions as root. I think aliases should be continuous on top. Using CleanupAliases is also good, at least we don't need adding new method.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, sorry, this is recursive too, but I think it is good to avoid a new method. Thanks.

case Alias(c, _) => removeAlias(c)
case _ => expr
}

/**
* Checks if given expression is foldable. Evaluates it and returns it as literal, if yes.
* If not, returns the original expression without evaluation.
*/
private def tryEvalExpr(expr: Expression): Expression = {
// Removes Alias over given expression, because Alias is not foldable.
if (!removeAlias(expr).foldable) {
// SPARK-28441: Some expressions, like PythonUDF, can't be statically evaluated.
// Needs to evaluate them on query runtime.
expr
} else {
Literal.create(expr.eval(), expr.dataType)
}
}

/**
* Statically evaluate an expression containing zero or more placeholders, given a set
* of bindings for placeholder values.
* of bindings for placeholder values, if the expression is evaluable. If it is not,
* bind statically evaluated expression results to an expression.
*/
private def evalExpr(expr: Expression, bindings: Map[ExprId, Option[Any]]) : Option[Any] = {
private def bindingExpr(
expr: Expression,
bindings: Map[ExprId, Expression]): Expression = {
val rewrittenExpr = expr transform {
case r: AttributeReference =>
bindings(r.exprId) match {
case Some(v) => Literal.create(v, r.dataType)
case None => Literal.default(NullType)
}
bindings.getOrElse(r.exprId, Literal.default(NullType))
}
Option(rewrittenExpr.eval())

tryEvalExpr(rewrittenExpr)
}

/**
* Statically evaluate an expression containing one or more aggregates on an empty input.
*/
private def evalAggOnZeroTups(expr: Expression) : Option[Any] = {
private def evalAggOnZeroTups(expr: Expression) : Expression = {
// AggregateExpressions are Unevaluable, so we need to replace all aggregates
// in the expression with the value they would return for zero input tuples.
// Also replace attribute refs (for example, for grouping columns) with NULL.
Expand All @@ -344,7 +365,8 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] {

case _: AttributeReference => Literal.default(NullType)
}
Option(rewrittenExpr.eval())

tryEvalExpr(rewrittenExpr)
}

/**
Expand All @@ -354,36 +376,51 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] {
* [[org.apache.spark.sql.catalyst.analysis.CheckAnalysis]]. If the checks in
* CheckAnalysis become less restrictive, this method will need to change.
*/
private def evalSubqueryOnZeroTups(plan: LogicalPlan) : Option[Any] = {
private def evalSubqueryOnZeroTups(plan: LogicalPlan) : Option[Expression] = {
// Inputs to this method will start with a chain of zero or more SubqueryAlias
// and Project operators, followed by an optional Filter, followed by an
// Aggregate. Traverse the operators recursively.
def evalPlan(lp : LogicalPlan) : Map[ExprId, Option[Any]] = lp match {
def evalPlan(lp : LogicalPlan) : Map[ExprId, Expression] = lp match {
case SubqueryAlias(_, child) => evalPlan(child)
case Filter(condition, child) =>
val bindings = evalPlan(child)
if (bindings.isEmpty) bindings
else {
val exprResult = evalExpr(condition, bindings).getOrElse(false)
.asInstanceOf[Boolean]
if (exprResult) bindings else Map.empty
if (bindings.isEmpty) {
bindings
} else {
val bindExpr = bindingExpr(condition, bindings)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: bindCondition looks better.


if (!bindExpr.foldable) {
// We can't evaluate the condition. Evaluate it in query runtime.
bindings.map { case (id, expr) =>
val newExpr = If(bindExpr, expr, Literal.create(null, expr.dataType))
(id, newExpr)
}
} else {
// The bound condition can be evaluated.
bindExpr.eval() match {
// For filter condition, null is the same as false.
case null | false => Map.empty
case true => bindings
}
}
}

case Project(projectList, child) =>
val bindings = evalPlan(child)
if (bindings.isEmpty) {
bindings
} else {
projectList.map(ne => (ne.exprId, evalExpr(ne, bindings))).toMap
projectList.map(ne => (ne.exprId, bindingExpr(ne, bindings))).toMap
}

case Aggregate(_, aggExprs, _) =>
// Some of the expressions under the Aggregate node are the join columns
// for joining with the outer query block. Fill those expressions in with
// nulls and statically evaluate the remainder.
aggExprs.map {
case ref: AttributeReference => (ref.exprId, None)
case alias @ Alias(_: AttributeReference, _) => (alias.exprId, None)
case ref: AttributeReference => (ref.exprId, Literal.create(null, ref.dataType))
case alias @ Alias(_: AttributeReference, _) =>
(alias.exprId, Literal.create(null, alias.dataType))
case ne => (ne.exprId, evalAggOnZeroTups(ne))
}.toMap

Expand All @@ -394,7 +431,10 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] {
val resultMap = evalPlan(plan)

// By convention, the scalar subquery result is the leftmost field.
resultMap.getOrElse(plan.output.head.exprId, None)
resultMap.get(plan.output.head.exprId) match {
case Some(Literal(null, _)) | None => None
case o => o
}
}

/**
Expand Down Expand Up @@ -432,6 +472,18 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] {
sys.error("This line should be unreachable")
}

/**
* This replaces original expression id used in attributes and aliases in expression.
*/
private def replaceOldExprId(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can remove this.

oldExprId: ExprId,
newExprId: ExprId): PartialFunction[Expression, Expression] = {
case a: AttributeReference if a.exprId == oldExprId =>
a.withExprId(newExprId)
case a: Alias if a.exprId == oldExprId =>
Alias(child = a.child, name = a.name)(exprId = newExprId)
}

// Name of generated column used in rewrite below
val ALWAYS_TRUE_COLNAME = "alwaysTrue"

Expand Down Expand Up @@ -469,11 +521,12 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] {

if (havingNode.isEmpty) {
// CASE 2: Subquery with no HAVING clause

Project(
currentChild.output :+
Alias(
If(IsNull(alwaysTrueRef),
Literal.create(resultWithZeroTups.get, origOutput.dataType),
resultWithZeroTups.get,
aggValRef), origOutput.name)(exprId = origOutput.exprId),
Join(currentChild,
Project(query.output :+ alwaysTrueExpr, query),
Expand All @@ -494,11 +547,11 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] {
case op => sys.error(s"Unexpected operator $op in corelated subquery")
}

// CASE WHEN alwayTrue IS NULL THEN resultOnZeroTups
// CASE WHEN alwaysTrue IS NULL THEN resultOnZeroTups
// WHEN NOT (original HAVING clause expr) THEN CAST(null AS <type of aggVal>)
// ELSE (aggregate value) END AS (original column name)
val caseExpr = Alias(CaseWhen(Seq(
(IsNull(alwaysTrueRef), Literal.create(resultWithZeroTups.get, origOutput.dataType)),
(IsNull(alwaysTrueRef), resultWithZeroTups.get),
(Not(havingNode.get.condition), Literal.create(null, aggValRef.dataType))),
aggValRef),
origOutput.name)(exprId = origOutput.exprId)
Expand All @@ -508,7 +561,6 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] {
Join(currentChild,
Project(subqueryRoot.output :+ alwaysTrueExpr, subqueryRoot),
LeftOuter, conditions.reduceOption(And), JoinHint.NONE))

}
}
}
Expand Down
183 changes: 183 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1384,4 +1384,187 @@ class SubquerySuite extends QueryTest with SharedSQLContext {
assert(subqueryExecs.forall(_.name.startsWith("scalar-subquery#")),
"SubqueryExec name should start with scalar-subquery#")
}

test("SPARK-28441: COUNT bug in WHERE clause (Filter) with PythonUDF") {
import IntegratedUDFTestUtils._

val pythonTestUDF = TestPythonUDF(name = "udf")
registerTestUDF(pythonTestUDF, spark)

// Case 1: Canonical example of the COUNT bug
checkAnswer(
sql("select l.a from l where (select udf(count(*)) from r where l.a = r.c) < l.a"),
Row(1) :: Row(1) :: Row(3) :: Row(6) :: Nil)
// Case 2: count(*) = 0; could be rewritten to NOT EXISTS but currently uses
// a rewrite that is vulnerable to the COUNT bug
checkAnswer(
sql("select l.a from l where (select udf(count(*)) from r where l.a = r.c) = 0"),
Row(1) :: Row(1) :: Row(null) :: Row(null) :: Nil)
// Case 3: COUNT bug without a COUNT aggregate
checkAnswer(
sql("select l.a from l where (select udf(sum(r.d)) is null from r where l.a = r.c)"),
Row(1) :: Row(1) ::Row(null) :: Row(null) :: Row(6) :: Nil)
}

test("SPARK-28441: COUNT bug in SELECT clause (Project) with PythonUDF") {
import IntegratedUDFTestUtils._

val pythonTestUDF = TestPythonUDF(name = "udf")
registerTestUDF(pythonTestUDF, spark)

checkAnswer(
sql("select a, (select udf(count(*)) from r where l.a = r.c) as cnt from l"),
Row(1, 0) :: Row(1, 0) :: Row(2, 2) :: Row(2, 2) :: Row(3, 1) :: Row(null, 0)
:: Row(null, 0) :: Row(6, 1) :: Nil)
}

test("SPARK-28441: COUNT bug in HAVING clause (Filter) with PythonUDF") {
import IntegratedUDFTestUtils._

val pythonTestUDF = TestPythonUDF(name = "udf")
registerTestUDF(pythonTestUDF, spark)

checkAnswer(
sql("select l.a as grp_a from l group by l.a " +
"having (select udf(count(*)) from r where grp_a = r.c) = 0 " +
"order by grp_a"),
Row(null) :: Row(1) :: Nil)
}

test("SPARK-28441: COUNT bug in Aggregate with PythonUDF") {
import IntegratedUDFTestUtils._

val pythonTestUDF = TestPythonUDF(name = "udf")
registerTestUDF(pythonTestUDF, spark)

checkAnswer(
sql("select l.a as aval, sum((select udf(count(*)) from r where l.a = r.c)) as cnt " +
"from l group by l.a order by aval"),
Row(null, 0) :: Row(1, 0) :: Row(2, 4) :: Row(3, 1) :: Row(6, 1) :: Nil)
}

test("SPARK-28441: COUNT bug negative examples with PythonUDF") {
import IntegratedUDFTestUtils._

val pythonTestUDF = TestPythonUDF(name = "udf")
registerTestUDF(pythonTestUDF, spark)

// Case 1: Potential COUNT bug case that was working correctly prior to the fix
checkAnswer(
sql("select l.a from l where (select udf(sum(r.d)) from r where l.a = r.c) is null"),
Row(1) :: Row(1) :: Row(null) :: Row(null) :: Row(6) :: Nil)
// Case 2: COUNT aggregate but no COUNT bug due to > 0 test.
checkAnswer(
sql("select l.a from l where (select udf(count(*)) from r where l.a = r.c) > 0"),
Row(2) :: Row(2) :: Row(3) :: Row(6) :: Nil)
// Case 3: COUNT inside aggregate expression but no COUNT bug.
checkAnswer(
sql("select l.a from l where (select udf(count(*)) + udf(sum(r.d)) " +
"from r where l.a = r.c) = 0"),
Nil)
}

test("SPARK-28441: COUNT bug in subquery in subquery in subquery with PythonUDF") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: maybe just say in nested subquery

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok. it was copied from old test.

import IntegratedUDFTestUtils._

val pythonTestUDF = TestPythonUDF(name = "udf")
registerTestUDF(pythonTestUDF, spark)

checkAnswer(
sql("""select l.a from l
|where (
| select cntPlusOne + 1 as cntPlusTwo from (
| select cnt + 1 as cntPlusOne from (
| select udf(sum(r.c)) s, udf(count(*)) cnt from r where l.a = r.c
| having cnt = 0
| )
| )
|) = 2""".stripMargin),
Row(1) :: Row(1) :: Row(null) :: Row(null) :: Nil)
}

test("SPARK-28441: COUNT bug with nasty predicate expr with PythonUDF") {
import IntegratedUDFTestUtils._

val pythonTestUDF = TestPythonUDF(name = "udf")
registerTestUDF(pythonTestUDF, spark)

checkAnswer(
sql("select l.a from l where " +
"(select case when udf(count(*)) = 1 then null else udf(count(*)) end as cnt " +
"from r where l.a = r.c) = 0"),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we use multi-line string to write long SQL? Let's also upper case the keywords.

Row(1) :: Row(1) :: Row(null) :: Row(null) :: Nil)
}

test("SPARK-28441: COUNT bug with attribute ref in subquery input and output with PythonUDF") {
import IntegratedUDFTestUtils._

val pythonTestUDF = TestPythonUDF(name = "udf")
registerTestUDF(pythonTestUDF, spark)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW, we should add assume(shouldTestPythonUDFs). Maybe it's not a biggie in general but it can matter in other venders' testing base. For instance, if somebody launches a test in a minimal docker image, it might make the tests failed suddenly.

This skipping stuff isn't completely new in our test base. See TestUtils.testCommandAvailable for instance.


checkAnswer(
sql(
"""
|select l.b, (select (r.c + udf(count(*))) is null
|from r
|where l.a = r.c group by r.c) from l
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's format the SQL in a more readable way. For this particular example

select
  l.b,
  (
    select (r.c + udf(count(*))) is null
    from r
    where l.a = r.c
    group by r.c
  )
from l

""".stripMargin),
Row(1.0, false) :: Row(1.0, false) :: Row(2.0, true) :: Row(2.0, true) ::
Row(3.0, false) :: Row(5.0, true) :: Row(null, false) :: Row(null, true) :: Nil)
}

test("SPARK-28441: COUNT bug with non-foldable expression") {
// Case 1: Canonical example of the COUNT bug
checkAnswer(
sql("select l.a from l where (select count(*) + cast(rand() as int) from r " +
"where l.a = r.c) < l.a"),
Row(1) :: Row(1) :: Row(3) :: Row(6) :: Nil)
// Case 2: count(*) = 0; could be rewritten to NOT EXISTS but currently uses
// a rewrite that is vulnerable to the COUNT bug
checkAnswer(
sql("select l.a from l where (select count(*) + cast(rand() as int) from r " +
"where l.a = r.c) = 0"),
Row(1) :: Row(1) :: Row(null) :: Row(null) :: Nil)
// Case 3: COUNT bug without a COUNT aggregate
checkAnswer(
sql("select l.a from l where (select sum(r.d) is null from r " +
"where l.a = r.c)"),
Row(1) :: Row(1) ::Row(null) :: Row(null) :: Row(6) :: Nil)
}

test("SPARK-28441: COUNT bug in subquery in subquery in subquery with non-foldable expr") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

checkAnswer(
sql("""select l.a from l
|where (
| select cntPlusOne + 1 as cntPlusTwo from (
| select cnt + 1 as cntPlusOne from (
| select sum(r.c) s, (count(*) + cast(rand() as int)) cnt from r
| where l.a = r.c having cnt = 0
| )
| )
|) = 2""".stripMargin),
Row(1) :: Row(1) :: Row(null) :: Row(null) :: Nil)
}

test("SPARK-28441: COUNT bug with non-foldable expression in Filter condition") {
val df = sql("""select l.a from l
|where (
| select cntPlusOne + 1 as cntPlusTwo from (
| select cnt + 1 as cntPlusOne from (
| select sum(r.c) s, count(*) cnt from r
| where l.a = r.c having cnt > 0
| )
| )
|) = 2""".stripMargin)
val df2 = sql("""select l.a from l
|where (
| select cntPlusOne + 1 as cntPlusTwo from (
| select cnt + 1 as cntPlusOne from (
| select sum(r.c) s, count(*) cnt from r
| where l.a = r.c having (cnt + cast(rand() as int)) > 0
| )
| )
|) = 2""".stripMargin)
checkAnswer(df, df2)
}
}