-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-2550][MLLIB][APACHE SPARK] Support regularization and intercept in pyspark's linear methods. #1624
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-2550][MLLIB][APACHE SPARK] Support regularization and intercept in pyspark's linear methods. #1624
Changes from 2 commits
3ac8874
78853ec
b962744
ec50ee9
638be47
8eba9c5
44e6ff0
fed8eaa
8dcb888
c02e5f5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -24,6 +24,7 @@ import org.apache.spark.api.java.{JavaSparkContext, JavaRDD} | |
| import org.apache.spark.mllib.classification._ | ||
| import org.apache.spark.mllib.clustering._ | ||
| import org.apache.spark.mllib.linalg.{SparseVector, Vector, Vectors} | ||
| import org.apache.spark.mllib.optimization._ | ||
| import org.apache.spark.mllib.recommendation._ | ||
| import org.apache.spark.mllib.regression._ | ||
| import org.apache.spark.mllib.util.MLUtils | ||
|
|
@@ -42,6 +43,16 @@ class PythonMLLibAPI extends Serializable { | |
| private val DENSE_MATRIX_MAGIC: Byte = 3 | ||
| private val LABELED_POINT_MAGIC: Byte = 4 | ||
|
|
||
| /** | ||
| * Enumeration used to define the type of Regularizer | ||
| * used for linear methods. | ||
| */ | ||
| object RegularizerType extends Serializable { | ||
| val L2 : Int = 0 | ||
| val L1 : Int = 1 | ||
| val NONE : Int = 2 | ||
| } | ||
|
|
||
| private[python] def deserializeDoubleVector(bytes: Array[Byte], offset: Int = 0): Vector = { | ||
| require(bytes.length - offset >= 5, "Byte array too short") | ||
| val magic = bytes(offset) | ||
|
|
@@ -247,16 +258,24 @@ class PythonMLLibAPI extends Serializable { | |
| dataBytesJRDD: JavaRDD[Array[Byte]], | ||
| numIterations: Int, | ||
| stepSize: Double, | ||
| regParam: Double, | ||
| regType: Int, | ||
| intercept: Boolean, | ||
| miniBatchFraction: Double, | ||
| initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = { | ||
| val lrAlg = new LinearRegressionWithSGD() | ||
| lrAlg.setIntercept(intercept) | ||
| lrAlg.optimizer. | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We usually put |
||
| setNumIterations(numIterations). | ||
| setRegParam(regParam). | ||
| setStepSize(stepSize) | ||
| if (regType == RegularizerType.L2) | ||
| lrAlg.optimizer.setUpdater(new SquaredL2Updater) | ||
| else if (regType == RegularizerType.L1) | ||
| lrAlg.optimizer.setUpdater(new L1Updater) | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not using enumerations for
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is safer to add
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. By adding the exception to the scala code, I am going to remove the |
||
| trainRegressionModel( | ||
| (data, initialWeights) => | ||
| LinearRegressionWithSGD.train( | ||
| data, | ||
| numIterations, | ||
| stepSize, | ||
| miniBatchFraction, | ||
| initialWeights), | ||
| lrAlg.run(data, initialWeights), | ||
| dataBytesJRDD, | ||
| initialWeightsBA) | ||
| } | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -109,18 +109,35 @@ class LinearRegressionModel(LinearRegressionModelBase): | |
| True | ||
| """ | ||
|
|
||
| class RegularizerType(object): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We use two empty lines to separate methods in pyspark. (I don't know the exact reason ...) |
||
| L2 = 0 | ||
| L1 = 1 | ||
| NONE = 2 | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The same enumeration is provided and through the |
||
|
|
||
| class LinearRegressionWithSGD(object): | ||
| @classmethod | ||
| def train(cls, data, iterations=100, step=1.0, | ||
| miniBatchFraction=1.0, initialWeights=None): | ||
| def train(cls, data, iterations=100, step=1.0, regParam=1.0, regType=None, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To be safe, we shouldn't change the order of arguments, we can append new arguments |
||
| intercept=False, miniBatchFraction=1.0, initialWeights=None): | ||
| """Train a linear regression model on the given data.""" | ||
| sc = data.context | ||
| train_f = lambda d, i: sc._jvm.PythonMLLibAPI().trainLinearRegressionModelWithSGD( | ||
| d._jrdd, iterations, step, miniBatchFraction, i) | ||
| if regType is None: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ditto: please keep this empty line |
||
| train_f = lambda d, i: sc._jvm.PythonMLLibAPI().trainLinearRegressionModelWithSGD( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To avoid having the long command twice, you can use |
||
| d._jrdd, iterations, step, regParam, sc._jvm.PythonMLLibAPI().RegularizerType().NONE(), | ||
| intercept, miniBatchFraction, i) | ||
| elif regType == RegularizerType.L2: | ||
| train_f = lambda d, i: sc._jvm.PythonMLLibAPI().trainLinearRegressionModelWithSGD( | ||
| d._jrdd, iterations, step, regParam, sc._jvm.PythonMLLibAPI().RegularizerType().L2(), | ||
| intercept, miniBatchFraction, i) | ||
| elif regType == RegularizerType.L1: | ||
| train_f = lambda d, i: sc._jvm.PythonMLLibAPI().trainLinearRegressionModelWithSGD( | ||
| d._jrdd, iterations, step, regParam, sc._jvm.PythonMLLibAPI().RegularizerType().L1(), | ||
| intercept, miniBatchFraction, i) | ||
| elif regType == RegularizerType.NONE: | ||
| train_f = lambda d, i: sc._jvm.PythonMLLibAPI().trainLinearRegressionModelWithSGD( | ||
| d._jrdd, iterations, step, regParam, sc._jvm.PythonMLLibAPI().RegularizerType().NONE(), | ||
| intercept, miniBatchFraction, i) | ||
| return _regression_train_wrapper(sc, train_f, LinearRegressionModel, data, initialWeights) | ||
|
|
||
|
|
||
| class LassoModel(LinearRegressionModelBase): | ||
| """A linear regression model derived from a least-squares fit with an | ||
| l_1 penalty term. | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I used a type of Enumeration in order to separate between the different types of Update Methods [Regularizers] with which the user wants to create the model from training data. I tried to extend this object from Enumeration but from what I have seen it uses reflection heavily and it does not work well with serialized objects and with py4j...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using strings with a clear doc should be sufficient. Then you can map the string to
L1UpdaterorSquaredUpdaterinsidePythonMLLibAPI.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok! I will do it with strings both in python and in scala.