Skip to content

Commit 1aa7aa1

Browse files
committed
no unnecessary logits upcast. fix naming
Signed-off-by: datta0 <[email protected]>
1 parent 478ec60 commit 1aa7aa1

File tree

2 files changed

+5
-6
lines changed

2 files changed

+5
-6
lines changed

unsloth/models/rl.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -672,10 +672,10 @@ def patch_functions(RLTrainer, trainer_file, RLTrainer_name, all_imports, import
672672
if trl_version >= "0.18":
673673
# Replace LLM init with already existing vLLM engine for colocate mode
674674
vllm_llm_init_pattern = r"self\.llm\s*=\s*LLM\([^)]*\)*\)"
675-
vllm_llm_repalcement = "self.llm = model.vllm_engine\n"
675+
vllm_llm_replacement = "self.llm = model.vllm_engine\n"
676676
new_vllm_part = re.sub(
677677
vllm_llm_init_pattern,
678-
vllm_llm_repalcement,
678+
vllm_llm_replacement,
679679
new_vllm_part,
680680
flags=re.DOTALL # Ensure . matches newlines [[5]]
681681
)

unsloth/models/rl_replacements.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -215,18 +215,17 @@ def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep,
215215
if not hasattr(self, '_autocast_dtype'):
216216
self._autocast_dtype = torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16
217217
if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1': self._autocast_dtype = torch.float16
218-
os.environ["UNSLOTH_RETURN_HIDDEN_STATES"] = "0"
218+
# os.environ["UNSLOTH_RETURN_HIDDEN_STATES"] = "0"
219219
with torch.amp.autocast(device_type = 'cuda', dtype = self._autocast_dtype):
220220
# We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded
221221
logits = model(input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1).logits
222222
logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
223-
logits = logits.to(torch.float32)
223+
# logits = logits.to(torch.float32)
224224
input_ids = input_ids[:, -logits_to_keep:]
225225
# For transformers<=4.48, logits_to_keep argument isn't supported, so here we drop logits ourselves.
226226
# See https://github.com/huggingface/trl/issues/2770
227227
logits = logits[:, -logits_to_keep:]
228-
#return logits
229-
return selective_log_softmax(logits, input_ids) # compute logprobs for the input tokens
228+
return logits
230229
pass
231230
pass
232231

0 commit comments

Comments
 (0)