-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-17017][Follow-up][ML] Refactor of ChiSqSelector and add ML Python API. #15214
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -27,7 +27,7 @@ import org.apache.spark.ml.param._ | |
| import org.apache.spark.ml.param.shared._ | ||
| import org.apache.spark.ml.util._ | ||
| import org.apache.spark.mllib.feature | ||
| import org.apache.spark.mllib.feature.ChiSqSelectorType | ||
| import org.apache.spark.mllib.feature.{ChiSqSelector => OldChiSqSelector} | ||
| import org.apache.spark.mllib.linalg.{Vectors => OldVectors} | ||
| import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint} | ||
| import org.apache.spark.rdd.RDD | ||
|
|
@@ -44,7 +44,9 @@ private[feature] trait ChiSqSelectorParams extends Params | |
| /** | ||
| * Number of features that selector will select (ordered by statistic value descending). If the | ||
| * number of features is less than numTopFeatures, then this will select all features. | ||
| * Only applicable when selectorType = "kbest". | ||
| * The default value of numTopFeatures is 50. | ||
| * | ||
| * @group param | ||
| */ | ||
| final val numTopFeatures = new IntParam(this, "numTopFeatures", | ||
|
|
@@ -56,6 +58,11 @@ private[feature] trait ChiSqSelectorParams extends Params | |
| /** @group getParam */ | ||
| def getNumTopFeatures: Int = $(numTopFeatures) | ||
|
|
||
| /** | ||
| * Percentile of features that selector will select, ordered by statistics value descending. | ||
| * Only applicable when selectorType = "percentile". | ||
| * Default value is 0.1. | ||
| */ | ||
| final val percentile = new DoubleParam(this, "percentile", | ||
| "Percentile of features that selector will select, ordered by statistics value descending.", | ||
| ParamValidators.inRange(0, 1)) | ||
|
|
@@ -64,38 +71,40 @@ private[feature] trait ChiSqSelectorParams extends Params | |
| /** @group getParam */ | ||
| def getPercentile: Double = $(percentile) | ||
|
|
||
| final val alpha = new DoubleParam(this, "alpha", | ||
| "The highest p-value for features to be kept.", | ||
| /** | ||
| * The highest p-value for features to be kept. | ||
| * Only applicable when selectorType = "fpr". | ||
| * Default value is 0.05. | ||
| */ | ||
| final val alpha = new DoubleParam(this, "alpha", "The highest p-value for features to be kept.", | ||
| ParamValidators.inRange(0, 1)) | ||
| setDefault(alpha -> 0.05) | ||
|
|
||
| /** @group getParam */ | ||
| def getAlpha: Double = $(alpha) | ||
|
|
||
| /** | ||
| * The ChiSqSelector supports KBest, Percentile, FPR selection, | ||
| * which is the same as ChiSqSelectorType defined in MLLIB. | ||
| * when call setNumTopFeatures, the selectorType is set to KBest | ||
| * when call setPercentile, the selectorType is set to Percentile | ||
| * when call setAlpha, the selectorType is set to FPR | ||
| * The selector type of the ChisqSelector. | ||
| * Supported options: "kbest" (default), "percentile" and "fpr". | ||
| */ | ||
| final val selectorType = new Param[String](this, "selectorType", | ||
| "ChiSqSelector Type: KBest, Percentile, FPR") | ||
| setDefault(selectorType -> ChiSqSelectorType.KBest.toString) | ||
| "The selector type of the ChisqSelector. " + | ||
| "Supported options: kbest (default), percentile and fpr.", | ||
| ParamValidators.inArray[String](OldChiSqSelector.supportedSelectorTypes.toArray)) | ||
| setDefault(selectorType -> OldChiSqSelector.KBest) | ||
|
|
||
| /** @group getParam */ | ||
| def getChiSqSelectorType: String = $(selectorType) | ||
| def getSelectorType: String = $(selectorType) | ||
| } | ||
|
|
||
| /** | ||
| * Chi-Squared feature selection, which selects categorical features to use for predicting a | ||
| * categorical label. | ||
| * The selector supports three selection methods: `KBest`, `Percentile` and `FPR`. | ||
| * `KBest` chooses the `k` top features according to a chi-squared test. | ||
| * `Percentile` is similar but chooses a fraction of all features instead of a fixed number. | ||
| * `FPR` chooses all features whose false positive rate meets some threshold. | ||
| * By default, the selection method is `KBest`, the default number of top features is 50. | ||
| * User can use setNumTopFeatures, setPercentile and setAlpha to set different selection methods. | ||
| * The selector supports three selection methods: `kbest`, `percentile` and `fpr`. | ||
| * `kbest` chooses the `k` top features according to a chi-squared test. | ||
| * `percentile` is similar but chooses a fraction of all features instead of a fixed number. | ||
| * `fpr` chooses all features whose false positive rate meets some threshold. | ||
| * By default, the selection method is `kbest`, the default number of top features is 50. | ||
| */ | ||
| @Since("1.6.0") | ||
| final class ChiSqSelector @Since("1.6.0") (@Since("1.6.0") override val uid: String) | ||
|
|
@@ -104,24 +113,21 @@ final class ChiSqSelector @Since("1.6.0") (@Since("1.6.0") override val uid: Str | |
| @Since("1.6.0") | ||
| def this() = this(Identifiable.randomUID("chiSqSelector")) | ||
|
|
||
| /** @group setParam */ | ||
| @Since("2.1.0") | ||
| def setSelectorType(value: String): this.type = set(selectorType, value) | ||
|
|
||
| /** @group setParam */ | ||
| @Since("1.6.0") | ||
| def setNumTopFeatures(value: Int): this.type = { | ||
| set(selectorType, ChiSqSelectorType.KBest.toString) | ||
| set(numTopFeatures, value) | ||
| } | ||
| def setNumTopFeatures(value: Int): this.type = set(numTopFeatures, value) | ||
|
|
||
| /** @group setParam */ | ||
| @Since("2.1.0") | ||
| def setPercentile(value: Double): this.type = { | ||
| set(selectorType, ChiSqSelectorType.Percentile.toString) | ||
| set(percentile, value) | ||
| } | ||
| def setPercentile(value: Double): this.type = set(percentile, value) | ||
|
|
||
| /** @group setParam */ | ||
| @Since("2.1.0") | ||
| def setAlpha(value: Double): this.type = { | ||
| set(selectorType, ChiSqSelectorType.FPR.toString) | ||
| set(alpha, value) | ||
| } | ||
| def setAlpha(value: Double): this.type = set(alpha, value) | ||
|
|
||
| /** @group setParam */ | ||
| @Since("1.6.0") | ||
|
|
@@ -143,13 +149,13 @@ final class ChiSqSelector @Since("1.6.0") (@Since("1.6.0") override val uid: Str | |
| case Row(label: Double, features: Vector) => | ||
| OldLabeledPoint(label, OldVectors.fromML(features)) | ||
| } | ||
| var selector = new feature.ChiSqSelector() | ||
| ChiSqSelectorType.withName($(selectorType)) match { | ||
| case ChiSqSelectorType.KBest => | ||
| val selector = new feature.ChiSqSelector() | ||
| $(selectorType) match { | ||
| case OldChiSqSelector.KBest => | ||
| selector.setNumTopFeatures($(numTopFeatures)) | ||
| case ChiSqSelectorType.Percentile => | ||
| case OldChiSqSelector.Percentile => | ||
| selector.setPercentile($(percentile)) | ||
| case ChiSqSelectorType.FPR => | ||
| case OldChiSqSelector.FPR => | ||
| selector.setAlpha($(alpha)) | ||
|
||
| case errorType => | ||
| throw new IllegalStateException(s"Unknown ChiSqSelector Type: $errorType") | ||
|
|
@@ -160,6 +166,12 @@ final class ChiSqSelector @Since("1.6.0") (@Since("1.6.0") override val uid: Str | |
|
|
||
| @Since("1.6.0") | ||
| override def transformSchema(schema: StructType): StructType = { | ||
| val otherPairs = OldChiSqSelector.supportedTypeAndParamPairs.filter(_._1 == $(selectorType)) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. == or != ?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry, typo. |
||
| otherPairs.foreach { case (_, paramName: String) => | ||
| if (isSet(getParam(paramName))) { | ||
| logWarning(s"Param $paramName will take no effect when selector type = ${$(selectorType)}.") | ||
| } | ||
| } | ||
| SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT) | ||
| SchemaUtils.checkNumericType(schema, $(labelCol)) | ||
| SchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -76,7 +76,7 @@ class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext { | |
| LabeledPoint(1.0, Vectors.dense(Array(4.0))), | ||
| LabeledPoint(1.0, Vectors.dense(Array(4.0))), | ||
| LabeledPoint(2.0, Vectors.dense(Array(9.0)))) | ||
| val model = new ChiSqSelector().setAlpha(0.1).fit(labeledDiscreteData) | ||
| val model = new ChiSqSelector().setSelectorType("fpr").setAlpha(0.1).fit(labeledDiscreteData) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you should also do the same thing for https://github.com/apache/spark/blob/master/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added ML test case. |
||
| val filteredData = labeledDiscreteData.map { lp => | ||
| LabeledPoint(lp.label, model.transform(lp.features)) | ||
| }.collect().toSet | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you need to set SelectorType here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, updated.