@@ -42,10 +42,10 @@ class Base_pkg(_BasePtForecasterV2):
4242
4343 def __init__ (
4444 self ,
45- model_cfg : Optional [ Union [ dict [str , Any ], str , Path ]] = None ,
46- trainer_cfg : Optional [ Union [ dict [str , Any ], str , Path ]] = None ,
47- datamodule_cfg : Optional [ Union [ dict [str , Any ], str , Path ]] = None ,
48- ckpt_path : Optional [ Union [ str , Path ]] = None ,
45+ model_cfg : dict [str , Any ] | str | Path | None = None ,
46+ trainer_cfg : dict [str , Any ] | str | Path | None = None ,
47+ datamodule_cfg : dict [str , Any ] | str | Path | None = None ,
48+ ckpt_path : str | Path | None = None ,
4949 ):
5050 self .ckpt_path = Path (ckpt_path ) if ckpt_path else None
5151 self .model_cfg = self ._load_config (
@@ -74,9 +74,9 @@ def __init__(
7474
7575 @staticmethod
7676 def _load_config (
77- config : Union [ dict , str , Path , None ] ,
78- ckpt_path : Optional [ Union [ str , Path ]] = None ,
79- auto_file_name : Optional [ str ] = None ,
77+ config : dict | str | Path | None ,
78+ ckpt_path : str | Path | None = None ,
79+ auto_file_name : str | None = None ,
8080 ) -> dict :
8181 """
8282 Loads configuration from a dictionary, YAML file, or Pickle file.
@@ -157,7 +157,7 @@ def _build_datamodule(self, data: TimeSeries) -> LightningDataModule:
157157 return datamodule_cls (data , ** self .datamodule_cfg )
158158
159159 def _load_dataloader (
160- self , data : Union [ TimeSeries , LightningDataModule , DataLoader ]
160+ self , data : TimeSeries | LightningDataModule | DataLoader
161161 ) -> DataLoader :
162162 """Converts various data input types into a DataLoader for prediction."""
163163 if isinstance (data , TimeSeries ): # D1 Layer
@@ -191,11 +191,11 @@ def _save_artifact(self, output_dir: Path):
191191
192192 def fit (
193193 self ,
194- data : Union [ TimeSeries , LightningDataModule ] ,
194+ data : TimeSeries | LightningDataModule ,
195195 # todo: we should create a base data_module for different data_modules
196196 save_ckpt : bool = True ,
197- ckpt_dir : Union [ str , Path ] = "checkpoints" ,
198- ckpt_kwargs : Optional [ dict [str , Any ]] = None ,
197+ ckpt_dir : str | Path = "checkpoints" ,
198+ ckpt_kwargs : dict [str , Any ] | None = None ,
199199 ** trainer_fit_kwargs ,
200200 ):
201201 """
@@ -265,10 +265,10 @@ def fit(
265265
266266 def predict (
267267 self ,
268- data : Union [ TimeSeries , LightningDataModule , DataLoader ] ,
269- output_dir : Optional [ Union [ str , Path ]] = None ,
268+ data : TimeSeries | LightningDataModule | DataLoader ,
269+ output_dir : str | Path | None = None ,
270270 ** kwargs ,
271- ) -> Union [ dict [str , torch .Tensor ], None ] :
271+ ) -> dict [str , torch .Tensor ] | None :
272272 """
273273 Generate predictions by wrapping the model's predict method.
274274
0 commit comments