Skip to content
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 28 additions & 6 deletions python/pyspark/statcounter.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,14 @@
import copy
import math

_have_numpy = False
try:
from numpy import maximum, minimum, sqrt
_have_numpy = True
except:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's better to have ImportError here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about do in this way:

try:
from numpy import maximum, minimum, sqrt
except ImportError:
maximum = max
minimum = min
sqrt = math.sqrt

This will simplify later codes.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! This is much better, updating the PR now...

# no NumPy, so fall back on scalar operators
pass


class StatCounter(object):

Expand All @@ -39,10 +47,14 @@ def merge(self, value):
self.n += 1
self.mu += delta / self.n
self.m2 += delta * (value - self.mu)
if self.maxValue < value:
self.maxValue = value
if self.minValue > value:
self.minValue = value
if not _have_numpy:
if self.maxValue < value:
self.maxValue = value
if self.minValue > value:
self.minValue = value
else:
self.maxValue = maximum(self.maxValue, value)
self.minValue = minimum(self.minValue, value)

return self

Expand Down Expand Up @@ -70,8 +82,12 @@ def mergeStats(self, other):
else:
self.mu = (self.mu * self.n + other.mu * other.n) / (self.n + other.n)

if not _have_numpy:
self.maxValue = max(self.maxValue, other.maxValue)
self.minValue = min(self.minValue, other.minValue)
else:
self.maxValue = maximum(self.maxValue, other.maxValue)
self.minValue = minimum(self.minValue, other.minValue)

self.m2 += other.m2 + (delta * delta * self.n * other.n) / (self.n + other.n)
self.n += other.n
Expand Down Expand Up @@ -115,14 +131,20 @@ def sampleVariance(self):

# Return the standard deviation of the values.
def stdev(self):
return math.sqrt(self.variance())
if not _have_numpy:
return math.sqrt(self.variance())
else:
return sqrt(self.variance())

#
# Return the sample standard deviation of the values, which corrects for bias in estimating the
# variance by dividing by N-1 instead of N.
#
def sampleStdev(self):
return math.sqrt(self.sampleVariance())
if not _have_numpy:
return math.sqrt(self.sampleVariance())
else:
return sqrt(self.sampleVariance())

def __repr__(self):
return ("(count: %s, mean: %s, stdev: %s, max: %s, min: %s)" %
Expand Down
25 changes: 25 additions & 0 deletions python/pyspark/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,19 @@
from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger

_have_scipy = False
_have_numpy = False
try:
import scipy.sparse
_have_scipy = True
except:
# No SciPy, but that's okay, we'll skip those tests
pass
try:
from numpy import array
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just try to import numpy, this array will overwrite array.array, make other unit tests fail.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@davies thanks, good catch, should be fixed now!

_have_numpy = True
except:
# No NumPy, but that's okay, we'll skip those tests
pass


SPARK_HOME = os.environ["SPARK_HOME"]
Expand Down Expand Up @@ -914,9 +921,27 @@ def test_serialize(self):
self.assertEqual(expected, observed)


@unittest.skipIf(not _have_numpy, "NumPy not installed")
class NumPyTests(PySparkTestCase):
"""General PySpark tests that depend on numpy """

def test_statcounter_array(self):
from numpy import array
x = self.sc.parallelize([array([1.0,1.0]), array([2.0,2.0]), array([3.0,3.0])])
s = x.stats()
self.assertSequenceEqual([2.0,2.0], s.mean().tolist())
self.assertSequenceEqual([1.0,1.0], s.min().tolist())
self.assertSequenceEqual([3.0,3.0], s.max().tolist())
self.assertSequenceEqual([1.0,1.0], s.sampleStdev().tolist())


if __name__ == "__main__":
if not _have_scipy:
print "NOTE: Skipping SciPy tests as it does not seem to be installed"
if not _have_numpy:
print "NOTE: Skipping NumPy tests as it does not seem to be installed"
unittest.main()
if not _have_scipy:
print "NOTE: SciPy tests were skipped as it does not seem to be installed"
if not _have_numpy:
print "NOTE: NumPy tests were skipped as it does not seem to be installed"