Skip to content
Closed
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
.getOrElse(sparkSession.sessionState.conf.columnNameOfCorruptRecord)
val schema = userSpecifiedSchema.getOrElse {
InferSchema.infer(
jsonRDD,
sparkSession.createDataset(jsonRDD)(Encoders.STRING),
columnNameOfCorruptRecord,
parsedOptions)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,11 @@ import org.apache.hadoop.mapreduce._

import org.apache.spark.TaskContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.{Dataset, Encoders, SparkSession}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.util.CompressionCodecs
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.execution.datasources.text.TextFileFormat
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types._
import org.apache.spark.util.SerializableConfiguration
Expand All @@ -56,13 +57,16 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister {

// TODO: Move filtering.
val paths = files.filterNot(_.getPath.getName startsWith "_").map(_.getPath.toString)
val rdd = baseRdd(sparkSession, csvOptions, paths)
val firstLine = findFirstLine(csvOptions, rdd)
val lines: Dataset[String] = readText(sparkSession, csvOptions, paths)
val firstLine: String = findFirstLine(csvOptions, lines)
val firstRow = new CsvReader(csvOptions).parseLine(firstLine)
val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis
val header = makeSafeHeader(firstRow, csvOptions, caseSensitive)

val parsedRdd = tokenRdd(sparkSession, csvOptions, header, paths)
val parsedRdd: RDD[Array[String]] = CSVRelation.univocityTokenizer(
lines,
firstLine = if (csvOptions.headerFlag) firstLine else null,
params = csvOptions)
val schema = if (csvOptions.inferSchemaFlag) {
CSVInferSchema.infer(parsedRdd, header, csvOptions)
} else {
Expand Down Expand Up @@ -173,35 +177,17 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister {
}
}

private def baseRdd(
sparkSession: SparkSession,
options: CSVOptions,
inputPaths: Seq[String]): RDD[String] = {
readText(sparkSession, options, inputPaths.mkString(","))
}

private def tokenRdd(
sparkSession: SparkSession,
options: CSVOptions,
header: Array[String],
inputPaths: Seq[String]): RDD[Array[String]] = {
val rdd = baseRdd(sparkSession, options, inputPaths)
// Make sure firstLine is materialized before sending to executors
val firstLine = if (options.headerFlag) findFirstLine(options, rdd) else null
CSVRelation.univocityTokenizer(rdd, firstLine, options)
}

/**
* Returns the first line of the first non-empty file in path
*/
private def findFirstLine(options: CSVOptions, rdd: RDD[String]): String = {
private def findFirstLine(options: CSVOptions, lines: Dataset[String]): String = {
if (options.isCommentSet) {
val comment = options.comment.toString
rdd.filter { line =>
lines.filter { line =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using untyped filter can be more performant here since we don't need to pay for the extra de/serialization costs:

lines.filter(length(trim($"value")) > 0 && $"value".startsWith(comment))

line.trim.nonEmpty && !line.startsWith(comment)
}.first()
} else {
rdd.filter { line =>
lines.filter { line =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above.

line.trim.nonEmpty
}.first()
}
Expand All @@ -210,14 +196,21 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister {
private def readText(
sparkSession: SparkSession,
options: CSVOptions,
location: String): RDD[String] = {
inputPaths: Seq[String]): Dataset[String] = {
if (Charset.forName(options.charset) == StandardCharsets.UTF_8) {
sparkSession.sparkContext.textFile(location)
sparkSession.baseRelationToDataFrame(
DataSource.apply(
sparkSession,
paths = inputPaths,
className = classOf[TextFileFormat].getName
).resolveRelation(checkFilesExist = false))
.select("value").as[String](Encoders.STRING)
Copy link
Member

@HyukjinKwon HyukjinKwon Nov 9, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @JoshRosen, I just happened to look at this one and I am just curious. IIUC, the schema from the sparkSession.baseRelationToDataFrame will always has only value column not including partitioned columns (it is empty and also inputPaths will be always leaf files).

So, my question is, is that .select("value") used just to doubly make sure? Just curious.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I copied this logic from the text method in DataFrameReader, so that's where the value came from.

} else {
val charset = options.charset
sparkSession.sparkContext
.hadoopFile[LongWritable, Text, TextInputFormat](location)
val rdd = sparkSession.sparkContext
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JoshRosen do you know why the special handling for non-utf8 encoding is needed? I would think TextFileFormat itself already supports that since it is reading it in from Hadoop Text.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure; I think this was a carryover from spark-csv.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @falaki
Can you chime in?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rxin, I made a PR to address it at #29063 FYI.

.hadoopFile[LongWritable, Text, TextInputFormat](inputPaths.mkString(","))
.mapPartitions(_.map(pair => new String(pair._2.getBytes, 0, pair._2.getLength, charset)))
sparkSession.createDataset(rdd)(Encoders.STRING)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,12 @@ import org.apache.spark.sql.types._
object CSVRelation extends Logging {

def univocityTokenizer(
file: RDD[String],
file: Dataset[String],
firstLine: String,
params: CSVOptions): RDD[Array[String]] = {
// If header is set, make sure firstLine is materialized before sending to executors.
val commentPrefix = params.comment.toString
file.mapPartitions { iter =>
file.rdd.mapPartitions { iter =>
val parser = new CsvReader(params)
val filteredIter = iter.filter { line =>
line.trim.nonEmpty && !line.startsWith(commentPrefix)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import java.util.Comparator

import com.fasterxml.jackson.core._

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.Dataset
import org.apache.spark.sql.catalyst.analysis.TypeCoercion
import org.apache.spark.sql.catalyst.json.JacksonUtils.nextUntil
import org.apache.spark.sql.catalyst.json.JSONOptions
Expand All @@ -37,7 +37,7 @@ private[sql] object InferSchema {
* 3. Replace any remaining null fields with string, the top type
*/
def infer(
json: RDD[String],
json: Dataset[String],
columnNameOfCorruptRecord: String,
configOptions: JSONOptions): StructType = {
require(configOptions.samplingRatio > 0,
Expand All @@ -50,7 +50,7 @@ private[sql] object InferSchema {
}

// perform schema inference on each row and merge afterwards
val rootType = schemaData.mapPartitions { iter =>
val rootType = schemaData.rdd.mapPartitions { iter =>
val factory = new JsonFactory()
configOptions.setJacksonOptions(factory)
iter.flatMap { row =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,21 +21,18 @@ import java.io.CharArrayWriter

import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileStatus, Path}
import org.apache.hadoop.io.{LongWritable, NullWritable, Text}
import org.apache.hadoop.mapred.{JobConf, TextInputFormat}
import org.apache.hadoop.io.{NullWritable, Text}
import org.apache.hadoop.mapreduce.{Job, RecordWriter, TaskAttemptContext}
import org.apache.hadoop.mapreduce.lib.input.FileInputFormat
import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat

import org.apache.spark.TaskContext
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{AnalysisException, Row, SparkSession}
import org.apache.spark.sql.{AnalysisException, Encoders, Row, SparkSession}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.json.{JacksonGenerator, JacksonParser, JSONOptions}
import org.apache.spark.sql.catalyst.util.CompressionCodecs
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.execution.datasources.text.TextOutputWriter
import org.apache.spark.sql.execution.datasources.text.{TextFileFormat, TextOutputWriter}
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.SerializableConfiguration
Expand All @@ -55,13 +52,21 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister {
val columnNameOfCorruptRecord =
parsedOptions.columnNameOfCorruptRecord
.getOrElse(sparkSession.sessionState.conf.columnNameOfCorruptRecord)
val jsonFiles = files.filterNot { status =>
val jsonFiles: Seq[String] = files.filterNot { status =>
val name = status.getPath.getName
(name.startsWith("_") && !name.contains("=")) || name.startsWith(".")
}.toArray
}.map(_.getPath.toString)

val lines = sparkSession.baseRelationToDataFrame(
DataSource.apply(
sparkSession,
paths = jsonFiles,
className = classOf[TextFileFormat].getName
).resolveRelation(checkFilesExist = false))
.select("value").as[String](Encoders.STRING)

val jsonSchema = InferSchema.infer(
createBaseRdd(sparkSession, jsonFiles),
lines,
columnNameOfCorruptRecord,
parsedOptions)
checkConstraints(jsonSchema)
Expand Down Expand Up @@ -119,25 +124,6 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister {
}
}

private def createBaseRdd(
sparkSession: SparkSession,
inputPaths: Seq[FileStatus]): RDD[String] = {
val job = Job.getInstance(sparkSession.sessionState.newHadoopConf())
val conf = job.getConfiguration

val paths = inputPaths.map(_.getPath)

if (paths.nonEmpty) {
FileInputFormat.setInputPaths(job, paths: _*)
}

sparkSession.sparkContext.hadoopRDD(
conf.asInstanceOf[JobConf],
classOf[TextInputFormat],
classOf[LongWritable],
classOf[Text]).map(_._2.toString) // get the text line
}

/** Constraints to be imposed on schema to be stored. */
private def checkConstraints(schema: StructType): Unit = {
if (schema.fieldNames.length != schema.fieldNames.distinct.length) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.sql.execution.datasources.json

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.{Dataset, Encoders, SparkSession}

private[json] trait TestJsonData {
protected def spark: SparkSession
Expand Down Expand Up @@ -196,14 +196,14 @@ private[json] trait TestJsonData {
"""42""" ::
""" ","ian":"test"}""" :: Nil)

def emptyRecords: RDD[String] =
spark.sparkContext.parallelize(
def emptyRecords: Dataset[String] =
spark.createDataset(
"""{""" ::
"""""" ::
"""{"a": {}}""" ::
"""{"a": {"b": {}}}""" ::
"""{"b": [{"c": {}}]}""" ::
"""]""" :: Nil)
"""]""" :: Nil)(Encoders.STRING)

def timestampAsLong: RDD[String] =
spark.sparkContext.parallelize(
Expand All @@ -230,5 +230,5 @@ private[json] trait TestJsonData {

lazy val singleRow: RDD[String] = spark.sparkContext.parallelize("""{"a":123}""" :: Nil)

def empty: RDD[String] = spark.sparkContext.parallelize(Seq[String]())
def empty: Dataset[String] = spark.createDataset(Seq[String]())(Encoders.STRING)
}