-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-14659][ML] RFormula consistent with R when handling strings #17967
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 5 commits
4d27123
6841c33
77fe864
a1be94c
698588e
147311b
5f31d31
341949c
24818a7
1a1e06c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -26,7 +26,7 @@ import org.apache.spark.annotation.{Experimental, Since} | |
| import org.apache.spark.ml.{Estimator, Model, Pipeline, PipelineModel, PipelineStage, Transformer} | ||
| import org.apache.spark.ml.attribute.AttributeGroup | ||
| import org.apache.spark.ml.linalg.VectorUDT | ||
| import org.apache.spark.ml.param.{BooleanParam, Param, ParamMap} | ||
| import org.apache.spark.ml.param.{BooleanParam, Param, ParamMap, ParamValidators} | ||
| import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasLabelCol} | ||
| import org.apache.spark.ml.util._ | ||
| import org.apache.spark.sql.{DataFrame, Dataset} | ||
|
|
@@ -37,6 +37,29 @@ import org.apache.spark.sql.types._ | |
| */ | ||
| private[feature] trait RFormulaBase extends HasFeaturesCol with HasLabelCol { | ||
|
|
||
| /** | ||
| * Param for how to order labels of string column. The first label after ordering is assigned | ||
|
||
| * an index of 0. | ||
| * Options are: | ||
| * - 'frequencyDesc': descending order by label frequency (most frequent label assigned 0) | ||
| * - 'frequencyAsc': ascending order by label frequency (least frequent label assigned 0) | ||
| * - 'alphabetDesc': descending alphabetical order | ||
| * - 'alphabetAsc': ascending alphabetical order | ||
| * Default is 'frequencyDesc'. | ||
| * | ||
| * @group param | ||
| */ | ||
| @Since("2.3.0") | ||
| final val stringOrderType: Param[String] = new Param(this, "stringOrderType", | ||
|
||
| "How to order labels of string column. " + | ||
| "The first label after ordering is assigned an index of 0. " + | ||
| s"Supported options: ${StringIndexer.supportedStringOrderType.mkString(", ")}.", | ||
| ParamValidators.inArray(StringIndexer.supportedStringOrderType)) | ||
|
|
||
| /** @group getParam */ | ||
| @Since("2.3.0") | ||
| def getStringOrderType: String = $(stringOrderType) | ||
|
|
||
| protected def hasLabelCol(schema: StructType): Boolean = { | ||
| schema.map(_.name).contains($(labelCol)) | ||
| } | ||
|
|
@@ -125,6 +148,11 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override val uid: String) | |
| @Since("2.1.0") | ||
| def setForceIndexLabel(value: Boolean): this.type = set(forceIndexLabel, value) | ||
|
|
||
| /** @group setParam */ | ||
| @Since("2.3.0") | ||
| def setStringOrderType(value: String): this.type = set(stringOrderType, value) | ||
| setDefault(stringOrderType, StringIndexer.frequencyDesc) | ||
|
|
||
| /** Whether the formula specifies fitting an intercept. */ | ||
| private[ml] def hasIntercept: Boolean = { | ||
| require(isDefined(formula), "Formula must be defined first.") | ||
|
|
@@ -155,6 +183,7 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override val uid: String) | |
| encoderStages += new StringIndexer() | ||
| .setInputCol(term) | ||
| .setOutputCol(indexCol) | ||
| .setStringOrderType($(stringOrderType)) | ||
| prefixesToRewrite(indexCol + "_") = term + "_" | ||
| (term, indexCol) | ||
| case _ => | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -129,6 +129,90 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul | |
| assert(result.collect() === expected.collect()) | ||
| } | ||
|
|
||
| test("encodes string terms with string order type") { | ||
| val formula = new RFormula().setFormula("id ~ a + b") | ||
| val original = Seq((1, "foo", 4), (2, "bar", 4), (3, "bar", 5), (4, "aaz", 5)) | ||
| .toDF("id", "a", "b") | ||
|
|
||
| val expected = Seq( | ||
| Seq( | ||
| (1, "foo", 4, Vectors.dense(0.0, 0.0, 4.0), 1.0), | ||
| (2, "bar", 4, Vectors.dense(1.0, 0.0, 4.0), 2.0), | ||
| (3, "bar", 5, Vectors.dense(1.0, 0.0, 5.0), 3.0), | ||
| (4, "aaz", 5, Vectors.dense(0.0, 1.0, 5.0), 4.0) | ||
| ).toDF("id", "a", "b", "features", "label"), | ||
| Seq( | ||
| (1, "foo", 4, Vectors.dense(0.0, 1.0, 4.0), 1.0), | ||
| (2, "bar", 4, Vectors.dense(0.0, 0.0, 4.0), 2.0), | ||
| (3, "bar", 5, Vectors.dense(0.0, 0.0, 5.0), 3.0), | ||
| (4, "aaz", 5, Vectors.dense(1.0, 0.0, 5.0), 4.0) | ||
| ).toDF("id", "a", "b", "features", "label"), | ||
| Seq( | ||
| (1, "foo", 4, Vectors.dense(1.0, 0.0, 4.0), 1.0), | ||
| (2, "bar", 4, Vectors.dense(0.0, 1.0, 4.0), 2.0), | ||
| (3, "bar", 5, Vectors.dense(0.0, 1.0, 5.0), 3.0), | ||
| (4, "aaz", 5, Vectors.dense(0.0, 0.0, 5.0), 4.0) | ||
| ).toDF("id", "a", "b", "features", "label"), | ||
| Seq( | ||
| (1, "foo", 4, Vectors.dense(0.0, 0.0, 4.0), 1.0), | ||
| (2, "bar", 4, Vectors.dense(0.0, 1.0, 4.0), 2.0), | ||
| (3, "bar", 5, Vectors.dense(0.0, 1.0, 5.0), 3.0), | ||
| (4, "aaz", 5, Vectors.dense(1.0, 0.0, 5.0), 4.0) | ||
| ).toDF("id", "a", "b", "features", "label") | ||
| ) | ||
|
|
||
| var idx = 0 | ||
| for (orderType <- StringIndexer.supportedStringOrderType) { | ||
| val model = formula.setStringOrderType(orderType).fit(original) | ||
| val result = model.transform(original) | ||
| val resultSchema = model.transformSchema(original.schema) | ||
| assert(result.schema.toString == resultSchema.toString) | ||
| assert(result.collect() === expected(idx).collect()) | ||
| idx += 1 | ||
| } | ||
| } | ||
|
|
||
| test("test consistency with R when encoding string terms") { | ||
| /* | ||
| R code: | ||
|
|
||
| df <- data.frame(id = c(1, 2, 3, 4), | ||
| a = c("foo", "bar", "bar", "aaz"), | ||
| b = c(4, 4, 5, 5)) | ||
| model.matrix(id ~ a + b, df)[, -1] | ||
|
|
||
| abar afoo b | ||
| 0 1 4 | ||
| 1 0 4 | ||
| 1 0 5 | ||
| 0 0 5 | ||
| */ | ||
| val original = Seq((1, "foo", 4), (2, "bar", 4), (3, "bar", 5), (4, "aaz", 5)) | ||
| .toDF("id", "a", "b") | ||
| val formula = new RFormula().setFormula("id ~ a + b") | ||
| .setStringOrderType(StringIndexer.alphabetDesc) | ||
|
|
||
| /* | ||
| Note that the category dropped after encoding is the same between R and Spark | ||
| (i.e., "aaz" is treated as the reference level). | ||
| However, the column order is still different: | ||
| R renders the columns in ascending alphabetical order ("bar", "foo"), while | ||
| RFormula renders the columns in descending alphabetical order ("foo", "bar"). | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. R and RFormula should behavior consistent if you fix the issue I mentioned above. |
||
| */ | ||
| val expected = Seq( | ||
| (1, "foo", 4, Vectors.dense(1.0, 0.0, 4.0), 1.0), | ||
| (2, "bar", 4, Vectors.dense(0.0, 1.0, 4.0), 2.0), | ||
| (3, "bar", 5, Vectors.dense(0.0, 1.0, 5.0), 3.0), | ||
| (4, "aaz", 5, Vectors.dense(0.0, 0.0, 5.0), 4.0) | ||
| ).toDF("id", "a", "b", "features", "label") | ||
|
|
||
| val model = formula.fit(original) | ||
| val result = model.transform(original) | ||
| val resultSchema = model.transformSchema(original.schema) | ||
| assert(result.schema.toString == resultSchema.toString) | ||
| assert(result.collect() === expected.collect()) | ||
| } | ||
|
|
||
| test("index string label") { | ||
| val formula = new RFormula().setFormula("id ~ a + b") | ||
| val original = | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we add a comment explaining which option is consistent with R?