From ffd0cfc755586402c0f22e4458149eed16f98010 Mon Sep 17 00:00:00 2001 From: Wayne Zhang Date: Fri, 5 May 2017 23:54:41 -0700 Subject: [PATCH 1/7] StringIndexer supports multiple ways of label ordering --- .../spark/ml/feature/StringIndexer.scala | 48 ++++++++++++++++--- .../spark/ml/feature/StringIndexerSuite.scala | 23 +++++++++ 2 files changed, 64 insertions(+), 7 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index 99321bcc7cf9..d67e69cdbd39 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -59,6 +59,28 @@ private[feature] trait StringIndexerBase extends Params with HasInputCol with Ha @Since("1.6.0") def getHandleInvalid: String = $(handleInvalid) + /** + * Param for how to order labels of string column. The first label after ordering is assigned + * an index of 0. + * Options are: + * - 'freq_desc': descending order by label frequency (most frequent label assigned 0) + * - 'freq_asc': ascending order by label frequency (least frequent label assigned 0) + * - 'alphabet_desc': descending alphabetical order + * - 'alphabet_asc': ascending alphabetical order + * Default is 'freq_desc'. + * + * @group param + */ + @Since("2.2.0") + final val stringOrderType: Param[String] = new Param(this, "stringOrderType", + "The method used to order values of input column. " + + s"Supported options: ${StringIndexer.supportedStringOrderType.mkString(", ")}.", + (value: String) => StringIndexer.supportedStringOrderType.contains(value.toLowerCase)) + + /** @group getParam */ + @Since("2.2.0") + def getStringOrderType: String = $(stringOrderType) + /** Validates and transforms the input schema. */ protected def validateAndTransformSchema(schema: StructType): StructType = { val inputColName = $(inputCol) @@ -79,8 +101,9 @@ private[feature] trait StringIndexerBase extends Params with HasInputCol with Ha /** * A label indexer that maps a string column of labels to an ML column of label indices. * If the input column is numeric, we cast it to string and index the string values. - * The indices are in [0, numLabels), ordered by label frequencies. - * So the most frequent label gets index 0. + * The indices are in [0, numLabels). By default, this is ordered by label frequencies + * so the most frequent label gets index 0. The ordering behavior is controlled by + * setting stringOrderType. * * @see `IndexToString` for the inverse transformation */ @@ -96,6 +119,11 @@ class StringIndexer @Since("1.4.0") ( @Since("1.6.0") def setHandleInvalid(value: String): this.type = set(handleInvalid, value) + /** @group setParam */ + @Since("2.2.0") + def setStringOrderType(value: String): this.type = set(stringOrderType, value) + setDefault(stringOrderType, "freq_desc") + /** @group setParam */ @Since("1.4.0") def setInputCol(value: String): this.type = set(inputCol, value) @@ -107,11 +135,15 @@ class StringIndexer @Since("1.4.0") ( @Since("2.0.0") override def fit(dataset: Dataset[_]): StringIndexerModel = { transformSchema(dataset.schema, logging = true) - val counts = dataset.na.drop(Array($(inputCol))).select(col($(inputCol)).cast(StringType)) - .rdd - .map(_.getString(0)) - .countByValue() - val labels = counts.toSeq.sortBy(-_._2).map(_._1).toArray + val values = dataset.na.drop(Array($(inputCol))) + .select(col($(inputCol)).cast(StringType)) + .rdd.map(_.getString(0)) + val labels = $(stringOrderType) match { + case "freq_desc" => values.countByValue().toSeq.sortBy(-_._2).map(_._1).toArray + case "freq_asc" => values.countByValue().toSeq.sortBy(_._2).map(_._1).toArray + case "alphabet_desc" => values.distinct.collect.sortWith(_ > _) + case "alphabet_asc" => values.distinct.collect.sortWith(_ < _) + } copyValues(new StringIndexerModel(uid, labels).setParent(this)) } @@ -131,6 +163,8 @@ object StringIndexer extends DefaultParamsReadable[StringIndexer] { private[feature] val KEEP_INVALID: String = "keep" private[feature] val supportedHandleInvalids: Array[String] = Array(SKIP_INVALID, ERROR_INVALID, KEEP_INVALID) + private[feature] val supportedStringOrderType: Array[String] = + Array("freq_desc", "freq_asc", "alphabet_desc", "alphabet_asc") @Since("1.6.0") override def load(path: String): StringIndexer = super.load(path) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala index 5634d4210f47..806a92760c8b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala @@ -291,4 +291,27 @@ class StringIndexerSuite NominalAttribute.decodeStructField(transformed.schema("labelIndex"), preserveName = true) assert(attrs.name.nonEmpty && attrs.name.get === "labelIndex") } + + test("StringIndexer order types") { + val data = Seq((0, "b"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "b")) + val df = data.toDF("id", "label") + val indexer = new StringIndexer() + .setInputCol("label") + .setOutputCol("labelIndex") + + val expected = Seq(Set((0, 0.0), (1, 0.0), (2, 2.0), (3, 1.0), (4, 1.0), (5, 0.0)), + Set((0, 2.0), (1, 2.0), (2, 0.0), (3, 1.0), (4, 1.0), (5, 2.0)), + Set((0, 1.0), (1, 1.0), (2, 0.0), (3, 2.0), (4, 2.0), (5, 1.0)), + Set((0, 1.0), (1, 1.0), (2, 2.0), (3, 0.0), (4, 0.0), (5, 1.0))) + + var idx = 0 + for (orderType <- StringIndexer.supportedStringOrderType) { + val transformed = indexer.setStringOrderType(orderType).fit(df).transform(df) + val output = transformed.select("id", "labelIndex").rdd.map { r => + (r.getInt(0), r.getDouble(1)) + }.collect().toSet + assert(output === expected(idx)) + idx += 1 + } + } } From 97e020f4aba1afcf45c1b10de09d6fbba1551918 Mon Sep 17 00:00:00 2001 From: Wayne Zhang Date: Sat, 6 May 2017 01:49:05 -0700 Subject: [PATCH 2/7] address review comments and fix style --- .../spark/ml/feature/StringIndexer.scala | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index d67e69cdbd39..c1275c159f5f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -122,7 +122,7 @@ class StringIndexer @Since("1.4.0") ( /** @group setParam */ @Since("2.2.0") def setStringOrderType(value: String): this.type = set(stringOrderType, value) - setDefault(stringOrderType, "freq_desc") + setDefault(stringOrderType, StringIndexer.FREQ_DESC) /** @group setParam */ @Since("1.4.0") @@ -138,11 +138,11 @@ class StringIndexer @Since("1.4.0") ( val values = dataset.na.drop(Array($(inputCol))) .select(col($(inputCol)).cast(StringType)) .rdd.map(_.getString(0)) - val labels = $(stringOrderType) match { - case "freq_desc" => values.countByValue().toSeq.sortBy(-_._2).map(_._1).toArray - case "freq_asc" => values.countByValue().toSeq.sortBy(_._2).map(_._1).toArray - case "alphabet_desc" => values.distinct.collect.sortWith(_ > _) - case "alphabet_asc" => values.distinct.collect.sortWith(_ < _) + val labels = $(stringOrderType).toLowerCase match { + case StringIndexer.FREQ_DESC => values.countByValue().toSeq.sortBy(-_._2).map(_._1).toArray + case StringIndexer.FREQ_ASC => values.countByValue().toSeq.sortBy(_._2).map(_._1).toArray + case StringIndexer.ALPHABET_DESC => values.distinct.collect.sortWith(_ > _) + case StringIndexer.ALPHABET_ASC => values.distinct.collect.sortWith(_ < _) } copyValues(new StringIndexerModel(uid, labels).setParent(this)) } @@ -163,8 +163,12 @@ object StringIndexer extends DefaultParamsReadable[StringIndexer] { private[feature] val KEEP_INVALID: String = "keep" private[feature] val supportedHandleInvalids: Array[String] = Array(SKIP_INVALID, ERROR_INVALID, KEEP_INVALID) + private[feature] val FREQ_DESC: String = "freq_desc" + private[feature] val FREQ_ASC: String = "freq_asc" + private[feature] val ALPHABET_DESC: String = "alphabet_desc" + private[feature] val ALPHABET_ASC: String = "alphabet_asc" private[feature] val supportedStringOrderType: Array[String] = - Array("freq_desc", "freq_asc", "alphabet_desc", "alphabet_asc") + Array(FREQ_DESC, FREQ_ASC, ALPHABET_DESC, ALPHABET_ASC) @Since("1.6.0") override def load(path: String): StringIndexer = super.load(path) From ba340437fee99f848dfa88eab2e10d87651eab0a Mon Sep 17 00:00:00 2001 From: Wayne Zhang Date: Mon, 8 May 2017 21:42:01 -0700 Subject: [PATCH 3/7] address comments- spell out freq and update annotation and toLowerCase --- .../spark/ml/feature/StringIndexer.scala | 29 ++++++++++--------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index c1275c159f5f..9840ca608fa6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -17,8 +17,9 @@ package org.apache.spark.ml.feature -import scala.language.existentials +import java.util.Locale +import scala.language.existentials import org.apache.hadoop.fs.Path import org.apache.spark.SparkException @@ -63,23 +64,25 @@ private[feature] trait StringIndexerBase extends Params with HasInputCol with Ha * Param for how to order labels of string column. The first label after ordering is assigned * an index of 0. * Options are: - * - 'freq_desc': descending order by label frequency (most frequent label assigned 0) - * - 'freq_asc': ascending order by label frequency (least frequent label assigned 0) + * - 'frequency_desc': descending order by label frequency (most frequent label assigned 0) + * - 'frequency_asc': ascending order by label frequency (least frequent label assigned 0) * - 'alphabet_desc': descending alphabetical order * - 'alphabet_asc': ascending alphabetical order - * Default is 'freq_desc'. + * Default is 'frequency_desc'. * * @group param */ - @Since("2.2.0") + @Since("2.3.0") final val stringOrderType: Param[String] = new Param(this, "stringOrderType", - "The method used to order values of input column. " + - s"Supported options: ${StringIndexer.supportedStringOrderType.mkString(", ")}.", - (value: String) => StringIndexer.supportedStringOrderType.contains(value.toLowerCase)) + "how to order labels of string column. " + + "The first label after ordering is assigned an index of 0. " + + s"Supported options: ${StringIndexer.supportedStringOrderType.mkString(", ")}.", + (value: String) => StringIndexer.supportedStringOrderType + .contains(value.toLowerCase(Locale.ROOT))) /** @group getParam */ - @Since("2.2.0") - def getStringOrderType: String = $(stringOrderType) + @Since("2.3.0") + def getStringOrderType: String = $(stringOrderType).toLowerCase(Locale.ROOT) /** Validates and transforms the input schema. */ protected def validateAndTransformSchema(schema: StructType): StructType = { @@ -138,7 +141,7 @@ class StringIndexer @Since("1.4.0") ( val values = dataset.na.drop(Array($(inputCol))) .select(col($(inputCol)).cast(StringType)) .rdd.map(_.getString(0)) - val labels = $(stringOrderType).toLowerCase match { + val labels = this.getStringOrderType match { case StringIndexer.FREQ_DESC => values.countByValue().toSeq.sortBy(-_._2).map(_._1).toArray case StringIndexer.FREQ_ASC => values.countByValue().toSeq.sortBy(_._2).map(_._1).toArray case StringIndexer.ALPHABET_DESC => values.distinct.collect.sortWith(_ > _) @@ -163,8 +166,8 @@ object StringIndexer extends DefaultParamsReadable[StringIndexer] { private[feature] val KEEP_INVALID: String = "keep" private[feature] val supportedHandleInvalids: Array[String] = Array(SKIP_INVALID, ERROR_INVALID, KEEP_INVALID) - private[feature] val FREQ_DESC: String = "freq_desc" - private[feature] val FREQ_ASC: String = "freq_asc" + private[feature] val FREQ_DESC: String = "frequency_desc" + private[feature] val FREQ_ASC: String = "frequency_asc" private[feature] val ALPHABET_DESC: String = "alphabet_desc" private[feature] val ALPHABET_ASC: String = "alphabet_asc" private[feature] val supportedStringOrderType: Array[String] = From ff9b1d66873eb8cad1a4a13f323555da2706a849 Mon Sep 17 00:00:00 2001 From: Wayne Zhang Date: Mon, 8 May 2017 21:52:20 -0700 Subject: [PATCH 4/7] fix style --- .../main/scala/org/apache/spark/ml/feature/StringIndexer.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index 9840ca608fa6..504642d907b2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -20,6 +20,7 @@ package org.apache.spark.ml.feature import java.util.Locale import scala.language.existentials + import org.apache.hadoop.fs.Path import org.apache.spark.SparkException From 07198d9bb45a54d3c257ad37e772cc31154ffcb6 Mon Sep 17 00:00:00 2001 From: Wayne Zhang Date: Mon, 8 May 2017 21:59:56 -0700 Subject: [PATCH 5/7] fix annotation --- .../main/scala/org/apache/spark/ml/feature/StringIndexer.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index 504642d907b2..324c9d83b8c5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -124,7 +124,7 @@ class StringIndexer @Since("1.4.0") ( def setHandleInvalid(value: String): this.type = set(handleInvalid, value) /** @group setParam */ - @Since("2.2.0") + @Since("2.3.0") def setStringOrderType(value: String): this.type = set(stringOrderType, value) setDefault(stringOrderType, StringIndexer.FREQ_DESC) From 6bbe7df8d53ef461146cbb3a786ba78436f54af1 Mon Sep 17 00:00:00 2001 From: Wayne Zhang Date: Mon, 8 May 2017 22:39:02 -0700 Subject: [PATCH 6/7] use camel case --- .../spark/ml/feature/StringIndexer.scala | 41 ++++++++++--------- 1 file changed, 21 insertions(+), 20 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index 324c9d83b8c5..446f3027cbcb 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -65,11 +65,11 @@ private[feature] trait StringIndexerBase extends Params with HasInputCol with Ha * Param for how to order labels of string column. The first label after ordering is assigned * an index of 0. * Options are: - * - 'frequency_desc': descending order by label frequency (most frequent label assigned 0) - * - 'frequency_asc': ascending order by label frequency (least frequent label assigned 0) - * - 'alphabet_desc': descending alphabetical order - * - 'alphabet_asc': ascending alphabetical order - * Default is 'frequency_desc'. + * - 'frequencyDesc': descending order by label frequency (most frequent label assigned 0) + * - 'frequencyAsc': ascending order by label frequency (least frequent label assigned 0) + * - 'alphabetDesc': descending alphabetical order + * - 'alphabetAsc': ascending alphabetical order + * Default is 'frequencyDesc'. * * @group param */ @@ -78,12 +78,11 @@ private[feature] trait StringIndexerBase extends Params with HasInputCol with Ha "how to order labels of string column. " + "The first label after ordering is assigned an index of 0. " + s"Supported options: ${StringIndexer.supportedStringOrderType.mkString(", ")}.", - (value: String) => StringIndexer.supportedStringOrderType - .contains(value.toLowerCase(Locale.ROOT))) + ParamValidators.inArray(StringIndexer.supportedStringOrderType)) /** @group getParam */ @Since("2.3.0") - def getStringOrderType: String = $(stringOrderType).toLowerCase(Locale.ROOT) + def getStringOrderType: String = $(stringOrderType) /** Validates and transforms the input schema. */ protected def validateAndTransformSchema(schema: StructType): StructType = { @@ -107,7 +106,7 @@ private[feature] trait StringIndexerBase extends Params with HasInputCol with Ha * If the input column is numeric, we cast it to string and index the string values. * The indices are in [0, numLabels). By default, this is ordered by label frequencies * so the most frequent label gets index 0. The ordering behavior is controlled by - * setting stringOrderType. + * setting `stringOrderType`. * * @see `IndexToString` for the inverse transformation */ @@ -126,7 +125,7 @@ class StringIndexer @Since("1.4.0") ( /** @group setParam */ @Since("2.3.0") def setStringOrderType(value: String): this.type = set(stringOrderType, value) - setDefault(stringOrderType, StringIndexer.FREQ_DESC) + setDefault(stringOrderType, StringIndexer.frequencyDesc) /** @group setParam */ @Since("1.4.0") @@ -142,11 +141,13 @@ class StringIndexer @Since("1.4.0") ( val values = dataset.na.drop(Array($(inputCol))) .select(col($(inputCol)).cast(StringType)) .rdd.map(_.getString(0)) - val labels = this.getStringOrderType match { - case StringIndexer.FREQ_DESC => values.countByValue().toSeq.sortBy(-_._2).map(_._1).toArray - case StringIndexer.FREQ_ASC => values.countByValue().toSeq.sortBy(_._2).map(_._1).toArray - case StringIndexer.ALPHABET_DESC => values.distinct.collect.sortWith(_ > _) - case StringIndexer.ALPHABET_ASC => values.distinct.collect.sortWith(_ < _) + val labels = $(stringOrderType) match { + case StringIndexer.frequencyDesc => values.countByValue().toSeq.sortBy(-_._2) + .map(_._1).toArray + case StringIndexer.frequencyAsc => values.countByValue().toSeq.sortBy(_._2) + .map(_._1).toArray + case StringIndexer.alphabetDesc => values.distinct.collect.sortWith(_ > _) + case StringIndexer.alphabetAsc => values.distinct.collect.sortWith(_ < _) } copyValues(new StringIndexerModel(uid, labels).setParent(this)) } @@ -167,12 +168,12 @@ object StringIndexer extends DefaultParamsReadable[StringIndexer] { private[feature] val KEEP_INVALID: String = "keep" private[feature] val supportedHandleInvalids: Array[String] = Array(SKIP_INVALID, ERROR_INVALID, KEEP_INVALID) - private[feature] val FREQ_DESC: String = "frequency_desc" - private[feature] val FREQ_ASC: String = "frequency_asc" - private[feature] val ALPHABET_DESC: String = "alphabet_desc" - private[feature] val ALPHABET_ASC: String = "alphabet_asc" + private[feature] val frequencyDesc: String = "frequencyDesc" + private[feature] val frequencyAsc: String = "frequencyAsc" + private[feature] val alphabetDesc: String = "alphabetDesc" + private[feature] val alphabetAsc: String = "alphabetAsc" private[feature] val supportedStringOrderType: Array[String] = - Array(FREQ_DESC, FREQ_ASC, ALPHABET_DESC, ALPHABET_ASC) + Array(frequencyDesc, frequencyAsc, alphabetDesc, alphabetAsc) @Since("1.6.0") override def load(path: String): StringIndexer = super.load(path) From 53381ea6ba41cc26ed89a6fc42252f7126198d9f Mon Sep 17 00:00:00 2001 From: Wayne Zhang Date: Mon, 8 May 2017 22:40:10 -0700 Subject: [PATCH 7/7] remove extra import --- .../main/scala/org/apache/spark/ml/feature/StringIndexer.scala | 2 -- 1 file changed, 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index 446f3027cbcb..b2dc4fcb6196 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -17,8 +17,6 @@ package org.apache.spark.ml.feature -import java.util.Locale - import scala.language.existentials import org.apache.hadoop.fs.Path