Skip to content

Commit 193f539

Browse files
[MNT] Remove mutable objects from defaults (#1699)
### Description This PR is a fix for #1668. Removes mutable default arguments and replaces them with internal newly initialized mutable defaults.
1 parent a884c4d commit 193f539

17 files changed

Lines changed: 511 additions & 303 deletions

File tree

docs/source/tutorials/building.ipynb

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -875,13 +875,13 @@
875875
" assert dataset.max_prediction_length == dataset.min_prediction_length, \"Decoder only supports a fixed length\"\n",
876876
" assert dataset.min_encoder_length == dataset.max_encoder_length, \"Encoder only supports a fixed length\"\n",
877877
" assert (\n",
878-
" len(dataset.time_varying_known_categoricals) == 0\n",
879-
" and len(dataset.time_varying_known_reals) == 0\n",
880-
" and len(dataset.time_varying_unknown_categoricals) == 0\n",
881-
" and len(dataset.static_categoricals) == 0\n",
882-
" and len(dataset.static_reals) == 0\n",
883-
" and len(dataset.time_varying_unknown_reals) == 1\n",
884-
" and dataset.time_varying_unknown_reals[0] == dataset.target\n",
878+
" len(dataset._time_varying_known_categoricals) == 0\n",
879+
" and len(dataset._time_varying_known_reals) == 0\n",
880+
" and len(dataset._time_varying_unknown_categoricals) == 0\n",
881+
" and len(dataset._static_categoricals) == 0\n",
882+
" and len(dataset._static_reals) == 0\n",
883+
" and len(dataset._time_varying_unknown_reals) == 1\n",
884+
" and dataset._time_varying_unknown_reals[0] == dataset.target\n",
885885
" ), \"Only covariate should be the target in 'time_varying_unknown_reals'\"\n",
886886
"\n",
887887
" return super().from_dataset(dataset, **new_kwargs)"
@@ -1587,12 +1587,12 @@
15871587
" assert dataset.max_prediction_length == dataset.min_prediction_length, \"Decoder only supports a fixed length\"\n",
15881588
" assert dataset.min_encoder_length == dataset.max_encoder_length, \"Encoder only supports a fixed length\"\n",
15891589
" assert (\n",
1590-
" len(dataset.time_varying_known_categoricals) == 0\n",
1591-
" and len(dataset.time_varying_known_reals) == 0\n",
1592-
" and len(dataset.time_varying_unknown_categoricals) == 0\n",
1593-
" and len(dataset.static_categoricals) == 0\n",
1594-
" and len(dataset.static_reals) == 0\n",
1595-
" and len(dataset.time_varying_unknown_reals) == 1\n",
1590+
" len(dataset._time_varying_known_categoricals) == 0\n",
1591+
" and len(dataset._time_varying_known_reals) == 0\n",
1592+
" and len(dataset._time_varying_unknown_categoricals) == 0\n",
1593+
" and len(dataset._static_categoricals) == 0\n",
1594+
" and len(dataset._static_reals) == 0\n",
1595+
" and len(dataset._time_varying_unknown_reals) == 1\n",
15961596
" ), \"Only covariate should be in 'time_varying_unknown_reals'\"\n",
15971597
"\n",
15981598
" return super().from_dataset(dataset, **new_kwargs)\n",
@@ -2136,12 +2136,12 @@
21362136
" assert dataset.max_prediction_length == dataset.min_prediction_length, \"Decoder only supports a fixed length\"\n",
21372137
" assert dataset.min_encoder_length == dataset.max_encoder_length, \"Encoder only supports a fixed length\"\n",
21382138
" assert (\n",
2139-
" len(dataset.time_varying_known_categoricals) == 0\n",
2140-
" and len(dataset.time_varying_known_reals) == 0\n",
2141-
" and len(dataset.time_varying_unknown_categoricals) == 0\n",
2142-
" and len(dataset.static_categoricals) == 0\n",
2143-
" and len(dataset.static_reals) == 0\n",
2144-
" and len(dataset.time_varying_unknown_reals)\n",
2139+
" len(dataset._time_varying_known_categoricals) == 0\n",
2140+
" and len(dataset._time_varying_known_reals) == 0\n",
2141+
" and len(dataset._time_varying_unknown_categoricals) == 0\n",
2142+
" and len(dataset._static_categoricals) == 0\n",
2143+
" and len(dataset._static_reals) == 0\n",
2144+
" and len(dataset._time_varying_unknown_reals)\n",
21452145
" == len(dataset.target_names) # Expect as as many unknown reals as targets\n",
21462146
" ), \"Only covariate should be in 'time_varying_unknown_reals'\"\n",
21472147
"\n",
@@ -3414,13 +3414,13 @@
34143414
" assert dataset.max_prediction_length == dataset.min_prediction_length, \"Decoder only supports a fixed length\"\n",
34153415
" assert dataset.min_encoder_length == dataset.max_encoder_length, \"Encoder only supports a fixed length\"\n",
34163416
" assert (\n",
3417-
" len(dataset.time_varying_known_categoricals) == 0\n",
3418-
" and len(dataset.time_varying_known_reals) == 0\n",
3419-
" and len(dataset.time_varying_unknown_categoricals) == 0\n",
3420-
" and len(dataset.static_categoricals) == 0\n",
3421-
" and len(dataset.static_reals) == 0\n",
3422-
" and len(dataset.time_varying_unknown_reals) == 1\n",
3423-
" and dataset.time_varying_unknown_reals[0] == dataset.target\n",
3417+
" len(dataset._time_varying_known_categoricals) == 0\n",
3418+
" and len(dataset._time_varying_known_reals) == 0\n",
3419+
" and len(dataset._time_varying_unknown_categoricals) == 0\n",
3420+
" and len(dataset._static_categoricals) == 0\n",
3421+
" and len(dataset._static_reals) == 0\n",
3422+
" and len(dataset._time_varying_unknown_reals) == 1\n",
3423+
" and dataset._time_varying_unknown_reals[0] == dataset.target\n",
34243424
" ), \"Only covariate should be the target in 'time_varying_unknown_reals'\"\n",
34253425
"\n",
34263426
" return super().from_dataset(dataset, **new_kwargs)\n",

pytorch_forecasting/data/encoders.py

Lines changed: 50 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,12 @@
22
Encoders for encoding categorical variables and scaling continuous data.
33
"""
44

5-
from typing import Any, Callable, Dict, Iterable, List, Tuple, Union
5+
from typing import Any, Callable, Dict, Iterable, List, Tuple, Union, Optional
66
import warnings
77

88
import numpy as np
99
import pandas as pd
10+
from copy import deepcopy
1011
from sklearn.base import BaseEstimator, TransformerMixin
1112
import torch
1213
from torch.distributions import constraints
@@ -396,7 +397,7 @@ def __init__(
396397
method: str = "standard",
397398
center: bool = True,
398399
transformation: Union[str, Tuple[Callable, Callable]] = None,
399-
method_kwargs: Dict[str, Any] = {},
400+
method_kwargs: Optional[Dict[str, Any]] = None,
400401
):
401402
"""
402403
Args:
@@ -428,6 +429,7 @@ def __init__(
428429
self.center = center
429430
self.transformation = transformation
430431
self.method_kwargs = method_kwargs
432+
self._method_kwargs = deepcopy(method_kwargs) if method_kwargs is not None else {}
431433

432434
def get_parameters(self, *args, **kwargs) -> torch.Tensor:
433435
"""
@@ -496,17 +498,17 @@ def _set_parameters(
496498

497499
elif self.method == "robust":
498500
if isinstance(y_center, torch.Tensor):
499-
self.center_ = y_center.quantile(self.method_kwargs.get("center", 0.5), dim=-1)
500-
q_75 = y_scale.quantile(self.method_kwargs.get("upper", 0.75), dim=-1)
501-
q_25 = y_scale.quantile(self.method_kwargs.get("lower", 0.25), dim=-1)
501+
self.center_ = y_center.quantile(self._method_kwargs.get("center", 0.5), dim=-1)
502+
q_75 = y_scale.quantile(self._method_kwargs.get("upper", 0.75), dim=-1)
503+
q_25 = y_scale.quantile(self._method_kwargs.get("lower", 0.25), dim=-1)
502504
elif isinstance(y_center, np.ndarray):
503-
self.center_ = np.percentile(y_center, self.method_kwargs.get("center", 0.5) * 100, axis=-1)
504-
q_75 = np.percentile(y_scale, self.method_kwargs.get("upper", 0.75) * 100, axis=-1)
505-
q_25 = np.percentile(y_scale, self.method_kwargs.get("lower", 0.25) * 100, axis=-1)
505+
self.center_ = np.percentile(y_center, self._method_kwargs.get("center", 0.5) * 100, axis=-1)
506+
q_75 = np.percentile(y_scale, self._method_kwargs.get("upper", 0.75) * 100, axis=-1)
507+
q_25 = np.percentile(y_scale, self._method_kwargs.get("lower", 0.25) * 100, axis=-1)
506508
else:
507-
self.center_ = np.percentile(y_center, self.method_kwargs.get("center", 0.5) * 100, axis=-1)
508-
q_75 = np.percentile(y_scale, self.method_kwargs.get("upper", 0.75) * 100)
509-
q_25 = np.percentile(y_scale, self.method_kwargs.get("lower", 0.25) * 100)
509+
self.center_ = np.percentile(y_center, self._method_kwargs.get("center", 0.5) * 100, axis=-1)
510+
q_75 = np.percentile(y_scale, self._method_kwargs.get("upper", 0.75) * 100)
511+
q_25 = np.percentile(y_scale, self._method_kwargs.get("lower", 0.25) * 100)
510512
self.scale_ = (q_75 - q_25) / 2.0 + eps
511513
if not self.center and self.method != "identity":
512514
self.scale_ = self.center_
@@ -623,7 +625,7 @@ def __init__(
623625
center: bool = True,
624626
max_length: Union[int, List[int]] = None,
625627
transformation: Union[str, Tuple[Callable, Callable]] = None,
626-
method_kwargs: Dict[str, Any] = {},
628+
method_kwargs: Dict[str, Any] = None,
627629
):
628630
"""
629631
Initialize
@@ -655,6 +657,7 @@ def __init__(
655657
should be defined if ``reverse`` is not the inverse of the forward transformation. ``inverse_torch``
656658
can be defined to provide a torch distribution transform for inverse transformations.
657659
"""
660+
method_kwargs = deepcopy(method_kwargs) if method_kwargs is not None else {}
658661
super().__init__(method=method, center=center, transformation=transformation, method_kwargs=method_kwargs)
659662
self.max_length = max_length
660663

@@ -726,11 +729,11 @@ class GroupNormalizer(TorchNormalizer):
726729
def __init__(
727730
self,
728731
method: str = "standard",
729-
groups: List[str] = [],
732+
groups: Optional[List[str]] = None,
730733
center: bool = True,
731734
scale_by_group: bool = False,
732-
transformation: Union[str, Tuple[Callable, Callable]] = None,
733-
method_kwargs: Dict[str, Any] = {},
735+
transformation: Optional[Union[str, Tuple[Callable, Callable]]] = None,
736+
method_kwargs: Optional[Dict[str, Any]] = None,
734737
):
735738
"""
736739
Group normalizer to normalize a given entry by groups. Can be used as target normalizer.
@@ -765,7 +768,9 @@ def __init__(
765768
766769
"""
767770
self.groups = groups
771+
self._groups = list(groups) if groups is not None else []
768772
self.scale_by_group = scale_by_group
773+
method_kwargs = deepcopy(method_kwargs) if method_kwargs is not None else {}
769774
super().__init__(method=method, center=center, transformation=transformation, method_kwargs=method_kwargs)
770775

771776
def fit(self, y: pd.Series, X: pd.DataFrame):
@@ -781,17 +786,17 @@ def fit(self, y: pd.Series, X: pd.DataFrame):
781786
"""
782787
y = self.preprocess(y)
783788
eps = np.finfo(np.float16).eps
784-
if len(self.groups) == 0:
789+
if len(self._groups) == 0:
785790
assert not self.scale_by_group, "No groups are defined, i.e. `scale_by_group=[]`"
786791
if self.method == "standard":
787792
self.norm_ = {"center": np.mean(y), "scale": np.std(y) + eps} # center and scale
788793
else:
789794
quantiles = np.quantile(
790795
y,
791796
[
792-
self.method_kwargs.get("lower", 0.25),
793-
self.method_kwargs.get("center", 0.5),
794-
self.method_kwargs.get("upper", 0.75),
797+
self._method_kwargs.get("lower", 0.25),
798+
self._method_kwargs.get("center", 0.5),
799+
self._method_kwargs.get("upper", 0.75),
795800
],
796801
)
797802
self.norm_ = {
@@ -810,7 +815,7 @@ def fit(self, y: pd.Series, X: pd.DataFrame):
810815
.groupby(g, observed=True)
811816
.agg(center=("y", "mean"), scale=("y", "std"))
812817
.assign(center=lambda x: x["center"], scale=lambda x: x.scale + eps)
813-
for g in self.groups
818+
for g in self._groups
814819
}
815820
else:
816821
self.norm_ = {
@@ -819,21 +824,21 @@ def fit(self, y: pd.Series, X: pd.DataFrame):
819824
.groupby(g, observed=True)
820825
.y.quantile(
821826
[
822-
self.method_kwargs.get("lower", 0.25),
823-
self.method_kwargs.get("center", 0.5),
824-
self.method_kwargs.get("upper", 0.75),
827+
self._method_kwargs.get("lower", 0.25),
828+
self._method_kwargs.get("center", 0.5),
829+
self._method_kwargs.get("upper", 0.75),
825830
]
826831
)
827832
.unstack(-1)
828833
.assign(
829-
center=lambda x: x[self.method_kwargs.get("center", 0.5)],
834+
center=lambda x: x[self._method_kwargs.get("center", 0.5)],
830835
scale=lambda x: (
831-
x[self.method_kwargs.get("upper", 0.75)] - x[self.method_kwargs.get("lower", 0.25)]
836+
x[self._method_kwargs.get("upper", 0.75)] - x[self._method_kwargs.get("lower", 0.25)]
832837
)
833838
/ 2.0
834839
+ eps,
835840
)[["center", "scale"]]
836-
for g in self.groups
841+
for g in self._groups
837842
}
838843
# calculate missings
839844
if not self.center: # swap center and scale
@@ -849,29 +854,29 @@ def swap_parameters(norm):
849854
else:
850855
if self.method == "standard":
851856
self.norm_ = (
852-
X[self.groups]
857+
X[self._groups]
853858
.assign(y=y)
854-
.groupby(self.groups, observed=True)
859+
.groupby(self._groups, observed=True)
855860
.agg(center=("y", "mean"), scale=("y", "std"))
856861
.assign(center=lambda x: x["center"], scale=lambda x: x.scale + eps)
857862
)
858863
else:
859864
self.norm_ = (
860-
X[self.groups]
865+
X[self._groups]
861866
.assign(y=y)
862-
.groupby(self.groups, observed=True)
867+
.groupby(self._groups, observed=True)
863868
.y.quantile(
864869
[
865-
self.method_kwargs.get("lower", 0.25),
866-
self.method_kwargs.get("center", 0.5),
867-
self.method_kwargs.get("upper", 0.75),
870+
self._method_kwargs.get("lower", 0.25),
871+
self._method_kwargs.get("center", 0.5),
872+
self._method_kwargs.get("upper", 0.75),
868873
]
869874
)
870875
.unstack(-1)
871876
.assign(
872-
center=lambda x: x[self.method_kwargs.get("center", 0.5)],
877+
center=lambda x: x[self._method_kwargs.get("center", 0.5)],
873878
scale=lambda x: (
874-
x[self.method_kwargs.get("upper", 0.75)] - x[self.method_kwargs.get("lower", 0.25)]
879+
x[self._method_kwargs.get("upper", 0.75)] - x[self._method_kwargs.get("lower", 0.25)]
875880
)
876881
/ 2.0
877882
+ eps,
@@ -883,7 +888,7 @@ def swap_parameters(norm):
883888
self.missing_ = self.norm_.median().to_dict()
884889

885890
if (
886-
(self.scale_by_group and any((self.norm_[group]["scale"] < 1e-7).any() for group in self.groups))
891+
(self.scale_by_group and any((self.norm_[group]["scale"] < 1e-7).any() for group in self._groups))
887892
or (not self.scale_by_group and isinstance(self.norm_["scale"], float) and self.norm_["scale"] < 1e-7)
888893
or (
889894
not self.scale_by_group
@@ -973,13 +978,13 @@ def get_parameters(self, groups: Union[torch.Tensor, list, tuple], group_names:
973978
if isinstance(groups, list):
974979
groups = tuple(groups)
975980
if group_names is None:
976-
group_names = self.groups
981+
group_names = self._groups
977982
else:
978983
# filter group names
979-
group_names = [name for name in group_names if name in self.groups]
980-
assert len(group_names) == len(self.groups), "Passed groups and fitted do not match"
984+
group_names = [name for name in group_names if name in self._groups]
985+
assert len(group_names) == len(self._groups), "Passed groups and fitted do not match"
981986

982-
if len(self.groups) == 0:
987+
if len(self._groups) == 0:
983988
params = np.array([self.norm_["center"], self.norm_["scale"]])
984989
elif self.scale_by_group:
985990
norm = np.array([1.0, 1.0])
@@ -988,7 +993,7 @@ def get_parameters(self, groups: Union[torch.Tensor, list, tuple], group_names:
988993
norm = norm * self.norm_[group_name].loc[group].to_numpy()
989994
except KeyError:
990995
norm = norm * np.asarray([self.missing_[group_name][name] for name in self.names])
991-
norm = np.power(norm, 1.0 / len(self.groups))
996+
norm = np.power(norm, 1.0 / len(self._groups))
992997
params = norm
993998
else:
994999
try:
@@ -1007,7 +1012,7 @@ def get_norm(self, X: pd.DataFrame) -> pd.DataFrame:
10071012
Returns:
10081013
pd.DataFrame: dataframe with scaling parameterswhere each row corresponds to the input dataframe
10091014
"""
1010-
if len(self.groups) == 0:
1015+
if len(self._groups) == 0:
10111016
norm = np.asarray([self.norm_["center"], self.norm_["scale"]]).reshape(1, -1)
10121017
elif self.scale_by_group:
10131018
norm = [
@@ -1017,15 +1022,15 @@ def get_norm(self, X: pd.DataFrame) -> pd.DataFrame:
10171022
.map(self.norm_[group_name][name])
10181023
.fillna(self.missing_[group_name][name])
10191024
.to_numpy()
1020-
for group_name in self.groups
1025+
for group_name in self._groups
10211026
],
10221027
axis=0,
10231028
)
10241029
for name in self.names
10251030
]
1026-
norm = np.power(np.stack(norm, axis=1), 1.0 / len(self.groups))
1031+
norm = np.power(np.stack(norm, axis=1), 1.0 / len(self._groups))
10271032
else:
1028-
norm = X[self.groups].set_index(self.groups).join(self.norm_).fillna(self.missing_).to_numpy()
1033+
norm = X[self._groups].set_index(self._groups).join(self.norm_).fillna(self.missing_).to_numpy()
10291034
return norm
10301035

10311036

0 commit comments

Comments
 (0)