|
15 | 15 | from .utils import ( |
16 | 16 | either_dict_or_kwargs, |
17 | 17 | hashable, |
| 18 | + is_scalar, |
18 | 19 | maybe_wrap_array, |
19 | 20 | peek_at, |
20 | 21 | safe_cast_to_index, |
21 | 22 | ) |
22 | 23 | from .variable import IndexVariable, Variable, as_variable |
23 | 24 |
|
24 | 25 |
|
| 26 | +def check_reduce_dims(reduce_dims, dimensions): |
| 27 | + |
| 28 | + if reduce_dims is not ...: |
| 29 | + if is_scalar(reduce_dims): |
| 30 | + reduce_dims = [reduce_dims] |
| 31 | + if any([dim not in dimensions for dim in reduce_dims]): |
| 32 | + raise ValueError( |
| 33 | + "cannot reduce over dimensions %r. expected either '...' to reduce over all dimensions or one or more of %r." |
| 34 | + % (reduce_dims, dimensions) |
| 35 | + ) |
| 36 | + |
| 37 | + |
25 | 38 | def unique_value_groups(ar, sort=True): |
26 | 39 | """Group an array by its unique values. |
27 | 40 |
|
@@ -348,6 +361,13 @@ def __init__( |
348 | 361 | group_indices = [slice(i, i + 1) for i in group_indices] |
349 | 362 | unique_coord = group |
350 | 363 | else: |
| 364 | + if group.isnull().any(): |
| 365 | + # drop any NaN valued groups. |
| 366 | + # also drop obj values where group was NaN |
| 367 | + # Use where instead of reindex to account for duplicate coordinate labels. |
| 368 | + obj = obj.where(group.notnull(), drop=True) |
| 369 | + group = group.dropna(group_dim) |
| 370 | + |
351 | 371 | # look through group to find the unique values |
352 | 372 | unique_values, group_indices = unique_value_groups( |
353 | 373 | safe_cast_to_index(group), sort=(bins is None) |
@@ -794,15 +814,11 @@ def reduce( |
794 | 814 | if keep_attrs is None: |
795 | 815 | keep_attrs = _get_keep_attrs(default=False) |
796 | 816 |
|
797 | | - if dim is not ... and dim not in self.dims: |
798 | | - raise ValueError( |
799 | | - "cannot reduce over dimension %r. expected either '...' to reduce over all dimensions or one or more of %r." |
800 | | - % (dim, self.dims) |
801 | | - ) |
802 | | - |
803 | 817 | def reduce_array(ar): |
804 | 818 | return ar.reduce(func, dim, axis, keep_attrs=keep_attrs, **kwargs) |
805 | 819 |
|
| 820 | + check_reduce_dims(dim, self.dims) |
| 821 | + |
806 | 822 | return self.apply(reduce_array, shortcut=shortcut) |
807 | 823 |
|
808 | 824 |
|
@@ -895,11 +911,7 @@ def reduce(self, func, dim=None, keep_attrs=None, **kwargs): |
895 | 911 | def reduce_dataset(ds): |
896 | 912 | return ds.reduce(func, dim, keep_attrs, **kwargs) |
897 | 913 |
|
898 | | - if dim is not ... and dim not in self.dims: |
899 | | - raise ValueError( |
900 | | - "cannot reduce over dimension %r. expected either '...' to reduce over all dimensions or one or more of %r." |
901 | | - % (dim, self.dims) |
902 | | - ) |
| 914 | + check_reduce_dims(dim, self.dims) |
903 | 915 |
|
904 | 916 | return self.apply(reduce_dataset) |
905 | 917 |
|
|
0 commit comments