@@ -742,42 +742,35 @@ object InferFiltersFromConstraints extends Rule[LogicalPlan] with PredicateHelpe
742742 * Reorder associative integral-type operators and fold all constants into one.
743743 */
744744object ReorderAssociativeOperator extends Rule [LogicalPlan ] {
745- private def isAssociativelyFoldable (e : Expression ): Boolean =
746- e.deterministic && e.isInstanceOf [BinaryArithmetic ] && e.dataType.isInstanceOf [IntegralType ] &&
747- isSingleOperatorExpr(e)
748-
749- private def isSingleOperatorExpr (e : Expression ): Boolean = e.find {
750- case a : Add if a.getClass == e.getClass => false
751- case m : Multiply if m.getClass == e.getClass => false
752- case _ : BinaryArithmetic => true
753- case _ => false
754- }.isEmpty
745+ private def flattenAdd (e : Expression ): Seq [Expression ] = e match {
746+ case Add (l, r) => flattenAdd(l) ++ flattenAdd(r)
747+ case other => other :: Nil
748+ }
755749
756- private def getOperandList (e : Expression ): Seq [Expression ] = e match {
757- case BinaryArithmetic (a, b ) => getOperandList(a ) ++ getOperandList(b )
750+ private def flattenMultiply (e : Expression ): Seq [Expression ] = e match {
751+ case Multiply (l, r ) => flattenMultiply(l ) ++ flattenMultiply(r )
758752 case other => other :: Nil
759753 }
760754
761- def apply (plan : LogicalPlan ): LogicalPlan = plan transform {
762- case q : LogicalPlan => q transformExpressionsDown {
763- case e if isAssociativelyFoldable(e) =>
764- val (foldables, others) = getOperandList(e).partition(_.foldable)
765- if (foldables.size > 1 ) {
766- e match {
767- case a : Add =>
768- val foldableExpr = foldables.reduce((x, y) => Add (x, y))
769- val c = Literal .create(foldableExpr.eval(EmptyRow ), e.dataType)
770- if (others.isEmpty) c else Add (others.reduce((x, y) => Add (x, y)), c)
771- case m : Multiply =>
772- val foldableExpr = foldables.reduce((x, y) => Multiply (x, y))
773- val c = Literal .create(foldableExpr.eval(EmptyRow ), e.dataType)
774- if (others.isEmpty) c else Multiply (others.reduce((x, y) => Multiply (x, y)), c)
775- case _ => e
776- }
777- } else {
778- e
779- }
780- }
755+ def apply (plan : LogicalPlan ): LogicalPlan = plan transformExpressionsDown {
756+ case a : Add if a.deterministic && a.dataType.isInstanceOf [IntegralType ] =>
757+ val (foldables, others) = flattenAdd(a).partition(_.foldable)
758+ if (foldables.size > 1 ) {
759+ val foldableExpr = foldables.reduce((x, y) => Add (x, y))
760+ val c = Literal .create(foldableExpr.eval(EmptyRow ), a.dataType)
761+ if (others.isEmpty) c else Add (others.reduce((x, y) => Add (x, y)), c)
762+ } else {
763+ a
764+ }
765+ case m : Multiply if m.deterministic && m.dataType.isInstanceOf [IntegralType ] =>
766+ val (foldables, others) = flattenMultiply(m).partition(_.foldable)
767+ if (foldables.size > 1 ) {
768+ val foldableExpr = foldables.reduce((x, y) => Multiply (x, y))
769+ val c = Literal .create(foldableExpr.eval(EmptyRow ), m.dataType)
770+ if (others.isEmpty) c else Multiply (others.reduce((x, y) => Multiply (x, y)), c)
771+ } else {
772+ m
773+ }
781774 }
782775}
783776
0 commit comments