-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-17498][ML] StringIndexer enhancement for handling unseen labels #16883
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 7 commits
b970728
5d4b07f
0eb7f07
9a41745
1736057
ebe9ddb
27c1b10
9bcaffc
4dc10e6
fa24e43
d1acfdb
c70e003
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -502,7 +502,7 @@ for more details on the API. | |
| ## StringIndexer | ||
|
|
||
| `StringIndexer` encodes a string column of labels to a column of label indices. | ||
| The indices are in `[0, numLabels)`, ordered by label frequencies, so the most frequent label gets index `0`. | ||
| The indices are in `[0, numLabels]`, ordered by label frequencies, so the most frequent label gets index `0`. | ||
| If the input column is numeric, we cast it to string and index the string | ||
| values. When downstream pipeline components such as `Estimator` or | ||
| `Transformer` make use of this string-indexed label, you must set the input | ||
|
|
@@ -542,12 +542,13 @@ column, we should get the following: | |
| "a" gets index `0` because it is the most frequent, followed by "c" with index `1` and "b" with | ||
| index `2`. | ||
|
|
||
| Additionally, there are two strategies regarding how `StringIndexer` will handle | ||
| Additionally, there are three strategies regarding how `StringIndexer` will handle | ||
| unseen labels when you have fit a `StringIndexer` on one dataset and then use it | ||
| to transform another: | ||
|
|
||
| - throw an exception (which is the default) | ||
| - skip the row containing the unseen label entirely | ||
| - map the unseen labels with indices [numLabels] | ||
|
||
|
|
||
| **Examples** | ||
|
|
||
|
|
@@ -561,6 +562,7 @@ Let's go back to our previous example but this time reuse our previously defined | |
| 1 | b | ||
| 2 | c | ||
| 3 | d | ||
| 4 | e | ||
| ~~~~ | ||
|
|
||
| If you've not set how `StringIndexer` handles unseen labels or set it to | ||
|
|
@@ -576,7 +578,22 @@ will be generated: | |
| 2 | c | 1.0 | ||
| ~~~~ | ||
|
|
||
| Notice that the row containing "d" does not appear. | ||
| Notice that the rows containing "d" or "e" do not appear. | ||
|
|
||
| If you had called `setHandleInvalid("keep")`, the following dataset | ||
|
||
| will be generated: | ||
|
|
||
| ~~~~ | ||
| id | category | categoryIndex | ||
| ----|----------|--------------- | ||
| 0 | a | 0.0 | ||
| 1 | b | 2.0 | ||
| 2 | c | 1.0 | ||
| 3 | d | 3.0 | ||
| 4 | e | 3.0 | ||
| ~~~~ | ||
|
|
||
| Notice that the rows containing "d" or "e" are mapped with indices "3.0" | ||
|
||
|
|
||
| <div class="codetabs"> | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,14 +17,16 @@ | |
|
|
||
| package org.apache.spark.ml.feature | ||
|
|
||
| import scala.language.existentials | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this needed?
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. local build&test are fine, but will get compilation error on line 193 on Jenkins |
||
|
|
||
| import org.apache.hadoop.fs.Path | ||
|
|
||
| import org.apache.spark.SparkException | ||
| import org.apache.spark.annotation.Since | ||
| import org.apache.spark.ml.{Estimator, Model, Transformer} | ||
| import org.apache.spark.ml.attribute.{Attribute, NominalAttribute} | ||
| import org.apache.spark.ml.param._ | ||
| import org.apache.spark.ml.param.shared._ | ||
| import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} | ||
| import org.apache.spark.ml.util._ | ||
| import org.apache.spark.sql.{DataFrame, Dataset} | ||
| import org.apache.spark.sql.functions._ | ||
|
|
@@ -34,8 +36,25 @@ import org.apache.spark.util.collection.OpenHashMap | |
| /** | ||
| * Base trait for [[StringIndexer]] and [[StringIndexerModel]]. | ||
| */ | ||
| private[feature] trait StringIndexerBase extends Params with HasInputCol with HasOutputCol | ||
| with HasHandleInvalid { | ||
| private[feature] trait StringIndexerBase extends Params with HasInputCol with HasOutputCol { | ||
| val SKIP_UNSEEN_LABEL: String = "skip" | ||
|
||
| val ERROR_UNSEEN_LABEL: String = "error" | ||
| val KEEP_UNSEEN_LABEL: String = "keep" | ||
| val supportedHandleInvalids: Array[String] = | ||
| Array(SKIP_UNSEEN_LABEL, ERROR_UNSEEN_LABEL, KEEP_UNSEEN_LABEL) | ||
|
|
||
| /** | ||
| * Param for how to handle unseen labels. Options are 'skip' (filter out rows with | ||
| * unseen labels), 'error' (throw an error), or 'keep' (map unseen labels with | ||
|
||
| * indices [numLabels]). | ||
| * Default: "error" | ||
| * @group param | ||
| */ | ||
| @Since("2.1.0") | ||
|
||
| val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", "how to handle " + | ||
| "unseen labels. Options are 'skip' (filter out rows with unseen labels), " + | ||
| "error (throw an error), or 'keep' (map unseen labels with indices [numLabels]).", | ||
| ParamValidators.inArray(supportedHandleInvalids)) | ||
|
|
||
| /** Validates and transforms the input schema. */ | ||
| protected def validateAndTransformSchema(schema: StructType): StructType = { | ||
|
|
@@ -70,11 +89,6 @@ class StringIndexer @Since("1.4.0") ( | |
| @Since("1.4.0") | ||
| def this() = this(Identifiable.randomUID("strIdx")) | ||
|
|
||
| /** @group setParam */ | ||
| @Since("1.6.0") | ||
| def setHandleInvalid(value: String): this.type = set(handleInvalid, value) | ||
| setDefault(handleInvalid, "error") | ||
|
|
||
| /** @group setParam */ | ||
| @Since("1.4.0") | ||
| def setInputCol(value: String): this.type = set(inputCol, value) | ||
|
|
@@ -83,6 +97,15 @@ class StringIndexer @Since("1.4.0") ( | |
| @Since("1.4.0") | ||
| def setOutputCol(value: String): this.type = set(outputCol, value) | ||
|
|
||
| /** @group getParam */ | ||
| @Since("2.1.0") | ||
| def getHandleInvalid: String = $(handleInvalid) | ||
|
||
|
|
||
| /** @group setParam */ | ||
| @Since("2.1.0") | ||
|
||
| def setHandleInvalid(value: String): this.type = set(handleInvalid, value) | ||
|
||
| setDefault(handleInvalid, ERROR_UNSEEN_LABEL) | ||
|
||
|
|
||
| @Since("2.0.0") | ||
| override def fit(dataset: Dataset[_]): StringIndexerModel = { | ||
| transformSchema(dataset.schema, logging = true) | ||
|
|
@@ -141,11 +164,6 @@ class StringIndexerModel ( | |
| map | ||
| } | ||
|
|
||
| /** @group setParam */ | ||
| @Since("1.6.0") | ||
| def setHandleInvalid(value: String): this.type = set(handleInvalid, value) | ||
| setDefault(handleInvalid, "error") | ||
|
|
||
| /** @group setParam */ | ||
| @Since("1.4.0") | ||
| def setInputCol(value: String): this.type = set(inputCol, value) | ||
|
|
@@ -154,6 +172,15 @@ class StringIndexerModel ( | |
| @Since("1.4.0") | ||
| def setOutputCol(value: String): this.type = set(outputCol, value) | ||
|
|
||
| /** @group getParam */ | ||
| @Since("2.1.0") | ||
| def getHandleInvalid: String = $(handleInvalid) | ||
|
|
||
| /** @group setParam */ | ||
| @Since("2.1.0") | ||
| def setHandleInvalid(value: String): this.type = set(handleInvalid, value) | ||
| setDefault(handleInvalid, ERROR_UNSEEN_LABEL) | ||
|
|
||
| @Since("2.0.0") | ||
| override def transform(dataset: Dataset[_]): DataFrame = { | ||
| if (!dataset.schema.fieldNames.contains($(inputCol))) { | ||
|
|
@@ -163,25 +190,28 @@ class StringIndexerModel ( | |
| } | ||
| transformSchema(dataset.schema, logging = true) | ||
|
|
||
| val metadata = NominalAttribute.defaultAttr | ||
| .withName($(outputCol)).withValues(labels).toMetadata() | ||
|
||
| // If we are skipping invalid records, filter them out. | ||
| val (filteredDataset, keepInvalid) = getHandleInvalid match { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. minor style comment: instead of keepInvalid, do you think that indexInvalid might be a better name (?)
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. actually, I think returning a tuple here just makes things more confusing. Maybe you can move the check outside of the match.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm OK with returning a tuple; that's a common pattern. Do you mean that it makes the code inside the match statement confusing? |
||
| case SKIP_UNSEEN_LABEL => | ||
| val filterer = udf { label: String => | ||
| labelToIndex.contains(label) | ||
| } | ||
| (dataset.where(filterer(dataset($(inputCol)))), false) | ||
| case _ => (dataset, getHandleInvalid == KEEP_UNSEEN_LABEL) | ||
| } | ||
|
|
||
| val indexer = udf { label: String => | ||
| if (labelToIndex.contains(label)) { | ||
| labelToIndex(label) | ||
| } else if (keepInvalid) { | ||
| labels.length | ||
| } else { | ||
| throw new SparkException(s"Unseen label: $label.") | ||
|
||
| } | ||
| } | ||
|
|
||
| val metadata = NominalAttribute.defaultAttr | ||
| .withName($(outputCol)).withValues(labels).toMetadata() | ||
| // If we are skipping invalid records, filter them out. | ||
| val filteredDataset = getHandleInvalid match { | ||
| case "skip" => | ||
| val filterer = udf { label: String => | ||
| labelToIndex.contains(label) | ||
| } | ||
| dataset.where(filterer(dataset($(inputCol)))) | ||
| case _ => dataset | ||
| } | ||
| filteredDataset.select(col("*"), | ||
| indexer(dataset($(inputCol)).cast(StringType)).as($(outputCol), metadata)) | ||
| } | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -64,7 +64,7 @@ class StringIndexerSuite | |
|
|
||
| test("StringIndexerUnseen") { | ||
| val data = Seq((0, "a"), (1, "b"), (4, "b")) | ||
| val data2 = Seq((0, "a"), (1, "b"), (2, "c")) | ||
| val data2 = Seq((0, "a"), (1, "b"), (2, "c"), (3, "d")) | ||
| val df = data.toDF("id", "label") | ||
| val df2 = data2.toDF("id", "label") | ||
| val indexer = new StringIndexer() | ||
|
|
@@ -75,22 +75,32 @@ class StringIndexerSuite | |
| intercept[SparkException] { | ||
| indexer.transform(df2).collect() | ||
| } | ||
| val indexerSkipInvalid = new StringIndexer() | ||
| .setInputCol("label") | ||
| .setOutputCol("labelIndex") | ||
| .setHandleInvalid("skip") | ||
| .fit(df) | ||
|
|
||
| indexer.setHandleInvalid("skip") | ||
| // Verify that we skip the c record | ||
| val transformed = indexerSkipInvalid.transform(df2) | ||
| val attr = Attribute.fromStructField(transformed.schema("labelIndex")) | ||
| var transformed = indexer.transform(df2) | ||
|
||
| var attr = Attribute.fromStructField(transformed.schema("labelIndex")) | ||
| .asInstanceOf[NominalAttribute] | ||
| assert(attr.values.get === Array("b", "a")) | ||
| val output = transformed.select("id", "labelIndex").rdd.map { r => | ||
| val outputSkip = transformed.select("id", "labelIndex").rdd.map { r => | ||
| (r.getInt(0), r.getDouble(1)) | ||
| }.collect().toSet | ||
| // a -> 1, b -> 0 | ||
| val expected = Set((0, 1.0), (1, 0.0)) | ||
| assert(output === expected) | ||
| val expectedSkip = Set((0, 1.0), (1, 0.0)) | ||
| assert(outputSkip === expectedSkip) | ||
|
|
||
| indexer.setHandleInvalid("keep") | ||
| // Verify that we keep the unseen records | ||
| transformed = indexer.transform(df2) | ||
| attr = Attribute.fromStructField(transformed.schema("labelIndex")) | ||
| .asInstanceOf[NominalAttribute] | ||
| assert(attr.values.get === Array("b", "a")) | ||
| val outputKeep = transformed.select("id", "labelIndex").rdd.map { r => | ||
| (r.getInt(0), r.getDouble(1)) | ||
| }.collect().toSet | ||
| // a -> 1, b -> 0, c -> 2, d -> 3 | ||
| val expectedKeep = Set((0, 1.0), (1, 0.0), (2, 2.0), (3, 2.0)) | ||
| assert(outputKeep === expectedKeep) | ||
| } | ||
|
|
||
| test("StringIndexer with a numeric input column") { | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This change is not correct, except when keeping invalid ones.