Skip to content

Commit 0acb157

Browse files
committed
Improve the code according to the comments.
1 parent 37bfa88 commit 0acb157

2 files changed

Lines changed: 29 additions & 32 deletions

File tree

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 25 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -742,42 +742,35 @@ object InferFiltersFromConstraints extends Rule[LogicalPlan] with PredicateHelpe
742742
* Reorder associative integral-type operators and fold all constants into one.
743743
*/
744744
object ReorderAssociativeOperator extends Rule[LogicalPlan] {
745-
private def isAssociativelyFoldable(e: Expression): Boolean =
746-
e.deterministic && e.isInstanceOf[BinaryArithmetic] && e.dataType.isInstanceOf[IntegralType] &&
747-
isSingleOperatorExpr(e)
748-
749-
private def isSingleOperatorExpr(e: Expression): Boolean = e.find {
750-
case a: Add if a.getClass == e.getClass => false
751-
case m: Multiply if m.getClass == e.getClass => false
752-
case _: BinaryArithmetic => true
753-
case _ => false
754-
}.isEmpty
745+
private def flattenAdd(e: Expression): Seq[Expression] = e match {
746+
case Add(l, r) => flattenAdd(l) ++ flattenAdd(r)
747+
case other => other :: Nil
748+
}
755749

756-
private def getOperandList(e: Expression): Seq[Expression] = e match {
757-
case BinaryArithmetic(a, b) => getOperandList(a) ++ getOperandList(b)
750+
private def flattenMultiply(e: Expression): Seq[Expression] = e match {
751+
case Multiply(l, r) => flattenMultiply(l) ++ flattenMultiply(r)
758752
case other => other :: Nil
759753
}
760754

761-
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
762-
case q: LogicalPlan => q transformExpressionsDown {
763-
case e if isAssociativelyFoldable(e) =>
764-
val (foldables, others) = getOperandList(e).partition(_.foldable)
765-
if (foldables.size > 1) {
766-
e match {
767-
case a: Add =>
768-
val foldableExpr = foldables.reduce((x, y) => Add(x, y))
769-
val c = Literal.create(foldableExpr.eval(EmptyRow), e.dataType)
770-
if (others.isEmpty) c else Add(others.reduce((x, y) => Add(x, y)), c)
771-
case m: Multiply =>
772-
val foldableExpr = foldables.reduce((x, y) => Multiply(x, y))
773-
val c = Literal.create(foldableExpr.eval(EmptyRow), e.dataType)
774-
if (others.isEmpty) c else Multiply(others.reduce((x, y) => Multiply(x, y)), c)
775-
case _ => e
776-
}
777-
} else {
778-
e
779-
}
780-
}
755+
def apply(plan: LogicalPlan): LogicalPlan = plan transformExpressionsDown {
756+
case a: Add if a.deterministic && a.dataType.isInstanceOf[IntegralType] =>
757+
val (foldables, others) = flattenAdd(a).partition(_.foldable)
758+
if (foldables.size > 1) {
759+
val foldableExpr = foldables.reduce((x, y) => Add(x, y))
760+
val c = Literal.create(foldableExpr.eval(EmptyRow), a.dataType)
761+
if (others.isEmpty) c else Add(others.reduce((x, y) => Add(x, y)), c)
762+
} else {
763+
a
764+
}
765+
case m: Multiply if m.deterministic && m.dataType.isInstanceOf[IntegralType] =>
766+
val (foldables, others) = flattenMultiply(m).partition(_.foldable)
767+
if (foldables.size > 1) {
768+
val foldableExpr = foldables.reduce((x, y) => Multiply(x, y))
769+
val c = Literal.create(foldableExpr.eval(EmptyRow), m.dataType)
770+
if (others.isEmpty) c else Multiply(others.reduce((x, y) => Multiply(x, y)), c)
771+
} else {
772+
m
773+
}
781774
}
782775
}
783776

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReorderAssociativeOperatorSuite.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,9 @@ class ReorderAssociativeOperatorSuite extends PlanTest {
4040
.select(
4141
(Literal(3) + ((Literal(1) + 'a) + 2)) + 4,
4242
'b * 1 * 2 * 3 * 4,
43+
('b + 1) * 2 * 3 * 4,
4344
'a + 1 + 'b + 2 + 'c + 3,
45+
'a + 1 + 'b * 2 + 'c + 3,
4446
Rand(0) * 1 * 2 * 3 * 4)
4547

4648
val optimized = Optimize.execute(originalQuery.analyze)
@@ -50,7 +52,9 @@ class ReorderAssociativeOperatorSuite extends PlanTest {
5052
.select(
5153
('a + 10).as("((3 + ((1 + a) + 2)) + 4)"),
5254
('b * 24).as("((((b * 1) * 2) * 3) * 4)"),
55+
(('b + 1) * 24).as("((((b + 1) * 2) * 3) * 4)"),
5356
('a + 'b + 'c + 6).as("(((((a + 1) + b) + 2) + c) + 3)"),
57+
('a + 'b * 2 + 'c + 4).as("((((a + 1) + (b * 2)) + c) + 3)"),
5458
Rand(0) * 1 * 2 * 3 * 4)
5559
.analyze
5660

0 commit comments

Comments
 (0)