@@ -2701,7 +2701,8 @@ def setParams(self, inputCol=None, outputCol=None):
27012701
27022702
27032703@inherit_doc
2704- class VectorAssembler (JavaTransformer , HasInputCols , HasOutputCol , JavaMLReadable , JavaMLWritable ):
2704+ class VectorAssembler (JavaTransformer , HasInputCols , HasOutputCol , HasHandleInvalid , JavaMLReadable ,
2705+ JavaMLWritable ):
27052706 """
27062707 A feature transformer that merges multiple columns into a vector column.
27072708
@@ -2719,25 +2720,56 @@ class VectorAssembler(JavaTransformer, HasInputCols, HasOutputCol, JavaMLReadabl
27192720 >>> loadedAssembler = VectorAssembler.load(vectorAssemblerPath)
27202721 >>> loadedAssembler.transform(df).head().freqs == vecAssembler.transform(df).head().freqs
27212722 True
2723+ >>> dfWithNullsAndNaNs = spark.createDataFrame(
2724+ ... [(1.0, 2.0, None), (3.0, float("nan"), 4.0), (5.0, 6.0, 7.0)], ["a", "b", "c"])
2725+ >>> vecAssembler2 = VectorAssembler(inputCols=["a", "b", "c"], outputCol="features",
2726+ ... handleInvalid="keep")
2727+ >>> vecAssembler2.transform(dfWithNullsAndNaNs).show()
2728+ +---+---+----+-------------+
2729+ | a| b| c| features|
2730+ +---+---+----+-------------+
2731+ |1.0|2.0|null|[1.0,2.0,NaN]|
2732+ |3.0|NaN| 4.0|[3.0,NaN,4.0]|
2733+ |5.0|6.0| 7.0|[5.0,6.0,7.0]|
2734+ +---+---+----+-------------+
2735+ ...
2736+ >>> vecAssembler2.setParams(handleInvalid="skip").transform(dfWithNullsAndNaNs).show()
2737+ +---+---+---+-------------+
2738+ | a| b| c| features|
2739+ +---+---+---+-------------+
2740+ |5.0|6.0|7.0|[5.0,6.0,7.0]|
2741+ +---+---+---+-------------+
2742+ ...
27222743
27232744 .. versionadded:: 1.4.0
27242745 """
27252746
2747+ handleInvalid = Param (Params ._dummy (), "handleInvalid" , "How to handle invalid data (NULL " +
2748+ "and NaN values). Options are 'skip' (filter out rows with invalid " +
2749+ "data), 'error' (throw an error), or 'keep' (return relevant number " +
2750+ "of NaN in the output). Column lengths are taken from the size of ML " +
2751+ "Attribute Group, which can be set using `VectorSizeHint` in a " +
2752+ "pipeline before `VectorAssembler`. Column lengths can also be " +
2753+ "inferred from first rows of the data since it is safe to do so but " +
2754+ "only in case of 'error' or 'skip')." ,
2755+ typeConverter = TypeConverters .toString )
2756+
27262757 @keyword_only
2727- def __init__ (self , inputCols = None , outputCol = None ):
2758+ def __init__ (self , inputCols = None , outputCol = None , handleInvalid = "error" ):
27282759 """
2729- __init__(self, inputCols=None, outputCol=None)
2760+ __init__(self, inputCols=None, outputCol=None, handleInvalid="error" )
27302761 """
27312762 super (VectorAssembler , self ).__init__ ()
27322763 self ._java_obj = self ._new_java_obj ("org.apache.spark.ml.feature.VectorAssembler" , self .uid )
2764+ self ._setDefault (handleInvalid = "error" )
27332765 kwargs = self ._input_kwargs
27342766 self .setParams (** kwargs )
27352767
27362768 @keyword_only
27372769 @since ("1.4.0" )
2738- def setParams (self , inputCols = None , outputCol = None ):
2770+ def setParams (self , inputCols = None , outputCol = None , handleInvalid = "error" ):
27392771 """
2740- setParams(self, inputCols=None, outputCol=None)
2772+ setParams(self, inputCols=None, outputCol=None, handleInvalid="error" )
27412773 Sets params for this VectorAssembler.
27422774 """
27432775 kwargs = self ._input_kwargs
0 commit comments