From b47d1abce5d6aca3cb4f8d77c665d468c58068c8 Mon Sep 17 00:00:00 2001 From: Anderson Banihirwe Date: Fri, 15 Mar 2024 17:38:05 -0700 Subject: [PATCH 1/7] Implement setitem syntax for `.oindex` and `.vindex` properties --- xarray/core/indexing.py | 142 +++++++++++++++++++++++++++++----------- xarray/core/variable.py | 2 +- 2 files changed, 105 insertions(+), 39 deletions(-) diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py index ea8ae44bb4d..f44f9f4d202 100644 --- a/xarray/core/indexing.py +++ b/xarray/core/indexing.py @@ -326,18 +326,23 @@ def as_integer_slice(value): class IndexCallable: - """Provide getitem syntax for a callable object.""" + """Provide getitem and setitem syntax for callable objects.""" - __slots__ = ("func",) + __slots__ = ("func_get", "func_set") - def __init__(self, func): - self.func = func + def __init__(self, func_get, func_set=None): + self.func_get = func_get + self.func_set = func_set def __getitem__(self, key): - return self.func(key) + return self.func_get(key) def __setitem__(self, key, value): - raise NotImplementedError + if self.func_set is None: + raise NotImplementedError( + "Setting values is not supported for this indexer." + ) + self.func_set(key, value) class BasicIndexer(ExplicitIndexer): @@ -486,10 +491,24 @@ def __array__(self, dtype: np.typing.DTypeLike = None) -> np.ndarray: return np.asarray(self.get_duck_array(), dtype=dtype) def _oindex_get(self, key): - raise NotImplementedError("This method should be overridden") + raise NotImplementedError( + f"{self.__class__.__name__}._oindex_get method should be overridden" + ) def _vindex_get(self, key): - raise NotImplementedError("This method should be overridden") + raise NotImplementedError( + f"{self.__class__.__name__}._vindex_get method should be overridden" + ) + + def _oindex_set(self, key, value): + raise NotImplementedError( + f"{self.__class__.__name__}._oindex_set method should be overridden" + ) + + def _vindex_set(self, key, value): + raise NotImplementedError( + f"{self.__class__.__name__}._vindex_set method should be overridden" + ) def _check_and_raise_if_non_basic_indexer(self, key): if isinstance(key, (VectorizedIndexer, OuterIndexer)): @@ -500,11 +519,11 @@ def _check_and_raise_if_non_basic_indexer(self, key): @property def oindex(self): - return IndexCallable(self._oindex_get) + return IndexCallable(self._oindex_get, self._oindex_set) @property def vindex(self): - return IndexCallable(self._vindex_get) + return IndexCallable(self._vindex_get, self._vindex_set) class ImplicitToExplicitIndexingAdapter(NDArrayMixin): @@ -616,12 +635,18 @@ def __getitem__(self, indexer): self._check_and_raise_if_non_basic_indexer(indexer) return type(self)(self.array, self._updated_key(indexer)) + def _vindex_set(self, key, value): + raise NotImplementedError( + "Lazy item assignment with the vectorized indexer is not yet " + "implemented. Load your data first by .load() or compute()." + ) + + def _oindex_set(self, key, value): + full_key = self._updated_key(key) + self.array[full_key] = value + def __setitem__(self, key, value): - if isinstance(key, VectorizedIndexer): - raise NotImplementedError( - "Lazy item assignment with the vectorized indexer is not yet " - "implemented. Load your data first by .load() or compute()." - ) + self._check_and_raise_if_non_basic_indexer(key) full_key = self._updated_key(key) self.array[full_key] = value @@ -657,7 +682,6 @@ def shape(self) -> tuple[int, ...]: return np.broadcast(*self.key.tuple).shape def get_duck_array(self): - if isinstance(self.array, ExplicitlyIndexedNDArrayMixin): array = apply_indexer(self.array, self.key) else: @@ -739,8 +763,18 @@ def __getitem__(self, key): def transpose(self, order): return self.array.transpose(order) + def _vindex_set(self, key, value): + self._ensure_copied() + self.array.vindex[key] = value + + def _oindex_set(self, key, value): + self._ensure_copied() + self.array.oindex[key] = value + def __setitem__(self, key, value): + self._check_and_raise_if_non_basic_indexer(key) self._ensure_copied() + self.array[key] = value def __deepcopy__(self, memo): @@ -779,7 +813,14 @@ def __getitem__(self, key): def transpose(self, order): return self.array.transpose(order) + def _vindex_set(self, key, value): + self.array.vindex[key] = value + + def _oindex_set(self, key, value): + self.array.oindex[key] = value + def __setitem__(self, key, value): + self._check_and_raise_if_non_basic_indexer(key) self.array[key] = value @@ -950,6 +991,16 @@ def apply_indexer(indexable, indexer): return indexable[indexer] +def set_with_indexer(indexable, indexer, value): + """Set values in an indexable object using an indexer.""" + if isinstance(indexer, VectorizedIndexer): + indexable.vindex[indexer] = value + elif isinstance(indexer, OuterIndexer): + indexable.oindex[indexer] = value + else: + indexable[indexer] = value + + def decompose_indexer( indexer: ExplicitIndexer, shape: tuple[int, ...], indexing_support: IndexingSupport ) -> tuple[ExplicitIndexer, ExplicitIndexer]: @@ -1433,11 +1484,10 @@ def __getitem__(self, key): array, key = self._indexing_array_and_key(key) return array[key] - def __setitem__(self, key, value): - array, key = self._indexing_array_and_key(key) + def _safe_setitem(self, array, key, value): try: array[key] = value - except ValueError: + except ValueError as exc: # More informative exception if read-only view if not array.flags.writeable and not array.flags.owndata: raise ValueError( @@ -1445,7 +1495,20 @@ def __setitem__(self, key, value): "Do you want to .copy() array first?" ) else: - raise + raise exc + + def _oindex_set(self, key, value): + key = _outer_to_numpy_indexer(key, self.array.shape) + self._safe_setitem(self.array, key, value) + + def _vindex_set(self, key, value): + array = NumpyVIndexAdapter(self.array) + self._safe_setitem(array, key.tuple, value) + + def __setitem__(self, key, value): + self._check_and_raise_if_non_basic_indexer(key) + array, key = self._indexing_array_and_key(key) + self._safe_setitem(array, key, value) class NdArrayLikeIndexingAdapter(NumpyIndexingAdapter): @@ -1488,13 +1551,15 @@ def __getitem__(self, key): self._check_and_raise_if_non_basic_indexer(key) return self.array[key.tuple] + def _oindex_set(self, key, value): + self.array[key.tuple] = value + + def _vindex_set(self, key, value): + raise TypeError("Vectorized indexing is not supported") + def __setitem__(self, key, value): - if isinstance(key, (BasicIndexer, OuterIndexer)): - self.array[key.tuple] = value - elif isinstance(key, VectorizedIndexer): - raise TypeError("Vectorized indexing is not supported") - else: - raise TypeError(f"Unrecognized indexer: {key}") + self._check_and_raise_if_non_basic_indexer(key) + self.array[key.tuple] = value def transpose(self, order): xp = self.array.__array_namespace__() @@ -1530,19 +1595,20 @@ def __getitem__(self, key): self._check_and_raise_if_non_basic_indexer(key) return self.array[key.tuple] + def _oindex_set(self, key, value): + num_non_slices = sum(0 if isinstance(k, slice) else 1 for k in key.tuple) + if num_non_slices > 1: + raise NotImplementedError( + "xarray can't set arrays with multiple " "array indices to dask yet." + ) + self.array[key.tuple] = value + + def _vindex_set(self, key, value): + self.array.vindex[key.tuple] = value + def __setitem__(self, key, value): - if isinstance(key, BasicIndexer): - self.array[key.tuple] = value - elif isinstance(key, VectorizedIndexer): - self.array.vindex[key.tuple] = value - elif isinstance(key, OuterIndexer): - num_non_slices = sum(0 if isinstance(k, slice) else 1 for k in key.tuple) - if num_non_slices > 1: - raise NotImplementedError( - "xarray can't set arrays with multiple " - "array indices to dask yet." - ) - self.array[key.tuple] = value + self._check_and_raise_if_non_basic_indexer(key) + self.array[key.tuple] = value def transpose(self, order): return self.array.transpose(order) diff --git a/xarray/core/variable.py b/xarray/core/variable.py index a03e93ac699..29938ed85da 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -842,7 +842,7 @@ def __setitem__(self, key, value): value = np.moveaxis(value, new_order, range(len(new_order))) indexable = as_indexable(self._data) - indexable[index_tuple] = value + indexing.set_with_indexer(indexable, index_tuple, value) @property def encoding(self) -> dict[Any, Any]: From 341328915ae1bc7be77d60a048717ce3db1626ea Mon Sep 17 00:00:00 2001 From: Anderson Banihirwe <13301940+andersy005@users.noreply.github.com> Date: Mon, 18 Mar 2024 14:51:01 -0700 Subject: [PATCH 2/7] Apply suggestions from code review Co-authored-by: Deepak Cherian --- xarray/core/indexing.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py index f44f9f4d202..9a9bedec4fb 100644 --- a/xarray/core/indexing.py +++ b/xarray/core/indexing.py @@ -328,7 +328,7 @@ def as_integer_slice(value): class IndexCallable: """Provide getitem and setitem syntax for callable objects.""" - __slots__ = ("func_get", "func_set") + __slots__ = ("getter", "setter") def __init__(self, func_get, func_set=None): self.func_get = func_get @@ -643,7 +643,7 @@ def _vindex_set(self, key, value): def _oindex_set(self, key, value): full_key = self._updated_key(key) - self.array[full_key] = value + self.array.oindex[full_key] = value def __setitem__(self, key, value): self._check_and_raise_if_non_basic_indexer(key) From 2b0888c84d659c6d4ed1a8539ae669d463589446 Mon Sep 17 00:00:00 2001 From: Anderson Banihirwe Date: Mon, 18 Mar 2024 14:56:46 -0700 Subject: [PATCH 3/7] use getter and setter properties instead of func_get and func_set methods --- xarray/core/indexing.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py index 9a9bedec4fb..80d614f2b00 100644 --- a/xarray/core/indexing.py +++ b/xarray/core/indexing.py @@ -330,19 +330,19 @@ class IndexCallable: __slots__ = ("getter", "setter") - def __init__(self, func_get, func_set=None): - self.func_get = func_get - self.func_set = func_set + def __init__(self, getter, setter=None): + self.getter = getter + self.setter = setter def __getitem__(self, key): - return self.func_get(key) + return self.getter(key) def __setitem__(self, key, value): - if self.func_set is None: + if self.setter is None: raise NotImplementedError( "Setting values is not supported for this indexer." ) - self.func_set(key, value) + self.setter(key, value) class BasicIndexer(ExplicitIndexer): From bd26644287c70f0953de394bcb965e1f9f40503b Mon Sep 17 00:00:00 2001 From: Anderson Banihirwe Date: Mon, 18 Mar 2024 15:02:06 -0700 Subject: [PATCH 4/7] delete unnecessary _indexing_array_and_key method --- xarray/core/indexing.py | 31 +++++++++++-------------------- 1 file changed, 11 insertions(+), 20 deletions(-) diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py index 80d614f2b00..407fda610fc 100644 --- a/xarray/core/indexing.py +++ b/xarray/core/indexing.py @@ -1450,24 +1450,6 @@ def __init__(self, array): ) self.array = array - def _indexing_array_and_key(self, key): - if isinstance(key, OuterIndexer): - array = self.array - key = _outer_to_numpy_indexer(key, self.array.shape) - elif isinstance(key, VectorizedIndexer): - array = NumpyVIndexAdapter(self.array) - key = key.tuple - elif isinstance(key, BasicIndexer): - array = self.array - # We want 0d slices rather than scalars. This is achieved by - # appending an ellipsis (see - # https://numpy.org/doc/stable/reference/arrays.indexing.html#detailed-notes). - key = key.tuple + (Ellipsis,) - else: - raise TypeError(f"unexpected key type: {type(key)}") - - return array, key - def transpose(self, order): return self.array.transpose(order) @@ -1481,7 +1463,12 @@ def _vindex_get(self, key): def __getitem__(self, key): self._check_and_raise_if_non_basic_indexer(key) - array, key = self._indexing_array_and_key(key) + + array = self.array + # We want 0d slices rather than scalars. This is achieved by + # appending an ellipsis (see + # https://numpy.org/doc/stable/reference/arrays.indexing.html#detailed-notes). + key = key.tuple + (Ellipsis,) return array[key] def _safe_setitem(self, array, key, value): @@ -1507,7 +1494,11 @@ def _vindex_set(self, key, value): def __setitem__(self, key, value): self._check_and_raise_if_non_basic_indexer(key) - array, key = self._indexing_array_and_key(key) + array = self.array + # We want 0d slices rather than scalars. This is achieved by + # appending an ellipsis (see + # https://numpy.org/doc/stable/reference/arrays.indexing.html#detailed-notes). + key = key.tuple + (Ellipsis,) self._safe_setitem(array, key, value) From d015df4d7f1fd63ed1b0039b5b89460ba7f1e50d Mon Sep 17 00:00:00 2001 From: Anderson Banihirwe Date: Mon, 18 Mar 2024 15:10:16 -0700 Subject: [PATCH 5/7] Add tests for IndexCallable class --- xarray/tests/test_indexing.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/xarray/tests/test_indexing.py b/xarray/tests/test_indexing.py index c3989bbf23e..41c241832a8 100644 --- a/xarray/tests/test_indexing.py +++ b/xarray/tests/test_indexing.py @@ -23,6 +23,28 @@ B = IndexerMaker(indexing.BasicIndexer) +class TestIndexCallable: + def test_getitem(self): + def getter(key): + return key * 2 + + indexer = indexing.IndexCallable(getter) + assert indexer[3] == 6 + assert indexer[0] == 0 + assert indexer[-1] == -2 + + def test_setitem(self): + def getter(key): + return key * 2 + + def setter(key, value): + raise NotImplementedError("Setter not implemented") + + indexer = indexing.IndexCallable(getter, setter) + with pytest.raises(NotImplementedError): + indexer[3] = 6 + + class TestIndexers: def set_to_zero(self, x, i): x = x.copy() From 8d58b6ffede48fac31776d9fd488a9930fabbf0b Mon Sep 17 00:00:00 2001 From: Anderson Banihirwe Date: Mon, 18 Mar 2024 15:49:00 -0700 Subject: [PATCH 6/7] fix bug/unnecessary code introduced in #8790 --- xarray/tests/test_indexing.py | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/xarray/tests/test_indexing.py b/xarray/tests/test_indexing.py index 41c241832a8..a3f0d821530 100644 --- a/xarray/tests/test_indexing.py +++ b/xarray/tests/test_indexing.py @@ -383,15 +383,8 @@ def test_vectorized_lazily_indexed_array(self) -> None: def check_indexing(v_eager, v_lazy, indexers): for indexer in indexers: - if isinstance(indexer, indexing.VectorizedIndexer): - actual = v_lazy.vindex[indexer] - expected = v_eager.vindex[indexer] - elif isinstance(indexer, indexing.OuterIndexer): - actual = v_lazy.oindex[indexer] - expected = v_eager.oindex[indexer] - else: - actual = v_lazy[indexer] - expected = v_eager[indexer] + actual = v_lazy[indexer] + expected = v_eager[indexer] assert expected.shape == actual.shape assert isinstance( actual._data, From c4547d3c6a582567200ba85a373fc7ba3c25eedf Mon Sep 17 00:00:00 2001 From: Anderson Banihirwe Date: Mon, 18 Mar 2024 17:05:04 -0700 Subject: [PATCH 7/7] add unit tests --- xarray/tests/test_indexing.py | 35 +++++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/xarray/tests/test_indexing.py b/xarray/tests/test_indexing.py index a3f0d821530..e650c454eac 100644 --- a/xarray/tests/test_indexing.py +++ b/xarray/tests/test_indexing.py @@ -421,6 +421,41 @@ def check_indexing(v_eager, v_lazy, indexers): ] check_indexing(v_eager, v_lazy, indexers) + def test_lazily_indexed_array_vindex_setitem(self) -> None: + + lazy = indexing.LazilyIndexedArray(np.random.rand(10, 20, 30)) + + # vectorized indexing + indexer = indexing.VectorizedIndexer( + (np.array([0, 1]), np.array([0, 1]), slice(None, None, None)) + ) + with pytest.raises( + NotImplementedError, + match=r"Lazy item assignment with the vectorized indexer is not yet", + ): + lazy.vindex[indexer] = 0 + + @pytest.mark.parametrize( + "indexer_class, key, value", + [ + (indexing.OuterIndexer, (0, 1, slice(None, None, None)), 10), + (indexing.BasicIndexer, (0, 1, slice(None, None, None)), 10), + ], + ) + def test_lazily_indexed_array_setitem(self, indexer_class, key, value) -> None: + original = np.random.rand(10, 20, 30) + x = indexing.NumpyIndexingAdapter(original) + lazy = indexing.LazilyIndexedArray(x) + + if indexer_class is indexing.BasicIndexer: + indexer = indexer_class(key) + lazy[indexer] = value + elif indexer_class is indexing.OuterIndexer: + indexer = indexer_class(key) + lazy.oindex[indexer] = value + + assert_array_equal(original[key], value) + class TestCopyOnWriteArray: def test_setitem(self) -> None: