Skip to content
Closed
Show file tree
Hide file tree
Changes from 48 commits
Commits
Show all changes
54 commits
Select commit Hold shift + click to select a range
2999b26
initial commit for Imputer
hhbyyh Feb 29, 2016
8335cf2
adjust mean and most
hhbyyh Feb 29, 2016
7be5e9b
Merge remote-tracking branch 'upstream/master' into imputer
hhbyyh Mar 2, 2016
131f7d5
Merge remote-tracking branch 'upstream/master' into imputer
hhbyyh Mar 3, 2016
a72a3ea
Merge remote-tracking branch 'upstream/master' into imputer
hhbyyh Mar 5, 2016
78df589
Merge remote-tracking branch 'upstream/master' into imputer
hhbyyh Mar 7, 2016
b949be5
refine code and add ut
hhbyyh Mar 9, 2016
79b1c62
Merge remote-tracking branch 'upstream/master' into imputer
hhbyyh Mar 9, 2016
c3d5d55
minor change
hhbyyh Mar 9, 2016
1b39668
add object Imputer and ut refine
hhbyyh Mar 9, 2016
7f87ffb
Merge remote-tracking branch 'upstream/master' into imputer
hhbyyh Mar 10, 2016
4e45f81
add options validate and some small changes
hhbyyh Mar 10, 2016
e1dd0d2
Merge remote-tracking branch 'upstream/master' into imputer
hhbyyh Mar 22, 2016
12220eb
Merge remote-tracking branch 'upstream/master' into imputer
hhbyyh Mar 23, 2016
1b36deb
optimize mean for vectors
hhbyyh Mar 23, 2016
72d104d
style fix
hhbyyh Mar 23, 2016
c311b2e
Merge remote-tracking branch 'upstream/master' into imputer
hhbyyh Apr 10, 2016
d6b9421
Merge remote-tracking branch 'upstream/master' into imputer
hhbyyh Apr 11, 2016
d181b12
Merge remote-tracking branch 'upstream/master' into imputer
hhbyyh Apr 11, 2016
e211481
Merge remote-tracking branch 'upstream/master' into imputer
hhbyyh Apr 11, 2016
791533b
Merge remote-tracking branch 'upstream/master' into imputer
hhbyyh Apr 12, 2016
fdd6f94
refactor to support numeric only
hhbyyh Apr 12, 2016
8042cfb
Merge branch 'imputer' of https://github.com/hhbyyh/spark into imputer
hhbyyh Apr 12, 2016
4bdf595
change most to mode
hhbyyh Apr 12, 2016
e6ad69c
Merge remote-tracking branch 'upstream/master' into imputer
hhbyyh Apr 17, 2016
1718422
move filter to NaN
hhbyyh Apr 17, 2016
594c501
add transformSchema
hhbyyh Apr 20, 2016
3043e7d
Merge remote-tracking branch 'upstream/master' into imputer
hhbyyh Apr 27, 2016
b3633e8
remove mode and change input type
hhbyyh Apr 27, 2016
053d489
remove print
hhbyyh Apr 27, 2016
63e7032
Merge remote-tracking branch 'upstream/master' into imputer
hhbyyh Apr 28, 2016
4e1c34a
update document and remove a ut
hhbyyh Apr 28, 2016
051aec6
Merge remote-tracking branch 'upstream/master' into imputer
hhbyyh Apr 29, 2016
aef094b
fix ut
hhbyyh Apr 29, 2016
335ded7
Merge remote-tracking branch 'upstream/master' into imputer
hhbyyh Apr 29, 2016
949ed79
Merge remote-tracking branch 'upstream/master' into imputer
hhbyyh Apr 30, 2016
93bba63
Merge branch 'imputer' of https://github.com/hhbyyh/spark into imputer
hhbyyh Apr 30, 2016
cca8dd4
rename ut
hhbyyh May 1, 2016
eea8947
Merge remote-tracking branch 'upstream/master' into imputer
hhbyyh May 3, 2016
4e07431
update parameter doc
hhbyyh May 3, 2016
31556e6
Merge remote-tracking branch 'upstream/master' into imputer
hhbyyh Sep 7, 2016
d4f92e4
Merge branch 'imputer' of https://github.com/hhbyyh/spark into imputer
hhbyyh Sep 7, 2016
544a65c
update version
hhbyyh Sep 7, 2016
910685e
Merge remote-tracking branch 'upstream/master' into imputer
YY-OnCall Oct 6, 2016
91d4cee
throw exception
YY-OnCall Oct 7, 2016
8744524
change data format
YY-OnCall Oct 7, 2016
ca45c33
Merge remote-tracking branch 'upstream/master' into imputer
YY-OnCall Feb 22, 2017
e86d919
add multi column support
YY-OnCall Feb 22, 2017
4f17c54
Merge remote-tracking branch 'upstream/master' into imputer
YY-OnCall Mar 2, 2017
ce59a5b
Merge remote-tracking branch 'upstream/master' into imputer
YY-OnCall Mar 3, 2017
41d91b9
change surrogateDF format and add ut for multi-columns
YY-OnCall Mar 3, 2017
9f6bd57
Merge remote-tracking branch 'upstream/master' into imputer
YY-OnCall Mar 6, 2017
e378db5
unit test refine and comments update
YY-OnCall Mar 6, 2017
c67afc1
fix exception message
YY-OnCall Mar 8, 2017
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
260 changes: 260 additions & 0 deletions mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,260 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.ml.feature

import org.apache.hadoop.fs.Path

import org.apache.spark.SparkException
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unused import

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not applicable anymore as it's used below now.

import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared.{HasInputCols, HasOutputCol}
import org.apache.spark.ml.util._
import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._

/**
* Params for [[Imputer]] and [[ImputerModel]].
*/
private[feature] trait ImputerParams extends Params with HasInputCols with HasOutputCol {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't use HasOutputCol anymore, correct?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, however I didn't get your first comment. Do you mean we should remove the import?


/**
* The imputation strategy.
* If "mean", then replace missing values using the mean value of the feature.
* If "median", then replace missing values using the approximate median value of the feature.
* Default: mean
*
* @group param
*/
final val strategy: Param[String] = new Param(this, "strategy", "strategy for imputation. " +
"If mean, then replace missing values using the mean value of the feature. " +
"If median, then replace missing values using the median value of the feature.",
ParamValidators.inArray[String](Imputer.supportedStrategyNames.toArray))

/** @group getParam */
def getStrategy: String = $(strategy)

/**
* The placeholder for the missing values. All occurrences of missingValue will be imputed.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doc: Note that null values are always treated as missing.

* Note that null values are always treated as missing.
* Default: Double.NaN
*
* @group param
*/
final val missingValue: DoubleParam = new DoubleParam(this, "missingValue",
"The placeholder for the missing values. All occurrences of missingValue will be imputed")

/** @group getParam */
def getMissingValue: Double = $(missingValue)

/**
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fix comment indentation here.

* Param for output column names.
* @group param
*/
final val outputCols: StringArrayParam = new StringArrayParam(this, "outputCols",
"output column names")

/** @group getParam */
final def getOutputCols: Array[String] = $(outputCols)

/** Validates and transforms the input schema. */
protected def validateAndTransformSchema(schema: StructType): StructType = {
require($(inputCols).length == $(outputCols).length, "inputCols and outputCols should have" +
"the same length")
val localInputCols = $(inputCols)
val localOutputCols = $(outputCols)
var outputSchema = schema

$(inputCols).indices.foreach { i =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can do $(inputCols).zip($(outputCols)).foreach { case (inputCol, outputCol) => ...

val inputCol = localInputCols(i)
val outputCol = localOutputCols(i)
val inputType = schema(inputCol).dataType
SchemaUtils.checkColumnTypes(schema, inputCol, Seq(DoubleType, FloatType))
outputSchema = SchemaUtils.appendColumn(outputSchema, outputCol, inputType)
}
outputSchema
}
}

/**
* :: Experimental ::
* Imputation estimator for completing missing values, either using the mean or the median
* of the column in which the missing values are located. The input column should be of
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As mentioned above at https://github.com/apache/spark/pull/11601/files#r104403880, you can add the note about relative error here.

Something like "For computing median, approxQuantile is used with a relative error of X" (provide a ScalaDoc link to approxQuantile).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't add the link as it may break java doc generation.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah right - perhaps just mention using approxQuantile?

* DoubleType or FloatType. Currently Imputer does not support categorical features yet
* (SPARK-15041) and possibly creates incorrect values for a categorical feature.
*
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we document that we only support "Float" and "Double" types for now?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool. Thanks.

* Note that the mean/median value is computed after filtering out missing values.
* All Null values in the input column are treated as missing, and so are also imputed.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Say here that this does not support categorical features yet and will transform them, possibly creating incorrect values for a categorical feature. Also add JIRA number for supporting them.

*/
@Experimental
class Imputer @Since("2.1.0")(override val uid: String)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All @Since annotations -> 2.2.0

extends Estimator[ImputerModel] with ImputerParams with DefaultParamsWritable {

@Since("2.1.0")
def this() = this(Identifiable.randomUID("imputer"))

/** @group setParam */
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we add Since annotations for the setters?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've heard an argument that everything in the class is implicitly since 2.1.0 since the class itself is - unless otherwise stated. Which does make sense. But I do slightly favour being explicit about it (even if it is a bit pedantic) so yeah let's add the annotation to all the setters.

@Since("2.1.0")
def setInputCols(value: Array[String]): this.type = set(inputCols, value)

/** @group setParam */
@Since("2.1.0")
def setOutputCols(value: Array[String]): this.type = set(outputCols, value)

/**
* Imputation strategy. Available options are ["mean", "median"].
* @group setParam
*/
@Since("2.1.0")
def setStrategy(value: String): this.type = set(strategy, value)

/** @group setParam */
@Since("2.1.0")
def setMissingValue(value: Double): this.type = set(missingValue, value)

setDefault(strategy -> "mean", missingValue -> Double.NaN)

override def fit(dataset: Dataset[_]): ImputerModel = {
transformSchema(dataset.schema, logging = true)
val surrogates = $(inputCols).map { inputCol =>
val ic = col(inputCol)
val filtered = dataset.select(ic.cast(DoubleType))
.filter(ic.isNotNull && ic =!= $(missingValue))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to just consolidate this into one filter (include the !ic.isNaN)?

.filter(!ic.isNaN)
if(filtered.rdd.isEmpty()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not ideal to have to call rdd here - but I guess unavoidable.

Copy link
Contributor

@MLnick MLnick Mar 8, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can do filtered.take(1).size == 0 which should be more efficient

throw new SparkException(s"surrogate cannot be computed. " +
s"All the values in ${inputCol} are Null, Nan or missingValue ($missingValue)")
}
val surrogate = $(strategy) match {
case "mean" => filtered.select(avg(inputCol)).first().getDouble(0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

slightly prefer filtered.select(avg(inputCol)).as[Double].first() (or ... head)

case "median" => filtered.stat.approxQuantile(inputCol, Array(0.5), 0.001)(0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

.head

}
surrogate.asInstanceOf[Double]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is the asInstanceOf[Double] necessary here?

Copy link
Contributor Author

@hhbyyh hhbyyh Mar 3, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no, will remove it.

}

import dataset.sparkSession.implicits._
val surrogateDF = Seq(surrogates).toDF("surrogates")
copyValues(new ImputerModel(uid, surrogateDF).setParent(this))
}

override def transformSchema(schema: StructType): StructType = {
validateAndTransformSchema(schema)
}

override def copy(extra: ParamMap): Imputer = defaultCopy(extra)
}

@Since("2.1.0")
object Imputer extends DefaultParamsReadable[Imputer] {

/** Set of strategy names that Imputer currently supports. */
private[ml] val supportedStrategyNames = Set("mean", "median")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we factor out the mean and median names in to private[ml] val so to be used instead of the raw strings throughout?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that's better.


@Since("2.1.0")
override def load(path: String): Imputer = super.load(path)
}

/**
* :: Experimental ::
* Model fitted by [[Imputer]].
*
* @param surrogateDF Value by which missing values in the input columns will be replaced. This
* is stored using DataFrame with input column names and the corresponding surrogates.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is misleading - you're just storing the array of surrogates... did you mean something different? Otherwise the comment must be changed,

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It sounds like you had the idea of storing the surrogates something like:

+------+---------+
|column|surrogate|
+------+---------+
|  col1|      1.2|
|  col2|      3.4|
|  col3|      5.4|
+------+---------+

?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I refactored it a little for better extensibility.

inputCol1 inputCol2
surrogate1 surrogate2

*/
@Experimental
class ImputerModel private[ml](
override val uid: String,
val surrogateDF: DataFrame)
extends Model[ImputerModel] with ImputerParams with MLWritable {

import ImputerModel._

/** @group setParam */
def setInputCols(value: Array[String]): this.type = set(inputCols, value)

/** @group setParam */
def setOutputCols(value: Array[String]): this.type = set(outputCols, value)

override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema, logging = true)
val localInputCols = $(inputCols)
val localOutputCols = $(outputCols)
var outputDF = dataset
val surrogates = surrogateDF.head().getSeq[Double](0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

.as[Seq[Double]].head()

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I change it to
val surrogates = surrogateDF.select($(inputCols).head, $(inputCols).tail: _*).head().toSeq

which can actually handle different datatypes.


$(inputCols).indices.foreach { i =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could do $(inputCols).zip($(outputCols)).zip(surrogates).map { case ((inputCol, outputCol), icSurrogate) => ...

val inputCol = localInputCols(i)
val outputCol = localOutputCols(i)
val inputType = dataset.schema(inputCol).dataType
val ic = col(inputCol)
val icSurrogate = surrogates(i)
outputDF = outputDF.withColumn(outputCol, when(ic.isNull, icSurrogate)
.when(ic === $(missingValue), icSurrogate)
.otherwise(ic)
.cast(inputType))
}
outputDF.toDF()
}

override def transformSchema(schema: StructType): StructType = {
validateAndTransformSchema(schema)
}

override def copy(extra: ParamMap): ImputerModel = {
val copied = new ImputerModel(uid, surrogateDF)
copyValues(copied, extra).setParent(parent)
}

@Since("2.1.0")
override def write: MLWriter = new ImputerModelWriter(this)
}


@Since("2.1.0")
object ImputerModel extends MLReadable[ImputerModel] {

private[ImputerModel] class ImputerModelWriter(instance: ImputerModel) extends MLWriter {

override protected def saveImpl(path: String): Unit = {
DefaultParamsWriter.saveMetadata(instance, path, sc)
val dataPath = new Path(path, "data").toString
instance.surrogateDF.repartition(1).write.parquet(dataPath)
}
}

private class ImputerReader extends MLReader[ImputerModel] {

private val className = classOf[ImputerModel].getName

override def load(path: String): ImputerModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val dataPath = new Path(path, "data").toString
val surrogateDF = sqlContext.read.parquet(dataPath)
val model = new ImputerModel(metadata.uid, surrogateDF)
DefaultParamsReader.getAndSetParams(model, metadata)
model
}
}

@Since("2.1.0")
override def read: MLReader[ImputerModel] = new ImputerReader

@Since("2.1.0")
override def load(path: String): ImputerModel = super.load(path)
}
136 changes: 136 additions & 0 deletions mllib/src/test/scala/org/apache/spark/ml/feature/ImputerSuite.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.ml.feature

import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.sql.{DataFrame, Row}

class ImputerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need tests for multiple columns too


test("Imputer for Double with default missing Value NaN") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The way these tests are written was pretty confusing to me. It seems like the "mean" column should be the mean of some values but really it is the expected output when "mean" strategy is used. This is minor since it just affects readability and might be more of a personal preference. I think the following is clearer:

val df = sqlContext.createDataFrame( Seq(
      (0, 1.0),
      (1, 1.0),
      (2, 3.0),
      (3, 4.0),
      (4, Double.NaN)
    )).toDF("id", "value")
    val expectedOutput = Map(
      "mean"->Array(1.0, 1.0, 3.0, 4.0, 2.25),
      "median" -> Array(1.0, 1.0, 3.0, 4.0, 1.0),
      "mode" -> Array(1.0, 1.0, 3.0, 4.0, 1.0))
    Seq("mean", "median", "mode").foreach { strategy =>
      val imputer = new Imputer().setInputCol("value").setOutputCol("out").setStrategy(strategy)
      val model = imputer.fit(df)
      val result = model.transform(df).select("out").collect().map(_.getDouble(0))
      result.zip(expectedOutput(strategy)).foreach { case (actual, expected) =>
        assert(actual ~== expected absTol 1e-5)
      }

Really, just any way of indicating that the extra columns are expected outputs would be clearer to me.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I get the point, yet data in the format of

    val df = sqlContext.createDataFrame( Seq(
      (0, 1.0, 1.0, 1.0),
      (1, 1.0, 1.0, 1.0),
      (2, 3.0, 3.0, 3.0),
      (3, 4.0, 4.0, 4.0),
      (4, Double.NaN, 2.25, 1.0)
    )).toDF("id", "value", "exp_mean", "exp_median")

provides direct correspondence.

I've updated the columns names.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry to nitpick, but "expected_mean" seems significantly more clear than "exp_mean" at the expense of only a few extra characters. Only change it if you push more commits since its minor :)

val df = spark.createDataFrame( Seq(
(0, 1.0, 1.0, 1.0),
(1, 1.0, 1.0, 1.0),
(2, 3.0, 3.0, 3.0),
(3, 4.0, 4.0, 4.0),
(4, Double.NaN, 2.25, 1.0)
)).toDF("id", "value", "expected_mean", "expected_median")
val imputer = new Imputer().setInputCols(Array("value")).setOutputCols(Array("out"))
ImputerSuite.iterateStrategyTest(imputer, df)
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to add tests for the case where the entire column is null or NaN. I just checked the NaN case and it will throw a NPE in the fit method.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch yes - obviously the imputer can't actually do anything useful in that case - but it should either throw a useful error, or return the dataset unchanged.

I would favor an error in this case as if a user is explicitly wanting to impute missing data and all their data is missing, rather blow up now than later in the pipeline.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, actually this also fails if the entire input column is the missing value as well. We need to beef up the test suite :)

test("Imputer should handle NaNs when computing surrogate value, if missingValue is not NaN") {
val df = spark.createDataFrame( Seq(
(0, 1.0, 1.0, 1.0),
(1, 3.0, 3.0, 3.0),
(2, Double.NaN, Double.NaN, Double.NaN),
(3, -1.0, 2.0, 3.0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does approx quantile always choose the greater of the two middle values as the median? If so, can we add a comment noting that? NumPy computes the median of [1.0, 3.0] exactly as 2.0. Future developers might think it's a mistake.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, it always choose the smaller one of the two middle values as I saw in some tests.
In this test case, the median is computed from [1, 3, Double.NaN]. And Double.NaN is treated as it's greater than Double.MaValue.

)).toDF("id", "value", "expected_mean", "expected_median")
val imputer = new Imputer().setInputCols(Array("value")).setOutputCols(Array("out"))
.setMissingValue(-1.0)
ImputerSuite.iterateStrategyTest(imputer, df)
}

test("Imputer for Float with missing Value -1.0") {
val df = spark.createDataFrame( Seq(
(0, 1.0F, 1.0F, 1.0F),
(1, 3.0F, 3.0F, 3.0F),
(2, 10.0F, 10.0F, 10.0F),
(3, 10.0F, 10.0F, 10.0F),
(4, -1.0F, 6.0F, 3.0F)
)).toDF("id", "value", "expected_mean", "expected_median")
val imputer = new Imputer().setInputCols(Array("value")).setOutputCols(Array("out"))
.setMissingValue(-1)
ImputerSuite.iterateStrategyTest(imputer, df)
}

test("Imputer should impute null as well as 'missingValue'") {
val df = spark.createDataFrame( Seq(
(0, 4.0, 4.0, 4.0),
(1, 10.0, 10.0, 10.0),
(2, 10.0, 10.0, 10.0),
(3, Double.NaN, 8.0, 10.0),
(4, -1.0, 8.0, 10.0)
)).toDF("id", "value", "expected_mean", "expected_median")
val df2 = df.selectExpr("*", "IF(value=-1.0, null, value) as nullable_value")
val imputer = new Imputer().setInputCols(Array("nullable_value")).setOutputCols(Array("out"))
ImputerSuite.iterateStrategyTest(imputer, df2)
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should also have a test for a non-NaN missing value, but with NaN in the dataset, to check that "mean" and "median" behave as we expect.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added


test("Imputer throws exception when surrogate cannot be computed") {
val df = spark.createDataFrame( Seq(
(0, Double.NaN, 1.0, 1.0),
(1, Double.NaN, 3.0, 3.0),
(2, Double.NaN, Double.NaN, Double.NaN)
)).toDF("id", "value", "expected_mean", "expected_median")
Seq("mean", "median").foreach { strategy =>
val imputer = new Imputer().setInputCols(Array("value")).setOutputCols(Array("out"))
.setStrategy(strategy)
intercept[SparkException] {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Check message here also.

val model = imputer.fit(df)
}
}
}

test("Imputer read/write") {
val t = new Imputer()
.setInputCols(Array("myInputCol"))
.setOutputCols(Array("myOutputCol"))
.setMissingValue(-1.0)
testDefaultReadWrite(t)
}

test("ImputerModel read/write") {
val spark = this.spark
import spark.implicits._
val surrogateDF = Seq(1.234).toDF("myInputCol")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be "surrogate" col name - though I see we don't actually use it in load or transform

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this happens to be the correct column name for now.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok - we should add a test here to check the column names of instance and newInstance match up? (The below check is just for the actual values of the surrogate, correct?


val instance = new ImputerModel(
"myImputer", surrogateDF)
.setInputCols(Array("myInputCol"))
.setOutputCols(Array("myOutputCol"))
val newInstance = testDefaultReadWrite(instance)
assert(newInstance.surrogateDF.collect() === instance.surrogateDF.collect())
}

}

object ImputerSuite{
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

space before {


/**
* Imputation strategy. Available options are ["mean", "median"].
* @param df DataFrame with columns "id", "value", "expected_mean", "expected_median"
*/
def iterateStrategyTest(imputer: Imputer, df: DataFrame): Unit = {
Seq("mean", "median").foreach { strategy =>
imputer.setStrategy(strategy)
val model = imputer.fit(df)
model.transform(df).select("expected_" + strategy, "out").collect().foreach {
case Row(exp: Float, out: Float) =>
assert((exp.isNaN && out.isNaN) || (exp == out),
s"Imputed values differ. Expected: $exp, actual: $out")
case Row(exp: Double, out: Double) =>
assert((exp.isNaN && out.isNaN) || (exp ~== out absTol 1e-5),
s"Imputed values differ. Expected: $exp, actual: $out")
}
}
}
}