Skip to content
Closed
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion core/src/main/scala/org/apache/spark/partial/BoundedDouble.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,23 @@ package org.apache.spark.partial
* A Double value with error bars and associated confidence.
*/
class BoundedDouble(val mean: Double, val confidence: Double, val low: Double, val high: Double) {
override def toString(): String = "[%.3f, %.3f]".format(low, high)
override def toString(): String =
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, I think this is all good, except I think the toString should be left alone. I forgot to mention this. Not that I really expect anyone to depend on the format, but let's leave it since it's a public class.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I definitely can put it back, but the previous toString was just weird - it
only printed the bounds. Anyway, I'll update this in a sec (to go back).
Let me know if you change your mind.

On Saturday, April 2, 2016, Sean Owen [email protected] wrote:

In core/src/main/scala/org/apache/spark/partial/BoundedDouble.scala
#12016 (comment):

@@ -21,5 +21,23 @@ package org.apache.spark.partial

  • A Double value with error bars and associated confidence.
    */
    class BoundedDouble(val mean: Double, val confidence: Double, val low: Double, val high: Double) {
    • override def toString(): String = "[%.3f, %.3f]".format(low, high)
    • override def toString(): String =

OK, I think this is all good, except I think the toString should be left
alone. I forgot to mention this. Not that I really expect anyone to depend
on the format, but let's leave it since it's a public class.


You are receiving this because you authored the thread.
Reply to this email directly or view it on GitHub
https://github.com/apache/spark/pull/12016/files/5e3c47762f79b89544360c383db10b3d77411109#r58301669

Want to work at Handy? Check out our culture deck and open roles
http://www.handy.com/careers
Latest news http://www.handy.com/press at Handy
Handy just raised $50m
http://venturebeat.com/2015/11/02/on-demand-home-service-handy-raises-50m-in-round-led-by-fidelity/ led
by Fidelity

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

"BoundedDouble(%.3f, %.3f, %.3f, %.3f)".format(mean, confidence, low, high)

override def hashCode: Int =
this.mean.hashCode ^ this.confidence.hashCode ^ this.low.hashCode ^ this.high.hashCode

/**
* Note that consistent with Double, any NaN value will make equality false
*/
override def equals(that: Any): Boolean =
that match {
case that: BoundedDouble => {
this.mean == that.mean &&
this.confidence == that.confidence &&
this.low == that.low &&
this.high == that.high
}
case _ => false
}
}
36 changes: 23 additions & 13 deletions core/src/main/scala/org/apache/spark/partial/SumEvaluator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,9 @@ import org.apache.spark.util.StatCounter
private[spark] class SumEvaluator(totalOutputs: Int, confidence: Double)
extends ApproximateEvaluator[StatCounter, BoundedDouble] {

// modified in merge
var outputsMerged = 0
var counter = new StatCounter
val counter = new StatCounter

override def merge(outputId: Int, taskResult: StatCounter) {
outputsMerged += 1
Expand All @@ -40,30 +41,39 @@ private[spark] class SumEvaluator(totalOutputs: Int, confidence: Double)
override def currentResult(): BoundedDouble = {
if (outputsMerged == totalOutputs) {
new BoundedDouble(counter.sum, 1.0, counter.sum, counter.sum)
} else if (outputsMerged == 0) {
} else if (outputsMerged == 0 || counter.count == 0) {
new BoundedDouble(0, 0.0, Double.NegativeInfinity, Double.PositiveInfinity)
} else {
val p = outputsMerged.toDouble / totalOutputs
val meanEstimate = counter.mean
val meanVar = counter.sampleVariance / counter.count
val countEstimate = (counter.count + 1 - p) / p
val countVar = (counter.count + 1) * (1 - p) / (p * p)
val sumEstimate = meanEstimate * countEstimate
val sumVar = (meanEstimate * meanEstimate * countVar) +
(countEstimate * countEstimate * meanVar) +
(meanVar * countVar)
val sumStdev = math.sqrt(sumVar)
val confFactor = {
if (counter.count > 100) {

val meanVar = counter.sampleVariance / counter.count

// branch at this point because counter.count == 1 implies counter.sampleVariance == Nan
// and we don't want to ever return a bound of NaN
if (meanVar.isNaN || counter.count == 1) {
new BoundedDouble(sumEstimate, confidence, Double.NegativeInfinity, Double.PositiveInfinity)
} else {
val countVar = (counter.count + 1) * (1 - p) / (p * p)
val sumVar = (meanEstimate * meanEstimate * countVar) +
(countEstimate * countEstimate * meanVar) +
(meanVar * countVar)
val sumStdev = math.sqrt(sumVar)
val confFactor = if (counter.count > 100) {
new NormalDistribution().inverseCumulativeProbability(1 - (1 - confidence) / 2)
} else {
// note that if this goes to 0, TDistribution will throw an exception.
// Hence special casing 1 above.
val degreesOfFreedom = (counter.count - 1).toInt
new TDistribution(degreesOfFreedom).inverseCumulativeProbability(1 - (1 - confidence) / 2)
}

val low = sumEstimate - confFactor * sumStdev
val high = sumEstimate + confFactor * sumStdev
new BoundedDouble(sumEstimate, confidence, low, high)
}
val low = sumEstimate - confFactor * sumStdev
val high = sumEstimate + confFactor * sumStdev
new BoundedDouble(sumEstimate, confidence, low, high)
}
}
}
107 changes: 107 additions & 0 deletions core/src/test/scala/org/apache/spark/partial/SumEvaluatorSuite.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.partial

import org.apache.spark._
import org.apache.spark.util.StatCounter

class SumEvaluatorSuite extends SparkFunSuite with SharedSparkContext {

test("correct handling of count 1") {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While you're here throw in a basic test for count == 0 too, and ideally some normal-path case for completeness

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.


// setup
val counter = new StatCounter(List(2.0))
// count of 10 because it's larger than 1,
// and 0.95 because that's the default
val evaluator = new SumEvaluator(10, 0.95)
// arbitrarily assign id 1
evaluator.merge(1, counter)

// execute
val res = evaluator.currentResult()
// 38.0 - 7.1E-15 because that's how the maths shakes out
val targetMean = 38.0 - 7.1E-15

// Sanity check that equality works on BoundedDouble
assert(new BoundedDouble(2.0, 0.95, 1.1, 1.2) == new BoundedDouble(2.0, 0.95, 1.1, 1.2))

// actual test
assert(res ==
new BoundedDouble(targetMean, 0.950, Double.NegativeInfinity, Double.PositiveInfinity))
}

test("correct handling of count 0") {

// setup
val counter = new StatCounter(List())
// count of 10 because it's larger than 0,
// and 0.95 because that's the default
val evaluator = new SumEvaluator(10, 0.95)
// arbitrarily assign id 1
evaluator.merge(1, counter)

// execute
val res = evaluator.currentResult()
// assert
assert(res == new BoundedDouble(0, 0.0, Double.NegativeInfinity, Double.PositiveInfinity))
}

test("correct handling of NaN") {

// setup
val counter = new StatCounter(List(1, Double.NaN, 2))
// count of 10 because it's larger than 0,
// and 0.95 because that's the default
val evaluator = new SumEvaluator(10, 0.95)
// arbitrarily assign id 1
evaluator.merge(1, counter)

// execute
val res = evaluator.currentResult()
// assert - note semantics of == in face of NaN
assert(res.mean.isNaN)
assert(res.confidence == 0.95)
assert(res.low == Double.NegativeInfinity)
assert(res.high == Double.PositiveInfinity)
}

test("correct handling of > 1 values") {

// setup
val counter = new StatCounter(List(1, 3, 2))
// count of 10 because it's larger than 0,
// and 0.95 because that's the default
val evaluator = new SumEvaluator(10, 0.95)
// arbitrarily assign id 1
evaluator.merge(1, counter)

// execute
val res = evaluator.currentResult()

// These vals because that's how the maths shakes out
val targetMean = 78.0
val targetLow = -117.617 + 2.732357258139473E-5
val targetHigh = 273.617 - 2.7323572624027292E-5
val target = new BoundedDouble(targetMean, 0.95, targetLow, targetHigh)


// check that values are within expected tolerance of expectation
assert(res == target)
}

}