Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -39,20 +39,21 @@ import org.apache.spark.util.collection.OpenHashMap
private[feature] trait StringIndexerBase extends Params with HasInputCol with HasOutputCol {

/**
* Param for how to handle unseen labels. Options are 'skip' (filter out rows with
* unseen labels), 'error' (throw an error), or 'keep' (put unseen labels in a special additional
* bucket, at index numLabels.
* Param for how to handle invalid data (unseen labels or NULL values).
* Options are 'skip' (filter out rows with invalid data),
* 'error' (throw an error), or 'keep' (put invalid data in a special additional
* bucket, at index numLabels).
* Default: "error"
* @group param
*/
@Since("1.6.0")
val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", "how to handle " +
"unseen labels. Options are 'skip' (filter out rows with unseen labels), " +
"error (throw an error), or 'keep' (put unseen labels in a special additional bucket, " +
"at index numLabels).",
"invalid data (unseen labels or NULL values). " +
"Options are 'skip' (filter out rows with invalid data), error (throw an error), " +
"or 'keep' (put invalid data in a special additional bucket, at index numLabels).",
ParamValidators.inArray(StringIndexer.supportedHandleInvalids))

setDefault(handleInvalid, StringIndexer.ERROR_UNSEEN_LABEL)
setDefault(handleInvalid, StringIndexer.ERROR_INVALID)

/** @group getParam */
@Since("1.6.0")
Expand Down Expand Up @@ -106,7 +107,7 @@ class StringIndexer @Since("1.4.0") (
@Since("2.0.0")
override def fit(dataset: Dataset[_]): StringIndexerModel = {
transformSchema(dataset.schema, logging = true)
val counts = dataset.select(col($(inputCol)).cast(StringType))
val counts = dataset.na.drop(Array($(inputCol))).select(col($(inputCol)).cast(StringType))
.rdd
.map(_.getString(0))
.countByValue()
Expand All @@ -125,11 +126,11 @@ class StringIndexer @Since("1.4.0") (

@Since("1.6.0")
object StringIndexer extends DefaultParamsReadable[StringIndexer] {
private[feature] val SKIP_UNSEEN_LABEL: String = "skip"
private[feature] val ERROR_UNSEEN_LABEL: String = "error"
private[feature] val KEEP_UNSEEN_LABEL: String = "keep"
private[feature] val SKIP_INVALID: String = "skip"
private[feature] val ERROR_INVALID: String = "error"
private[feature] val KEEP_INVALID: String = "keep"
private[feature] val supportedHandleInvalids: Array[String] =
Array(SKIP_UNSEEN_LABEL, ERROR_UNSEEN_LABEL, KEEP_UNSEEN_LABEL)
Array(SKIP_INVALID, ERROR_INVALID, KEEP_INVALID)

@Since("1.6.0")
override def load(path: String): StringIndexer = super.load(path)
Expand Down Expand Up @@ -188,30 +189,39 @@ class StringIndexerModel (
transformSchema(dataset.schema, logging = true)

val filteredLabels = getHandleInvalid match {
case StringIndexer.KEEP_UNSEEN_LABEL => labels :+ "__unknown"
case StringIndexer.KEEP_INVALID => labels :+ "__unknown"
case _ => labels
}

val metadata = NominalAttribute.defaultAttr
.withName($(outputCol)).withValues(filteredLabels).toMetadata()
// If we are skipping invalid records, filter them out.
val (filteredDataset, keepInvalid) = getHandleInvalid match {
case StringIndexer.SKIP_UNSEEN_LABEL =>
case StringIndexer.SKIP_INVALID =>
val filterer = udf { label: String =>
labelToIndex.contains(label)
}
(dataset.where(filterer(dataset($(inputCol)))), false)
case _ => (dataset, getHandleInvalid == StringIndexer.KEEP_UNSEEN_LABEL)
(dataset.na.drop(Array($(inputCol))).where(filterer(dataset($(inputCol)))), false)
case _ => (dataset, getHandleInvalid == StringIndexer.KEEP_INVALID)
}

val indexer = udf { label: String =>
if (labelToIndex.contains(label)) {
labelToIndex(label)
} else if (keepInvalid) {
labels.length
if (label == null) {
if (keepInvalid) {
labels.length
} else {
throw new SparkException("StringIndexer encountered NULL value. To handle or skip " +
"NULLS, try setting StringIndexer.handleInvalid.")
}
} else {
throw new SparkException(s"Unseen label: $label. To handle unseen labels, " +
s"set Param handleInvalid to ${StringIndexer.KEEP_UNSEEN_LABEL}.")
if (labelToIndex.contains(label)) {
labelToIndex(label)
} else if (keepInvalid) {
labels.length
} else {
throw new SparkException(s"Unseen label: $label. To handle unseen labels, " +
s"set Param handleInvalid to ${StringIndexer.KEEP_INVALID}.")
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,51 @@ class StringIndexerSuite
assert(output === expected)
}

test("StringIndexer with NULLs") {
val data: Seq[(Int, String)] = Seq((0, "a"), (1, "b"), (2, "b"), (3, null))
val data2: Seq[(Int, String)] = Seq((0, "a"), (1, "b"), (3, null))
val df = data.toDF("id", "label")
val df2 = data2.toDF("id", "label")

val indexer = new StringIndexer()
.setInputCol("label")
.setOutputCol("labelIndex")

withClue("StringIndexer should throw error when setHandleInvalid=error " +
"when given NULL values") {
intercept[SparkException] {
indexer.setHandleInvalid("error")
indexer.fit(df).transform(df2).collect()
}
}

indexer.setHandleInvalid("skip")
val transformedSkip = indexer.fit(df).transform(df2)
val attrSkip = Attribute
.fromStructField(transformedSkip.schema("labelIndex"))
.asInstanceOf[NominalAttribute]
assert(attrSkip.values.get === Array("b", "a"))
val outputSkip = transformedSkip.select("id", "labelIndex").rdd.map { r =>
(r.getInt(0), r.getDouble(1))
}.collect().toSet
// a -> 1, b -> 0
val expectedSkip = Set((0, 1.0), (1, 0.0))
assert(outputSkip === expectedSkip)

indexer.setHandleInvalid("keep")
val transformedKeep = indexer.fit(df).transform(df2)
val attrKeep = Attribute
.fromStructField(transformedKeep.schema("labelIndex"))
.asInstanceOf[NominalAttribute]
assert(attrKeep.values.get === Array("b", "a", "__unknown"))
val outputKeep = transformedKeep.select("id", "labelIndex").rdd.map { r =>
(r.getInt(0), r.getDouble(1))
}.collect().toSet
// a -> 1, b -> 0, null -> 2
val expectedKeep = Set((0, 1.0), (1, 0.0), (3, 2.0))
assert(outputKeep === expectedKeep)
}

test("StringIndexerModel should keep silent if the input column does not exist.") {
val indexerModel = new StringIndexerModel("indexer", Array("a", "b", "c"))
.setInputCol("label")
Expand Down