Skip to content
Closed
Show file tree
Hide file tree
Changes from 43 commits
Commits
Show all changes
54 commits
Select commit Hold shift + click to select a range
2999b26
initial commit for Imputer
hhbyyh Feb 29, 2016
8335cf2
adjust mean and most
hhbyyh Feb 29, 2016
7be5e9b
Merge remote-tracking branch 'upstream/master' into imputer
hhbyyh Mar 2, 2016
131f7d5
Merge remote-tracking branch 'upstream/master' into imputer
hhbyyh Mar 3, 2016
a72a3ea
Merge remote-tracking branch 'upstream/master' into imputer
hhbyyh Mar 5, 2016
78df589
Merge remote-tracking branch 'upstream/master' into imputer
hhbyyh Mar 7, 2016
b949be5
refine code and add ut
hhbyyh Mar 9, 2016
79b1c62
Merge remote-tracking branch 'upstream/master' into imputer
hhbyyh Mar 9, 2016
c3d5d55
minor change
hhbyyh Mar 9, 2016
1b39668
add object Imputer and ut refine
hhbyyh Mar 9, 2016
7f87ffb
Merge remote-tracking branch 'upstream/master' into imputer
hhbyyh Mar 10, 2016
4e45f81
add options validate and some small changes
hhbyyh Mar 10, 2016
e1dd0d2
Merge remote-tracking branch 'upstream/master' into imputer
hhbyyh Mar 22, 2016
12220eb
Merge remote-tracking branch 'upstream/master' into imputer
hhbyyh Mar 23, 2016
1b36deb
optimize mean for vectors
hhbyyh Mar 23, 2016
72d104d
style fix
hhbyyh Mar 23, 2016
c311b2e
Merge remote-tracking branch 'upstream/master' into imputer
hhbyyh Apr 10, 2016
d6b9421
Merge remote-tracking branch 'upstream/master' into imputer
hhbyyh Apr 11, 2016
d181b12
Merge remote-tracking branch 'upstream/master' into imputer
hhbyyh Apr 11, 2016
e211481
Merge remote-tracking branch 'upstream/master' into imputer
hhbyyh Apr 11, 2016
791533b
Merge remote-tracking branch 'upstream/master' into imputer
hhbyyh Apr 12, 2016
fdd6f94
refactor to support numeric only
hhbyyh Apr 12, 2016
8042cfb
Merge branch 'imputer' of https://github.com/hhbyyh/spark into imputer
hhbyyh Apr 12, 2016
4bdf595
change most to mode
hhbyyh Apr 12, 2016
e6ad69c
Merge remote-tracking branch 'upstream/master' into imputer
hhbyyh Apr 17, 2016
1718422
move filter to NaN
hhbyyh Apr 17, 2016
594c501
add transformSchema
hhbyyh Apr 20, 2016
3043e7d
Merge remote-tracking branch 'upstream/master' into imputer
hhbyyh Apr 27, 2016
b3633e8
remove mode and change input type
hhbyyh Apr 27, 2016
053d489
remove print
hhbyyh Apr 27, 2016
63e7032
Merge remote-tracking branch 'upstream/master' into imputer
hhbyyh Apr 28, 2016
4e1c34a
update document and remove a ut
hhbyyh Apr 28, 2016
051aec6
Merge remote-tracking branch 'upstream/master' into imputer
hhbyyh Apr 29, 2016
aef094b
fix ut
hhbyyh Apr 29, 2016
335ded7
Merge remote-tracking branch 'upstream/master' into imputer
hhbyyh Apr 29, 2016
949ed79
Merge remote-tracking branch 'upstream/master' into imputer
hhbyyh Apr 30, 2016
93bba63
Merge branch 'imputer' of https://github.com/hhbyyh/spark into imputer
hhbyyh Apr 30, 2016
cca8dd4
rename ut
hhbyyh May 1, 2016
eea8947
Merge remote-tracking branch 'upstream/master' into imputer
hhbyyh May 3, 2016
4e07431
update parameter doc
hhbyyh May 3, 2016
31556e6
Merge remote-tracking branch 'upstream/master' into imputer
hhbyyh Sep 7, 2016
d4f92e4
Merge branch 'imputer' of https://github.com/hhbyyh/spark into imputer
hhbyyh Sep 7, 2016
544a65c
update version
hhbyyh Sep 7, 2016
910685e
Merge remote-tracking branch 'upstream/master' into imputer
YY-OnCall Oct 6, 2016
91d4cee
throw exception
YY-OnCall Oct 7, 2016
8744524
change data format
YY-OnCall Oct 7, 2016
ca45c33
Merge remote-tracking branch 'upstream/master' into imputer
YY-OnCall Feb 22, 2017
e86d919
add multi column support
YY-OnCall Feb 22, 2017
4f17c54
Merge remote-tracking branch 'upstream/master' into imputer
YY-OnCall Mar 2, 2017
ce59a5b
Merge remote-tracking branch 'upstream/master' into imputer
YY-OnCall Mar 3, 2017
41d91b9
change surrogateDF format and add ut for multi-columns
YY-OnCall Mar 3, 2017
9f6bd57
Merge remote-tracking branch 'upstream/master' into imputer
YY-OnCall Mar 6, 2017
e378db5
unit test refine and comments update
YY-OnCall Mar 6, 2017
c67afc1
fix exception message
YY-OnCall Mar 8, 2017
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
219 changes: 219 additions & 0 deletions mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,219 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.ml.feature

import org.apache.hadoop.fs.Path

import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
import org.apache.spark.ml.util._
import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._

/**
* Params for [[Imputer]] and [[ImputerModel]].
*/
private[feature] trait ImputerParams extends Params with HasInputCol with HasOutputCol {

/**
* The imputation strategy.
* If "mean", then replace missing values using the mean value of the feature.
* If "median", then replace missing values using the approximate median value of the feature.
* Default: mean
*
* @group param
*/
final val strategy: Param[String] = new Param(this, "strategy", "strategy for imputation. " +
"If mean, then replace missing values using the mean value of the feature. " +
"If median, then replace missing values using the median value of the feature.",
ParamValidators.inArray[String](Imputer.supportedStrategyNames.toArray))

/** @group getParam */
def getStrategy: String = $(strategy)

/**
* The placeholder for the missing values. All occurrences of missingValue will be imputed.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doc: Note that null values are always treated as missing.

* Default: Double.NaN
*
* @group param
*/
final val missingValue: DoubleParam = new DoubleParam(this, "missingValue",
"The placeholder for the missing values. All occurrences of missingValue will be imputed")

/** @group getParam */
def getMissingValue: Double = $(missingValue)

/** Validates and transforms the input schema. */
protected def validateAndTransformSchema(schema: StructType): StructType = {
val inputType = schema($(inputCol)).dataType
SchemaUtils.checkColumnTypes(schema, $(inputCol), Seq(DoubleType, FloatType))
require(!schema.fieldNames.contains($(outputCol)),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is already checked in appendColumn

s"Output column ${$(outputCol)} already exists.")
SchemaUtils.appendColumn(schema, $(outputCol), inputType)
}
}

/**
* :: Experimental ::
* Imputation estimator for completing missing values, either using the mean or the median
* of the column in which the missing values are located. The input column should be of
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As mentioned above at https://github.com/apache/spark/pull/11601/files#r104403880, you can add the note about relative error here.

Something like "For computing median, approxQuantile is used with a relative error of X" (provide a ScalaDoc link to approxQuantile).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't add the link as it may break java doc generation.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah right - perhaps just mention using approxQuantile?

* DoubleType or FloatType.
*
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we document that we only support "Float" and "Double" types for now?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool. Thanks.

* Note that the mean/median value is computed after filtering out missing values.
* All Null values in the input column are treated as missing, and so are also imputed.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Say here that this does not support categorical features yet and will transform them, possibly creating incorrect values for a categorical feature. Also add JIRA number for supporting them.

*/
@Experimental
class Imputer @Since("2.1.0")(override val uid: String)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All @Since annotations -> 2.2.0

extends Estimator[ImputerModel] with ImputerParams with DefaultParamsWritable {

@Since("2.1.0")
def this() = this(Identifiable.randomUID("imputer"))

/** @group setParam */
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we add Since annotations for the setters?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've heard an argument that everything in the class is implicitly since 2.1.0 since the class itself is - unless otherwise stated. Which does make sense. But I do slightly favour being explicit about it (even if it is a bit pedantic) so yeah let's add the annotation to all the setters.

def setInputCol(value: String): this.type = set(inputCol, value)

/** @group setParam */
def setOutputCol(value: String): this.type = set(outputCol, value)

/**
* Imputation strategy. Available options are ["mean", "median"].
* @group setParam
*/
def setStrategy(value: String): this.type = set(strategy, value)

/** @group setParam */
def setMissingValue(value: Double): this.type = set(missingValue, value)

setDefault(strategy -> "mean", missingValue -> Double.NaN)

override def fit(dataset: Dataset[_]): ImputerModel = {
transformSchema(dataset.schema, logging = true)
val ic = col($(inputCol))
val filtered = dataset.select(ic.cast(DoubleType))
.filter(ic.isNotNull && ic =!= $(missingValue))
val surrogate = $(strategy) match {
case "mean" => filtered.filter(!ic.isNaN).select(avg($(inputCol))).first().getDouble(0)
case "median" => filtered.stat.approxQuantile($(inputCol), Array(0.5), 0.001)(0)
}
copyValues(new ImputerModel(uid, surrogate).setParent(this))
}

override def transformSchema(schema: StructType): StructType = {
validateAndTransformSchema(schema)
}

override def copy(extra: ParamMap): Imputer = {
val copied = new Imputer(uid)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use defaultCopy

copyValues(copied, extra)
}
}

@Since("2.1.0")
object Imputer extends DefaultParamsReadable[Imputer] {

/** Set of strategy names that Imputer currently supports. */
private[ml] val supportedStrategyNames = Set("mean", "median")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we factor out the mean and median names in to private[ml] val so to be used instead of the raw strings throughout?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that's better.


@Since("2.1.0")
override def load(path: String): Imputer = super.load(path)
}

/**
* :: Experimental ::
* Model fitted by [[Imputer]].
*
* @param surrogate Value by which missing values in the input column will be replaced.
*/
@Experimental
class ImputerModel private[ml](
override val uid: String,
val surrogate: Double)
extends Model[ImputerModel] with ImputerParams with MLWritable {

import ImputerModel._

/** @group setParam */
def setInputCol(value: String): this.type = set(inputCol, value)

/** @group setParam */
def setOutputCol(value: String): this.type = set(outputCol, value)

override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema, logging = true)
val inputType = dataset.select($(inputCol)).schema.fields(0).dataType
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

simplify: dataset.schema($(inputCol)).dataType

val ic = col($(inputCol))
dataset.withColumn($(outputCol), when(ic.isNull, surrogate)
.when(ic === $(missingValue), surrogate)
.otherwise(ic)
.cast(inputType))
}

override def transformSchema(schema: StructType): StructType = {
validateAndTransformSchema(schema)
}

override def copy(extra: ParamMap): ImputerModel = {
val copied = new ImputerModel(uid, surrogate)
copyValues(copied, extra).setParent(parent)
}

@Since("2.1.0")
override def write: MLWriter = new ImputerModelWriter(this)
}


@Since("2.1.0")
object ImputerModel extends MLReadable[ImputerModel] {

private[ImputerModel] class ImputerModelWriter(instance: ImputerModel) extends MLWriter {

private case class Data(surrogate: Double)
Copy link
Member

@jkbradley jkbradley Sep 26, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we save an ArrayType[VectorUDT] or ArrayType[ArrayType[Double]] here? That will make this extensible to multiple columns, including mixed NumericType and Vector columns, in the future without us having to change the persistence format.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would think that if we support multiple columns, we need to match up the column name to the surrogate, correct? So I'd think we would want to save a DF with the same columns as inputCol(s) and then yes either double or vector type. Is this what you mean here?


override protected def saveImpl(path: String): Unit = {
DefaultParamsWriter.saveMetadata(instance, path, sc)
val data = new Data(instance.surrogate)
val dataPath = new Path(path, "data").toString
sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
}
}

private class ImputerReader extends MLReader[ImputerModel] {

private val className = classOf[ImputerModel].getName

override def load(path: String): ImputerModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val dataPath = new Path(path, "data").toString
val Row(surrogate: Double) = sqlContext.read.parquet(dataPath)
.select("surrogate")
.head()
val model = new ImputerModel(metadata.uid, surrogate)
DefaultParamsReader.getAndSetParams(model, metadata)
model
}
}

@Since("2.1.0")
override def read: MLReader[ImputerModel] = new ImputerReader

@Since("2.1.0")
override def load(path: String): ImputerModel = super.load(path)
}
122 changes: 122 additions & 0 deletions mllib/src/test/scala/org/apache/spark/ml/feature/ImputerSuite.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.ml.feature

import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.util.{DefaultReadWriteTest}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.sql.Row

class ImputerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need tests for multiple columns too


test("Imputer for Double with default missing Value NaN") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The way these tests are written was pretty confusing to me. It seems like the "mean" column should be the mean of some values but really it is the expected output when "mean" strategy is used. This is minor since it just affects readability and might be more of a personal preference. I think the following is clearer:

val df = sqlContext.createDataFrame( Seq(
      (0, 1.0),
      (1, 1.0),
      (2, 3.0),
      (3, 4.0),
      (4, Double.NaN)
    )).toDF("id", "value")
    val expectedOutput = Map(
      "mean"->Array(1.0, 1.0, 3.0, 4.0, 2.25),
      "median" -> Array(1.0, 1.0, 3.0, 4.0, 1.0),
      "mode" -> Array(1.0, 1.0, 3.0, 4.0, 1.0))
    Seq("mean", "median", "mode").foreach { strategy =>
      val imputer = new Imputer().setInputCol("value").setOutputCol("out").setStrategy(strategy)
      val model = imputer.fit(df)
      val result = model.transform(df).select("out").collect().map(_.getDouble(0))
      result.zip(expectedOutput(strategy)).foreach { case (actual, expected) =>
        assert(actual ~== expected absTol 1e-5)
      }

Really, just any way of indicating that the extra columns are expected outputs would be clearer to me.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I get the point, yet data in the format of

    val df = sqlContext.createDataFrame( Seq(
      (0, 1.0, 1.0, 1.0),
      (1, 1.0, 1.0, 1.0),
      (2, 3.0, 3.0, 3.0),
      (3, 4.0, 4.0, 4.0),
      (4, Double.NaN, 2.25, 1.0)
    )).toDF("id", "value", "exp_mean", "exp_median")

provides direct correspondence.

I've updated the columns names.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry to nitpick, but "expected_mean" seems significantly more clear than "exp_mean" at the expense of only a few extra characters. Only change it if you push more commits since its minor :)

val df = spark.createDataFrame( Seq(
(0, 1.0, 1.0, 1.0),
(1, 1.0, 1.0, 1.0),
(2, 3.0, 3.0, 3.0),
(3, 4.0, 4.0, 4.0),
(4, Double.NaN, 2.25, 1.0)
)).toDF("id", "value", "expected_mean", "expected_median")
Seq("mean", "median").foreach { strategy =>
val imputer = new Imputer().setInputCol("value").setOutputCol("out").setStrategy(strategy)
val model = imputer.fit(df)
model.transform(df).select("expected_" + strategy, "out").collect().foreach {
case Row(exp: Double, out: Double) =>
assert(exp ~== out absTol 1e-5, s"Imputed values differ. Expected: $exp, actual: $out")
}
}
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to add tests for the case where the entire column is null or NaN. I just checked the NaN case and it will throw a NPE in the fit method.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch yes - obviously the imputer can't actually do anything useful in that case - but it should either throw a useful error, or return the dataset unchanged.

I would favor an error in this case as if a user is explicitly wanting to impute missing data and all their data is missing, rather blow up now than later in the pipeline.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, actually this also fails if the entire input column is the missing value as well. We need to beef up the test suite :)

test("Imputer should handle NaNs when computing surrogate value, if missingValue is not NaN") {
val df = spark.createDataFrame( Seq(
(0, 1.0, 1.0, 1.0),
(1, 3.0, 3.0, 3.0),
(2, Double.NaN, Double.NaN, Double.NaN),
(3, -1.0, 2.0, 3.0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does approx quantile always choose the greater of the two middle values as the median? If so, can we add a comment noting that? NumPy computes the median of [1.0, 3.0] exactly as 2.0. Future developers might think it's a mistake.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, it always choose the smaller one of the two middle values as I saw in some tests.
In this test case, the median is computed from [1, 3, Double.NaN]. And Double.NaN is treated as it's greater than Double.MaValue.

)).toDF("id", "value", "expected_mean", "expected_median")
Seq("mean", "median").foreach { strategy =>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This basic logic could be reused across the unit tests comparing actual and expected results. I'd recommend extracting this foreach into a method which can be called for each of the tests in this suite.

val imputer = new Imputer().setInputCol("value").setOutputCol("out").setStrategy(strategy)
.setMissingValue(-1.0)
val model = imputer.fit(df)
model.transform(df).select("expected_" + strategy, "out").collect().foreach {
case Row(exp: Double, out: Double) =>
assert((exp.isNaN && out.isNaN) || (exp ~== out absTol 1e-5),
s"Imputed values differ. Expected: $exp, actual: $out")
}
}
}

test("Imputer for Float with missing Value -1.0") {
val df = spark.createDataFrame( Seq(
(0, 1.0F, 1.0F, 1.0F),
(1, 3.0F, 3.0F, 3.0F),
(2, 10.0F, 10.0F, 10.0F),
(3, 10.0F, 10.0F, 10.0F),
(4, -1.0F, 6.0F, 3.0F)
)).toDF("id", "value", "expected_mean", "expected_median")

Seq("mean", "median").foreach { strategy =>
val imputer = new Imputer().setInputCol("value").setOutputCol("out").setStrategy(strategy)
.setMissingValue(-1)
val model = imputer.fit(df)
val result = model.transform(df)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is never used.

model.transform(df).select("expected_" + strategy, "out").collect().foreach {
case Row(exp: Float, out: Float) =>
assert(exp == out, s"Imputed values differ. Expected: $exp, actual: $out")
}
}
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should also have a test for a non-NaN missing value, but with NaN in the dataset, to check that "mean" and "median" behave as we expect.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added

test("Imputer should impute null as well as 'missingValue'") {
val df = spark.createDataFrame( Seq(
(0, 4.0, 4.0, 4.0),
(1, 10.0, 10.0, 10.0),
(2, 10.0, 10.0, 10.0),
(3, Double.NaN, 8.0, 10.0),
(4, -1.0, 8.0, 10.0)
)).toDF("id", "value", "expected_mean", "expected_median")
val df2 = df.selectExpr("*", "IF(value=-1.0, null, value) as nullable_value")
Seq("mean", "median").foreach { strategy =>
val imputer = new Imputer().setInputCol("nullable_value").setOutputCol("out")
.setStrategy(strategy)
val model = imputer.fit(df2)
model.transform(df2).select("expected_" + strategy, "out").collect().foreach {
case Row(exp: Double, out: Double) =>
assert(exp ~== out absTol 1e-5, s"Imputed values differ. Expected: $exp, actual: $out")
}
}
}

test("Imputer read/write") {
val t = new Imputer()
.setInputCol("myInputCol")
.setOutputCol("myOutputCol")
.setMissingValue(-1.0)
testDefaultReadWrite(t)
}

test("ImputerModel read/write") {
val instance = new ImputerModel(
"myImputer", 1.234)
.setInputCol("myInputCol")
.setOutputCol("myOutputCol")
val newInstance = testDefaultReadWrite(instance)
assert(newInstance.surrogate === instance.surrogate)
}

}