Skip to content
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -439,8 +439,9 @@ private[ml] object WeightedLeastSquares {

/**
* Weighted population standard deviation of labels.
* We prevent variance from negative value caused by numerical error.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not so against this, but this is really an implementation detail and not relevant to the caller. It's a value that is by definition nonnegative.

*/
def bStd: Double = math.sqrt(bbSum / wSum - bBar * bBar)
def bStd: Double = math.sqrt(math.max(bbSum / wSum - bBar * bBar, 0.0))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add comment here and bellow to clarify that we are preventing from negative value caused by numerical error.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are a couple more places where variance is computed in this file -- I think they need this too?


/**
* Weighted mean of (label * features).
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -436,8 +436,9 @@ private[ml] object SummaryBuilderImpl extends Logging {
var i = 0
val len = currM2n.length
while (i < len) {
realVariance(i) = (currM2n(i) + deltaMean(i) * deltaMean(i) * weightSum(i) *
(totalWeightSum - weightSum(i)) / totalWeightSum) / denominator
// We prevent variance from negative value caused by numerical error.
realVariance(i) = math.max((currM2n(i) + deltaMean(i) * deltaMean(i) * weightSum(i) *
(totalWeightSum - weightSum(i)) / totalWeightSum) / denominator, 0.0)
i += 1
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -213,8 +213,9 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
var i = 0
val len = currM2n.length
while (i < len) {
realVariance(i) = (currM2n(i) + deltaMean(i) * deltaMean(i) * weightSum(i) *
(totalWeightSum - weightSum(i)) / totalWeightSum) / denominator
// We prevent variance from negative value caused by numerical error.
realVariance(i) = math.max((currM2n(i) + deltaMean(i) * deltaMean(i) * weightSum(i) *
(totalWeightSum - weightSum(i)) / totalWeightSum) / denominator, 0.0)
i += 1
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,24 @@ class SummarizerSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(summarizer.count === 6)
}

test("summarizer buffer zero variance test (SPARK-21818)") {
val summarizer1 = new SummarizerBuffer()
.add(Vectors.dense(3.0), 0.7)
val summarizer2 = new SummarizerBuffer()
.add(Vectors.dense(3.0), 0.4)
val summarizer3 = new SummarizerBuffer()
.add(Vectors.dense(3.0), 0.5)
val summarizer4 = new SummarizerBuffer()
.add(Vectors.dense(3.0), 0.4)

val summarizer = summarizer1
.merge(summarizer2)
.merge(summarizer3)
.merge(summarizer4)

assert(summarizer.variance(0) >= 0.0)
}

test("summarizer buffer merging summarizer with empty summarizer") {
// If one of two is non-empty, this should return the non-empty summarizer.
// If both of them are empty, then just return the empty summarizer.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -270,4 +270,22 @@ class MultivariateOnlineSummarizerSuite extends SparkFunSuite {
assert(summarizer3.max ~== Vectors.dense(10.0, 0.0) absTol 1e-14)
assert(summarizer3.min ~== Vectors.dense(0.0, -10.0) absTol 1e-14)
}

test ("test zero variance (SPARK-21818)") {
val summarizer1 = (new MultivariateOnlineSummarizer)
.add(Vectors.dense(3.0), 0.7)
val summarizer2 = (new MultivariateOnlineSummarizer)
.add(Vectors.dense(3.0), 0.4)
val summarizer3 = (new MultivariateOnlineSummarizer)
.add(Vectors.dense(3.0), 0.5)
val summarizer4 = (new MultivariateOnlineSummarizer)
.add(Vectors.dense(3.0), 0.4)

val summarizer = summarizer1
.merge(summarizer2)
.merge(summarizer3)
.merge(summarizer4)

assert(summarizer.variance(0) >= 0.0)
}
}