Skip to content

Commit 68d5399

Browse files
authored
[trainer] fix: avoid loading duplicated custom reward function to fix issue volcengine#3150 (volcengine#3404)
1 parent 277036e commit 68d5399

File tree

1 file changed

+18
-15
lines changed

1 file changed

+18
-15
lines changed

verl/trainer/ppo/reward.py

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -64,25 +64,28 @@ def get_custom_reward_fn(config: DictConfig) -> Optional[RawRewardFn]:
6464
if not file_path:
6565
return None
6666

67-
if not os.path.exists(file_path):
68-
raise FileNotFoundError(f"Reward function file '{file_path}' not found.")
69-
70-
spec = importlib.util.spec_from_file_location("custom_module", file_path)
71-
assert spec is not None
72-
module = importlib.util.module_from_spec(spec)
73-
try:
74-
sys.modules["custom_module"] = module
75-
assert spec.loader is not None
76-
spec.loader.exec_module(module)
77-
except Exception as e:
78-
raise RuntimeError(f"Error loading module from '{file_path}': {e}") from e
79-
8067
function_name = reward_fn_config.get("name")
8168
assert function_name is not None
69+
70+
module = sys.modules.get("custom_module", None)
71+
if module is None:
72+
if not os.path.exists(file_path):
73+
raise FileNotFoundError(f"Reward function file '{file_path}' not found.")
74+
75+
spec = importlib.util.spec_from_file_location("custom_module", file_path)
76+
assert spec is not None
77+
module = importlib.util.module_from_spec(spec)
78+
try:
79+
sys.modules["custom_module"] = module
80+
assert spec.loader is not None
81+
spec.loader.exec_module(module)
82+
except Exception as e:
83+
raise RuntimeError(f"Error loading module from '{file_path}': {e}") from e
84+
8285
if not hasattr(module, function_name):
83-
raise AttributeError(f"Reward function '{function_name}' not found in '{file_path}'.")
86+
raise AttributeError(f"Reward function '{function_name}' not found in '{module.__file__}'.")
8487

85-
print(f"using customized reward function '{function_name}' from '{file_path}'")
88+
print(f"using customized reward function '{function_name}' from '{module.__file__}'")
8689
raw_fn = getattr(module, function_name)
8790

8891
reward_kwargs = dict(reward_fn_config.get("reward_kwargs", {}))

0 commit comments

Comments
 (0)