Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -879,13 +879,15 @@ object EliminateSorts extends Rule[LogicalPlan] {

/**
* Removes redundant Sort operation. This can happen:
* 1) if the child is already sorted
* 1) if the Sort operator is a local sort and the child is already sorted
* 2) if there is another Sort operator separated by 0...n Project/Filter operators
*/
object RemoveRedundantSorts extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transformDown {
case Sort(orders, true, child) if SortOrder.orderingSatisfies(child.outputOrdering, orders) =>
child
def apply(plan: LogicalPlan): LogicalPlan = plan transform applyLocally

private val applyLocally: PartialFunction[LogicalPlan, LogicalPlan] = {
case Sort(orders, false, child) if SortOrder.orderingSatisfies(child.outputOrdering, orders) =>
applyLocally.lift(child).getOrElse(child)
case s @ Sort(_, _, child) => s.copy(child = recursiveRemoveSort(child))
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -850,6 +850,12 @@ object SQLConf {
.booleanConf
.createWithDefault(true)

val REMOVE_REDUNDANT_SORTS_ENABLED = buildConf("spark.sql.execution.removeRedundantSorts")
.internal()
.doc("Whether to remove redundant physical sort node")
.booleanConf
.createWithDefault(true)

val STATE_STORE_PROVIDER_CLASS =
buildConf("spark.sql.streaming.stateStore.providerClass")
.internal()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,4 +92,11 @@ class EliminateSortsSuite extends PlanTest {
val correctAnswer = distributedPlan.analyze
comparePlans(optimized, correctAnswer)
}

test("SPARK-33183: remove consecutive no-op sorts") {
val plan = testRelation.orderBy().orderBy().orderBy()
val optimized = Optimize.execute(plan.analyze)
val correctAnswer = testRelation.analyze
comparePlans(optimized, correctAnswer)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,27 @@ class RemoveRedundantSortsSuite extends PlanTest {

val testRelation = LocalRelation('a.int, 'b.int, 'c.int)

test("remove redundant order by") {
test("SPARK-33183: remove redundant sort by") {
val orderedPlan = testRelation.select('a, 'b).orderBy('a.asc, 'b.desc_nullsFirst)
val unnecessaryReordered = orderedPlan.limit(2).select('a).orderBy('a.asc, 'b.desc_nullsFirst)
val unnecessaryReordered = orderedPlan.limit(2).select('a).sortBy('a.asc, 'b.desc_nullsFirst)
val optimized = Optimize.execute(unnecessaryReordered.analyze)
val correctAnswer = orderedPlan.limit(2).select('a).analyze
comparePlans(Optimize.execute(optimized), correctAnswer)
comparePlans(optimized, correctAnswer)
}

test("SPARK-33183: remove all redundant local sorts") {
val orderedPlan = testRelation.sortBy('a.asc).orderBy('a.asc).sortBy('a.asc)
val optimized = Optimize.execute(orderedPlan.analyze)
val correctAnswer = testRelation.orderBy('a.asc).analyze
comparePlans(optimized, correctAnswer)
}

test("SPARK-33183: should not remove global sort") {
val orderedPlan = testRelation.select('a, 'b).orderBy('a.asc, 'b.desc_nullsFirst)
val reordered = orderedPlan.limit(2).select('a).orderBy('a.asc, 'b.desc_nullsFirst)
val optimized = Optimize.execute(reordered.analyze)
val correctAnswer = reordered.analyze
comparePlans(optimized, correctAnswer)
}

test("do not remove sort if the order is different") {
Expand All @@ -52,22 +67,39 @@ class RemoveRedundantSortsSuite extends PlanTest {
comparePlans(optimized, correctAnswer)
}

test("filters don't affect order") {
test("SPARK-33183: remove top level local sort with filter operators") {
val orderedPlan = testRelation.select('a, 'b).orderBy('a.asc, 'b.desc)
val filteredAndReordered = orderedPlan.where('a > Literal(10)).orderBy('a.asc, 'b.desc)
val filteredAndReordered = orderedPlan.where('a > Literal(10)).sortBy('a.asc, 'b.desc)
val optimized = Optimize.execute(filteredAndReordered.analyze)
val correctAnswer = orderedPlan.where('a > Literal(10)).analyze
comparePlans(optimized, correctAnswer)
}

test("limits don't affect order") {
test("SPARK-33183: keep top level global sort with filter operators") {
val projectPlan = testRelation.select('a, 'b)
val orderedPlan = projectPlan.orderBy('a.asc, 'b.desc)
val filteredAndReordered = orderedPlan.where('a > Literal(10)).orderBy('a.asc, 'b.desc)
val optimized = Optimize.execute(filteredAndReordered.analyze)
val correctAnswer = projectPlan.where('a > Literal(10)).orderBy('a.asc, 'b.desc).analyze
comparePlans(optimized, correctAnswer)
}

test("SPARK-33183: limits should not affect order for local sort") {
val orderedPlan = testRelation.select('a, 'b).orderBy('a.asc, 'b.desc)
val filteredAndReordered = orderedPlan.limit(Literal(10)).orderBy('a.asc, 'b.desc)
val filteredAndReordered = orderedPlan.limit(Literal(10)).sortBy('a.asc, 'b.desc)
val optimized = Optimize.execute(filteredAndReordered.analyze)
val correctAnswer = orderedPlan.limit(Literal(10)).analyze
comparePlans(optimized, correctAnswer)
}

test("SPARK-33183: should not remove global sort with limit operators") {
val orderedPlan = testRelation.select('a, 'b).orderBy('a.asc, 'b.desc)
val filteredAndReordered = orderedPlan.limit(Literal(10)).orderBy('a.asc, 'b.desc)
val optimized = Optimize.execute(filteredAndReordered.analyze)
val correctAnswer = filteredAndReordered.analyze
comparePlans(optimized, correctAnswer)
}

test("different sorts are not simplified if limit is in between") {
val orderedPlan = testRelation.select('a, 'b).orderBy('b.desc).limit(Literal(10))
.orderBy('a.asc)
Expand All @@ -76,11 +108,11 @@ class RemoveRedundantSortsSuite extends PlanTest {
comparePlans(optimized, correctAnswer)
}

test("range is already sorted") {
test("SPARK-33183: should not remove global sort with range operator") {
val inputPlan = Range(1L, 1000L, 1, 10)
val orderedPlan = inputPlan.orderBy('id.asc)
val optimized = Optimize.execute(orderedPlan.analyze)
val correctAnswer = inputPlan.analyze
val correctAnswer = orderedPlan.analyze
comparePlans(optimized, correctAnswer)

val reversedPlan = inputPlan.orderBy('id.desc)
Expand All @@ -91,10 +123,18 @@ class RemoveRedundantSortsSuite extends PlanTest {
val negativeStepInputPlan = Range(10L, 1L, -1, 10)
val negativeStepOrderedPlan = negativeStepInputPlan.orderBy('id.desc)
val negativeStepOptimized = Optimize.execute(negativeStepOrderedPlan.analyze)
val negativeStepCorrectAnswer = negativeStepInputPlan.analyze
val negativeStepCorrectAnswer = negativeStepOrderedPlan.analyze
comparePlans(negativeStepOptimized, negativeStepCorrectAnswer)
}

test("SPARK-33183: remove local sort with range operator") {
val inputPlan = Range(1L, 1000L, 1, 10)
val orderedPlan = inputPlan.sortBy('id.asc)
val optimized = Optimize.execute(orderedPlan.analyze)
val correctAnswer = inputPlan.analyze
comparePlans(optimized, correctAnswer)
}

test("sort should not be removed when there is a node which doesn't guarantee any order") {
val orderedPlan = testRelation.select('a, 'b).orderBy('a.asc)
val groupedAndResorted = orderedPlan.groupBy('a)(sum('a)).orderBy('a.asc)
Expand Down Expand Up @@ -135,4 +175,39 @@ class RemoveRedundantSortsSuite extends PlanTest {
.select(('b + 1).as('c)).orderBy('c.asc).analyze
comparePlans(optimizedThrice, correctAnswerThrice)
}

test("SPARK-33183: remove consecutive global sorts with the same ordering") {
Seq(
(testRelation.orderBy('a.asc).orderBy('a.asc), testRelation.orderBy('a.asc)),
(testRelation.orderBy('a.asc, 'b.desc).orderBy('a.asc), testRelation.orderBy('a.asc))
).foreach { case (ordered, answer) =>
val optimized = Optimize.execute(ordered.analyze)
comparePlans(optimized, answer.analyze)
}
}

test("SPARK-33183: remove consecutive local sorts with the same ordering") {
val orderedPlan = testRelation.sortBy('a.asc).sortBy('a.asc).sortBy('a.asc)
val optimized = Optimize.execute(orderedPlan.analyze)
val correctAnswer = testRelation.sortBy('a.asc).analyze
comparePlans(optimized, correctAnswer)
}

test("SPARK-33183: remove consecutive local sorts with different ordering") {
val orderedPlan = testRelation.sortBy('b.asc).sortBy('a.desc).sortBy('a.asc)
val optimized = Optimize.execute(orderedPlan.analyze)
val correctAnswer = testRelation.sortBy('a.asc).analyze
comparePlans(optimized, correctAnswer)
}

test("SPARK-33183: should keep global sort when child is a local sort with the same ordering") {
val correctAnswer = testRelation.orderBy('a.asc).analyze
Seq(
testRelation.sortBy('a.asc).orderBy('a.asc),
testRelation.orderBy('a.asc).sortBy('a.asc).orderBy('a.asc)
).foreach { ordered =>
val optimized = Optimize.execute(ordered.analyze)
comparePlans(optimized, correctAnswer)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) {
/** A sequence of rules that will be applied in order to the physical plan before execution. */
protected def preparations: Seq[Rule[SparkPlan]] = Seq(
PlanSubqueries(sparkSession),
RemoveRedundantSorts(sparkSession.sessionState.conf),
EnsureRequirements(sparkSession.sessionState.conf),
CollapseCodegenStages(sparkSession.sessionState.conf),
ReuseExchange(sparkSession.sessionState.conf),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution

import org.apache.spark.sql.catalyst.expressions.SortOrder
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.internal.SQLConf

/**
* Remove redundant SortExec node from the spark plan. A sort node is redundant when
* its child satisfies both its sort orders and its required child distribution. Note
* this rule differs from the Optimizer rule EliminateSorts in that this rule also checks
* if the child satisfies the required distribution so that it is safe to remove not only a
* local sort but also a global sort when its child already satisfies required sort orders.
*/
case class RemoveRedundantSorts(conf: SQLConf) extends Rule[SparkPlan] {
def apply(plan: SparkPlan): SparkPlan = {
if (!conf.getConf(SQLConf.REMOVE_REDUNDANT_SORTS_ENABLED)) {
plan
} else {
removeSorts(plan)
}
}

private def removeSorts(plan: SparkPlan): SparkPlan = plan transform {
case s @ SortExec(orders, _, child, _)
if SortOrder.orderingSatisfies(child.outputOrdering, orders) &&
child.outputPartitioning.satisfies(s.requiredChildDistribution.head) =>
child
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -230,19 +230,6 @@ class PlannerSuite extends SharedSQLContext {
}
}

test("SPARK-23375: Cached sorted data doesn't need to be re-sorted") {
val query = testData.select('key, 'value).sort('key.desc).cache()
assert(query.queryExecution.optimizedPlan.isInstanceOf[InMemoryRelation])
val resorted = query.sort('key.desc)
assert(resorted.queryExecution.optimizedPlan.collect { case s: Sort => s}.isEmpty)
assert(resorted.select('key).collect().map(_.getInt(0)).toSeq ==
(1 to 100).reverse)
// with a different order, the sort is needed
val sortedAsc = query.sort('key)
assert(sortedAsc.queryExecution.optimizedPlan.collect { case s: Sort => s}.size == 1)
assert(sortedAsc.select('key).collect().map(_.getInt(0)).toSeq == (1 to 100))
}

test("PartitioningCollection") {
withTempView("normal", "small", "tiny") {
testData.createOrReplaceTempView("normal")
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution

import org.apache.spark.sql.{DataFrame, QueryTest}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession


class RemoveRedundantSortsSuite
extends QueryTest
with SharedSparkSession {
import testImplicits._

private def checkNumSorts(df: DataFrame, count: Int): Unit = {
val plan = df.queryExecution.executedPlan
assert(plan.collect { case s: SortExec => s }.length == count)
}

private def checkSorts(query: String, enabledCount: Int, disabledCount: Int): Unit = {
withSQLConf(SQLConf.REMOVE_REDUNDANT_SORTS_ENABLED.key -> "true") {
val df = sql(query)
checkNumSorts(df, enabledCount)
val result = df.collect()
withSQLConf(SQLConf.REMOVE_REDUNDANT_SORTS_ENABLED.key -> "false") {
val df = sql(query)
checkNumSorts(df, disabledCount)
checkAnswer(df, result)
}
}
}

test("remove redundant sorts with limit") {
withTempView("t") {
spark.range(100).select('id as "key").createOrReplaceTempView("t")
val query =
"""
|SELECT key FROM
| (SELECT key FROM t WHERE key > 10 ORDER BY key DESC LIMIT 10)
|ORDER BY key DESC
|""".stripMargin
checkSorts(query, 0, 1)
}
}

test("remove redundant sorts with sort merge join") {
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
withTempView("t1", "t2") {
spark.range(1000).select('id as "key").createOrReplaceTempView("t1")
spark.range(1000).select('id as "key").createOrReplaceTempView("t2")
val query = """
|SELECT t1.key FROM
| (SELECT key FROM t1 WHERE key > 10 ORDER BY key DESC LIMIT 10) t1
|JOIN
| (SELECT key FROM t2 WHERE key > 50 ORDER BY key DESC LIMIT 100) t2
|ON t1.key = t2.key
|ORDER BY t1.key
""".stripMargin

val queryAsc = query + " ASC"
checkSorts(queryAsc, 2, 3)

// The top level sort should not be removed since the child output ordering is ASC and
// the required ordering is DESC.
val queryDesc = query + " DESC"
checkSorts(queryDesc, 3, 3)
}
}
}

test("cached sorted data doesn't need to be re-sorted") {
withSQLConf(SQLConf.REMOVE_REDUNDANT_SORTS_ENABLED.key -> "true") {
val df = spark.range(1000).select('id as "key").sort('key.desc).cache()
val resorted = df.sort('key.desc)
val sortedAsc = df.sort('key.asc)
checkNumSorts(df, 0)
checkNumSorts(resorted, 0)
checkNumSorts(sortedAsc, 1)
val result = resorted.collect()
withSQLConf(SQLConf.REMOVE_REDUNDANT_SORTS_ENABLED.key -> "false") {
val resorted = df.sort('key.desc)
checkNumSorts(resorted, 1)
checkAnswer(resorted, result)
}
}
}
}