-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-2550][MLLIB][APACHE SPARK] Support regularization and intercept in pyspark's linear methods. #1624
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-2550][MLLIB][APACHE SPARK] Support regularization and intercept in pyspark's linear methods. #1624
Changes from 4 commits
3ac8874
78853ec
b962744
ec50ee9
638be47
8eba9c5
44e6ff0
fed8eaa
8dcb888
c02e5f5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -24,6 +24,7 @@ import org.apache.spark.api.java.{JavaSparkContext, JavaRDD} | |
| import org.apache.spark.mllib.classification._ | ||
| import org.apache.spark.mllib.clustering._ | ||
| import org.apache.spark.mllib.linalg.{SparseVector, Vector, Vectors} | ||
| import org.apache.spark.mllib.optimization._ | ||
| import org.apache.spark.mllib.recommendation._ | ||
| import org.apache.spark.mllib.regression._ | ||
| import org.apache.spark.mllib.util.MLUtils | ||
|
|
@@ -247,16 +248,24 @@ class PythonMLLibAPI extends Serializable { | |
| dataBytesJRDD: JavaRDD[Array[Byte]], | ||
| numIterations: Int, | ||
| stepSize: Double, | ||
| regParam: Double, | ||
| regType: String, | ||
| intercept: Boolean, | ||
| miniBatchFraction: Double, | ||
| initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = { | ||
| val lrAlg = new LinearRegressionWithSGD() | ||
| lrAlg.setIntercept(intercept) | ||
| lrAlg.optimizer. | ||
| setNumIterations(numIterations). | ||
| setRegParam(regParam). | ||
| setStepSize(stepSize) | ||
| if (regType == "SquaredUpdater") | ||
| lrAlg.optimizer.setUpdater(new SquaredL2Updater) | ||
| else if (regType == "L1Updater") | ||
| lrAlg.optimizer.setUpdater(new L1Updater) | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not using enumerations for
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is safer to add
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. By adding the exception to the scala code, I am going to remove the |
||
| trainRegressionModel( | ||
| (data, initialWeights) => | ||
| LinearRegressionWithSGD.train( | ||
| data, | ||
| numIterations, | ||
| stepSize, | ||
| miniBatchFraction, | ||
| initialWeights), | ||
| lrAlg.run(data, initialWeights), | ||
| dataBytesJRDD, | ||
| initialWeightsBA) | ||
| } | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -109,18 +109,27 @@ class LinearRegressionModel(LinearRegressionModelBase): | |
| True | ||
| """ | ||
|
|
||
|
|
||
| class LinearRegressionWithSGD(object): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We use two empty lines to separate methods in pyspark. (I don't know the exact reason ...) |
||
| @classmethod | ||
| def train(cls, data, iterations=100, step=1.0, | ||
| miniBatchFraction=1.0, initialWeights=None): | ||
| """Train a linear regression model on the given data.""" | ||
| def train(cls, data, iterations=100, step=1.0, regParam=1.0, regType=None, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To be safe, we shouldn't change the order of arguments, we can append new arguments |
||
| intercept=False, miniBatchFraction=1.0, initialWeights=None): | ||
| """Train a linear regression model on the given data. The 'regType' parameter can take | ||
| one from the following string values: "L1Updater" for invoking the lasso regularizer, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In Python, the line width for docs should be less than 80 (or 78 to be safe). |
||
| "SquaredUpdater" for invoking the ridge regularizer or "NONE" for not using a | ||
| regularizer at all. The user can determine the regularizer parameter by setting the | ||
| appropriate value to variable 'regParam' (by default is set to 1.0).""" | ||
| sc = data.context | ||
| train_f = lambda d, i: sc._jvm.PythonMLLibAPI().trainLinearRegressionModelWithSGD( | ||
| d._jrdd, iterations, step, miniBatchFraction, i) | ||
| if regType is None: | ||
| train_f = lambda d, i: sc._jvm.PythonMLLibAPI().trainLinearRegressionModelWithSGD( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To avoid having the long command twice, you can use
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ditto: please keep this empty line |
||
| d._jrdd, iterations, step, regParam, "NONE", intercept, miniBatchFraction, i) | ||
| elif regType == "SquaredUpdater" or regType == "L1Updater" or regType == "NONE": | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You can change it to |
||
| train_f = lambda d, i: sc._jvm.PythonMLLibAPI().trainLinearRegressionModelWithSGD( | ||
| d._jrdd, iterations, step, regParam, regType, intercept, miniBatchFraction, i) | ||
| else: | ||
| raise ValueError("Invalid value for 'regType' parameter. Can only be initialized " + | ||
| "using the following string values [L1Updater, SquaredUpdater, NONE].") | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not using enumerations for
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @miccagiann It may be easier if you send the string directly to
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In the current version, all branches in the if-else block are essentially the same.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes! I fixed it in the |
||
| return _regression_train_wrapper(sc, train_f, LinearRegressionModel, data, initialWeights) | ||
|
|
||
|
|
||
| class LassoModel(LinearRegressionModelBase): | ||
| """A linear regression model derived from a least-squares fit with an | ||
| l_1 penalty term. | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We usually put
.at the beginning of the line: