diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index b2625bddeecf4..5212468840cd5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -548,41 +548,68 @@ object PushFoldableIntoBranches extends Rule[LogicalPlan] with PredicateHelper { foldables.nonEmpty && others.length < 2 } + // Not all UnaryExpression can be pushed into (if / case) branches, e.g. Alias. + private def supportedUnaryExpression(e: UnaryExpression): Boolean = e match { + case _: IsNull | _: IsNotNull => true + case _: UnaryMathExpression | _: Abs | _: Bin | _: Factorial | _: Hex => true + case _: String2StringExpression | _: Ascii | _: Base64 | _: BitLength | _: Chr | _: Length => + true + case _: CastBase => true + case _: GetDateField | _: LastDay => true + case _: ExtractIntervalPart => true + case _: ArraySetLike => true + case _: ExtractValue => true + case _ => false + } + + // Not all BinaryExpression can be pushed into (if / case) branches. + private def supportedBinaryExpression(e: BinaryExpression): Boolean = e match { + case _: BinaryComparison | _: StringPredicate | _: StringRegexExpression => true + case _: BinaryArithmetic => true + case _: BinaryMathExpression => true + case _: AddMonths | _: DateAdd | _: DateAddInterval | _: DateDiff | _: DateSub => true + case _: FindInSet | _: RoundBase => true + case _ => false + } + def apply(plan: LogicalPlan): LogicalPlan = plan transform { case q: LogicalPlan => q transformExpressionsUp { - case a: Alias => a // Skip an alias. case u @ UnaryExpression(i @ If(_, trueValue, falseValue)) - if atMostOneUnfoldable(Seq(trueValue, falseValue)) => + if supportedUnaryExpression(u) && atMostOneUnfoldable(Seq(trueValue, falseValue)) => i.copy( trueValue = u.withNewChildren(Array(trueValue)), falseValue = u.withNewChildren(Array(falseValue))) case u @ UnaryExpression(c @ CaseWhen(branches, elseValue)) - if atMostOneUnfoldable(branches.map(_._2) ++ elseValue) => + if supportedUnaryExpression(u) && atMostOneUnfoldable(branches.map(_._2) ++ elseValue) => c.copy( branches.map(e => e.copy(_2 = u.withNewChildren(Array(e._2)))), elseValue.map(e => u.withNewChildren(Array(e)))) case b @ BinaryExpression(i @ If(_, trueValue, falseValue), right) - if right.foldable && atMostOneUnfoldable(Seq(trueValue, falseValue)) => + if supportedBinaryExpression(b) && right.foldable && + atMostOneUnfoldable(Seq(trueValue, falseValue)) => i.copy( trueValue = b.withNewChildren(Array(trueValue, right)), falseValue = b.withNewChildren(Array(falseValue, right))) case b @ BinaryExpression(left, i @ If(_, trueValue, falseValue)) - if left.foldable && atMostOneUnfoldable(Seq(trueValue, falseValue)) => + if supportedBinaryExpression(b) && left.foldable && + atMostOneUnfoldable(Seq(trueValue, falseValue)) => i.copy( trueValue = b.withNewChildren(Array(left, trueValue)), falseValue = b.withNewChildren(Array(left, falseValue))) case b @ BinaryExpression(c @ CaseWhen(branches, elseValue), right) - if right.foldable && atMostOneUnfoldable(branches.map(_._2) ++ elseValue) => + if supportedBinaryExpression(b) && right.foldable && + atMostOneUnfoldable(branches.map(_._2) ++ elseValue) => c.copy( branches.map(e => e.copy(_2 = b.withNewChildren(Array(e._2, right)))), elseValue.map(e => b.withNewChildren(Array(e, right)))) case b @ BinaryExpression(left, c @ CaseWhen(branches, elseValue)) - if left.foldable && atMostOneUnfoldable(branches.map(_._2) ++ elseValue) => + if supportedBinaryExpression(b) && left.foldable && + atMostOneUnfoldable(branches.map(_._2) ++ elseValue) => c.copy( branches.map(e => e.copy(_2 = b.withNewChildren(Array(left, e._2)))), elseValue.map(e => b.withNewChildren(Array(left, e))))