Skip to content

Commit cff73e0

Browse files
committed
Replaced accumulators with RDD.aggregate
1 parent 20ebca1 commit cff73e0

1 file changed

Lines changed: 60 additions & 65 deletions

File tree

mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModelEM.scala

Lines changed: 60 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,50 @@ class GaussianMixtureModelEM private (
5252
private type DenseDoubleVector = BreezeVector[Double]
5353
private type DenseDoubleMatrix = BreezeMatrix[Double]
5454

55+
private type ExpectationSum = (
56+
Array[Double], // log-likelihood in index 0
57+
Array[Double], // array of weights
58+
Array[DenseDoubleVector], // array of means
59+
Array[DenseDoubleMatrix]) // array of cov matrices
60+
61+
// create a zero'd ExpectationSum instance
62+
private def zeroExpectationSum(k: Int, d: Int): ExpectationSum = {
63+
(Array(0.0),
64+
new Array[Double](k),
65+
(0 until k).map(_ => BreezeVector.zeros[Double](d)).toArray,
66+
(0 until k).map(_ => BreezeMatrix.zeros[Double](d,d)).toArray)
67+
}
68+
69+
// add two ExpectationSum objects (allowed to use modify m1)
70+
// (U, U) => U for aggregation
71+
private def addExpectationSums(m1: ExpectationSum, m2: ExpectationSum): ExpectationSum = {
72+
m1._1(0) += m2._1(0)
73+
for (i <- 0 until m1._2.length) {
74+
m1._2(i) += m2._2(i)
75+
m1._3(i) += m2._3(i)
76+
m1._4(i) += m2._4(i)
77+
}
78+
m1
79+
}
80+
81+
// compute cluster contributions for each input point
82+
// (U, T) => U for aggregation
83+
private def computeExpectation(weights: Array[Double], dists: Array[MultivariateGaussian])
84+
(model: ExpectationSum, x: DenseDoubleVector): ExpectationSum = {
85+
val k = model._2.length
86+
val p = (0 until k).map(i => eps + weights(i) * dists(i).pdf(x)).toArray
87+
val pSum = p.sum
88+
model._1(0) += math.log(pSum)
89+
val xxt = x * new Transpose(x)
90+
for (i <- 0 until k) {
91+
p(i) /= pSum
92+
model._2(i) += p(i)
93+
model._3(i) += x * p(i)
94+
model._4(i) += xxt * p(i)
95+
}
96+
model
97+
}
98+
5599
// number of samples per cluster to use when initializing Gaussians
56100
private val nSamples = 5
57101

@@ -115,7 +159,7 @@ class GaussianMixtureModelEM private (
115159
val ctx = data.sparkContext
116160

117161
// we will operate on the data as breeze data
118-
val breezeData = data.map( u => u.toBreeze.toDenseVector ).cache()
162+
val breezeData = data.map(u => u.toBreeze.toDenseVector).cache()
119163

120164
// Get length of the input vectors
121165
val d = breezeData.first.length
@@ -143,55 +187,28 @@ class GaussianMixtureModelEM private (
143187
}
144188
}
145189

146-
val accW = new Array[Accumulator[Double]](k)
147-
val accMu = new Array[Accumulator[DenseDoubleVector]](k)
148-
val accSigma = new Array[Accumulator[DenseDoubleMatrix]](k)
149-
150190
var llh = Double.MinValue // current log-likelihood
151191
var llhp = 0.0 // previous log-likelihood
152192

153193
var iter = 0
154194
do {
155-
// reset accumulators
156-
for (i <- 0 until k) {
157-
accW(i) = ctx.accumulator(0.0)
158-
accMu(i) = ctx.accumulator(
159-
BreezeVector.zeros[Double](d))(DenseDoubleVectorAccumulatorParam)
160-
accSigma(i) = ctx.accumulator(
161-
BreezeMatrix.zeros[Double](d,d))(DenseDoubleMatrixAccumulatorParam)
162-
}
195+
// pivot gaussians into weight and distribution arrays
196+
val weights = (0 until k).map(i => gaussians(i)._1).toArray
197+
val dists = (0 until k).map{ i =>
198+
new MultivariateGaussian(gaussians(i)._2, gaussians(i)._3)
199+
}.toArray
163200

164-
val logLikelihood = ctx.accumulator(0.0)
165-
166-
// broadcast the current weights and distributions to all nodes
167-
val dists = ctx.broadcast{
168-
(0 until k).map(i => new MultivariateGaussian(gaussians(i)._2, gaussians(i)._3)).toArray
169-
}
170-
val weights = ctx.broadcast((0 until k).map(i => gaussians(i)._1).toArray)
201+
// create and broadcast curried cluster contribution function
202+
val compute = ctx.broadcast(computeExpectation(weights, dists)_)
171203

172-
// calculate partial assignments for each sample in the data
173-
// (often referred to as the "E" step in literature)
174-
breezeData.foreach{ x =>
175-
val p = (0 until k).map(i => eps + weights.value(i) * dists.value(i).pdf(x)).toArray
176-
177-
val pSum = p.sum
178-
179-
logLikelihood += math.log(pSum)
180-
181-
// accumulate weighted sums
182-
val xxt = x * new Transpose(x)
183-
for (i <- 0 until k) {
184-
p(i) /= pSum
185-
accW(i) += p(i)
186-
accMu(i) += x * p(i)
187-
accSigma(i) += xxt * p(i)
188-
}
189-
}
204+
// aggregate the cluster contribution for all sample points
205+
val sums = breezeData.aggregate(zeroExpectationSum(k, d))(compute.value, addExpectationSums)
190206

191-
// Collect the computed sums
192-
val W = (0 until k).map(i => accW(i).value).toArray
193-
val MU = (0 until k).map(i => accMu(i).value).toArray
194-
val SIGMA = (0 until k).map(i => accSigma(i).value).toArray
207+
// Assignments to make the code more readable
208+
val logLikelihood = sums._1(0)
209+
val W = sums._2
210+
val MU = sums._3
211+
val SIGMA = sums._4
195212

196213
// Create new distributions based on the partial assignments
197214
// (often referred to as the "M" step in literature)
@@ -203,7 +220,7 @@ class GaussianMixtureModelEM private (
203220
}.toArray
204221

205222
llhp = llh // current becomes previous
206-
llh = logLikelihood.value // this is the freshly computed log-likelihood
223+
llh = logLikelihood // this is the freshly computed log-likelihood
207224
iter += 1
208225
} while(iter < maxIterations && Math.abs(llh-llhp) > convergenceTol)
209226

@@ -264,26 +281,4 @@ class GaussianMixtureModelEM private (
264281
}
265282
p
266283
}
267-
268-
/** AccumulatorParam for Dense Breeze Vectors */
269-
private object DenseDoubleVectorAccumulatorParam extends AccumulatorParam[DenseDoubleVector] {
270-
def zero(initialVector: DenseDoubleVector): DenseDoubleVector = {
271-
BreezeVector.zeros[Double](initialVector.length)
272-
}
273-
274-
def addInPlace(a: DenseDoubleVector, b: DenseDoubleVector): DenseDoubleVector = {
275-
a += b
276-
}
277-
}
278-
279-
/** AccumulatorParam for Dense Breeze Matrices */
280-
private object DenseDoubleMatrixAccumulatorParam extends AccumulatorParam[DenseDoubleMatrix] {
281-
def zero(initialMatrix: DenseDoubleMatrix): DenseDoubleMatrix = {
282-
BreezeMatrix.zeros[Double](initialMatrix.rows, initialMatrix.cols)
283-
}
284-
285-
def addInPlace(a: DenseDoubleMatrix, b: DenseDoubleMatrix): DenseDoubleMatrix = {
286-
a += b
287-
}
288-
}
289284
}

0 commit comments

Comments
 (0)