diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index d58fb7cb20bfc..ea06bc0f074bc 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -319,7 +319,7 @@ ctes ; namedQuery - : name=identifier AS? '(' query ')' + : name=identifier (columnAliases=identifierList)? AS? '(' query ')' ; tableProvider diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 0048088f3ede0..59d0aa49262dd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -185,7 +185,11 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging * This is only used for Common Table Expressions. */ override def visitNamedQuery(ctx: NamedQueryContext): SubqueryAlias = withOrigin(ctx) { - SubqueryAlias(ctx.name.getText, plan(ctx.query)) + val subQuery: LogicalPlan = plan(ctx.query).optionalMap(ctx.columnAliases)( + (columnAliases, plan) => + UnresolvedSubqueryColumnAliases(visitIdentifierList(columnAliases), plan) + ) + SubqueryAlias(ctx.name.getText, subQuery) } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 80387336e1645..b69f25b8a5d18 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.parser.CatalystSqlParser.parsePlan import org.apache.spark.sql.catalyst.plans.{Cross, Inner} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, @@ -633,4 +634,15 @@ class AnalysisSuite extends AnalysisTest with Matchers { val res = ViewAnalyzer.execute(view) comparePlans(res, expected) } + + test("CTE with non-existing column alias") { + assertAnalysisError(parsePlan("WITH t(x) AS (SELECT 1) SELECT * FROM t WHERE y = 1"), + Seq("cannot resolve '`y`' given input columns: [x]")) + } + + test("CTE with non-matching column alias") { + assertAnalysisError(parsePlan("WITH t(x, y) AS (SELECT 1) SELECT * FROM t WHERE x = 1"), + Seq("Number of column aliases does not match number of columns. Number of column aliases: " + + "2; number of columns: 1.")) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala index fba2a28c3fc38..1d63f1b6ca83c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -42,10 +42,15 @@ class PlanParserSuite extends AnalysisTest { private def intercept(sqlCommand: String, messages: String*): Unit = interceptParseException(parsePlan)(sqlCommand, messages: _*) - private def cte(plan: LogicalPlan, namedPlans: (String, LogicalPlan)*): With = { + private def cte(plan: LogicalPlan, namedPlans: (String, (LogicalPlan, Seq[String]))*): With = { val ctes = namedPlans.map { - case (name, cte) => - name -> SubqueryAlias(name, cte) + case (name, (cte, columnAliases)) => + val subquery = if (columnAliases.isEmpty) { + cte + } else { + UnresolvedSubqueryColumnAliases(columnAliases, cte) + } + name -> SubqueryAlias(name, subquery) } With(plan, ctes) } @@ -84,15 +89,15 @@ class PlanParserSuite extends AnalysisTest { test("common table expressions") { assertEqual( "with cte1 as (select * from a) select * from cte1", - cte(table("cte1").select(star()), "cte1" -> table("a").select(star()))) + cte(table("cte1").select(star()), "cte1" -> ((table("a").select(star()), Seq.empty)))) assertEqual( "with cte1 (select 1) select * from cte1", - cte(table("cte1").select(star()), "cte1" -> OneRowRelation().select(1))) + cte(table("cte1").select(star()), "cte1" -> ((OneRowRelation().select(1), Seq.empty)))) assertEqual( "with cte1 (select 1), cte2 as (select * from cte1) select * from cte2", cte(table("cte2").select(star()), - "cte1" -> OneRowRelation().select(1), - "cte2" -> table("cte1").select(star()))) + "cte1" -> ((OneRowRelation().select(1), Seq.empty)), + "cte2" -> ((table("cte1").select(star()), Seq.empty)))) intercept( "with cte1 (select 1), cte1 as (select 1 from cte1) select * from cte1", "Found duplicate keys 'cte1'") @@ -812,10 +817,17 @@ class PlanParserSuite extends AnalysisTest { |WITH cte1 AS (SELECT * FROM testcat.db.tab) |SELECT * FROM cte1 """.stripMargin, - cte(table("cte1").select(star()), "cte1" -> table("testcat", "db", "tab").select(star()))) + cte(table("cte1").select(star()), + "cte1" -> ((table("testcat", "db", "tab").select(star()), Seq.empty)))) assertEqual( "SELECT /*+ BROADCAST(tab) */ * FROM testcat.db.tab", table("testcat", "db", "tab").select(star()).hint("BROADCAST", $"tab")) } + + test("CTE with column alias") { + assertEqual( + "WITH t(x) AS (SELECT c FROM a) SELECT * FROM t", + cte(table("t").select(star()), "t" -> ((table("a").select('c), Seq("x"))))) + } } diff --git a/sql/core/src/test/resources/sql-tests/inputs/cte.sql b/sql/core/src/test/resources/sql-tests/inputs/cte.sql index d34d89f23575a..822c5c4660e3b 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/cte.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/cte.sql @@ -24,6 +24,10 @@ SELECT t1.id AS c1, FROM CTE1 t1 CROSS JOIN CTE1 t2; +-- CTE with column alias +WITH t(x) AS (SELECT 1) +SELECT * FROM t WHERE x = 1; + -- Clean up DROP VIEW IF EXISTS t; DROP VIEW IF EXISTS t2; diff --git a/sql/core/src/test/resources/sql-tests/results/cte.sql.out b/sql/core/src/test/resources/sql-tests/results/cte.sql.out index a446c2cd183da..f8ccecbc46f46 100644 --- a/sql/core/src/test/resources/sql-tests/results/cte.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/cte.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 9 +-- Number of queries: 10 -- !query 0 @@ -89,16 +89,25 @@ struct -- !query 7 -DROP VIEW IF EXISTS t +WITH t(x) AS (SELECT 1) +SELECT * FROM t WHERE x = 1 -- !query 7 schema -struct<> +struct -- !query 7 output - +1 -- !query 8 -DROP VIEW IF EXISTS t2 +DROP VIEW IF EXISTS t -- !query 8 schema struct<> -- !query 8 output + + +-- !query 9 +DROP VIEW IF EXISTS t2 +-- !query 9 schema +struct<> +-- !query 9 output +