@@ -52,6 +52,50 @@ class GaussianMixtureModelEM private (
5252 private type DenseDoubleVector = BreezeVector [Double ]
5353 private type DenseDoubleMatrix = BreezeMatrix [Double ]
5454
55+ private type ExpectationSum = (
56+ Array [Double ], // log-likelihood in index 0
57+ Array [Double ], // array of weights
58+ Array [DenseDoubleVector ], // array of means
59+ Array [DenseDoubleMatrix ]) // array of cov matrices
60+
61+ // create a zero'd ExpectationSum instance
62+ private def zeroExpectationSum (k : Int , d : Int ): ExpectationSum = {
63+ (Array (0.0 ),
64+ new Array [Double ](k),
65+ (0 until k).map(_ => BreezeVector .zeros[Double ](d)).toArray,
66+ (0 until k).map(_ => BreezeMatrix .zeros[Double ](d,d)).toArray)
67+ }
68+
69+ // add two ExpectationSum objects (allowed to use modify m1)
70+ // (U, U) => U for aggregation
71+ private def addExpectationSums (m1 : ExpectationSum , m2 : ExpectationSum ): ExpectationSum = {
72+ m1._1(0 ) += m2._1(0 )
73+ for (i <- 0 until m1._2.length) {
74+ m1._2(i) += m2._2(i)
75+ m1._3(i) += m2._3(i)
76+ m1._4(i) += m2._4(i)
77+ }
78+ m1
79+ }
80+
81+ // compute cluster contributions for each input point
82+ // (U, T) => U for aggregation
83+ private def computeExpectation (weights : Array [Double ], dists : Array [MultivariateGaussian ])
84+ (model : ExpectationSum , x : DenseDoubleVector ): ExpectationSum = {
85+ val k = model._2.length
86+ val p = (0 until k).map(i => eps + weights(i) * dists(i).pdf(x)).toArray
87+ val pSum = p.sum
88+ model._1(0 ) += math.log(pSum)
89+ val xxt = x * new Transpose (x)
90+ for (i <- 0 until k) {
91+ p(i) /= pSum
92+ model._2(i) += p(i)
93+ model._3(i) += x * p(i)
94+ model._4(i) += xxt * p(i)
95+ }
96+ model
97+ }
98+
5599 // number of samples per cluster to use when initializing Gaussians
56100 private val nSamples = 5
57101
@@ -115,7 +159,7 @@ class GaussianMixtureModelEM private (
115159 val ctx = data.sparkContext
116160
117161 // we will operate on the data as breeze data
118- val breezeData = data.map( u => u.toBreeze.toDenseVector ).cache()
162+ val breezeData = data.map(u => u.toBreeze.toDenseVector).cache()
119163
120164 // Get length of the input vectors
121165 val d = breezeData.first.length
@@ -143,55 +187,28 @@ class GaussianMixtureModelEM private (
143187 }
144188 }
145189
146- val accW = new Array [Accumulator [Double ]](k)
147- val accMu = new Array [Accumulator [DenseDoubleVector ]](k)
148- val accSigma = new Array [Accumulator [DenseDoubleMatrix ]](k)
149-
150190 var llh = Double .MinValue // current log-likelihood
151191 var llhp = 0.0 // previous log-likelihood
152192
153193 var iter = 0
154194 do {
155- // reset accumulators
156- for (i <- 0 until k) {
157- accW(i) = ctx.accumulator(0.0 )
158- accMu(i) = ctx.accumulator(
159- BreezeVector .zeros[Double ](d))(DenseDoubleVectorAccumulatorParam )
160- accSigma(i) = ctx.accumulator(
161- BreezeMatrix .zeros[Double ](d,d))(DenseDoubleMatrixAccumulatorParam )
162- }
195+ // pivot gaussians into weight and distribution arrays
196+ val weights = (0 until k).map(i => gaussians(i)._1).toArray
197+ val dists = (0 until k).map{ i =>
198+ new MultivariateGaussian (gaussians(i)._2, gaussians(i)._3)
199+ }.toArray
163200
164- val logLikelihood = ctx.accumulator(0.0 )
165-
166- // broadcast the current weights and distributions to all nodes
167- val dists = ctx.broadcast{
168- (0 until k).map(i => new MultivariateGaussian (gaussians(i)._2, gaussians(i)._3)).toArray
169- }
170- val weights = ctx.broadcast((0 until k).map(i => gaussians(i)._1).toArray)
201+ // create and broadcast curried cluster contribution function
202+ val compute = ctx.broadcast(computeExpectation(weights, dists)_)
171203
172- // calculate partial assignments for each sample in the data
173- // (often referred to as the "E" step in literature)
174- breezeData.foreach{ x =>
175- val p = (0 until k).map(i => eps + weights.value(i) * dists.value(i).pdf(x)).toArray
176-
177- val pSum = p.sum
178-
179- logLikelihood += math.log(pSum)
180-
181- // accumulate weighted sums
182- val xxt = x * new Transpose (x)
183- for (i <- 0 until k) {
184- p(i) /= pSum
185- accW(i) += p(i)
186- accMu(i) += x * p(i)
187- accSigma(i) += xxt * p(i)
188- }
189- }
204+ // aggregate the cluster contribution for all sample points
205+ val sums = breezeData.aggregate(zeroExpectationSum(k, d))(compute.value, addExpectationSums)
190206
191- // Collect the computed sums
192- val W = (0 until k).map(i => accW(i).value).toArray
193- val MU = (0 until k).map(i => accMu(i).value).toArray
194- val SIGMA = (0 until k).map(i => accSigma(i).value).toArray
207+ // Assignments to make the code more readable
208+ val logLikelihood = sums._1(0 )
209+ val W = sums._2
210+ val MU = sums._3
211+ val SIGMA = sums._4
195212
196213 // Create new distributions based on the partial assignments
197214 // (often referred to as the "M" step in literature)
@@ -203,7 +220,7 @@ class GaussianMixtureModelEM private (
203220 }.toArray
204221
205222 llhp = llh // current becomes previous
206- llh = logLikelihood.value // this is the freshly computed log-likelihood
223+ llh = logLikelihood // this is the freshly computed log-likelihood
207224 iter += 1
208225 } while (iter < maxIterations && Math .abs(llh- llhp) > convergenceTol)
209226
@@ -264,26 +281,4 @@ class GaussianMixtureModelEM private (
264281 }
265282 p
266283 }
267-
268- /** AccumulatorParam for Dense Breeze Vectors */
269- private object DenseDoubleVectorAccumulatorParam extends AccumulatorParam [DenseDoubleVector ] {
270- def zero (initialVector : DenseDoubleVector ): DenseDoubleVector = {
271- BreezeVector .zeros[Double ](initialVector.length)
272- }
273-
274- def addInPlace (a : DenseDoubleVector , b : DenseDoubleVector ): DenseDoubleVector = {
275- a += b
276- }
277- }
278-
279- /** AccumulatorParam for Dense Breeze Matrices */
280- private object DenseDoubleMatrixAccumulatorParam extends AccumulatorParam [DenseDoubleMatrix ] {
281- def zero (initialMatrix : DenseDoubleMatrix ): DenseDoubleMatrix = {
282- BreezeMatrix .zeros[Double ](initialMatrix.rows, initialMatrix.cols)
283- }
284-
285- def addInPlace (a : DenseDoubleMatrix , b : DenseDoubleMatrix ): DenseDoubleMatrix = {
286- a += b
287- }
288- }
289284}
0 commit comments