Skip to content
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
54 commits
Select commit Hold shift + click to select a range
2999b26
initial commit for Imputer
hhbyyh Feb 29, 2016
8335cf2
adjust mean and most
hhbyyh Feb 29, 2016
7be5e9b
Merge remote-tracking branch 'upstream/master' into imputer
hhbyyh Mar 2, 2016
131f7d5
Merge remote-tracking branch 'upstream/master' into imputer
hhbyyh Mar 3, 2016
a72a3ea
Merge remote-tracking branch 'upstream/master' into imputer
hhbyyh Mar 5, 2016
78df589
Merge remote-tracking branch 'upstream/master' into imputer
hhbyyh Mar 7, 2016
b949be5
refine code and add ut
hhbyyh Mar 9, 2016
79b1c62
Merge remote-tracking branch 'upstream/master' into imputer
hhbyyh Mar 9, 2016
c3d5d55
minor change
hhbyyh Mar 9, 2016
1b39668
add object Imputer and ut refine
hhbyyh Mar 9, 2016
7f87ffb
Merge remote-tracking branch 'upstream/master' into imputer
hhbyyh Mar 10, 2016
4e45f81
add options validate and some small changes
hhbyyh Mar 10, 2016
e1dd0d2
Merge remote-tracking branch 'upstream/master' into imputer
hhbyyh Mar 22, 2016
12220eb
Merge remote-tracking branch 'upstream/master' into imputer
hhbyyh Mar 23, 2016
1b36deb
optimize mean for vectors
hhbyyh Mar 23, 2016
72d104d
style fix
hhbyyh Mar 23, 2016
c311b2e
Merge remote-tracking branch 'upstream/master' into imputer
hhbyyh Apr 10, 2016
d6b9421
Merge remote-tracking branch 'upstream/master' into imputer
hhbyyh Apr 11, 2016
d181b12
Merge remote-tracking branch 'upstream/master' into imputer
hhbyyh Apr 11, 2016
e211481
Merge remote-tracking branch 'upstream/master' into imputer
hhbyyh Apr 11, 2016
791533b
Merge remote-tracking branch 'upstream/master' into imputer
hhbyyh Apr 12, 2016
fdd6f94
refactor to support numeric only
hhbyyh Apr 12, 2016
8042cfb
Merge branch 'imputer' of https://github.com/hhbyyh/spark into imputer
hhbyyh Apr 12, 2016
4bdf595
change most to mode
hhbyyh Apr 12, 2016
e6ad69c
Merge remote-tracking branch 'upstream/master' into imputer
hhbyyh Apr 17, 2016
1718422
move filter to NaN
hhbyyh Apr 17, 2016
594c501
add transformSchema
hhbyyh Apr 20, 2016
3043e7d
Merge remote-tracking branch 'upstream/master' into imputer
hhbyyh Apr 27, 2016
b3633e8
remove mode and change input type
hhbyyh Apr 27, 2016
053d489
remove print
hhbyyh Apr 27, 2016
63e7032
Merge remote-tracking branch 'upstream/master' into imputer
hhbyyh Apr 28, 2016
4e1c34a
update document and remove a ut
hhbyyh Apr 28, 2016
051aec6
Merge remote-tracking branch 'upstream/master' into imputer
hhbyyh Apr 29, 2016
aef094b
fix ut
hhbyyh Apr 29, 2016
335ded7
Merge remote-tracking branch 'upstream/master' into imputer
hhbyyh Apr 29, 2016
949ed79
Merge remote-tracking branch 'upstream/master' into imputer
hhbyyh Apr 30, 2016
93bba63
Merge branch 'imputer' of https://github.com/hhbyyh/spark into imputer
hhbyyh Apr 30, 2016
cca8dd4
rename ut
hhbyyh May 1, 2016
eea8947
Merge remote-tracking branch 'upstream/master' into imputer
hhbyyh May 3, 2016
4e07431
update parameter doc
hhbyyh May 3, 2016
31556e6
Merge remote-tracking branch 'upstream/master' into imputer
hhbyyh Sep 7, 2016
d4f92e4
Merge branch 'imputer' of https://github.com/hhbyyh/spark into imputer
hhbyyh Sep 7, 2016
544a65c
update version
hhbyyh Sep 7, 2016
910685e
Merge remote-tracking branch 'upstream/master' into imputer
YY-OnCall Oct 6, 2016
91d4cee
throw exception
YY-OnCall Oct 7, 2016
8744524
change data format
YY-OnCall Oct 7, 2016
ca45c33
Merge remote-tracking branch 'upstream/master' into imputer
YY-OnCall Feb 22, 2017
e86d919
add multi column support
YY-OnCall Feb 22, 2017
4f17c54
Merge remote-tracking branch 'upstream/master' into imputer
YY-OnCall Mar 2, 2017
ce59a5b
Merge remote-tracking branch 'upstream/master' into imputer
YY-OnCall Mar 3, 2017
41d91b9
change surrogateDF format and add ut for multi-columns
YY-OnCall Mar 3, 2017
9f6bd57
Merge remote-tracking branch 'upstream/master' into imputer
YY-OnCall Mar 6, 2017
e378db5
unit test refine and comments update
YY-OnCall Mar 6, 2017
c67afc1
fix exception message
YY-OnCall Mar 8, 2017
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 9 additions & 10 deletions mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,7 @@ private[feature] trait ImputerParams extends Params with HasInputCols {
/**
* The imputation strategy.
* If "mean", then replace missing values using the mean value of the feature.
* If "median", then replace missing values using the approximate median value of the
* feature (relative error less than 0.001).
* If "median", then replace missing values using the approximate median value of the feature.
* Default: mean
*
* @group param
Expand Down Expand Up @@ -76,10 +75,10 @@ private[feature] trait ImputerParams extends Params with HasInputCols {

/** Validates and transforms the input schema. */
protected def validateAndTransformSchema(schema: StructType): StructType = {
require($(inputCols).length == $(inputCols).distinct.length, s"inputCols duplicates:" +
s" (${$(inputCols).mkString(", ")})")
require($(outputCols).length == $(outputCols).distinct.length, s"outputCols duplicates:" +
s" (${$(outputCols).mkString(", ")})")
require($(inputCols).length == $(inputCols).distinct.length, s"inputCols contains" +
s" duplicates: (${$(inputCols).mkString(", ")})")
require($(outputCols).length == $(outputCols).distinct.length, s"outputCols contains" +
s" duplicates: (${$(outputCols).mkString(", ")})")
require($(inputCols).length == $(outputCols).length, s"inputCols(${$(inputCols).length})" +
s" and outputCols(${$(outputCols).length}) should have the same length")
val outputFields = $(inputCols).zip($(outputCols)).map { case (inputCol, outputCol) =>
Expand All @@ -99,7 +98,8 @@ private[feature] trait ImputerParams extends Params with HasInputCols {
* (SPARK-15041) and possibly creates incorrect values for a categorical feature.
*
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we document that we only support "Float" and "Double" types for now?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool. Thanks.

* Note that the mean/median value is computed after filtering out missing values.
* All Null values in the input column are treated as missing, and so are also imputed.
* All Null values in the input column are treated as missing, and so are also imputed. For
* computing median, DataFrameStatFunctions.approxQuantile is used with a relative error of 0.001.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I see it is here - nevermind

*/
@Experimental
class Imputer @Since("2.2.0")(override val uid: String)
Expand Down Expand Up @@ -127,8 +127,7 @@ class Imputer @Since("2.2.0")(override val uid: String)
@Since("2.2.0")
def setMissingValue(value: Double): this.type = set(missingValue, value)

import org.apache.spark.ml.feature.Imputer._
setDefault(strategy -> mean, missingValue -> Double.NaN)
setDefault(strategy -> Imputer.mean, missingValue -> Double.NaN)

override def fit(dataset: Dataset[_]): ImputerModel = {
transformSchema(dataset.schema, logging = true)
Expand Down Expand Up @@ -197,7 +196,7 @@ class ImputerModel private[ml](
override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema, logging = true)
var outputDF = dataset
val surrogates = surrogateDF.select($(inputCols).head, $(inputCols).tail: _*).head().toSeq
val surrogates = surrogateDF.select($(inputCols).map(col): _*).head().toSeq

$(inputCols).zip($(outputCols)).zip(surrogates).foreach {
case ((inputCol, outputCol), surrogate) =>
Expand Down
49 changes: 34 additions & 15 deletions mllib/src/test/scala/org/apache/spark/ml/feature/ImputerSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -85,33 +85,51 @@ class ImputerSuite extends SparkFunSuite with MLlibTestSparkContext with Default
Seq("mean", "median").foreach { strategy =>
val imputer = new Imputer().setInputCols(Array("value")).setOutputCols(Array("out"))
.setStrategy(strategy)
intercept[SparkException] {
val model = imputer.fit(df)
withClue("Imputer should fail all the values are invalid") {
val e: SparkException = intercept[SparkException] {
val model = imputer.fit(df)
}
assert(e.getMessage.contains("surrogate cannot be computed"))
}
}
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should also have a test for a non-NaN missing value, but with NaN in the dataset, to check that "mean" and "median" behave as we expect.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added

test("Imputer throws exception when inputCols does not match outputCols") {
test("Imputer input & output column validation") {
val df = spark.createDataFrame( Seq(
(0, 1.0, 1.0, 1.0),
(1, Double.NaN, 3.0, 3.0),
(2, Double.NaN, Double.NaN, Double.NaN)
)).toDF("id", "value1", "value2", "value3")
Seq("mean", "median").foreach { strategy =>
// inputCols and outCols length different
val imputer = new Imputer()
.setInputCols(Array("value1", "value2"))
.setOutputCols(Array("out1"))
.setStrategy(strategy)
intercept[IllegalArgumentException] {
val model = imputer.fit(df)
withClue("Imputer should fail if inputCols and outputCols are different length") {
val e: IllegalArgumentException = intercept[IllegalArgumentException] {
val imputer = new Imputer().setStrategy(strategy)
.setInputCols(Array("value1", "value2"))
.setOutputCols(Array("out1"))
val model = imputer.fit(df)
}
assert(e.getMessage.contains("should have the same length"))
}
// duplicate name in inputCols
imputer.setInputCols(Array("value1", "value1")).setOutputCols(Array("out1, out2"))
intercept[IllegalArgumentException] {
val model = imputer.fit(df)

withClue("Imputer should fail if inputCols contains duplicates") {
val e: IllegalArgumentException = intercept[IllegalArgumentException] {
val imputer = new Imputer().setStrategy(strategy)
.setInputCols(Array("value1", "value1"))
.setOutputCols(Array("out1", "out2"))
val model = imputer.fit(df)
}
assert(e.getMessage.contains("inputCols contains duplicates"))
}

withClue("Imputer should fail if outputCols contains duplicates") {
val e: IllegalArgumentException = intercept[IllegalArgumentException] {
val imputer = new Imputer().setStrategy(strategy)
.setInputCols(Array("value1", "value2"))
.setOutputCols(Array("out1", "out1"))
val model = imputer.fit(df)
}
assert(e.getMessage.contains("outputCols contains duplicates"))
}
}
}

Expand All @@ -133,12 +151,13 @@ class ImputerSuite extends SparkFunSuite with MLlibTestSparkContext with Default
.setInputCols(Array("myInputCol"))
.setOutputCols(Array("myOutputCol"))
val newInstance = testDefaultReadWrite(instance)
assert(newInstance.surrogateDF.columns === instance.surrogateDF.columns)
assert(newInstance.surrogateDF.collect() === instance.surrogateDF.collect())
}

}

object ImputerSuite{
object ImputerSuite {

/**
* Imputation strategy. Available options are ["mean", "median"].
Expand Down