Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,20 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT
*/
def outputSet: AttributeSet = AttributeSet(output)

/**
* Returns output attributes with provided names.
* The length of provided names should be the same of the length of [[output]].
*/
def outputWithNames(names: Seq[String]): Seq[Attribute] = {
// Save the output attributes to a variable to avoid duplicated function calls.
val outputAttributes = output
assert(outputAttributes.length == names.length,
"The length of provided names doesn't match the length of output attributes.")
outputAttributes.zipWithIndex.map { case (element, index) =>
element.withName(names(index))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

outputAttributes.zip(names).map { case (attr, outputName) => attr.withName(outputName) }?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gengliangwang In what situations would outputAttributes.length != names.length,could u give me an example?

}
}

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If #22311 merged, we don't need this function anymore? If so, IMHO it'd be better to fix this issue in the FileFormatWriter side as a workaround?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or make it a util function

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems overkill to add a function here. But in FileFormatWriter we can't not access LogicalPlan to get the attributes.
Another way is to put this method in a Util.
Do you have a good suggestion?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking...

object FileFormatWriter {
  ...

  // workaround: a helper function...
  def outputWithNames(outputAttributes: Seq[Attribute], names: Seq[String]): Seq[Attribute] = {
     assert(outputAttributes.length == names.length,
       "The length of provided names doesn't match the length of output attributes.")
     outputAttributes.zipWithIndex.map { case (element, index) =>
       element.withName(names(index))
     }
   }

Then, in each callsite, just say FileFormatWriter. outputWithNames(logicalPlan.output, names)?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@maropu Thanks! I have create object DataWritingCommand for this.

/**
* All Attributes that appear in expressions from this operator. Note that this set does not
* include attributes that are implicitly referenced by being passed through to the output tuple.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ trait DataWritingCommand extends Command {
override final def children: Seq[LogicalPlan] = query :: Nil

// Output columns of the analyzed input query plan
def outputColumns: Seq[Attribute]
def outputColumnNames: Seq[String]

lazy val metrics: Map[String, SQLMetric] = BasicWriteJobStatsTracker.metrics

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ case class CreateDataSourceTableAsSelectCommand(
table: CatalogTable,
mode: SaveMode,
query: LogicalPlan,
outputColumns: Seq[Attribute])
outputColumnNames: Seq[String])
extends DataWritingCommand {

override def run(sparkSession: SparkSession, child: SparkPlan): Seq[Row] = {
Expand Down Expand Up @@ -214,7 +214,7 @@ case class CreateDataSourceTableAsSelectCommand(
catalogTable = if (tableExists) Some(table) else None)

try {
dataSource.writeAndRead(mode, query, outputColumns, physicalPlan)
dataSource.writeAndRead(mode, query, outputColumnNames, physicalPlan)
} catch {
case ex: AnalysisException =>
logError(s"Failed to write to table ${table.identifier.unquotedString}", ex)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -450,7 +450,7 @@ case class DataSource(
mode = mode,
catalogTable = catalogTable,
fileIndex = fileIndex,
outputColumns = data.output)
outputColumnNames = data.output.map(_.name))
}

/**
Expand All @@ -460,9 +460,9 @@ case class DataSource(
* @param mode The save mode for this writing.
* @param data The input query plan that produces the data to be written. Note that this plan
* is analyzed and optimized.
* @param outputColumns The original output columns of the input query plan. The optimizer may not
* preserve the output column's names' case, so we need this parameter
* instead of `data.output`.
* @param outputColumnNames The original output column names of the input query plan. The
* optimizer may not preserve the output column's names' case, so we need
* this parameter instead of `data.output`.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

   * @param outputColumnNames The original output column names of the input query plan. The
   *                          optimizer may not preserve the output column's names' case, so we need
   *                          this parameter instead of `data.output`.

* @param physicalPlan The physical plan of the input query plan. We should run the writing
* command with this physical plan instead of creating a new physical plan,
* so that the metrics can be correctly linked to the given physical plan and
Expand All @@ -471,8 +471,9 @@ case class DataSource(
def writeAndRead(
mode: SaveMode,
data: LogicalPlan,
outputColumns: Seq[Attribute],
outputColumnNames: Seq[String],
physicalPlan: SparkPlan): BaseRelation = {
val outputColumns = data.outputWithNames(names = outputColumnNames)
if (outputColumns.map(_.dataType).exists(_.isInstanceOf[CalendarIntervalType])) {
throw new AnalysisException("Cannot save interval data type into external storage.")
}
Expand All @@ -495,7 +496,9 @@ case class DataSource(
s"Unable to resolve $name given [${data.output.map(_.name).mkString(", ")}]")
}
}
val resolved = cmd.copy(partitionColumns = resolvedPartCols, outputColumns = outputColumns)
val resolved = cmd.copy(
partitionColumns = resolvedPartCols,
outputColumnNames = outputColumns.map(_.name))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why can't we use outputColumnNames directly here?

resolved.run(sparkSession, physicalPlan)
// Replace the schema with that of the DataFrame we just wrote out to avoid re-inferring
copy(userSpecifiedSchema = Some(outputColumns.toStructType.asNullable)).resolveRelation()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ case class DataSourceAnalysis(conf: SQLConf) extends Rule[LogicalPlan] with Cast
case CreateTable(tableDesc, mode, Some(query))
if query.resolved && DDLUtils.isDatasourceTable(tableDesc) =>
DDLUtils.checkDataColNames(tableDesc.copy(schema = query.schema))
CreateDataSourceTableAsSelectCommand(tableDesc, mode, query, query.output)
CreateDataSourceTableAsSelectCommand(tableDesc, mode, query, query.output.map(_.name))

case InsertIntoTable(l @ LogicalRelation(_: InsertableRelation, _, _, _),
parts, query, overwrite, false) if parts.isEmpty =>
Expand Down Expand Up @@ -209,7 +209,7 @@ case class DataSourceAnalysis(conf: SQLConf) extends Rule[LogicalPlan] with Cast
mode,
table,
Some(t.location),
actualQuery.output)
actualQuery.output.map(_.name))
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ case class InsertIntoHadoopFsRelationCommand(
mode: SaveMode,
catalogTable: Option[CatalogTable],
fileIndex: Option[FileIndex],
outputColumns: Seq[Attribute])
outputColumnNames: Seq[String])
extends DataWritingCommand {
import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils.escapePathName

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Line 66: query.schema should be DataWritingCommand.logicalPlanSchemaWithNames(query, outputColumnNames).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, then we can use this method instead.

def checkColumnNameDuplication(
      columnNames: Seq[String], colType: String, caseSensitiveAnalysis: Boolean): Unit

Expand Down Expand Up @@ -155,6 +155,7 @@ case class InsertIntoHadoopFsRelationCommand(
}
}

val outputColumns = query.outputWithNames(outputColumnNames)
val updatedPartitionPaths =
FileFormatWriter.write(
sparkSession = sparkSession,
Expand Down
76 changes: 76 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import java.util.concurrent.atomic.AtomicBoolean

import org.apache.spark.{AccumulatorSuite, SparkException}
import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart}
import org.apache.spark.sql.catalyst.TableIdentifier
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unnecessary change

import org.apache.spark.sql.catalyst.util.StringUtils
import org.apache.spark.sql.execution.aggregate
import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, SortAggregateExec}
Expand Down Expand Up @@ -2853,6 +2854,81 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
}
}

test("Insert overwrite table command should output correct schema: basic") {
withTable("tbl", "tbl2") {
withView("view1") {
val df = spark.range(10).toDF("id")
df.write.format("parquet").saveAsTable("tbl")
spark.sql("CREATE VIEW view1 AS SELECT id FROM tbl")
spark.sql("CREATE TABLE tbl2(ID long) USING parquet")
spark.sql("INSERT OVERWRITE TABLE tbl2 SELECT ID FROM view1")
val identifier = TableIdentifier("tbl2", Some("default"))
val location = spark.sessionState.catalog.getTableMetadata(identifier).location.toString
val expectedSchema = StructType(Seq(StructField("ID", LongType, true)))
assert(spark.read.parquet(location).schema == expectedSchema)
checkAnswer(spark.table("tbl2"), df)
}
}
}

test("Insert overwrite table command should output correct schema: complex") {
withTable("tbl", "tbl2") {
withView("view1") {
val df = spark.range(10).map(x => (x, x.toInt, x.toInt)).toDF("col1", "col2", "col3")
df.write.format("parquet").saveAsTable("tbl")
spark.sql("CREATE VIEW view1 AS SELECT * FROM tbl")
spark.sql("CREATE TABLE tbl2(COL1 long, COL2 int, COL3 int) USING parquet PARTITIONED " +
"BY (COL2) CLUSTERED BY (COL3) INTO 3 BUCKETS")
spark.sql("INSERT OVERWRITE TABLE tbl2 SELECT COL1, COL2, COL3 " +
"FROM view1 CLUSTER BY COL3")
val identifier = TableIdentifier("tbl2", Some("default"))
val location = spark.sessionState.catalog.getTableMetadata(identifier).location.toString
val expectedSchema = StructType(Seq(
StructField("COL1", LongType, true),
StructField("COL3", IntegerType, true),
StructField("COL2", IntegerType, true)))
assert(spark.read.parquet(location).schema == expectedSchema)
checkAnswer(spark.table("tbl2"), df)
}
}
}

test("Create table as select command should output correct schema: basic") {
withTable("tbl", "tbl2") {
withView("view1") {
val df = spark.range(10).toDF("id")
df.write.format("parquet").saveAsTable("tbl")
spark.sql("CREATE VIEW view1 AS SELECT id FROM tbl")
spark.sql("CREATE TABLE tbl2 USING parquet AS SELECT ID FROM view1")
val identifier = TableIdentifier("tbl2", Some("default"))
val location = spark.sessionState.catalog.getTableMetadata(identifier).location.toString
val expectedSchema = StructType(Seq(StructField("ID", LongType, true)))
assert(spark.read.parquet(location).schema == expectedSchema)
checkAnswer(spark.table("tbl2"), df)
}
}
}

test("Create table as select command should output correct schema: complex") {
withTable("tbl", "tbl2") {
withView("view1") {
val df = spark.range(10).map(x => (x, x.toInt, x.toInt)).toDF("col1", "col2", "col3")
df.write.format("parquet").saveAsTable("tbl")
spark.sql("CREATE VIEW view1 AS SELECT * FROM tbl")
spark.sql("CREATE TABLE tbl2 USING parquet PARTITIONED BY (COL2) " +
"CLUSTERED BY (COL3) INTO 3 BUCKETS AS SELECT COL1, COL2, COL3 FROM view1")
val identifier = TableIdentifier("tbl2", Some("default"))
val location = spark.sessionState.catalog.getTableMetadata(identifier).location.toString
val expectedSchema = StructType(Seq(
StructField("COL1", LongType, true),
StructField("COL3", IntegerType, true),
StructField("COL2", IntegerType, true)))
assert(spark.read.parquet(location).schema == expectedSchema)
checkAnswer(spark.table("tbl2"), df)
}
}
}

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

better to move these tests into DataFrameReaderWriterSuite?

test("SPARK-25144 'distinct' causes memory leak") {
val ds = List(Foo(Some("bar"))).toDS
val result = ds.flatMap(_.bar).distinct
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,22 +149,22 @@ object HiveAnalysis extends Rule[LogicalPlan] {
case InsertIntoTable(r: HiveTableRelation, partSpec, query, overwrite, ifPartitionNotExists)
if DDLUtils.isHiveTable(r.tableMeta) =>
InsertIntoHiveTable(r.tableMeta, partSpec, query, overwrite,
ifPartitionNotExists, query.output)
ifPartitionNotExists, query.output.map(_.name))

case CreateTable(tableDesc, mode, None) if DDLUtils.isHiveTable(tableDesc) =>
DDLUtils.checkDataColNames(tableDesc)
CreateTableCommand(tableDesc, ignoreIfExists = mode == SaveMode.Ignore)

case CreateTable(tableDesc, mode, Some(query)) if DDLUtils.isHiveTable(tableDesc) =>
DDLUtils.checkDataColNames(tableDesc)
CreateHiveTableAsSelectCommand(tableDesc, query, query.output, mode)
CreateHiveTableAsSelectCommand(tableDesc, query, query.output.map(_.name), mode)

case InsertIntoDir(isLocal, storage, provider, child, overwrite)
if DDLUtils.isHiveTable(provider) =>
val outputPath = new Path(storage.locationUri.get)
if (overwrite) DDLUtils.verifyNotReadPath(child, outputPath)

InsertIntoHiveDirCommand(isLocal, storage, child, overwrite, child.output)
InsertIntoHiveDirCommand(isLocal, storage, child, overwrite, child.output.map(_.name))
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ import org.apache.spark.sql.execution.command.DataWritingCommand
case class CreateHiveTableAsSelectCommand(
tableDesc: CatalogTable,
query: LogicalPlan,
outputColumns: Seq[Attribute],
outputColumnNames: Seq[String],
mode: SaveMode)
extends DataWritingCommand {

Expand All @@ -63,7 +63,7 @@ case class CreateHiveTableAsSelectCommand(
query,
overwrite = false,
ifPartitionNotExists = false,
outputColumns = outputColumns).run(sparkSession, child)
outputColumnNames = outputColumnNames).run(sparkSession, child)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you remove one outputColumnNames?

} else {
// TODO ideally, we should get the output data ready first and then
// add the relation into catalog, just in case of failure occurs while data
Expand All @@ -82,7 +82,7 @@ case class CreateHiveTableAsSelectCommand(
query,
overwrite = true,
ifPartitionNotExists = false,
outputColumns = outputColumns).run(sparkSession, child)
outputColumnNames = outputColumnNames).run(sparkSession, child)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this duplication needed here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the duplication?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

outputColumnNames themselves. Specyfing outputColumnNames as the name of the property to set using outputColumnNames does nothing but introduces a duplication. If you removed one outputColumnNames the comprehension should not be lowered whatsoever, shouldn't it?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel it's better to specify parameters by name if the previous parameter is already specified by name, e.g. ifPartitionNotExists = false

} catch {
case NonFatal(e) =>
// drop the created table.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ case class InsertIntoHiveDirCommand(
storage: CatalogStorageFormat,
query: LogicalPlan,
overwrite: Boolean,
outputColumns: Seq[Attribute]) extends SaveAsHiveFile {
outputColumnNames: Seq[String]) extends SaveAsHiveFile {

override def run(sparkSession: SparkSession, child: SparkPlan): Seq[Row] = {
assert(storage.locationUri.nonEmpty)
Expand Down Expand Up @@ -105,7 +105,7 @@ case class InsertIntoHiveDirCommand(
hadoopConf = hadoopConf,
fileSinkConf = fileSinkConf,
outputLocation = tmpPath.toString,
allColumns = outputColumns)
allColumns = query.outputWithNames(outputColumnNames))

val fs = writeToPath.getFileSystem(hadoopConf)
if (overwrite && fs.exists(writeToPath)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ case class InsertIntoHiveTable(
query: LogicalPlan,
overwrite: Boolean,
ifPartitionNotExists: Boolean,
outputColumns: Seq[Attribute]) extends SaveAsHiveFile {
outputColumnNames: Seq[String]) extends SaveAsHiveFile {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For better test coverage, can you add tests for hive tables?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No problem 👍

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks!


/**
* Inserts all the rows in the table into Hive. Row objects are properly serialized with the
Expand Down Expand Up @@ -198,7 +198,7 @@ case class InsertIntoHiveTable(
hadoopConf = hadoopConf,
fileSinkConf = fileSinkConf,
outputLocation = tmpLocation.toString,
allColumns = outputColumns,
allColumns = query.outputWithNames(outputColumnNames),
partitionAttributes = partitionAttributes)

if (partition.nonEmpty) {
Expand Down