diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParser.scala index 64bdd6f4643dc..e6918f48aa838 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParser.scala @@ -85,12 +85,9 @@ private[csv] class LineCsvWriter(params: CSVOptions, headers: Seq[String]) exten private val buffer = new CharArrayWriter() private val writer = new CsvWriter(buffer, writerSettings) - def writeRow(row: Seq[String], includeHeader: Boolean): Unit = { - if (includeHeader) { - writer.writeHeaders() - } - writer.writeRow(row.toArray: _*) - } + def writeHeader(): Unit = writer.writeHeaders() + + def writeRow(row: Seq[String]): Unit = writer.writeRow(row.toArray: _*) def flush(): String = { writer.flush() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala index 33b170bc31f62..e3cc454214bd9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala @@ -188,6 +188,8 @@ private[csv] class CsvOutputWriter( // create the Generator without separator inserted between 2 records private[this] val text = new Text() + private val csvWriter = new LineCsvWriter(params, dataSchema.fieldNames.toSeq) + // A `ValueConverter` is responsible for converting a value of an `InternalRow` to `String`. // When the value is null, this converter should not be called. private type ValueConverter = (InternalRow, Int) => String @@ -197,7 +199,7 @@ private[csv] class CsvOutputWriter( dataSchema.map(_.dataType).map(makeConverter).toArray private val recordWriter: RecordWriter[NullWritable, Text] = { - new TextOutputFormat[NullWritable, Text]() { + val writer = new TextOutputFormat[NullWritable, Text]() { override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { val configuration = context.getConfiguration val uniqueWriteJobId = configuration.get(WriterContainer.DATASOURCE_WRITEJOBUUID) @@ -206,11 +208,17 @@ private[csv] class CsvOutputWriter( new Path(path, f"part-r-$split%05d-$uniqueWriteJobId.csv$extension") } }.getRecordWriter(context) + // Write header even if `writeInternal()` is not called. + if (params.headerFlag) { + csvWriter.writeHeader() + text.set(csvWriter.flush()) + writer.write(NullWritable.get(), text) + } + writer } private val FLUSH_BATCH_SIZE = 1024L private var records: Long = 0L - private val csvWriter = new LineCsvWriter(params, dataSchema.fieldNames.toSeq) private def rowToString(row: InternalRow): Seq[String] = { var i = 0 @@ -245,7 +253,7 @@ private[csv] class CsvOutputWriter( override def write(row: Row): Unit = throw new UnsupportedOperationException("call writeInternal") override protected[sql] def writeInternal(row: InternalRow): Unit = { - csvWriter.writeRow(rowToString(row), records == 0L && params.headerFlag) + csvWriter.writeRow(rowToString(row)) records += 1 if (records % FLUSH_BATCH_SIZE == 0) { flush() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index 1930862118e9b..e38d81858f74a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -663,6 +663,23 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { assert(numbers.count() == 8) } + test("Write and read empty data with the header as the schema") { + withTempPath { path => + val emptyDf = spark.range(10).limit(0).toDF() + emptyDf.write + .format("csv") + .option("header", "true") + .save(path.getCanonicalPath) + + val copyEmptyDf = spark.read + .format("csv") + .option("header", "true") + .load(path.getCanonicalPath) + + checkAnswer(emptyDf, copyEmptyDf) + } + } + test("error handling for unsupported data types.") { withTempDir { dir => val csvDir = new File(dir, "csv").getCanonicalPath