Skip to content

Commit 86b9e34

Browse files
committed
one pass over APIs of GLMs, NaiveBayes, and ALS
1 parent f21d862 commit 86b9e34

23 files changed

+157
-94
lines changed

mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,10 @@ import org.apache.spark.mllib.regression._
2828
import org.apache.spark.rdd.RDD
2929

3030
/**
31+
* <span class="badge badge-red" style="float: right;">DEVELOPER API</span>
32+
*
3133
* The Java stubs necessary for the Python mllib bindings.
34+
* Users should not call the methods defined in this class directly.
3235
*/
3336
class PythonMLLibAPI extends Serializable {
3437
private def deserializeDoubleVector(bytes: Array[Byte]): Array[Double] = {

mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ class LogisticRegressionModel(
5555
this
5656
}
5757

58-
override def predictPoint(dataMatrix: Vector, weightMatrix: Vector,
58+
override protected def predictPoint(dataMatrix: Vector, weightMatrix: Vector,
5959
intercept: Double) = {
6060
val margin = weightMatrix.toBreeze.dot(dataMatrix.toBreeze) + intercept
6161
val score = 1.0/ (1.0 + math.exp(-margin))
@@ -70,28 +70,28 @@ class LogisticRegressionModel(
7070
* Train a classification model for Logistic Regression using Stochastic Gradient Descent.
7171
* NOTE: Labels used in Logistic Regression should be {0, 1}
7272
*/
73-
class LogisticRegressionWithSGD private (
74-
var stepSize: Double,
75-
var numIterations: Int,
76-
var regParam: Double,
77-
var miniBatchFraction: Double)
73+
class LogisticRegressionWithSGD (
74+
private var stepSize: Double,
75+
private var numIterations: Int,
76+
private var regParam: Double,
77+
private var miniBatchFraction: Double)
7878
extends GeneralizedLinearAlgorithm[LogisticRegressionModel] with Serializable {
7979

80-
val gradient = new LogisticGradient()
81-
val updater = new SimpleUpdater()
80+
private val gradient = new LogisticGradient()
81+
private val updater = new SimpleUpdater()
8282
override val optimizer = new GradientDescent(gradient, updater)
8383
.setStepSize(stepSize)
8484
.setNumIterations(numIterations)
8585
.setRegParam(regParam)
8686
.setMiniBatchFraction(miniBatchFraction)
87-
override val validators = List(DataValidators.classificationLabels)
87+
override protected val validators = List(DataValidators.binaryLabelValidator)
8888

8989
/**
9090
* Construct a LogisticRegression object with default parameters
9191
*/
9292
def this() = this(1.0, 100, 0.0, 1.0)
9393

94-
def createModel(weights: Vector, intercept: Double) = {
94+
override protected def createModel(weights: Vector, intercept: Double) = {
9595
new LogisticRegressionModel(weights, intercept)
9696
}
9797
}

mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -40,14 +40,17 @@ class NaiveBayesModel(
4040
private val brzPi = new BDV[Double](pi)
4141
private val brzTheta = new BDM[Double](theta.length, theta(0).length)
4242

43-
var i = 0
44-
while (i < theta.length) {
45-
var j = 0
46-
while (j < theta(i).length) {
47-
brzTheta(i, j) = theta(i)(j)
48-
j += 1
43+
{
44+
// Need to put an extra pair of braces to prevent Scala treat `i` as a member.
45+
var i = 0
46+
while (i < theta.length) {
47+
var j = 0
48+
while (j < theta(i).length) {
49+
brzTheta(i, j) = theta(i)(j)
50+
j += 1
51+
}
52+
i += 1
4953
}
50-
i += 1
5154
}
5255

5356
override def predict(testData: RDD[Vector]): RDD[Double] = testData.map(predict)
@@ -65,7 +68,7 @@ class NaiveBayesModel(
6568
* document classification. By making every vector a 0-1 vector, it can also be used as
6669
* Bernoulli NB ([[http://tinyurl.com/p7c96j6]]).
6770
*/
68-
class NaiveBayes private (var lambda: Double) extends Serializable with Logging {
71+
class NaiveBayes (private var lambda: Double) extends Serializable with Logging {
6972

7073
def this() = this(1.0)
7174

mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,9 @@ class SVMModel(
5555
this
5656
}
5757

58-
override def predictPoint(dataMatrix: Vector, weightMatrix: Vector,
58+
override protected def predictPoint(
59+
dataMatrix: Vector,
60+
weightMatrix: Vector,
5961
intercept: Double) = {
6062
val margin = weightMatrix.toBreeze.dot(dataMatrix.toBreeze) + intercept
6163
threshold match {
@@ -69,29 +71,28 @@ class SVMModel(
6971
* Train a Support Vector Machine (SVM) using Stochastic Gradient Descent.
7072
* NOTE: Labels used in SVM should be {0, 1}.
7173
*/
72-
class SVMWithSGD private (
73-
var stepSize: Double,
74-
var numIterations: Int,
75-
var regParam: Double,
76-
var miniBatchFraction: Double)
74+
class SVMWithSGD(
75+
private var stepSize: Double,
76+
private var numIterations: Int,
77+
private var regParam: Double,
78+
private var miniBatchFraction: Double)
7779
extends GeneralizedLinearAlgorithm[SVMModel] with Serializable {
7880

79-
val gradient = new HingeGradient()
80-
val updater = new SquaredL2Updater()
81+
private val gradient = new HingeGradient()
82+
private val updater = new SquaredL2Updater()
8183
override val optimizer = new GradientDescent(gradient, updater)
8284
.setStepSize(stepSize)
8385
.setNumIterations(numIterations)
8486
.setRegParam(regParam)
8587
.setMiniBatchFraction(miniBatchFraction)
86-
87-
override val validators = List(DataValidators.classificationLabels)
88+
override protected val validators = List(DataValidators.binaryLabelValidator)
8889

8990
/**
9091
* Construct a SVM object with default parameters
9192
*/
9293
def this() = this(1.0, 100, 1.0, 1.0)
9394

94-
def createModel(weights: Vector, intercept: Double) = {
95+
override protected def createModel(weights: Vector, intercept: Double) = {
9596
new SVMModel(weights, intercept)
9697
}
9798
}

mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,13 +36,13 @@ import org.apache.spark.util.random.XORShiftRandom
3636
* This is an iterative algorithm that will make multiple passes over the data, so any RDDs given
3737
* to it should be cached by the user.
3838
*/
39-
class KMeans private (
40-
var k: Int,
41-
var maxIterations: Int,
42-
var runs: Int,
43-
var initializationMode: String,
44-
var initializationSteps: Int,
45-
var epsilon: Double) extends Serializable with Logging {
39+
class KMeans(
40+
private var k: Int,
41+
private var maxIterations: Int,
42+
private var runs: Int,
43+
private var initializationMode: String,
44+
private var initializationSteps: Int,
45+
private var epsilon: Double) extends Serializable with Logging {
4646
def this() = this(2, 20, 1, KMeans.K_MEANS_PARALLEL, 5, 1e-4)
4747

4848
/** Set the number of clusters to create (k). Default: 2. */
@@ -71,6 +71,8 @@ class KMeans private (
7171
}
7272

7373
/**
74+
* <span class="badge" style="float: right; background-color: #257080;">EXPERIMENTAL</span>
75+
*
7476
* Set the number of runs of the algorithm to execute in parallel. We initialize the algorithm
7577
* this many times with random starting conditions (configured by the initialization mode), then
7678
* return the best clustering found over any run. Default: 1.

mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -64,11 +64,13 @@ trait Vector extends Serializable {
6464

6565
/**
6666
* Factory methods for [[org.apache.spark.mllib.linalg.Vector]].
67+
* We don't use the name `Vector` because Scala imports
68+
* [[scala.collection.immutable.Vector]] by default.
6769
*/
6870
object Vectors {
6971

7072
/**
71-
* Creates a dense vector.
73+
* Creates a dense vector from its values.
7274
*/
7375
@varargs
7476
def dense(firstValue: Double, otherValues: Double*): Vector =
@@ -158,20 +160,21 @@ class DenseVector(val values: Array[Double]) extends Vector {
158160
/**
159161
* A sparse vector represented by an index array and an value array.
160162
*
161-
* @param n size of the vector.
163+
* @param size size of the vector.
162164
* @param indices index array, assume to be strictly increasing.
163165
* @param values value array, must have the same length as the index array.
164166
*/
165-
class SparseVector(val n: Int, val indices: Array[Int], val values: Array[Double]) extends Vector {
166-
167-
override def size: Int = n
167+
class SparseVector(
168+
override val size: Int,
169+
val indices: Array[Int],
170+
val values: Array[Double]) extends Vector {
168171

169172
override def toString: String = {
170-
"(" + n + "," + indices.zip(values).mkString("[", "," ,"]") + ")"
173+
"(" + size + "," + indices.zip(values).mkString("[", "," ,"]") + ")"
171174
}
172175

173176
override def toArray: Array[Double] = {
174-
val data = new Array[Double](n)
177+
val data = new Array[Double](size)
175178
var i = 0
176179
val nnz = indices.length
177180
while (i < nnz) {
@@ -181,5 +184,5 @@ class SparseVector(val n: Int, val indices: Array[Int], val values: Array[Double
181184
data
182185
}
183186

184-
private[mllib] override def toBreeze: BV[Double] = new BSV[Double](indices, values, n)
187+
private[mllib] override def toBreeze: BV[Double] = new BSV[Double](indices, values, size)
185188
}

mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ import breeze.linalg.{axpy => brzAxpy}
2222
import org.apache.spark.mllib.linalg.{Vectors, Vector}
2323

2424
/**
25+
* <span class="badge" style="float: right; background-color: #44751E;">DEVELOPER API</span>
26+
*
2527
* Class used to compute the gradient for a loss function, given a single data point.
2628
*/
2729
abstract class Gradient extends Serializable {
@@ -51,6 +53,8 @@ abstract class Gradient extends Serializable {
5153
}
5254

5355
/**
56+
* <span class="badge" style="float: right; background-color: #44751E;">DEVELOPER API</span>
57+
*
5458
* Compute gradient and loss for a logistic loss function, as used in binary classification.
5559
* See also the documentation for the precise formulation.
5660
*/
@@ -92,6 +96,8 @@ class LogisticGradient extends Gradient {
9296
}
9397

9498
/**
99+
* <span class="badge" style="float: right; background-color: #44751E;">DEVELOPER API</span>
100+
*
95101
* Compute gradient and loss for a Least-squared loss function, as used in linear regression.
96102
* This is correct for the averaged least squares loss function (mean squared error)
97103
* L = 1/n ||A weights-y||^2
@@ -124,6 +130,8 @@ class LeastSquaresGradient extends Gradient {
124130
}
125131

126132
/**
133+
* <span class="badge" style="float: right; background-color: #44751E;">DEVELOPER API</span>
134+
*
127135
* Compute gradient and loss for a Hinge loss function, as used in SVM binary classification.
128136
* See also the documentation for the precise formulation.
129137
* NOTE: This assumes that the labels are {0,1}

mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,18 +19,20 @@ package org.apache.spark.mllib.optimization
1919

2020
import scala.collection.mutable.ArrayBuffer
2121

22-
import breeze.linalg.{Vector => BV, DenseVector => BDV}
22+
import breeze.linalg.{DenseVector => BDV}
2323

2424
import org.apache.spark.Logging
2525
import org.apache.spark.rdd.RDD
2626
import org.apache.spark.mllib.linalg.{Vectors, Vector}
2727

2828
/**
29+
* <span class="badge" style="float: right; background-color: #44751E;">DEVELOPER API</span>
30+
*
2931
* Class used to solve an optimization problem using Gradient Descent.
3032
* @param gradient Gradient function to be used.
3133
* @param updater Updater to be used to update weights after every iteration.
3234
*/
33-
class GradientDescent(var gradient: Gradient, var updater: Updater)
35+
class GradientDescent(private var gradient: Gradient, private var updater: Updater)
3436
extends Optimizer with Logging
3537
{
3638
private var stepSize: Double = 1.0
@@ -107,7 +109,11 @@ class GradientDescent(var gradient: Gradient, var updater: Updater)
107109

108110
}
109111

110-
// Top-level method to run gradient descent.
112+
/**
113+
* <span class="badge" style="float: right; background-color: #44751E;">DEVELOPER API</span>
114+
*
115+
* Top-level method to run gradient descent.
116+
*/
111117
object GradientDescent extends Logging {
112118
/**
113119
* Run stochastic gradient descent (SGD) in parallel using mini batches.

mllib/src/main/scala/org/apache/spark/mllib/optimization/Optimizer.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,11 @@ import org.apache.spark.rdd.RDD
2121

2222
import org.apache.spark.mllib.linalg.Vector
2323

24+
/**
25+
* <span class="badge" style="float: right; background-color: #44751E;">DEVELOPER API</span>
26+
*
27+
* Trait for optimization problem solvers.
28+
*/
2429
trait Optimizer extends Serializable {
2530

2631
/**

mllib/src/main/scala/org/apache/spark/mllib/optimization/Updater.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ import breeze.linalg.{norm => brzNorm, axpy => brzAxpy, Vector => BV}
2424
import org.apache.spark.mllib.linalg.{Vectors, Vector}
2525

2626
/**
27+
* <span class="badge" style="float: right; background-color: #44751E;">DEVELOPER API</span>
28+
*
2729
* Class used to perform steps (weight update) using Gradient Descent methods.
2830
*
2931
* For general minimization problems, or for regularized problems of the form
@@ -59,6 +61,8 @@ abstract class Updater extends Serializable {
5961
}
6062

6163
/**
64+
* <span class="badge" style="float: right; background-color: #44751E;">DEVELOPER API</span>
65+
*
6266
* A simple updater for gradient descent *without* any regularization.
6367
* Uses a step-size decreasing with the square root of the number of iterations.
6468
*/
@@ -78,6 +82,8 @@ class SimpleUpdater extends Updater {
7882
}
7983

8084
/**
85+
* <span class="badge" style="float: right; background-color: #44751E;">DEVELOPER API</span>
86+
*
8187
* Updater for L1 regularized problems.
8288
* R(w) = ||w||_1
8389
* Uses a step-size decreasing with the square root of the number of iterations.
@@ -120,6 +126,8 @@ class L1Updater extends Updater {
120126
}
121127

122128
/**
129+
* <span class="badge" style="float: right; background-color: #44751E;">DEVELOPER API</span>
130+
*
123131
* Updater for L2 regularized problems.
124132
* R(w) = 1/2 ||w||^2
125133
* Uses a step-size decreasing with the square root of the number of iterations.

0 commit comments

Comments
 (0)