@@ -184,8 +184,9 @@ def grpo_trainer__prepare_inputs(function_name, function):
184184 rest = re .sub (r"^[ \t]*free, total = torch.cuda.mem_get_info\(\)\s*\n" , "" , rest )
185185 rest = re .sub (r"^[ \t]*print\(f?\".*cuda.*\"\)\s*\n" , "" , rest )
186186 insert = (
187- " if getattr(self.llm.llm_engine.vllm_config.model_config, 'enable_sleep_mode', False):\n "
188- " self.llm.wake_up()\n "
187+ " if hasattr(self, 'llm'):\n "
188+ " if getattr(self.llm.llm_engine.vllm_config.model_config, 'enable_sleep_mode', False):\n "
189+ " self.llm.wake_up()\n "
189190 )
190191 function = function [:sig_end ] + insert + rest
191192 else :
@@ -199,8 +200,9 @@ def grpo_trainer__prepare_inputs(function_name, function):
199200 rest = re .sub (r"^[ \t]*free, total = torch.cuda.mem_get_info\(\)\s*\n" , "" , rest )
200201 rest = re .sub (r"^[ \t]*print\(f?\".*cuda.*\"\)\s*\n" , "" , rest )
201202 insert = (
202- " if getattr(self.llm.llm_engine.vllm_config.model_config, 'enable_sleep_mode', False):\n "
203- " self.llm.wake_up()\n "
203+ " if (hasattr(self, 'llm'):\n "
204+ " if getattr(self.llm.llm_engine.vllm_config.model_config, 'enable_sleep_mode', False):\n "
205+ " self.llm.wake_up()\n "
204206 )
205207 function = header_and_comments + insert + rest
206208
@@ -218,8 +220,9 @@ def grpo_trainer__prepare_inputs(function_name, function):
218220 "self.accelerator.unwrap_model(self.model, keep_fp32_wrapper = False)" ,
219221 )
220222 sleep_and_cache = (
221- "if getattr(self.llm.llm_engine.vllm_config.model_config, 'enable_sleep_mode', False):\n "
222- " self.llm.sleep(os.environ.get('VLLM_SLEEP_MODE', 1))\n "
223+ "if hasattr(self, 'llm'):\n "
224+ " if getattr(self.llm.llm_engine.vllm_config.model_config, 'enable_sleep_mode', False):\n "
225+ " self.llm.sleep(os.environ.get('VLLM_SLEEP_MODE', 1))\n "
223226 " "
224227 )
225228 if re .search (r"\n\s*return " , function ):
@@ -310,7 +313,7 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch
310313 logits_to_keep = completion_ids .size (1 ) # we only need to compute the logits for the completion tokens
311314 _input_ids = input_ids
312315 _logits_to_keep = logits_to_keep
313-
316+
314317 per_token_logps = self ._get_per_token_logps (model , input_ids , attention_mask , logits_to_keep )
315318
316319 # Compute the KL divergence between the model and the reference model
@@ -330,12 +333,12 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch
330333 # loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
331334 if "old_per_token_logps" in inputs .keys ():
332335 old_hidden_states = inputs ["old_per_token_logps" ]
333- else :
336+ else :
334337 old_hidden_states = None
335338 input_ids = input_ids [:, - logits_to_keep :]
336339 if per_token_logps is not None :
337340 loss , completion_length , mean_kl = grpo_compute_loss_slow (
338- ref_per_token_logps , per_token_logps , old_hidden_states , input_ids , completion_mask , self .beta , advantages ,
341+ ref_per_token_logps , per_token_logps , old_hidden_states , input_ids , completion_mask , self .beta , advantages ,
339342 loss_type = self .args .loss_type ,
340343 epsilon_low = self .epsilon_low , epsilon_high = self .epsilon_high ,
341344 max_completion_length = self .args .max_completion_length ,
@@ -356,7 +359,7 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch
356359 loss , completion_length , mean_kl = grpo_accumulated_loss (
357360 self , _input_ids , logits_to_keep , completion_mask , advantages , old_hidden_states ,
358361 n_chunks = self .args .unsloth_num_chunks ,
359- )
362+ )
360363
361364 # Log the metrics
362365 # completion_length = self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean().item()
0 commit comments