Skip to content
Closed
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 13 additions & 24 deletions src/main/scala/com/databricks/spark/csv/CsvParser.scala
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
*/
package com.databricks.spark.csv


import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, SQLContext}
import org.apache.spark.sql.types.StructType
Expand Down Expand Up @@ -117,12 +116,12 @@ class CsvParser extends Serializable {
this
}

/** Returns a Schema RDD for the given CSV path. */
@throws[RuntimeException]
def csvFile(sqlContext: SQLContext, path: String): DataFrame = {
val relation: CsvRelation = CsvRelation(
() => TextFile.withCharset(sqlContext.sparkContext, path, charset),
Some(path),
/** Returns a csvRelation instance based on the state definition of csv parser. */
private[csv] def csvRelation(sqlContext: SQLContext, csvRDD: RDD[String],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just my personal thought. Could we maybe do this refactoring in a separate PR?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay. Would make more sense.

path: Option[String]): CsvRelation = {
CsvRelation(
() => csvRDD,
path,
useHeader,
delimiter,
quote,
Expand All @@ -137,27 +136,17 @@ class CsvParser extends Serializable {
inferSchema,
codec,
nullValue)(sqlContext)
}
/** Returns a Schema RDD for the given CSV path. */
@throws[RuntimeException]
def csvFile(sqlContext: SQLContext, path: String): DataFrame = {
val relation: CsvRelation = csvRelation(sqlContext,
TextFile.withCharset(sqlContext.sparkContext, path, charset), Some(path))
sqlContext.baseRelationToDataFrame(relation)
}

def csvRdd(sqlContext: SQLContext, csvRDD: RDD[String]): DataFrame = {
val relation: CsvRelation = CsvRelation(
() => csvRDD,
None,
useHeader,
delimiter,
quote,
escape,
comment,
parseMode,
parserLib,
ignoreLeadingWhiteSpace,
ignoreTrailingWhiteSpace,
treatEmptyValuesAsNulls,
schema,
inferSchema,
codec,
nullValue)(sqlContext)
val relation: CsvRelation = csvRelation(sqlContext, csvRDD, None)
sqlContext.baseRelationToDataFrame(relation)
}
}
15 changes: 8 additions & 7 deletions src/main/scala/com/databricks/spark/csv/util/InferSchema.scala
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,11 @@ private[csv] object InferSchema {
mergeRowTypes)

val structFields = header.zip(rootTypes).map { case (thisHeader, rootType) =>
StructField(thisHeader, rootType, nullable = true)
val dType = rootType match {
case z: NullType => StringType
case other => other
}
StructField(thisHeader, dType, nullable = true)
}

StructType(structFields)
Expand All @@ -62,11 +66,7 @@ private[csv] object InferSchema {
first: Array[DataType],
second: Array[DataType]): Array[DataType] = {
first.zipAll(second, NullType, NullType).map { case ((a, b)) =>
val tpe = findTightestCommonType(a, b).getOrElse(StringType)
tpe match {
case _: NullType => StringType
case other => other
}
findTightestCommonType(a, b).getOrElse(NullType)
}
}

Expand All @@ -93,7 +93,6 @@ private[csv] object InferSchema {
}
}


private def tryParseInteger(field: String): DataType = if ((allCatch opt field.toInt).isDefined) {
IntegerType
} else {
Expand Down Expand Up @@ -152,6 +151,8 @@ private[csv] object InferSchema {
case (t1, t2) if t1 == t2 => Some(t1)
case (NullType, t1) => Some(t1)
case (t1, NullType) => Some(t1)
case (StringType, t2) => Some(StringType)
case (t1, StringType) => Some(StringType)

// Promote numeric types to the highest of the two and all numeric types to unlimited decimal
case (t1, t2) if Seq(t1, t2).forall(numericPrecedence.contains) =>
Expand Down
5 changes: 5 additions & 0 deletions src/test/resources/simple.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
A,B,C,D
1,,,
,1,,
,,1,
,,,1
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,31 @@ package com.databricks.spark.csv.util

import org.apache.spark.sql.types._
import org.scalatest.FunSuite
import org.scalatest.BeforeAndAfterAll
import org.apache.spark.SparkContext
import org.apache.spark.sql.SQLContext
import com.databricks.spark.csv.CsvParser
import com.databricks.spark.csv.CsvRelation

class InferSchemaSuite extends FunSuite {
class InferSchemaSuite extends FunSuite with BeforeAndAfterAll {

private val simpleDatasetFile = "src/test/resources/simple.csv"
private val utf8Charset = "utf-8"
private var sqlContext: SQLContext = _

override def beforeAll(): Unit =
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These can be removed once you move your end-to-end test to CsvSuite.scala. This suite a unit test and does not need to create a SQLContext.

{
super.beforeAll()
sqlContext = new SQLContext(new SparkContext("local[2]", "InferSchemaSuite"))
}

override def afterAll(): Unit = {
try {
sqlContext.sparkContext.stop()
} finally {
super.afterAll()
}
}

test("String fields types are inferred correctly from null types") {
assert(InferSchema.inferField(NullType, "") == NullType)
Expand Down Expand Up @@ -40,6 +63,14 @@ class InferSchemaSuite extends FunSuite {
assert(InferSchema.inferField(LongType, "2015-08 14:49:00") == StringType)
}

test("Merging Nulltypes should yeild Nulltype.")
{
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto.

assert(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: the indent is off:

assert(
  InferSchema.mergeRowTypes(Array(NullType),
  Array(NullType)).deep == Array(NullType).deep)

InferSchema.mergeRowTypes(Array(NullType),
Array(NullType)).deep == Array(NullType).deep)

}

test("Type arrays are merged to highest common type") {
assert(
InferSchema.mergeRowTypes(Array(StringType),
Expand All @@ -52,4 +83,14 @@ class InferSchemaSuite extends FunSuite {
Array(LongType)).deep == Array(DoubleType).deep)
}

test("Type/Schema inference works as expected for the simple parse dataset.")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm.. Shouldn't this go to CsvSuite and remove with BeforeAndAfterAll ,beforeAll() and afterAll() as this test is a end-to-end test?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that CsvSuite performs all end to end tests but since there was a dedicated suite for SchemaInference, I did prefer to put the schema tests in there. Do you see any issues in that?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, I think it is okay though. I just said this because I see suites have been added in this way.

For example, #235 and #224

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reverted the refactoring part. Tests I kept as is. If more people feel CsvSuite would rather be the right place to put these tests in, I will make that change.

{
val df = new CsvParser().withUseHeader(true).withInferSchema(true)
.csvFile(sqlContext, simpleDatasetFile)
assert(
df.schema.fields.map{field => field.dataType}.deep ==
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: use two spaces for indent. of this line. Next line's indent is fine.

Array(IntegerType, IntegerType, IntegerType, IntegerType).deep
)

}
}