Skip to content
Closed
Original file line number Diff line number Diff line change
Expand Up @@ -456,11 +456,23 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
val keyExpr = df.col(col.name).expr
def buildExpr(v: Any) = Cast(Literal(v), keyExpr.dataType)
val branches = replacementMap.flatMap { case (source, target) =>
Seq(buildExpr(source), buildExpr(target))
if (isNaN(source) || isNaN(target)) {
col.dataType match {
case IntegerType | LongType | ShortType | ByteType => Seq.empty
case _ => Seq(buildExpr(source), buildExpr(target))
}
} else {
Seq(buildExpr(source), buildExpr(target))
}
}.toSeq
new Column(CaseKeyWhen(keyExpr, branches :+ keyExpr)).as(col.name)
}

// Check if NaN
private[this] def isNaN(value: Any): Boolean =
(value.isInstanceOf[Double] && value.asInstanceOf[Double].isNaN) ||
(value.isInstanceOf[Float] && value.asInstanceOf[Float].isNaN)

private def convertToDouble(v: Any): Double = v match {
case v: Float => v.toDouble
case v: Double => v
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,16 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSparkSession {
).toDF("name", "age", "height")
}

def createNaNDF(): DataFrame = {
Seq[(java.lang.Integer, java.lang.Long, java.lang.Short,
java.lang.Byte, java.lang.Float, java.lang.Double)](
(1, new java.lang.Long(1), new java.lang.Short("1"),
new java.lang.Byte("1"), new java.lang.Float(1.0), 1.0),
(0, new java.lang.Long(0), new java.lang.Short("0"),
new java.lang.Byte("0"), java.lang.Float.NaN, java.lang.Double.NaN)
).toDF("int", "long", "short", "byte", "float", "double")
}

test("drop") {
val input = createDF()
val rows = input.collect()
Expand Down Expand Up @@ -404,4 +414,48 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSparkSession {
df.na.drop("any"),
Row("5", "6", "6") :: Nil)
}

test("replace nan with float") {
checkAnswer(
createNaNDF().na.replace("*", Map(
Float.NaN -> 10f
)),
Row(1, new java.lang.Long(1), new java.lang.Short("1"),
new java.lang.Byte("1"), new java.lang.Float(1.0), 1.0) ::
Row(0, new java.lang.Long(0), new java.lang.Short("0"),
new java.lang.Byte("0"), new java.lang.Float(10), new java.lang.Double(10)) :: Nil)
}

test("replace nan with double") {
checkAnswer(
createNaNDF().na.replace("*", Map(
Double.NaN -> 10.toDouble
)),
Row(1, new java.lang.Long(1), new java.lang.Short("1"),
new java.lang.Byte("1"), new java.lang.Float(1.0), 1.0) ::
Row(0, new java.lang.Long(0), new java.lang.Short("0"),
new java.lang.Byte("0"), new java.lang.Float(10), new java.lang.Double(10)) :: Nil)
}

test("replace float with nan") {
checkAnswer(
createNaNDF().na.replace("*", Map(
1.0f -> Float.NaN
)),
Row(1, new java.lang.Long(1), new java.lang.Short("1"),
new java.lang.Byte("1"), java.lang.Float.NaN, java.lang.Double.NaN) ::
Row(0, new java.lang.Long(0), new java.lang.Short("0"),
new java.lang.Byte("0"), java.lang.Float.NaN, java.lang.Double.NaN) :: Nil)
}

test("replace double with nan") {
checkAnswer(
createNaNDF().na.replace("*", Map(
1.toDouble -> Double.NaN
)),
Row(1, new java.lang.Long(1), new java.lang.Short("1"),
new java.lang.Byte("1"), java.lang.Float.NaN, java.lang.Double.NaN) ::
Row(0, new java.lang.Long(0), new java.lang.Short("0"),
new java.lang.Byte("0"), java.lang.Float.NaN, java.lang.Double.NaN) :: Nil)
}
}