diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 7b67985f2b320..b57bb97e4e84a 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1336,8 +1336,8 @@ def replace(self, to_replace, value=None, subset=None): Value to be replaced. If the value is a dict, then `value` is ignored and `to_replace` must be a mapping between a value and a replacement. - :param value: int, long, float, string, or list. - The replacement value must be an int, long, float, or string. If `value` is a + :param value: int, long, float, string, list or None. + The replacement value must be an int, long, float, string or None. If `value` is a list, `value` should be of the same length and type as `to_replace`. If `value` is a scalar and `to_replace` is a sequence, then `value` is used as a replacement for each item in `to_replace`. @@ -1356,6 +1356,16 @@ def replace(self, to_replace, value=None, subset=None): |null| null| null| +----+------+-----+ + >>> df4.na.replace('Alice', None).show() + +----+------+----+ + | age|height|name| + +----+------+----+ + | 10| 80|null| + | 5| null| Bob| + |null| null| Tom| + |null| null|null| + +----+------+----+ + >>> df4.na.replace(['Alice', 'Bob'], ['A', 'B'], 'name').show() +----+------+----+ | age|height|name| @@ -1391,9 +1401,10 @@ def all_of_(xs): "to_replace should be a float, int, long, string, list, tuple, or dict. " "Got {0}".format(type(to_replace))) - if not isinstance(value, valid_types) and not isinstance(to_replace, dict): + if not isinstance(value, valid_types) and value is not None \ + and not isinstance(to_replace, dict): raise ValueError("If to_replace is not a dict, value should be " - "a float, int, long, string, list, or tuple. " + "a float, int, long, string, list, tuple or None. " "Got {0}".format(type(value))) if isinstance(to_replace, (list, tuple)) and isinstance(value, (list, tuple)): @@ -1409,7 +1420,7 @@ def all_of_(xs): if isinstance(to_replace, (float, int, long, basestring)): to_replace = [to_replace] - if isinstance(value, (float, int, long, basestring)): + if isinstance(value, (float, int, long, basestring)) or value is None: value = [value for _ in range(len(to_replace))] if isinstance(to_replace, dict): @@ -1423,7 +1434,9 @@ def all_of_(xs): subset = [subset] # Verify we were not passed in mixed type generics." - if not any(all_of_type(rep_dict.keys()) and all_of_type(rep_dict.values()) + if not any(all_of_type(rep_dict.keys()) + and (all_of_type(rep_dict.values()) + or list(rep_dict.values()).count(None) == len(rep_dict)) for all_of_type in [all_of_bool, all_of_str, all_of_numeric]): raise ValueError("Mixed type replacements are not supported") diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index acea9113ee858..509463837a7a1 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -1851,6 +1851,11 @@ def test_replace(self): .replace(False, True).first()) self.assertTupleEqual(row, (True, True)) + # replace with None + row = self.spark.createDataFrame( + [(u'Alice', 10, 80.0)], schema).replace((10, 80), None).first() + self.assertTupleEqual(row, (u'Alice', None, None)) + # should fail if subset is not list, tuple or None with self.assertRaises(ValueError): self.spark.createDataFrame( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala index 052d85ad33bd6..8e2b01417a6ec 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala @@ -319,8 +319,10 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { /** * (Scala-specific) Replaces values matching keys in `replacement` map. - * Key and value of `replacement` map must have the same type, and - * can only be doubles , strings or booleans. + * Key and value of `replacement` map must satisfy one of: + * 1. keys are String, values are mix of String and null + * 2. keys are Boolean, values are mix of Boolean and null + * 3. keys are Double, values are either all Double or all null * * {{{ * // Replaces all occurrences of 1.0 with 2.0 in column "height" and "weight". @@ -342,8 +344,10 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { return df } - // replacementMap is either Map[String, String] or Map[Double, Double] or Map[Boolean,Boolean] + // replacementMap is either Map[String, String], Map[Double, Double], Map[Boolean,Boolean] + // or value being null val replacementMap: Map[_, _] = replacement.head._2 match { + case null => replacement case v: String => replacement case v: Boolean => replacement case _ => replacement.map { case (k, v) => (convertToDouble(k), convertToDouble(v)) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala index aa237d0619ac3..f1f498b2b9e6d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala @@ -222,16 +222,16 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSQLContext { assert(out(4) === Row("Amy", null, null)) assert(out(5) === Row(null, null, null)) - // Replace only the age column - val out1 = input.na.replace("age", Map( - 16 -> 61, - 60 -> 6, - 164.3 -> 461.3 // Alice is really tall + // Replace only the name column + val out1 = input.na.replace("name", Map( + "Bob" -> "Bravo", + "Alice" -> "Jessie", + "David" -> null )).collect() - assert(out1(0) === Row("Bob", 61, 176.5)) - assert(out1(1) === Row("Alice", null, 164.3)) - assert(out1(2) === Row("David", 6, null)) + assert(out1(0) === Row("Bravo", 16, 176.5)) + assert(out1(1) === Row("Jessie", null, 164.3)) + assert(out1(2) === Row(null, 60, null)) assert(out1(3).get(2).asInstanceOf[Double].isNaN) assert(out1(4) === Row("Amy", null, null)) assert(out1(5) === Row(null, null, null))