Skip to content
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -82,21 +82,31 @@ class EquivalentExpressions {
/**
* Adds only expressions which are common in each of given expressions, in a recursive way.
* For example, given two expressions `(a + (b + (c + 1)))` and `(d + (e + (c + 1)))`,
* the common expression `(c + 1)` will be added into `equivalenceMap`.
* the common expression `(c + 1)` will be added into `equivalenceMap`. Note that if an
* expression and its child expressions are all commonly occurred in each of given expressions,
* we filter out the child expressions. For example, if `((a + b) + c)` and `(a + b)` are
* common expressions, we only add `((a + b) + c)`.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the redundant children expressions are counted as common expressions too, they will be redundantly evaluated and miss the subexpression elimination opportunity.

Could you leave comments here about why we need to filter out these exprs here?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a question; even if we filter out the redundant expr (e.g., (a + b) in this case) here, the suboptimal (this PR pointed out) case still can happen if the expr, (a + b), is added as a common one in the other part? I thought a query like this: Seq((1, 1, 1)).toDF("a", "b", "c").select(when($"a" + $"b" + $"c" > 0, $"a" + $"b" + $"c").when($"a" + $"b" + $"c" <= 0, $"a" + $"b")).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The so called common expressions must occur at all branches/values. So in the above case, (a + b) is actually the only one common expression among two values $"a" + $"b" + $"c and $"a" + $"b".

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated the comment.

*/
private def addCommonExprs(
exprs: Seq[Expression],
addFunc: Expression => Boolean = addExpr): Unit = {
val exprSetForAll = mutable.Set[Expr]()
Copy link
Contributor

@Kimahriman Kimahriman May 20, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One potentially unrelated thing I just noticed, do we need to keep track of all of the Expressions here as well (as in an Expr -> Seq[Expression] map)? It's really basically keeping the first Expression found, but the codegen looks like it uses the Expression hash (versus the semantic hash) to lookup subexpressions. Very much an edge case, just wondering if I'm understanding things correctly

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You mean equivalenceMap?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't mean add it directly to that here. I'm just thinking of a really stupid example, when((col + 1) > 0, col + 1).otherwise(1 + col). Wouldn't col + 1 and 1 + col resolve as a common expression because they're semantically equal, but only col + 1 is added to equivalenceMap, so during codegen 1 + col wouldn't be resolved to the subexpression?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

col + 1 and 1 + col will both be recognized as subexpression.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah but won't the codgen stage not replace 1 + col since only col + 1 will be added to the equivalenceMap entry for Expr(col + 1)? For non commonExprs cases, both would be in equivalenceMap so that the codegen stage maps both of those expressions to the resulting subexpression. Again, not super related to this PR, but was the easiest place to ask

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both 1 + col and col + 1 will be replaced with the extracted subexpression during codege. We don't just look of key at equivalenceMap when replacing with subexpression.

addExprTree(exprs.head, addExprToSet(_, exprSetForAll))

val commonExprSet = exprs.tail.foldLeft(exprSetForAll) { (exprSet, expr) =>
val candidateExprs = exprs.tail.foldLeft(exprSetForAll) { (exprSet, expr) =>
val otherExprSet = mutable.Set[Expr]()
addExprTree(expr, addExprToSet(_, otherExprSet))
exprSet.intersect(otherExprSet)
}

commonExprSet.foreach(expr => addFunc(expr.e))
// Not all expressions in the set should be added. We should filter out the subexprs.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to revise line 83 consistently?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, revised the method comment. Thanks.

val commonExprSet = candidateExprs.filter { candidateExpr =>
candidateExprs.forall { expr =>
expr == candidateExpr || expr.e.find(_.semanticEquals(candidateExpr.e)).isEmpty
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this loop not expensive? It seems the time-complexity is big-O(the total number of expr nodes in candidateExprs) x (candidateExprs.size)^2 )?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1, but I don't have a better idea now...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, I considered this part but didn't come out better one.

Copy link
Member

@maropu maropu May 17, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, okay. I don't have a idea, too... That was just a question.

}

commonExprSet.foreach(expr => addExprTree(expr.e, addFunc))
}

// There are some special expressions that we should not recurse into all of its children.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,22 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel
CodeGenerator.compile(code)
}
}

test("SPARK-35410: SubExpr elimination should not include redundant child exprs " +
"for conditional expressions") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this only a problem for conditional expression?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So far the only one I can think about.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Found a non-conditional example that still is an issue even with this update (a bit contrived, but I'm sure there's a real use case)

val myUdf = udf(() => {
  println("In UDF")
  1
}).withName("myUdf")
spark.range(1).withColumn("a", myUdf()).select(($"a" + $"a") / ($"a" + $"a")).show()

This generates subexpressions myUdf() and (myUdf() + myUdf()), even though only the second one is used.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @Kimahriman. I see. Let me also look at it. As it is non-conditional case, but looks like the similar case. Let me see if it can be solved similarly.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I figured out. This might be an issue since we have sub-expr elimination. We also need to remove redundant children exprs for non-conditional cases.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But the fix might be different. I will work on it locally and submit another fix for it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any more thoughts on this? Was the subexpr sorting supposed to address this?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might need another fix. I'm working on it and will submit it after these PRs merged.

val add1 = Add(Literal(1), Literal(2))
val add2 = Add(Literal(2), Literal(3))
val add3 = Add(add1, add2)
val condition = (GreaterThan(add3, Literal(3)), add3) :: Nil

val caseWhenExpr = CaseWhen(condition, None)
val equivalence = new EquivalentExpressions
equivalence.addExprTree(caseWhenExpr)

val commonExprs = equivalence.getAllEquivalentExprs.filter(_.size > 1)
assert(commonExprs.size == 1)
assert(commonExprs.head === Seq(add3, add3))
}
}

case class CodegenFallbackExpression(child: Expression)
Expand Down
19 changes: 19 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2861,6 +2861,25 @@ class DataFrameSuite extends QueryTest
checkAnswer(result, Row(0, 0, 0, 0, 100))
}
}

test("SPARK-35410: SubExpr elimination should not include redundant child exprs " +
"for conditional expressions") {
val accum = sparkContext.longAccumulator("call")
val simpleUDF = udf((s: String) => {
accum.add(1)
s
})
val df1 = spark.range(5).select(when(functions.length(simpleUDF($"id")) > 0,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the fix for https://issues.apache.org/jira/browse/SPARK-35449 will break this, since it's really a "bug" that the case value is included in subexpression resolution without an else value. Not a huge deal, I can try to fix in my follow up once this is merged

functions.length(simpleUDF($"id"))))
df1.collect()
assert(accum.value == 5)

val nondeterministicUDF = simpleUDF.asNondeterministic()
val df2 = spark.range(5).select(when(functions.length(nondeterministicUDF($"id")) > 0,
functions.length(nondeterministicUDF($"id"))))
df2.collect()
assert(accum.value == 15)
}
}

case class GroupByKey(a: Int, b: Int)
Expand Down