-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-17017][MLLIB][ML] add a chiSquare Selector based on False Positive Rate (FPR) test #14597
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 6 commits
2adebe8
04053ca
7623563
3d6aecb
026ac85
5305709
1e8d83a
85a17dd
61b71c8
d7b2892
6699396
b8986b5
5c2e44c
0d3967a
1dc6a8e
9908871
bbccac7
c35bcf1
e8f03ed
ec74dde
6398f4c
6cc4c92
1d2f67f
6220dd5
ce3f8fb
88d2143
24f26f2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -54,6 +54,29 @@ private[feature] trait ChiSqSelectorParams extends Params | |
|
|
||
| /** @group getParam */ | ||
| def getNumTopFeatures: Int = $(numTopFeatures) | ||
|
|
||
| final val percentile = new DoubleParam(this, "percentile", | ||
| "Percentile of features that selector will select, ordered by statistics value descending.", | ||
| ParamValidators.gtEq(0)) | ||
| setDefault(percentile -> 10) | ||
|
||
|
|
||
| /** @group getParam */ | ||
| def getPercentile: Double = $(percentile) | ||
|
|
||
| final val alpha = new DoubleParam(this, "alpha", | ||
| "The highest p-value for features to be kept.", | ||
| ParamValidators.gtEq(0)) | ||
|
||
| setDefault(alpha -> 0.05) | ||
|
|
||
| /** @group getParam */ | ||
| def getAlpha: Double = $(alpha) | ||
|
|
||
| final val selectorType = new Param[String](this, "selectorType", | ||
| "ChiSqSelector Type: KBest, Percentile, Fpr") | ||
| setDefault(selectorType -> "KBest") | ||
|
|
||
| /** @group getParam */ | ||
| def getChiSqSelectorType: String = $(selectorType) | ||
| } | ||
|
|
||
| /** | ||
|
|
@@ -67,9 +90,27 @@ final class ChiSqSelector @Since("1.6.0") (@Since("1.6.0") override val uid: Str | |
| @Since("1.6.0") | ||
| def this() = this(Identifiable.randomUID("chiSqSelector")) | ||
|
|
||
| @Since("2.1.0") | ||
| var chiSqSelector: feature.ChiSqSelector = null | ||
|
|
||
| /** @group setParam */ | ||
| @Since("1.6.0") | ||
| def setNumTopFeatures(value: Int): this.type = set(numTopFeatures, value) | ||
| @Since("2.1.0") | ||
|
||
| def setNumTopFeatures(value: Int): this.type = { | ||
| set(selectorType, "KBest") | ||
|
||
| set(numTopFeatures, value) | ||
| } | ||
|
|
||
| @Since("2.1.0") | ||
| def setPercentile(value: Double): this.type = { | ||
| set(selectorType, "Percentile") | ||
| set(percentile, value) | ||
| } | ||
|
|
||
| @Since("2.1.0") | ||
| def setAlpha(value: Double): this.type = { | ||
| set(selectorType, "Fpr") | ||
| set(alpha, value) | ||
| } | ||
|
|
||
| /** @group setParam */ | ||
| @Since("1.6.0") | ||
|
|
@@ -91,8 +132,38 @@ final class ChiSqSelector @Since("1.6.0") (@Since("1.6.0") override val uid: Str | |
| case Row(label: Double, features: Vector) => | ||
| OldLabeledPoint(label, OldVectors.fromML(features)) | ||
| } | ||
| val chiSqSelector = new feature.ChiSqSelector($(numTopFeatures)).fit(input) | ||
| copyValues(new ChiSqSelectorModel(uid, chiSqSelector).setParent(this)) | ||
| $(selectorType) match { | ||
| case "KBest" => | ||
|
||
| chiSqSelector = new feature.ChiSqSelector().setNumTopFeatures($(numTopFeatures)) | ||
| case "Percentile" => | ||
| chiSqSelector = new feature.ChiSqSelector().setPercentile($(percentile)) | ||
| case "Fpr" => | ||
| chiSqSelector = new feature.ChiSqSelector().setAlpha($(alpha)) | ||
| case _ => throw new Exception("Unknown ChiSqSelector Type.") | ||
| } | ||
| val model = chiSqSelector.fit(input) | ||
| copyValues(new ChiSqSelectorModel(uid, model).setParent(this)) | ||
| } | ||
|
|
||
| @Since("2.1.0") | ||
| def selectKBest(value: Int): ChiSqSelectorModel = { | ||
| require(chiSqSelector != null, "ChiSqSelector has not been created.") | ||
| val model = chiSqSelector.selectKBest(value) | ||
| copyValues(new ChiSqSelectorModel(uid, model).setParent(this)) | ||
| } | ||
|
|
||
| @Since("2.1.0") | ||
| def selectPercentile(value: Double): ChiSqSelectorModel = { | ||
| require(chiSqSelector != null, "ChiSqSelector has not been created.") | ||
| val model = chiSqSelector.selectPercentile(value) | ||
| copyValues(new ChiSqSelectorModel(uid, model).setParent(this)) | ||
| } | ||
|
|
||
| @Since("2.1.0") | ||
| def selectFpr(value: Double): ChiSqSelectorModel = { | ||
| require(chiSqSelector != null, "ChiSqSelector has not been created.") | ||
| val model = chiSqSelector.selectFpr(value) | ||
| copyValues(new ChiSqSelectorModel(uid, model).setParent(this)) | ||
| } | ||
|
|
||
| @Since("1.6.0") | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -27,22 +27,27 @@ import org.apache.spark.annotation.Since | |
| import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} | ||
| import org.apache.spark.mllib.regression.LabeledPoint | ||
| import org.apache.spark.mllib.stat.Statistics | ||
| import org.apache.spark.mllib.stat.test.ChiSqTestResult | ||
| import org.apache.spark.mllib.util.{Loader, Saveable} | ||
| import org.apache.spark.rdd.RDD | ||
| import org.apache.spark.SparkContext | ||
| import org.apache.spark.sql.{Row, SparkSession} | ||
|
|
||
| @Since("2.1.0") | ||
| object ChiSqSelectorType extends Enumeration { | ||
|
||
| type SelectorType = Value | ||
| val KBest, Percentile, Fpr = Value | ||
| } | ||
|
|
||
| /** | ||
| * Chi Squared selector model. | ||
| * | ||
| * @param selectedFeatures list of indices to select (filter). Must be ordered asc | ||
| * @param selectedFeatures list of indices to select (filter). | ||
| */ | ||
| @Since("1.3.0") | ||
| class ChiSqSelectorModel @Since("1.3.0") ( | ||
| @Since("1.3.0") val selectedFeatures: Array[Int]) extends VectorTransformer with Saveable { | ||
|
|
||
| require(isSorted(selectedFeatures), "Array has to be sorted asc") | ||
|
|
||
| protected def isSorted(array: Array[Int]): Boolean = { | ||
| var i = 1 | ||
| val len = array.length | ||
|
|
@@ -69,21 +74,22 @@ class ChiSqSelectorModel @Since("1.3.0") ( | |
| * Preserves the order of filtered features the same as their indices are stored. | ||
| * Might be moved to Vector as .slice | ||
| * @param features vector | ||
| * @param filterIndices indices of features to filter, must be ordered asc | ||
| * @param filterIndices indices of features to filter | ||
| */ | ||
| private def compress(features: Vector, filterIndices: Array[Int]): Vector = { | ||
| val orderedIndices = filterIndices.sorted | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This can be computed once and stored, rather than store unsorted indices and resort them. |
||
| features match { | ||
| case SparseVector(size, indices, values) => | ||
| val newSize = filterIndices.length | ||
| val newSize = orderedIndices.length | ||
| val newValues = new ArrayBuilder.ofDouble | ||
| val newIndices = new ArrayBuilder.ofInt | ||
| var i = 0 | ||
| var j = 0 | ||
| var indicesIdx = 0 | ||
| var filterIndicesIdx = 0 | ||
| while (i < indices.length && j < filterIndices.length) { | ||
| while (i < indices.length && j < orderedIndices.length) { | ||
| indicesIdx = indices(i) | ||
| filterIndicesIdx = filterIndices(j) | ||
| filterIndicesIdx = orderedIndices(j) | ||
| if (indicesIdx == filterIndicesIdx) { | ||
| newIndices += j | ||
| newValues += values(i) | ||
|
|
@@ -101,7 +107,7 @@ class ChiSqSelectorModel @Since("1.3.0") ( | |
| Vectors.sparse(newSize, newIndices.result(), newValues.result()) | ||
| case DenseVector(values) => | ||
| val values = features.toArray | ||
| Vectors.dense(filterIndices.map(i => values(i))) | ||
| Vectors.dense(orderedIndices.map(i => values(i))) | ||
| case other => | ||
| throw new UnsupportedOperationException( | ||
| s"Only sparse and dense vectors are supported but got ${other.getClass}.") | ||
|
|
@@ -171,14 +177,47 @@ object ChiSqSelectorModel extends Loader[ChiSqSelectorModel] { | |
|
|
||
| /** | ||
| * Creates a ChiSquared feature selector. | ||
| * @param numTopFeatures number of features that selector will select | ||
| * (ordered by statistic value descending) | ||
| * Note that if the number of features is less than numTopFeatures, | ||
| * then this will select all features. | ||
| */ | ||
| @Since("1.3.0") | ||
| class ChiSqSelector @Since("1.3.0") ( | ||
| @Since("1.3.0") val numTopFeatures: Int) extends Serializable { | ||
| @Since("2.1.0") | ||
|
||
| class ChiSqSelector @Since("2.1.0") () extends Serializable { | ||
| private var numTopFeatures: Int = 50 | ||
| private var percentile: Double = 10.0 | ||
| private var alpha: Double = 0.05 | ||
| private var selectorType = ChiSqSelectorType.KBest | ||
| private var chiSqTestResult: Array[ChiSqTestResult] = _ | ||
|
|
||
| @Since("1.3.0") | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The existing constructor should still have javadoc maybe pointing to the setNumTopFeatures method to say that's the effect it has |
||
| def this(numTopFeatures: Int) { | ||
| this() | ||
| this.numTopFeatures = numTopFeatures | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should call
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is not necessary, because the default selectorType is KBest
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK. It seemed to split the logic a bit here but it's not bad. The default behavior needs to be documented then. Now there is effectively a default numTopFeatures. |
||
| } | ||
|
|
||
| @Since("2.1.0") | ||
| def setNumTopFeatures(value: Int): this.type = { | ||
| numTopFeatures = value | ||
| selectorType = ChiSqSelectorType.KBest | ||
| this | ||
| } | ||
|
|
||
| @Since("2.1.0") | ||
| def setPercentile(value: Double): this.type = { | ||
| percentile = value | ||
| selectorType = ChiSqSelectorType.Percentile | ||
| this | ||
| } | ||
|
|
||
| @Since("2.1.0") | ||
| def setAlpha(value: Double): this.type = { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does it need a
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. require is added, thanks |
||
| alpha = value | ||
| selectorType = ChiSqSelectorType.Fpr | ||
| this | ||
| } | ||
|
|
||
| @Since("2.1.0") | ||
| def setChiSqSelectorType(value: ChiSqSelectorType.Value): this.type = { | ||
| selectorType = value | ||
| this | ||
| } | ||
|
|
||
| /** | ||
| * Returns a ChiSquared feature selector. | ||
|
|
@@ -189,11 +228,35 @@ class ChiSqSelector @Since("1.3.0") ( | |
| */ | ||
| @Since("1.3.0") | ||
| def fit(data: RDD[LabeledPoint]): ChiSqSelectorModel = { | ||
| val indices = Statistics.chiSqTest(data) | ||
| .zipWithIndex.sortBy { case (res, _) => -res.statistic } | ||
| .take(numTopFeatures) | ||
| .map { case (_, indices) => indices } | ||
| .sorted | ||
| chiSqTestResult = Statistics.chiSqTest(data) | ||
| selectorType match { | ||
| case ChiSqSelectorType.KBest => selectKBest(numTopFeatures) | ||
| case ChiSqSelectorType.Percentile => selectPercentile(percentile) | ||
| case ChiSqSelectorType.Fpr => selectFpr(alpha) | ||
| case _ => throw new Exception("Unknown ChiSqSelector Type") | ||
|
||
| } | ||
| } | ||
|
|
||
| @Since("2.1.0") | ||
| def selectKBest(value: Int): ChiSqSelectorModel = { | ||
|
||
| val indices = chiSqTestResult.zipWithIndex.sortBy { case (res, _) => -res.statistic } | ||
| .take(numTopFeatures) | ||
| .map { case (_, indices) => indices } | ||
| new ChiSqSelectorModel(indices) | ||
| } | ||
|
|
||
| @Since("2.1.0") | ||
| def selectPercentile(value: Double): ChiSqSelectorModel = { | ||
| val indices = chiSqTestResult.zipWithIndex.sortBy { case (res, _) => -res.statistic } | ||
| .take((chiSqTestResult.length * percentile / 100).toInt) | ||
| .map { case (_, indices) => indices } | ||
| new ChiSqSelectorModel(indices) | ||
| } | ||
|
|
||
| @Since("2.1.0") | ||
| def selectFpr(value: Double): ChiSqSelectorModel = { | ||
| val indices = chiSqTestResult.zipWithIndex.filter{ case (res, _) => res.pValue < alpha } | ||
| .map { case (_, indices) => indices } | ||
| new ChiSqSelectorModel(indices) | ||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use
.inRange(0, 1)here? it needs to be <= 1 too.