Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -85,12 +85,9 @@ private[csv] class LineCsvWriter(params: CSVOptions, headers: Seq[String]) exten
private val buffer = new CharArrayWriter()
private val writer = new CsvWriter(buffer, writerSettings)

def writeRow(row: Seq[String], includeHeader: Boolean): Unit = {
if (includeHeader) {
writer.writeHeaders()
}
writer.writeRow(row.toArray: _*)
}
def writeHeader(): Unit = writer.writeHeaders()

def writeRow(row: Seq[String]): Unit = writer.writeRow(row.toArray: _*)

def flush(): String = {
writer.flush()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,8 @@ private[csv] class CsvOutputWriter(
// create the Generator without separator inserted between 2 records
private[this] val text = new Text()

private val csvWriter = new LineCsvWriter(params, dataSchema.fieldNames.toSeq)

// A `ValueConverter` is responsible for converting a value of an `InternalRow` to `String`.
// When the value is null, this converter should not be called.
private type ValueConverter = (InternalRow, Int) => String
Expand All @@ -197,7 +199,7 @@ private[csv] class CsvOutputWriter(
dataSchema.map(_.dataType).map(makeConverter).toArray

private val recordWriter: RecordWriter[NullWritable, Text] = {
new TextOutputFormat[NullWritable, Text]() {
val writer = new TextOutputFormat[NullWritable, Text]() {
override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = {
val configuration = context.getConfiguration
val uniqueWriteJobId = configuration.get(WriterContainer.DATASOURCE_WRITEJOBUUID)
Expand All @@ -206,11 +208,17 @@ private[csv] class CsvOutputWriter(
new Path(path, f"part-r-$split%05d-$uniqueWriteJobId.csv$extension")
}
}.getRecordWriter(context)
// Write header even if `writeInternal()` is not called.
if (params.headerFlag) {
csvWriter.writeHeader()
text.set(csvWriter.flush())
writer.write(NullWritable.get(), text)
}
writer
}

private val FLUSH_BATCH_SIZE = 1024L
private var records: Long = 0L
private val csvWriter = new LineCsvWriter(params, dataSchema.fieldNames.toSeq)

private def rowToString(row: InternalRow): Seq[String] = {
var i = 0
Expand Down Expand Up @@ -245,7 +253,7 @@ private[csv] class CsvOutputWriter(
override def write(row: Row): Unit = throw new UnsupportedOperationException("call writeInternal")

override protected[sql] def writeInternal(row: InternalRow): Unit = {
csvWriter.writeRow(rowToString(row), records == 0L && params.headerFlag)
csvWriter.writeRow(rowToString(row))
records += 1
if (records % FLUSH_BATCH_SIZE == 0) {
flush()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -663,6 +663,23 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
assert(numbers.count() == 8)
}

test("Write and read empty data with the header as the schema") {
withTempPath { path =>
val emptyDf = spark.range(10).limit(0).toDF()
emptyDf.write
.format("csv")
.option("header", "true")
.save(path.getCanonicalPath)

val copyEmptyDf = spark.read
.format("csv")
.option("header", "true")
.load(path.getCanonicalPath)

checkAnswer(emptyDf, copyEmptyDf)
}
}

test("error handling for unsupported data types.") {
withTempDir { dir =>
val csvDir = new File(dir, "csv").getCanonicalPath
Expand Down