Skip to content

Commit b1dbd0a

Browse files
committed
append a bucketed table using DataFrameWriter with mismatched bucketing should fail
1 parent 7d858bc commit b1dbd0a

File tree

8 files changed

+180
-90
lines changed

8 files changed

+180
-90
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ package org.apache.spark.sql.catalyst.catalog
2020
import org.apache.hadoop.fs.Path
2121
import org.apache.hadoop.util.Shell
2222

23+
import org.apache.spark.sql.AnalysisException
24+
import org.apache.spark.sql.catalyst.analysis.Resolver
2325
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
2426

2527
object ExternalCatalogUtils {
@@ -133,4 +135,39 @@ object CatalogUtils {
133135
case o => o
134136
}
135137
}
138+
139+
def normalizePartCols(
140+
tableName: String,
141+
tableCols: Seq[String],
142+
partCols: Seq[String],
143+
resolver: Resolver): Seq[String] = {
144+
partCols.map(normalizeColumnName(tableName, tableCols, _, "partition", resolver))
145+
}
146+
147+
def normalizeBucketSpec(
148+
tableName: String,
149+
tableCols: Seq[String],
150+
bucketSpec: BucketSpec,
151+
resolver: Resolver): BucketSpec = {
152+
val BucketSpec(numBuckets, bucketColumnNames, sortColumnNames) = bucketSpec
153+
val normalizedBucketCols = bucketColumnNames.map { colName =>
154+
normalizeColumnName(tableName, tableCols, colName, "bucket", resolver)
155+
}
156+
val normalizedSortCols = sortColumnNames.map { colName =>
157+
normalizeColumnName(tableName, tableCols, colName, "sort", resolver)
158+
}
159+
BucketSpec(numBuckets, normalizedBucketCols, normalizedSortCols)
160+
}
161+
162+
private def normalizeColumnName(
163+
tableName: String,
164+
tableCols: Seq[String],
165+
colName: String,
166+
colType: String,
167+
resolver: Resolver): String = {
168+
tableCols.find(resolver(_, colName)).getOrElse {
169+
throw new AnalysisException(s"$colType column $colName is not defined in table $tableName, " +
170+
s"defined table columns are: ${tableCols.mkString(", ")}")
171+
}
172+
}
136173
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,16 @@ case class BucketSpec(
133133
if (numBuckets <= 0) {
134134
throw new AnalysisException(s"Expected positive number of buckets, but got `$numBuckets`.")
135135
}
136+
137+
override def toString: String = {
138+
val bucketString = s"bucket columns: [${bucketColumnNames.mkString(", ")}]"
139+
val sortString = if (sortColumnNames.nonEmpty) {
140+
s", sort columns: [${sortColumnNames.mkString(", ")}]"
141+
} else {
142+
""
143+
}
144+
s"$numBuckets buckets, $bucketString$sortString"
145+
}
136146
}
137147

138148
/**

sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala

Lines changed: 71 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,11 @@
1818
package org.apache.spark.sql.execution.command
1919

2020
import org.apache.spark.sql._
21-
import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases
2221
import org.apache.spark.sql.catalyst.catalog._
23-
import org.apache.spark.sql.catalyst.plans.QueryPlan
22+
import org.apache.spark.sql.catalyst.expressions.NamedExpression
2423
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
2524
import org.apache.spark.sql.execution.datasources._
26-
import org.apache.spark.sql.sources.{BaseRelation, InsertableRelation}
27-
import org.apache.spark.sql.types._
25+
import org.apache.spark.sql.sources.BaseRelation
2826

2927
/**
3028
* A command used to create a data source table.
@@ -143,8 +141,9 @@ case class CreateDataSourceTableAsSelectCommand(
143141
val tableName = tableIdentWithDB.unquotedString
144142

145143
var createMetastoreTable = false
146-
var existingSchema = Option.empty[StructType]
147-
if (sparkSession.sessionState.catalog.tableExists(tableIdentWithDB)) {
144+
// We may need to reorder the columns of the query to match the existing table.
145+
var reorderedColumns = Option.empty[Seq[NamedExpression]]
146+
if (sessionState.catalog.tableExists(tableIdentWithDB)) {
148147
// Check if we need to throw an exception or just return.
149148
mode match {
150149
case SaveMode.ErrorIfExists =>
@@ -157,39 +156,74 @@ case class CreateDataSourceTableAsSelectCommand(
157156
// Since the table already exists and the save mode is Ignore, we will just return.
158157
return Seq.empty[Row]
159158
case SaveMode.Append =>
159+
val existingTable = sessionState.catalog.getTableMetadata(tableIdentWithDB)
160+
if (existingTable.tableType == CatalogTableType.VIEW) {
161+
throw new AnalysisException("Saving data into a view is not allowed.")
162+
}
163+
164+
if (existingTable.provider.get == DDLUtils.HIVE_PROVIDER) {
165+
throw new AnalysisException(s"Saving data in the Hive serde table $tableName is " +
166+
s"not supported yet. Please use the insertInto() API as an alternative.")
167+
}
168+
160169
// Check if the specified data source match the data source of the existing table.
161-
val existingProvider = DataSource.lookupDataSource(provider)
170+
val existingProvider = DataSource.lookupDataSource(existingTable.provider.get)
171+
val specifiedProvider = DataSource.lookupDataSource(table.provider.get)
162172
// TODO: Check that options from the resolved relation match the relation that we are
163173
// inserting into (i.e. using the same compression).
174+
if (existingProvider != specifiedProvider) {
175+
throw new AnalysisException(s"The format of the existing table $tableName is " +
176+
s"`${existingProvider.getSimpleName}`. It doesn't match the specified format " +
177+
s"`${specifiedProvider.getSimpleName}`.")
178+
}
164179

165-
// Pass a table identifier with database part, so that `lookupRelation` won't get temp
166-
// views unexpectedly.
167-
EliminateSubqueryAliases(sessionState.catalog.lookupRelation(tableIdentWithDB)) match {
168-
case l @ LogicalRelation(_: InsertableRelation | _: HadoopFsRelation, _, _) =>
169-
// check if the file formats match
170-
l.relation match {
171-
case r: HadoopFsRelation if r.fileFormat.getClass != existingProvider =>
172-
throw new AnalysisException(
173-
s"The file format of the existing table $tableName is " +
174-
s"`${r.fileFormat.getClass.getName}`. It doesn't match the specified " +
175-
s"format `$provider`")
176-
case _ =>
177-
}
178-
if (query.schema.size != l.schema.size) {
179-
throw new AnalysisException(
180-
s"The column number of the existing schema[${l.schema}] " +
181-
s"doesn't match the data schema[${query.schema}]'s")
182-
}
183-
existingSchema = Some(l.schema)
184-
case s: SimpleCatalogRelation if DDLUtils.isDatasourceTable(s.metadata) =>
185-
existingSchema = Some(s.metadata.schema)
186-
case c: CatalogRelation if c.catalogTable.provider == Some(DDLUtils.HIVE_PROVIDER) =>
187-
throw new AnalysisException("Saving data in the Hive serde table " +
188-
s"${c.catalogTable.identifier} is not supported yet. Please use the " +
189-
"insertInto() API as an alternative..")
190-
case o =>
191-
throw new AnalysisException(s"Saving data in ${o.toString} is not supported.")
180+
if (query.schema.length != existingTable.schema.length) {
181+
throw new AnalysisException(
182+
s"The column number of the existing table $tableName" +
183+
s"(${existingTable.schema.catalogString}) doesn't match the data schema" +
184+
s"(${query.schema.catalogString})")
192185
}
186+
187+
val resolver = sessionState.conf.resolver
188+
val tableCols = existingTable.schema.map(_.name)
189+
190+
reorderedColumns = Some(existingTable.schema.map { f =>
191+
query.resolve(Seq(f.name), resolver).getOrElse {
192+
val inputColumns = query.schema.map(_.name).mkString(", ")
193+
throw new AnalysisException(
194+
s"cannot resolve '${f.name}' given input columns: [$inputColumns]")
195+
}
196+
})
197+
198+
// Check if the specified partition columns match the existing table.
199+
val specifiedPartCols = CatalogUtils.normalizePartCols(
200+
tableName, tableCols, table.partitionColumnNames, resolver)
201+
if (specifiedPartCols != existingTable.partitionColumnNames) {
202+
throw new AnalysisException(
203+
s"""
204+
|Specified partitioning does not match the existing table $tableName.
205+
|Specified partition columns: [${specifiedPartCols.mkString(", ")}]
206+
|Existing partition columns: [${existingTable.partitionColumnNames.mkString(", ")}]
207+
""".stripMargin)
208+
}
209+
210+
// Check if the specified bucketing match the existing table.
211+
val specifiedBucketSpec = table.bucketSpec.map { bucketSpec =>
212+
CatalogUtils.normalizeBucketSpec(tableName, tableCols, bucketSpec, resolver)
213+
}
214+
if (specifiedBucketSpec != existingTable.bucketSpec) {
215+
val specifiedBucketString =
216+
specifiedBucketSpec.map(_.toString).getOrElse("not bucketed")
217+
val existingBucketString =
218+
existingTable.bucketSpec.map(_.toString).getOrElse("not bucketed")
219+
throw new AnalysisException(
220+
s"""
221+
|Specified bucketing does not match the existing table $tableName.
222+
|Specified bucketing: $specifiedBucketString
223+
|Existing bucketing: $existingBucketString
224+
""".stripMargin)
225+
}
226+
193227
case SaveMode.Overwrite =>
194228
sessionState.catalog.dropTable(tableIdentWithDB, ignoreIfNotExists = true, purge = false)
195229
// Need to create the table again.
@@ -201,9 +235,9 @@ case class CreateDataSourceTableAsSelectCommand(
201235
}
202236

203237
val data = Dataset.ofRows(sparkSession, query)
204-
val df = existingSchema match {
205-
// If we are inserting into an existing table, just use the existing schema.
206-
case Some(s) => data.selectExpr(s.fieldNames: _*)
238+
val df = reorderedColumns match {
239+
// Reorder the columns of the query to match the existing table.
240+
case Some(cols) => data.select(cols.map(Column(_)): _*)
207241
case None => data
208242
}
209243

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala

Lines changed: 18 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,11 @@
1717

1818
package org.apache.spark.sql.execution.datasources
1919

20-
import java.util.regex.Pattern
21-
2220
import scala.util.control.NonFatal
2321

2422
import org.apache.spark.sql.{AnalysisException, SaveMode, SparkSession}
25-
import org.apache.spark.sql.catalyst.TableIdentifier
2623
import org.apache.spark.sql.catalyst.analysis._
27-
import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogRelation, CatalogTable, SessionCatalog}
24+
import org.apache.spark.sql.catalyst.catalog.{CatalogRelation, CatalogTable, CatalogUtils, SessionCatalog}
2825
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Cast, RowOrdering}
2926
import org.apache.spark.sql.catalyst.plans.logical
3027
import org.apache.spark.sql.catalyst.plans.logical._
@@ -122,9 +119,12 @@ case class AnalyzeCreateTable(sparkSession: SparkSession) extends Rule[LogicalPl
122119
}
123120

124121
private def checkPartitionColumns(schema: StructType, tableDesc: CatalogTable): CatalogTable = {
125-
val normalizedPartitionCols = tableDesc.partitionColumnNames.map { colName =>
126-
normalizeColumnName(tableDesc.identifier, schema, colName, "partition")
127-
}
122+
val normalizedPartitionCols = CatalogUtils.normalizePartCols(
123+
tableName = tableDesc.identifier.unquotedString,
124+
tableCols = schema.map(_.name),
125+
partCols = tableDesc.partitionColumnNames,
126+
resolver = sparkSession.sessionState.conf.resolver)
127+
128128
checkDuplication(normalizedPartitionCols, "partition")
129129

130130
if (schema.nonEmpty && normalizedPartitionCols.length == schema.length) {
@@ -149,25 +149,21 @@ case class AnalyzeCreateTable(sparkSession: SparkSession) extends Rule[LogicalPl
149149

150150
private def checkBucketColumns(schema: StructType, tableDesc: CatalogTable): CatalogTable = {
151151
tableDesc.bucketSpec match {
152-
case Some(BucketSpec(numBuckets, bucketColumnNames, sortColumnNames)) =>
153-
val normalizedBucketCols = bucketColumnNames.map { colName =>
154-
normalizeColumnName(tableDesc.identifier, schema, colName, "bucket")
155-
}
156-
checkDuplication(normalizedBucketCols, "bucket")
157-
158-
val normalizedSortCols = sortColumnNames.map { colName =>
159-
normalizeColumnName(tableDesc.identifier, schema, colName, "sort")
160-
}
161-
checkDuplication(normalizedSortCols, "sort")
162-
163-
schema.filter(f => normalizedSortCols.contains(f.name)).map(_.dataType).foreach {
152+
case Some(bucketSpec) =>
153+
val normalizedBucketing = CatalogUtils.normalizeBucketSpec(
154+
tableName = tableDesc.identifier.unquotedString,
155+
tableCols = schema.map(_.name),
156+
bucketSpec = bucketSpec,
157+
resolver = sparkSession.sessionState.conf.resolver)
158+
checkDuplication(normalizedBucketing.bucketColumnNames, "bucket")
159+
checkDuplication(normalizedBucketing.sortColumnNames, "sort")
160+
161+
normalizedBucketing.sortColumnNames.map(schema(_)).map(_.dataType).foreach {
164162
case dt if RowOrdering.isOrderable(dt) => // OK
165163
case other => failAnalysis(s"Cannot use ${other.simpleString} for sorting column")
166164
}
167165

168-
tableDesc.copy(
169-
bucketSpec = Some(BucketSpec(numBuckets, normalizedBucketCols, normalizedSortCols))
170-
)
166+
tableDesc.copy(bucketSpec = Some(normalizedBucketing))
171167

172168
case None => tableDesc
173169
}
@@ -182,19 +178,6 @@ case class AnalyzeCreateTable(sparkSession: SparkSession) extends Rule[LogicalPl
182178
}
183179
}
184180

185-
private def normalizeColumnName(
186-
tableIdent: TableIdentifier,
187-
schema: StructType,
188-
colName: String,
189-
colType: String): String = {
190-
val tableCols = schema.map(_.name)
191-
val resolver = sparkSession.sessionState.conf.resolver
192-
tableCols.find(resolver(_, colName)).getOrElse {
193-
failAnalysis(s"$colType column $colName is not defined in table $tableIdent, " +
194-
s"defined table columns are: ${tableCols.mkString(", ")}")
195-
}
196-
}
197-
198181
private def failAnalysis(msg: String) = throw new AnalysisException(msg)
199182
}
200183

sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -342,15 +342,15 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach {
342342
val e = intercept[AnalysisException] {
343343
sql("CREATE TABLE tbl(a int, b string) USING json PARTITIONED BY (c)")
344344
}
345-
assert(e.message == "partition column c is not defined in table `tbl`, " +
345+
assert(e.message == "partition column c is not defined in table tbl, " +
346346
"defined table columns are: a, b")
347347
}
348348

349349
test("create table - bucket column names not in table definition") {
350350
val e = intercept[AnalysisException] {
351351
sql("CREATE TABLE tbl(a int, b string) USING json CLUSTERED BY (c) INTO 4 BUCKETS")
352352
}
353-
assert(e.message == "bucket column c is not defined in table `tbl`, " +
353+
assert(e.message == "bucket column c is not defined in table tbl, " +
354354
"defined table columns are: a, b")
355355
}
356356

sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -108,16 +108,14 @@ class DefaultSourceWithoutUserSpecifiedSchema
108108
}
109109

110110
class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with BeforeAndAfter {
111-
111+
import testImplicits._
112112

113113
private val userSchema = new StructType().add("s", StringType)
114114
private val textSchema = new StructType().add("value", StringType)
115115
private val data = Seq("1", "2", "3")
116116
private val dir = Utils.createTempDir(namePrefix = "input").getCanonicalPath
117-
private implicit var enc: Encoder[String] = _
118117

119118
before {
120-
enc = spark.implicits.newStringEncoder
121119
Utils.deleteRecursively(new File(dir))
122120
}
123121

@@ -609,4 +607,35 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with Be
609607
)
610608
}
611609
}
610+
611+
test("SPARK-18899: append to a bucketed table using DataFrameWriter with mismatched bucketing") {
612+
withTable("t") {
613+
Seq(1 -> "a", 2 -> "b").toDF("i", "j").write.bucketBy(2, "i").saveAsTable("t")
614+
val e = intercept[AnalysisException] {
615+
Seq(3 -> "c").toDF("i", "j").write.bucketBy(3, "i").mode("append").saveAsTable("t")
616+
}
617+
assert(e.message.contains("Specified bucketing does not match the existing table"))
618+
}
619+
}
620+
621+
test("SPARK-18912: number of columns mismatch for non-file-based data source table") {
622+
withTable("t") {
623+
sql("CREATE TABLE t USING org.apache.spark.sql.test.DefaultSource")
624+
625+
val e = intercept[AnalysisException] {
626+
Seq(1 -> "a").toDF("a", "b").write
627+
.format("org.apache.spark.sql.test.DefaultSource")
628+
.mode("append").saveAsTable("t")
629+
}
630+
assert(e.message.contains("The column number of the existing table"))
631+
}
632+
}
633+
634+
test("SPARK-18913: append to a table with special column names") {
635+
withTable("t") {
636+
Seq(1 -> "a").toDF("x.x", "y.y").write.saveAsTable("t")
637+
Seq(2 -> "b").toDF("x.x", "y.y").write.mode("append").saveAsTable("t")
638+
checkAnswer(spark.table("t"), Row(1, "a") :: Row(2, "b") :: Nil)
639+
}
640+
}
612641
}

0 commit comments

Comments
 (0)