-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-25313][SQL]Fix regression in FileFormatWriter output names #22320
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
bbd572c
5bce8a0
16bb457
3c282ef
98bf027
538fea9
45d2a20
3ca072d
4590c98
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -38,6 +38,20 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT | |
| */ | ||
| def outputSet: AttributeSet = AttributeSet(output) | ||
|
|
||
| /** | ||
| * Returns output attributes with provided names. | ||
| * The length of provided names should be the same of the length of [[output]]. | ||
| */ | ||
| def outputWithNames(names: Seq[String]): Seq[Attribute] = { | ||
| // Save the output attributes to a variable to avoid duplicated function calls. | ||
| val outputAttributes = output | ||
| assert(outputAttributes.length == names.length, | ||
| "The length of provided names doesn't match the length of output attributes.") | ||
| outputAttributes.zipWithIndex.map { case (element, index) => | ||
| element.withName(names(index)) | ||
| } | ||
| } | ||
|
|
||
|
||
| /** | ||
| * All Attributes that appear in expressions from this operator. Note that this set does not | ||
| * include attributes that are implicitly referenced by being passed through to the output tuple. | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -450,7 +450,7 @@ case class DataSource( | |
| mode = mode, | ||
| catalogTable = catalogTable, | ||
| fileIndex = fileIndex, | ||
| outputColumns = data.output) | ||
| outputColumnNames = data.output.map(_.name)) | ||
| } | ||
|
|
||
| /** | ||
|
|
@@ -460,9 +460,9 @@ case class DataSource( | |
| * @param mode The save mode for this writing. | ||
| * @param data The input query plan that produces the data to be written. Note that this plan | ||
| * is analyzed and optimized. | ||
| * @param outputColumns The original output columns of the input query plan. The optimizer may not | ||
| * preserve the output column's names' case, so we need this parameter | ||
| * instead of `data.output`. | ||
| * @param outputColumnNames The original output column names of the input query plan. The | ||
| * optimizer may not preserve the output column's names' case, so we need | ||
| * this parameter instead of `data.output`. | ||
|
||
| * @param physicalPlan The physical plan of the input query plan. We should run the writing | ||
| * command with this physical plan instead of creating a new physical plan, | ||
| * so that the metrics can be correctly linked to the given physical plan and | ||
|
|
@@ -471,8 +471,9 @@ case class DataSource( | |
| def writeAndRead( | ||
| mode: SaveMode, | ||
| data: LogicalPlan, | ||
| outputColumns: Seq[Attribute], | ||
| outputColumnNames: Seq[String], | ||
| physicalPlan: SparkPlan): BaseRelation = { | ||
| val outputColumns = data.outputWithNames(names = outputColumnNames) | ||
| if (outputColumns.map(_.dataType).exists(_.isInstanceOf[CalendarIntervalType])) { | ||
| throw new AnalysisException("Cannot save interval data type into external storage.") | ||
| } | ||
|
|
@@ -495,7 +496,9 @@ case class DataSource( | |
| s"Unable to resolve $name given [${data.output.map(_.name).mkString(", ")}]") | ||
| } | ||
| } | ||
| val resolved = cmd.copy(partitionColumns = resolvedPartCols, outputColumns = outputColumns) | ||
| val resolved = cmd.copy( | ||
| partitionColumns = resolvedPartCols, | ||
| outputColumnNames = outputColumns.map(_.name)) | ||
|
||
| resolved.run(sparkSession, physicalPlan) | ||
| // Replace the schema with that of the DataFrame we just wrote out to avoid re-inferring | ||
| copy(userSpecifiedSchema = Some(outputColumns.toStructType.asNullable)).resolveRelation() | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -56,7 +56,7 @@ case class InsertIntoHadoopFsRelationCommand( | |
| mode: SaveMode, | ||
| catalogTable: Option[CatalogTable], | ||
| fileIndex: Option[FileIndex], | ||
| outputColumns: Seq[Attribute]) | ||
| outputColumnNames: Seq[String]) | ||
| extends DataWritingCommand { | ||
| import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils.escapePathName | ||
|
|
||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Line 66:
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh, then we can use this method instead. |
||
|
|
@@ -155,6 +155,7 @@ case class InsertIntoHadoopFsRelationCommand( | |
| } | ||
| } | ||
|
|
||
| val outputColumns = query.outputWithNames(outputColumnNames) | ||
| val updatedPartitionPaths = | ||
| FileFormatWriter.write( | ||
| sparkSession = sparkSession, | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -24,6 +24,7 @@ import java.util.concurrent.atomic.AtomicBoolean | |
|
|
||
| import org.apache.spark.{AccumulatorSuite, SparkException} | ||
| import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} | ||
| import org.apache.spark.sql.catalyst.TableIdentifier | ||
|
||
| import org.apache.spark.sql.catalyst.util.StringUtils | ||
| import org.apache.spark.sql.execution.aggregate | ||
| import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, SortAggregateExec} | ||
|
|
@@ -2853,6 +2854,81 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { | |
| } | ||
| } | ||
|
|
||
| test("Insert overwrite table command should output correct schema: basic") { | ||
| withTable("tbl", "tbl2") { | ||
| withView("view1") { | ||
| val df = spark.range(10).toDF("id") | ||
| df.write.format("parquet").saveAsTable("tbl") | ||
| spark.sql("CREATE VIEW view1 AS SELECT id FROM tbl") | ||
| spark.sql("CREATE TABLE tbl2(ID long) USING parquet") | ||
| spark.sql("INSERT OVERWRITE TABLE tbl2 SELECT ID FROM view1") | ||
| val identifier = TableIdentifier("tbl2", Some("default")) | ||
| val location = spark.sessionState.catalog.getTableMetadata(identifier).location.toString | ||
| val expectedSchema = StructType(Seq(StructField("ID", LongType, true))) | ||
| assert(spark.read.parquet(location).schema == expectedSchema) | ||
| checkAnswer(spark.table("tbl2"), df) | ||
| } | ||
| } | ||
| } | ||
|
|
||
| test("Insert overwrite table command should output correct schema: complex") { | ||
| withTable("tbl", "tbl2") { | ||
| withView("view1") { | ||
| val df = spark.range(10).map(x => (x, x.toInt, x.toInt)).toDF("col1", "col2", "col3") | ||
| df.write.format("parquet").saveAsTable("tbl") | ||
| spark.sql("CREATE VIEW view1 AS SELECT * FROM tbl") | ||
| spark.sql("CREATE TABLE tbl2(COL1 long, COL2 int, COL3 int) USING parquet PARTITIONED " + | ||
| "BY (COL2) CLUSTERED BY (COL3) INTO 3 BUCKETS") | ||
| spark.sql("INSERT OVERWRITE TABLE tbl2 SELECT COL1, COL2, COL3 " + | ||
| "FROM view1 CLUSTER BY COL3") | ||
| val identifier = TableIdentifier("tbl2", Some("default")) | ||
| val location = spark.sessionState.catalog.getTableMetadata(identifier).location.toString | ||
| val expectedSchema = StructType(Seq( | ||
| StructField("COL1", LongType, true), | ||
| StructField("COL3", IntegerType, true), | ||
| StructField("COL2", IntegerType, true))) | ||
| assert(spark.read.parquet(location).schema == expectedSchema) | ||
| checkAnswer(spark.table("tbl2"), df) | ||
| } | ||
| } | ||
| } | ||
|
|
||
| test("Create table as select command should output correct schema: basic") { | ||
| withTable("tbl", "tbl2") { | ||
| withView("view1") { | ||
| val df = spark.range(10).toDF("id") | ||
| df.write.format("parquet").saveAsTable("tbl") | ||
| spark.sql("CREATE VIEW view1 AS SELECT id FROM tbl") | ||
| spark.sql("CREATE TABLE tbl2 USING parquet AS SELECT ID FROM view1") | ||
| val identifier = TableIdentifier("tbl2", Some("default")) | ||
| val location = spark.sessionState.catalog.getTableMetadata(identifier).location.toString | ||
| val expectedSchema = StructType(Seq(StructField("ID", LongType, true))) | ||
| assert(spark.read.parquet(location).schema == expectedSchema) | ||
| checkAnswer(spark.table("tbl2"), df) | ||
| } | ||
| } | ||
| } | ||
|
|
||
| test("Create table as select command should output correct schema: complex") { | ||
| withTable("tbl", "tbl2") { | ||
| withView("view1") { | ||
| val df = spark.range(10).map(x => (x, x.toInt, x.toInt)).toDF("col1", "col2", "col3") | ||
| df.write.format("parquet").saveAsTable("tbl") | ||
| spark.sql("CREATE VIEW view1 AS SELECT * FROM tbl") | ||
| spark.sql("CREATE TABLE tbl2 USING parquet PARTITIONED BY (COL2) " + | ||
| "CLUSTERED BY (COL3) INTO 3 BUCKETS AS SELECT COL1, COL2, COL3 FROM view1") | ||
| val identifier = TableIdentifier("tbl2", Some("default")) | ||
| val location = spark.sessionState.catalog.getTableMetadata(identifier).location.toString | ||
| val expectedSchema = StructType(Seq( | ||
| StructField("COL1", LongType, true), | ||
| StructField("COL3", IntegerType, true), | ||
| StructField("COL2", IntegerType, true))) | ||
| assert(spark.read.parquet(location).schema == expectedSchema) | ||
| checkAnswer(spark.table("tbl2"), df) | ||
| } | ||
| } | ||
| } | ||
|
|
||
|
||
| test("SPARK-25144 'distinct' causes memory leak") { | ||
| val ds = List(Foo(Some("bar"))).toDS | ||
| val result = ds.flatMap(_.bar).distinct | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -37,7 +37,7 @@ import org.apache.spark.sql.execution.command.DataWritingCommand | |
| case class CreateHiveTableAsSelectCommand( | ||
| tableDesc: CatalogTable, | ||
| query: LogicalPlan, | ||
| outputColumns: Seq[Attribute], | ||
| outputColumnNames: Seq[String], | ||
| mode: SaveMode) | ||
| extends DataWritingCommand { | ||
|
|
||
|
|
@@ -63,7 +63,7 @@ case class CreateHiveTableAsSelectCommand( | |
| query, | ||
| overwrite = false, | ||
| ifPartitionNotExists = false, | ||
| outputColumns = outputColumns).run(sparkSession, child) | ||
| outputColumnNames = outputColumnNames).run(sparkSession, child) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you remove one |
||
| } else { | ||
| // TODO ideally, we should get the output data ready first and then | ||
| // add the relation into catalog, just in case of failure occurs while data | ||
|
|
@@ -82,7 +82,7 @@ case class CreateHiveTableAsSelectCommand( | |
| query, | ||
| overwrite = true, | ||
| ifPartitionNotExists = false, | ||
| outputColumns = outputColumns).run(sparkSession, child) | ||
| outputColumnNames = outputColumnNames).run(sparkSession, child) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why is this duplication needed here?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what's the duplication?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I feel it's better to specify parameters by name if the previous parameter is already specified by name, e.g. |
||
| } catch { | ||
| case NonFatal(e) => | ||
| // drop the created table. | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -69,7 +69,7 @@ case class InsertIntoHiveTable( | |
| query: LogicalPlan, | ||
| overwrite: Boolean, | ||
| ifPartitionNotExists: Boolean, | ||
| outputColumns: Seq[Attribute]) extends SaveAsHiveFile { | ||
| outputColumnNames: Seq[String]) extends SaveAsHiveFile { | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For better test coverage, can you add tests for hive tables?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No problem 👍
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. thanks! |
||
|
|
||
| /** | ||
| * Inserts all the rows in the table into Hive. Row objects are properly serialized with the | ||
|
|
@@ -198,7 +198,7 @@ case class InsertIntoHiveTable( | |
| hadoopConf = hadoopConf, | ||
| fileSinkConf = fileSinkConf, | ||
| outputLocation = tmpLocation.toString, | ||
| allColumns = outputColumns, | ||
| allColumns = query.outputWithNames(outputColumnNames), | ||
| partitionAttributes = partitionAttributes) | ||
|
|
||
| if (partition.nonEmpty) { | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
outputAttributes.zip(names).map { case (attr, outputName) => attr.withName(outputName) }?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@gengliangwang In what situations would
outputAttributes.length != names.length,could u give me an example?