Skip to content

Commit 828f9bf

Browse files
iosbandcopybara-github
authored andcommitted
Utility function in analysis to maintain compatibility between old/new bsuite runs.
Does not change performance. PiperOrigin-RevId: 330719181 Change-Id: Id32f50649d86f688c17163c57343f567fbc095be
1 parent 333b09f commit 828f9bf

File tree

1 file changed

+18
-0
lines changed

1 file changed

+18
-0
lines changed

bsuite/experiments/discounting_chain/analysis.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,27 @@ def score(df: pd.DataFrame) -> float:
3939
return np.clip(raw_score, 0, 1)
4040

4141

42+
def _mapping_seed_compatibility(df: pd.DataFrame) -> pd.DataFrame:
43+
"""Utility function to maintain compatibility with old bsuite runs."""
44+
# Discounting chain kwarg "seed" was renamed to "mapping_seed"
45+
if 'mapping_seed' in df.columns:
46+
nan_seeds = df.mapping_seed.isna()
47+
if np.any(nan_seeds):
48+
df.loc[nan_seeds, 'mapping_seed'] = df.loc[nan_seeds, 'seed']
49+
print('WARNING: seed renamed to "mapping_seed" for compatibility.')
50+
else:
51+
if 'seed' in df.columns:
52+
print('WARNING: seed renamed to "mapping_seed" for compatibility.')
53+
df['mapping_seed'] = df.seed
54+
else:
55+
print('ERROR: outdated bsuite run, please relaunch.')
56+
return df
57+
58+
4259
def dc_preprocess(df_in: pd.DataFrame) -> pd.DataFrame:
4360
"""Preprocess discounting chain data for use with regret metrics."""
4461
df = df_in.copy()
62+
df = _mapping_seed_compatibility(df)
4563
df['optimal_horizon'] = _HORIZONS[
4664
(df.mapping_seed % len(_HORIZONS)).astype(int)]
4765
df['total_regret'] = 1.1 * df.episode - df.total_return

0 commit comments

Comments
 (0)