Skip to content
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class EquivalentExpressions {
}

// For each expression, the set of equivalent expressions.
private val equivalenceMap = mutable.HashMap.empty[Expr, mutable.ArrayBuffer[Expression]]
private val equivalenceMap = mutable.LinkedHashMap.empty[Expr, mutable.ArrayBuffer[Expression]]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

after the new change, does it still need to be LinkedHashMap?

Copy link
Member Author

@viirya viirya May 19, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea, can be reverted back to HashMap, if we are going to sort it at all.


/**
* Adds each expression to this data structure, grouping them with existing equivalent
Expand Down Expand Up @@ -167,7 +167,25 @@ class EquivalentExpressions {
* Returns all the equivalent sets of expressions.
*/
def getAllEquivalentExprs: Seq[Seq[Expression]] = {
equivalenceMap.values.map(_.toSeq).toSeq
equivalenceMap.values.map(_.toSeq).toSeq.sortBy(_.head)(new ExpressionOrdering)
}

/**
* Orders [Expression] by parent/child relations. The child expression is smaller
* than parent expression. If there is child-parent relationships among the subexpressions,
* we want the child expressions come first than parent expressions, so we can replace
* child expressions in parent expressions with subexpression evaluation.
*/
class ExpressionOrdering extends Ordering[Expression] {
override def compare(x: Expression, y: Expression): Int = {
if (x.semanticEquals(y)) {
0
} else if (x.find(_.semanticEquals(y)).isDefined) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we run TPCDSQuerySuite and see the time of the query compilation phase? This looks like a very expensive sort.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok. Let me compare before/after this PR.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW, I think better approach is to sort after filter (e.g. size > 1 in most use-case), because the number of sub-exprs should be smaller.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed the call usage of getAllEquivalentExprs. So we filter it first and then do sorting.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ran TPCDSQuerySuite.

Before (master):

23.233160578 seconds 
22.501728011 seconds
23.547332524 seconds

After:

23.995751468 seconds 
22.262832936 seconds
21.503776059 seconds  

I don't see significant difference there.

1
} else {
-1
}
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,35 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel
CodeGenerator.compile(code)
}
}

test("SPARK-35439: Children subexpr should come first than parent subexpr") {
val add = Add(Literal(1), Literal(2))

val equivalence1 = new EquivalentExpressions

equivalence1.addExprTree(add)
assert(equivalence1.getAllEquivalentExprs.head === Seq(add))

equivalence1.addExprTree(Add(Literal(3), add))
assert(equivalence1.getAllEquivalentExprs ===
Seq(Seq(add, add), Seq(Add(Literal(3), add))))

equivalence1.addExprTree(Add(Literal(3), add))
assert(equivalence1.getAllEquivalentExprs ===
Seq(Seq(add, add), Seq(Add(Literal(3), add), Add(Literal(3), add))))

val equivalence2 = new EquivalentExpressions

equivalence2.addExprTree(Add(Literal(3), add))
assert(equivalence2.getAllEquivalentExprs === Seq(Seq(add), Seq(Add(Literal(3), add))))

equivalence2.addExprTree(add)
assert(equivalence2.getAllEquivalentExprs === Seq(Seq(add, add), Seq(Add(Literal(3), add))))

equivalence2.addExprTree(Add(Literal(3), add))
assert(equivalence2.getAllEquivalentExprs ===
Seq(Seq(add, add), Seq(Add(Literal(3), add), Add(Literal(3), add))))
}
}

case class CodegenFallbackExpression(child: Expression)
Expand Down