Skip to content

Commit db8185d

Browse files
Merge pull request #2780 from rolandtannous/fix/gemma3-grpo-self-llm
Fix AttributeError in GRPO trainer for models without llm attribute
2 parents 46795df + 88b2a9c commit db8185d

File tree

1 file changed

+13
-10
lines changed

1 file changed

+13
-10
lines changed

unsloth/models/rl_replacements.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -184,8 +184,9 @@ def grpo_trainer__prepare_inputs(function_name, function):
184184
rest = re.sub(r"^[ \t]*free, total = torch.cuda.mem_get_info\(\)\s*\n", "", rest)
185185
rest = re.sub(r"^[ \t]*print\(f?\".*cuda.*\"\)\s*\n", "", rest)
186186
insert = (
187-
" if getattr(self.llm.llm_engine.vllm_config.model_config, 'enable_sleep_mode', False):\n"
188-
" self.llm.wake_up()\n"
187+
" if hasattr(self, 'llm'):\n"
188+
" if getattr(self.llm.llm_engine.vllm_config.model_config, 'enable_sleep_mode', False):\n"
189+
" self.llm.wake_up()\n"
189190
)
190191
function = function[:sig_end] + insert + rest
191192
else:
@@ -199,8 +200,9 @@ def grpo_trainer__prepare_inputs(function_name, function):
199200
rest = re.sub(r"^[ \t]*free, total = torch.cuda.mem_get_info\(\)\s*\n", "", rest)
200201
rest = re.sub(r"^[ \t]*print\(f?\".*cuda.*\"\)\s*\n", "", rest)
201202
insert = (
202-
" if getattr(self.llm.llm_engine.vllm_config.model_config, 'enable_sleep_mode', False):\n"
203-
" self.llm.wake_up()\n"
203+
" if (hasattr(self, 'llm'):\n"
204+
" if getattr(self.llm.llm_engine.vllm_config.model_config, 'enable_sleep_mode', False):\n"
205+
" self.llm.wake_up()\n"
204206
)
205207
function = header_and_comments + insert + rest
206208

@@ -218,8 +220,9 @@ def grpo_trainer__prepare_inputs(function_name, function):
218220
"self.accelerator.unwrap_model(self.model, keep_fp32_wrapper = False)",
219221
)
220222
sleep_and_cache = (
221-
"if getattr(self.llm.llm_engine.vllm_config.model_config, 'enable_sleep_mode', False):\n"
222-
" self.llm.sleep(os.environ.get('VLLM_SLEEP_MODE', 1))\n"
223+
"if hasattr(self, 'llm'):\n"
224+
" if getattr(self.llm.llm_engine.vllm_config.model_config, 'enable_sleep_mode', False):\n"
225+
" self.llm.sleep(os.environ.get('VLLM_SLEEP_MODE', 1))\n"
223226
" "
224227
)
225228
if re.search(r"\n\s*return ", function):
@@ -310,7 +313,7 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch
310313
logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
311314
_input_ids = input_ids
312315
_logits_to_keep = logits_to_keep
313-
316+
314317
per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep)
315318

316319
# Compute the KL divergence between the model and the reference model
@@ -330,12 +333,12 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch
330333
# loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
331334
if "old_per_token_logps" in inputs.keys():
332335
old_hidden_states = inputs["old_per_token_logps"]
333-
else:
336+
else:
334337
old_hidden_states = None
335338
input_ids = input_ids[:, -logits_to_keep:]
336339
if per_token_logps is not None:
337340
loss, completion_length, mean_kl = grpo_compute_loss_slow(
338-
ref_per_token_logps, per_token_logps, old_hidden_states, input_ids, completion_mask, self.beta, advantages,
341+
ref_per_token_logps, per_token_logps, old_hidden_states, input_ids, completion_mask, self.beta, advantages,
339342
loss_type = self.args.loss_type,
340343
epsilon_low = self.epsilon_low, epsilon_high = self.epsilon_high,
341344
max_completion_length = self.args.max_completion_length,
@@ -356,7 +359,7 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch
356359
loss, completion_length, mean_kl = grpo_accumulated_loss(
357360
self, _input_ids, logits_to_keep, completion_mask, advantages, old_hidden_states,
358361
n_chunks = self.args.unsloth_num_chunks,
359-
)
362+
)
360363

361364
# Log the metrics
362365
# completion_length = self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean().item()

0 commit comments

Comments
 (0)