3434 ToolResponse ,
3535)
3636from verl .tools .search_tool import SearchTool
37+ from verl .utils .config import omega_conf_to_dataclass
38+ from verl .workers .config import HFModelConfig , RolloutConfig
3739from verl .workers .rollout .schemas import AsyncRolloutRequest , AsyncRolloutRequestStateEnum , Message
3840from verl .workers .rollout .sglang_rollout .sglang_rollout import SGLangRollout
3941
@@ -87,18 +89,18 @@ def get_search_messages():
8789
8890
8991class TestRolloutWithSearchTools :
92+ local_model_path = "Qwen/Qwen2.5-0.5B"
93+
9094 @pytest .fixture
9195 def qwen_tokenizer (self ):
92- local_model_path = "Qwen/Qwen2.5-0.5B"
93- tokenizer = AutoTokenizer .from_pretrained (local_model_path , padding_side = "left" )
96+ tokenizer = AutoTokenizer .from_pretrained (self .local_model_path , padding_side = "left" )
9497 tokenizer .pad_token = tokenizer .eos_token
9598 return tokenizer
9699
97100 # we only need this for tokenizer
98101 @pytest .fixture
99102 def qwen_model_config (self ):
100- local_model_path = "Qwen/Qwen2.5-0.5B"
101- config = AutoConfig .from_pretrained (local_model_path )
103+ config = AutoConfig .from_pretrained (self .local_model_path )
102104 return config
103105
104106 @pytest .fixture
@@ -172,11 +174,12 @@ def mock_rollout(self, search_rollout_config, qwen_tokenizer, qwen_model_config)
172174 patch .object (SGLangRollout , "_init_inference_engine" , return_value = None ),
173175 patch .object (SGLangRollout , "_init_sampling_params" , return_value = None ),
174176 ):
177+ rollout_config : RolloutConfig = omega_conf_to_dataclass (search_rollout_config , dataclass_type = RolloutConfig )
178+ model_config = HFModelConfig (path = self .local_model_path )
175179 rollout = SGLangRollout (
176- actor_module = "" ,
177- config = search_rollout_config ,
178- processing_class = qwen_tokenizer ,
179- model_hf_config = qwen_model_config ,
180+ config = rollout_config ,
181+ model_config = model_config ,
182+ device_mesh = None ,
180183 )
181184 rollout .sampling_params = {
182185 "n" : 1 ,
@@ -193,11 +196,12 @@ def mock_rollout(self, search_rollout_config, qwen_tokenizer, qwen_model_config)
193196 def test_tools_registration (
194197 self , mock_env , mock_engine , mock_sampling , search_rollout_config , qwen_tokenizer , qwen_model_config
195198 ):
199+ rollout_config : RolloutConfig = omega_conf_to_dataclass (search_rollout_config , dataclass_type = RolloutConfig )
200+ model_config = HFModelConfig (path = self .local_model_path )
196201 rollout = SGLangRollout (
197- actor_module = "" ,
198- config = search_rollout_config ,
199- processing_class = qwen_tokenizer ,
200- model_hf_config = qwen_model_config ,
202+ config = rollout_config ,
203+ model_config = model_config ,
204+ device_mesh = None ,
201205 )
202206 assert len (rollout ._tool_schemas ) == 1
203207 assert "search" in rollout ._tool_map .keys ()
@@ -220,11 +224,12 @@ def test_rollout_req_creation(
220224 qwen_model_config ,
221225 search_data_proto ,
222226 ):
227+ rollout_config : RolloutConfig = omega_conf_to_dataclass (search_rollout_config , dataclass_type = RolloutConfig )
228+ model_config = HFModelConfig (path = self .local_model_path )
223229 rollout = SGLangRollout (
224- actor_module = "" ,
225- config = search_rollout_config ,
226- processing_class = qwen_tokenizer ,
227- model_hf_config = qwen_model_config ,
230+ config = rollout_config ,
231+ model_config = model_config ,
232+ device_mesh = None ,
228233 )
229234 req_list = rollout ._preprocess_prompt_to_async_rollout_requests (search_data_proto , n = 1 )
230235 assert len (req_list ) == 1
0 commit comments