-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-20619][ML] StringIndexer supports multiple ways to order label #17879
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
ffd0cfc
97e020f
ba34043
ff9b1d6
07198d9
6bbe7df
53381ea
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -59,6 +59,28 @@ private[feature] trait StringIndexerBase extends Params with HasInputCol with Ha | |
| @Since("1.6.0") | ||
| def getHandleInvalid: String = $(handleInvalid) | ||
|
|
||
| /** | ||
| * Param for how to order labels of string column. The first label after ordering is assigned | ||
| * an index of 0. | ||
| * Options are: | ||
| * - 'freq_desc': descending order by label frequency (most frequent label assigned 0) | ||
| * - 'freq_asc': ascending order by label frequency (least frequent label assigned 0) | ||
| * - 'alphabet_desc': descending alphabetical order | ||
| * - 'alphabet_asc': ascending alphabetical order | ||
| * Default is 'freq_desc'. | ||
| * | ||
| * @group param | ||
| */ | ||
| @Since("2.2.0") | ||
| final val stringOrderType: Param[String] = new Param(this, "stringOrderType", | ||
| "The method used to order values of input column. " + | ||
| s"Supported options: ${StringIndexer.supportedStringOrderType.mkString(", ")}.", | ||
| (value: String) => StringIndexer.supportedStringOrderType.contains(value.toLowerCase)) | ||
|
|
||
| /** @group getParam */ | ||
| @Since("2.2.0") | ||
| def getStringOrderType: String = $(stringOrderType) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I checked other ML classes. Looks like for a case-insensitive setting, we may do toLowerCase in its public API: And you can use
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
|
||
| /** Validates and transforms the input schema. */ | ||
| protected def validateAndTransformSchema(schema: StructType): StructType = { | ||
| val inputColName = $(inputCol) | ||
|
|
@@ -79,8 +101,9 @@ private[feature] trait StringIndexerBase extends Params with HasInputCol with Ha | |
| /** | ||
| * A label indexer that maps a string column of labels to an ML column of label indices. | ||
| * If the input column is numeric, we cast it to string and index the string values. | ||
| * The indices are in [0, numLabels), ordered by label frequencies. | ||
| * So the most frequent label gets index 0. | ||
| * The indices are in [0, numLabels). By default, this is ordered by label frequencies | ||
| * so the most frequent label gets index 0. The ordering behavior is controlled by | ||
| * setting stringOrderType. | ||
| * | ||
| * @see `IndexToString` for the inverse transformation | ||
| */ | ||
|
|
@@ -96,6 +119,11 @@ class StringIndexer @Since("1.4.0") ( | |
| @Since("1.6.0") | ||
| def setHandleInvalid(value: String): this.type = set(handleInvalid, value) | ||
|
|
||
| /** @group setParam */ | ||
| @Since("2.2.0") | ||
| def setStringOrderType(value: String): this.type = set(stringOrderType, value) | ||
| setDefault(stringOrderType, StringIndexer.FREQ_DESC) | ||
|
|
||
| /** @group setParam */ | ||
| @Since("1.4.0") | ||
| def setInputCol(value: String): this.type = set(inputCol, value) | ||
|
|
@@ -107,11 +135,15 @@ class StringIndexer @Since("1.4.0") ( | |
| @Since("2.0.0") | ||
| override def fit(dataset: Dataset[_]): StringIndexerModel = { | ||
| transformSchema(dataset.schema, logging = true) | ||
| val counts = dataset.na.drop(Array($(inputCol))).select(col($(inputCol)).cast(StringType)) | ||
| .rdd | ||
| .map(_.getString(0)) | ||
| .countByValue() | ||
| val labels = counts.toSeq.sortBy(-_._2).map(_._1).toArray | ||
| val values = dataset.na.drop(Array($(inputCol))) | ||
| .select(col($(inputCol)).cast(StringType)) | ||
| .rdd.map(_.getString(0)) | ||
| val labels = $(stringOrderType).toLowerCase match { | ||
| case StringIndexer.FREQ_DESC => values.countByValue().toSeq.sortBy(-_._2).map(_._1).toArray | ||
| case StringIndexer.FREQ_ASC => values.countByValue().toSeq.sortBy(_._2).map(_._1).toArray | ||
| case StringIndexer.ALPHABET_DESC => values.distinct.collect.sortWith(_ > _) | ||
| case StringIndexer.ALPHABET_ASC => values.distinct.collect.sortWith(_ < _) | ||
| } | ||
| copyValues(new StringIndexerModel(uid, labels).setParent(this)) | ||
| } | ||
|
|
||
|
|
@@ -131,6 +163,12 @@ object StringIndexer extends DefaultParamsReadable[StringIndexer] { | |
| private[feature] val KEEP_INVALID: String = "keep" | ||
| private[feature] val supportedHandleInvalids: Array[String] = | ||
| Array(SKIP_INVALID, ERROR_INVALID, KEEP_INVALID) | ||
| private[feature] val FREQ_DESC: String = "freq_desc" | ||
|
||
| private[feature] val FREQ_ASC: String = "freq_asc" | ||
| private[feature] val ALPHABET_DESC: String = "alphabet_desc" | ||
| private[feature] val ALPHABET_ASC: String = "alphabet_asc" | ||
|
||
| private[feature] val supportedStringOrderType: Array[String] = | ||
| Array(FREQ_DESC, FREQ_ASC, ALPHABET_DESC, ALPHABET_ASC) | ||
|
|
||
| @Since("1.6.0") | ||
| override def load(path: String): StringIndexer = super.load(path) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use
ParamValidators.inArray?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@viirya
ParamValidators.inArraydoes not allow case-insensitive validation, does it?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I originally thought you'd change to case-sensitive. It looks good to me.