-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-13568] [ML] Create feature transformer to impute missing values #11601
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
2999b26
8335cf2
7be5e9b
131f7d5
a72a3ea
78df589
b949be5
79b1c62
c3d5d55
1b39668
7f87ffb
4e45f81
e1dd0d2
12220eb
1b36deb
72d104d
c311b2e
d6b9421
d181b12
e211481
791533b
fdd6f94
8042cfb
4bdf595
e6ad69c
1718422
594c501
3043e7d
b3633e8
053d489
63e7032
4e1c34a
051aec6
aef094b
335ded7
949ed79
93bba63
cca8dd4
eea8947
4e07431
31556e6
d4f92e4
544a65c
910685e
91d4cee
8744524
ca45c33
e86d919
4f17c54
ce59a5b
41d91b9
9f6bd57
e378db5
c67afc1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -37,8 +37,7 @@ private[feature] trait ImputerParams extends Params with HasInputCols { | |
| /** | ||
| * The imputation strategy. | ||
| * If "mean", then replace missing values using the mean value of the feature. | ||
| * If "median", then replace missing values using the approximate median value of the | ||
| * feature (relative error less than 0.001). | ||
| * If "median", then replace missing values using the approximate median value of the feature. | ||
| * Default: mean | ||
| * | ||
| * @group param | ||
|
|
@@ -76,10 +75,10 @@ private[feature] trait ImputerParams extends Params with HasInputCols { | |
|
|
||
| /** Validates and transforms the input schema. */ | ||
| protected def validateAndTransformSchema(schema: StructType): StructType = { | ||
| require($(inputCols).length == $(inputCols).distinct.length, s"inputCols duplicates:" + | ||
| s" (${$(inputCols).mkString(", ")})") | ||
| require($(outputCols).length == $(outputCols).distinct.length, s"outputCols duplicates:" + | ||
| s" (${$(outputCols).mkString(", ")})") | ||
| require($(inputCols).length == $(inputCols).distinct.length, s"inputCols contains" + | ||
| s" duplicates: (${$(inputCols).mkString(", ")})") | ||
| require($(outputCols).length == $(outputCols).distinct.length, s"outputCols contains" + | ||
| s" duplicates: (${$(outputCols).mkString(", ")})") | ||
| require($(inputCols).length == $(outputCols).length, s"inputCols(${$(inputCols).length})" + | ||
| s" and outputCols(${$(outputCols).length}) should have the same length") | ||
| val outputFields = $(inputCols).zip($(outputCols)).map { case (inputCol, outputCol) => | ||
|
|
@@ -99,7 +98,8 @@ private[feature] trait ImputerParams extends Params with HasInputCols { | |
| * (SPARK-15041) and possibly creates incorrect values for a categorical feature. | ||
| * | ||
| * Note that the mean/median value is computed after filtering out missing values. | ||
| * All Null values in the input column are treated as missing, and so are also imputed. | ||
| * All Null values in the input column are treated as missing, and so are also imputed. For | ||
| * computing median, DataFrameStatFunctions.approxQuantile is used with a relative error of 0.001. | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah I see it is here - nevermind |
||
| */ | ||
| @Experimental | ||
| class Imputer @Since("2.2.0")(override val uid: String) | ||
|
|
@@ -127,8 +127,7 @@ class Imputer @Since("2.2.0")(override val uid: String) | |
| @Since("2.2.0") | ||
| def setMissingValue(value: Double): this.type = set(missingValue, value) | ||
|
|
||
| import org.apache.spark.ml.feature.Imputer._ | ||
| setDefault(strategy -> mean, missingValue -> Double.NaN) | ||
| setDefault(strategy -> Imputer.mean, missingValue -> Double.NaN) | ||
|
|
||
| override def fit(dataset: Dataset[_]): ImputerModel = { | ||
| transformSchema(dataset.schema, logging = true) | ||
|
|
@@ -197,7 +196,7 @@ class ImputerModel private[ml]( | |
| override def transform(dataset: Dataset[_]): DataFrame = { | ||
| transformSchema(dataset.schema, logging = true) | ||
| var outputDF = dataset | ||
| val surrogates = surrogateDF.select($(inputCols).head, $(inputCols).tail: _*).head().toSeq | ||
| val surrogates = surrogateDF.select($(inputCols).map(col): _*).head().toSeq | ||
|
|
||
| $(inputCols).zip($(outputCols)).zip(surrogates).foreach { | ||
| case ((inputCol, outputCol), surrogate) => | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -85,33 +85,51 @@ class ImputerSuite extends SparkFunSuite with MLlibTestSparkContext with Default | |
| Seq("mean", "median").foreach { strategy => | ||
| val imputer = new Imputer().setInputCols(Array("value")).setOutputCols(Array("out")) | ||
| .setStrategy(strategy) | ||
| intercept[SparkException] { | ||
| val model = imputer.fit(df) | ||
| withClue("Imputer should fail all the values are invalid") { | ||
| val e: SparkException = intercept[SparkException] { | ||
| val model = imputer.fit(df) | ||
| } | ||
| assert(e.getMessage.contains("surrogate cannot be computed")) | ||
| } | ||
| } | ||
| } | ||
|
|
||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we should also have a test for a
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. added |
||
| test("Imputer throws exception when inputCols does not match outputCols") { | ||
| test("Imputer input & output column validation") { | ||
| val df = spark.createDataFrame( Seq( | ||
| (0, 1.0, 1.0, 1.0), | ||
| (1, Double.NaN, 3.0, 3.0), | ||
| (2, Double.NaN, Double.NaN, Double.NaN) | ||
| )).toDF("id", "value1", "value2", "value3") | ||
| Seq("mean", "median").foreach { strategy => | ||
| // inputCols and outCols length different | ||
| val imputer = new Imputer() | ||
| .setInputCols(Array("value1", "value2")) | ||
| .setOutputCols(Array("out1")) | ||
| .setStrategy(strategy) | ||
| intercept[IllegalArgumentException] { | ||
| val model = imputer.fit(df) | ||
| withClue("Imputer should fail if inputCols and outputCols are different length") { | ||
| val e: IllegalArgumentException = intercept[IllegalArgumentException] { | ||
| val imputer = new Imputer().setStrategy(strategy) | ||
| .setInputCols(Array("value1", "value2")) | ||
| .setOutputCols(Array("out1")) | ||
| val model = imputer.fit(df) | ||
| } | ||
| assert(e.getMessage.contains("should have the same length")) | ||
| } | ||
| // duplicate name in inputCols | ||
| imputer.setInputCols(Array("value1", "value1")).setOutputCols(Array("out1, out2")) | ||
| intercept[IllegalArgumentException] { | ||
| val model = imputer.fit(df) | ||
|
|
||
| withClue("Imputer should fail if inputCols contains duplicates") { | ||
| val e: IllegalArgumentException = intercept[IllegalArgumentException] { | ||
| val imputer = new Imputer().setStrategy(strategy) | ||
| .setInputCols(Array("value1", "value1")) | ||
| .setOutputCols(Array("out1", "out2")) | ||
| val model = imputer.fit(df) | ||
| } | ||
| assert(e.getMessage.contains("inputCols contains duplicates")) | ||
| } | ||
|
|
||
| withClue("Imputer should fail if outputCols contains duplicates") { | ||
| val e: IllegalArgumentException = intercept[IllegalArgumentException] { | ||
| val imputer = new Imputer().setStrategy(strategy) | ||
| .setInputCols(Array("value1", "value2")) | ||
| .setOutputCols(Array("out1", "out1")) | ||
| val model = imputer.fit(df) | ||
| } | ||
| assert(e.getMessage.contains("outputCols contains duplicates")) | ||
| } | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -133,12 +151,13 @@ class ImputerSuite extends SparkFunSuite with MLlibTestSparkContext with Default | |
| .setInputCols(Array("myInputCol")) | ||
| .setOutputCols(Array("myOutputCol")) | ||
| val newInstance = testDefaultReadWrite(instance) | ||
| assert(newInstance.surrogateDF.columns === instance.surrogateDF.columns) | ||
| assert(newInstance.surrogateDF.collect() === instance.surrogateDF.collect()) | ||
| } | ||
|
|
||
| } | ||
|
|
||
| object ImputerSuite{ | ||
| object ImputerSuite { | ||
|
|
||
| /** | ||
| * Imputation strategy. Available options are ["mean", "median"]. | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we document that we only support "Float" and "Double" types for now?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool. Thanks.