-
-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Use strict type hinting for namedarray #8241
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 87 commits
b894d25
895c2cb
15a2a8b
89c8fea
4e10650
fc7f69a
43f4e20
e6147d3
b84b1bc
7cb3a09
4a83aa4
7ad7634
cebd1eb
71a942b
4cdbed5
c80ff30
2e84c31
f6c6f44
4b897eb
7a5cb43
7081ee1
87958d9
f913206
3b0c122
027f300
23ec9fe
d94b766
c353336
5e26eba
c5a9594
ecb50c0
9ab9dae
19b3304
41bd67c
5f58cee
7f1a94e
685ca7c
7aa2f57
99b0aca
84b6894
08d11ef
196a5c6
07e3085
707f244
a3901bc
df13d47
f76aeb1
cce278c
3865264
b61d9a8
4dec3ca
26ac902
9d23245
1f1a25d
d8007e8
762e808
459b38a
1f93f5f
b2570dd
6835c09
1bac4af
5d72861
ebf4752
6c8fac9
ee49c5e
5de4142
99bf8aa
fa41cbe
b27145e
7f262d5
401a93a
bcda5a4
2535a5f
2c5b49d
9d29827
2fba5a9
80842ea
cf8d5cc
32439fe
946bd3d
130c894
e0064b9
fec9f1b
877f0f1
a5eddb1
2194715
77e05f2
1348df6
025e9cc
2305216
13c8953
5559548
a177ce7
ca7ee37
ce77930
99f6c9b
0fa4fd3
5b98dd5
56a7755
b321a84
6a33331
863ed1d
11b36fa
2bd6f8c
476dda2
27c18b8
931659f
c2a1fb7
892e83d
58266c4
00cba3d
114c45c
c414c0a
ec9c173
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,7 +2,6 @@ | |
|
|
||
| import copy | ||
| import math | ||
| import sys | ||
| import typing | ||
| from collections.abc import Hashable, Iterable, Mapping | ||
|
|
||
|
|
@@ -11,30 +10,38 @@ | |
| # TODO: get rid of this after migrating this class to array API | ||
| from xarray.core import dtypes | ||
| from xarray.core.indexing import ExplicitlyIndexed | ||
| from xarray.core.utils import Default, _default | ||
| from xarray.namedarray.utils import ( | ||
| Default, | ||
| T_DuckArray, | ||
| _default, | ||
| is_chunked_duck_array, | ||
| is_duck_array, | ||
| is_duck_dask_array, | ||
| to_0d_object_array, | ||
| ) | ||
|
|
||
| if typing.TYPE_CHECKING: | ||
| T_NamedArray = typing.TypeVar("T_NamedArray", bound="NamedArray") | ||
| from xarray.namedarray.utils import Self # type: ignore[attr-defined] | ||
|
|
||
| try: | ||
| from dask.typing import ( | ||
| Graph, | ||
| NestedKeys, | ||
| PostComputeCallable, | ||
| PostPersistCallable, | ||
| SchedulerGetCallable, | ||
| ) | ||
| except ImportError: | ||
| Graph: typing.Any # type: ignore[no-redef] | ||
| NestedKeys: typing.Any # type: ignore[no-redef] | ||
| SchedulerGetCallable: typing.Any # type: ignore[no-redef] | ||
| PostComputeCallable: typing.Any # type: ignore[no-redef] | ||
| PostPersistCallable: typing.Any # type: ignore[no-redef] | ||
|
|
||
| # T_NamedArray = typing.TypeVar("T_NamedArray", bound="NamedArray") | ||
| DimsInput = typing.Union[str, Iterable[Hashable]] | ||
| Dims = tuple[Hashable, ...] | ||
|
|
||
|
|
||
| try: | ||
| if sys.version_info >= (3, 11): | ||
| from typing import Self | ||
| else: | ||
| from typing_extensions import Self | ||
| except ImportError: | ||
| if typing.TYPE_CHECKING: | ||
| raise | ||
| else: | ||
| Self: typing.Any = None | ||
| AttrsInput = typing.Union[Mapping[typing.Any, typing.Any], None] | ||
|
|
||
|
|
||
| # TODO: Add tests! | ||
|
|
@@ -46,7 +53,7 @@ def as_compatible_data( | |
| return typing.cast(T_DuckArray, data) | ||
|
|
||
| if isinstance(data, np.ma.MaskedArray): | ||
| mask = np.ma.getmaskarray(data) | ||
| mask = np.ma.getmaskarray(data) # type: ignore[no-untyped-call] | ||
| if mask.any(): | ||
| # TODO: requires refactoring/vendoring xarray.core.dtypes and xarray.core.duck_array_ops | ||
| raise NotImplementedError("MaskedArray is not supported yet") | ||
|
|
@@ -74,13 +81,17 @@ class NamedArray(typing.Generic[T_DuckArray]): | |
| Numeric operations on this object implement array broadcasting and dimension alignment based on dimension names, | ||
| rather than axis order.""" | ||
|
|
||
| __slots__ = ("_dims", "_data", "_attrs") | ||
| __slots__ = ("_data", "_dims", "_attrs") | ||
|
|
||
| _data: T_DuckArray | ||
| _dims: Dims | ||
| _attrs: dict[typing.Any, typing.Any] | None | ||
|
|
||
| def __init__( | ||
| self, | ||
| dims: DimsInput, | ||
| data: T_DuckArray | np.typing.ArrayLike, | ||
| attrs: dict | None = None, | ||
| attrs: AttrsInput = None, | ||
| fastpath: bool = False, | ||
| ): | ||
| """ | ||
|
|
@@ -105,9 +116,9 @@ def __init__( | |
|
|
||
|
|
||
| """ | ||
| self._data: T_DuckArray = as_compatible_data(data, fastpath=fastpath) | ||
| self._dims: Dims = self._parse_dimensions(dims) | ||
| self._attrs: dict | None = dict(attrs) if attrs else None | ||
| self._data = as_compatible_data(data, fastpath=fastpath) | ||
| self._dims = self._parse_dimensions(dims) | ||
| self._attrs = dict(attrs) if attrs else None | ||
|
|
||
| @property | ||
| def ndim(self) -> int: | ||
|
|
@@ -140,7 +151,7 @@ def __len__(self) -> int: | |
| raise TypeError("len() of unsized object") from exc | ||
|
|
||
| @property | ||
| def dtype(self) -> np.dtype: | ||
| def dtype(self) -> np.dtype[typing.Any]: | ||
| """ | ||
| Data-type of the array’s elements. | ||
|
|
||
|
|
@@ -178,7 +189,7 @@ def nbytes(self) -> int: | |
| the bytes consumed based on the ``size`` and ``dtype``. | ||
| """ | ||
| if hasattr(self._data, "nbytes"): | ||
| return self._data.nbytes | ||
| return self._data.nbytes # type: ignore[no-any-return] | ||
| else: | ||
| return self.size * self.dtype.itemsize | ||
|
|
||
|
|
@@ -208,7 +219,7 @@ def attrs(self) -> dict[typing.Any, typing.Any]: | |
| return self._attrs | ||
|
|
||
| @attrs.setter | ||
| def attrs(self, value: Mapping) -> None: | ||
| def attrs(self, value: Mapping[typing.Any, typing.Any]) -> None: | ||
| self._attrs = dict(value) | ||
|
|
||
| def _check_shape(self, new_data: T_DuckArray) -> None: | ||
|
|
@@ -256,43 +267,78 @@ def imag(self) -> Self: | |
| """ | ||
| return self._replace(data=self.data.imag) | ||
|
|
||
| def __dask_tokenize__(self): | ||
| # Use v.data, instead of v._data, in order to cope with the wrappers | ||
| # around NetCDF and the like | ||
| from dask.base import normalize_token | ||
| def __dask_tokenize__(self) -> Hashable | None: | ||
| if is_duck_dask_array(self._data): | ||
| # Use v.data, instead of v._data, in order to cope with the wrappers | ||
| # around NetCDF and the like | ||
| from dask.base import normalize_token | ||
|
|
||
| return normalize_token((type(self), self._dims, self.data, self.attrs)) | ||
| s, d, a, attrs = type(self), self._dims, self.data, self.attrs | ||
| return normalize_token((s, d, a, attrs)) # type: ignore[no-any-return] | ||
| else: | ||
| return None | ||
|
|
||
| def __dask_graph__(self): | ||
| return self._data.__dask_graph__() if is_duck_dask_array(self._data) else None | ||
| def __dask_graph__(self) -> Graph | None: | ||
| if is_duck_dask_array(self._data): | ||
| return self._data.__dask_graph__() | ||
| else: | ||
| # TODO: Should this method just raise instead? | ||
| # raise NotImplementedError("Method requires self.data to be a dask array") | ||
| return None | ||
|
|
||
| def __dask_keys__(self): | ||
| return self._data.__dask_keys__() | ||
| def __dask_keys__(self) -> NestedKeys: | ||
| if is_duck_dask_array(self._data): | ||
| return self._data.__dask_keys__() | ||
| else: | ||
| raise AttributeError("Method requires self.data to be a dask array.") | ||
|
|
||
| def __dask_layers__(self): | ||
| return self._data.__dask_layers__() | ||
| def __dask_layers__(self) -> typing.Sequence[str]: | ||
| if is_duck_dask_array(self._data): | ||
| return self._data.__dask_layers__() | ||
| else: | ||
| raise AttributeError("Method requires self.data to be a dask array.") | ||
|
|
||
| @property | ||
| def __dask_optimize__(self) -> typing.Callable: | ||
| return self._data.__dask_optimize__ | ||
| def __dask_optimize__( | ||
| self, | ||
| ) -> typing.Callable[..., dict[typing.Any, typing.Any]]: | ||
| if is_duck_dask_array(self._data): | ||
| return self._data.__dask_optimize__() # type: ignore[no-any-return] | ||
| else: | ||
| raise AttributeError("Method requires self.data to be a dask array.") | ||
|
|
||
| @property | ||
| def __dask_scheduler__(self) -> typing.Callable: | ||
| return self._data.__dask_scheduler__ | ||
| def __dask_scheduler__(self) -> staticmethod[SchedulerGetCallable]: | ||
| if is_duck_dask_array(self._data): | ||
| return self._data.__dask_scheduler__() # type: ignore[no-any-return] | ||
| else: | ||
| raise AttributeError("Method requires self.data to be a dask array.") | ||
|
|
||
| def __dask_postcompute__( | ||
| self, | ||
| ) -> tuple[typing.Callable, tuple[typing.Any, ...]]: | ||
| array_func, array_args = self._data.__dask_postcompute__() | ||
| return self._dask_finalize, (array_func,) + array_args | ||
| ) -> tuple[PostComputeCallable, tuple[typing.Any, ...]]: | ||
| if is_duck_dask_array(self._data): | ||
| array_func, array_args = self._data.__dask_postcompute__() # type: ignore[no-untyped-call] | ||
| return self._dask_finalize, (array_func,) + array_args | ||
| else: | ||
| raise AttributeError("Method requires self.data to be a dask array.") | ||
|
|
||
| def __dask_postpersist__( | ||
| self, | ||
| ) -> tuple[typing.Callable, tuple[typing.Any, ...]]: | ||
| array_func, array_args = self._data.__dask_postpersist__() | ||
| return self._dask_finalize, (array_func,) + array_args | ||
| ) -> tuple[PostPersistCallable, tuple[typing.Any, ...]]: | ||
| if is_duck_dask_array(self._data): | ||
| array_func, array_args = self._data.__dask_postpersist__() # type: ignore[no-untyped-call] | ||
| return self._dask_finalize, (array_func,) + array_args | ||
| else: | ||
| raise AttributeError("Method requires self.data to be a dask array.") | ||
|
|
||
| def _dask_finalize(self, results, array_func, *args, **kwargs) -> Self: | ||
| def _dask_finalize( | ||
| self, | ||
| results: T_DuckArray, | ||
| array_func: typing.Callable[..., T_DuckArray], | ||
| *args: typing.Any, | ||
| **kwargs: typing.Any, | ||
| ) -> Self: | ||
| data = array_func(results, *args, **kwargs) | ||
| return type(self)(self._dims, data, attrs=self._attrs) | ||
|
|
||
|
|
@@ -308,7 +354,13 @@ def chunks(self) -> tuple[tuple[int, ...], ...] | None: | |
| NamedArray.chunksizes | ||
| xarray.unify_chunks | ||
| """ | ||
| return getattr(self._data, "chunks", None) | ||
| data = self._data | ||
| # reveal_type(data) | ||
| if is_chunked_duck_array(data): | ||
| # reveal_type(data) | ||
| return data.chunks | ||
| else: | ||
| return None | ||
|
|
||
| @property | ||
| def chunksizes( | ||
|
|
@@ -328,8 +380,9 @@ def chunksizes( | |
| NamedArray.chunks | ||
| xarray.unify_chunks | ||
| """ | ||
| if hasattr(self._data, "chunks"): | ||
| return dict(zip(self.dims, self.data.chunks)) | ||
| data = self._data | ||
| if is_chunked_duck_array(data): | ||
| return dict(zip(self.dims, data.chunks)) | ||
| else: | ||
| return {} | ||
|
|
||
|
|
@@ -338,7 +391,12 @@ def sizes(self) -> dict[Hashable, int]: | |
| """Ordered mapping from dimension names to lengths.""" | ||
| return dict(zip(self.dims, self.shape)) | ||
|
|
||
| def _replace(self, dims=_default, data=_default, attrs=_default) -> Self: | ||
| def _replace( | ||
| self, | ||
| dims: DimsInput | Default = _default, | ||
| data: T_DuckArray | np.typing.ArrayLike | Default = _default, | ||
| attrs: AttrsInput | Default = _default, | ||
| ) -> Self: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If the duck array type differs from Self the generic part of the return type should change as well. Probably need a second TypeVar for this. |
||
| if dims is _default: | ||
| dims = copy.copy(self._dims) | ||
| if data is _default: | ||
|
|
@@ -415,7 +473,7 @@ def _nonzero(self) -> tuple[Self, ...]: | |
| def _as_sparse( | ||
| self, | ||
| sparse_format: str | Default = _default, | ||
| fill_value=dtypes.NA, | ||
| fill_value: typing.Any = dtypes.NA, | ||
|
||
| ) -> Self: | ||
| """ | ||
| use sparse-array as backend. | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@headtr1ck Do you (or anyone else) understand why
dataisn't narrowed down toT_ChunkedArray?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure actually.
But the "`1" in the first reveal_type usually is already an indication that something is not correct. Even though I have never figured out why Mypy does this and what it means (somehow the type is not exactly known at this time or part of a Union or something like that).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Somehow I like the
getattrimplementation better than the explicit check... it's just much easier to read and understand.Still your issue indicates that something is wrong somewhere...