Skip to content
This repository was archived by the owner on Dec 21, 2023. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion src/core/data/sframe/gl_sarray.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -506,7 +506,6 @@ class gl_sarray {
/* Make Friends */
/* */
/**************************************************************************/

friend gl_sarray operator+(const flexible_type& opnd, const gl_sarray& opnd2);
friend gl_sarray operator-(const flexible_type& opnd, const gl_sarray& opnd2);
friend gl_sarray operator*(const flexible_type& opnd, const gl_sarray& opnd2);
Expand Down
2 changes: 2 additions & 0 deletions src/core/storage/sframe_interface/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,6 @@ unity_sarray_builder.cpp
unity_sframe.cpp
unity_sframe_builder.cpp
unity_sgraph.cpp
REQUIRES
unity_sketches
)
53 changes: 53 additions & 0 deletions src/core/storage/sframe_interface/unity_sarray.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1102,6 +1102,59 @@ flexible_type unity_sarray::mean() {
}
}

flexible_type unity_sarray::median(bool approx) {
log_func_entry();

flex_type_enum type = dtype();
if(type != flex_type_enum::INTEGER && type != flex_type_enum::FLOAT) {
log_and_throw("Can not calculate median on non-numeric SArray.");
}
if (size() == 0) {
return flex_undefined();
}

// Use sketch to get a fast approximate answer
turi::unity_sketch* sketch = new turi::unity_sketch();
gl_sarray data = std::static_pointer_cast<unity_sarray>(shared_from_this());
data = data.dropna();
sketch->construct_from_sarray(data);
double approx_median = sketch->get_quantile(0.5);

flexible_type result;
if (!approx) {
const double epsilon = sketch->numeric_epsilon();
const double upper_bound = approx_median + (epsilon * data.size());
const double lower_bound = approx_median - (epsilon * data.size());

// Count the number below lower_bound.
// Store all values between lower_bound and upper_bound.
atomic<size_t> n_below_a = 0;
std::vector<flexible_type> candidates;
std::mutex candidate_lock;
auto count_median = [&](size_t, const std::shared_ptr<sframe_rows>& rows) {
for (const auto& row : *rows) {
const flexible_type x = row[0];
if(x < lower_bound) {
++n_below_a;
} else if(x <= upper_bound) {
std::lock_guard<std::mutex> lg(candidate_lock);
candidates.push_back(x);
}
}
return false;
};
data.materialize_to_callback(count_median);

size_t median_index = (data.size() / 2) - n_below_a;
std::nth_element(candidates.begin(), candidates.begin() + median_index, candidates.end());
result = candidates[median_index];
} else {
result = approx_median;
}

return result;
}

flexible_type unity_sarray::std(size_t ddof) {
log_func_entry();
flexible_type variance = this->var(ddof);
Expand Down
9 changes: 9 additions & 0 deletions src/core/storage/sframe_interface/unity_sarray.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,15 @@ class unity_sarray: public unity_sarray_base {
*/
flexible_type mean();

/**
* Returns the medain of the elements in the sarray.
*
* Invoking on an empty sarray returns flex_undefined.
* Invoking on a non-numeric type throws an exception.
* Undefined values in the array are skipped.
*/
flexible_type median(bool approx);

/**
* Returns the standard deviation of the elements in sarray as a flex_float.
*
Expand Down
2 changes: 1 addition & 1 deletion src/core/storage/sframe_interface/unity_sframe.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@
#include <visualization/server/table.hpp>

#include <model_server/lib/image_util.hpp>
#include <ml/sketches/unity_sketch.hpp>
#include <algorithm>
#include <random>
#include <string>
Expand All @@ -59,6 +58,7 @@
#ifdef TC_HAS_PYTHON
#include <core/system/lambda/pylambda_function.hpp>
#endif

namespace turi {

using namespace turi::query_eval;
Expand Down
4 changes: 2 additions & 2 deletions src/ml/sketches/quantile_sketch.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class streaming_quantile_sketch;
*
* Usage is simple:
* \code
* // constrcut sketch at a particular size
* // construct sketch at a particular size
* quantile_sketch<double> sketch(100000, 0.01);
* ...
* insert any number elements into the sketch using
Expand Down Expand Up @@ -89,7 +89,7 @@ class streaming_quantile_sketch;
* -----------------
* The basic mechanic of the sketch lies in making a hierarchy of sketches, each
* having size m_b. The first sketch has error 0. The second sketch has error
* 1/b, the 3rd sketch has error 2/b and so on. These sketches are treated as
* 1/b, the 3rd sketch has error 2/b and so on. These sketches are treated
* like a binary sequence:
* - When a buffer of size b has been accumulated, it gets sorted and
* inserted into sketch 1
Expand Down
4 changes: 4 additions & 0 deletions src/ml/sketches/streaming_quantile_sketch.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,10 @@ class streaming_quantile_sketch {
m_final.m_epsilon = m_epsilon;
}

double get_epsilon() const {
return m_epsilon;
}

void save(oarchive& oarc) const {
oarc << m_epsilon << m_elements_inserted
<< m_initial_sketch_size << m_levels << m_final;
Expand Down
7 changes: 4 additions & 3 deletions src/ml/sketches/unity_sketch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -466,16 +466,14 @@ void unity_sketch::combine_global(std::vector<turi::mutex>& thr_locks) {
reset_global_sketches_and_statistics();

// merge all the sketches
for (size_t i = 0;i < m_thrlocal.size(); ++i) {
for (size_t i = 0; i < m_thrlocal.size(); ++i) {
std::unique_lock<turi::mutex> thrlocal_lock(thr_locks[i]);
m_num_elements_processed += m_thrlocal[i].num_elements_processed;
m_discrete_sketch.combine(m_thrlocal[i].discrete_sketch);
m_undefined_count += m_thrlocal[i].undefined_count;
if (m_is_numeric) {
m_numeric_sketch.combine(m_thrlocal[i].numeric_sketch);
}


}

if (m_is_child_sketch) {
Expand Down Expand Up @@ -563,6 +561,9 @@ void unity_sketch::numeric_sketch_struct::combine(const numeric_sketch_struct& o
max = std::max(other.max, max);
sum += other.sum;

DASSERT_TRUE(epsilon == -1 || epsilon == other.epsilon);
epsilon = other.epsilon;

if (num_items + other.num_items > 0) {

double delta = other.mean - mean;
Expand Down
19 changes: 15 additions & 4 deletions src/ml/sketches/unity_sketch.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,9 @@
#include <core/logging/logger.hpp>
namespace turi {

// forward declarations
class unity_sarray;
class unity_sketch;

namespace sketches {

// forward declarations
template <typename T, typename Comparator>
class streaming_quantile_sketch;
template <typename T, typename Comparator>
Expand Down Expand Up @@ -273,6 +270,18 @@ class unity_sketch: public unity_sketch_base {
return m_numeric_sketch.min;
}

/**
* Returns the epsilon value used by the numeric sketch. Returns NaN on an
* empty array. Throws an exception if called on an sarray with non-numeric
* type.
*/
inline double numeric_epsilon() {
if (!m_is_numeric) log_and_throw("Epsilon value not available for a non-numeric column");
commit_global_if_out_of_date();
std::unique_lock<turi::mutex> global_lock(lock);
return m_numeric_sketch.max;
}

/**
* Returns the sum of the values in the sarray. Returns 0 on an empty
* array. Throws an exception if called on an sarray with non-numeric
Expand Down Expand Up @@ -330,6 +339,7 @@ class unity_sketch: public unity_sketch_base {
double mean = 0.0;
size_t num_items = 0;
double m2;
double epsilon = -1;

void reset();

Expand Down Expand Up @@ -452,6 +462,7 @@ class unity_sketch: public unity_sketch_base {
m_numeric_sketch.sum = 0;
m_numeric_sketch.m2 = 0;
m_numeric_sketch.num_items = 0;
m_numeric_sketch.epsilon = NAN;
}

inline void increase_nested_element_count(unity_sketch& nested_sketch, size_t thr, size_t count) {
Expand Down
1 change: 1 addition & 0 deletions src/model_server/lib/api/unity_sarray_interface.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ GENERATE_INTERFACE_AND_PROXY(unity_sarray_base, unity_sarray_proxy,
(flexible_type, min, )
(flexible_type, sum, )
(flexible_type, mean, )
(flexible_type, median, (bool))
(flexible_type, std, (size_t))
(flexible_type, var, (size_t))
(size_t, num_missing, )
Expand Down
1 change: 0 additions & 1 deletion src/model_server/server/registration.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
#include <core/storage/sframe_interface/unity_sframe.hpp>
#include <core/storage/sframe_interface/unity_sframe_builder.hpp>
#include <core/storage/sframe_interface/unity_sgraph.hpp>
#include <ml/sketches/unity_sketch.hpp>

#include <model_server/lib/extensions/ml_model.hpp>

Expand Down
3 changes: 1 addition & 2 deletions src/model_server/server/unity_server_init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,8 @@
#include <core/storage/sframe_interface/unity_sframe.hpp>
#include <core/storage/sframe_interface/unity_sframe_builder.hpp>
#include <core/storage/sframe_interface/unity_sgraph.hpp>
#include <ml/sketches/unity_sketch.hpp>
#include <model_server/lib/simple_model.hpp>
#

namespace turi {
unity_server_initializer::~unity_server_initializer() {}

Expand Down
3 changes: 3 additions & 0 deletions src/python/turicreate/_cython/cy_sarray.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ cdef extern from "<core/storage/sframe_interface/unity_sarray.hpp>" namespace "t
flexible_type min() except +
flexible_type sum() except +
flexible_type mean() except +
flexible_type median(bint) except +
flexible_type std(size_t) except +
flexible_type var(size_t) except +
size_t nnz() except +
Expand Down Expand Up @@ -142,6 +143,8 @@ cdef class UnitySArrayProxy:

cpdef mean(self)

cpdef median(self, bint approx)

cpdef std(self, size_t ddof)

cpdef var(self, size_t ddof)
Expand Down
6 changes: 6 additions & 0 deletions src/python/turicreate/_cython/cy_sarray.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,12 @@ cdef class UnitySArrayProxy:
tmp = self.thisptr.mean()
return pyobject_from_flexible_type(tmp)

cpdef median(self, bint approx):
cdef flexible_type tmp
with nogil:
tmp = self.thisptr.median(approx)
return pyobject_from_flexible_type(tmp)

cpdef std(self, size_t ddof):
cdef flexible_type tmp
with nogil:
Expand Down
33 changes: 33 additions & 0 deletions src/python/turicreate/data_structures/sarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -2285,6 +2285,10 @@ def mean(self):
out : float | turicreate.Image
Mean of all values in SArray, or image holding per-pixel mean
across the input SArray.

See Also
--------
median
"""
with cython_context():
if self.dtype == _Image:
Expand All @@ -2294,6 +2298,35 @@ def mean(self):
else:
return self.__proxy__.mean()

def median(self, approximate=False):
"""
Median of all the values in the SArray.

Note: no linear smoothing is performed. If the lenght of the SArray is
an odd number. Then a value between `a` and `b` will be used, where
`a` and `b` are the two middle values.

Parameters
----------
approximate : bool
If True an approximate value will be returned. Calculating
an approximate value is faster. The approximate value will
be within 5% of the exact value.

Returns
-------
out : float | turicreate.Image
Median of all values in SArray

See Also
--------
mean
"""
if not isinstance(approximate, bool):
raise("\"approximate\" must be a bool.")

return self.__proxy__.median(approximate)

def std(self, ddof=0):
"""
Standard deviation of all the values in the SArray.
Expand Down
58 changes: 58 additions & 0 deletions src/python/turicreate/test/test_sarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1037,6 +1037,64 @@ def test_max_min_sum_mean(self):
with self.assertRaises(RuntimeError):
a.mean()

def test_median(self):

def check_correctness(l):
sa = SArray(l)
l = list(filter(lambda x: x is not None, l))
if(len(l) % 2 == 1):
self.assertAlmostEqual(sa.median(), np.median(l))
else:
l = sorted(l)
self.assertTrue(l[(len(l)//2)-1] <= sa.median() <= l[len(l)//2])

def check_approximate_correctness(l):
sa = SArray(l)
approx = sa.median(approximate=True)
# The approximate answer should be within 5%
if(len(l) % 2 == 1):
exact = np.median(l)
fuzzy_lower_bound = exact - (abs(exact) * 0.5)
fuzzy_upper_bound = exact + (abs(exact) * 0.5)
else:
l = sorted(l)
lower, upper = l[(len(l)//2)-1], l[len(l)//2]
fuzzy_lower_bound = lower - (abs(lower) * 0.5)
fuzzy_upper_bound = upper + (abs(upper) * 0.5)
self.assertTrue(fuzzy_lower_bound <= approx <= fuzzy_upper_bound)

check_correctness([10, 201, -3]) # int, odd length
check_correctness([12, 3, -1, 5]) # int, even length
check_approximate_correctness([1, 30, 99, 0, 10]) # int, odd length
check_approximate_correctness([-4, 10, -1, -100]) # int, even length

check_correctness([-2.22, 0.9, 34.]) # float, odd length
check_correctness([2.3, -3.14]) # float, even length
check_approximate_correctness([99.9, -48.3, -14.3]) # float, odd length
check_approximate_correctness([-10.1, 14.8, 12.99, 0.]) # float, even length

# Bigger test
import random
a = [random.randint(-20000, 20000) for _ in range(10000)]
check_correctness(a)
check_approximate_correctness(a)
check_correctness(a + [1])
check_approximate_correctness(a + [1])

# Test SArray with Nones
a += [None] * 20
random.shuffle(a)
check_correctness(a)

# Empty input
self.assertIsNone(SArray().median())

# Bad inputs
with self.assertRaises(TypeError):
SArray([1]).median(approximate="this is not a bool")
with self.assertRaises(RuntimeError):
SArray(["this", "is", "not", "numeric"]).median()

def test_max_min_sum_mean_missing(self):
# negative and positive
s = SArray([-2, 0, None, None, None], int)
Expand Down