@@ -64,25 +64,28 @@ def get_custom_reward_fn(config: DictConfig) -> Optional[RawRewardFn]:
6464 if not file_path :
6565 return None
6666
67- if not os .path .exists (file_path ):
68- raise FileNotFoundError (f"Reward function file '{ file_path } ' not found." )
69-
70- spec = importlib .util .spec_from_file_location ("custom_module" , file_path )
71- assert spec is not None
72- module = importlib .util .module_from_spec (spec )
73- try :
74- sys .modules ["custom_module" ] = module
75- assert spec .loader is not None
76- spec .loader .exec_module (module )
77- except Exception as e :
78- raise RuntimeError (f"Error loading module from '{ file_path } ': { e } " ) from e
79-
8067 function_name = reward_fn_config .get ("name" )
8168 assert function_name is not None
69+
70+ module = sys .modules .get ("custom_module" , None )
71+ if module is None :
72+ if not os .path .exists (file_path ):
73+ raise FileNotFoundError (f"Reward function file '{ file_path } ' not found." )
74+
75+ spec = importlib .util .spec_from_file_location ("custom_module" , file_path )
76+ assert spec is not None
77+ module = importlib .util .module_from_spec (spec )
78+ try :
79+ sys .modules ["custom_module" ] = module
80+ assert spec .loader is not None
81+ spec .loader .exec_module (module )
82+ except Exception as e :
83+ raise RuntimeError (f"Error loading module from '{ file_path } ': { e } " ) from e
84+
8285 if not hasattr (module , function_name ):
83- raise AttributeError (f"Reward function '{ function_name } ' not found in '{ file_path } '." )
86+ raise AttributeError (f"Reward function '{ function_name } ' not found in '{ module . __file__ } '." )
8487
85- print (f"using customized reward function '{ function_name } ' from '{ file_path } '" )
88+ print (f"using customized reward function '{ function_name } ' from '{ module . __file__ } '" )
8689 raw_fn = getattr (module , function_name )
8790
8891 reward_kwargs = dict (reward_fn_config .get ("reward_kwargs" , {}))
0 commit comments