Skip to content

Commit f4ed3bc

Browse files
committed
address comments
1 parent 808a5fa commit f4ed3bc

3 files changed

Lines changed: 66 additions & 38 deletions

File tree

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -733,6 +733,24 @@ class Analyzer(
733733
}
734734
}
735735

736+
protected[sql] def resolveFunction(func: UnresolvedFunction) = {
737+
catalog.lookupFunction(func.name, func.children) match {
738+
// DISTINCT is not meaningful for a Max or a Min.
739+
case max: Max if func.isDistinct =>
740+
AggregateExpression(max, Complete, isDistinct = false)
741+
case min: Min if func.isDistinct =>
742+
AggregateExpression(min, Complete, isDistinct = false)
743+
// AggregateWindowFunctions are AggregateFunctions that can only be evaluated within
744+
// the context of a Window clause. They do not need to be wrapped in an
745+
// AggregateExpression.
746+
case wf: AggregateWindowFunction => wf
747+
// We get an aggregate function, we need to wrap it in an AggregateExpression.
748+
case agg: AggregateFunction => AggregateExpression(agg, Complete, func.isDistinct)
749+
// This function is not an aggregate function, just return the resolved one.
750+
case other => other
751+
}
752+
}
753+
736754
/**
737755
* In many dialects of SQL it is valid to use ordinal positions in order/sort by and group by
738756
* clauses. This rule is to convert ordinal positions to the corresponding expressions in the
@@ -916,21 +934,7 @@ class Analyzer(
916934
}
917935
case u @ UnresolvedFunction(funcId, children, isDistinct) =>
918936
withPosition(u) {
919-
catalog.lookupFunction(funcId, children) match {
920-
// DISTINCT is not meaningful for a Max or a Min.
921-
case max: Max if isDistinct =>
922-
AggregateExpression(max, Complete, isDistinct = false)
923-
case min: Min if isDistinct =>
924-
AggregateExpression(min, Complete, isDistinct = false)
925-
// AggregateWindowFunctions are AggregateFunctions that can only be evaluated within
926-
// the context of a Window clause. They do not need to be wrapped in an
927-
// AggregateExpression.
928-
case wf: AggregateWindowFunction => wf
929-
// We get an aggregate function, we need to wrap it in an AggregateExpression.
930-
case agg: AggregateFunction => AggregateExpression(agg, Complete, isDistinct)
931-
// This function is not an aggregate function, just return the resolved one.
932-
case other => other
933-
}
937+
resolveFunction(u)
934938
}
935939
}
936940
}

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala

Lines changed: 3 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ import org.antlr.v4.runtime.tree.TerminalNode
2525

2626
import org.apache.spark.sql.SaveMode
2727
import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier}
28-
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
2928
import org.apache.spark.sql.catalyst.catalog._
3029
import org.apache.spark.sql.catalyst.expressions._
3130
import org.apache.spark.sql.catalyst.parser._
@@ -601,27 +600,11 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder {
601600
*/
602601
override def visitCreateMacro(ctx: CreateMacroContext): LogicalPlan = withOrigin(ctx) {
603602
val arguments = Option(ctx.colTypeList).map(visitColTypeList(_))
604-
.getOrElse(Seq.empty[StructField]).map { col =>
605-
AttributeReference(col.name, col.dataType, col.nullable, col.metadata)() }
606-
val colToIndex: Map[String, Int] = arguments.map(_.name).zipWithIndex.toMap
607-
if (colToIndex.size != arguments.size) {
608-
throw operationNotAllowed(
609-
s"Cannot support duplicate colNames for CREATE TEMPORARY MACRO ", ctx)
610-
}
611-
val macroFunction = expression(ctx.expression).transformUp {
612-
case u: UnresolvedAttribute =>
613-
val index = colToIndex.get(u.name).getOrElse(
614-
throw new ParseException(
615-
s"Cannot find colName: [${u}] for CREATE TEMPORARY MACRO", ctx))
616-
BoundReference(index, arguments(index).dataType, arguments(index).nullable)
617-
case _: SubqueryExpression =>
618-
throw operationNotAllowed(s"Cannot support Subquery for CREATE TEMPORARY MACRO", ctx)
619-
}
620-
603+
.getOrElse(Seq.empty[StructField])
604+
val e = expression(ctx.expression)
621605
CreateMacroCommand(
622606
ctx.macroName.getText,
623-
arguments,
624-
macroFunction)
607+
MacroFunctionWrapper(arguments, e))
625608
}
626609

627610
/**

sql/core/src/main/scala/org/apache/spark/sql/execution/command/macros.scala

Lines changed: 44 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,14 @@
1818
package org.apache.spark.sql.execution.command
1919

2020
import org.apache.spark.sql.{AnalysisException, Row, SparkSession}
21+
import org.apache.spark.sql.catalyst.analysis._
2122
import org.apache.spark.sql.catalyst.expressions._
23+
import org.apache.spark.sql.types.StructField
24+
25+
/**
26+
* This class provides arguments and body expression of the macro function.
27+
*/
28+
case class MacroFunctionWrapper(columns: Seq[StructField], macroFunction: Expression)
2229

2330
/**
2431
* The DDL command that creates a macro.
@@ -29,20 +36,54 @@ import org.apache.spark.sql.catalyst.expressions._
2936
*/
3037
case class CreateMacroCommand(
3138
macroName: String,
32-
columns: Seq[AttributeReference],
33-
macroFunction: Expression)
39+
funcWrapper: MacroFunctionWrapper)
3440
extends RunnableCommand {
3541

3642
override def run(sparkSession: SparkSession): Seq[Row] = {
3743
val catalog = sparkSession.sessionState.catalog
38-
val macroInfo = columns.mkString(",") + " -> " + macroFunction.toString
44+
val columns = funcWrapper.columns.map { col =>
45+
AttributeReference(col.name, col.dataType, col.nullable, col.metadata)() }
46+
val colToIndex: Map[String, Int] = columns.map(_.name).zipWithIndex.toMap
47+
if (colToIndex.size != columns.size) {
48+
throw new AnalysisException(s"Cannot support duplicate colNames " +
49+
s"for CREATE TEMPORARY MACRO $macroName, actual columns: ${columns.mkString(",")}")
50+
}
51+
val macroFunction = funcWrapper.macroFunction.transformDown {
52+
case u: UnresolvedAttribute =>
53+
val index = colToIndex.get(u.name).getOrElse(
54+
throw new AnalysisException(s"Cannot find colName: ${u} " +
55+
s"for CREATE TEMPORARY MACRO $macroName, actual columns: ${columns.mkString(",")}"))
56+
BoundReference(index, columns(index).dataType, columns(index).nullable)
57+
case u: UnresolvedFunction =>
58+
sparkSession.sessionState.analyzer.resolveFunction(u)
59+
case s: SubqueryExpression =>
60+
throw new AnalysisException(s"Cannot support Subquery: ${s} " +
61+
s"for CREATE TEMPORARY MACRO $macroName")
62+
case u: UnresolvedGenerator =>
63+
throw new AnalysisException(s"Cannot support Generator: ${u} " +
64+
s"for CREATE TEMPORARY MACRO $macroName")
65+
}
66+
if (!macroFunction.resolved) {
67+
if (macroFunction.checkInputDataTypes().isFailure) {
68+
macroFunction.checkInputDataTypes() match {
69+
case TypeCheckResult.TypeCheckFailure(message) =>
70+
throw new AnalysisException(s"Cannot resolve '${macroFunction.sql}' " +
71+
s"for CREATE TEMPORARY MACRO $macroName, due to data type mismatch: $message")
72+
}
73+
} else {
74+
throw new AnalysisException(s"Cannot resolve '${macroFunction.sql}' " +
75+
s"for CREATE TEMPORARY MACRO $macroName")
76+
}
77+
}
78+
val macroInfo = columns.mkString(",") + " -> " + funcWrapper.macroFunction.toString
3979
val info = new ExpressionInfo(macroInfo, macroName)
4080
val builder = (children: Seq[Expression]) => {
4181
if (children.size != columns.size) {
4282
throw new AnalysisException(s"Actual number of columns: ${children.size} != " +
4383
s"expected number of columns: ${columns.size} for Macro $macroName")
4484
}
4585
macroFunction.transformUp {
86+
// Skip to validate the input type because Analyzer will check it after ResolveFunctions.
4687
case b: BoundReference => children(b.ordinal)
4788
}
4889
}

0 commit comments

Comments
 (0)