Skip to content
21 changes: 10 additions & 11 deletions xarray/core/_typed_ops.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -9,24 +9,23 @@ from .dataarray import DataArray
from .dataset import Dataset
from .groupby import DataArrayGroupBy, DatasetGroupBy, GroupBy
from .npcompat import ArrayLike
from .types import (
DaCompatible,
DsCompatible,
GroupByIncompatible,
ScalarOrArray,
T_DataArray,
T_Dataset,
T_Variable,
VarCompatible,
)
from .variable import Variable

try:
from dask.array import Array as DaskArray
except ImportError:
DaskArray = np.ndarray

# DatasetOpsMixin etc. are parent classes of Dataset etc.
T_Dataset = TypeVar("T_Dataset", bound="DatasetOpsMixin")
T_DataArray = TypeVar("T_DataArray", bound="DataArrayOpsMixin")
T_Variable = TypeVar("T_Variable", bound="VariableOpsMixin")

ScalarOrArray = Union[ArrayLike, np.generic, np.ndarray, DaskArray]
DsCompatible = Union[Dataset, DataArray, Variable, GroupBy, ScalarOrArray]
DaCompatible = Union[DataArray, Variable, DataArrayGroupBy, ScalarOrArray]
VarCompatible = Union[Variable, ScalarOrArray]
GroupByIncompatible = Union[Variable, GroupBy]

class DatasetOpsMixin:
__slots__ = ()
def _binary_op(self, other, f, reflexive=...): ...
Expand Down
8 changes: 4 additions & 4 deletions xarray/core/common.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import warnings
from contextlib import suppress
from html import escape
Expand Down Expand Up @@ -36,10 +38,10 @@
if TYPE_CHECKING:
from .dataarray import DataArray
from .dataset import Dataset
from .types import T_DataWithCoords, T_DSorDA
from .variable import Variable
from .weighted import Weighted

T_DataWithCoords = TypeVar("T_DataWithCoords", bound="DataWithCoords")

C = TypeVar("C")
T = TypeVar("T")
Expand Down Expand Up @@ -795,9 +797,7 @@ def groupby_bins(
},
)

def weighted(
self: T_DataWithCoords, weights: "DataArray"
) -> "Weighted[T_DataWithCoords]":
def weighted(self: T_DataWithCoords, weights: "DataArray") -> Weighted[T_DSorDA]:
"""
Weighted operations.

Expand Down
11 changes: 4 additions & 7 deletions xarray/core/computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
Optional,
Sequence,
Tuple,
TypeVar,
Union,
)

Expand All @@ -36,11 +35,9 @@
from .variable import Variable

if TYPE_CHECKING:
from .coordinates import Coordinates # noqa
from .dataarray import DataArray
from .coordinates import Coordinates
from .dataset import Dataset

T_DSorDA = TypeVar("T_DSorDA", DataArray, Dataset)
from .types import T_DSorDA

_NO_FILL_VALUE = utils.ReprObject("<no-fill-value>")
_DEFAULT_NAME = utils.ReprObject("<default-name>")
Expand Down Expand Up @@ -199,7 +196,7 @@ def result_name(objects: list) -> Any:
return name


def _get_coords_list(args) -> List["Coordinates"]:
def _get_coords_list(args) -> List[Coordinates]:
coords_list = []
for arg in args:
try:
Expand Down Expand Up @@ -401,7 +398,7 @@ def apply_dict_of_variables_vfunc(

def _fast_dataset(
variables: Dict[Hashable, Variable], coord_variables: Mapping[Hashable, Variable]
) -> "Dataset":
) -> Dataset:
"""Create a dataset as quickly as possible.

Beware: the `variables` dict is modified INPLACE.
Expand Down
7 changes: 4 additions & 3 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import datetime
import warnings
from typing import (
Expand All @@ -12,7 +14,6 @@
Optional,
Sequence,
Tuple,
TypeVar,
Union,
cast,
)
Expand Down Expand Up @@ -76,8 +77,6 @@
assert_unique_multiindex_level_names,
)

T_DataArray = TypeVar("T_DataArray", bound="DataArray")
T_DSorDA = TypeVar("T_DSorDA", "DataArray", Dataset)
if TYPE_CHECKING:
try:
from dask.delayed import Delayed
Expand All @@ -92,6 +91,8 @@
except ImportError:
iris_Cube = None

from .types import T_DSorDA


def _infer_coords_and_dims(
shape, coords, dims
Expand Down
4 changes: 1 addition & 3 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
Sequence,
Set,
Tuple,
TypeVar,
Union,
cast,
overload,
Expand Down Expand Up @@ -110,8 +109,7 @@
from ..backends import AbstractDataStore, ZarrStore
from .dataarray import DataArray
from .merge import CoercibleMapping

T_DSorDA = TypeVar("T_DSorDA", DataArray, "Dataset")
from .types import T_DSorDA

try:
from dask.delayed import Delayed
Expand Down
7 changes: 5 additions & 2 deletions xarray/core/parallel.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

try:
import dask
import dask.array
Expand All @@ -11,6 +13,7 @@
import itertools
import operator
from typing import (
TYPE_CHECKING,
Any,
Callable,
DefaultDict,
Expand All @@ -21,7 +24,6 @@
Mapping,
Sequence,
Tuple,
TypeVar,
Union,
)

Expand All @@ -33,7 +35,8 @@
from .dataarray import DataArray
from .dataset import Dataset

T_DSorDA = TypeVar("T_DSorDA", DataArray, Dataset)
if TYPE_CHECKING:
from .types import T_DSorDA


def unzip(iterable):
Expand Down
19 changes: 13 additions & 6 deletions xarray/core/rolling_exp.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

from distutils.version import LooseVersion
from typing import TYPE_CHECKING, Generic, Hashable, Mapping, TypeVar, Union

Expand All @@ -7,12 +9,6 @@
from .pdcompat import count_not_none
from .pycompat import is_duck_dask_array

if TYPE_CHECKING:
from .dataarray import DataArray # noqa: F401
from .dataset import Dataset # noqa: F401

T_DSorDA = TypeVar("T_DSorDA", "DataArray", "Dataset")


def _get_alpha(com=None, span=None, halflife=None, alpha=None):
# pandas defines in terms of com (converting to alpha in the algo)
Expand Down Expand Up @@ -79,6 +75,17 @@ def _get_center_of_mass(comass, span, halflife, alpha):
return float(comass)


# We seem to need to redefine T_DSorDA here, rather than importing `core.types`, because
# a) it needs to be defined (can't be a string) b) it can't be behind an `if
# TYPE_CHECKING` branch and c) we have import errors if we import it without at the
# module level like: from .types import T_DSorDA

if TYPE_CHECKING:
from .dataarray import DataArray
from .dataset import Dataset
T_DSorDA = TypeVar("T_DSorDA", "DataArray", "Dataset")


class RollingExp(Generic[T_DSorDA]):
"""
Exponentially-weighted moving window object.
Expand Down
31 changes: 31 additions & 0 deletions xarray/core/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from __future__ import annotations

from typing import TYPE_CHECKING, TypeVar, Union

import numpy as np

if TYPE_CHECKING:
from .common import DataWithCoords
from .dataarray import DataArray
from .dataset import Dataset
from .groupby import DataArrayGroupBy, GroupBy
from .npcompat import ArrayLike
from .variable import Variable

try:
from dask.array import Array as DaskArray
except ImportError:
DaskArray = np.ndarray

T_Dataset = TypeVar("T_Dataset", bound=Dataset)
T_DataArray = TypeVar("T_DataArray", bound=DataArray)
T_Variable = TypeVar("T_Variable", bound=Variable)
# Maybe we rename this to T_Data or something less Fortran-y?
T_DSorDA = TypeVar("T_DSorDA", "DataArray", Dataset)
T_DataWithCoords = TypeVar("T_DataWithCoords", bound=DataWithCoords)

ScalarOrArray = ArrayLike | np.generic | np.ndarray | DaskArray
DsCompatible = Dataset | DataArray | Variable | GroupBy | ScalarOrArray
DaCompatible = DataArray | Variable | DataArrayGroupBy | ScalarOrArray
VarCompatible = Variable | ScalarOrArray
GroupByIncompatible = Variable | GroupBy
33 changes: 13 additions & 20 deletions xarray/core/variable.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from __future__ import annotations

import copy
import itertools
import numbers
import warnings
from collections import defaultdict
from datetime import timedelta
from typing import (
TYPE_CHECKING,
Any,
Dict,
Hashable,
Expand All @@ -13,7 +16,6 @@
Optional,
Sequence,
Tuple,
TypeVar,
Union,
)

Expand Down Expand Up @@ -58,17 +60,8 @@
# https://github.com/python/mypy/issues/224
BASIC_INDEXING_TYPES = integer_types + (slice,)

VariableType = TypeVar("VariableType", bound="Variable")
"""Type annotation to be used when methods of Variable return self or a copy of self.
When called from an instance of a subclass, e.g. IndexVariable, mypy identifies the
output as an instance of the subclass.

Usage::

class Variable:
def f(self: VariableType, ...) -> VariableType:
...
"""
if TYPE_CHECKING:
from .types import T_Variable


class MissingDimensionsError(ValueError):
Expand Down Expand Up @@ -357,15 +350,15 @@ def data(self, data):
self._data = data

def astype(
self: VariableType,
self: T_Variable,
dtype,
*,
order=None,
casting=None,
subok=None,
copy=None,
keep_attrs=True,
) -> VariableType:
) -> T_Variable:
"""
Copy of the Variable object, with data cast to a specified type.

Expand Down Expand Up @@ -763,7 +756,7 @@ def _broadcast_indexes_vectorized(self, key):

return out_dims, VectorizedIndexer(tuple(out_key)), new_order

def __getitem__(self: VariableType, key) -> VariableType:
def __getitem__(self: T_Variable, key) -> T_Variable:
"""Return a new Variable object whose contents are consistent with
getting the provided key from the underlying data.

Expand All @@ -782,7 +775,7 @@ def __getitem__(self: VariableType, key) -> VariableType:
data = np.moveaxis(data, range(len(new_order)), new_order)
return self._finalize_indexing_result(dims, data)

def _finalize_indexing_result(self: VariableType, dims, data) -> VariableType:
def _finalize_indexing_result(self: T_Variable, dims, data) -> T_Variable:
"""Used by IndexVariable to return IndexVariable objects when possible."""
return self._replace(dims=dims, data=data)

Expand Down Expand Up @@ -962,12 +955,12 @@ def copy(self, deep=True, data=None):
return self._replace(data=data)

def _replace(
self: VariableType,
self: T_Variable,
dims=_default,
data=_default,
attrs=_default,
encoding=_default,
) -> VariableType:
) -> T_Variable:
if dims is _default:
dims = copy.copy(self._dims)
if data is _default:
Expand Down Expand Up @@ -1100,11 +1093,11 @@ def _to_dense(self):
return self.copy(deep=False)

def isel(
self: VariableType,
self: T_Variable,
indexers: Mapping[Hashable, Any] = None,
missing_dims: str = "raise",
**indexers_kwargs: Any,
) -> VariableType:
) -> T_Variable:
"""Return a new array indexed along the specified dimension(s).

Parameters
Expand Down
Loading