diff --git a/core/src/main/scala/org/apache/spark/partial/BoundedDouble.scala b/core/src/main/scala/org/apache/spark/partial/BoundedDouble.scala index 48b9434153172..d06b2c67d2077 100644 --- a/core/src/main/scala/org/apache/spark/partial/BoundedDouble.scala +++ b/core/src/main/scala/org/apache/spark/partial/BoundedDouble.scala @@ -21,5 +21,23 @@ package org.apache.spark.partial * A Double value with error bars and associated confidence. */ class BoundedDouble(val mean: Double, val confidence: Double, val low: Double, val high: Double) { + override def toString(): String = "[%.3f, %.3f]".format(low, high) + + override def hashCode: Int = + this.mean.hashCode ^ this.confidence.hashCode ^ this.low.hashCode ^ this.high.hashCode + + /** + * Note that consistent with Double, any NaN value will make equality false + */ + override def equals(that: Any): Boolean = + that match { + case that: BoundedDouble => { + this.mean == that.mean && + this.confidence == that.confidence && + this.low == that.low && + this.high == that.high + } + case _ => false + } } diff --git a/core/src/main/scala/org/apache/spark/partial/SumEvaluator.scala b/core/src/main/scala/org/apache/spark/partial/SumEvaluator.scala index 44295e5a1affe..5fe33583166c3 100644 --- a/core/src/main/scala/org/apache/spark/partial/SumEvaluator.scala +++ b/core/src/main/scala/org/apache/spark/partial/SumEvaluator.scala @@ -29,8 +29,9 @@ import org.apache.spark.util.StatCounter private[spark] class SumEvaluator(totalOutputs: Int, confidence: Double) extends ApproximateEvaluator[StatCounter, BoundedDouble] { + // modified in merge var outputsMerged = 0 - var counter = new StatCounter + val counter = new StatCounter override def merge(outputId: Int, taskResult: StatCounter) { outputsMerged += 1 @@ -40,30 +41,39 @@ private[spark] class SumEvaluator(totalOutputs: Int, confidence: Double) override def currentResult(): BoundedDouble = { if (outputsMerged == totalOutputs) { new BoundedDouble(counter.sum, 1.0, counter.sum, counter.sum) - } else if (outputsMerged == 0) { + } else if (outputsMerged == 0 || counter.count == 0) { new BoundedDouble(0, 0.0, Double.NegativeInfinity, Double.PositiveInfinity) } else { val p = outputsMerged.toDouble / totalOutputs val meanEstimate = counter.mean - val meanVar = counter.sampleVariance / counter.count val countEstimate = (counter.count + 1 - p) / p - val countVar = (counter.count + 1) * (1 - p) / (p * p) val sumEstimate = meanEstimate * countEstimate - val sumVar = (meanEstimate * meanEstimate * countVar) + - (countEstimate * countEstimate * meanVar) + - (meanVar * countVar) - val sumStdev = math.sqrt(sumVar) - val confFactor = { - if (counter.count > 100) { + + val meanVar = counter.sampleVariance / counter.count + + // branch at this point because counter.count == 1 implies counter.sampleVariance == Nan + // and we don't want to ever return a bound of NaN + if (meanVar.isNaN || counter.count == 1) { + new BoundedDouble(sumEstimate, confidence, Double.NegativeInfinity, Double.PositiveInfinity) + } else { + val countVar = (counter.count + 1) * (1 - p) / (p * p) + val sumVar = (meanEstimate * meanEstimate * countVar) + + (countEstimate * countEstimate * meanVar) + + (meanVar * countVar) + val sumStdev = math.sqrt(sumVar) + val confFactor = if (counter.count > 100) { new NormalDistribution().inverseCumulativeProbability(1 - (1 - confidence) / 2) } else { + // note that if this goes to 0, TDistribution will throw an exception. + // Hence special casing 1 above. val degreesOfFreedom = (counter.count - 1).toInt new TDistribution(degreesOfFreedom).inverseCumulativeProbability(1 - (1 - confidence) / 2) } + + val low = sumEstimate - confFactor * sumStdev + val high = sumEstimate + confFactor * sumStdev + new BoundedDouble(sumEstimate, confidence, low, high) } - val low = sumEstimate - confFactor * sumStdev - val high = sumEstimate + confFactor * sumStdev - new BoundedDouble(sumEstimate, confidence, low, high) } } } diff --git a/core/src/test/scala/org/apache/spark/partial/SumEvaluatorSuite.scala b/core/src/test/scala/org/apache/spark/partial/SumEvaluatorSuite.scala new file mode 100644 index 0000000000000..a79f5b4d74467 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/partial/SumEvaluatorSuite.scala @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.partial + +import org.apache.spark._ +import org.apache.spark.util.StatCounter + +class SumEvaluatorSuite extends SparkFunSuite with SharedSparkContext { + + test("correct handling of count 1") { + + // setup + val counter = new StatCounter(List(2.0)) + // count of 10 because it's larger than 1, + // and 0.95 because that's the default + val evaluator = new SumEvaluator(10, 0.95) + // arbitrarily assign id 1 + evaluator.merge(1, counter) + + // execute + val res = evaluator.currentResult() + // 38.0 - 7.1E-15 because that's how the maths shakes out + val targetMean = 38.0 - 7.1E-15 + + // Sanity check that equality works on BoundedDouble + assert(new BoundedDouble(2.0, 0.95, 1.1, 1.2) == new BoundedDouble(2.0, 0.95, 1.1, 1.2)) + + // actual test + assert(res == + new BoundedDouble(targetMean, 0.950, Double.NegativeInfinity, Double.PositiveInfinity)) + } + + test("correct handling of count 0") { + + // setup + val counter = new StatCounter(List()) + // count of 10 because it's larger than 0, + // and 0.95 because that's the default + val evaluator = new SumEvaluator(10, 0.95) + // arbitrarily assign id 1 + evaluator.merge(1, counter) + + // execute + val res = evaluator.currentResult() + // assert + assert(res == new BoundedDouble(0, 0.0, Double.NegativeInfinity, Double.PositiveInfinity)) + } + + test("correct handling of NaN") { + + // setup + val counter = new StatCounter(List(1, Double.NaN, 2)) + // count of 10 because it's larger than 0, + // and 0.95 because that's the default + val evaluator = new SumEvaluator(10, 0.95) + // arbitrarily assign id 1 + evaluator.merge(1, counter) + + // execute + val res = evaluator.currentResult() + // assert - note semantics of == in face of NaN + assert(res.mean.isNaN) + assert(res.confidence == 0.95) + assert(res.low == Double.NegativeInfinity) + assert(res.high == Double.PositiveInfinity) + } + + test("correct handling of > 1 values") { + + // setup + val counter = new StatCounter(List(1, 3, 2)) + // count of 10 because it's larger than 0, + // and 0.95 because that's the default + val evaluator = new SumEvaluator(10, 0.95) + // arbitrarily assign id 1 + evaluator.merge(1, counter) + + // execute + val res = evaluator.currentResult() + + // These vals because that's how the maths shakes out + val targetMean = 78.0 + val targetLow = -117.617 + 2.732357258139473E-5 + val targetHigh = 273.617 - 2.7323572624027292E-5 + val target = new BoundedDouble(targetMean, 0.95, targetLow, targetHigh) + + + // check that values are within expected tolerance of expectation + assert(res == target) + } + +}