-
-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Feature/weighted #2922
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feature/weighted #2922
Changes from 8 commits
0f2da8e
5f64492
685e5c4
c9d612d
a20a4cf
26c24b6
f3c6758
25c3c29
5d37d11
b1c572b
d1d1f2c
6be1414
059263c
8b1904b
8cad145
49d4e43
527256e
739568f
f01305d
2e3880d
ae8d048
dc7f605
c646568
e2ad69e
bd4f048
3c7695a
ef07edd
064b5a9
fec1a35
72c7942
0e91411
1eb2913
118dfed
e08c921
0fafe0b
a8d330d
ae0012f
111259b
5afc6f3
668b54b
d877022
ead681e
c4598ba
866fba5
3cc00c1
8f34167
9f0a8cd
98929f1
62c43e6
2e8aba2
d14f668
7fa78ae
3ebb9d4
f01d47a
4b184f6
1e06adc
706579a
b2718db
4c17108
8acc78e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,120 @@ | ||
|
|
||
|
|
||
| _doc_ = """ | ||
| Reduce this DataArray's data by a weighted `{fcn}` along some dimension(s). | ||
| Parameters | ||
| ---------- | ||
| dim : str or sequence of str, optional | ||
| Dimension(s) over which to apply the weighted `{fcn}`. | ||
dcherian marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| axis : int or sequence of int, optional | ||
| Axis(es) over which to apply the weighted `{fcn}`. Only one of the | ||
| 'dim' and 'axis' arguments can be supplied. If neither are supplied, | ||
| then the weighted `{fcn}` is calculated over all axes. | ||
|
||
| skipna : bool, optional | ||
mathause marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| If True, skip missing values (as marked by NaN). By default, only | ||
| skips missing values for float dtypes; other dtypes either do not | ||
| have a sentinel missing value (int) or skipna=True has not been | ||
| implemented (object, datetime64 or timedelta64). | ||
| Note: Missing values in the weights are replaced with 0 (i.e. no | ||
| weight). | ||
| keep_attrs : bool, optional | ||
| If True, the attributes (`attrs`) will be copied from the original | ||
mathause marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| object to the new one. If False (default), the new object will be | ||
| returned without attributes. | ||
| **kwargs : dict | ||
| Additional keyword arguments passed on to the appropriate array | ||
| function for calculating `{fcn}` on this object's data. | ||
| Returns | ||
| ------- | ||
| reduced : DataArray | ||
| New DataArray object with weighted `{fcn}` applied to its data and | ||
| the indicated dimension(s) removed. | ||
| """ | ||
|
|
||
|
|
||
| class DataArrayWeighted(object): | ||
| def __init__(self, obj, weights): | ||
| """ | ||
| Weighted operations for DataArray. | ||
mathause marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| Parameters | ||
| ---------- | ||
| obj : DataArray | ||
| Object over which the weighted reduction operation is applied. | ||
| weights : DataArray | ||
| An array of weights associated with the values in this Dataset. | ||
| Each value in the DataArray contributes to the reduction operation | ||
| according to its associated weight. | ||
| Note | ||
| ---- | ||
| Missing values in the weights are replaced with 0 (i.e. no weight). | ||
| """ | ||
|
|
||
| super(DataArrayWeighted, self).__init__() | ||
mathause marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| from .dataarray import DataArray | ||
|
|
||
| msg = "'weights' must be a DataArray" | ||
| assert isinstance(weights, DataArray), msg | ||
|
|
||
| self.obj = obj | ||
| self.weights = weights.fillna(0) | ||
|
|
||
| def sum_of_weights(self, dim=None, axis=None): | ||
| """ | ||
| Calcualte the sum of weights, accounting for missing values | ||
| Parameters | ||
| ---------- | ||
| dim : str or sequence of str, optional | ||
| Dimension(s) over which to sum the weights. | ||
| axis : int or sequence of int, optional | ||
| Axis(es) over which to sum the weights. Only one of the 'dim' and | ||
| 'axis' arguments can be supplied. If neither are supplied, then | ||
| the weights are summed over all axes. | ||
| """ | ||
|
|
||
| # we need to mask DATA values that are nan; else the weights are wrong | ||
| masked_weights = self.weights.where(self.obj.notnull()) | ||
|
||
|
|
||
| sum_of_weights = masked_weights.sum(dim=dim, axis=axis, skipna=True) | ||
|
|
||
| # find all weights that are valid (not 0) | ||
| valid_weights = sum_of_weights != 0. | ||
|
|
||
| # set invalid weights to nan | ||
| return sum_of_weights.where(valid_weights) | ||
|
|
||
| def sum(self, dim=None, axis=None, skipna=None, **kwargs): | ||
|
|
||
| # calculate weighted sum | ||
| return (self.obj * self.weights).sum(dim, axis=axis, skipna=skipna, | ||
| **kwargs) | ||
|
||
|
|
||
| def mean(self, dim=None, axis=None, skipna=None, **kwargs): | ||
|
|
||
| # get the sum of weights | ||
| sum_of_weights = self.sum_of_weights(dim=dim, axis=axis) | ||
|
|
||
| # get weighted sum | ||
| weighted_sum = self.sum(dim=dim, axis=axis, skipna=skipna, **kwargs) | ||
|
|
||
| # calculate weighted mean | ||
| return weighted_sum / sum_of_weights | ||
|
|
||
| def __repr__(self): | ||
| """provide a nice str repr of our weighted object""" | ||
|
|
||
| msg = "{klass} with weights along dimensions: {weight_dims}" | ||
| return msg.format(klass=self.__class__.__name__, | ||
| weight_dims=", ".join(self.weights.dims)) | ||
|
|
||
|
|
||
| # add docstrings | ||
| DataArrayWeighted.mean.__doc__ = _doc_.format(fcn='mean') | ||
| DataArrayWeighted.sum.__doc__ = _doc_.format(fcn='sum') | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,157 @@ | ||
| import pytest | ||
|
|
||
| import numpy as np | ||
|
|
||
| import xarray as xr | ||
| from xarray import ( | ||
| DataArray,) | ||
|
|
||
| from xarray.tests import assert_equal, raises_regex | ||
|
|
||
|
|
||
| def test_weigted_non_DataArray_weights(): | ||
|
|
||
| da = DataArray([1, 2]) | ||
| with raises_regex(AssertionError, "'weights' must be a DataArray"): | ||
| da.weighted([1, 2]) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize('weights', ([1, 2], [np.nan, 2], [np.nan, np.nan])) | ||
| def test_weighted_weights_nan_replaced(weights): | ||
| # make sure nans are removed from weights | ||
|
|
||
| da = DataArray([1, 2]) | ||
|
|
||
| expected = DataArray(weights).fillna(0.) | ||
| result = da.weighted(DataArray(weights)).weights | ||
|
|
||
| assert_equal(expected, result) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize(('weights', 'expected'), (([1, 2], 3), | ||
| ([0, 2], 2), | ||
| ([0, 0], np.nan), | ||
| ([-1, 1], np.nan))) | ||
| def test_weigted_sum_of_weights_no_nan(weights, expected): | ||
|
|
||
| da = DataArray([1, 2]) | ||
| weights = DataArray(weights) | ||
| result = da.weighted(weights).sum_of_weights() | ||
|
|
||
| expected = DataArray(expected) | ||
|
|
||
| assert_equal(expected, result) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize(('weights', 'expected'), (([1, 2], 2), | ||
| ([0, 2], 2), | ||
| ([0, 0], np.nan), | ||
| ([-1, 1], 1))) | ||
| def test_weigted_sum_of_weights_nan(weights, expected): | ||
mathause marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| da = DataArray([np.nan, 2]) | ||
| weights = DataArray(weights) | ||
| result = da.weighted(weights).sum_of_weights() | ||
|
|
||
| expected = DataArray(expected) | ||
|
|
||
| assert_equal(expected, result) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize('da', ([1, 2], [1, np.nan], [np.nan, np.nan])) | ||
| @pytest.mark.parametrize('factor', [0, 1, 2, 3.14]) | ||
| @pytest.mark.parametrize('skipna', (True, False)) | ||
| def test_weighted_sum_equal_weights(da, factor, skipna): | ||
| # if all weights are 'f'; weighted sum is f times the ordinary sum | ||
|
|
||
| da = DataArray(da) | ||
| weights = xr.zeros_like(da) + factor | ||
dcherian marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| expected = da.sum(skipna=skipna) * factor | ||
| result = da.weighted(weights).sum(skipna=skipna) | ||
|
|
||
| assert_equal(expected, result) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize(('weights', 'expected'), (([1, 2], 5), | ||
| ([0, 2], 4), | ||
| ([0, 0], 0))) | ||
| def test_weighted_sum_no_nan(weights, expected): | ||
| da = DataArray([1, 2]) | ||
|
|
||
| weights = DataArray(weights) | ||
| result = da.weighted(weights).sum() | ||
| expected = DataArray(expected) | ||
|
|
||
| assert_equal(expected, result) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize(('weights', 'expected'), (([1, 2], 4), | ||
| ([0, 2], 4), | ||
| ([1, 0], 0), | ||
| ([0, 0], 0))) | ||
| @pytest.mark.parametrize('skipna', (True, False)) | ||
| def test_weighted_sum_nan(weights, expected, skipna): | ||
dcherian marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| da = DataArray([np.nan, 2]) | ||
|
|
||
| weights = DataArray(weights) | ||
| result = da.weighted(weights).sum(skipna=skipna) | ||
|
|
||
| if skipna: | ||
| expected = DataArray(expected) | ||
| else: | ||
| expected = DataArray(np.nan) | ||
|
|
||
| assert_equal(expected, result) | ||
|
|
||
|
|
||
| @pytest.mark.filterwarnings("ignore:Mean of empty slice") | ||
| @pytest.mark.parametrize('da', ([1, 2], [1, np.nan], [np.nan, np.nan])) | ||
| @pytest.mark.parametrize('skipna', (True, False)) | ||
| def test_weigted_mean_equal_weights(da, skipna): | ||
| # if all weights are equal, should yield the same result as mean | ||
|
|
||
| da = DataArray(da) | ||
|
|
||
| # all weights as 1. | ||
| weights = xr.zeros_like(da) + 1 | ||
|
|
||
| expected = da.mean(skipna=skipna) | ||
| result = da.weighted(weights).mean(skipna=skipna) | ||
|
|
||
| assert_equal(expected, result) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize(('weights', 'expected'), (([4, 6], 1.6), | ||
| ([0, 1], 2.0), | ||
| ([0, 2], 2.0), | ||
| ([0, 0], np.nan))) | ||
| def test_weigted_mean_no_nan(weights, expected): | ||
mathause marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| da = DataArray([1, 2]) | ||
| weights = DataArray(weights) | ||
| expected = DataArray(expected) | ||
|
|
||
| result = da.weighted(weights).mean() | ||
|
|
||
| assert_equal(expected, result) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize(('weights', 'expected'), (([4, 6], 2.0), | ||
| ([0, 1], 2.0), | ||
| ([0, 2], 2.0), | ||
| ([0, 0], np.nan))) | ||
| @pytest.mark.parametrize('skipna', (True, False)) | ||
| def test_weigted_mean_nan(weights, expected, skipna): | ||
|
|
||
| da = DataArray([np.nan, 2]) | ||
| weights = DataArray(weights) | ||
|
|
||
| if skipna: | ||
| expected = DataArray(expected) | ||
| else: | ||
| expected = DataArray(np.nan) | ||
|
|
||
| result = da.weighted(weights).mean(skipna=skipna) | ||
|
|
||
| assert_equal(expected, result) | ||
Uh oh!
There was an error while loading. Please reload this page.