@@ -20,6 +20,7 @@ import breeze.linalg.{Vector => BV}
2020
2121import org .apache .spark .mllib .linalg .{Vector , Vectors }
2222import org .apache .spark .rdd .RDD
23+ import breeze .linalg .axpy
2324
2425case class VectorRDDStatisticalSummary (
2526 mean : Vector ,
@@ -58,17 +59,22 @@ class VectorRDDFunctions(self: RDD[Vector]) extends Serializable {
5859 BV .fill(size){Double .MaxValue }))(
5960 seqOp = (c, v) => (c, v) match {
6061 case ((prevMean, prevM2n, cnt, nnzVec, maxVec, minVec), currData) =>
61- val currMean = ((prevMean :* cnt) + currData) :/ (cnt + 1.0 )
62- val nonZeroCnt = Vectors
63- .sparse(size, currData.activeKeysIterator.toSeq.map(x => (x, 1.0 ))).toBreeze
62+ val currMean = prevMean :* (cnt / (cnt + 1.0 ))
63+ axpy(1.0 / (cnt+ 1.0 ), currData, currMean)
64+ axpy(- 1.0 , currData, prevMean)
65+ prevMean :*= (currMean - currData)
66+ axpy(1.0 , prevMean, prevM2n)
67+ axpy(1.0 ,
68+ Vectors .sparse(size, currData.activeKeysIterator.toSeq.map(x => (x, 1.0 ))).toBreeze,
69+ nnzVec)
6470 currData.activeIterator.foreach { case (id, value) =>
6571 if (maxVec(id) < value) maxVec(id) = value
6672 if (minVec(id) > value) minVec(id) = value
6773 }
6874 (currMean,
69- prevM2n + ((currData - prevMean) :* (currData - currMean)) ,
75+ prevM2n,
7076 cnt + 1.0 ,
71- nnzVec + nonZeroCnt ,
77+ nnzVec,
7278 maxVec,
7379 minVec)
7480 },
@@ -77,23 +83,30 @@ class VectorRDDFunctions(self: RDD[Vector]) extends Serializable {
7783 (lhsMean, lhsM2n, lhsCnt, lhsNNZ, lhsMax, lhsMin),
7884 (rhsMean, rhsM2n, rhsCnt, rhsNNZ, rhsMax, rhsMin)) =>
7985 val totalCnt = lhsCnt + rhsCnt
80- val totalMean = (lhsMean :* lhsCnt) + (rhsMean :* rhsCnt) :/ totalCnt
8186 val deltaMean = rhsMean - lhsMean
82- val totalM2n =
83- lhsM2n + rhsM2n + (((deltaMean :* deltaMean) :* (lhsCnt * rhsCnt)) :/ totalCnt)
87+ lhsMean :*= (lhsCnt / totalCnt)
88+ axpy(rhsCnt/ totalCnt, rhsMean, lhsMean)
89+ val totalMean = lhsMean
90+ deltaMean :*= deltaMean
91+ axpy(lhsCnt* rhsCnt/ totalCnt, deltaMean, lhsM2n)
92+ axpy(1.0 , rhsM2n, lhsM2n)
93+ val totalM2n = lhsM2n
8494 rhsMax.activeIterator.foreach { case (id, value) =>
8595 if (lhsMax(id) < value) lhsMax(id) = value
8696 }
8797 rhsMin.activeIterator.foreach { case (id, value) =>
8898 if (lhsMin(id) > value) lhsMin(id) = value
8999 }
90- (totalMean, totalM2n, totalCnt, lhsNNZ + rhsNNZ, lhsMax, lhsMin)
100+ axpy(1.0 , rhsNNZ, lhsNNZ)
101+ (totalMean, totalM2n, totalCnt, lhsNNZ, lhsMax, lhsMin)
91102 }
92103 )
93104
105+ results._2 :/= results._3
106+
94107 VectorRDDStatisticalSummary (
95108 Vectors .fromBreeze(results._1),
96- Vectors .fromBreeze(results._2 :/ results._3 ),
109+ Vectors .fromBreeze(results._2),
97110 results._3.toLong,
98111 Vectors .fromBreeze(results._4),
99112 Vectors .fromBreeze(results._5),
0 commit comments