Skip to content

Commit 99d899f

Browse files
committed
refactor evaluation pipeline
1 parent f61d8a5 commit 99d899f

33 files changed

+2166
-1936
lines changed

experiments/EXP_REAL_DATA.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import multiprocessing
2+
import pathlib as pathlib
3+
import time
4+
5+
from experiments.pipeline.helpers.arg_parsers import get_real_data_arg_parser
6+
from experiments.pipeline.helpers.data_loaders import DATASETS
7+
from experiments.pipeline.helpers.misc import build_real_data_name
8+
from experiments.pipeline.pipeline import Pipeline
9+
10+
if __name__ == "__main__":
11+
args = get_real_data_arg_parser()
12+
13+
# Freezing support for multiprocessing
14+
multiprocessing.freeze_support()
15+
16+
data = DATASETS[args.dataset]["load"]()
17+
18+
pipeline_name = build_real_data_name(args.dataset, args.gen_func)
19+
20+
pipeline = Pipeline(
21+
model_name=args.model,
22+
params_name=args.params,
23+
data=data,
24+
freq=DATASETS[args.dataset]["freq"],
25+
pipeline_name=pipeline_name,
26+
base_dir_name=pathlib.Path(__file__).parent.absolute(),
27+
)
28+
29+
# kwargs could contain:
30+
# scalers,
31+
# scaling_levels,
32+
# weighted_loss,
33+
# norm_types,
34+
# norm_modes,
35+
# norm_affines,
36+
# e.g. kwargs = {"scalers": [StandardScaler()], "scaling_levels": ["per_time_series"]}
37+
kwargs = {}
38+
39+
start_time = time.time()
40+
pipeline.run(
41+
save=True, test_percentage=0.25, params_generator_name=args.gen_func, with_scalers=args.with_scalers, **kwargs
42+
)
43+
end_time = time.time()
44+
45+
print("Pipeline execution time: ", end_time - start_time)
46+
pipeline.summary()

0 commit comments

Comments
 (0)