22Encoders for encoding categorical variables and scaling continuous data.
33"""
44
5- from typing import Any , Callable , Dict , Iterable , List , Tuple , Union
5+ from typing import Any , Callable , Dict , Iterable , List , Tuple , Union , Optional
66import warnings
77
88import numpy as np
99import pandas as pd
10+ from copy import deepcopy
1011from sklearn .base import BaseEstimator , TransformerMixin
1112import torch
1213from torch .distributions import constraints
@@ -396,7 +397,7 @@ def __init__(
396397 method : str = "standard" ,
397398 center : bool = True ,
398399 transformation : Union [str , Tuple [Callable , Callable ]] = None ,
399- method_kwargs : Dict [str , Any ] = {} ,
400+ method_kwargs : Optional [ Dict [str , Any ]] = None ,
400401 ):
401402 """
402403 Args:
@@ -428,6 +429,7 @@ def __init__(
428429 self .center = center
429430 self .transformation = transformation
430431 self .method_kwargs = method_kwargs
432+ self ._method_kwargs = deepcopy (method_kwargs ) if method_kwargs is not None else {}
431433
432434 def get_parameters (self , * args , ** kwargs ) -> torch .Tensor :
433435 """
@@ -496,17 +498,17 @@ def _set_parameters(
496498
497499 elif self .method == "robust" :
498500 if isinstance (y_center , torch .Tensor ):
499- self .center_ = y_center .quantile (self .method_kwargs .get ("center" , 0.5 ), dim = - 1 )
500- q_75 = y_scale .quantile (self .method_kwargs .get ("upper" , 0.75 ), dim = - 1 )
501- q_25 = y_scale .quantile (self .method_kwargs .get ("lower" , 0.25 ), dim = - 1 )
501+ self .center_ = y_center .quantile (self ._method_kwargs .get ("center" , 0.5 ), dim = - 1 )
502+ q_75 = y_scale .quantile (self ._method_kwargs .get ("upper" , 0.75 ), dim = - 1 )
503+ q_25 = y_scale .quantile (self ._method_kwargs .get ("lower" , 0.25 ), dim = - 1 )
502504 elif isinstance (y_center , np .ndarray ):
503- self .center_ = np .percentile (y_center , self .method_kwargs .get ("center" , 0.5 ) * 100 , axis = - 1 )
504- q_75 = np .percentile (y_scale , self .method_kwargs .get ("upper" , 0.75 ) * 100 , axis = - 1 )
505- q_25 = np .percentile (y_scale , self .method_kwargs .get ("lower" , 0.25 ) * 100 , axis = - 1 )
505+ self .center_ = np .percentile (y_center , self ._method_kwargs .get ("center" , 0.5 ) * 100 , axis = - 1 )
506+ q_75 = np .percentile (y_scale , self ._method_kwargs .get ("upper" , 0.75 ) * 100 , axis = - 1 )
507+ q_25 = np .percentile (y_scale , self ._method_kwargs .get ("lower" , 0.25 ) * 100 , axis = - 1 )
506508 else :
507- self .center_ = np .percentile (y_center , self .method_kwargs .get ("center" , 0.5 ) * 100 , axis = - 1 )
508- q_75 = np .percentile (y_scale , self .method_kwargs .get ("upper" , 0.75 ) * 100 )
509- q_25 = np .percentile (y_scale , self .method_kwargs .get ("lower" , 0.25 ) * 100 )
509+ self .center_ = np .percentile (y_center , self ._method_kwargs .get ("center" , 0.5 ) * 100 , axis = - 1 )
510+ q_75 = np .percentile (y_scale , self ._method_kwargs .get ("upper" , 0.75 ) * 100 )
511+ q_25 = np .percentile (y_scale , self ._method_kwargs .get ("lower" , 0.25 ) * 100 )
510512 self .scale_ = (q_75 - q_25 ) / 2.0 + eps
511513 if not self .center and self .method != "identity" :
512514 self .scale_ = self .center_
@@ -623,7 +625,7 @@ def __init__(
623625 center : bool = True ,
624626 max_length : Union [int , List [int ]] = None ,
625627 transformation : Union [str , Tuple [Callable , Callable ]] = None ,
626- method_kwargs : Dict [str , Any ] = {} ,
628+ method_kwargs : Dict [str , Any ] = None ,
627629 ):
628630 """
629631 Initialize
@@ -655,6 +657,7 @@ def __init__(
655657 should be defined if ``reverse`` is not the inverse of the forward transformation. ``inverse_torch``
656658 can be defined to provide a torch distribution transform for inverse transformations.
657659 """
660+ method_kwargs = deepcopy (method_kwargs ) if method_kwargs is not None else {}
658661 super ().__init__ (method = method , center = center , transformation = transformation , method_kwargs = method_kwargs )
659662 self .max_length = max_length
660663
@@ -726,11 +729,11 @@ class GroupNormalizer(TorchNormalizer):
726729 def __init__ (
727730 self ,
728731 method : str = "standard" ,
729- groups : List [str ] = [] ,
732+ groups : Optional [ List [str ]] = None ,
730733 center : bool = True ,
731734 scale_by_group : bool = False ,
732- transformation : Union [str , Tuple [Callable , Callable ]] = None ,
733- method_kwargs : Dict [str , Any ] = {} ,
735+ transformation : Optional [ Union [str , Tuple [Callable , Callable ] ]] = None ,
736+ method_kwargs : Optional [ Dict [str , Any ]] = None ,
734737 ):
735738 """
736739 Group normalizer to normalize a given entry by groups. Can be used as target normalizer.
@@ -765,7 +768,9 @@ def __init__(
765768
766769 """
767770 self .groups = groups
771+ self ._groups = list (groups ) if groups is not None else []
768772 self .scale_by_group = scale_by_group
773+ method_kwargs = deepcopy (method_kwargs ) if method_kwargs is not None else {}
769774 super ().__init__ (method = method , center = center , transformation = transformation , method_kwargs = method_kwargs )
770775
771776 def fit (self , y : pd .Series , X : pd .DataFrame ):
@@ -781,17 +786,17 @@ def fit(self, y: pd.Series, X: pd.DataFrame):
781786 """
782787 y = self .preprocess (y )
783788 eps = np .finfo (np .float16 ).eps
784- if len (self .groups ) == 0 :
789+ if len (self ._groups ) == 0 :
785790 assert not self .scale_by_group , "No groups are defined, i.e. `scale_by_group=[]`"
786791 if self .method == "standard" :
787792 self .norm_ = {"center" : np .mean (y ), "scale" : np .std (y ) + eps } # center and scale
788793 else :
789794 quantiles = np .quantile (
790795 y ,
791796 [
792- self .method_kwargs .get ("lower" , 0.25 ),
793- self .method_kwargs .get ("center" , 0.5 ),
794- self .method_kwargs .get ("upper" , 0.75 ),
797+ self ._method_kwargs .get ("lower" , 0.25 ),
798+ self ._method_kwargs .get ("center" , 0.5 ),
799+ self ._method_kwargs .get ("upper" , 0.75 ),
795800 ],
796801 )
797802 self .norm_ = {
@@ -810,7 +815,7 @@ def fit(self, y: pd.Series, X: pd.DataFrame):
810815 .groupby (g , observed = True )
811816 .agg (center = ("y" , "mean" ), scale = ("y" , "std" ))
812817 .assign (center = lambda x : x ["center" ], scale = lambda x : x .scale + eps )
813- for g in self .groups
818+ for g in self ._groups
814819 }
815820 else :
816821 self .norm_ = {
@@ -819,21 +824,21 @@ def fit(self, y: pd.Series, X: pd.DataFrame):
819824 .groupby (g , observed = True )
820825 .y .quantile (
821826 [
822- self .method_kwargs .get ("lower" , 0.25 ),
823- self .method_kwargs .get ("center" , 0.5 ),
824- self .method_kwargs .get ("upper" , 0.75 ),
827+ self ._method_kwargs .get ("lower" , 0.25 ),
828+ self ._method_kwargs .get ("center" , 0.5 ),
829+ self ._method_kwargs .get ("upper" , 0.75 ),
825830 ]
826831 )
827832 .unstack (- 1 )
828833 .assign (
829- center = lambda x : x [self .method_kwargs .get ("center" , 0.5 )],
834+ center = lambda x : x [self ._method_kwargs .get ("center" , 0.5 )],
830835 scale = lambda x : (
831- x [self .method_kwargs .get ("upper" , 0.75 )] - x [self .method_kwargs .get ("lower" , 0.25 )]
836+ x [self ._method_kwargs .get ("upper" , 0.75 )] - x [self ._method_kwargs .get ("lower" , 0.25 )]
832837 )
833838 / 2.0
834839 + eps ,
835840 )[["center" , "scale" ]]
836- for g in self .groups
841+ for g in self ._groups
837842 }
838843 # calculate missings
839844 if not self .center : # swap center and scale
@@ -849,29 +854,29 @@ def swap_parameters(norm):
849854 else :
850855 if self .method == "standard" :
851856 self .norm_ = (
852- X [self .groups ]
857+ X [self ._groups ]
853858 .assign (y = y )
854- .groupby (self .groups , observed = True )
859+ .groupby (self ._groups , observed = True )
855860 .agg (center = ("y" , "mean" ), scale = ("y" , "std" ))
856861 .assign (center = lambda x : x ["center" ], scale = lambda x : x .scale + eps )
857862 )
858863 else :
859864 self .norm_ = (
860- X [self .groups ]
865+ X [self ._groups ]
861866 .assign (y = y )
862- .groupby (self .groups , observed = True )
867+ .groupby (self ._groups , observed = True )
863868 .y .quantile (
864869 [
865- self .method_kwargs .get ("lower" , 0.25 ),
866- self .method_kwargs .get ("center" , 0.5 ),
867- self .method_kwargs .get ("upper" , 0.75 ),
870+ self ._method_kwargs .get ("lower" , 0.25 ),
871+ self ._method_kwargs .get ("center" , 0.5 ),
872+ self ._method_kwargs .get ("upper" , 0.75 ),
868873 ]
869874 )
870875 .unstack (- 1 )
871876 .assign (
872- center = lambda x : x [self .method_kwargs .get ("center" , 0.5 )],
877+ center = lambda x : x [self ._method_kwargs .get ("center" , 0.5 )],
873878 scale = lambda x : (
874- x [self .method_kwargs .get ("upper" , 0.75 )] - x [self .method_kwargs .get ("lower" , 0.25 )]
879+ x [self ._method_kwargs .get ("upper" , 0.75 )] - x [self ._method_kwargs .get ("lower" , 0.25 )]
875880 )
876881 / 2.0
877882 + eps ,
@@ -883,7 +888,7 @@ def swap_parameters(norm):
883888 self .missing_ = self .norm_ .median ().to_dict ()
884889
885890 if (
886- (self .scale_by_group and any ((self .norm_ [group ]["scale" ] < 1e-7 ).any () for group in self .groups ))
891+ (self .scale_by_group and any ((self .norm_ [group ]["scale" ] < 1e-7 ).any () for group in self ._groups ))
887892 or (not self .scale_by_group and isinstance (self .norm_ ["scale" ], float ) and self .norm_ ["scale" ] < 1e-7 )
888893 or (
889894 not self .scale_by_group
@@ -973,13 +978,13 @@ def get_parameters(self, groups: Union[torch.Tensor, list, tuple], group_names:
973978 if isinstance (groups , list ):
974979 groups = tuple (groups )
975980 if group_names is None :
976- group_names = self .groups
981+ group_names = self ._groups
977982 else :
978983 # filter group names
979- group_names = [name for name in group_names if name in self .groups ]
980- assert len (group_names ) == len (self .groups ), "Passed groups and fitted do not match"
984+ group_names = [name for name in group_names if name in self ._groups ]
985+ assert len (group_names ) == len (self ._groups ), "Passed groups and fitted do not match"
981986
982- if len (self .groups ) == 0 :
987+ if len (self ._groups ) == 0 :
983988 params = np .array ([self .norm_ ["center" ], self .norm_ ["scale" ]])
984989 elif self .scale_by_group :
985990 norm = np .array ([1.0 , 1.0 ])
@@ -988,7 +993,7 @@ def get_parameters(self, groups: Union[torch.Tensor, list, tuple], group_names:
988993 norm = norm * self .norm_ [group_name ].loc [group ].to_numpy ()
989994 except KeyError :
990995 norm = norm * np .asarray ([self .missing_ [group_name ][name ] for name in self .names ])
991- norm = np .power (norm , 1.0 / len (self .groups ))
996+ norm = np .power (norm , 1.0 / len (self ._groups ))
992997 params = norm
993998 else :
994999 try :
@@ -1007,7 +1012,7 @@ def get_norm(self, X: pd.DataFrame) -> pd.DataFrame:
10071012 Returns:
10081013 pd.DataFrame: dataframe with scaling parameterswhere each row corresponds to the input dataframe
10091014 """
1010- if len (self .groups ) == 0 :
1015+ if len (self ._groups ) == 0 :
10111016 norm = np .asarray ([self .norm_ ["center" ], self .norm_ ["scale" ]]).reshape (1 , - 1 )
10121017 elif self .scale_by_group :
10131018 norm = [
@@ -1017,15 +1022,15 @@ def get_norm(self, X: pd.DataFrame) -> pd.DataFrame:
10171022 .map (self .norm_ [group_name ][name ])
10181023 .fillna (self .missing_ [group_name ][name ])
10191024 .to_numpy ()
1020- for group_name in self .groups
1025+ for group_name in self ._groups
10211026 ],
10221027 axis = 0 ,
10231028 )
10241029 for name in self .names
10251030 ]
1026- norm = np .power (np .stack (norm , axis = 1 ), 1.0 / len (self .groups ))
1031+ norm = np .power (np .stack (norm , axis = 1 ), 1.0 / len (self ._groups ))
10271032 else :
1028- norm = X [self .groups ].set_index (self .groups ).join (self .norm_ ).fillna (self .missing_ ).to_numpy ()
1033+ norm = X [self ._groups ].set_index (self ._groups ).join (self .norm_ ).fillna (self .missing_ ).to_numpy ()
10291034 return norm
10301035
10311036
0 commit comments