Skip to content
This repository was archived by the owner on Nov 15, 2024. It is now read-only.

Commit f584af2

Browse files
WeichenXu123MatthewRBruce
authored andcommitted
[SPARK-21818][ML][MLLIB] Fix bug of MultivariateOnlineSummarizer.variance generate negative result
Because of numerical error, MultivariateOnlineSummarizer.variance is possible to generate negative variance. **This is a serious bug because many algos in MLLib** **use stddev computed from** `sqrt(variance)` **it will generate NaN and crash the whole algorithm.** we can reproduce this bug use the following code: ``` val summarizer1 = (new MultivariateOnlineSummarizer) .add(Vectors.dense(3.0), 0.7) val summarizer2 = (new MultivariateOnlineSummarizer) .add(Vectors.dense(3.0), 0.4) val summarizer3 = (new MultivariateOnlineSummarizer) .add(Vectors.dense(3.0), 0.5) val summarizer4 = (new MultivariateOnlineSummarizer) .add(Vectors.dense(3.0), 0.4) val summarizer = summarizer1 .merge(summarizer2) .merge(summarizer3) .merge(summarizer4) println(summarizer.variance(0)) ``` This PR fix the bugs in `mllib.stat.MultivariateOnlineSummarizer.variance` and `ml.stat.SummarizerBuffer.variance`, and several places in `WeightedLeastSquares` test cases added. Author: WeichenXu <WeichenXu123@outlook.com> Closes apache#19029 from WeichenXu123/fix_summarizer_var_bug. (cherry picked from commit 0456b40) Signed-off-by: Sean Owen <sowen@cloudera.com>
1 parent f8deaf0 commit f584af2

3 files changed

Lines changed: 30 additions & 5 deletions

File tree

mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -440,7 +440,11 @@ private[ml] object WeightedLeastSquares {
440440
/**
441441
* Weighted population standard deviation of labels.
442442
*/
443-
def bStd: Double = math.sqrt(bbSum / wSum - bBar * bBar)
443+
def bStd: Double = {
444+
// We prevent variance from negative value caused by numerical error.
445+
val variance = math.max(bbSum / wSum - bBar * bBar, 0.0)
446+
math.sqrt(variance)
447+
}
444448

445449
/**
446450
* Weighted mean of (label * features).
@@ -471,7 +475,8 @@ private[ml] object WeightedLeastSquares {
471475
while (i < triK) {
472476
val l = j - 2
473477
val aw = aSum(l) / wSum
474-
std(l) = math.sqrt(aaValues(i) / wSum - aw * aw)
478+
// We prevent variance from negative value caused by numerical error.
479+
std(l) = math.sqrt(math.max(aaValues(i) / wSum - aw * aw, 0.0))
475480
i += j
476481
j += 1
477482
}
@@ -489,7 +494,8 @@ private[ml] object WeightedLeastSquares {
489494
while (i < triK) {
490495
val l = j - 2
491496
val aw = aSum(l) / wSum
492-
variance(l) = aaValues(i) / wSum - aw * aw
497+
// We prevent variance from negative value caused by numerical error.
498+
variance(l) = math.max(aaValues(i) / wSum - aw * aw, 0.0)
493499
i += j
494500
j += 1
495501
}

mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -213,8 +213,9 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
213213
var i = 0
214214
val len = currM2n.length
215215
while (i < len) {
216-
realVariance(i) = (currM2n(i) + deltaMean(i) * deltaMean(i) * weightSum(i) *
217-
(totalWeightSum - weightSum(i)) / totalWeightSum) / denominator
216+
// We prevent variance from negative value caused by numerical error.
217+
realVariance(i) = math.max((currM2n(i) + deltaMean(i) * deltaMean(i) * weightSum(i) *
218+
(totalWeightSum - weightSum(i)) / totalWeightSum) / denominator, 0.0)
218219
i += 1
219220
}
220221
}

mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,4 +270,22 @@ class MultivariateOnlineSummarizerSuite extends SparkFunSuite {
270270
assert(summarizer3.max ~== Vectors.dense(10.0, 0.0) absTol 1e-14)
271271
assert(summarizer3.min ~== Vectors.dense(0.0, -10.0) absTol 1e-14)
272272
}
273+
274+
test ("test zero variance (SPARK-21818)") {
275+
val summarizer1 = (new MultivariateOnlineSummarizer)
276+
.add(Vectors.dense(3.0), 0.7)
277+
val summarizer2 = (new MultivariateOnlineSummarizer)
278+
.add(Vectors.dense(3.0), 0.4)
279+
val summarizer3 = (new MultivariateOnlineSummarizer)
280+
.add(Vectors.dense(3.0), 0.5)
281+
val summarizer4 = (new MultivariateOnlineSummarizer)
282+
.add(Vectors.dense(3.0), 0.4)
283+
284+
val summarizer = summarizer1
285+
.merge(summarizer2)
286+
.merge(summarizer3)
287+
.merge(summarizer4)
288+
289+
assert(summarizer.variance(0) >= 0.0)
290+
}
273291
}

0 commit comments

Comments
 (0)