Skip to content
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.catalog.CatalogTable
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.logical.{Command, LogicalPlan}
import org.apache.spark.sql.execution.command.RunnableCommand
import org.apache.spark.sql.types._
Expand All @@ -37,7 +38,9 @@ case class CreateTable(tableDesc: CatalogTable, mode: SaveMode, query: Option[Lo

override def output: Seq[Attribute] = Seq.empty[Attribute]

override def children: Seq[LogicalPlan] = query.toSeq
override def children: Seq[LogicalPlan] = Seq.empty[LogicalPlan]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

extend LeafNode?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah. : )


override def innerChildren: Seq[QueryPlan[_]] = query.toSeq
}

case class CreateTempViewUsing(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class ResolveDataSource(sparkSession: SparkSession) extends Rule[LogicalPlan] {
/**
* Preprocess some DDL plans, e.g. [[CreateTable]], to do some normalization and checking.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should update the comments to say that this rule will also analyze the query.(we may also wanna update the rule name)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, let me do it now. Thanks!

*/
case class PreprocessDDL(conf: SQLConf) extends Rule[LogicalPlan] {
case class PreprocessDDL(sparkSession: SparkSession) extends Rule[LogicalPlan] {

def apply(plan: LogicalPlan): LogicalPlan = plan transform {
// When we CREATE TABLE without specifying the table schema, we should fail the query if
Expand All @@ -95,9 +95,19 @@ case class PreprocessDDL(conf: SQLConf) extends Rule[LogicalPlan] {
// * can't use all table columns as partition columns.
// * partition columns' type must be AtomicType.
// * sort columns' type must be orderable.
case c @ CreateTable(tableDesc, mode, query) if c.childrenResolved =>
val schema = if (query.isDefined) query.get.schema else tableDesc.schema
val columnNames = if (conf.caseSensitiveAnalysis) {
case c @ CreateTable(tableDesc, mode, query) =>
val analyzedQuery = query.map { q =>
// Analyze the query in CTAS and then we can do the normalization and checking.
val qe = sparkSession.sessionState.executePlan(q)
qe.assertAnalyzed()
qe.analyzed
}
val schema = if (analyzedQuery.isDefined) {
analyzedQuery.get.schema
} else {
tableDesc.schema
}
val columnNames = if (sparkSession.sessionState.conf.caseSensitiveAnalysis) {
schema.map(_.name)
} else {
schema.map(_.name.toLowerCase)
Expand All @@ -106,7 +116,7 @@ case class PreprocessDDL(conf: SQLConf) extends Rule[LogicalPlan] {

val partitionColsChecked = checkPartitionColumns(schema, tableDesc)
val bucketColsChecked = checkBucketColumns(schema, partitionColsChecked)
c.copy(tableDesc = bucketColsChecked)
c.copy(tableDesc = bucketColsChecked, query = analyzedQuery)
}

private def checkPartitionColumns(schema: StructType, tableDesc: CatalogTable): CatalogTable = {
Expand Down Expand Up @@ -176,6 +186,7 @@ case class PreprocessDDL(conf: SQLConf) extends Rule[LogicalPlan] {
colName: String,
colType: String): String = {
val tableCols = schema.map(_.name)
val conf = sparkSession.sessionState.conf
tableCols.find(conf.resolver(_, colName)).getOrElse {
failAnalysis(s"$colType column $colName is not defined in table $tableIdent, " +
s"defined table columns are: ${tableCols.mkString(", ")}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ private[sql] class SessionState(sparkSession: SparkSession) {
lazy val analyzer: Analyzer = {
new Analyzer(catalog, conf) {
override val extendedResolutionRules =
PreprocessDDL(conf) ::
PreprocessDDL(sparkSession) ::
PreprocessTableInsertion(conf) ::
new FindDataSourceTable(sparkSession) ::
DataSourceAnalysis(conf) ::
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -236,4 +236,16 @@ class CreateTableAsSelectSuite
assert(e.contains("Expected positive number of buckets, but got `0`"))
}
}

test("CTAS of decimal calculation") {
withTable("tab2") {
withTempView("tab1") {
spark.range(99, 101).createOrReplaceTempView("tab1")
val sqlStmt =
"SELECT id, cast(id as long) * cast('1.0' as decimal(38, 18)) as num FROM tab1"
sql(s"CREATE TABLE tab2 USING PARQUET AS $sqlStmt")
checkAnswer(spark.table("tab2"), sql(sqlStmt))
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ private[hive] class HiveSessionState(sparkSession: SparkSession)
override val extendedResolutionRules =
catalog.ParquetConversions ::
catalog.OrcConversions ::
PreprocessDDL(conf) ::
PreprocessDDL(sparkSession) ::
PreprocessTableInsertion(conf) ::
DataSourceAnalysis(conf) ::
(if (conf.runSQLonFile) new ResolveDataSource(sparkSession) :: Nil else Nil)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ class HiveExplainSuite extends QueryTest with SQLTestUtils with TestHiveSingleto
"src")
}

test("SPARK-6212: The EXPLAIN output of CTAS only shows the analyzed plan") {
test("SPARK-17409: The EXPLAIN output of CTAS only shows the analyzed plan") {
withTempView("jt") {
val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str$i"}"""))
spark.read.json(rdd).createOrReplaceTempView("jt")
Expand All @@ -98,8 +98,8 @@ class HiveExplainSuite extends QueryTest with SQLTestUtils with TestHiveSingleto
}

val physicalIndex = outputs.indexOf("== Physical Plan ==")
assert(!outputs.substring(physicalIndex).contains("Subquery"),
"Physical Plan should not contain Subquery since it's eliminated by optimizer")
assert(outputs.substring(physicalIndex).contains("Subquery"),
"Physical Plan should contain SubqueryAlias since the query should not be optimized")
}
}

Expand Down