diff --git a/src/core/data/sframe/gl_sarray.hpp b/src/core/data/sframe/gl_sarray.hpp index 49145d5099..9c8b6f2e38 100644 --- a/src/core/data/sframe/gl_sarray.hpp +++ b/src/core/data/sframe/gl_sarray.hpp @@ -506,7 +506,6 @@ class gl_sarray { /* Make Friends */ /* */ /**************************************************************************/ - friend gl_sarray operator+(const flexible_type& opnd, const gl_sarray& opnd2); friend gl_sarray operator-(const flexible_type& opnd, const gl_sarray& opnd2); friend gl_sarray operator*(const flexible_type& opnd, const gl_sarray& opnd2); diff --git a/src/core/storage/sframe_interface/CMakeLists.txt b/src/core/storage/sframe_interface/CMakeLists.txt index e8c5eaf74c..1b72214703 100644 --- a/src/core/storage/sframe_interface/CMakeLists.txt +++ b/src/core/storage/sframe_interface/CMakeLists.txt @@ -6,4 +6,6 @@ unity_sarray_builder.cpp unity_sframe.cpp unity_sframe_builder.cpp unity_sgraph.cpp + REQUIRES +unity_sketches ) diff --git a/src/core/storage/sframe_interface/unity_sarray.cpp b/src/core/storage/sframe_interface/unity_sarray.cpp index f28e28cf8e..5e05e8e5eb 100644 --- a/src/core/storage/sframe_interface/unity_sarray.cpp +++ b/src/core/storage/sframe_interface/unity_sarray.cpp @@ -1102,6 +1102,60 @@ flexible_type unity_sarray::mean() { } } +flexible_type unity_sarray::median(bool approx) { + log_func_entry(); + + flex_type_enum type = dtype(); + if(type != flex_type_enum::INTEGER && type != flex_type_enum::FLOAT) { + log_and_throw("Can not calculate median on non-numeric SArray."); + } + + // Use sketch to get a fast approximate answer + turi::unity_sketch* sketch = new turi::unity_sketch(); + gl_sarray data = std::static_pointer_cast(shared_from_this()); + data = data.dropna(); + if (data.size() == 0) { + return flex_undefined(); + } + sketch->construct_from_sarray(data); + const float quantile = 0.5; + double approx_median = sketch->get_quantile(quantile); + + flexible_type result; + if (!approx) { + const double epsilon = sketch->numeric_epsilon(); + const double upper_bound = sketch->get_quantile(quantile + epsilon); + const double lower_bound = sketch->get_quantile(quantile - epsilon); + + // Count the number below lower_bound. + // Store all values between lower_bound and upper_bound. + atomic n_below_a = 0; + std::vector candidates; + std::mutex candidate_lock; + auto count_median = [&](size_t, const std::shared_ptr& rows) { + for (const auto& row : *rows) { + const flexible_type x = row[0]; + if(x < lower_bound) { + ++n_below_a; + } else if(x <= upper_bound) { + std::lock_guard lg(candidate_lock); + candidates.push_back(x); + } + } + return false; + }; + data.materialize_to_callback(count_median); + + size_t median_index = (data.size() / 2) - n_below_a; + std::nth_element(candidates.begin(), candidates.begin() + median_index, candidates.end()); + result = candidates[median_index]; + } else { + result = approx_median; + } + + return result; +} + flexible_type unity_sarray::std(size_t ddof) { log_func_entry(); flexible_type variance = this->var(ddof); diff --git a/src/core/storage/sframe_interface/unity_sarray.hpp b/src/core/storage/sframe_interface/unity_sarray.hpp index 4b251c5c36..0645ca8275 100644 --- a/src/core/storage/sframe_interface/unity_sarray.hpp +++ b/src/core/storage/sframe_interface/unity_sarray.hpp @@ -356,6 +356,15 @@ class unity_sarray: public unity_sarray_base { */ flexible_type mean(); + /** + * Returns the medain of the elements in the sarray. + * + * Invoking on an empty sarray returns flex_undefined. + * Invoking on a non-numeric type throws an exception. + * Undefined values in the array are skipped. + */ + flexible_type median(bool approx); + /** * Returns the standard deviation of the elements in sarray as a flex_float. * diff --git a/src/core/storage/sframe_interface/unity_sframe.cpp b/src/core/storage/sframe_interface/unity_sframe.cpp index 67474b84d8..b299515864 100644 --- a/src/core/storage/sframe_interface/unity_sframe.cpp +++ b/src/core/storage/sframe_interface/unity_sframe.cpp @@ -48,7 +48,6 @@ #include #include -#include #include #include #include @@ -59,6 +58,7 @@ #ifdef TC_HAS_PYTHON #include #endif + namespace turi { using namespace turi::query_eval; diff --git a/src/ml/sketches/quantile_sketch.hpp b/src/ml/sketches/quantile_sketch.hpp index 07a6b514b8..763dca24b0 100644 --- a/src/ml/sketches/quantile_sketch.hpp +++ b/src/ml/sketches/quantile_sketch.hpp @@ -38,7 +38,7 @@ class streaming_quantile_sketch; * * Usage is simple: * \code - * // constrcut sketch at a particular size + * // construct sketch at a particular size * quantile_sketch sketch(100000, 0.01); * ... * insert any number elements into the sketch using @@ -89,7 +89,7 @@ class streaming_quantile_sketch; * ----------------- * The basic mechanic of the sketch lies in making a hierarchy of sketches, each * having size m_b. The first sketch has error 0. The second sketch has error - * 1/b, the 3rd sketch has error 2/b and so on. These sketches are treated as + * 1/b, the 3rd sketch has error 2/b and so on. These sketches are treated * like a binary sequence: * - When a buffer of size b has been accumulated, it gets sorted and * inserted into sketch 1 diff --git a/src/ml/sketches/streaming_quantile_sketch.hpp b/src/ml/sketches/streaming_quantile_sketch.hpp index 9a13f17f9c..0837a44a97 100644 --- a/src/ml/sketches/streaming_quantile_sketch.hpp +++ b/src/ml/sketches/streaming_quantile_sketch.hpp @@ -393,6 +393,10 @@ class streaming_quantile_sketch { m_final.m_epsilon = m_epsilon; } + double get_epsilon() const { + return m_epsilon; + } + void save(oarchive& oarc) const { oarc << m_epsilon << m_elements_inserted << m_initial_sketch_size << m_levels << m_final; diff --git a/src/ml/sketches/unity_sketch.cpp b/src/ml/sketches/unity_sketch.cpp index 43a1d72916..45fbaaa699 100644 --- a/src/ml/sketches/unity_sketch.cpp +++ b/src/ml/sketches/unity_sketch.cpp @@ -466,7 +466,7 @@ void unity_sketch::combine_global(std::vector& thr_locks) { reset_global_sketches_and_statistics(); // merge all the sketches - for (size_t i = 0;i < m_thrlocal.size(); ++i) { + for (size_t i = 0; i < m_thrlocal.size(); ++i) { std::unique_lock thrlocal_lock(thr_locks[i]); m_num_elements_processed += m_thrlocal[i].num_elements_processed; m_discrete_sketch.combine(m_thrlocal[i].discrete_sketch); @@ -474,8 +474,6 @@ void unity_sketch::combine_global(std::vector& thr_locks) { if (m_is_numeric) { m_numeric_sketch.combine(m_thrlocal[i].numeric_sketch); } - - } if (m_is_child_sketch) { @@ -563,6 +561,9 @@ void unity_sketch::numeric_sketch_struct::combine(const numeric_sketch_struct& o max = std::max(other.max, max); sum += other.sum; + DASSERT_TRUE(epsilon == -1 || epsilon == other.epsilon); + epsilon = other.epsilon; + if (num_items + other.num_items > 0) { double delta = other.mean - mean; diff --git a/src/ml/sketches/unity_sketch.hpp b/src/ml/sketches/unity_sketch.hpp index d64f375d96..5ef35ac284 100644 --- a/src/ml/sketches/unity_sketch.hpp +++ b/src/ml/sketches/unity_sketch.hpp @@ -18,12 +18,9 @@ #include namespace turi { -// forward declarations -class unity_sarray; -class unity_sketch; - namespace sketches { +// forward declarations template class streaming_quantile_sketch; template @@ -273,6 +270,18 @@ class unity_sketch: public unity_sketch_base { return m_numeric_sketch.min; } + /** + * Returns the epsilon value used by the numeric sketch. Returns NaN on an + * empty array. Throws an exception if called on an sarray with non-numeric + * type. + */ + inline double numeric_epsilon() { + if (!m_is_numeric) log_and_throw("Epsilon value not available for a non-numeric column"); + commit_global_if_out_of_date(); + std::unique_lock global_lock(lock); + return m_numeric_sketch.max; + } + /** * Returns the sum of the values in the sarray. Returns 0 on an empty * array. Throws an exception if called on an sarray with non-numeric @@ -330,6 +339,7 @@ class unity_sketch: public unity_sketch_base { double mean = 0.0; size_t num_items = 0; double m2; + double epsilon = -1; void reset(); @@ -452,6 +462,7 @@ class unity_sketch: public unity_sketch_base { m_numeric_sketch.sum = 0; m_numeric_sketch.m2 = 0; m_numeric_sketch.num_items = 0; + m_numeric_sketch.epsilon = NAN; } inline void increase_nested_element_count(unity_sketch& nested_sketch, size_t thr, size_t count) { diff --git a/src/model_server/lib/api/unity_sarray_interface.hpp b/src/model_server/lib/api/unity_sarray_interface.hpp index 2bee385515..8f93029a11 100644 --- a/src/model_server/lib/api/unity_sarray_interface.hpp +++ b/src/model_server/lib/api/unity_sarray_interface.hpp @@ -43,6 +43,7 @@ GENERATE_INTERFACE_AND_PROXY(unity_sarray_base, unity_sarray_proxy, (flexible_type, min, ) (flexible_type, sum, ) (flexible_type, mean, ) + (flexible_type, median, (bool)) (flexible_type, std, (size_t)) (flexible_type, var, (size_t)) (size_t, num_missing, ) diff --git a/src/model_server/server/registration.cpp b/src/model_server/server/registration.cpp index 47820a0cb7..2e5c721fc3 100644 --- a/src/model_server/server/registration.cpp +++ b/src/model_server/server/registration.cpp @@ -11,7 +11,6 @@ #include #include #include -#include #include diff --git a/src/model_server/server/unity_server_init.cpp b/src/model_server/server/unity_server_init.cpp index a4bf3752b9..466d701cf0 100644 --- a/src/model_server/server/unity_server_init.cpp +++ b/src/model_server/server/unity_server_init.cpp @@ -13,9 +13,8 @@ #include #include #include -#include #include -# + namespace turi { unity_server_initializer::~unity_server_initializer() {} diff --git a/src/python/turicreate/_cython/cy_sarray.pxd b/src/python/turicreate/_cython/cy_sarray.pxd index c13ffe20c7..a40bd3aa93 100644 --- a/src/python/turicreate/_cython/cy_sarray.pxd +++ b/src/python/turicreate/_cython/cy_sarray.pxd @@ -40,6 +40,7 @@ cdef extern from "" namespace "t flexible_type min() except + flexible_type sum() except + flexible_type mean() except + + flexible_type median(bint) except + flexible_type std(size_t) except + flexible_type var(size_t) except + size_t nnz() except + @@ -142,6 +143,8 @@ cdef class UnitySArrayProxy: cpdef mean(self) + cpdef median(self, bint approx) + cpdef std(self, size_t ddof) cpdef var(self, size_t ddof) diff --git a/src/python/turicreate/_cython/cy_sarray.pyx b/src/python/turicreate/_cython/cy_sarray.pyx index 825f915a9f..1a937322f0 100644 --- a/src/python/turicreate/_cython/cy_sarray.pyx +++ b/src/python/turicreate/_cython/cy_sarray.pyx @@ -217,6 +217,12 @@ cdef class UnitySArrayProxy: tmp = self.thisptr.mean() return pyobject_from_flexible_type(tmp) + cpdef median(self, bint approx): + cdef flexible_type tmp + with nogil: + tmp = self.thisptr.median(approx) + return pyobject_from_flexible_type(tmp) + cpdef std(self, size_t ddof): cdef flexible_type tmp with nogil: diff --git a/src/python/turicreate/data_structures/sarray.py b/src/python/turicreate/data_structures/sarray.py index 36de82ca9c..39e3d75d84 100644 --- a/src/python/turicreate/data_structures/sarray.py +++ b/src/python/turicreate/data_structures/sarray.py @@ -2285,6 +2285,10 @@ def mean(self): out : float | turicreate.Image Mean of all values in SArray, or image holding per-pixel mean across the input SArray. + + See Also + -------- + median """ with cython_context(): if self.dtype == _Image: @@ -2294,6 +2298,35 @@ def mean(self): else: return self.__proxy__.mean() + def median(self, approximate=False): + """ + Median of all the values in the SArray. + + Note: no linear smoothing is performed. If the lenght of the SArray is + an odd number. Then a value between `a` and `b` will be used, where + `a` and `b` are the two middle values. + + Parameters + ---------- + approximate : bool + If True an approximate value will be returned. Calculating + an approximate value is faster. The approximate value will + be within 5% of the exact value. + + Returns + ------- + out : float | turicreate.Image + Median of all values in SArray + + See Also + -------- + mean + """ + if not isinstance(approximate, bool): + raise("\"approximate\" must be a bool.") + + return self.__proxy__.median(approximate) + def std(self, ddof=0): """ Standard deviation of all the values in the SArray. diff --git a/src/python/turicreate/test/test_sarray.py b/src/python/turicreate/test/test_sarray.py index 014e21cb65..0b8a743920 100644 --- a/src/python/turicreate/test/test_sarray.py +++ b/src/python/turicreate/test/test_sarray.py @@ -1037,6 +1037,64 @@ def test_max_min_sum_mean(self): with self.assertRaises(RuntimeError): a.mean() + def test_median(self): + + def check_correctness(l): + sa = SArray(l) + l = list(filter(lambda x: x is not None, l)) + if(len(l) % 2 == 1): + self.assertAlmostEqual(sa.median(), np.median(l)) + else: + l = sorted(l) + self.assertTrue(l[(len(l)//2)-1] <= sa.median() <= l[len(l)//2]) + + def check_approximate_correctness(l): + sa = SArray(l) + approx = sa.median(approximate=True) + # The approximate answer should be within 5% + if(len(l) % 2 == 1): + exact = np.median(l) + fuzzy_lower_bound = exact - (abs(exact) * 0.5) + fuzzy_upper_bound = exact + (abs(exact) * 0.5) + else: + l = sorted(l) + lower, upper = l[(len(l)//2)-1], l[len(l)//2] + fuzzy_lower_bound = lower - (abs(lower) * 0.5) + fuzzy_upper_bound = upper + (abs(upper) * 0.5) + self.assertTrue(fuzzy_lower_bound <= approx <= fuzzy_upper_bound) + + check_correctness([10, 201, -3]) # int, odd length + check_correctness([12, 3, -1, 5]) # int, even length + check_approximate_correctness([1, 30, 99, 0, 10]) # int, odd length + check_approximate_correctness([-4, 10, -1, -100]) # int, even length + + check_correctness([-2.22, 0.9, 34.]) # float, odd length + check_correctness([2.3, -3.14]) # float, even length + check_approximate_correctness([99.9, -48.3, -14.3]) # float, odd length + check_approximate_correctness([-10.1, 14.8, 12.99, 0.]) # float, even length + + # Bigger test + import random + a = [random.randint(-20000, 20000) for _ in range(10000)] + check_correctness(a) + check_approximate_correctness(a) + check_correctness(a + [1]) + check_approximate_correctness(a + [1]) + + # Test SArray with Nones + a += [None] * 20 + random.shuffle(a) + check_correctness(a) + + # Empty input + self.assertIsNone(SArray().median()) + + # Bad inputs + with self.assertRaises(TypeError): + SArray([1]).median(approximate="this is not a bool") + with self.assertRaises(RuntimeError): + SArray(["this", "is", "not", "numeric"]).median() + def test_max_min_sum_mean_missing(self): # negative and positive s = SArray([-2, 0, None, None, None], int)