-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-28441][SQL][Python] Fix error when non-foldable expression is used in correlated scalar subquery #25204
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 7 commits
725304c
7972d7c
33441a3
110a39e
a7803f5
0158d85
2dd29c1
9aea844
1f6b717
d7d023d
fd29677
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -316,25 +316,46 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] { | |
| newExpression.asInstanceOf[E] | ||
| } | ||
|
|
||
| private def removeAlias(expr: Expression): Expression = expr match { | ||
|
||
| case Alias(c, _) => removeAlias(c) | ||
| case _ => expr | ||
| } | ||
|
|
||
| /** | ||
| * Checks if given expression is foldable. Evaluates it and returns it as literal, if yes. | ||
| * If not, returns the original expression without evaluation. | ||
| */ | ||
| private def tryEvalExpr(expr: Expression): Expression = { | ||
| // Removes Alias over given expression, because Alias is not foldable. | ||
| if (!removeAlias(expr).foldable) { | ||
| // SPARK-28441: Some expressions, like PythonUDF, can't be statically evaluated. | ||
| // Needs to evaluate them on query runtime. | ||
| expr | ||
| } else { | ||
| Literal.create(expr.eval(), expr.dataType) | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * Statically evaluate an expression containing zero or more placeholders, given a set | ||
| * of bindings for placeholder values. | ||
| * of bindings for placeholder values, if the expression is evaluable. If it is not, | ||
| * bind statically evaluated expression results to an expression. | ||
| */ | ||
| private def evalExpr(expr: Expression, bindings: Map[ExprId, Option[Any]]) : Option[Any] = { | ||
| private def bindingExpr( | ||
| expr: Expression, | ||
| bindings: Map[ExprId, Expression]): Expression = { | ||
| val rewrittenExpr = expr transform { | ||
| case r: AttributeReference => | ||
| bindings(r.exprId) match { | ||
| case Some(v) => Literal.create(v, r.dataType) | ||
| case None => Literal.default(NullType) | ||
| } | ||
| bindings.getOrElse(r.exprId, Literal.default(NullType)) | ||
| } | ||
| Option(rewrittenExpr.eval()) | ||
|
|
||
| tryEvalExpr(rewrittenExpr) | ||
| } | ||
|
|
||
| /** | ||
| * Statically evaluate an expression containing one or more aggregates on an empty input. | ||
| */ | ||
| private def evalAggOnZeroTups(expr: Expression) : Option[Any] = { | ||
| private def evalAggOnZeroTups(expr: Expression) : Expression = { | ||
| // AggregateExpressions are Unevaluable, so we need to replace all aggregates | ||
| // in the expression with the value they would return for zero input tuples. | ||
| // Also replace attribute refs (for example, for grouping columns) with NULL. | ||
|
|
@@ -344,7 +365,8 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] { | |
|
|
||
| case _: AttributeReference => Literal.default(NullType) | ||
| } | ||
| Option(rewrittenExpr.eval()) | ||
|
|
||
| tryEvalExpr(rewrittenExpr) | ||
| } | ||
|
|
||
| /** | ||
|
|
@@ -354,36 +376,51 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] { | |
| * [[org.apache.spark.sql.catalyst.analysis.CheckAnalysis]]. If the checks in | ||
| * CheckAnalysis become less restrictive, this method will need to change. | ||
| */ | ||
| private def evalSubqueryOnZeroTups(plan: LogicalPlan) : Option[Any] = { | ||
| private def evalSubqueryOnZeroTups(plan: LogicalPlan) : Option[Expression] = { | ||
| // Inputs to this method will start with a chain of zero or more SubqueryAlias | ||
| // and Project operators, followed by an optional Filter, followed by an | ||
| // Aggregate. Traverse the operators recursively. | ||
| def evalPlan(lp : LogicalPlan) : Map[ExprId, Option[Any]] = lp match { | ||
| def evalPlan(lp : LogicalPlan) : Map[ExprId, Expression] = lp match { | ||
| case SubqueryAlias(_, child) => evalPlan(child) | ||
| case Filter(condition, child) => | ||
| val bindings = evalPlan(child) | ||
| if (bindings.isEmpty) bindings | ||
| else { | ||
| val exprResult = evalExpr(condition, bindings).getOrElse(false) | ||
| .asInstanceOf[Boolean] | ||
| if (exprResult) bindings else Map.empty | ||
| if (bindings.isEmpty) { | ||
| bindings | ||
| } else { | ||
| val bindExpr = bindingExpr(condition, bindings) | ||
|
||
|
|
||
| if (!bindExpr.foldable) { | ||
| // We can't evaluate the condition. Evaluate it in query runtime. | ||
| bindings.map { case (id, expr) => | ||
| val newExpr = If(bindExpr, expr, Literal.create(null, expr.dataType)) | ||
| (id, newExpr) | ||
| } | ||
| } else { | ||
| // The bound condition can be evaluated. | ||
| bindExpr.eval() match { | ||
| // For filter condition, null is the same as false. | ||
| case null | false => Map.empty | ||
| case true => bindings | ||
| } | ||
| } | ||
| } | ||
|
|
||
| case Project(projectList, child) => | ||
| val bindings = evalPlan(child) | ||
| if (bindings.isEmpty) { | ||
| bindings | ||
| } else { | ||
| projectList.map(ne => (ne.exprId, evalExpr(ne, bindings))).toMap | ||
| projectList.map(ne => (ne.exprId, bindingExpr(ne, bindings))).toMap | ||
| } | ||
|
|
||
| case Aggregate(_, aggExprs, _) => | ||
| // Some of the expressions under the Aggregate node are the join columns | ||
| // for joining with the outer query block. Fill those expressions in with | ||
| // nulls and statically evaluate the remainder. | ||
| aggExprs.map { | ||
| case ref: AttributeReference => (ref.exprId, None) | ||
| case alias @ Alias(_: AttributeReference, _) => (alias.exprId, None) | ||
| case ref: AttributeReference => (ref.exprId, Literal.create(null, ref.dataType)) | ||
| case alias @ Alias(_: AttributeReference, _) => | ||
| (alias.exprId, Literal.create(null, alias.dataType)) | ||
| case ne => (ne.exprId, evalAggOnZeroTups(ne)) | ||
| }.toMap | ||
|
|
||
|
|
@@ -394,7 +431,10 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] { | |
| val resultMap = evalPlan(plan) | ||
|
|
||
| // By convention, the scalar subquery result is the leftmost field. | ||
| resultMap.getOrElse(plan.output.head.exprId, None) | ||
| resultMap.get(plan.output.head.exprId) match { | ||
| case Some(Literal(null, _)) | None => None | ||
| case o => o | ||
| } | ||
| } | ||
|
|
||
| /** | ||
|
|
@@ -432,6 +472,18 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] { | |
| sys.error("This line should be unreachable") | ||
| } | ||
|
|
||
| /** | ||
| * This replaces original expression id used in attributes and aliases in expression. | ||
| */ | ||
| private def replaceOldExprId( | ||
|
||
| oldExprId: ExprId, | ||
| newExprId: ExprId): PartialFunction[Expression, Expression] = { | ||
| case a: AttributeReference if a.exprId == oldExprId => | ||
| a.withExprId(newExprId) | ||
| case a: Alias if a.exprId == oldExprId => | ||
| Alias(child = a.child, name = a.name)(exprId = newExprId) | ||
| } | ||
|
|
||
| // Name of generated column used in rewrite below | ||
| val ALWAYS_TRUE_COLNAME = "alwaysTrue" | ||
|
|
||
|
|
@@ -469,11 +521,12 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] { | |
|
|
||
| if (havingNode.isEmpty) { | ||
| // CASE 2: Subquery with no HAVING clause | ||
|
|
||
| Project( | ||
| currentChild.output :+ | ||
| Alias( | ||
| If(IsNull(alwaysTrueRef), | ||
| Literal.create(resultWithZeroTups.get, origOutput.dataType), | ||
| resultWithZeroTups.get, | ||
| aggValRef), origOutput.name)(exprId = origOutput.exprId), | ||
| Join(currentChild, | ||
| Project(query.output :+ alwaysTrueExpr, query), | ||
|
|
@@ -494,11 +547,11 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] { | |
| case op => sys.error(s"Unexpected operator $op in corelated subquery") | ||
| } | ||
|
|
||
| // CASE WHEN alwayTrue IS NULL THEN resultOnZeroTups | ||
| // CASE WHEN alwaysTrue IS NULL THEN resultOnZeroTups | ||
| // WHEN NOT (original HAVING clause expr) THEN CAST(null AS <type of aggVal>) | ||
| // ELSE (aggregate value) END AS (original column name) | ||
| val caseExpr = Alias(CaseWhen(Seq( | ||
| (IsNull(alwaysTrueRef), Literal.create(resultWithZeroTups.get, origOutput.dataType)), | ||
| (IsNull(alwaysTrueRef), resultWithZeroTups.get), | ||
| (Not(havingNode.get.condition), Literal.create(null, aggValRef.dataType))), | ||
| aggValRef), | ||
| origOutput.name)(exprId = origOutput.exprId) | ||
|
|
@@ -508,7 +561,6 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] { | |
| Join(currentChild, | ||
| Project(subqueryRoot.output :+ alwaysTrueExpr, subqueryRoot), | ||
| LeftOuter, conditions.reduceOption(And), JoinHint.NONE)) | ||
|
|
||
| } | ||
| } | ||
| } | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1384,4 +1384,187 @@ class SubquerySuite extends QueryTest with SharedSQLContext { | |
| assert(subqueryExecs.forall(_.name.startsWith("scalar-subquery#")), | ||
| "SubqueryExec name should start with scalar-subquery#") | ||
| } | ||
|
|
||
| test("SPARK-28441: COUNT bug in WHERE clause (Filter) with PythonUDF") { | ||
| import IntegratedUDFTestUtils._ | ||
|
|
||
| val pythonTestUDF = TestPythonUDF(name = "udf") | ||
| registerTestUDF(pythonTestUDF, spark) | ||
|
|
||
| // Case 1: Canonical example of the COUNT bug | ||
| checkAnswer( | ||
| sql("select l.a from l where (select udf(count(*)) from r where l.a = r.c) < l.a"), | ||
| Row(1) :: Row(1) :: Row(3) :: Row(6) :: Nil) | ||
| // Case 2: count(*) = 0; could be rewritten to NOT EXISTS but currently uses | ||
| // a rewrite that is vulnerable to the COUNT bug | ||
| checkAnswer( | ||
| sql("select l.a from l where (select udf(count(*)) from r where l.a = r.c) = 0"), | ||
| Row(1) :: Row(1) :: Row(null) :: Row(null) :: Nil) | ||
| // Case 3: COUNT bug without a COUNT aggregate | ||
| checkAnswer( | ||
| sql("select l.a from l where (select udf(sum(r.d)) is null from r where l.a = r.c)"), | ||
| Row(1) :: Row(1) ::Row(null) :: Row(null) :: Row(6) :: Nil) | ||
| } | ||
|
|
||
| test("SPARK-28441: COUNT bug in SELECT clause (Project) with PythonUDF") { | ||
| import IntegratedUDFTestUtils._ | ||
|
|
||
| val pythonTestUDF = TestPythonUDF(name = "udf") | ||
| registerTestUDF(pythonTestUDF, spark) | ||
|
|
||
| checkAnswer( | ||
| sql("select a, (select udf(count(*)) from r where l.a = r.c) as cnt from l"), | ||
| Row(1, 0) :: Row(1, 0) :: Row(2, 2) :: Row(2, 2) :: Row(3, 1) :: Row(null, 0) | ||
| :: Row(null, 0) :: Row(6, 1) :: Nil) | ||
| } | ||
|
|
||
| test("SPARK-28441: COUNT bug in HAVING clause (Filter) with PythonUDF") { | ||
| import IntegratedUDFTestUtils._ | ||
|
|
||
| val pythonTestUDF = TestPythonUDF(name = "udf") | ||
| registerTestUDF(pythonTestUDF, spark) | ||
|
|
||
| checkAnswer( | ||
| sql("select l.a as grp_a from l group by l.a " + | ||
| "having (select udf(count(*)) from r where grp_a = r.c) = 0 " + | ||
| "order by grp_a"), | ||
| Row(null) :: Row(1) :: Nil) | ||
| } | ||
|
|
||
| test("SPARK-28441: COUNT bug in Aggregate with PythonUDF") { | ||
| import IntegratedUDFTestUtils._ | ||
|
|
||
| val pythonTestUDF = TestPythonUDF(name = "udf") | ||
| registerTestUDF(pythonTestUDF, spark) | ||
|
|
||
| checkAnswer( | ||
| sql("select l.a as aval, sum((select udf(count(*)) from r where l.a = r.c)) as cnt " + | ||
| "from l group by l.a order by aval"), | ||
| Row(null, 0) :: Row(1, 0) :: Row(2, 4) :: Row(3, 1) :: Row(6, 1) :: Nil) | ||
| } | ||
|
|
||
| test("SPARK-28441: COUNT bug negative examples with PythonUDF") { | ||
| import IntegratedUDFTestUtils._ | ||
|
|
||
| val pythonTestUDF = TestPythonUDF(name = "udf") | ||
| registerTestUDF(pythonTestUDF, spark) | ||
|
|
||
| // Case 1: Potential COUNT bug case that was working correctly prior to the fix | ||
| checkAnswer( | ||
| sql("select l.a from l where (select udf(sum(r.d)) from r where l.a = r.c) is null"), | ||
| Row(1) :: Row(1) :: Row(null) :: Row(null) :: Row(6) :: Nil) | ||
| // Case 2: COUNT aggregate but no COUNT bug due to > 0 test. | ||
| checkAnswer( | ||
| sql("select l.a from l where (select udf(count(*)) from r where l.a = r.c) > 0"), | ||
| Row(2) :: Row(2) :: Row(3) :: Row(6) :: Nil) | ||
| // Case 3: COUNT inside aggregate expression but no COUNT bug. | ||
| checkAnswer( | ||
| sql("select l.a from l where (select udf(count(*)) + udf(sum(r.d)) " + | ||
| "from r where l.a = r.c) = 0"), | ||
| Nil) | ||
| } | ||
|
|
||
| test("SPARK-28441: COUNT bug in subquery in subquery in subquery with PythonUDF") { | ||
|
||
| import IntegratedUDFTestUtils._ | ||
|
|
||
| val pythonTestUDF = TestPythonUDF(name = "udf") | ||
| registerTestUDF(pythonTestUDF, spark) | ||
|
|
||
| checkAnswer( | ||
| sql("""select l.a from l | ||
| |where ( | ||
| | select cntPlusOne + 1 as cntPlusTwo from ( | ||
| | select cnt + 1 as cntPlusOne from ( | ||
| | select udf(sum(r.c)) s, udf(count(*)) cnt from r where l.a = r.c | ||
| | having cnt = 0 | ||
| | ) | ||
| | ) | ||
| |) = 2""".stripMargin), | ||
| Row(1) :: Row(1) :: Row(null) :: Row(null) :: Nil) | ||
| } | ||
|
|
||
| test("SPARK-28441: COUNT bug with nasty predicate expr with PythonUDF") { | ||
| import IntegratedUDFTestUtils._ | ||
|
|
||
| val pythonTestUDF = TestPythonUDF(name = "udf") | ||
| registerTestUDF(pythonTestUDF, spark) | ||
|
|
||
| checkAnswer( | ||
| sql("select l.a from l where " + | ||
| "(select case when udf(count(*)) = 1 then null else udf(count(*)) end as cnt " + | ||
| "from r where l.a = r.c) = 0"), | ||
|
||
| Row(1) :: Row(1) :: Row(null) :: Row(null) :: Nil) | ||
| } | ||
|
|
||
| test("SPARK-28441: COUNT bug with attribute ref in subquery input and output with PythonUDF") { | ||
| import IntegratedUDFTestUtils._ | ||
|
|
||
| val pythonTestUDF = TestPythonUDF(name = "udf") | ||
| registerTestUDF(pythonTestUDF, spark) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. BTW, we should add This skipping stuff isn't completely new in our test base. See |
||
|
|
||
| checkAnswer( | ||
| sql( | ||
| """ | ||
| |select l.b, (select (r.c + udf(count(*))) is null | ||
| |from r | ||
| |where l.a = r.c group by r.c) from l | ||
|
||
| """.stripMargin), | ||
| Row(1.0, false) :: Row(1.0, false) :: Row(2.0, true) :: Row(2.0, true) :: | ||
| Row(3.0, false) :: Row(5.0, true) :: Row(null, false) :: Row(null, true) :: Nil) | ||
| } | ||
|
|
||
| test("SPARK-28441: COUNT bug with non-foldable expression") { | ||
| // Case 1: Canonical example of the COUNT bug | ||
| checkAnswer( | ||
| sql("select l.a from l where (select count(*) + cast(rand() as int) from r " + | ||
| "where l.a = r.c) < l.a"), | ||
| Row(1) :: Row(1) :: Row(3) :: Row(6) :: Nil) | ||
| // Case 2: count(*) = 0; could be rewritten to NOT EXISTS but currently uses | ||
| // a rewrite that is vulnerable to the COUNT bug | ||
| checkAnswer( | ||
| sql("select l.a from l where (select count(*) + cast(rand() as int) from r " + | ||
| "where l.a = r.c) = 0"), | ||
| Row(1) :: Row(1) :: Row(null) :: Row(null) :: Nil) | ||
| // Case 3: COUNT bug without a COUNT aggregate | ||
| checkAnswer( | ||
| sql("select l.a from l where (select sum(r.d) is null from r " + | ||
| "where l.a = r.c)"), | ||
| Row(1) :: Row(1) ::Row(null) :: Row(null) :: Row(6) :: Nil) | ||
| } | ||
|
|
||
| test("SPARK-28441: COUNT bug in subquery in subquery in subquery with non-foldable expr") { | ||
|
||
| checkAnswer( | ||
| sql("""select l.a from l | ||
| |where ( | ||
| | select cntPlusOne + 1 as cntPlusTwo from ( | ||
| | select cnt + 1 as cntPlusOne from ( | ||
| | select sum(r.c) s, (count(*) + cast(rand() as int)) cnt from r | ||
| | where l.a = r.c having cnt = 0 | ||
| | ) | ||
| | ) | ||
| |) = 2""".stripMargin), | ||
| Row(1) :: Row(1) :: Row(null) :: Row(null) :: Nil) | ||
| } | ||
|
|
||
| test("SPARK-28441: COUNT bug with non-foldable expression in Filter condition") { | ||
| val df = sql("""select l.a from l | ||
| |where ( | ||
| | select cntPlusOne + 1 as cntPlusTwo from ( | ||
| | select cnt + 1 as cntPlusOne from ( | ||
| | select sum(r.c) s, count(*) cnt from r | ||
| | where l.a = r.c having cnt > 0 | ||
| | ) | ||
| | ) | ||
| |) = 2""".stripMargin) | ||
| val df2 = sql("""select l.a from l | ||
| |where ( | ||
| | select cntPlusOne + 1 as cntPlusTwo from ( | ||
| | select cnt + 1 as cntPlusOne from ( | ||
| | select sum(r.c) s, count(*) cnt from r | ||
| | where l.a = r.c having (cnt + cast(rand() as int)) > 0 | ||
| | ) | ||
| | ) | ||
| |) = 2""".stripMargin) | ||
| checkAnswer(df, df2) | ||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure about this change. This may cause serious perf regression
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How can we remove project that's not attribute-only?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd say it was wrong previously, but if a project's output has same expr IDs with its child, it's usually attribute-only.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mmmmh... I may be missing something, but I'd imagine a case like this:
Before this change, would be optimized as:
while after it is not. Am I wrong?
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In above case, it has a Alias in project list, so it's not an attribute-only project. And I think it also create new attr c, so
p2.outputSet.subsetOf(child.outputSet)is not met too.I think the rules in
ColumnPruningwill trimvery_expensive_operationin the end.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see now, sorry. Why do we need this? Seems an unrelated change to the fix in this PR, isn't it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh, the issue was seen in previous comment 33441a3. It was overwritten now.
We added a column for count bug. The column checks a always-true leading column
alwaysTrueExpr, returns special value ifalwaysTrueExpris null, to simulate empty input case.This column reuses expr id of original output in the subquery. In non-foldable expression case, the added column in a potential Project-Filter-Project, will be trimmed by
removeProjectBeforeFilter, because the second project meetsp2.outputSet.subsetOf(child.outputSet).My original fix is to create an expr id. Replace original expr id with new one in the subquery. Looks complicated. This seems a simple fix, and looks reasonable.