Skip to content
Closed
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -339,11 +339,34 @@ case class DataSource(
dataSource.createRelation(sparkSession.sqlContext, caseInsensitiveOptions)
case (_: SchemaRelationProvider, None) =>
throw new AnalysisException(s"A schema needs to be specified when using $className.")
case (dataSource: RelationProvider, Some(schema)) =>
case (dataSource: RelationProvider, Some(specifiedSchema)) =>
val baseRelation =
dataSource.createRelation(sparkSession.sqlContext, caseInsensitiveOptions)
if (baseRelation.schema != schema) {
throw new AnalysisException(s"$className does not allow user-specified schemas.")
val persistentSchema = baseRelation.schema
val persistentSize = persistentSchema.size
val specifiedSize = specifiedSchema.size
if (persistentSize == specifiedSize) {
val (persistentFields, specifiedFields) = persistentSchema.zip(specifiedSchema)
Copy link
Member

@HyukjinKwon HyukjinKwon Dec 7, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we're going to improve such error message case across the codebase, we might also think about having a common method (maybe something called assertEquality in StructType?) that checks each type recursively and shows a better message. Can we at least have a private method here for this case in the future?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure whether we'd require this similar functionality in some cases in the future. But, maybe, we could still give it a try.

Copy link
Member

@HyukjinKwon HyukjinKwon Dec 8, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I think it wont handle nested cases. There are other external data sources that support nested schema and the current code tells only root columns.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, there are many cases to show better error messages like this. E.g., StructType.merge or _merge_type in Python's schema inference (https://github.com/apache/spark/blob/master/python/pyspark/sql/types.py#L1097-L1111)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see #19792 or #18521 as an example.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @HyukjinKwon , after discussing with wenchen offline, we decide not to make it too complicated here. If schemas are detected not match, we simply show the whole schema to user rather than those mismatched fields as previously did. Please see de036b6.

.filter { case (existedField, userField) => existedField != userField }
.unzip
if (persistentFields.nonEmpty) {
val errorMsg =
s"Mismatched fields detected between persistent schema and user specified schema: " +
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: seems like we can remove ss.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean: fields -> filed ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nope, I meant the s for string interpolation (s"...)

s"persistentFields: ${persistentFields.map(_.toDDL).mkString(", ")}, " +
s"specifiedFields: ${specifiedFields.map(_.toDDL).mkString(", ")}. " +
s"This happens either you make a mistake in schema or type mapping between Spark " +
s"and external data sources have been updated while your specified schema still " +
s"using the old schema. Please either correct the schema or just do not specify " +
s"the schema since a specified schema for $className is not necessary."
throw new AnalysisException(errorMsg)
}
} else {
val errorMsg =
s"The number of fields between persistent schema and user specified schema " +
s"mismatched: expect $persistentSize fields, but got $specifiedSize fields. " +
s"Please either correct the schema or just do not specify the schema since " +
s"a specified schema for $className is not necessary."
throw new AnalysisException(errorMsg)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: format like this?

          throw new AnalysisException("The user-specified schema doesn't match the actual schema: " +
            s"user-specified: ${schema.toDDL}, actual: ${baseRelation.schema.toDDL}. If " +
            "you're using DataFrameReader.schema API or creating a table, please do not " +
            "specify the schema. Or if you're scanning an existed table, please drop " +
            "it and re-create it.")

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated, thanks!

}
baseRelation

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,27 @@

package org.apache.spark.sql.execution.datasources.jdbc

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.{AnalysisException, SQLContext}
import org.apache.spark.sql.catalyst.parser.ParseException
import org.apache.spark.sql.execution.datasources.DataSource
import org.apache.spark.sql.sources.{BaseRelation, RelationProvider}
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types._

class JdbcUtilsSuite extends SparkFunSuite {
/* A test JdbcRelationProvider used to provide persistent schema */
class TestJdbcRelationProvider extends RelationProvider {
override def createRelation(sqlCtx: SQLContext, parameters: Map[String, String])
: BaseRelation = {
new BaseRelation {
override def sqlContext: SQLContext = sqlCtx
override def schema: StructType = {
new StructType().add(StructField("a", StringType)).add(StructField("b", IntegerType))
}
}
}
}

class JdbcUtilsSuite extends SharedSparkSession {

val tableSchema = StructType(Seq(
StructField("C1", StringType, false), StructField("C2", IntegerType, false)))
Expand Down Expand Up @@ -65,4 +80,129 @@ class JdbcUtilsSuite extends SparkFunSuite {
}
assert(mismatchedInput.getMessage.contains("mismatched input '.' expecting"))
}

test("SPARK-30151: user-specified schema not match relation schema - number mismatch") {
// persistent: (a STRING, b INT)
val persistentSchema =
DataSource(spark, classOf[TestJdbcRelationProvider].getCanonicalName)
.resolveRelation()
.schema
// specified: (a STRING)
val specifiedSchema = new StructType()
.add(StructField("a", StringType))
val msg = intercept[AnalysisException] {
DataSource(
spark,
classOf[TestJdbcRelationProvider].getCanonicalName,
userSpecifiedSchema = Some(specifiedSchema))
.resolveRelation()
}.getMessage
assert(msg.contains(
"The number of fields between persistent schema and user specified schema mismatch"))
}

test("SPARK-30151: user-specified schema not match relation schema - wrong name") {
// persistent: (a STRING, b INT)
val persistentSchema =
DataSource(spark, classOf[TestJdbcRelationProvider].getCanonicalName)
.resolveRelation()
.schema
// specified: (a STRING, c INT)
val specifiedSchema = new StructType()
.add(StructField("a", StringType))
.add(StructField("c", IntegerType)) // wrong field name
val msg = intercept[AnalysisException] {
DataSource(
spark,
classOf[TestJdbcRelationProvider].getCanonicalName,
userSpecifiedSchema = Some(specifiedSchema))
.resolveRelation()
}.getMessage
assert(msg.contains(s"persistentFields: ${persistentSchema("b").toDDL}"))
assert(msg.contains(s"specifiedFields: ${specifiedSchema("c").toDDL}"))
}

test("SPARK-30151: user-specified schema not match relation schema - wrong type") {
// persistent: (a STRING, b INT)
val persistentSchema =
DataSource(spark, classOf[TestJdbcRelationProvider].getCanonicalName)
.resolveRelation()
.schema
// specified: (a STRING, b STRING)
val specifiedSchema = new StructType()
.add(StructField("a", StringType))
.add(StructField("b", StringType)) // wrong filed type
val msg = intercept[AnalysisException] {
DataSource(
spark,
classOf[TestJdbcRelationProvider].getCanonicalName,
userSpecifiedSchema = Some(specifiedSchema))
.resolveRelation()
}.getMessage
assert(msg.contains(s"persistentFields: ${persistentSchema("b").toDDL}"))
assert(msg.contains(s"specifiedFields: ${specifiedSchema("b").toDDL}"))
}

test("SPARK-30151: user-specified schema not match relation schema - wrong name & type") {
// persistent: (a STRING, b INT)
val persistentSchema =
DataSource(spark, classOf[TestJdbcRelationProvider].getCanonicalName)
.resolveRelation()
.schema
// specified: (a STRING, c STRING)
val specifiedSchema = new StructType()
.add(StructField("a", StringType))
.add(StructField("c", StringType)) // wrong filed name and type
val msg = intercept[AnalysisException] {
DataSource(
spark,
classOf[TestJdbcRelationProvider].getCanonicalName,
userSpecifiedSchema = Some(specifiedSchema))
.resolveRelation()
}.getMessage
assert(msg.contains(s"persistentFields: ${persistentSchema("b").toDDL}"))
assert(msg.contains(s"specifiedFields: ${specifiedSchema("c").toDDL}"))
}

test("SPARK-30151: user-specified schema not match relation schema - wrong order") {
// persistent: (a STRING, b INT)
val persistentSchema =
DataSource(spark, classOf[TestJdbcRelationProvider].getCanonicalName)
.resolveRelation()
.schema
// specified: (b INT, a STRING)
val specifiedSchema = new StructType() // wrong order
.add(StructField("b", IntegerType))
.add(StructField("a", StringType))
val msg = intercept[AnalysisException] {
DataSource(
spark,
classOf[TestJdbcRelationProvider].getCanonicalName,
userSpecifiedSchema = Some(specifiedSchema))
.resolveRelation()
}.getMessage
assert(msg.contains(s"persistentFields: ${persistentSchema.map(_.toDDL).mkString(", ")}"))
assert(msg.contains(s"specifiedFields: ${specifiedSchema.map(_.toDDL).mkString(", ")}"))
}

test("SPARK-30151: user-specified schema not match relation schema - complex type") {
// persistent: (a STRING, b INT)
val persistentSchema =
DataSource(spark, classOf[TestJdbcRelationProvider].getCanonicalName)
.resolveRelation()
.schema
// specified: (a STRING, b STRUCT<c INT>)
val specifiedSchema = new StructType()
.add(StructField("a", StringType))
.add(StructField("b", StructType(StructField("c", IntegerType) :: Nil))) // complex type
val msg = intercept[AnalysisException] {
DataSource(
spark,
classOf[TestJdbcRelationProvider].getCanonicalName,
userSpecifiedSchema = Some(specifiedSchema))
.resolveRelation()
}.getMessage
assert(msg.contains(s"persistentFields: ${persistentSchema("b").toDDL}"))
assert(msg.contains(s"specifiedFields: ${specifiedSchema("b").toDDL}"))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,8 @@ class TableScanSuite extends DataSourceTest with SharedSparkSession {
|)
""".stripMargin)
}
assert(schemaNotAllowed.getMessage.contains("does not allow user-specified schemas"))
assert(schemaNotAllowed.getMessage.contains(
"a specified schema for org.apache.spark.sql.sources.SimpleScanSource is not necessary"))

val schemaNeeded = intercept[Exception] {
sql(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.plans.logical.{AppendData, LogicalPlan, OverwriteByExpression}
import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.execution.datasources.DataSourceUtils
import org.apache.spark.sql.execution.datasources.{DataSource, DataSourceUtils}
import org.apache.spark.sql.execution.datasources.jdbc.TestJdbcRelationProvider
import org.apache.spark.sql.execution.datasources.noop.NoopDataSource
import org.apache.spark.sql.execution.datasources.parquet.SpecificParquetRecordReaderBase
import org.apache.spark.sql.internal.SQLConf
Expand Down Expand Up @@ -447,6 +448,24 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSparkSession with
}
}

test("SPARK-30151: user-specified schema not match relation schema") {
// persistent: (a STRING, b INT)
val persistentSchema =
DataSource(spark, classOf[TestJdbcRelationProvider].getCanonicalName)
.resolveRelation()
.schema
// specified: (a STRING, c INT)
val specifiedSchema = new StructType()
.add(StructField("a", StringType))
.add(StructField("c", IntegerType))
val msg = intercept[AnalysisException] {
spark.read.format(classOf[TestJdbcRelationProvider].getCanonicalName)
.schema(specifiedSchema).load()
}.getMessage
assert(msg.contains(s"persistentFields: ${persistentSchema("b").toDDL}"))
assert(msg.contains(s"specifiedFields: ${specifiedSchema("c").toDDL}"))
}

test("prevent all column partitioning") {
withTempDir { dir =>
val path = dir.getCanonicalPath
Expand Down Expand Up @@ -493,7 +512,7 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSparkSession with
val inputSchema = new StructType().add("s", IntegerType, nullable = false)
val e = intercept[AnalysisException] { dfReader.schema(inputSchema).load() }
assert(e.getMessage.contains(
"org.apache.spark.sql.sources.SimpleScanSource does not allow user-specified schemas"))
"a specified schema for org.apache.spark.sql.sources.SimpleScanSource is not necessary"))
}

test("read a data source that does not extend RelationProvider") {
Expand Down