Skip to content
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -572,98 +572,64 @@ class Analyzer(
// Skip sort with aggregate. This will be handled in ResolveAggregateFunctions
case sa @ Sort(_, _, child: Aggregate) => sa

case s @ Sort(_, _, child) if !s.resolved && child.resolved =>
val (newOrdering, missingResolvableAttrs) = collectResolvableMissingAttrs(s.order, child)

if (missingResolvableAttrs.isEmpty) {
val unresolvableAttrs = s.order.filterNot(_.resolved)
logDebug(s"Failed to find $unresolvableAttrs in ${child.output.mkString(", ")}")
s // Nothing we can do here. Return original plan.
} else {
// Add the missing attributes into projectList of Project/Window or
// aggregateExpressions of Aggregate, if they are in the inputSet
// but not in the outputSet of the plan.
val newChild = child transformUp {
case p: Project =>
p.copy(projectList = p.projectList ++
missingResolvableAttrs.filter((p.inputSet -- p.outputSet).contains))
case w: Window =>
w.copy(projectList = w.projectList ++
missingResolvableAttrs.filter((w.inputSet -- w.outputSet).contains))
case a: Aggregate =>
val resolvableAttrs = missingResolvableAttrs.filter(a.groupingExpressions.contains)
val notResolvedAttrs = resolvableAttrs.filterNot(a.aggregateExpressions.contains)
val newAggregateExpressions = a.aggregateExpressions ++ notResolvedAttrs
a.copy(aggregateExpressions = newAggregateExpressions)
case o => o
}

case s @ Sort(order, _, child) if !s.resolved && child.resolved =>
val newOrder = order.map(resolveExpressionRecursively(_, child).asInstanceOf[SortOrder])
val requiredAttrs = AttributeSet(newOrder).filter(_.resolved)
val missingAttrs = requiredAttrs -- child.outputSet
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this solution could skip Subquery, we might need to add/change qualifiers for the missingAttrs, if necessary.

if (missingAttrs.nonEmpty) {
// Add missing attributes and then project them away after the sort.
Project(child.output,
Sort(newOrdering, s.global, newChild))
Sort(newOrder, s.global, addMissingAttr(child, missingAttrs)))
} else if (newOrder != order) {
s.copy(order = newOrder)
} else {
s
}
}

/**
* Traverse the tree until resolving the sorting attributes
* Return all the resolvable missing sorting attributes
*/
@tailrec
private def collectResolvableMissingAttrs(
ordering: Seq[SortOrder],
plan: LogicalPlan): (Seq[SortOrder], Seq[Attribute]) = {
* Add the missing attributes into projectList of Project/Window or aggregateExpressions of
* Aggregate.
*/
private def addMissingAttr(plan: LogicalPlan, missingAttrs: AttributeSet): LogicalPlan = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not using tail recursion. When the tree is large, we might hit stack overflow. I am fine if this is not a concern anymore.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because it can not go over JOIN, it's very uncommon to have thousands of unary nodes in practice.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It makes sense to me. Thanks!

if (missingAttrs.isEmpty) {
return plan
}
plan match {
// Only Windows and Project have projectList-like attribute.
case un: UnaryNode if un.isInstanceOf[Project] || un.isInstanceOf[Window] =>
val (newOrdering, missingAttrs) = resolveAndFindMissing(ordering, un, un.child)
// If missingAttrs is non empty, that means we got it and return it;
// Otherwise, continue to traverse the tree.
if (missingAttrs.nonEmpty) {
(newOrdering, missingAttrs)
} else {
collectResolvableMissingAttrs(ordering, un.child)
}
case p: Project =>
val missing = missingAttrs -- p.child.outputSet
Project(p.projectList ++ missingAttrs, addMissingAttr(p.child, missing))
case w: Window =>
val missing = missingAttrs -- w.child.outputSet
w.copy(projectList = w.projectList ++ missingAttrs,
child = addMissingAttr(w.child, missing))
case a: Aggregate =>
val (newOrdering, missingAttrs) = resolveAndFindMissing(ordering, a, a.child)
// For Aggregate, all the order by columns must be specified in group by clauses
if (missingAttrs.nonEmpty &&
missingAttrs.forall(ar => a.groupingExpressions.exists(_.semanticEquals(ar)))) {
(newOrdering, missingAttrs)
} else {
// If missingAttrs is empty, we are unable to resolve any unresolved missing attributes
(Seq.empty[SortOrder], Seq.empty[Attribute])
// all the missing attributes should be grouping expressions
// TODO: push down AggregateExpression
missingAttrs.foreach { attr =>
if (!a.groupingExpressions.contains(attr)) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

contains will call equals instead of semanticEquals. Thus, it might not work if we cross Subquery.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should use a.groupingExprs.exists(_.semanticEquals(attr)) here, to be consistent with other places that do the similar checking like CheckAnalysis

throw new AnalysisException(s"Can't add $attr to ${a.simpleString}")
}
}
// Jump over the following UnaryNode types
// The output of these types is the same as their child's output
case _: Distinct |
_: Filter |
_: RepartitionByExpression =>
collectResolvableMissingAttrs(ordering, plan.asInstanceOf[UnaryNode].child)
// If hitting the other unsupported operators, we are unable to resolve it.
case other => (Seq.empty[SortOrder], Seq.empty[Attribute])
val newAggregateExpressions = a.aggregateExpressions ++ missingAttrs
a.copy(aggregateExpressions = newAggregateExpressions)
case u: UnaryNode =>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It sounds like we want to cover all the UnaryNode here? This is different from what we discussed. I am fine if you want to add more supports (e.g., crossing the boundary of subquery), but we might need to add more test cases to ensure it does not break anything.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If needed, I can add more test cases after this is merged, since it could be time-consuming. Please feel free to let me know.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will exclude Subquery, feel free to add more tests.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, will do.

u.withNewChildren(addMissingAttr(u.child, missingAttrs) :: Nil)
case other =>
throw new AnalysisException(s"Can't add $missingAttrs to $other")
}
}

/**
* Try to resolve the sort ordering and returns it with a list of attributes that are missing
* from the plan but are present in the child.
*/
private def resolveAndFindMissing(
ordering: Seq[SortOrder],
plan: LogicalPlan,
child: LogicalPlan): (Seq[SortOrder], Seq[Attribute]) = {
val newOrdering =
ordering.map(order => resolveExpression(order, child).asInstanceOf[SortOrder])
// Construct a set that contains all of the attributes that we need to evaluate the
// ordering.
val requiredAttributes = AttributeSet(newOrdering).filter(_.resolved)
// Figure out which ones are missing from the projection, so that we can add them and
// remove them after the sort.
val missingInProject = requiredAttributes -- plan.outputSet
// It is important to return the new SortOrders here, instead of waiting for the standard
// resolving process as adding attributes to the project below can actually introduce
// ambiguity that was not present before.
(newOrdering, missingInProject.toSeq)
private def resolveExpressionRecursively(expr: Expression, plan: LogicalPlan): Expression = {
val resolved = resolveExpression(expr, plan)
if (resolved.resolved) {
resolved
} else {
plan match {
case u: UnaryNode => resolveExpressionRecursively(resolved, u.child)
case other => resolved
}
}
}
}

Expand Down Expand Up @@ -753,8 +719,7 @@ class Analyzer(
filter
}

case sort @ Sort(sortOrder, global, aggregate: Aggregate)
if aggregate.resolved =>
case sort @ Sort(sortOrder, global, aggregate: Aggregate) if aggregate.resolved =>

// Try resolving the ordering as though it is in the aggregate clause.
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ class AnalysisSuite extends AnalysisTest {
.where(a > "str").select(a, b, c)
.where(b > "str").select(a, b, c)
.sortBy(b.asc, c.desc)
.select(a, b).select(a)
.select(a)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Based on this test case, it sounds like the previous PR can cover the case of two missing attributes. Do you know why Q98 still has an issue? Thanks!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It depends on what the missing attributes are, checkout the added regression test.

checkAnalysis(plan1, expected1)

// Case 2: all the missing attributes are in the leaf node
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -978,6 +978,21 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
("d", 1),
("c", 2)
).map(i => Row(i._1, i._2)))

checkAnswer(
sql(
"""
|select area, sum(product) / sum(sum(product)) over (partition by area) as c1
|from windowData group by area, month order by month, c1
""".stripMargin),
Seq(
("d", 1.0),
("a", 1.0),
("b", 0.4666666666666667),
("b", 0.5333333333333333),
("c", 0.45),
("c", 0.55)
).map(i => Row(i._1, i._2)))
}

// todo: fix this test case by reimplementing the function ResolveAggregateFunctions
Expand Down