Skip to content
Closed
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -536,6 +536,12 @@ abstract class UnaryExpression extends Expression {
}
}


object UnaryExpression {
def unapply(e: UnaryExpression): Option[Expression] = Some(e.child)
}


/**
* An expression with two inputs and one output. The output is by default evaluated to null
* if any input is evaluated to null.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -542,29 +542,41 @@ object PushFoldableIntoBranches extends Rule[LogicalPlan] with PredicateHelper {

def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case q: LogicalPlan => q transformExpressionsUp {
case u @ UnaryExpression(i @ If(_, trueValue, falseValue))
if !u.isInstanceOf[Alias] && atMostOneUnfoldable(Seq(trueValue, falseValue)) =>
i.copy(
trueValue = u.withNewChildren(Array(trueValue)),
falseValue = u.withNewChildren(Array(falseValue)))

case u @ UnaryExpression(c @ CaseWhen(branches, elseValue))
if !u.isInstanceOf[Alias] && atMostOneUnfoldable(branches.map(_._2) ++ elseValue) =>
c.copy(
branches.map(e => e.copy(_2 = u.withNewChildren(Array(e._2)))),
elseValue.map(e => u.withNewChildren(Array(e))))

case b @ BinaryExpression(i @ If(_, trueValue, falseValue), right)
if right.foldable && atMostOneUnfoldable(Seq(trueValue, falseValue)) =>
i.copy(
trueValue = b.makeCopy(Array(trueValue, right)),
falseValue = b.makeCopy(Array(falseValue, right)))
trueValue = b.withNewChildren(Array(trueValue, right)),
falseValue = b.withNewChildren(Array(falseValue, right)))

case b @ BinaryExpression(left, i @ If(_, trueValue, falseValue))
if left.foldable && atMostOneUnfoldable(Seq(trueValue, falseValue)) =>
i.copy(
trueValue = b.makeCopy(Array(left, trueValue)),
falseValue = b.makeCopy(Array(left, falseValue)))
trueValue = b.withNewChildren(Array(left, trueValue)),
falseValue = b.withNewChildren(Array(left, falseValue)))

case b @ BinaryExpression(c @ CaseWhen(branches, elseValue), right)
if right.foldable && atMostOneUnfoldable(branches.map(_._2) ++ elseValue) =>
c.copy(
branches.map(e => e.copy(_2 = b.makeCopy(Array(e._2, right)))),
elseValue.map(e => b.makeCopy(Array(e, right))))
branches.map(e => e.copy(_2 = b.withNewChildren(Array(e._2, right)))),
elseValue.map(e => b.withNewChildren(Array(e, right))))

case b @ BinaryExpression(left, c @ CaseWhen(branches, elseValue))
if left.foldable && atMostOneUnfoldable(branches.map(_._2) ++ elseValue) =>
c.copy(
branches.map(e => e.copy(_2 = b.makeCopy(Array(left, e._2)))),
elseValue.map(e => b.makeCopy(Array(left, e))))
branches.map(e => e.copy(_2 = b.withNewChildren(Array(left, e._2)))),
elseValue.map(e => b.withNewChildren(Array(left, e))))
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,14 @@ import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLite
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.types.{BooleanType, IntegerType}
import org.apache.spark.sql.types.{BooleanType, IntegerType, StringType}


class PushFoldableIntoBranchesSuite
extends PlanTest with ExpressionEvalHelper with PredicateHelper {

object Optimize extends RuleExecutor[LogicalPlan] {
val batches = Batch("PushFoldableIntoBranches", FixedPoint(50),
val batches = Batch("PushFoldableIntoBranchesSuite", FixedPoint(50),
BooleanSimplification, ConstantFolding, SimplifyConditionals, PushFoldableIntoBranches) :: Nil
}

Expand Down Expand Up @@ -222,4 +222,41 @@ class PushFoldableIntoBranchesSuite
assertEquivalent(EqualTo(Literal(4), ifExp), FalseLiteral)
assertEquivalent(EqualTo(Literal(4), caseWhen), FalseLiteral)
}

test("Push down cast through If/CaseWhen") {
assertEquivalent(If(a, Literal(2), Literal(3)).cast(StringType),
If(a, Literal("2"), Literal("3")))
assertEquivalent(If(a, b, Literal(3)).cast(StringType),
If(a, b.cast(StringType), Literal("3")))
assertEquivalent(If(a, b, b + 1).cast(StringType),
If(a, b, b + 1).cast(StringType))

assertEquivalent(
CaseWhen(Seq((a, Literal(1))), Some(Literal(3))).cast(StringType),
CaseWhen(Seq((a, Literal("1"))), Some(Literal("3"))))
assertEquivalent(
CaseWhen(Seq((a, Literal(1))), Some(b)).cast(StringType),
CaseWhen(Seq((a, Literal("1"))), Some(b.cast(StringType))))
assertEquivalent(
CaseWhen(Seq((a, b)), Some(b + 1)).cast(StringType),
CaseWhen(Seq((a, b)), Some(b + 1)).cast(StringType))
}

test("Push down abs through If/CaseWhen") {
assertEquivalent(Abs(If(a, Literal(-2), Literal(-3))), If(a, Literal(2), Literal(3)))
assertEquivalent(
Abs(CaseWhen(Seq((a, Literal(-1))), Some(Literal(-3)))),
CaseWhen(Seq((a, Literal(1))), Some(Literal(3))))
}

test("Push down cast with binary expression through If/CaseWhen") {
assertEquivalent(EqualTo(If(a, Literal(2), Literal(3)).cast(StringType), Literal("4")),
FalseLiteral)
assertEquivalent(
EqualTo(CaseWhen(Seq((a, Literal(1))), Some(Literal(3))).cast(StringType), Literal("4")),
FalseLiteral)
assertEquivalent(
EqualTo(CaseWhen(Seq((a, Literal(1)), (c, Literal(2))), None).cast(StringType), Literal("4")),
CaseWhen(Seq((a, FalseLiteral), (c, FalseLiteral)), None))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -130,24 +130,24 @@ Input [6]: [inv_warehouse_sk#3, inv_quantity_on_hand#4, i_item_id#6, d_date#10,
(23) HashAggregate [codegen id : 4]
Input [4]: [inv_quantity_on_hand#4, w_warehouse_name#13, i_item_id#6, d_date#10]
Keys [2]: [w_warehouse_name#13, i_item_id#6]
Functions [2]: [partial_sum(cast(CASE WHEN (d_date#10 < 11027) THEN inv_quantity_on_hand#4 ELSE 0 END as bigint)), partial_sum(cast(CASE WHEN (d_date#10 >= 11027) THEN inv_quantity_on_hand#4 ELSE 0 END as bigint))]
Functions [2]: [partial_sum(CASE WHEN (d_date#10 < 11027) THEN cast(inv_quantity_on_hand#4 as bigint) ELSE 0 END), partial_sum(CASE WHEN (d_date#10 >= 11027) THEN cast(inv_quantity_on_hand#4 as bigint) ELSE 0 END)]
Aggregate Attributes [2]: [sum#15, sum#16]
Results [4]: [w_warehouse_name#13, i_item_id#6, sum#17, sum#18]

(24) Exchange
Input [4]: [w_warehouse_name#13, i_item_id#6, sum#17, sum#18]
Arguments: hashpartitioning(w_warehouse_name#13, i_item_id#6, 5), true, [id=#19]
Arguments: hashpartitioning(w_warehouse_name#13, i_item_id#6, 5), ENSURE_REQUIREMENTS, [id=#19]

(25) HashAggregate [codegen id : 5]
Input [4]: [w_warehouse_name#13, i_item_id#6, sum#17, sum#18]
Keys [2]: [w_warehouse_name#13, i_item_id#6]
Functions [2]: [sum(cast(CASE WHEN (d_date#10 < 11027) THEN inv_quantity_on_hand#4 ELSE 0 END as bigint)), sum(cast(CASE WHEN (d_date#10 >= 11027) THEN inv_quantity_on_hand#4 ELSE 0 END as bigint))]
Aggregate Attributes [2]: [sum(cast(CASE WHEN (d_date#10 < 11027) THEN inv_quantity_on_hand#4 ELSE 0 END as bigint))#20, sum(cast(CASE WHEN (d_date#10 >= 11027) THEN inv_quantity_on_hand#4 ELSE 0 END as bigint))#21]
Results [4]: [w_warehouse_name#13, i_item_id#6, sum(cast(CASE WHEN (d_date#10 < 11027) THEN inv_quantity_on_hand#4 ELSE 0 END as bigint))#20 AS inv_before#22, sum(cast(CASE WHEN (d_date#10 >= 11027) THEN inv_quantity_on_hand#4 ELSE 0 END as bigint))#21 AS inv_after#23]
Functions [2]: [sum(CASE WHEN (d_date#10 < 11027) THEN cast(inv_quantity_on_hand#4 as bigint) ELSE 0 END), sum(CASE WHEN (d_date#10 >= 11027) THEN cast(inv_quantity_on_hand#4 as bigint) ELSE 0 END)]
Aggregate Attributes [2]: [sum(CASE WHEN (d_date#10 < 11027) THEN cast(inv_quantity_on_hand#4 as bigint) ELSE 0 END)#20, sum(CASE WHEN (d_date#10 >= 11027) THEN cast(inv_quantity_on_hand#4 as bigint) ELSE 0 END)#21]
Results [4]: [w_warehouse_name#13, i_item_id#6, sum(CASE WHEN (d_date#10 < 11027) THEN cast(inv_quantity_on_hand#4 as bigint) ELSE 0 END)#20 AS inv_before#22, sum(CASE WHEN (d_date#10 >= 11027) THEN cast(inv_quantity_on_hand#4 as bigint) ELSE 0 END)#21 AS inv_after#23]

(26) Filter [codegen id : 5]
Input [4]: [w_warehouse_name#13, i_item_id#6, inv_before#22, inv_after#23]
Condition : ((CASE WHEN (inv_before#22 > 0) THEN (cast(inv_after#23 as double) / cast(inv_before#22 as double)) ELSE null END >= 0.666667) AND (CASE WHEN (inv_before#22 > 0) THEN (cast(inv_after#23 as double) / cast(inv_before#22 as double)) ELSE null END <= 1.5))
Condition : (CASE WHEN (inv_before#22 > 0) THEN ((cast(inv_after#23 as double) / cast(inv_before#22 as double)) >= 0.666667) ELSE false END AND CASE WHEN (inv_before#22 > 0) THEN ((cast(inv_after#23 as double) / cast(inv_before#22 as double)) <= 1.5) ELSE false END)

(27) TakeOrderedAndProject
Input [4]: [w_warehouse_name#13, i_item_id#6, inv_before#22, inv_after#23]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
TakeOrderedAndProject [w_warehouse_name,i_item_id,inv_before,inv_after]
WholeStageCodegen (5)
Filter [inv_before,inv_after]
HashAggregate [w_warehouse_name,i_item_id,sum,sum] [sum(cast(CASE WHEN (d_date < 11027) THEN inv_quantity_on_hand ELSE 0 END as bigint)),sum(cast(CASE WHEN (d_date >= 11027) THEN inv_quantity_on_hand ELSE 0 END as bigint)),inv_before,inv_after,sum,sum]
HashAggregate [w_warehouse_name,i_item_id,sum,sum] [sum(CASE WHEN (d_date < 11027) THEN cast(inv_quantity_on_hand as bigint) ELSE 0 END),sum(CASE WHEN (d_date >= 11027) THEN cast(inv_quantity_on_hand as bigint) ELSE 0 END),inv_before,inv_after,sum,sum]
InputAdapter
Exchange [w_warehouse_name,i_item_id] #1
WholeStageCodegen (4)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,24 +130,24 @@ Input [6]: [inv_date_sk#1, inv_quantity_on_hand#4, w_warehouse_name#6, i_item_id
(23) HashAggregate [codegen id : 4]
Input [4]: [inv_quantity_on_hand#4, w_warehouse_name#6, i_item_id#9, d_date#13]
Keys [2]: [w_warehouse_name#6, i_item_id#9]
Functions [2]: [partial_sum(cast(CASE WHEN (d_date#13 < 11027) THEN inv_quantity_on_hand#4 ELSE 0 END as bigint)), partial_sum(cast(CASE WHEN (d_date#13 >= 11027) THEN inv_quantity_on_hand#4 ELSE 0 END as bigint))]
Functions [2]: [partial_sum(CASE WHEN (d_date#13 < 11027) THEN cast(inv_quantity_on_hand#4 as bigint) ELSE 0 END), partial_sum(CASE WHEN (d_date#13 >= 11027) THEN cast(inv_quantity_on_hand#4 as bigint) ELSE 0 END)]
Aggregate Attributes [2]: [sum#15, sum#16]
Results [4]: [w_warehouse_name#6, i_item_id#9, sum#17, sum#18]

(24) Exchange
Input [4]: [w_warehouse_name#6, i_item_id#9, sum#17, sum#18]
Arguments: hashpartitioning(w_warehouse_name#6, i_item_id#9, 5), true, [id=#19]
Arguments: hashpartitioning(w_warehouse_name#6, i_item_id#9, 5), ENSURE_REQUIREMENTS, [id=#19]

(25) HashAggregate [codegen id : 5]
Input [4]: [w_warehouse_name#6, i_item_id#9, sum#17, sum#18]
Keys [2]: [w_warehouse_name#6, i_item_id#9]
Functions [2]: [sum(cast(CASE WHEN (d_date#13 < 11027) THEN inv_quantity_on_hand#4 ELSE 0 END as bigint)), sum(cast(CASE WHEN (d_date#13 >= 11027) THEN inv_quantity_on_hand#4 ELSE 0 END as bigint))]
Aggregate Attributes [2]: [sum(cast(CASE WHEN (d_date#13 < 11027) THEN inv_quantity_on_hand#4 ELSE 0 END as bigint))#20, sum(cast(CASE WHEN (d_date#13 >= 11027) THEN inv_quantity_on_hand#4 ELSE 0 END as bigint))#21]
Results [4]: [w_warehouse_name#6, i_item_id#9, sum(cast(CASE WHEN (d_date#13 < 11027) THEN inv_quantity_on_hand#4 ELSE 0 END as bigint))#20 AS inv_before#22, sum(cast(CASE WHEN (d_date#13 >= 11027) THEN inv_quantity_on_hand#4 ELSE 0 END as bigint))#21 AS inv_after#23]
Functions [2]: [sum(CASE WHEN (d_date#13 < 11027) THEN cast(inv_quantity_on_hand#4 as bigint) ELSE 0 END), sum(CASE WHEN (d_date#13 >= 11027) THEN cast(inv_quantity_on_hand#4 as bigint) ELSE 0 END)]
Aggregate Attributes [2]: [sum(CASE WHEN (d_date#13 < 11027) THEN cast(inv_quantity_on_hand#4 as bigint) ELSE 0 END)#20, sum(CASE WHEN (d_date#13 >= 11027) THEN cast(inv_quantity_on_hand#4 as bigint) ELSE 0 END)#21]
Results [4]: [w_warehouse_name#6, i_item_id#9, sum(CASE WHEN (d_date#13 < 11027) THEN cast(inv_quantity_on_hand#4 as bigint) ELSE 0 END)#20 AS inv_before#22, sum(CASE WHEN (d_date#13 >= 11027) THEN cast(inv_quantity_on_hand#4 as bigint) ELSE 0 END)#21 AS inv_after#23]

(26) Filter [codegen id : 5]
Input [4]: [w_warehouse_name#6, i_item_id#9, inv_before#22, inv_after#23]
Condition : ((CASE WHEN (inv_before#22 > 0) THEN (cast(inv_after#23 as double) / cast(inv_before#22 as double)) ELSE null END >= 0.666667) AND (CASE WHEN (inv_before#22 > 0) THEN (cast(inv_after#23 as double) / cast(inv_before#22 as double)) ELSE null END <= 1.5))
Condition : (CASE WHEN (inv_before#22 > 0) THEN ((cast(inv_after#23 as double) / cast(inv_before#22 as double)) >= 0.666667) ELSE false END AND CASE WHEN (inv_before#22 > 0) THEN ((cast(inv_after#23 as double) / cast(inv_before#22 as double)) <= 1.5) ELSE false END)

(27) TakeOrderedAndProject
Input [4]: [w_warehouse_name#6, i_item_id#9, inv_before#22, inv_after#23]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
TakeOrderedAndProject [w_warehouse_name,i_item_id,inv_before,inv_after]
WholeStageCodegen (5)
Filter [inv_before,inv_after]
HashAggregate [w_warehouse_name,i_item_id,sum,sum] [sum(cast(CASE WHEN (d_date < 11027) THEN inv_quantity_on_hand ELSE 0 END as bigint)),sum(cast(CASE WHEN (d_date >= 11027) THEN inv_quantity_on_hand ELSE 0 END as bigint)),inv_before,inv_after,sum,sum]
HashAggregate [w_warehouse_name,i_item_id,sum,sum] [sum(CASE WHEN (d_date < 11027) THEN cast(inv_quantity_on_hand as bigint) ELSE 0 END),sum(CASE WHEN (d_date >= 11027) THEN cast(inv_quantity_on_hand as bigint) ELSE 0 END),inv_before,inv_after,sum,sum]
InputAdapter
Exchange [w_warehouse_name,i_item_id] #1
WholeStageCodegen (4)
Expand Down
Loading