Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/main/scala/com/databricks/spark/csv/CsvRelation.scala
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,8 @@ case class CsvRelation protected[spark] (
try {
index = 0
while (index < schemaFields.length) {
rowArray(index) = TypeCast.castTo(tokens(index), schemaFields(index).dataType)
val field = schemaFields(index)
rowArray(index) = TypeCast.castTo(tokens(index), field.dataType, field.nullable)
index = index + 1
}
Some(Row.fromSeq(rowArray))
Expand Down
40 changes: 24 additions & 16 deletions src/main/scala/com/databricks/spark/csv/util/TypeCast.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,25 +29,33 @@ object TypeCast {
* Casts given string datum to specified type.
* Currently we do not support complex types (ArrayType, MapType, StructType).
*
* For string types, this is simply the datum. For other types.
* For other nullable types, this is null if the string datum is empty.
*
* @param datum string value
* @param castType SparkSQL type
*/
private[csv] def castTo(datum: String, castType: DataType): Any = {
castType match {
case _: ByteType => datum.toByte
case _: ShortType => datum.toShort
case _: IntegerType => datum.toInt
case _: LongType => datum.toLong
case _: FloatType => datum.toFloat
case _: DoubleType => datum.toDouble
case _: BooleanType => datum.toBoolean
case _: DecimalType => new BigDecimal(datum.replaceAll(",", ""))
// TODO(hossein): would be good to support other common timestamp formats
case _: TimestampType => Timestamp.valueOf(datum)
// TODO(hossein): would be good to support other common date formats
case _: DateType => Date.valueOf(datum)
case _: StringType => datum
case _ => throw new RuntimeException(s"Unsupported type: ${castType.typeName}")
private[csv] def castTo(datum: String, castType: DataType, nullable: Boolean = true): Any = {
if (castType.isInstanceOf[StringType]){
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this would be simpler:

if (datum == "" && nullable && !castType.isInstanceOf[StringType]) {
  null
} else {
  castType match {
      case _: ByteType => datum.toByte
      case _: ShortType => datum.toShort
      case _: IntegerType => datum.toInt
      case _: LongType => datum.toLong
      case _: FloatType => datum.toFloat
      case _: DoubleType => datum.toDouble
      case _: BooleanType => datum.toBoolean
      case _: DecimalType => new BigDecimal(datum.replaceAll(",", ""))
      // TODO(hossein): would be good to support other common timestamp formats
      case _: TimestampType => Timestamp.valueOf(datum)
      // TODO(hossein): would be good to support other common date formats
      case _: DateType => Date.valueOf(datum)
      case _: StringType => datum
      case _ => throw new RuntimeException(s"Unsupported type: ${castType.typeName}")
  }
}

datum
} else if (nullable && datum == ""){
null
} else {
castType match {
case _: ByteType => datum.toByte
case _: ShortType => datum.toShort
case _: IntegerType => datum.toInt
case _: LongType => datum.toLong
case _: FloatType => datum.toFloat
case _: DoubleType => datum.toDouble
case _: BooleanType => datum.toBoolean
case _: DecimalType => new BigDecimal(datum.replaceAll(",", ""))
// TODO(hossein): would be good to support other common timestamp formats
case _: TimestampType => Timestamp.valueOf(datum)
// TODO(hossein): would be good to support other common date formats
case _: DateType => Date.valueOf(datum)
case _ => throw new RuntimeException(s"Unsupported type: ${castType.typeName}")
}
}
}

Expand Down
4 changes: 4 additions & 0 deletions src/test/resources/null-numbers.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
name,age
alice,35
bob,
,24
16 changes: 16 additions & 0 deletions src/test/scala/com/databricks/spark/csv/CsvFastSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class CsvFastSuite extends FunSuite {
val carsFile8859 = "src/test/resources/cars_iso-8859-1.csv"
val carsTsvFile = "src/test/resources/cars.tsv"
val carsAltFile = "src/test/resources/cars-alternative.csv"
val nullNumbersFile = "src/test/resources/null-numbers.csv"
val emptyFile = "src/test/resources/empty.csv"
val escapeFile = "src/test/resources/escape.csv"
val tempEmptyDir = "target/test/empty2/"
Expand Down Expand Up @@ -387,4 +388,19 @@ class CsvFastSuite extends FunSuite {
assert(results.first().getInt(0) === 1997)

}

test("DSL test nullable fields"){

val results = new CsvParser()
.withSchema(StructType(List(StructField("name", StringType, false), StructField("age", IntegerType, true))))
.withUseHeader(true)
.withParserLib("univocity")
.csvFile(TestSQLContext, nullNumbersFile)
.collect()

assert(results.head.toSeq == Seq("alice", 35))
assert(results(1).toSeq == Seq("bob", null))
assert(results(2).toSeq == Seq("", 24))

}
}
15 changes: 15 additions & 0 deletions src/test/scala/com/databricks/spark/csv/CsvSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class CsvSuite extends FunSuite {
val carsFile8859 = "src/test/resources/cars_iso-8859-1.csv"
val carsTsvFile = "src/test/resources/cars.tsv"
val carsAltFile = "src/test/resources/cars-alternative.csv"
val nullNumbersFile = "src/test/resources/null-numbers.csv"
val emptyFile = "src/test/resources/empty.csv"
val escapeFile = "src/test/resources/escape.csv"
val tempEmptyDir = "target/test/empty/"
Expand Down Expand Up @@ -392,4 +393,18 @@ class CsvSuite extends FunSuite {
assert(results.first().getInt(0) === 1997)

}

test("DSL test nullable fields"){

val results = new CsvParser()
.withSchema(StructType(List(StructField("name", StringType, false), StructField("age", IntegerType, true))))
.withUseHeader(true)
.csvFile(TestSQLContext, nullNumbersFile)
.collect()

assert(results.head.toSeq == Seq("alice", 35))
assert(results(1).toSeq == Seq("bob", null))
assert(results(2).toSeq == Seq("", 24))

}
}
32 changes: 31 additions & 1 deletion src/test/scala/com/databricks/spark/csv/util/TypeCastSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,11 @@
package com.databricks.spark.csv.util

import java.math.BigDecimal
import java.sql.{Date, Timestamp}

import org.scalatest.FunSuite

import org.apache.spark.sql.types.DecimalType
import org.apache.spark.sql.types._

class TypeCastSuite extends FunSuite {

Expand Down Expand Up @@ -56,4 +57,33 @@ class TypeCastSuite extends FunSuite {
}
assert(exception.getMessage.contains("Unsupported special character for delimiter"))
}

test("Nullable types are handled"){
assert(TypeCast.castTo("", IntegerType, nullable = true) == null)
}

test("String type should always return the same as the input"){
assert(TypeCast.castTo("", StringType, nullable = true) == "")
assert(TypeCast.castTo("", StringType, nullable = false) == "")
}

test("Throws exception for empty string with non null type"){
val exception = intercept[NumberFormatException]{
TypeCast.castTo("", IntegerType, nullable = false)
}
assert(exception.getMessage.contains("For input string: \"\""))
}

test("Types are cast correctly"){
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding these!

assert(TypeCast.castTo("10", ByteType) == 10)
assert(TypeCast.castTo("10", ShortType) == 10)
assert(TypeCast.castTo("10", IntegerType) == 10)
assert(TypeCast.castTo("10", LongType) == 10)
assert(TypeCast.castTo("1.00", FloatType) == 1.0)
assert(TypeCast.castTo("1.00", DoubleType) == 1.0)
assert(TypeCast.castTo("true", BooleanType) == true)
val timestamp = "2015-01-01 00:00:00"
assert(TypeCast.castTo(timestamp, TimestampType) == Timestamp.valueOf(timestamp))
assert(TypeCast.castTo("2015-01-01", DateType) == Date.valueOf("2015-01-01"))
}
}