File tree Expand file tree Collapse file tree 1 file changed +18
-0
lines changed
bsuite/experiments/discounting_chain Expand file tree Collapse file tree 1 file changed +18
-0
lines changed Original file line number Diff line number Diff line change @@ -39,9 +39,27 @@ def score(df: pd.DataFrame) -> float:
3939 return np .clip (raw_score , 0 , 1 )
4040
4141
42+ def _mapping_seed_compatibility (df : pd .DataFrame ) -> pd .DataFrame :
43+ """Utility function to maintain compatibility with old bsuite runs."""
44+ # Discounting chain kwarg "seed" was renamed to "mapping_seed"
45+ if 'mapping_seed' in df .columns :
46+ nan_seeds = df .mapping_seed .isna ()
47+ if np .any (nan_seeds ):
48+ df .loc [nan_seeds , 'mapping_seed' ] = df .loc [nan_seeds , 'seed' ]
49+ print ('WARNING: seed renamed to "mapping_seed" for compatibility.' )
50+ else :
51+ if 'seed' in df .columns :
52+ print ('WARNING: seed renamed to "mapping_seed" for compatibility.' )
53+ df ['mapping_seed' ] = df .seed
54+ else :
55+ print ('ERROR: outdated bsuite run, please relaunch.' )
56+ return df
57+
58+
4259def dc_preprocess (df_in : pd .DataFrame ) -> pd .DataFrame :
4360 """Preprocess discounting chain data for use with regret metrics."""
4461 df = df_in .copy ()
62+ df = _mapping_seed_compatibility (df )
4563 df ['optimal_horizon' ] = _HORIZONS [
4664 (df .mapping_seed % len (_HORIZONS )).astype (int )]
4765 df ['total_regret' ] = 1.1 * df .episode - df .total_return
You can’t perform that action at this time.
0 commit comments