Skip to content
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1995,9 +1995,10 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
case Sort(orders, global, child, hint)
if orders.exists(_.child.isInstanceOf[UnresolvedOrdinal]) =>
val newOrders = orders map {
case s @ SortOrder(UnresolvedOrdinal(index), direction, nullOrdering, _) =>
case s @ SortOrder(UnresolvedOrdinal(index), direction, nullOrdering, _, _) =>
if (index > 0 && index <= child.output.size) {
SortOrder(child.output(index - 1), direction, nullOrdering, Seq.empty)
val resolvedCol = child.output(index - 1)
SortOrder(resolvedCol, direction, nullOrdering, Seq.empty, resolvedCol.foldable)
} else {
throw QueryCompilationErrors.orderByPositionRangeError(index, child.output.size, s)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ class SortResolver(operatorResolver: Resolver, expressionResolver: ExpressionRes

private def canOrderByAll(expressions: Seq[SortOrder]): Boolean = {
val isOrderByAll = expressions match {
case Seq(SortOrder(unresolvedAttribute: UnresolvedAttribute, _, _, _)) =>
case Seq(SortOrder(unresolvedAttribute: UnresolvedAttribute, _, _, _, _)) =>
unresolvedAttribute.equalsIgnoreCase("ALL")
case _ => false
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,7 @@ package object dsl extends SQLConfHelper {
* `orderByOrdinal` is enabled.
*/
private def replaceOrdinalsInSortOrder(sortOrder: SortOrder): SortOrder = sortOrder match {
case sortOrderByOrdinal @ SortOrder(literal @ Literal(value: Int, IntegerType), _, _, _)
case sortOrderByOrdinal @ SortOrder(literal @ Literal(value: Int, IntegerType), _, _, _, _)
if conf.orderByOrdinal =>
val ordinal = CurrentOrigin.withOrigin(literal.origin) { UnresolvedOrdinal(value) }
sortOrderByOrdinal
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,8 @@ case class SortOrder(
child: Expression,
direction: SortDirection,
nullOrdering: NullOrdering,
sameOrderExpressions: Seq[Expression])
sameOrderExpressions: Seq[Expression],
isConstant: Boolean)
extends Expression with Unevaluable {

override def children: Seq[Expression] = child +: sameOrderExpressions
Expand All @@ -82,7 +83,7 @@ case class SortOrder(

def satisfies(required: SortOrder): Boolean = {
children.exists(required.child.semanticEquals) &&
direction == required.direction && nullOrdering == required.nullOrdering
(isConstant || direction == required.direction && nullOrdering == required.nullOrdering)
}

override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): SortOrder =
Expand All @@ -94,7 +95,15 @@ object SortOrder {
child: Expression,
direction: SortDirection,
sameOrderExpressions: Seq[Expression] = Seq.empty): SortOrder = {
new SortOrder(child, direction, direction.defaultNullOrdering, sameOrderExpressions)
new SortOrder(child, direction, direction.defaultNullOrdering, sameOrderExpressions, false)
}

def apply(
child: Expression,
direction: SortDirection,
nullOrdering: NullOrdering,
sameOrderExpressions: Seq[Expression]): SortOrder = {
new SortOrder(child, direction, nullOrdering, sameOrderExpressions, false)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -195,9 +195,9 @@ case class Mode(
nodeName, 1, orderingWithinGroup.length)
}
orderingWithinGroup.head match {
case SortOrder(child, Ascending, _, _) =>
case SortOrder(child, Ascending, _, _, _) =>
this.copy(child = child, reverseOpt = Some(true))
case SortOrder(child, Descending, _, _) =>
case SortOrder(child, Descending, _, _, _) =>
this.copy(child = child, reverseOpt = Some(false))
}
case _ => this
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -382,8 +382,8 @@ case class PercentileCont(left: Expression, right: Expression, reverse: Boolean
nodeName, 1, orderingWithinGroup.length)
}
orderingWithinGroup.head match {
case SortOrder(child, Ascending, _, _) => this.copy(left = child)
case SortOrder(child, Descending, _, _) => this.copy(left = child, reverse = true)
case SortOrder(child, Ascending, _, _, _) => this.copy(left = child)
case SortOrder(child, Descending, _, _, _) => this.copy(left = child, reverse = true)
}
}

Expand Down Expand Up @@ -440,8 +440,8 @@ case class PercentileDisc(
nodeName, 1, orderingWithinGroup.length)
}
orderingWithinGroup.head match {
case SortOrder(expr, Ascending, _, _) => this.copy(child = expr)
case SortOrder(expr, Descending, _, _) => this.copy(child = expr, reverse = true)
case SortOrder(expr, Ascending, _, _, _) => this.copy(child = expr)
case SortOrder(expr, Descending, _, _, _) => this.copy(child = expr, reverse = true)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1919,7 +1919,7 @@ object CombineFilters extends Rule[LogicalPlan] with PredicateHelper {
object EliminateSorts extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning(_.containsPattern(SORT)) {
case s @ Sort(orders, _, child, _) if orders.isEmpty || orders.exists(_.child.foldable) =>
val newOrders = orders.filterNot(_.child.foldable)
val newOrders = orders.filterNot(o => o.child.foldable && !o.isConstant)
if (newOrders.isEmpty) {
child
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.plans
import scala.collection.mutable

import org.apache.spark.sql.catalyst.SQLConfHelper
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeSet, Empty2Null, Expression, NamedExpression, SortOrder}
import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Attribute, AttributeSet, Empty2Null, Expression, NamedExpression, SortOrder}
import org.apache.spark.sql.internal.SQLConf

/**
Expand Down Expand Up @@ -128,6 +128,9 @@ trait AliasAwareQueryOutputOrdering[T <: QueryPlan[T]]
}
}
}
newOrdering.takeWhile(_.isDefined).flatten.toSeq
newOrdering.takeWhile(_.isDefined).flatten.toSeq ++ outputExpressions.filter {
case Alias(child, _) => child.foldable
case expr => expr.foldable
}.map(SortOrder(_, Ascending).copy(isConstant = true))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, do we need to add the whole Alias to SortOrder expression or could adding only the generated attribute work?
Also, I wonder if it would be a breaking change to add Constant as a new SortDirection instead of using a boolean flag?

Copy link
Member Author

@pan3793 pan3793 Oct 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@peter-toth the Constant SortDirection sounds like a good idea, I have updated the code to use it.

while I have tried to add only the generated attribute (I left the code as comments)

    newOrdering.takeWhile(_.isDefined).flatten.toSeq ++ outputExpressions.flatMap {
      case alias @ Alias(child, _) if child.foldable =>
        Some(SortOrder(alias.toAttribute, Constant))
      case expr if expr.foldable =>
        Some(SortOrder(expr, Constant))
      case _ => None
    }

there are two tests fail (haven't figured out the root cause)

[info] CachedTableSuite:
...
[info] - SPARK-36120: Support cache/uncache table with TimestampNTZ type *** FAILED *** (43 milliseconds)
[info]   AttributeSet(TIMESTAMP_NTZ '2021-01-01 00:00:00'#17739) was not empty The optimized logical plan has missing inputs:
[info]   InMemoryRelation [TIMESTAMP_NTZ '2021-01-01 00:00:00'#17776], StorageLevel(disk, memory, deserialized, 1 replicas)
[info]      +- *(1) Project [2021-01-01 00:00:00 AS TIMESTAMP_NTZ '2021-01-01 00:00:00'#17739]
[info]         +- *(1) Scan OneRowRelation[] (QueryTest.scala:241)
...
[info] - SPARK-52692: Support cache/uncache table with Time type *** FAILED *** (58 milliseconds)
[info]   AttributeSet(TIME '22:00:00'#18852) was not empty The optimized logical plan has missing inputs:
[info]   InMemoryRelation [TIME '22:00:00'#18889], StorageLevel(disk, memory, deserialized, 1 replicas)
[info]      +- *(1) Project [22:00:00 AS TIME '22:00:00'#18852]
[info]         +- *(1) Scan OneRowRelation[] (QueryTest.scala:241)
...

Copy link
Contributor

@peter-toth peter-toth Oct 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem seems to be that InMemoryRelation.withOutput() doesn't remap outputOrdering. And because outputOrdering is present in InMemoryRelation as case class argument the unmapped ordering attributes are considered missing inputs.

This seems to be another hidden issue with InMemoryRelation.outputOrdering and got exposed with this change.

Copy link
Contributor

@peter-toth peter-toth Oct 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I opened a small PR into this PR: pan3793#2, hopefully it helps fixing the above tests.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@peter-toth Many thanks for your professionalism and patience! I tested locally, and it did fix the issue. Have educated a lot from your review.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It a pleasure working with you @pan3793!

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -912,7 +912,7 @@ case class Sort(
override def maxRowsPerPartition: Option[Long] = {
if (global) maxRows else child.maxRowsPerPartition
}
override def outputOrdering: Seq[SortOrder] = order
override def outputOrdering: Seq[SortOrder] = order ++ child.outputOrdering.filter(_.isConstant)
final override val nodePatterns: Seq[TreePattern] = Seq(SORT)
override protected def withNewChildInternal(newChild: LogicalPlan): Sort = copy(child = newChild)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2274,7 +2274,8 @@ class Dataset[T] private[sql](
protected def sortInternal(global: Boolean, sortExprs: Seq[Column]): Dataset[T] = {
val sortOrder: Seq[SortOrder] = sortExprs.map { col =>
col.expr match {
case sortOrderWithOrdinal @ SortOrder(literal @ Literal(value: Int, IntegerType), _, _, _)
case sortOrderWithOrdinal @ SortOrder(
literal @ Literal(value: Int, IntegerType), _, _, _, _)
if sparkSession.sessionState.conf.orderByOrdinal =>
// Replace top-level integer literals with UnresolvedOrdinal, if `orderByOrdinal` is
// enabled.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ case class SortExec(

override def output: Seq[Attribute] = child.output

override def outputOrdering: Seq[SortOrder] = sortOrder
override def outputOrdering: Seq[SortOrder] =
sortOrder ++ child.outputOrdering.filter(_.isConstant)

// sort performed is local within a given partition so will retain
// child operator's partitioning
Expand All @@ -73,15 +74,17 @@ case class SortExec(
* should make it public.
*/
def createSorter(): UnsafeExternalRowSorter = {
val effectiveSortOrder = sortOrder.filterNot(_.isConstant)

rowSorter = new ThreadLocal[UnsafeExternalRowSorter]()

val ordering = RowOrdering.create(sortOrder, output)

// The comparator for comparing prefix
val boundSortExpression = BindReferences.bindReference(sortOrder.head, output)
val boundSortExpression = BindReferences.bindReference(effectiveSortOrder.head, output)
val prefixComparator = SortPrefixUtils.getPrefixComparator(boundSortExpression)

val canUseRadixSort = enableRadixSort && sortOrder.length == 1 &&
val canUseRadixSort = enableRadixSort && effectiveSortOrder.length == 1 &&
SortPrefixUtils.canSortFullyWithPrefix(boundSortExpression)

// The generator for prefix
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -823,7 +823,7 @@ object DataSourceStrategy

protected[sql] def translateSortOrders(sortOrders: Seq[SortOrder]): Seq[V2SortOrder] = {
def translateSortOrder(sortOrder: SortOrder): Option[V2SortOrder] = sortOrder match {
case SortOrder(PushableExpression(expr), directionV1, nullOrderingV1, _) =>
case SortOrder(PushableExpression(expr), directionV1, nullOrderingV1, _, _) =>
val directionV2 = directionV1 match {
case Ascending => SortDirection.ASCENDING
case Descending => SortDirection.DESCENDING
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -202,10 +202,24 @@ object V1WritesUtils {
def isOrderingMatched(
requiredOrdering: Seq[Expression],
outputOrdering: Seq[SortOrder]): Boolean = {
if (requiredOrdering.length > outputOrdering.length) {
val (constantOutputOrdering, nonConstantOutputOrdering) = outputOrdering.partition {
case SortOrder(child, _, _, _, isConstant) => isConstant || child.foldable
}

val effectiveRequiredOrdering = requiredOrdering.filterNot { requiredOrder =>
constantOutputOrdering.exists {
case s @ SortOrder(alias: Alias, _, _, _, true) =>
val outputOrder = s.copy(child = alias.toAttribute)
outputOrder.satisfies(outputOrder.copy(child = requiredOrder))
case outputOrder =>
outputOrder.satisfies(outputOrder.copy(child = requiredOrder))
}
}

if (effectiveRequiredOrdering.length > nonConstantOutputOrdering.length) {
false
} else {
requiredOrdering.zip(outputOrdering).forall {
effectiveRequiredOrdering.zip(nonConstantOutputOrdering).forall {
case (requiredOrder, outputOrder) =>
outputOrder.satisfies(outputOrder.copy(child = requiredOrder))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,11 @@ class ReplaceIntegerLiteralsWithOrdinalsDataframeSuite extends QueryTest with Sh
val resolvedPlan = query.queryExecution.analyzed

assert(unresolvedPlan.expressions.collect {
case ordinal @ SortOrder(UnresolvedOrdinal(1), _, _, _) => ordinal
case ordinal @ SortOrder(UnresolvedOrdinal(1), _, _, _, _) => ordinal
}.nonEmpty)

assert(resolvedPlan.expressions.collect {
case ordinal @ SortOrder(UnresolvedOrdinal(1), _, _, _) => ordinal
case ordinal @ SortOrder(UnresolvedOrdinal(1), _, _, _, _) => ordinal
}.isEmpty)

checkAnswer(query, Row(1, 2) :: Row(2, 1) :: Nil)
Expand Down Expand Up @@ -100,11 +100,11 @@ class ReplaceIntegerLiteralsWithOrdinalsDataframeSuite extends QueryTest with Sh
val resolvedPlan = query.queryExecution.analyzed

assert(unresolvedPlan.expressions.collect {
case ordinal @ SortOrder(UnresolvedOrdinal(1), _, _, _) => ordinal
case ordinal @ SortOrder(UnresolvedOrdinal(1), _, _, _, _) => ordinal
}.isEmpty)

assert(resolvedPlan.expressions.collect {
case ordinal @ SortOrder(UnresolvedOrdinal(1), _, _, _) => ordinal
case ordinal @ SortOrder(UnresolvedOrdinal(1), _, _, _, _) => ordinal
}.isEmpty)

checkAnswer(query, Row(2, 1) :: Row(1, 2) :: Nil)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,11 +77,11 @@ class ReplaceIntegerLiteralsWithOrdinalsSqlSuite extends QueryTest with SharedSp
val analyzedPlan = query.queryExecution.analyzed

assert(parsedPlan.expressions.collect {
case ordinal @ SortOrder(UnresolvedOrdinal(1), _, _, _) => ordinal
case ordinal @ SortOrder(UnresolvedOrdinal(1), _, _, _, _) => ordinal
}.nonEmpty)

assert(analyzedPlan.expressions.collect {
case ordinal @ SortOrder(UnresolvedOrdinal(1), _, _, _) => ordinal
case ordinal @ SortOrder(UnresolvedOrdinal(1), _, _, _, _) => ordinal
}.isEmpty)

checkAnswer(query, Row(1) :: Row(2) :: Nil)
Expand All @@ -100,11 +100,11 @@ class ReplaceIntegerLiteralsWithOrdinalsSqlSuite extends QueryTest with SharedSp
val analyzedPlan = query.queryExecution.analyzed

assert(parsedPlan.expressions.collect {
case ordinal @ SortOrder(UnresolvedOrdinal(1), _, _, _) => ordinal
case ordinal @ SortOrder(UnresolvedOrdinal(1), _, _, _, _) => ordinal
}.isEmpty)

assert(analyzedPlan.expressions.collect {
case ordinal @ SortOrder(UnresolvedOrdinal(1), _, _, _) => ordinal
case ordinal @ SortOrder(UnresolvedOrdinal(1), _, _, _, _) => ordinal
}.isEmpty)

checkAnswer(query, Row(2) :: Row(1) :: Nil)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,8 @@ class ColumnNodeToExpressionConverterSuite extends SparkFunSuite {
analysis.UnresolvedAttribute("unsorted"),
catDirection,
catNullOrdering,
Nil))
Nil,
false))
}

test("sortOrder") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1013,7 +1013,7 @@ class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper {

val projects = collect(planned) { case p: ProjectExec => p }
assert(projects.exists(_.outputPartitioning match {
case RangePartitioning(Seq(SortOrder(ar: AttributeReference, _, _, _)), _) =>
case RangePartitioning(Seq(SortOrder(ar: AttributeReference, _, _, _, _)), _) =>
ar.name == "id1"
case _ => false
}))
Expand Down Expand Up @@ -1121,7 +1121,7 @@ class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper {

val projects = collect(planned) { case p: ProjectExec => p }
assert(projects.exists(_.outputOrdering match {
case Seq(s @ SortOrder(_, Ascending, NullsFirst, _)) =>
case Seq(s @ SortOrder(_, Ascending, NullsFirst, _, _)) =>
s.children.map(_.asInstanceOf[AttributeReference].name).toSet == Set("t2id", "t3id")
case _ => false
}))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,23 @@ trait V1WriteCommandSuiteBase extends SQLTestUtils with AdaptiveSparkPlanHelper
hasLogicalSort: Boolean,
orderingMatched: Boolean,
hasEmpty2Null: Boolean = false)(query: => Unit): Unit = {
var optimizedPlan: LogicalPlan = null
executeAndCheckOrderingAndCustomValidate(
hasLogicalSort, orderingMatched, hasEmpty2Null)(query)(_ => ())
}

/**
* Execute a write query and check ordering of the plan, then do custom validation
*/
protected def executeAndCheckOrderingAndCustomValidate(
hasLogicalSort: Boolean,
orderingMatched: Boolean,
hasEmpty2Null: Boolean = false)(query: => Unit)(
customValidate: LogicalPlan => Unit): Unit = {
@volatile var optimizedPlan: LogicalPlan = null

val listener = new QueryExecutionListener {
override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = {
val conf = qe.sparkSession.sessionState.conf
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a bugfix, the listener runs in another thread, without this change, conf.getConf actually gets conf from the thread local, thus may cause issues on concurrency running tests

qe.optimizedPlan match {
case w: V1WriteCommand =>
if (hasLogicalSort && conf.getConf(SQLConf.PLANNED_WRITE_ENABLED)) {
Expand All @@ -87,7 +100,8 @@ trait V1WriteCommandSuiteBase extends SQLTestUtils with AdaptiveSparkPlanHelper

// Check whether the output ordering is matched before FileFormatWriter executes rdd.
assert(FileFormatWriter.outputOrderingMatched == orderingMatched,
s"Expect: $orderingMatched, Actual: ${FileFormatWriter.outputOrderingMatched}")
s"Expect orderingMatched: $orderingMatched, " +
s"Actual: ${FileFormatWriter.outputOrderingMatched}")

sparkContext.listenerBus.waitUntilEmpty()

Expand All @@ -103,6 +117,8 @@ trait V1WriteCommandSuiteBase extends SQLTestUtils with AdaptiveSparkPlanHelper
assert(empty2nullExpr == hasEmpty2Null,
s"Expect hasEmpty2Null: $hasEmpty2Null, Actual: $empty2nullExpr. Plan:\n$optimizedPlan")

customValidate(optimizedPlan)

spark.listenerManager.unregister(listener)
}
}
Expand Down Expand Up @@ -228,8 +244,8 @@ class V1WriteCommandSuite extends QueryTest with SharedSparkSession with V1Write
case s: SortExec => s
}.exists {
case SortExec(Seq(
SortOrder(AttributeReference("key", IntegerType, _, _), Ascending, NullsFirst, _),
SortOrder(AttributeReference("value", StringType, _, _), Ascending, NullsFirst, _)
SortOrder(AttributeReference("key", IntegerType, _, _), Ascending, NullsFirst, _, _),
SortOrder(AttributeReference("value", StringType, _, _), Ascending, NullsFirst, _, _)
), false, _, _) => true
case _ => false
}, plan)
Expand Down Expand Up @@ -275,8 +291,8 @@ class V1WriteCommandSuite extends QueryTest with SharedSparkSession with V1Write
case s: SortExec => s
}.exists {
case SortExec(Seq(
SortOrder(AttributeReference("value", StringType, _, _), Ascending, NullsFirst, _),
SortOrder(AttributeReference("key", IntegerType, _, _), Ascending, NullsFirst, _)
SortOrder(AttributeReference("value", StringType, _, _), Ascending, NullsFirst, _, _),
SortOrder(AttributeReference("key", IntegerType, _, _), Ascending, NullsFirst, _, _)
), false, _, _) => true
case _ => false
}, plan)
Expand Down Expand Up @@ -391,4 +407,30 @@ class V1WriteCommandSuite extends QueryTest with SharedSparkSession with V1Write
}
}
}

test("v1 write with sort by literal column preserve custom order") {
withPlannedWrite { _ =>
withTable("t") {
sql(
"""
|CREATE TABLE t(i INT, j INT, k STRING) USING PARQUET
|PARTITIONED BY (k)
|""".stripMargin)
executeAndCheckOrderingAndCustomValidate(hasLogicalSort = true, orderingMatched = true) {
sql(
"""
|INSERT OVERWRITE t
|SELECT i, j, '0' as k FROM t0 SORT BY k, i
|""".stripMargin)
} { optimizedPlan =>
assert {
optimizedPlan.outputOrdering.exists {
case SortOrder(attr: AttributeReference, _, _, _, _) => attr.name == "i"
case _ => false
}
}
}
}
}
}
}
Loading