Skip to content

Commit 4f1e8b9

Browse files
huaxingaojkbradley
authored andcommitted
[SPARK-23871][ML][PYTHON] add python api for VectorAssembler handleInvalid
## What changes were proposed in this pull request? add python api for VectorAssembler handleInvalid ## How was this patch tested? Add doctest Author: Huaxin Gao <huaxing@us.ibm.com> Closes #21003 from huaxingao/spark-23871.
1 parent adb222b commit 4f1e8b9

2 files changed

Lines changed: 43 additions & 11 deletions

File tree

mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -71,12 +71,12 @@ class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String)
7171
*/
7272
@Since("2.4.0")
7373
override val handleInvalid: Param[String] = new Param[String](this, "handleInvalid",
74-
"""Param for how to handle invalid data (NULL values). Options are 'skip' (filter out rows with
75-
|invalid data), 'error' (throw an error), or 'keep' (return relevant number of NaN in the
76-
|output). Column lengths are taken from the size of ML Attribute Group, which can be set using
77-
|`VectorSizeHint` in a pipeline before `VectorAssembler`. Column lengths can also be inferred
78-
|from first rows of the data since it is safe to do so but only in case of 'error' or 'skip'.
79-
|""".stripMargin.replaceAll("\n", " "),
74+
"""Param for how to handle invalid data (NULL and NaN values). Options are 'skip' (filter out
75+
|rows with invalid data), 'error' (throw an error), or 'keep' (return relevant number of NaN
76+
|in the output). Column lengths are taken from the size of ML Attribute Group, which can be
77+
|set using `VectorSizeHint` in a pipeline before `VectorAssembler`. Column lengths can also
78+
|be inferred from first rows of the data since it is safe to do so but only in case of 'error'
79+
|or 'skip'.""".stripMargin.replaceAll("\n", " "),
8080
ParamValidators.inArray(VectorAssembler.supportedHandleInvalids))
8181

8282
setDefault(handleInvalid, VectorAssembler.ERROR_INVALID)

python/pyspark/ml/feature.py

Lines changed: 37 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2701,7 +2701,8 @@ def setParams(self, inputCol=None, outputCol=None):
27012701

27022702

27032703
@inherit_doc
2704-
class VectorAssembler(JavaTransformer, HasInputCols, HasOutputCol, JavaMLReadable, JavaMLWritable):
2704+
class VectorAssembler(JavaTransformer, HasInputCols, HasOutputCol, HasHandleInvalid, JavaMLReadable,
2705+
JavaMLWritable):
27052706
"""
27062707
A feature transformer that merges multiple columns into a vector column.
27072708
@@ -2719,25 +2720,56 @@ class VectorAssembler(JavaTransformer, HasInputCols, HasOutputCol, JavaMLReadabl
27192720
>>> loadedAssembler = VectorAssembler.load(vectorAssemblerPath)
27202721
>>> loadedAssembler.transform(df).head().freqs == vecAssembler.transform(df).head().freqs
27212722
True
2723+
>>> dfWithNullsAndNaNs = spark.createDataFrame(
2724+
... [(1.0, 2.0, None), (3.0, float("nan"), 4.0), (5.0, 6.0, 7.0)], ["a", "b", "c"])
2725+
>>> vecAssembler2 = VectorAssembler(inputCols=["a", "b", "c"], outputCol="features",
2726+
... handleInvalid="keep")
2727+
>>> vecAssembler2.transform(dfWithNullsAndNaNs).show()
2728+
+---+---+----+-------------+
2729+
| a| b| c| features|
2730+
+---+---+----+-------------+
2731+
|1.0|2.0|null|[1.0,2.0,NaN]|
2732+
|3.0|NaN| 4.0|[3.0,NaN,4.0]|
2733+
|5.0|6.0| 7.0|[5.0,6.0,7.0]|
2734+
+---+---+----+-------------+
2735+
...
2736+
>>> vecAssembler2.setParams(handleInvalid="skip").transform(dfWithNullsAndNaNs).show()
2737+
+---+---+---+-------------+
2738+
| a| b| c| features|
2739+
+---+---+---+-------------+
2740+
|5.0|6.0|7.0|[5.0,6.0,7.0]|
2741+
+---+---+---+-------------+
2742+
...
27222743
27232744
.. versionadded:: 1.4.0
27242745
"""
27252746

2747+
handleInvalid = Param(Params._dummy(), "handleInvalid", "How to handle invalid data (NULL " +
2748+
"and NaN values). Options are 'skip' (filter out rows with invalid " +
2749+
"data), 'error' (throw an error), or 'keep' (return relevant number " +
2750+
"of NaN in the output). Column lengths are taken from the size of ML " +
2751+
"Attribute Group, which can be set using `VectorSizeHint` in a " +
2752+
"pipeline before `VectorAssembler`. Column lengths can also be " +
2753+
"inferred from first rows of the data since it is safe to do so but " +
2754+
"only in case of 'error' or 'skip').",
2755+
typeConverter=TypeConverters.toString)
2756+
27262757
@keyword_only
2727-
def __init__(self, inputCols=None, outputCol=None):
2758+
def __init__(self, inputCols=None, outputCol=None, handleInvalid="error"):
27282759
"""
2729-
__init__(self, inputCols=None, outputCol=None)
2760+
__init__(self, inputCols=None, outputCol=None, handleInvalid="error")
27302761
"""
27312762
super(VectorAssembler, self).__init__()
27322763
self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.VectorAssembler", self.uid)
2764+
self._setDefault(handleInvalid="error")
27332765
kwargs = self._input_kwargs
27342766
self.setParams(**kwargs)
27352767

27362768
@keyword_only
27372769
@since("1.4.0")
2738-
def setParams(self, inputCols=None, outputCol=None):
2770+
def setParams(self, inputCols=None, outputCol=None, handleInvalid="error"):
27392771
"""
2740-
setParams(self, inputCols=None, outputCol=None)
2772+
setParams(self, inputCols=None, outputCol=None, handleInvalid="error")
27412773
Sets params for this VectorAssembler.
27422774
"""
27432775
kwargs = self._input_kwargs

0 commit comments

Comments
 (0)