-
Notifications
You must be signed in to change notification settings - Fork 29k
[WIP] [MLLIB-28] An optimized GradientDescent implementation #166
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -32,10 +32,10 @@ import scala.collection.mutable.ArrayBuffer | |
| class GradientDescent(var gradient: Gradient, var updater: Updater) | ||
| extends Optimizer with Logging | ||
| { | ||
| private var stepSize: Double = 1.0 | ||
| private var numIterations: Int = 100 | ||
| private var regParam: Double = 0.0 | ||
| private var miniBatchFraction: Double = 1.0 | ||
| protected var stepSize: Double = 1.0 | ||
| protected var numIterations: Int = 100 | ||
| protected var regParam: Double = 0.0 | ||
| protected var miniBatchFraction: Double = 1.0 | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Indentation error, should be 2 spaces. |
||
|
|
||
| /** | ||
| * Set the initial step size of SGD for the first step. Default 1.0. | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,147 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one or more | ||
| * contributor license agreements. See the NOTICE file distributed with | ||
| * this work for additional information regarding copyright ownership. | ||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| * (the "License"); you may not use this file except in compliance with | ||
| * the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| package org.apache.spark.mllib.optimization | ||
|
|
||
| import org.apache.spark.Logging | ||
| import org.apache.spark.rdd.RDD | ||
|
|
||
| import org.jblas.DoubleMatrix | ||
|
|
||
| import scala.collection.mutable.ArrayBuffer | ||
| import scala.util.Random | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please reorder the imports according to Spark coding convention. |
||
|
|
||
| /** | ||
| * Class used to solve an optimization problem using Gradient Descent. | ||
| * @param gradient Gradient function to be used. | ||
| * @param updater Updater to be used to update weights after every iteration. | ||
| */ | ||
| class GradientDescentWithLocalUpdate(gradient: Gradient, updater: Updater) | ||
| extends GradientDescent(gradient, updater) with Logging | ||
| { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Move the right brace to the end of the last line |
||
| private var numLocalIterations: Int = 1 | ||
|
|
||
| /** | ||
| * Set the number of local iterations. Default 1. | ||
| */ | ||
| def setNumLocalIterations(numLocalIter: Int): this.type = { | ||
| this.numLocalIterations = numLocalIter | ||
| this | ||
| } | ||
|
|
||
| override def optimize(data: RDD[(Double, Array[Double])], initialWeights: Array[Double]) | ||
| : Array[Double] = { | ||
|
|
||
| val (weights, stochasticLossHistory) = GradientDescentWithLocalUpdate.runMiniBatchSGD( | ||
| data, | ||
| gradient, | ||
| updater, | ||
| stepSize, | ||
| numIterations, | ||
| numLocalIterations, | ||
| regParam, | ||
| miniBatchFraction, | ||
| initialWeights) | ||
| weights | ||
| } | ||
|
|
||
| } | ||
|
|
||
| // Top-level method to run gradient descent. | ||
| object GradientDescentWithLocalUpdate extends Logging { | ||
| /** | ||
| * Run BSP+ gradient descent in parallel using mini batches. | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm... since BSP+ is not a well known concept and there's no related references (yet), maybe we should not use this term? Any suggestions @mengxr?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree. We need a reference here or at least explain |
||
| * | ||
| * @param data - Input data for SGD. RDD of form (label, [feature values]). | ||
| * @param gradient - Gradient object that will be used to compute the gradient. | ||
| * @param updater - Updater object that will be used to update the model. | ||
| * @param stepSize - stepSize to be used during update. | ||
| * @param numOuterIterations - number of outer iterations that SGD should be run. | ||
| * @param numInnerIterations - number of inner iterations that SGD should be run. | ||
| * @param regParam - regularization parameter | ||
| * @param miniBatchFraction - fraction of the input data set that should be used for | ||
| * one iteration of SGD. Default value 1.0. | ||
| * | ||
| * @return A tuple containing two elements. The first element is a column matrix containing | ||
| * weights for every feature, and the second element is an array containing the stochastic | ||
| * loss computed for every iteration. | ||
| */ | ||
| def runMiniBatchSGD( | ||
| data: RDD[(Double, Array[Double])], | ||
| gradient: Gradient, | ||
| updater: Updater, | ||
| stepSize: Double, | ||
| numOuterIterations: Int, | ||
| numInnerIterations: Int, | ||
| regParam: Double, | ||
| miniBatchFraction: Double, | ||
| initialWeights: Array[Double]) : (Array[Double], Array[Double]) = { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove the space before the colon. |
||
|
|
||
| val stochasticLossHistory = new ArrayBuffer[Double](numOuterIterations) | ||
|
|
||
| val numExamples: Long = data.count() | ||
| val numPartition = data.partitions.length | ||
| val miniBatchSize = numExamples * miniBatchFraction / numPartition | ||
|
|
||
| // Initialize weights as a column vector | ||
| var weights = new DoubleMatrix(initialWeights.length, 1, initialWeights: _*) | ||
| var regVal = 0.0 | ||
|
|
||
| for (i <- 1 to numOuterIterations) { | ||
| val weightsAndLosses = data.mapPartitions { iter => | ||
| var iterReserved = iter | ||
| val localLossHistory = new ArrayBuffer[Double](numInnerIterations) | ||
|
|
||
| for (j <- 1 to numInnerIterations) { | ||
| val (iterCurrent, iterNext) = iterReserved.duplicate | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not very familiar with how
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good question, I look into I have no idea of the memory cost by the
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We're using If my guess is right, instead of using
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @yinxusen I think this approach will certainly run OOM if data is too big to fit into memory. You can set a small executor memory and test some data without caching.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @mengxr, I absolutely agree with you. I am trying another way now, and will have a test result tomorrow. |
||
| val rand = new Random(42 + i * numOuterIterations + j) | ||
| val sampled = iterCurrent.filter(x => rand.nextDouble() <= miniBatchFraction) | ||
| val (gradientSum, lossSum) = sampled.map { case (y, features) => | ||
| val featuresCol = new DoubleMatrix(features.length, 1, features: _*) | ||
| val (grad, loss) = gradient.compute(featuresCol, y, weights) | ||
| (grad, loss) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can simply return |
||
| }.reduce((a, b) => (a._1.addi(b._1), a._2 + b._2)) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. An edge case: do we need to consider empty partition here (
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm... Sounds good. I should take care of it, especially when the data in each partition is little and the |
||
|
|
||
| localLossHistory += lossSum / miniBatchSize + regVal | ||
|
|
||
| val update = updater.compute(weights, gradientSum.div(miniBatchSize), | ||
| stepSize, (i - 1) + numOuterIterations + j, regParam) | ||
|
|
||
| weights = update._1 | ||
| regVal = update._2 | ||
|
|
||
| iterReserved = iterNext | ||
| } | ||
|
|
||
| List((weights, localLossHistory.toArray)).iterator | ||
| } | ||
|
|
||
| val c = weightsAndLosses.collect() | ||
| val (ws, ls) = c.unzip | ||
|
|
||
| stochasticLossHistory.append(ls.head.reduce(_ + _) / ls.head.size) | ||
|
|
||
| val weightsSum = ws.reduce(_ addi _) | ||
| weights = weightsSum.divi(c.size) | ||
| } | ||
|
|
||
| logInfo("GradientDescentWithLocalUpdate finished. Last 10 stochastic losses %s".format( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How about changing 10 to "a few"? Length of |
||
| stochasticLossHistory.takeRight(10).mkString(", "))) | ||
|
|
||
| (weights.toArray, stochasticLossHistory.toArray) | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,71 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one or more | ||
| * contributor license agreements. See the NOTICE file distributed with | ||
| * this work for additional information regarding copyright ownership. | ||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| * (the "License"); you may not use this file except in compliance with | ||
| * the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| package org.apache.spark.mllib.optimization | ||
|
|
||
| import org.scalatest.FunSuite | ||
| import org.scalatest.matchers.ShouldMatchers | ||
|
|
||
| import org.apache.spark.mllib.regression._ | ||
| import org.apache.spark.mllib.util.LocalSparkContext | ||
|
|
||
| class GradientDescentWithLocalUpdateSuite extends FunSuite with LocalSparkContext with ShouldMatchers { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 100 columns exceeded.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also, seems that there's no need to mixin |
||
|
|
||
| import GradientDescentSuite._ | ||
|
|
||
| test("Assert the loss is decreasing.") { | ||
| val nPoints = 10000 | ||
| val A = 2.0 | ||
| val B = -1.5 | ||
|
|
||
| val initialB = -1.0 | ||
| val initialWeights = Array(initialB) | ||
|
|
||
| val gradient = new LogisticGradient() | ||
| val updater = new SimpleUpdater() | ||
| val stepSize = 1.0 | ||
| val numIterations = 10 | ||
| val numLocalIterations = 10 | ||
| val regParam = 0 | ||
| val miniBatchFrac = 1.0 | ||
|
|
||
| // Add a extra variable consisting of all 1.0's for the intercept. | ||
| val testData = generateGDInput(A, B, nPoints, 42) | ||
| val data = testData.map { case LabeledPoint(label, features) => | ||
| label -> Array(1.0, features: _*) | ||
| } | ||
|
|
||
| val dataRDD = sc.parallelize(data, 2).cache() | ||
| val initialWeightsWithIntercept = Array(1.0, initialWeights: _*) | ||
|
|
||
| val (_, loss) = GradientDescentWithLocalUpdate.runMiniBatchSGD( | ||
| dataRDD, | ||
| gradient, | ||
| updater, | ||
| stepSize, | ||
| numIterations, | ||
| numLocalIterations, | ||
| regParam, | ||
| miniBatchFrac, | ||
| initialWeightsWithIntercept) | ||
|
|
||
| assert(loss.last - loss.head < 0, "loss isn't decreasing.") | ||
|
|
||
| val lossDiff = loss.init.zip(loss.tail).map { case (lhs, rhs) => lhs - rhs } | ||
| assert(lossDiff.count(_ > 0).toDouble / lossDiff.size > 0.8) | ||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Move right brace to the end of the last line.