Skip to content
Closed
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 20 additions & 2 deletions docs/ml-features.md
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,7 @@ for more details on the API.

`StringIndexer` encodes a string column of labels to a column of label indices.
The indices are in `[0, numLabels)`, ordered by label frequencies, so the most frequent label gets index `0`.
The unseen labels will be put at index numLabels if user chooses to keep them.
If the input column is numeric, we cast it to string and index the string
values. When downstream pipeline components such as `Estimator` or
`Transformer` make use of this string-indexed label, you must set the input
Expand Down Expand Up @@ -542,12 +543,13 @@ column, we should get the following:
"a" gets index `0` because it is the most frequent, followed by "c" with index `1` and "b" with
index `2`.

Additionally, there are two strategies regarding how `StringIndexer` will handle
Additionally, there are three strategies regarding how `StringIndexer` will handle
unseen labels when you have fit a `StringIndexer` on one dataset and then use it
to transform another:

- throw an exception (which is the default)
- skip the row containing the unseen label entirely
- map the unseen labels with indices [numLabels]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

doc suggestion: "map the unseen labels to their own index"

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or just match the phrasing in the doc param


**Examples**

Expand All @@ -561,6 +563,7 @@ Let's go back to our previous example but this time reuse our previously defined
1 | b
2 | c
3 | d
4 | e
~~~~

If you've not set how `StringIndexer` handles unseen labels or set it to
Expand All @@ -576,7 +579,22 @@ will be generated:
2 | c | 1.0
~~~~

Notice that the row containing "d" does not appear.
Notice that the rows containing "d" or "e" do not appear.

If you call `setHandleInvalid("keep")`, the following dataset
will be generated:

~~~~
id | category | categoryIndex
----|----------|---------------
0 | a | 0.0
1 | b | 2.0
2 | c | 1.0
3 | d | 3.0
4 | e | 3.0
~~~~

Notice that the rows containing "d" or "e" are mapped with indices "3.0"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

doc suggestion: rows containing "d" or "e" are mapped with indices "3.0" => rows containing "d" and "e" are mapped to index "3.0"


<div class="codetabs">

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,16 @@

package org.apache.spark.ml.feature

import scala.language.existentials
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this needed?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

local build&test are fine, but will get compilation error on line 193 on Jenkins


import org.apache.hadoop.fs.Path

import org.apache.spark.SparkException
import org.apache.spark.annotation.Since
import org.apache.spark.ml.{Estimator, Model, Transformer}
import org.apache.spark.ml.attribute.{Attribute, NominalAttribute}
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
import org.apache.spark.ml.util._
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.functions._
Expand All @@ -34,8 +36,27 @@ import org.apache.spark.util.collection.OpenHashMap
/**
* Base trait for [[StringIndexer]] and [[StringIndexerModel]].
*/
private[feature] trait StringIndexerBase extends Params with HasInputCol with HasOutputCol
with HasHandleInvalid {
private[feature] trait StringIndexerBase extends Params with HasInputCol with HasOutputCol {

/**
* Param for how to handle unseen labels. Options are 'skip' (filter out rows with
* unseen labels), 'error' (throw an error), or 'keep' (put unseen labels in a special additional
* bucket, at index numLabels.
* Default: "error"
* @group param
*/
@Since("2.1.0")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I missed this before, but these Since annotations should stay set to 1.6.0 since handleInvalid and the get/set methods were added in 1.6.0

val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", "how to handle " +
"unseen labels. Options are 'skip' (filter out rows with unseen labels), " +
"error (throw an error), or 'keep' (put unseen labels in a special additional bucket," +
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need space after comma: "bucket, "

"at index numLabels).",
ParamValidators.inArray(StringIndexer.supportedHandleInvalids))

setDefault(handleInvalid, StringIndexer.ERROR_UNSEEN_LABEL)

/** @group getParam */
@Since("2.2.0")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

def getHandleInvalid: String = $(handleInvalid)

/** Validates and transforms the input schema. */
protected def validateAndTransformSchema(schema: StructType): StructType = {
Expand Down Expand Up @@ -70,11 +91,6 @@ class StringIndexer @Since("1.4.0") (
@Since("1.4.0")
def this() = this(Identifiable.randomUID("strIdx"))

/** @group setParam */
@Since("1.6.0")
def setHandleInvalid(value: String): this.type = set(handleInvalid, value)
setDefault(handleInvalid, "error")

/** @group setParam */
@Since("1.4.0")
def setInputCol(value: String): this.type = set(inputCol, value)
Expand All @@ -83,6 +99,10 @@ class StringIndexer @Since("1.4.0") (
@Since("1.4.0")
def setOutputCol(value: String): this.type = set(outputCol, value)

/** @group setParam */
@Since("2.2.0")
def setHandleInvalid(value: String): this.type = set(handleInvalid, value)
Copy link
Contributor

@imatiach-msft imatiach-msft Mar 3, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you keep the order of the params same as before? also, minor style comment -- keep the setDefault(handleInvalid) below the set method.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 for maintaining order.
setDefault will go in the trait (except in cases where it belongs in just one of the Estimator or Model)


@Since("2.0.0")
override def fit(dataset: Dataset[_]): StringIndexerModel = {
transformSchema(dataset.schema, logging = true)
Expand All @@ -105,7 +125,11 @@ class StringIndexer @Since("1.4.0") (

@Since("1.6.0")
object StringIndexer extends DefaultParamsReadable[StringIndexer] {

private[feature] val SKIP_UNSEEN_LABEL: String = "skip"
private[feature] val ERROR_UNSEEN_LABEL: String = "error"
private[feature] val KEEP_UNSEEN_LABEL: String = "keep"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is very nice, good use of constants, I really like to see this type of code :)

Copy link
Contributor

@imatiach-msft imatiach-msft Mar 3, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would make me even happier if these were public and could be used by the test code, but I think it's up to the committers (jkbradley)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At some point, let's do that, but not yet. I like keeping things private at first in case we find mistakes after release and need to change things.

private[feature] val supportedHandleInvalids: Array[String] =
Array(SKIP_UNSEEN_LABEL, ERROR_UNSEEN_LABEL, KEEP_UNSEEN_LABEL)
@Since("1.6.0")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: add newline here

override def load(path: String): StringIndexer = super.load(path)
}
Expand Down Expand Up @@ -141,11 +165,6 @@ class StringIndexerModel (
map
}

/** @group setParam */
@Since("1.6.0")
def setHandleInvalid(value: String): this.type = set(handleInvalid, value)
setDefault(handleInvalid, "error")

/** @group setParam */
@Since("1.4.0")
def setInputCol(value: String): this.type = set(inputCol, value)
Expand All @@ -154,6 +173,11 @@ class StringIndexerModel (
@Since("1.4.0")
def setOutputCol(value: String): this.type = set(outputCol, value)

/** @group setParam */
@Since("2.2.0")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

def setHandleInvalid(value: String): this.type = set(handleInvalid, value)
setDefault(handleInvalid, StringIndexer.ERROR_UNSEEN_LABEL)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to set default here since it's set in the trait


@Since("2.0.0")
override def transform(dataset: Dataset[_]): DataFrame = {
if (!dataset.schema.fieldNames.contains($(inputCol))) {
Expand All @@ -163,25 +187,28 @@ class StringIndexerModel (
}
transformSchema(dataset.schema, logging = true)

val metadata = NominalAttribute.defaultAttr
.withName($(outputCol)).withValues(labels).toMetadata()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

withValues should include a special field for invalid, if we are keeping invalid labels. How about calling that field "_invalidLabels" (and defining this constant in the StringIndexer object)?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry, I cannot fully get the point.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think he means that "labels" above should also include the invalid bucket. In previous ML frameworks I've worked on we've just called this "unknown".

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, that's what I meant: In withValues(labels), labels can be set as:

val labels = getHandleInvalid match {
  case StringIndexer.KEEP_UNSEEN_LABEL => labels :+ "__unknown"
  case _ => labels
}

I'm adding underscores to the attribute name to make it a little less likely to hit conflicts.

// If we are skipping invalid records, filter them out.
val (filteredDataset, keepInvalid) = getHandleInvalid match {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor style comment: instead of keepInvalid, do you think that indexInvalid might be a better name (?)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually, I think returning a tuple here just makes things more confusing. Maybe you can move the check outside of the match.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm OK with returning a tuple; that's a common pattern. Do you mean that it makes the code inside the match statement confusing?

case StringIndexer.SKIP_UNSEEN_LABEL =>
val filterer = udf { label: String =>
labelToIndex.contains(label)
}
(dataset.where(filterer(dataset($(inputCol)))), false)
case _ => (dataset, getHandleInvalid == StringIndexer.KEEP_UNSEEN_LABEL)
}

val indexer = udf { label: String =>
if (labelToIndex.contains(label)) {
labelToIndex(label)
} else if (keepInvalid) {
labels.length
} else {
throw new SparkException(s"Unseen label: $label.")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update with recommendation to set handleInvalid to "keep" to handle unseen labels.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you improve the error message?

throw new SparkException(s"Unseen label: $label.  To handle unseen labels, set Param handleInvalid to ${StringIndexer.KEEP_UNSEEN_LABEL}.")

}
}

val metadata = NominalAttribute.defaultAttr
.withName($(outputCol)).withValues(labels).toMetadata()
// If we are skipping invalid records, filter them out.
val filteredDataset = getHandleInvalid match {
case "skip" =>
val filterer = udf { label: String =>
labelToIndex.contains(label)
}
dataset.where(filterer(dataset($(inputCol))))
case _ => dataset
}
filteredDataset.select(col("*"),
indexer(dataset($(inputCol)).cast(StringType)).as($(outputCol), metadata))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ class StringIndexerSuite

test("StringIndexerUnseen") {
val data = Seq((0, "a"), (1, "b"), (4, "b"))
val data2 = Seq((0, "a"), (1, "b"), (2, "c"))
val data2 = Seq((0, "a"), (1, "b"), (2, "c"), (3, "d"))
val df = data.toDF("id", "label")
val df2 = data2.toDF("id", "label")
val indexer = new StringIndexer()
Expand All @@ -75,22 +75,32 @@ class StringIndexerSuite
intercept[SparkException] {
indexer.transform(df2).collect()
}
val indexerSkipInvalid = new StringIndexer()
.setInputCol("label")
.setOutputCol("labelIndex")
.setHandleInvalid("skip")
.fit(df)

indexer.setHandleInvalid("skip")
// Verify that we skip the c record
val transformed = indexerSkipInvalid.transform(df2)
val attr = Attribute.fromStructField(transformed.schema("labelIndex"))
val transformedSkip = indexer.transform(df2)
val attrSkip = Attribute.fromStructField(transformedSkip.schema("labelIndex"))
.asInstanceOf[NominalAttribute]
assert(attr.values.get === Array("b", "a"))
val output = transformed.select("id", "labelIndex").rdd.map { r =>
assert(attrSkip.values.get === Array("b", "a"))
val outputSkip = transformedSkip.select("id", "labelIndex").rdd.map { r =>
(r.getInt(0), r.getDouble(1))
}.collect().toSet
// a -> 1, b -> 0
val expected = Set((0, 1.0), (1, 0.0))
assert(output === expected)
val expectedSkip = Set((0, 1.0), (1, 0.0))
assert(outputSkip === expectedSkip)

indexer.setHandleInvalid("keep")
// Verify that we keep the unseen records
val transformedKeep = indexer.transform(df2)
val attrKeep = Attribute.fromStructField(transformedKeep.schema("labelIndex"))
.asInstanceOf[NominalAttribute]
assert(attrKeep.values.get === Array("b", "a"))
val outputKeep = transformedKeep.select("id", "labelIndex").rdd.map { r =>
(r.getInt(0), r.getDouble(1))
}.collect().toSet
// a -> 1, b -> 0, c -> 2, d -> 3
val expectedKeep = Set((0, 1.0), (1, 0.0), (2, 2.0), (3, 2.0))
assert(outputKeep === expectedKeep)
}

test("StringIndexer with a numeric input column") {
Expand Down
4 changes: 4 additions & 0 deletions project/MimaExcludes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -907,6 +907,10 @@ object MimaExcludes {
) ++ Seq(
// [SPARK-17163] Unify logistic regression interface. Private constructor has new signature.
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionModel.this")
) ++ Seq(
// [SPARK-17498] StringIndexer enhancement for handling unseen labels
ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.feature.StringIndexer"),
ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.feature.StringIndexerModel")
) ++ Seq(
// [SPARK-17365][Core] Remove/Kill multiple executors together to reduce RPC call time
ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.SparkContext")
Expand Down