Skip to content

Commit cdc390a

Browse files
committed
code review
1 parent 7f34e08 commit cdc390a

4 files changed

Lines changed: 62 additions & 36 deletions

File tree

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2432,6 +2432,10 @@ class Analyzer(
24322432
nondeterToAttr.get(e).map(_.toAttribute).getOrElse(e)
24332433
}.copy(child = newChild)
24342434

2435+
// Don't touch collect metrics. Top-level metrics are not supported (check analysis will fail)
2436+
// and we want to retain them inside the aggregate functions.
2437+
case m: CollectMetrics => m
2438+
24352439
// todo: It's hard to write a general rule to pull out nondeterministic expressions
24362440
// from LogicalPlan, currently we only do it for UnaryNode which has same output
24372441
// schema with its child.

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,41 @@ trait CheckAnalysis extends PredicateHelper {
281281
groupingExprs.foreach(checkValidGroupingExprs)
282282
aggregateExprs.foreach(checkValidAggregateExpression)
283283

284+
case CollectMetrics(name, metrics, _) =>
285+
if (name == null || name.isEmpty) {
286+
operator.failAnalysis(s"observed metrics should be named: $operator")
287+
}
288+
// Check if an expression is a valid metric. A metric must meet the following criteria:
289+
// - Is not a window function;
290+
// - Is not nested aggregate function;
291+
// - Is not a distinct aggregate function;
292+
// - Has only non-deterministic functions that are nested inside an aggregate function;
293+
// - Has only attributes that are nested inside an aggregate function.
294+
def checkMetric(s: Expression, e: Expression, seenAggregate: Boolean = false): Unit = {
295+
e match {
296+
case _: WindowExpression =>
297+
e.failAnalysis(
298+
"window expressions are not allowed in observed metrics, but found: " + s.sql)
299+
case _ if !e.deterministic && !seenAggregate =>
300+
e.failAnalysis(s"non-deterministic expression ${s.sql} can only be used " +
301+
"as an argument to an aggregate function.")
302+
case a: AggregateExpression if seenAggregate =>
303+
e.failAnalysis(
304+
"nested aggregates are not allowed in observed metrics, but found: " + s.sql)
305+
case a: AggregateExpression if a.isDistinct =>
306+
e.failAnalysis(
307+
"distinct aggregates are not allowed in observed metrics, but found: " + s.sql)
308+
case _: Attribute if !seenAggregate =>
309+
e.failAnalysis (s"attribute ${s.sql} can only be used as an argument to an " +
310+
"aggregate function.")
311+
case _: AggregateExpression =>
312+
e.children.foreach(checkMetric (s, _, seenAggregate = true))
313+
case _ =>
314+
e.children.foreach(checkMetric (s, _, seenAggregate))
315+
}
316+
}
317+
metrics.foreach(m => checkMetric(m, m))
318+
284319
case Sort(orders, _, _) =>
285320
orders.foreach { order =>
286321
if (!RowOrdering.isOrderable(order.dataType)) {

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala

Lines changed: 1 addition & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -979,34 +979,8 @@ case class CollectMetrics(
979979
child: LogicalPlan)
980980
extends UnaryNode {
981981

982-
/**
983-
* Check if an expression is a valid metric. A metric must meet the following criteria:
984-
* - Is not a window function;
985-
* - Is not nested aggregate function;
986-
* - Is not a distinct aggregate function;
987-
* - Has only non-deterministic functions that are nested inside an aggregate function;
988-
* - Has only attributes that are nested inside an aggregate function.
989-
*
990-
* @param e expression to check.
991-
* @param seenAggregate `true` iff one of the parents on the expression is an aggregate function.
992-
* @return `true` if the metric is valid, `false` otherwise.
993-
*/
994-
private def isValidMetric(e: Expression, seenAggregate: Boolean = false): Boolean = {
995-
e match {
996-
case _: WindowExpression => false
997-
case a: AggregateExpression if seenAggregate || a.isDistinct => false
998-
case _: AggregateExpression => e.children.forall(isValidMetric(_, seenAggregate = true))
999-
case _: Nondeterministic if !seenAggregate => false
1000-
case _: Attribute if !seenAggregate => false
1001-
case _ => e.children.forall(isValidMetric(_, seenAggregate))
1002-
}
1003-
}
1004-
1005982
override lazy val resolved: Boolean = {
1006-
def metricsResolved: Boolean = metrics.forall { e =>
1007-
e.resolved && isValidMetric(e)
1008-
}
1009-
name.nonEmpty && metrics.nonEmpty && metricsResolved && childrenResolved
983+
name.nonEmpty && metrics.nonEmpty && metrics.forall(_.resolved) && childrenResolved
1010984
}
1011985

1012986
override def output: Seq[Attribute] = child.output

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -662,31 +662,44 @@ class AnalysisSuite extends AnalysisTest with Matchers {
662662

663663
// Bad name
664664
assert(!CollectMetrics("", sum :: Nil, testRelation).resolved)
665+
assertAnalysisError(CollectMetrics("", sum :: Nil, testRelation),
666+
"observed metrics should be named" :: Nil)
665667

666-
def checkUnresolved(exprs: NamedExpression*): Unit = {
667-
assert(!CollectMetrics("event", exprs, testRelation).resolved)
668-
}
669668
// No columns
670-
checkUnresolved()
669+
assert(!CollectMetrics("evt", Nil, testRelation).resolved)
670+
671+
def checkAnalysisError(exprs: Seq[NamedExpression], errors: String*): Unit = {
672+
assertAnalysisError(CollectMetrics("event", exprs, testRelation), errors)
673+
}
671674

672675
// Unwrapped attribute
673-
checkUnresolved(a)
676+
checkAnalysisError(
677+
a :: Nil,
678+
"Attribute", "can only be used as an argument to an aggregate function")
674679

675680
// Unwrapped non-deterministic expression
676-
checkUnresolved(Rand(10).as("rnd"))
681+
checkAnalysisError(
682+
Rand(10).as("rnd") :: Nil,
683+
"non-deterministic expression", "can only be used as an argument to an aggregate function")
677684

678685
// Distinct aggregate
679-
checkUnresolved(Sum(a).toAggregateExpression(isDistinct = true).as("sum"))
686+
checkAnalysisError(
687+
Sum(a).toAggregateExpression(isDistinct = true).as("sum") :: Nil,
688+
"distinct aggregates are not allowed in observed metrics, but found")
680689

681690
// Nested aggregate
682-
checkUnresolved(Sum(Sum(a).toAggregateExpression()).toAggregateExpression().as("sum"))
691+
checkAnalysisError(
692+
Sum(Sum(a).toAggregateExpression()).toAggregateExpression().as("sum") :: Nil,
693+
"nested aggregates are not allowed in observed metrics, but found")
683694

684695
// Windowed aggregate
685696
val windowExpr = WindowExpression(
686697
RowNumber(),
687698
WindowSpecDefinition(Nil, a.asc :: Nil,
688699
SpecifiedWindowFrame(RowFrame, UnboundedPreceding, CurrentRow)))
689-
checkUnresolved(windowExpr.as("rn"))
700+
checkAnalysisError(
701+
windowExpr.as("rn") :: Nil,
702+
"window expressions are not allowed in observed metrics, but found")
690703
}
691704

692705
test("check CollectMetrics duplicates") {

0 commit comments

Comments
 (0)