Skip to content

Commit 81e4828

Browse files
committed
Use separate rule and add more tests
1 parent 550ff99 commit 81e4828

7 files changed

Lines changed: 133 additions & 27 deletions

File tree

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,8 @@ abstract class Optimizer(sessionCatalog: SessionCatalog)
140140
operatorOptimizationBatch) :+
141141
Batch("Join Reorder", Once,
142142
CostBasedJoinReorder) :+
143+
Batch("Remove Redundant Sorts", Once,
144+
RemoveRedundantSorts) :+
143145
Batch("Decimal Optimizations", fixedPoint,
144146
DecimalAggregates) :+
145147
Batch("Object Expressions Optimization", fixedPoint,
@@ -730,8 +732,16 @@ object EliminateSorts extends Rule[LogicalPlan] {
730732
case s @ Sort(orders, _, child) if orders.isEmpty || orders.exists(_.child.foldable) =>
731733
val newOrders = orders.filterNot(_.child.foldable)
732734
if (newOrders.isEmpty) child else s.copy(order = newOrders)
733-
case Sort(orders, true, child) if child.isSorted && child.sortedOrder.get.zip(orders).forall {
734-
case (s1, s2) => s1.satisfies(s2) } =>
735+
}
736+
}
737+
738+
/**
739+
* Removes Sort operations on already sorted data
740+
*/
741+
object RemoveRedundantSorts extends Rule[LogicalPlan] {
742+
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
743+
case Sort(orders, true, child) if child.sortedOrder.nonEmpty
744+
&& child.sortedOrder.zip(orders).forall { case (s1, s2) => s1.satisfies(s2) } =>
735745
child
736746
}
737747
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -223,9 +223,7 @@ abstract class LogicalPlan
223223
/**
224224
* If the current plan contains sorted data, it contains the sorted order.
225225
*/
226-
def sortedOrder: Option[Seq[SortOrder]] = None
227-
228-
final def isSorted: Boolean = sortedOrder.isDefined
226+
def sortedOrder: Seq[SortOrder] = Nil
229227
}
230228

231229
/**
@@ -283,5 +281,5 @@ abstract class BinaryNode extends LogicalPlan {
283281
}
284282

285283
abstract class KeepOrderUnaryNode extends UnaryNode {
286-
override final def sortedOrder: Option[Seq[SortOrder]] = child.sortedOrder
284+
override final def sortedOrder: Seq[SortOrder] = child.sortedOrder
287285
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -470,7 +470,7 @@ case class Sort(
470470
child: LogicalPlan) extends UnaryNode {
471471
override def output: Seq[Attribute] = child.output
472472
override def maxRows: Option[Long] = child.maxRows
473-
override def sortedOrder: Option[Seq[SortOrder]] = Some(order)
473+
override def sortedOrder: Seq[SortOrder] = order
474474
}
475475

476476
/** Factory for constructing new `Range` nodes. */
@@ -524,6 +524,8 @@ case class Range(
524524
override def computeStats(): Statistics = {
525525
Statistics(sizeInBytes = LongType.defaultSize * numElements)
526526
}
527+
528+
override def sortedOrder: Seq[SortOrder] = output.map(a => SortOrder(a, Descending))
527529
}
528530

529531
case class Aggregate(
@@ -746,7 +748,7 @@ case class GlobalLimit(limitExpr: Expression, child: LogicalPlan) extends KeepOr
746748
*
747749
* See [[Limit]] for more information.
748750
*/
749-
case class LocalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode {
751+
case class LocalLimit(limitExpr: Expression, child: LogicalPlan) extends KeepOrderUnaryNode {
750752
override def output: Seq[Attribute] = child.output
751753

752754
override def maxRowsPerPartition: Option[Long] = {
@@ -870,9 +872,9 @@ case class RepartitionByExpression(
870872
override def maxRows: Option[Long] = child.maxRows
871873
override def shuffle: Boolean = true
872874

873-
override def sortedOrder: Option[Seq[SortOrder]] = partitioning match {
874-
case RangePartitioning(sortedOrder, _) => Some(sortedOrder)
875-
case _ => None
875+
override def sortedOrder: Seq[SortOrder] = partitioning match {
876+
case RangePartitioning(sortedOrder, _) => sortedOrder
877+
case _ => Nil
876878
}
877879
}
878880

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,7 @@ class EliminateSortsSuite extends PlanTest {
3737
val batches =
3838
Batch("Eliminate Sorts", FixedPoint(10),
3939
FoldablePropagation,
40-
EliminateSorts,
41-
CollapseProject) :: Nil
40+
EliminateSorts) :: Nil
4241
}
4342

4443
val testRelation = LocalRelation('a.int, 'b.int, 'c.int)
@@ -84,16 +83,4 @@ class EliminateSortsSuite extends PlanTest {
8483

8584
comparePlans(optimized, correctAnswer)
8685
}
87-
88-
test("remove redundant order by") {
89-
val orderedPlan = testRelation.select('a, 'b).orderBy('a.asc, 'b.desc_nullsFirst)
90-
val unnecessaryReordered = orderedPlan.select('a).orderBy('a.asc, 'b.desc_nullsFirst)
91-
val optimized = Optimize.execute(analyzer.execute(unnecessaryReordered))
92-
val correctAnswer = analyzer.execute(orderedPlan.select('a))
93-
comparePlans(Optimize.execute(optimized), correctAnswer)
94-
val reorderedDifferently = orderedPlan.select('a).orderBy('a.asc, 'b.desc)
95-
val nonOptimized = Optimize.execute(analyzer.execute(reorderedDifferently))
96-
val correctAnswerNonOptimized = analyzer.execute(reorderedDifferently)
97-
comparePlans(nonOptimized, correctAnswerNonOptimized)
98-
}
9986
}
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.optimizer
19+
20+
import org.apache.spark.sql.catalyst.analysis.{Analyzer, EmptyFunctionRegistry}
21+
import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog}
22+
import org.apache.spark.sql.catalyst.dsl.expressions._
23+
import org.apache.spark.sql.catalyst.dsl.plans._
24+
import org.apache.spark.sql.catalyst.expressions._
25+
import org.apache.spark.sql.catalyst.plans._
26+
import org.apache.spark.sql.catalyst.plans.logical._
27+
import org.apache.spark.sql.catalyst.rules._
28+
import org.apache.spark.sql.internal.SQLConf
29+
import org.apache.spark.sql.internal.SQLConf.{CASE_SENSITIVE, ORDER_BY_ORDINAL}
30+
31+
class RemoveRedundantSortsSuite extends PlanTest {
32+
override val conf = new SQLConf().copy(CASE_SENSITIVE -> true, ORDER_BY_ORDINAL -> false)
33+
val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf)
34+
val analyzer = new Analyzer(catalog, conf)
35+
36+
object Optimize extends RuleExecutor[LogicalPlan] {
37+
val batches =
38+
Batch("Remove Redundant Sorts", Once,
39+
RemoveRedundantSorts) ::
40+
Batch("Collapse Project", Once,
41+
CollapseProject) :: Nil
42+
}
43+
44+
val testRelation = LocalRelation('a.int, 'b.int, 'c.int)
45+
46+
test("remove redundant order by") {
47+
val orderedPlan = testRelation.select('a, 'b).orderBy('a.asc, 'b.desc_nullsFirst)
48+
val unnecessaryReordered = orderedPlan.select('a).orderBy('a.asc, 'b.desc_nullsFirst)
49+
val optimized = Optimize.execute(analyzer.execute(unnecessaryReordered))
50+
val correctAnswer = analyzer.execute(orderedPlan.select('a))
51+
comparePlans(Optimize.execute(optimized), correctAnswer)
52+
}
53+
54+
test("do not remove sort if the order is different") {
55+
val orderedPlan = testRelation.select('a, 'b).orderBy('a.asc, 'b.desc_nullsFirst)
56+
val reorderedDifferently = orderedPlan.select('a).orderBy('a.asc, 'b.desc)
57+
val optimized = Optimize.execute(analyzer.execute(reorderedDifferently))
58+
val correctAnswer = analyzer.execute(reorderedDifferently)
59+
comparePlans(optimized, correctAnswer)
60+
}
61+
62+
test("filters don't affect order") {
63+
val orderedPlan = testRelation.select('a, 'b).orderBy('a.asc, 'b.desc)
64+
val filteredAndReordered = orderedPlan.where('a > Literal(10)).orderBy('a.asc, 'b.desc)
65+
val optimized = Optimize.execute(analyzer.execute(filteredAndReordered))
66+
val correctAnswer = analyzer.execute(orderedPlan.where('a > Literal(10)))
67+
comparePlans(optimized, correctAnswer)
68+
}
69+
70+
test("limits don't affect order") {
71+
val orderedPlan = testRelation.select('a, 'b).orderBy('a.asc, 'b.desc)
72+
val filteredAndReordered = orderedPlan.limit(Literal(10)).orderBy('a.asc, 'b.desc)
73+
val optimized = Optimize.execute(analyzer.execute(filteredAndReordered))
74+
val correctAnswer = analyzer.execute(orderedPlan.limit(Literal(10)))
75+
comparePlans(optimized, correctAnswer)
76+
}
77+
78+
test("range is already sorted") {
79+
val inputPlan = Range(1L, 1000L, 1, 10)
80+
val orderedPlan = inputPlan.orderBy('id.desc)
81+
val optimized = Optimize.execute(analyzer.execute(orderedPlan))
82+
val correctAnswer = analyzer.execute(inputPlan)
83+
comparePlans(optimized, correctAnswer)
84+
}
85+
86+
test("sort should not be removed when there is a node which doesn't guarantee any order") {
87+
val orderedPlan = testRelation.select('a, 'b).orderBy('a.asc)
88+
val groupedAndResorted = orderedPlan.groupBy('a)(sum('a)).orderBy('a.asc)
89+
val optimized = Optimize.execute(analyzer.execute(groupedAndResorted))
90+
val correctAnswer = analyzer.execute(groupedAndResorted)
91+
comparePlans(optimized, correctAnswer)
92+
}
93+
}

sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,4 +169,6 @@ case class InMemoryRelation(
169169

170170
override protected def otherCopyArgs: Seq[AnyRef] =
171171
Seq(_cachedColumnBuffers, sizeInBytesStats, statsOfPlanToCache)
172+
173+
override def sortedOrder: Seq[SortOrder] = child.outputOrdering
172174
}

sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,11 @@ import org.apache.spark.sql.{execution, Row}
2222
import org.apache.spark.sql.catalyst.InternalRow
2323
import org.apache.spark.sql.catalyst.expressions._
2424
import org.apache.spark.sql.catalyst.plans.{Cross, FullOuter, Inner, LeftOuter, RightOuter}
25-
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Repartition}
25+
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Repartition, Sort}
2626
import org.apache.spark.sql.catalyst.plans.physical._
2727
import org.apache.spark.sql.execution.columnar.InMemoryRelation
28-
import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReusedExchangeExec, ReuseExchange, ShuffleExchangeExec}
28+
import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReusedExchangeExec, ReuseExchange,
29+
ShuffleExchangeExec}
2930
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, SortMergeJoinExec}
3031
import org.apache.spark.sql.functions._
3132
import org.apache.spark.sql.internal.SQLConf
@@ -197,6 +198,19 @@ class PlannerSuite extends SharedSQLContext {
197198
assert(planned.child.isInstanceOf[CollectLimitExec])
198199
}
199200

201+
test("SPARK-23375: Cached sorted data doesn't need to be re-sorted") {
202+
val query = testData.select('key, 'value).sort('key.desc).cache()
203+
assert(query.queryExecution.optimizedPlan.isInstanceOf[InMemoryRelation])
204+
val resorted = query.sort('key.desc)
205+
assert(resorted.queryExecution.optimizedPlan.collect { case s: Sort => s}.isEmpty)
206+
assert(resorted.select('key).collect().map(_.getInt(0)).toSeq ==
207+
(1 to 100).sorted(Ordering[Int].reverse))
208+
// with a different order, the sort is needed
209+
val sortedAsc = query.sort('key)
210+
assert(sortedAsc.queryExecution.optimizedPlan.collect { case s: Sort => s}.nonEmpty)
211+
assert(sortedAsc.select('key).collect().map(_.getInt(0)).toSeq == (1 to 100))
212+
}
213+
200214
test("PartitioningCollection") {
201215
withTempView("normal", "small", "tiny") {
202216
testData.createOrReplaceTempView("normal")

0 commit comments

Comments
 (0)