Skip to content
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,28 @@ private[feature] trait StringIndexerBase extends Params with HasInputCol with Ha
@Since("1.6.0")
def getHandleInvalid: String = $(handleInvalid)

/**
* Param for how to order labels of string column. The first label after ordering is assigned
* an index of 0.
* Options are:
* - 'freq_desc': descending order by label frequency (most frequent label assigned 0)
* - 'freq_asc': ascending order by label frequency (least frequent label assigned 0)
* - 'alphabet_desc': descending alphabetical order
* - 'alphabet_asc': ascending alphabetical order
* Default is 'freq_desc'.
*
* @group param
*/
@Since("2.2.0")
final val stringOrderType: Param[String] = new Param(this, "stringOrderType",
"The method used to order values of input column. " +
s"Supported options: ${StringIndexer.supportedStringOrderType.mkString(", ")}.",
(value: String) => StringIndexer.supportedStringOrderType.contains(value.toLowerCase))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use ParamValidators.inArray?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@viirya ParamValidators.inArray does not allow case-insensitive validation, does it?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I originally thought you'd change to case-sensitive. It looks good to me.


/** @group getParam */
@Since("2.2.0")
def getStringOrderType: String = $(stringOrderType)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I checked other ML classes. Looks like for a case-insensitive setting, we may do toLowerCase in its public API:

def getStringOrderType: String = $(stringOrderType).toLowerCase

And you can use getStringOrderType below instead of $(stringOrderType).toLowerCase in fit.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@viirya Which ML classes were you referring to? I was told not to change the raw values in the getters in other PRs #16675.


/** Validates and transforms the input schema. */
protected def validateAndTransformSchema(schema: StructType): StructType = {
val inputColName = $(inputCol)
Expand All @@ -79,8 +101,9 @@ private[feature] trait StringIndexerBase extends Params with HasInputCol with Ha
/**
* A label indexer that maps a string column of labels to an ML column of label indices.
* If the input column is numeric, we cast it to string and index the string values.
* The indices are in [0, numLabels), ordered by label frequencies.
* So the most frequent label gets index 0.
* The indices are in [0, numLabels). By default, this is ordered by label frequencies
* so the most frequent label gets index 0. The ordering behavior is controlled by
* setting stringOrderType.
*
* @see `IndexToString` for the inverse transformation
*/
Expand All @@ -96,6 +119,11 @@ class StringIndexer @Since("1.4.0") (
@Since("1.6.0")
def setHandleInvalid(value: String): this.type = set(handleInvalid, value)

/** @group setParam */
@Since("2.2.0")
def setStringOrderType(value: String): this.type = set(stringOrderType, value)
setDefault(stringOrderType, StringIndexer.FREQ_DESC)

/** @group setParam */
@Since("1.4.0")
def setInputCol(value: String): this.type = set(inputCol, value)
Expand All @@ -107,11 +135,15 @@ class StringIndexer @Since("1.4.0") (
@Since("2.0.0")
override def fit(dataset: Dataset[_]): StringIndexerModel = {
transformSchema(dataset.schema, logging = true)
val counts = dataset.na.drop(Array($(inputCol))).select(col($(inputCol)).cast(StringType))
.rdd
.map(_.getString(0))
.countByValue()
val labels = counts.toSeq.sortBy(-_._2).map(_._1).toArray
val values = dataset.na.drop(Array($(inputCol)))
.select(col($(inputCol)).cast(StringType))
.rdd.map(_.getString(0))
val labels = $(stringOrderType).toLowerCase match {
case StringIndexer.FREQ_DESC => values.countByValue().toSeq.sortBy(-_._2).map(_._1).toArray
case StringIndexer.FREQ_ASC => values.countByValue().toSeq.sortBy(_._2).map(_._1).toArray
case StringIndexer.ALPHABET_DESC => values.distinct.collect.sortWith(_ > _)
case StringIndexer.ALPHABET_ASC => values.distinct.collect.sortWith(_ < _)
}
copyValues(new StringIndexerModel(uid, labels).setParent(this))
}

Expand All @@ -131,6 +163,12 @@ object StringIndexer extends DefaultParamsReadable[StringIndexer] {
private[feature] val KEEP_INVALID: String = "keep"
private[feature] val supportedHandleInvalids: Array[String] =
Array(SKIP_INVALID, ERROR_INVALID, KEEP_INVALID)
private[feature] val FREQ_DESC: String = "freq_desc"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there any prior standard for these names like freq_desc?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@felixcheung I did not find any prior standard, and am open to suggestion for better names.
Maybe better use frequency_desc or count_desc?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gatorsmile thought?

private[feature] val FREQ_ASC: String = "freq_asc"
private[feature] val ALPHABET_DESC: String = "alphabet_desc"
private[feature] val ALPHABET_ASC: String = "alphabet_asc"
Copy link
Member

@gatorsmile gatorsmile May 9, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Normally, we do not use underscore in the names. lowerCamelCase is our rules for naming.

Thanks for ping me, @felixcheung

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gatorsmile Thanks much for the suggestion. Changed them to lowerCamelCase.
@felixcheung Any additional suggestions?

private[feature] val supportedStringOrderType: Array[String] =
Array(FREQ_DESC, FREQ_ASC, ALPHABET_DESC, ALPHABET_ASC)

@Since("1.6.0")
override def load(path: String): StringIndexer = super.load(path)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -291,4 +291,27 @@ class StringIndexerSuite
NominalAttribute.decodeStructField(transformed.schema("labelIndex"), preserveName = true)
assert(attrs.name.nonEmpty && attrs.name.get === "labelIndex")
}

test("StringIndexer order types") {
val data = Seq((0, "b"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "b"))
val df = data.toDF("id", "label")
val indexer = new StringIndexer()
.setInputCol("label")
.setOutputCol("labelIndex")

val expected = Seq(Set((0, 0.0), (1, 0.0), (2, 2.0), (3, 1.0), (4, 1.0), (5, 0.0)),
Set((0, 2.0), (1, 2.0), (2, 0.0), (3, 1.0), (4, 1.0), (5, 2.0)),
Set((0, 1.0), (1, 1.0), (2, 0.0), (3, 2.0), (4, 2.0), (5, 1.0)),
Set((0, 1.0), (1, 1.0), (2, 2.0), (3, 0.0), (4, 0.0), (5, 1.0)))

var idx = 0
for (orderType <- StringIndexer.supportedStringOrderType) {
val transformed = indexer.setStringOrderType(orderType).fit(df).transform(df)
val output = transformed.select("id", "labelIndex").rdd.map { r =>
(r.getInt(0), r.getDouble(1))
}.collect().toSet
assert(output === expected(idx))
idx += 1
}
}
}