Skip to content
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import org.apache.spark.api.java.{JavaSparkContext, JavaRDD}
import org.apache.spark.mllib.classification._
import org.apache.spark.mllib.clustering._
import org.apache.spark.mllib.linalg.{SparseVector, Vector, Vectors}
import org.apache.spark.mllib.optimization._
import org.apache.spark.mllib.recommendation._
import org.apache.spark.mllib.regression._
import org.apache.spark.mllib.util.MLUtils
Expand Down Expand Up @@ -247,16 +248,24 @@ class PythonMLLibAPI extends Serializable {
dataBytesJRDD: JavaRDD[Array[Byte]],
numIterations: Int,
stepSize: Double,
regParam: Double,
regType: String,
intercept: Boolean,
miniBatchFraction: Double,
initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = {
val lrAlg = new LinearRegressionWithSGD()
lrAlg.setIntercept(intercept)
lrAlg.optimizer.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We usually put . at the beginning of the line:

lrAlg.optimizer
  .setNumIterations(numIterations)
  .setRegParam(regParam)
  .setStepSize(stepSize)

setNumIterations(numIterations).
setRegParam(regParam).
setStepSize(stepSize)
if (regType == "SquaredUpdater")
lrAlg.optimizer.setUpdater(new SquaredL2Updater)
else if (regType == "L1Updater")
lrAlg.optimizer.setUpdater(new L1Updater)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not using enumerations for regType parameter anymore. Switched to string values.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is safer to add

    else if (regType != "none")
      throw IllegalArgumentException("...")

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By adding the exception to the scala code, I am going to remove the ValueError exception used in the python code.

trainRegressionModel(
(data, initialWeights) =>
LinearRegressionWithSGD.train(
data,
numIterations,
stepSize,
miniBatchFraction,
initialWeights),
lrAlg.run(data, initialWeights),
dataBytesJRDD,
initialWeightsBA)
}
Expand Down
23 changes: 16 additions & 7 deletions python/pyspark/mllib/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,18 +109,27 @@ class LinearRegressionModel(LinearRegressionModelBase):
True
"""


class LinearRegressionWithSGD(object):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We use two empty lines to separate methods in pyspark. (I don't know the exact reason ...)

@classmethod
def train(cls, data, iterations=100, step=1.0,
miniBatchFraction=1.0, initialWeights=None):
"""Train a linear regression model on the given data."""
def train(cls, data, iterations=100, step=1.0, regParam=1.0, regType=None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be safe, we shouldn't change the order of arguments, we can append new arguments regParam, regType, and intercept at the end. So if user used train(data, 100, 1.0, 1.0, np.zeros(n)), it still works in the new version.

intercept=False, miniBatchFraction=1.0, initialWeights=None):
"""Train a linear regression model on the given data. The 'regType' parameter can take
one from the following string values: "L1Updater" for invoking the lasso regularizer,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In Python, the line width for docs should be less than 80 (or 78 to be safe).

"SquaredUpdater" for invoking the ridge regularizer or "NONE" for not using a
regularizer at all. The user can determine the regularizer parameter by setting the
appropriate value to variable 'regParam' (by default is set to 1.0)."""
sc = data.context
train_f = lambda d, i: sc._jvm.PythonMLLibAPI().trainLinearRegressionModelWithSGD(
d._jrdd, iterations, step, miniBatchFraction, i)
if regType is None:
train_f = lambda d, i: sc._jvm.PythonMLLibAPI().trainLinearRegressionModelWithSGD(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To avoid having the long command twice, you can use

if regType is None:
  regType = "none"
if regType in {"l2", "l1", "none"}:
    train_f = ...
else:
    raise ...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto: please keep this empty line

d._jrdd, iterations, step, regParam, "NONE", intercept, miniBatchFraction, i)
elif regType == "SquaredUpdater" or regType == "L1Updater" or regType == "NONE":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can change it to l2, l1, and none (simpler and lowercases)

train_f = lambda d, i: sc._jvm.PythonMLLibAPI().trainLinearRegressionModelWithSGD(
d._jrdd, iterations, step, regParam, regType, intercept, miniBatchFraction, i)
else:
raise ValueError("Invalid value for 'regType' parameter. Can only be initialized " +
"using the following string values [L1Updater, SquaredUpdater, NONE].")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not using enumerations for regType parameter anymore. Switched to string values.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@miccagiann It may be easier if you send the string directly to PythonMLLibAPI().trainLinearRegressionModelWithSGD and implement the logic there.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the current version, all branches in the if-else block are essentially the same.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes! I fixed it in the regression.py file where I was calling the same function again and again. As far as PythonMLLibAPI().trainLinearRegressionModelWithSGD I implement there the logic as well... I am building right now and I will commit instantly.

return _regression_train_wrapper(sc, train_f, LinearRegressionModel, data, initialWeights)


class LassoModel(LinearRegressionModelBase):
"""A linear regression model derived from a least-squares fit with an
l_1 penalty term.
Expand Down