Skip to content
Closed
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ abstract class Optimizer(catalogManager: CatalogManager)
OptimizeWindowFunctions,
CollapseWindow,
CombineFilters,
CombineLimits,
EliminateLimits,
CombineUnions,
// Constant folding and strength reduction
TransposeWindow,
Expand Down Expand Up @@ -1452,11 +1452,23 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper {
}

/**
* Combines two adjacent [[Limit]] operators into one, merging the
* expressions into one single expression.
* 1. Eliminate [[Limit]] operators if it's child max row <= limit.
* 2. Combines two adjacent [[Limit]] operators into one, merging the
* expressions into one single expression.
*/
object CombineLimits extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
object EliminateLimits extends Rule[LogicalPlan] {
private def canEliminate(limitExpr: Expression, childMaxRow: Option[Long]): Boolean = {
limitExpr.foldable &&
childMaxRow.isDefined &&
childMaxRow.get <= limitExpr.eval().toString.toInt
}

def apply(plan: LogicalPlan): LogicalPlan = plan transformDown {
case GlobalLimit(l, child) if canEliminate(l, child.maxRows) =>
child
case LocalLimit(l, child) if canEliminate(l, child.maxRows) =>
child

case GlobalLimit(le, GlobalLimit(ne, grandChild)) =>
GlobalLimit(Least(Seq(ne, le)), grandChild)
case LocalLimit(le, LocalLimit(ne, grandChild)) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ class CombiningLimitsSuite extends PlanTest {
Batch("Column Pruning", FixedPoint(100),
ColumnPruning,
RemoveNoopOperators) ::
Batch("Combine Limit", FixedPoint(10),
CombineLimits) ::
Batch("Eliminate Limit", FixedPoint(10),
EliminateLimits) ::
Batch("Constant Folding", FixedPoint(10),
NullPropagation,
ConstantFolding,
Expand Down Expand Up @@ -90,4 +90,22 @@ class CombiningLimitsSuite extends PlanTest {

comparePlans(optimized, correctAnswer)
}

test("SPARK-33442: Change Combine Limit to Eliminate limit using max row") {
// test child max row <= limit.
val query1 = testRelation.select().groupBy()(count(1)).limit(1).analyze
val optimized1 = Optimize.execute(query1)
val expected1 = testRelation.select().groupBy()(count(1)).analyze
comparePlans(optimized1, expected1)

// test child max row > limit.
val query2 = testRelation.select().groupBy()(count(1)).limit(0).analyze
val optimized2 = Optimize.execute(query2)
comparePlans(optimized2, query2)

// test child max row is none
val query3 = testRelation.select(Symbol("a")).limit(1).analyze
val optimized3 = Optimize.execute(query3)
comparePlans(optimized3, query3)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class LimitPushdownSuite extends PlanTest {
EliminateSubqueryAliases) ::
Batch("Limit pushdown", FixedPoint(100),
LimitPushDown,
CombineLimits,
EliminateLimits,
ConstantFolding,
BooleanSimplification) :: Nil
}
Expand Down